File size: 10,644 Bytes
8d71e21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
089e309
 
 
280c08c
 
 
 
 
 
089e309
280c08c
 
 
089e309
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280c08c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e56b47
280c08c
 
 
 
 
 
 
1e56b47
280c08c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
"""
This script provides an interactive Gradio web application for visualizing token-level attributions in language model predictions using Integrated Gradients. It loads a small LLaMA model, computes how each input token contributes to the probability of a specified target token, and generates a color-coded visualization to explain model reasoning.
Features:
- Loads a causal language model and tokenizer (LLaMA).
- Computes Integrated Gradients attributions for a prompt and target token.
- Visualizes token contributions with a grid of colored boxes (green = positive, red = negative).
- Interactive Gradio UI for custom prompts and target tokens.
- Includes a Feynman-style explanation for interpretability concepts.

How to run:
1. Ensure Python dependencies are installed: torch, transformers, captum, matplotlib, gradio.
2. Place this file in your project directory.
3. Run the script from the command line:
    python app.py
4. The app will launch locally (default port 7860). Open the provided URL in your browser.
5. Enter a prompt and target token to see the visualization and interpret model predictions.

Notes:
- The script saves the visualization as 'token_attributions.png'.
- For long prompts (>50 tokens), a warning is shown to prevent performance issues.
- Example prompts are provided for quick testing.
"""

import os
import logging
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from captum.attr import IntegratedGradients
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import gradio as gr  # Added for interactive UI

device = "cuda" if torch.cuda.is_available() else "cpu"

# Basic logger for helpful messages when loading gated models
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# ---------------- Load model (gated models handled safely) ----------------
# Default attempts to load LLaMA-3.2-1B, but that model is gated on HF. We try to use
# HUGGINGFACE_HUB_TOKEN if available, otherwise fall back to a small public model for demo.
requested_model = "meta-llama/Llama-3.2-1B"
fallback_model = "distilgpt2"
hf_token = os.environ.get("HUGGINGFACE_HUB_TOKEN")
model_name = requested_model
try:
    load_kwargs = {}
    if hf_token:
        load_kwargs["use_auth_token"] = hf_token
    tokenizer = AutoTokenizer.from_pretrained(model_name, **load_kwargs)
    model = AutoModelForCausalLM.from_pretrained(model_name, **load_kwargs).to(device)
    model.eval()
    logger.info(f"Loaded gated model: {model_name}")
except Exception as e:
    logger.warning(f"Could not load requested model '{requested_model}': {e}")
    logger.info(f"Falling back to public model: {fallback_model} for demo purposes.")
    model_name = fallback_model
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
    model.eval()

# ---------------- Modularized Functions ----------------
def compute_attributions(prompt, target_token):
    """
    Compute Integrated Gradients attributions for a given prompt and target token.
    Appeals to devs/ML: Shows model interpretability; business: Builds trust by explaining AI decisions.
    """
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    target_id = tokenizer(target_token, add_special_tokens=False)["input_ids"][0]

    def forward_func(embeds):
        outputs = model(inputs_embeds=embeds)
        logits = outputs.logits[:, -1, :]
        probs = torch.softmax(logits, dim=-1)
        return probs[:, target_id]

    embeddings = model.get_input_embeddings()(inputs["input_ids"])
    embeddings.requires_grad_(True)

    ig = IntegratedGradients(forward_func)
    attributions, delta = ig.attribute(
        embeddings, n_steps=30, return_convergence_delta=True
    )

    token_attr = attributions.sum(-1).squeeze().detach().cpu()
    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"].squeeze())

    # Normalize safely
    token_attr_np = token_attr.numpy()
    norm_denom = (abs(token_attr_np).max() + 1e-8)
    token_attr_np = token_attr_np / norm_denom

    return tokens, token_attr_np

