import os import re import unicodedata from functools import lru_cache import numpy as np import pandas as pd import torch from transformers import AutoModelForCausalLM, AutoTokenizer import gradio as gr # ===== Defaults (you can change from the UI too) ===== DEFAULT_MODEL_A_ID = "livekit/turn-detector" DEFAULT_MODEL_A_REV = "v0.3.0-intl" DEFAULT_MODEL_B_ID = "livekit/turn-detector" DEFAULT_MODEL_B_REV = "v0.4.1-intl" # adjust if there's a specific revision # ===== Utilities ===== def normalize_text(text: str) -> str: text = unicodedata.normalize("NFKC", text.lower()) text = ''.join( ch for ch in text if not (unicodedata.category(ch).startswith('P') and ch not in ["'", "-"]) ) text = re.sub(r'\s+', ' ', text).strip() return text def log_odds(p, eps=0.0): return np.log(p / (1 - p + eps)) def get_threshold(rev_id): import requests DEFAULT_THRESH = 0.0049 URL = f"https://huggingface.co/livekit/turn-detector/resolve/{rev_id}/languages.json" try: config = requests.get(URL).json().get("en") except Exception as e: print(f"Error loading languages.json: \n{e}") config = {} return config.get("threshold", DEFAULT_THRESH) # ===== Per-model runner (keeps tokenizer/model and token ids) ===== class ModelRunner: def __init__(self, model_id: str, revision: str | None = None, dtype=torch.bfloat16): self.model_id = model_id self.revision = revision self.tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision) self.model = AutoModelForCausalLM.from_pretrained( model_id, revision=revision, torch_dtype=dtype, device_map="auto", ) self.model.eval() self.thresh = get_threshold(revision) # Pull commonly used tokens, falling back gracefully if not present self.START_TOKEN_ID = self._tok_id("<|im_start|>") self.EOU_TOKEN_ID = self._tok_id("<|im_end|>") self.NEWLINE_TOKEN_ID = self._tok_id("\n") # Common role tokens; include both legacy and chat-template variants self.USER_TOKEN_IDS = tuple( tid for tid in [ self._tok_id("user"), self._tok_id("<|user|>") ] if tid is not None ) # Tokens we do not want to score on (specials / scaffolding) self.SPECIAL_TOKENS = set( tid for tid in [ self.NEWLINE_TOKEN_ID, self.START_TOKEN_ID, self._tok_id("user"), self._tok_id("assistant"), ] if tid is not None ) # For filtering in display self.CONTROL_TOKS = set([ "<|im_start|>", "<|im_end|>", "user", "assistant", "\n" ]) def _tok_id(self, tok: str) -> int | None: tid = self.tokenizer.convert_tokens_to_ids(tok) # convert_tokens_to_ids returns None or 0/-1 if unknown depending on tokenizer if tid is None or tid < 0: return None return tid def format_input(self, text: str) -> str: # If not a chat-formatted string, wrap as a single user message via chat template if "<|im_start|>" not in text: msg = {"role": "user", "content": normalize_text(text)} text = self.tokenizer.apply_chat_template( [msg], tokenize=False, add_generation_prompt=True ) return text def make_pred_mask(self, input_ids: np.ndarray) -> np.ndarray: """Return boolean mask: True where we should compute EoT prob (user tokens only).""" if self.START_TOKEN_ID is None or not self.USER_TOKEN_IDS: # Fallback: score all non-special tokens if start/user not available return np.array([tok not in self.SPECIAL_TOKENS for tok in input_ids], dtype=bool) user_mask = [False] * len(input_ids) is_user_role = False for i in range(len(input_ids)): tok = input_ids[i] if (self.START_TOKEN_ID is not None) and (tok == self.START_TOKEN_ID) and i + 1 < len(input_ids): is_user_role = input_ids[i + 1] in self.USER_TOKEN_IDS user_mask[i] = False continue user_mask[i] = is_user_role and (tok not in self.SPECIAL_TOKENS) return np.array(user_mask, dtype=bool) @torch.no_grad() def predict_eou(self, text: str) -> pd.DataFrame: text = self.format_input(text) with torch.amp.autocast(self.model.device.type): inputs = self.tokenizer.encode( text, add_special_tokens=False, return_tensors="pt" ).to(self.model.device) outputs = self.model(inputs) # probs over vocab for each position; then take the probability of EOU token logits = outputs.logits probs = torch.nn.functional.softmax(logits, dim=-1) if self.EOU_TOKEN_ID is None: # If the model/tokenizer doesn't have <|im_end|>, use newline as a proxy (last resort) fallback_id = self.NEWLINE_TOKEN_ID if self.NEWLINE_TOKEN_ID is not None else 0 eou_probs = probs[..., fallback_id] else: eou_probs = probs[..., self.EOU_TOKEN_ID] eou_probs = eou_probs.squeeze(0).float().cpu().numpy() input_ids = inputs.squeeze(0).int().cpu().numpy() mask = self.make_pred_mask(input_ids) # set masked positions to NaN (not scored) eou_probs_masked = eou_probs.copy() eou_probs_masked[~mask] = np.nan tokens = [self.tokenizer.decode(i) for i in input_ids] return pd.DataFrame({"token": tokens, "pred": eou_probs_masked}) def make_styled_df(self, df: pd.DataFrame, cmap="coolwarm") -> str: EPS = 1e-12 thresh = self.thresh _df = df.copy() _df = _df[~_df.token.isin(self.CONTROL_TOKS)] _df.token = _df.token.replace({"\n": "⏎", " ": "␠"}) _df["log_odds"] = ( _df.pred.fillna(thresh) .add(EPS) .apply(log_odds).sub(log_odds(thresh)) .mask(_df.pred.isna()) ) _df["Prob(EoT) as %"] = _df.pred.mul(100).fillna(0).astype(int) vmin, vmax = _df.log_odds.min(), _df.log_odds.max() vmax_abs = max(abs(vmin), abs(vmax)) * 1.5 if pd.notna(vmin) and pd.notna(vmax) else 1.0 fmt = ( _df.drop(columns=["pred"]) .style .bar( subset=["log_odds"], align="zero", vmin=-vmax_abs, vmax=vmax_abs, cmap=cmap, height=70, width=100, ) .text_gradient(subset=["log_odds"], cmap=cmap, vmin=-vmax_abs, vmax=vmax_abs) .format(na_rep="", precision=1, subset=["log_odds"]) .format("{:3d}", subset=["Prob(EoT) as %"]) .hide(axis="index") ) return fmt.to_html() def generate_highlighted_text(self, text: str): """Returns: (highlighted_list, styled_html) for Gradio""" eps = 1e-12 threshold = self.thresh if not text: return [], "
No input.
" df = self.predict_eou(text) df.token = df.token.replace({"user": "\nUSER:", "assistant": "\nAGENT:"}) df = df[~df.token.isin(self.CONTROL_TOKS)] df["score"] = ( df.pred.fillna(threshold) .add(eps) .apply(log_odds).sub(log_odds(threshold)) .mask(df.pred.isna() | df.pred.round(2).eq(0)) ) max_abs_score = df["score"].abs().max() if pd.notna(max_abs_score) and max_abs_score > 0: df.score = df.score / (max_abs_score * 1.5) styled_df = self.make_styled_df(df[["token", "pred"]]) return list(zip(df.token, df.score)), styled_df # ===== Cached loaders so switching models in the UI is fast ===== @lru_cache(maxsize=4) def get_runner(model_id: str, revision: str | None): return ModelRunner(model_id, revision) # ===== Gradio App ===== import spaces @spaces.GPU def compare_models( text: str, model_a_id: str, model_a_rev: str, model_b_id: str, model_b_rev: str, ): runner_a = get_runner(model_a_id, model_a_rev if model_a_rev else None) runner_b = get_runner(model_b_id, model_b_rev if model_b_rev else None) ht_a, html_a = runner_a.generate_highlighted_text(text) ht_b, html_b = runner_b.generate_highlighted_text(text) # Optional: prepend small headers indicating model names in the HTML blocks html_a = f"

{model_a_id}@{model_a_rev or 'default'}

" + html_a html_b = f"

{model_b_id}@{model_b_rev or 'default'}

" + html_b return ht_a, html_a, ht_b, html_b EXAMPLE_CONVO = """<|im_start|>assistant what is your phone number<|im_end|> <|im_start|>user five five five four one zero zero four two three<|im_end|>""" with gr.Blocks(theme="soft", title="Turn Detector Debugger — Side by Side") as demo: gr.Markdown( """# Turn Detector Debugger — Side by Side Visualize predicted turn endings from **two models**. Red ⇒ agent should reply • Blue ⇒ agent should wait""" ) with gr.Row(): text_in = gr.Textbox( label="Input Text", info="Input text should follow the following chat template. Transcripts should be normalized to be lowercase and without punctuation.", value=EXAMPLE_CONVO, lines=4, ) with gr.Row(): with gr.Column(): gr.Markdown("### Model A") model_a_id = gr.Textbox(value=DEFAULT_MODEL_A_ID, label="Model ID") model_a_rev = gr.Textbox(value=DEFAULT_MODEL_A_REV, label="Revision (optional)") with gr.Column(): gr.Markdown("### Model B") model_b_id = gr.Textbox(value=DEFAULT_MODEL_B_ID, label="Model ID") model_b_rev = gr.Textbox(value=DEFAULT_MODEL_B_REV, label="Revision (optional)") run_btn = gr.Button("Run Comparison", variant="primary") with gr.Row(): with gr.Column(): out_ht_a = gr.HighlightedText( label="EoT Predictions (Model A)", color_map="coolwarm", scale=1.5, ) out_html_a = gr.HTML(label="Raw scores (Model A)") with gr.Column(): out_ht_b = gr.HighlightedText( label="EoT Predictions (Model B)", color_map="coolwarm", scale=1.5, ) out_html_b = gr.HTML(label="Raw scores (Model B)") run_btn.click( fn=compare_models, inputs=[text_in, model_a_id, model_a_rev, model_b_id, model_b_rev], outputs=[out_ht_a, out_html_a, out_ht_b, out_html_b] ) demo.launch(share=True)