Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import re | |
| import unicodedata | |
| from functools import lru_cache | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import gradio as gr | |
| # ===== Defaults (you can change from the UI too) ===== | |
| DEFAULT_MODEL_A_ID = "livekit/turn-detector" | |
| DEFAULT_MODEL_A_REV = "v0.3.0-intl" | |
| DEFAULT_MODEL_B_ID = "livekit/turn-detector" | |
| DEFAULT_MODEL_B_REV = "v0.4.1-intl" # adjust if there's a specific revision | |
| # ===== Utilities ===== | |
| def normalize_text(text: str) -> str: | |
| text = unicodedata.normalize("NFKC", text.lower()) | |
| text = ''.join( | |
| ch for ch in text | |
| if not (unicodedata.category(ch).startswith('P') and ch not in ["'", "-"]) | |
| ) | |
| text = re.sub(r'\s+', ' ', text).strip() | |
| return text | |
| def log_odds(p, eps=0.0): | |
| return np.log(p / (1 - p + eps)) | |
| def get_threshold(rev_id): | |
| import requests | |
| DEFAULT_THRESH = 0.0049 | |
| URL = f"https://huggingface.co/livekit/turn-detector/resolve/{rev_id}/languages.json" | |
| try: | |
| config = requests.get(URL).json().get("en") | |
| except Exception as e: | |
| print(f"Error loading languages.json: \n{e}") | |
| config = {} | |
| return config.get("threshold", DEFAULT_THRESH) | |
| # ===== Per-model runner (keeps tokenizer/model and token ids) ===== | |
| class ModelRunner: | |
| def __init__(self, model_id: str, revision: str | None = None, dtype=torch.bfloat16): | |
| self.model_id = model_id | |
| self.revision = revision | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| revision=revision, | |
| torch_dtype=dtype, | |
| device_map="auto", | |
| ) | |
| self.model.eval() | |
| self.thresh = get_threshold(revision) | |
| # Pull commonly used tokens, falling back gracefully if not present | |
| self.START_TOKEN_ID = self._tok_id("<|im_start|>") | |
| self.EOU_TOKEN_ID = self._tok_id("<|im_end|>") | |
| self.NEWLINE_TOKEN_ID = self._tok_id("\n") | |
| # Common role tokens; include both legacy and chat-template variants | |
| self.USER_TOKEN_IDS = tuple( | |
| tid for tid in [ | |
| self._tok_id("user"), | |
| self._tok_id("<|user|>") | |
| ] if tid is not None | |
| ) | |
| # Tokens we do not want to score on (specials / scaffolding) | |
| self.SPECIAL_TOKENS = set( | |
| tid for tid in [ | |
| self.NEWLINE_TOKEN_ID, | |
| self.START_TOKEN_ID, | |
| self._tok_id("user"), | |
| self._tok_id("assistant"), | |
| ] if tid is not None | |
| ) | |
| # For filtering in display | |
| self.CONTROL_TOKS = set([ | |
| "<|im_start|>", "<|im_end|>", "user", "assistant", "\n" | |
| ]) | |
| def _tok_id(self, tok: str) -> int | None: | |
| tid = self.tokenizer.convert_tokens_to_ids(tok) | |
| # convert_tokens_to_ids returns None or 0/-1 if unknown depending on tokenizer | |
| if tid is None or tid < 0: | |
| return None | |
| return tid | |
| def format_input(self, text: str) -> str: | |
| # If not a chat-formatted string, wrap as a single user message via chat template | |
| if "<|im_start|>" not in text: | |
| msg = {"role": "user", "content": normalize_text(text)} | |
| text = self.tokenizer.apply_chat_template( | |
| [msg], | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| return text | |
| def make_pred_mask(self, input_ids: np.ndarray) -> np.ndarray: | |
| """Return boolean mask: True where we should compute EoT prob (user tokens only).""" | |
| if self.START_TOKEN_ID is None or not self.USER_TOKEN_IDS: | |
| # Fallback: score all non-special tokens if start/user not available | |
| return np.array([tok not in self.SPECIAL_TOKENS for tok in input_ids], dtype=bool) | |
| user_mask = [False] * len(input_ids) | |
| is_user_role = False | |
| for i in range(len(input_ids)): | |
| tok = input_ids[i] | |
| if (self.START_TOKEN_ID is not None) and (tok == self.START_TOKEN_ID) and i + 1 < len(input_ids): | |
| is_user_role = input_ids[i + 1] in self.USER_TOKEN_IDS | |
| user_mask[i] = False | |
| continue | |
| user_mask[i] = is_user_role and (tok not in self.SPECIAL_TOKENS) | |
| return np.array(user_mask, dtype=bool) | |
| def predict_eou(self, text: str) -> pd.DataFrame: | |
| text = self.format_input(text) | |
| with torch.amp.autocast(self.model.device.type): | |
| inputs = self.tokenizer.encode( | |
| text, | |
| add_special_tokens=False, | |
| return_tensors="pt" | |
| ).to(self.model.device) | |
| outputs = self.model(inputs) | |
| # probs over vocab for each position; then take the probability of EOU token | |
| logits = outputs.logits | |
| probs = torch.nn.functional.softmax(logits, dim=-1) | |
| if self.EOU_TOKEN_ID is None: | |
| # If the model/tokenizer doesn't have <|im_end|>, use newline as a proxy (last resort) | |
| fallback_id = self.NEWLINE_TOKEN_ID if self.NEWLINE_TOKEN_ID is not None else 0 | |
| eou_probs = probs[..., fallback_id] | |
| else: | |
| eou_probs = probs[..., self.EOU_TOKEN_ID] | |
| eou_probs = eou_probs.squeeze(0).float().cpu().numpy() | |
| input_ids = inputs.squeeze(0).int().cpu().numpy() | |
| mask = self.make_pred_mask(input_ids) | |
| # set masked positions to NaN (not scored) | |
| eou_probs_masked = eou_probs.copy() | |
| eou_probs_masked[~mask] = np.nan | |
| tokens = [self.tokenizer.decode(i) for i in input_ids] | |
| return pd.DataFrame({"token": tokens, "pred": eou_probs_masked}) | |
| def make_styled_df(self, df: pd.DataFrame, cmap="coolwarm") -> str: | |
| EPS = 1e-12 | |
| thresh = self.thresh | |
| _df = df.copy() | |
| _df = _df[~_df.token.isin(self.CONTROL_TOKS)] | |
| _df.token = _df.token.replace({"\n": "⏎", " ": "␠"}) | |
| _df["log_odds"] = ( | |
| _df.pred.fillna(thresh) | |
| .add(EPS) | |
| .apply(log_odds).sub(log_odds(thresh)) | |
| .mask(_df.pred.isna()) | |
| ) | |
| _df["Prob(EoT) as %"] = _df.pred.mul(100).fillna(0).astype(int) | |
| vmin, vmax = _df.log_odds.min(), _df.log_odds.max() | |
| vmax_abs = max(abs(vmin), abs(vmax)) * 1.5 if pd.notna(vmin) and pd.notna(vmax) else 1.0 | |
| fmt = ( | |
| _df.drop(columns=["pred"]) | |
| .style | |
| .bar( | |
| subset=["log_odds"], | |
| align="zero", | |
| vmin=-vmax_abs, | |
| vmax=vmax_abs, | |
| cmap=cmap, | |
| height=70, | |
| width=100, | |
| ) | |
| .text_gradient(subset=["log_odds"], cmap=cmap, vmin=-vmax_abs, vmax=vmax_abs) | |
| .format(na_rep="", precision=1, subset=["log_odds"]) | |
| .format("{:3d}", subset=["Prob(EoT) as %"]) | |
| .hide(axis="index") | |
| ) | |
| return fmt.to_html() | |
| def generate_highlighted_text(self, text: str): | |
| """Returns: (highlighted_list, styled_html) for Gradio""" | |
| eps = 1e-12 | |
| threshold = self.thresh | |
| if not text: | |
| return [], "<div>No input.</div>" | |
| df = self.predict_eou(text) | |
| df.token = df.token.replace({"user": "\nUSER:", "assistant": "\nAGENT:"}) | |
| df = df[~df.token.isin(self.CONTROL_TOKS)] | |
| df["score"] = ( | |
| df.pred.fillna(threshold) | |
| .add(eps) | |
| .apply(log_odds).sub(log_odds(threshold)) | |
| .mask(df.pred.isna() | df.pred.round(2).eq(0)) | |
| ) | |
| max_abs_score = df["score"].abs().max() | |
| if pd.notna(max_abs_score) and max_abs_score > 0: | |
| df.score = df.score / (max_abs_score * 1.5) | |
| styled_df = self.make_styled_df(df[["token", "pred"]]) | |
| return list(zip(df.token, df.score)), styled_df | |
| # ===== Cached loaders so switching models in the UI is fast ===== | |
| def get_runner(model_id: str, revision: str | None): | |
| return ModelRunner(model_id, revision) | |
| # ===== Gradio App ===== | |
| import spaces | |
| def compare_models( | |
| text: str, | |
| model_a_id: str, | |
| model_a_rev: str, | |
| model_b_id: str, | |
| model_b_rev: str, | |
| ): | |
| runner_a = get_runner(model_a_id, model_a_rev if model_a_rev else None) | |
| runner_b = get_runner(model_b_id, model_b_rev if model_b_rev else None) | |
| ht_a, html_a = runner_a.generate_highlighted_text(text) | |
| ht_b, html_b = runner_b.generate_highlighted_text(text) | |
| # Optional: prepend small headers indicating model names in the HTML blocks | |
| html_a = f"<h4 style='margin:0 0 8px 0'>{model_a_id}@{model_a_rev or 'default'}</h4>" + html_a | |
| html_b = f"<h4 style='margin:0 0 8px 0'>{model_b_id}@{model_b_rev or 'default'}</h4>" + html_b | |
| return ht_a, html_a, ht_b, html_b | |
| EXAMPLE_CONVO = """<|im_start|>assistant | |
| what is your phone number<|im_end|> | |
| <|im_start|>user | |
| five five five four one zero zero four two three<|im_end|>""" | |
| with gr.Blocks(theme="soft", title="Turn Detector Debugger — Side by Side") as demo: | |
| gr.Markdown( | |
| """# Turn Detector Debugger — Side by Side | |
| Visualize predicted turn endings from **two models**. | |
| Red ⇒ agent should reply • Blue ⇒ agent should wait""" | |
| ) | |
| with gr.Row(): | |
| text_in = gr.Textbox( | |
| label="Input Text", | |
| info="Input text should follow the following chat template. Transcripts should be normalized to be lowercase and without punctuation.", | |
| value=EXAMPLE_CONVO, | |
| lines=4, | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Model A") | |
| model_a_id = gr.Textbox(value=DEFAULT_MODEL_A_ID, label="Model ID") | |
| model_a_rev = gr.Textbox(value=DEFAULT_MODEL_A_REV, label="Revision (optional)") | |
| with gr.Column(): | |
| gr.Markdown("### Model B") | |
| model_b_id = gr.Textbox(value=DEFAULT_MODEL_B_ID, label="Model ID") | |
| model_b_rev = gr.Textbox(value=DEFAULT_MODEL_B_REV, label="Revision (optional)") | |
| run_btn = gr.Button("Run Comparison", variant="primary") | |
| with gr.Row(): | |
| with gr.Column(): | |
| out_ht_a = gr.HighlightedText( | |
| label="EoT Predictions (Model A)", | |
| color_map="coolwarm", | |
| scale=1.5, | |
| ) | |
| out_html_a = gr.HTML(label="Raw scores (Model A)") | |
| with gr.Column(): | |
| out_ht_b = gr.HighlightedText( | |
| label="EoT Predictions (Model B)", | |
| color_map="coolwarm", | |
| scale=1.5, | |
| ) | |
| out_html_b = gr.HTML(label="Raw scores (Model B)") | |
| run_btn.click( | |
| fn=compare_models, | |
| inputs=[text_in, model_a_id, model_a_rev, model_b_id, model_b_rev], | |
| outputs=[out_ht_a, out_html_a, out_ht_b, out_html_b] | |
| ) | |
| demo.launch(share=True) | |