Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| WimBERT Synth v0 Gradio Space | |
| Dual-head multi-label classifier for Dutch signal messages | |
| """ | |
| import json | |
| import importlib.util | |
| import torch | |
| import gradio as gr | |
| from huggingface_hub import snapshot_download | |
| # Constants | |
| MODEL_REPO = "UWV/wimbert-synth-v0" | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| DTYPE = torch.float16 if DEVICE.type == "cuda" else torch.float32 | |
| print(f"🔧 Loading model from {MODEL_REPO}...") | |
| print(f"🖥️ Device: {DEVICE} ({DTYPE})") | |
| # Download model files (uses HF cache) | |
| model_dir = snapshot_download(MODEL_REPO, cache_dir=None) | |
| # Dynamic import of model.py from downloaded dir | |
| spec = importlib.util.spec_from_file_location("model", f"{model_dir}/model.py") | |
| model_module = importlib.util.module_from_spec(spec) | |
| spec.loader.exec_module(model_module) | |
| DualHeadModel = model_module.DualHeadModel | |
| # Load model + tokenizer + config | |
| model, tokenizer, config = DualHeadModel.from_pretrained(model_dir, device=DEVICE) | |
| # Cast to target dtype | |
| if DTYPE == torch.float16: | |
| model = model.half() | |
| # Warm-up inference | |
| with torch.no_grad(): | |
| dummy_input = tokenizer("Warm-up", return_tensors="pt", truncation=True, | |
| max_length=config["max_length"]) | |
| _ = model.predict( | |
| dummy_input["input_ids"].to(DEVICE), | |
| dummy_input["attention_mask"].to(DEVICE) | |
| ) | |
| print(f"✅ Model loaded and warmed up (max_length: {config['max_length']})") | |
| # Extract label names | |
| LABELS_ONDERWERP = config["labels"]["onderwerp"] | |
| LABELS_BELEVING = config["labels"]["beleving"] | |
| def prob_to_color(prob: float, threshold: float) -> str: | |
| """Generate CSS style for probability visualization (10X UX approved)""" | |
| # Green gradient: low prob = very light green, high prob = saturated green | |
| # Use HSL: Hue=145 (green), Saturation increases with prob, Lightness decreases | |
| saturation = 30 + int(prob * 50) # 30% to 80% | |
| lightness = 92 - int(prob * 55) # 92% to 37% | |
| # Text color: white for dark backgrounds (prob > 0.6), dark for light | |
| text_color = "#ffffff" if prob > 0.6 else "#1f2937" | |
| # Border: thick + accent for predicted, subtle for others | |
| if prob >= threshold: | |
| border = "2px solid #059669" | |
| box_shadow = "0 1px 3px rgba(5, 150, 105, 0.3)" | |
| else: | |
| border = "1px solid #d1d5db" | |
| box_shadow = "none" | |
| return ( | |
| f"background: hsl(145, {saturation}%, {lightness}%); " | |
| f"color: {text_color}; " | |
| f"border: {border}; " | |
| f"box-shadow: {box_shadow}; " | |
| f"padding: 6px 12px; " | |
| f"border-radius: 4px; " | |
| f"margin: 2px 0; " | |
| f"font-weight: 500;" | |
| ) | |
| def format_topk(labels: list, probs: list, threshold: float, topk: int) -> str: | |
| """Generate HTML for top-K labels""" | |
| sorted_indices = sorted(range(len(probs)), key=lambda i: probs[i], reverse=True) | |
| html = "<div style='display: flex; flex-direction: column; gap: 6px;'>" | |
| for idx in sorted_indices[:topk]: | |
| label = labels[idx] | |
| prob = probs[idx] | |
| style = prob_to_color(prob, threshold) | |
| predicted = " ✓" if prob >= threshold else "" | |
| html += f"<div style='{style}'><b>{label}</b>: {prob:.3f}{predicted}</div>" | |
| html += "</div>" | |
| return html | |
| def format_all_labels(head_name: str, labels: list, probs: list, threshold: float) -> str: | |
| """Generate scrollable table for all labels""" | |
| sorted_indices = sorted(range(len(probs)), key=lambda i: probs[i], reverse=True) | |
| html = f"<h3>{head_name}</h3><div style='max-height: 500px; overflow-y: auto; border: 1px solid #e5e7eb; border-radius: 4px;'>" | |
| html += "<table style='width: 100%; border-collapse: collapse;'>" | |
| html += "<thead style='position: sticky; top: 0; background: white; border-bottom: 2px solid #e5e7eb;'>" | |
| html += "<tr><th style='text-align: left; padding: 8px;'>Label</th><th style='text-align: right; padding: 8px;'>Probability</th><th style='padding: 8px;'>Predicted</th></tr>" | |
| html += "</thead><tbody>" | |
| for idx in sorted_indices: | |
| label = labels[idx] | |
| prob = probs[idx] | |
| style = prob_to_color(prob, threshold) | |
| predicted = "✓" if prob >= threshold else "" | |
| html += f"<tr><td style='{style}'><b>{label}</b></td><td style='text-align: right; padding: 8px;'>{prob:.4f}</td><td style='text-align: center; padding: 8px;'>{predicted}</td></tr>" | |
| html += "</tbody></table></div>" | |
| return html | |
| def predict(text: str, threshold: float, topk: int): | |
| """Run inference and return visualizations""" | |
| if not text or not text.strip(): | |
| empty_msg = "<p style='color: #666; font-style: italic;'>Voer een bericht in om te classificeren...</p>" | |
| return empty_msg, empty_msg, {} | |
| # Tokenize with dynamic length (only truncate if needed) | |
| inputs = tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=config["max_length"] # 1408 from model config | |
| ) | |
| # Get actual sequence length (non-padding tokens) | |
| actual_length = inputs["attention_mask"].sum().item() | |
| # Move to device | |
| input_ids = inputs["input_ids"].to(DEVICE) | |
| attention_mask = inputs["attention_mask"].to(DEVICE) | |
| # Predict | |
| onderwerp_probs, beleving_probs = model.predict(input_ids, attention_mask) | |
| # Convert to lists | |
| onderwerp_probs = onderwerp_probs[0].cpu().numpy().tolist() | |
| beleving_probs = beleving_probs[0].cpu().numpy().tolist() | |
| # Generate summary view (top-K for each head side by side) | |
| summary_html = "<div style='display: grid; grid-template-columns: 1fr 1fr; gap: 20px;'>" | |
| summary_html += f"<div><h3>Onderwerp (Top-{topk})</h3>{format_topk(LABELS_ONDERWERP, onderwerp_probs, threshold, topk)}</div>" | |
| summary_html += f"<div><h3>Beleving (Top-{topk})</h3>{format_topk(LABELS_BELEVING, beleving_probs, threshold, topk)}</div>" | |
| summary_html += "</div>" | |
| # Generate all labels view | |
| all_labels_html = "<div style='display: grid; grid-template-columns: 1fr 1fr; gap: 20px;'>" | |
| all_labels_html += f"<div>{format_all_labels('Onderwerp', LABELS_ONDERWERP, onderwerp_probs, threshold)}</div>" | |
| all_labels_html += f"<div>{format_all_labels('Beleving', LABELS_BELEVING, beleving_probs, threshold)}</div>" | |
| all_labels_html += "</div>" | |
| # Generate JSON output | |
| json_output = { | |
| "text": text, | |
| "token_count": actual_length, | |
| "max_length": config["max_length"], | |
| "threshold": threshold, | |
| "onderwerp": { | |
| "probabilities": {label: float(prob) for label, prob in zip(LABELS_ONDERWERP, onderwerp_probs)}, | |
| "predicted": [label for label, prob in zip(LABELS_ONDERWERP, onderwerp_probs) if prob >= threshold] | |
| }, | |
| "beleving": { | |
| "probabilities": {label: float(prob) for label, prob in zip(LABELS_BELEVING, beleving_probs)}, | |
| "predicted": [label for label, prob in zip(LABELS_BELEVING, beleving_probs) if prob >= threshold] | |
| } | |
| } | |
| return summary_html, all_labels_html, json_output | |
| def count_tokens(text: str) -> str: | |
| """Count tokens for live feedback""" | |
| if not text or not text.strip(): | |
| return "📏 Tokens: 0 / 1408" | |
| # Quick tokenization (no GPU needed, just counting) | |
| tokens = tokenizer(text, truncation=True, max_length=config["max_length"]) | |
| actual_length = sum(tokens["attention_mask"]) | |
| # Color code based on usage | |
| if actual_length > config["max_length"]: | |
| color = "#dc2626" # Red: truncated | |
| warning = " ⚠️ (truncated)" | |
| elif actual_length > config["max_length"] * 0.8: | |
| color = "#f59e0b" # Orange: getting long | |
| warning = "" | |
| else: | |
| color = "#059669" # Green: all good | |
| warning = "" | |
| return f"<span style='color: {color}; font-size: 0.875rem; font-weight: 500;'>📏 Tokens: {actual_length} / {config['max_length']}{warning}</span>" | |
| def load_examples(): | |
| """Load example texts""" | |
| try: | |
| with open("examples.json") as f: | |
| return json.load(f) | |
| except: | |
| return [] | |
| # Build Gradio interface | |
| with gr.Blocks(title="WimBERT Synth v0", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # 🏛️ WimBERT Synth v0: Multi-label Signaal Classifier | |
| Classificeert Nederlandse signaalberichten op **Onderwerp** (64 categorieën) en **Beleving** (33 categorieën). | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| input_text = gr.Textbox( | |
| label="Signaalbericht (Nederlands)", | |
| lines=8, | |
| placeholder="Bijv: Ik kan niet parkeren bij mijn huis en de website voor vergunningen werkt niet..." | |
| ) | |
| token_counter = gr.HTML(value="<span style='color: #6b7280; font-size: 0.875rem;'>📏 Tokens: 0 / 1408</span>") | |
| with gr.Row(): | |
| predict_btn = gr.Button("🔮 Voorspel", variant="primary", scale=2) | |
| clear_btn = gr.ClearButton([input_text], value="🗑️ Wissen", scale=1) | |
| with gr.Column(scale=1): | |
| threshold_slider = gr.Slider( | |
| minimum=0, | |
| maximum=1, | |
| value=0.5, | |
| step=0.05, | |
| label="🎯 Drempel", | |
| info="Labels boven deze waarde worden als 'voorspeld' gemarkeerd" | |
| ) | |
| topk_slider = gr.Slider( | |
| minimum=1, | |
| maximum=15, | |
| value=5, | |
| step=1, | |
| label="📊 Top-K", | |
| info="Aantal top labels om te tonen in samenvatting" | |
| ) | |
| gr.Markdown(f""" | |
| **Hardware:** {DEVICE.type.upper()} | |
| **Dtype:** {DTYPE} | |
| **Max length:** {config['max_length']} | |
| """) | |
| with gr.Tabs(): | |
| with gr.Tab("📋 Samenvatting"): | |
| summary_output = gr.HTML(label="Top voorspellingen per categorie") | |
| with gr.Tab("📊 Alle labels"): | |
| all_labels_output = gr.HTML(label="Volledige classificatie") | |
| with gr.Tab("💾 JSON"): | |
| json_output = gr.JSON(label="Ruwe output") | |
| gr.Examples( | |
| examples=load_examples(), | |
| inputs=input_text, | |
| label="📝 Voorbeelden" | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| ### ℹ️ Over dit model | |
| - **Model:** `UWV/wimbert-synth-v0` (dual-head BERT) | |
| - **Licentie:** Apache-2.0 | |
| - **Privacy:** Input wordt alleen in-memory verwerkt, niet opgeslagen | |
| [Model Card](https://huggingface.co/UWV/wimbert-synth-v0) • Gebouwd met Gradio | |
| """) | |
| # Event handlers | |
| # Live token counting as user types | |
| input_text.change( | |
| fn=count_tokens, | |
| inputs=input_text, | |
| outputs=token_counter | |
| ) | |
| # Prediction on button click | |
| predict_btn.click( | |
| fn=predict, | |
| inputs=[input_text, threshold_slider, topk_slider], | |
| outputs=[summary_output, all_labels_output, json_output] | |
| ) | |
| # Update predictions when threshold/topk changes (if there's existing output) | |
| threshold_slider.change( | |
| fn=predict, | |
| inputs=[input_text, threshold_slider, topk_slider], | |
| outputs=[summary_output, all_labels_output, json_output] | |
| ) | |
| topk_slider.change( | |
| fn=predict, | |
| inputs=[input_text, threshold_slider, topk_slider], | |
| outputs=[summary_output, all_labels_output, json_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |