Spaces:
Sleeping
Sleeping
File size: 10,644 Bytes
8d71e21 089e309 280c08c 089e309 280c08c 089e309 280c08c 1e56b47 280c08c 1e56b47 280c08c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
"""
This script provides an interactive Gradio web application for visualizing token-level attributions in language model predictions using Integrated Gradients. It loads a small LLaMA model, computes how each input token contributes to the probability of a specified target token, and generates a color-coded visualization to explain model reasoning.
Features:
- Loads a causal language model and tokenizer (LLaMA).
- Computes Integrated Gradients attributions for a prompt and target token.
- Visualizes token contributions with a grid of colored boxes (green = positive, red = negative).
- Interactive Gradio UI for custom prompts and target tokens.
- Includes a Feynman-style explanation for interpretability concepts.
How to run:
1. Ensure Python dependencies are installed: torch, transformers, captum, matplotlib, gradio.
2. Place this file in your project directory.
3. Run the script from the command line:
python app.py
4. The app will launch locally (default port 7860). Open the provided URL in your browser.
5. Enter a prompt and target token to see the visualization and interpret model predictions.
Notes:
- The script saves the visualization as 'token_attributions.png'.
- For long prompts (>50 tokens), a warning is shown to prevent performance issues.
- Example prompts are provided for quick testing.
"""
import os
import logging
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from captum.attr import IntegratedGradients
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import gradio as gr # Added for interactive UI
device = "cuda" if torch.cuda.is_available() else "cpu"
# Basic logger for helpful messages when loading gated models
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# ---------------- Load model (gated models handled safely) ----------------
# Default attempts to load LLaMA-3.2-1B, but that model is gated on HF. We try to use
# HUGGINGFACE_HUB_TOKEN if available, otherwise fall back to a small public model for demo.
requested_model = "meta-llama/Llama-3.2-1B"
fallback_model = "distilgpt2"
hf_token = os.environ.get("HUGGINGFACE_HUB_TOKEN")
model_name = requested_model
try:
load_kwargs = {}
if hf_token:
load_kwargs["use_auth_token"] = hf_token
tokenizer = AutoTokenizer.from_pretrained(model_name, **load_kwargs)
model = AutoModelForCausalLM.from_pretrained(model_name, **load_kwargs).to(device)
model.eval()
logger.info(f"Loaded gated model: {model_name}")
except Exception as e:
logger.warning(f"Could not load requested model '{requested_model}': {e}")
logger.info(f"Falling back to public model: {fallback_model} for demo purposes.")
model_name = fallback_model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
model.eval()
# ---------------- Modularized Functions ----------------
def compute_attributions(prompt, target_token):
"""
Compute Integrated Gradients attributions for a given prompt and target token.
Appeals to devs/ML: Shows model interpretability; business: Builds trust by explaining AI decisions.
"""
inputs = tokenizer(prompt, return_tensors="pt").to(device)
target_id = tokenizer(target_token, add_special_tokens=False)["input_ids"][0]
def forward_func(embeds):
outputs = model(inputs_embeds=embeds)
logits = outputs.logits[:, -1, :]
probs = torch.softmax(logits, dim=-1)
return probs[:, target_id]
embeddings = model.get_input_embeddings()(inputs["input_ids"])
embeddings.requires_grad_(True)
ig = IntegratedGradients(forward_func)
attributions, delta = ig.attribute(
embeddings, n_steps=30, return_convergence_delta=True
)
token_attr = attributions.sum(-1).squeeze().detach().cpu()
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"].squeeze())
# Normalize safely
token_attr_np = token_attr.numpy()
norm_denom = (abs(token_attr_np).max() + 1e-8)
token_attr_np = token_attr_np / norm_denom
return tokens, token_attr_np
def create_visualization(tokens, token_attr_np, prompt, target_token):
"""
Generate an appealing visualization: Grid of colored token boxes.
Enhanced for mixed audience: Clean design, simple explanations, professional look.
"""
num_tokens = max(1, len(tokens))
cols = min(max(3, int(num_tokens**0.5)), 8)
rows = (num_tokens + cols - 1) // cols
box_w = 1.0 / cols
box_h = 0.18
fig_h = max(4, rows * 0.7 + 2.0) # Increased height for more spacing
fig = plt.figure(figsize=(12, fig_h))
# Add title for context
fig.suptitle(f"Token Contributions to Predicting '{target_token}' in: '{prompt}'",
fontsize=14, y=0.95, ha='center')
ax = fig.add_axes([0, 0.30, 1, 0.60]) # Shift grid higher for more bottom space
ax.set_xlim(0, cols)
ax.set_ylim(0, rows)
ax.axis('off')
# Normalize for colormap (0-1 range)
minv, maxv = token_attr_np.min(), token_attr_np.max()
norm = (token_attr_np - minv) / (maxv - minv + 1e-8)
cmap = plt.get_cmap('RdYlGn') # Green positive, red negative
from matplotlib.patches import FancyBboxPatch
for idx, (tok, score_norm) in enumerate(zip(tokens, norm)):
r = idx // cols
c = idx % cols
x = c
y = rows - 1 - r
color = cmap(score_norm)
pad = 0.08
rect = FancyBboxPatch((x + pad*0.15, y + pad*0.15), 1 - pad, box_h - pad*0.3,
boxstyle='round,pad=0.02', linewidth=0.8,
facecolor=color, edgecolor='gray', alpha=0.95) # Softer edges
ax.add_patch(rect)
# Improved text: Larger font, wrap long tokens
display_tok = tok.replace('Ġ', ' ') if isinstance(tok, str) else str(tok) # Space for subwords
ax.text(x + 0.5, y + box_h/2, display_tok, ha='center', va='center',
fontsize=10, fontweight='bold') # Bold for readability
# Enhanced colorbar - lowered position
sm = plt.cm.ScalarMappable(cmap=cmap)
sm.set_array([0, 1])
cax = fig.add_axes([0.1, 0.22, 0.8, 0.04]) # Lowered from 0.18
cb = fig.colorbar(sm, cax=cax, orientation='horizontal')
cb.set_label('Contribution Strength', fontsize=11, fontweight='bold')
# Markers for audience-friendly explanation - lowered
fig.text(0.05, 0.16, 'Green Positive (helps prediction)', fontsize=10, ha='left')
fig.text(0.75, 0.16, 'Red Negative (hinders prediction)', fontsize=10, ha='right')
# Engaging caption for mixed audience - shortened and lowered with wrap
caption = (
"How input tokens influence the model's target prediction: Green supports (builds AI trust), "
"red opposes. For debugging (devs), reasoning insights (ML), reliable decisions (business). Normalized."
)
fig.text(0.5, 0.08, caption, fontsize=9, ha='center', va='top', wrap=True) # Smaller font, lower pos
# Save with higher quality
out_path = 'token_attributions.png'
fig.savefig(out_path, dpi=300, bbox_inches='tight', facecolor='white')
plt.close(fig) # Clean up
return out_path
# ---------------- Gradio Interface for Interactivity ----------------
def generate_attribution(prompt, target_token):
"""
Gradio wrapper: Compute and visualize for custom inputs.
Default example: France capital for quick demo.
"""
if not prompt.strip():
prompt = "The capital of France is"
if not target_token.strip():
target_token = " Paris"
# Add check for long prompts to prevent overload
if len(prompt.split()) > 50:
return "Warning: Prompt too long (>50 tokens). Shorten for better performance."
try:
tokens, token_attr_np = compute_attributions(prompt, target_token)
img_path = create_visualization(tokens, token_attr_np, prompt, target_token)
return img_path
except Exception as e:
return f"Error: {str(e)}"
# Launch interactive app
iface = gr.Interface(
fn=generate_attribution,
inputs=[
gr.Textbox(label="Prompt", value="The capital of France is", placeholder="Enter your prompt..."),
gr.Textbox(label="Target Token", value=" Paris", placeholder="Enter target token (e.g., ' Paris')")
],
outputs=gr.Image(label="Token Attribution Visualization"),
title="AI Interpretability Explorer: See How Tokens Influence Predictions",
description="Input a prompt and target token to visualize token contributions using [Integrated Gradients](https://captum.ai/docs/extension/integrated_gradients) on LLaMA. "
"Explore model reasoning interactively.",
# Insert a collapsible Feynman-style explanation and quick cheat-sheet actions using HTML so Gradio shows it above the app.
# We use safe escaping for the cheat text when embedding into HTML/JS.
# The small JS below enables a copy-to-clipboard action and a downloadable .txt file via data URI.
article="""
### How it works — Feynman-style
This tool explains which input tokens most influence the model's next-token prediction using Integrated Gradients https://captum.ai/docs/extension/integrated_gradients.
- What it does: Interpolates from a baseline to the actual input in embedding space, accumulates gradients along the path, and attributes importance to each input token.
- Why it helps: Highlights which tokens push the model toward (green) or away from (red) the chosen target token. Useful for debugging, bias detection, and model transparency.
- How to read results: Higher positive values (green) mean the token increases the probability of the target; negative values (red) mean the token reduces it. Values are normalized per example.
- Watch-outs: IG depends on the baseline choice and number of interpolation steps. Subword tokens (e.g., Ġ) are shown with spaces; long prompts may be noisy.
"""
,
examples=[
["The capital of France is", " Paris"],
["I love this product because", " it's amazing"],
["The weather today is", " sunny"]
]
)
if __name__ == "__main__":
# Run the original example for backward compatibility, then launch Gradio
print("Generating default example...")
default_img = generate_attribution("", "")
print(f"Default plot saved to: token_attributions.png")
print("\nLaunching interactive Gradio app... Open in browser for custom examples.")
iface.launch(share=True, server_name="0.0.0.0", server_port=7860)
|