def create_visualization(tokens, token_attr_np, prompt, target_token):
    """
    Generate an appealing visualization: Grid of colored token boxes.
    Enhanced for mixed audience: Clean design, simple explanations, professional look.
    """
    num_tokens = max(1, len(tokens))
    cols = min(max(3, int(num_tokens**0.5)), 8)
    rows = (num_tokens + cols - 1) // cols

    box_w = 1.0 / cols
    box_h = 0.18

    fig_h = max(4, rows * 0.7 + 2.0)  # Increased height for more spacing
    fig = plt.figure(figsize=(12, fig_h))
    
    # Add title for context
    fig.suptitle(f"Token Contributions to Predicting '{target_token}' in: '{prompt}'", 
                 fontsize=14, y=0.95, ha='center')
    
    ax = fig.add_axes([0, 0.30, 1, 0.60])  # Shift grid higher for more bottom space
    ax.set_xlim(0, cols)
    ax.set_ylim(0, rows)
    ax.axis('off')

    # Normalize for colormap (0-1 range)
    minv, maxv = token_attr_np.min(), token_attr_np.max()
    norm = (token_attr_np - minv) / (maxv - minv + 1e-8)
    cmap = plt.get_cmap('RdYlGn')  # Green positive, red negative

    from matplotlib.patches import FancyBboxPatch
    for idx, (tok, score_norm) in enumerate(zip(tokens, norm)):
        r = idx // cols
        c = idx % cols
        x = c
        y = rows - 1 - r
        color = cmap(score_norm)
        pad = 0.08
        rect = FancyBboxPatch((x + pad*0.15, y + pad*0.15), 1 - pad, box_h - pad*0.3,
                              boxstyle='round,pad=0.02', linewidth=0.8,
                              facecolor=color, edgecolor='gray', alpha=0.95)  # Softer edges
        ax.add_patch(rect)
        # Improved text: Larger font, wrap long tokens
        display_tok = tok.replace('Ġ', ' ') if isinstance(tok, str) else str(tok)  # Space for subwords
        ax.text(x + 0.5, y + box_h/2, display_tok, ha='center', va='center', 
                fontsize=10, fontweight='bold')  # Bold for readability

    # Enhanced colorbar - lowered position
    sm = plt.cm.ScalarMappable(cmap=cmap)
    sm.set_array([0, 1])
    cax = fig.add_axes([0.1, 0.22, 0.8, 0.04])  # Lowered from 0.18
    cb = fig.colorbar(sm, cax=cax, orientation='horizontal')
    cb.set_label('Contribution Strength', fontsize=11, fontweight='bold')

    # Markers for audience-friendly explanation - lowered
    fig.text(0.05, 0.16, 'Green Positive (helps prediction)', fontsize=10, ha='left')
    fig.text(0.75, 0.16, 'Red Negative (hinders prediction)', fontsize=10, ha='right')

    # Engaging caption for mixed audience - shortened and lowered with wrap
    caption = (
        "How input tokens influence the model's target prediction: Green supports (builds AI trust), "
        "red opposes. For debugging (devs), reasoning insights (ML), reliable decisions (business). Normalized."
    )
    fig.text(0.5, 0.08, caption, fontsize=9, ha='center', va='top', wrap=True)  # Smaller font, lower pos

    # Save with higher quality
    out_path = 'token_attributions.png'
    fig.savefig(out_path, dpi=300, bbox_inches='tight', facecolor='white')
    plt.close(fig)  # Clean up
    return out_path

# ---------------- Gradio Interface for Interactivity ----------------
def generate_attribution(prompt, target_token):
    """
    Gradio wrapper: Compute and visualize for custom inputs.
    Default example: France capital for quick demo.
    """
    if not prompt.strip():
        prompt = "The capital of France is"
    if not target_token.strip():
        target_token = " Paris"
    
    # Add check for long prompts to prevent overload
    if len(prompt.split()) > 50:
        return "Warning: Prompt too long (>50 tokens). Shorten for better performance."
    
    try:
        tokens, token_attr_np = compute_attributions(prompt, target_token)
        img_path = create_visualization(tokens, token_attr_np, prompt, target_token)
        return img_path
    except Exception as e:
        return f"Error: {str(e)}"

# Launch interactive app
iface = gr.Interface(
    fn=generate_attribution,
    inputs=[
        gr.Textbox(label="Prompt", value="The capital of France is", placeholder="Enter your prompt..."),
        gr.Textbox(label="Target Token", value=" Paris", placeholder="Enter target token (e.g., ' Paris')")
    ],
    outputs=gr.Image(label="Token Attribution Visualization"),
    title="AI Interpretability Explorer: See How Tokens Influence Predictions",
    description="Input a prompt and target token to visualize token contributions using [Integrated Gradients](https://captum.ai/docs/extension/integrated_gradients) on LLaMA. "
                                "Explore model reasoning interactively.",
        # Insert a collapsible Feynman-style explanation and quick cheat-sheet actions using HTML so Gradio shows it above the app.
        # We use safe escaping for the cheat text when embedding into HTML/JS.
        # The small JS below enables a copy-to-clipboard action and a downloadable .txt file via data URI.
        article="""
### How it works — Feynman-style

This tool explains which input tokens most influence the model's next-token prediction using Integrated Gradients https://captum.ai/docs/extension/integrated_gradients.

- What it does: Interpolates from a baseline to the actual input in embedding space, accumulates gradients along the path, and attributes importance to each input token.
- Why it helps: Highlights which tokens push the model toward (green) or away from (red) the chosen target token. Useful for debugging, bias detection, and model transparency.
- How to read results: Higher positive values (green) mean the token increases the probability of the target; negative values (red) mean the token reduces it. Values are normalized per example.
- Watch-outs: IG depends on the baseline choice and number of interpolation steps. Subword tokens (e.g., Ġ) are shown with spaces; long prompts may be noisy.
"""
,
    examples=[
        ["The capital of France is", " Paris"],
        ["I love this product because", " it's amazing"],
        ["The weather today is", " sunny"]
    ]
)

if __name__ == "__main__":
    # Run the original example for backward compatibility, then launch Gradio
    print("Generating default example...")
    default_img = generate_attribution("", "")
    print(f"Default plot saved to: token_attributions.png")
    print("\nLaunching interactive Gradio app... Open in browser for custom examples.")
    iface.launch(share=True, server_name="0.0.0.0", server_port=7860)