code-19 / App_RL_Design_Doc.md
PiotrPasztor's picture
Add RL design doc with beginner-friendly explanations
e8e26aa

RL Text Classification Agent - Design Document

Overview

This application implements a Double DQN with Actor-Critic approach for text classification. It classifies user messages into three actions: TRIP, GITHUB, or MAIL.

For Beginners: Think of this as a smart assistant that reads your message and decides what you want to do - book a trip, do something on GitHub, or send an email. Instead of using fixed rules, it learns from examples like a human would.


Architecture

User Message β†’ DistilBERT Encoder β†’ State (768-dim) β†’ RL Agent β†’ Action + Confidence

For Beginners: The flow is simple: your text message gets converted into numbers (768 of them) by a pre-trained language model. These numbers capture the meaning of your message. Then our RL agent looks at these numbers and picks the best action.


Libraries

Library Purpose
torch Neural network framework
transformers Text encoding (DistilBERT)
FastAPI REST API server
pydantic Request/response validation

For Beginners: These are the main tools we use. PyTorch builds the brain (neural networks), Transformers helps understand text, FastAPI creates a web server so other apps can talk to ours, and Pydantic makes sure data is in the right format.


RL Concepts Used

For Beginners: Reinforcement Learning (RL) is like training a dog - the agent tries actions, gets rewards for good ones, and learns to make better choices over time. Below are the specific techniques we use.

Double DQN

Separates action selection from evaluation to reduce overestimation:

# Select action with online network
best_actions = self.q_net(next_states).argmax(dim=1)
# Evaluate with target network
next_q_values = self.target_q_net(next_states).gather(1, best_actions)

For Beginners: Regular DQN tends to be overconfident about how good actions are. Double DQN uses two networks - one picks the action, another judges it. It's like having a friend double-check your decisions to avoid being too optimistic.

Actor-Critic

  • Critic (Q-Network): Estimates action values
  • Actor (Policy Network): Outputs action probabilities

For Beginners: Imagine a movie set - the Actor performs actions, while the Critic scores how good they were. The Actor learns to do better based on the Critic's feedback. Together, they improve faster than either would alone.

Soft Target Update

Gradual target network updates for stability:

target_param = tau * param + (1 - tau) * target_param  # tau=0.005

For Beginners: Instead of suddenly copying all knowledge to the target network (which can cause instability), we blend in just 0.5% of the new knowledge each time. It's like slowly adjusting to new information rather than making sudden changes.

Entropy Regularization

Encourages exploration by penalizing confident policies:

entropy = -(probs * log_probs).sum(dim=-1).mean()
policy_loss = -advantage_weighted_loss - 0.05 * entropy

For Beginners: We don't want the agent to become too stubborn and always pick the same action. Entropy measures "randomness" - by rewarding some entropy, we encourage the agent to stay open-minded and keep exploring different options.

Epsilon-Greedy Exploration

During training: random actions with probability Ξ΅ (decays from 1.0 β†’ 0.05)

For Beginners: At the start, the agent picks random actions 100% of the time (exploring). As training progresses, it gradually shifts to using what it learned, eventually only being random 5% of the time. It's like a new employee who tries everything at first, then settles into what works.

Confidence Scoring

Combines entropy and probability for uncertainty estimation:

confidence = (1 - entropy/max_entropy) * raw_probability

For Beginners: The agent tells us how sure it is about its decision. If it's torn between options (high entropy) or the chosen action has low probability, confidence drops. This helps us know when to trust the agent vs. when to ask a human.

Outlier Detection

Uses cosine similarity to class centroids to reject out-of-distribution inputs.

For Beginners: If someone asks "What's the weather?" (not TRIP, GITHUB, or MAIL), the agent shouldn't guess. We measure how similar the input is to known categories - if it's too different from anything we trained on, we return "NONE" instead of a wrong guess.


Network Architecture

For Beginners: Neural networks are layers of math operations stacked together. Data flows through each layer, getting transformed until we get our final answer. Below are our two networks - one decides what to do (Policy), the other evaluates how good actions are (Q-Network).

Policy Network:

Linear(768β†’128) β†’ LayerNorm β†’ ReLU β†’ Dropout(0.1) β†’
Linear(128β†’128) β†’ LayerNorm β†’ ReLU β†’ Dropout(0.1) β†’
Linear(128β†’3) β†’ Softmax

For Beginners: Takes the 768 numbers from DistilBERT, shrinks them to 128, processes them, then outputs 3 probabilities (one for each action). Dropout randomly ignores some neurons during training to prevent overfitting. Softmax ensures probabilities sum to 1.

Q-Network:

Linear(768β†’128) β†’ LayerNorm β†’ ReLU β†’
Linear(128β†’128) β†’ LayerNorm β†’ ReLU β†’
Linear(128β†’3)

For Beginners: Similar structure but outputs raw scores (Q-values) for each action instead of probabilities. No Dropout here because we want stable value estimates.


Training Loop

For Beginners: Training is how the agent learns. We show it examples, tell it what's right/wrong, and it adjusts its internal numbers (weights) to do better next time.

  1. Encode texts with DistilBERT (frozen)
  2. For each batch:
    • Create positive examples (correct action β†’ reward +1)
    • Create negative examples (wrong action β†’ reward -1)
    • Update Q-network via TD learning
    • Update policy via advantage-weighted loss
    • Soft update target network
  3. Decay epsilon

For Beginners: We freeze DistilBERT (don't change it) because it's already great at understanding text. We only train our smaller RL networks. "Frozen" means we use it as a fixed tool, like using a calculator without modifying it.


API Endpoints

For Beginners: APIs let other programs talk to ours over the internet. Think of endpoints as different "phone numbers" - you call the right one depending on what you need.

Endpoint Method Description
/health GET Check model status
/action POST Classify message β†’ action + score

Request:

{"message": "Book a flight to Paris"}

Response:

{"action": "TRIP", "score": 0.87}

For Beginners: Send a message, get back an action and confidence score. The score (0.87 = 87%) tells you how confident the agent is. If confidence is too low, you'll get "NONE" instead.


Key Hyperparameters

For Beginners: Hyperparameters are settings we choose before training - they control how the learning happens. Think of them as recipe ingredients that affect the final result.

Parameter Value What it means
Learning rate 1e-3 How big each learning step is (0.001)
Gamma (discount) 0.95 How much future rewards matter vs immediate
Batch size 64 Examples processed together per step
Epochs 50 Times we go through the entire dataset
Confidence threshold 0.6 Below this, return "NONE"
Distance threshold 0.93 Similarity needed to not be an outlier

In General

Our OS assistant app communicates with the RL agent. App informs agent what it's on the screen at the moment and RL knows what action should be performed. For the demo we introduced three types of actions:

  • TRIP
  • GITHUB
  • MAIL

For Beginners: This is the big picture - imagine an assistant watching your screen. When you type something, it figures out your intent and triggers the right action automatically. It's like having a smart helper that knows whether you want to travel, code, or email just from what you say.