David Zhao
set b model to 0.4.1-intl
ebe0b0a
import os
import re
import unicodedata
from functools import lru_cache
import numpy as np
import pandas as pd
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import gradio as gr
# ===== Defaults (you can change from the UI too) =====
DEFAULT_MODEL_A_ID = "livekit/turn-detector"
DEFAULT_MODEL_A_REV = "v0.3.0-intl"
DEFAULT_MODEL_B_ID = "livekit/turn-detector"
DEFAULT_MODEL_B_REV = "v0.4.1-intl" # adjust if there's a specific revision
# ===== Utilities =====
def normalize_text(text: str) -> str:
text = unicodedata.normalize("NFKC", text.lower())
text = ''.join(
ch for ch in text
if not (unicodedata.category(ch).startswith('P') and ch not in ["'", "-"])
)
text = re.sub(r'\s+', ' ', text).strip()
return text
def log_odds(p, eps=0.0):
return np.log(p / (1 - p + eps))
def get_threshold(rev_id):
import requests
DEFAULT_THRESH = 0.0049
URL = f"https://huggingface.co/livekit/turn-detector/resolve/{rev_id}/languages.json"
try:
config = requests.get(URL).json().get("en")
except Exception as e:
print(f"Error loading languages.json: \n{e}")
config = {}
return config.get("threshold", DEFAULT_THRESH)
# ===== Per-model runner (keeps tokenizer/model and token ids) =====
class ModelRunner:
def __init__(self, model_id: str, revision: str | None = None, dtype=torch.bfloat16):
self.model_id = model_id
self.revision = revision
self.tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)
self.model = AutoModelForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
device_map="auto",
)
self.model.eval()
self.thresh = get_threshold(revision)
# Pull commonly used tokens, falling back gracefully if not present
self.START_TOKEN_ID = self._tok_id("<|im_start|>")
self.EOU_TOKEN_ID = self._tok_id("<|im_end|>")
self.NEWLINE_TOKEN_ID = self._tok_id("\n")
# Common role tokens; include both legacy and chat-template variants
self.USER_TOKEN_IDS = tuple(
tid for tid in [
self._tok_id("user"),
self._tok_id("<|user|>")
] if tid is not None
)
# Tokens we do not want to score on (specials / scaffolding)
self.SPECIAL_TOKENS = set(
tid for tid in [
self.NEWLINE_TOKEN_ID,
self.START_TOKEN_ID,
self._tok_id("user"),
self._tok_id("assistant"),
] if tid is not None
)
# For filtering in display
self.CONTROL_TOKS = set([
"<|im_start|>", "<|im_end|>", "user", "assistant", "\n"
])
def _tok_id(self, tok: str) -> int | None:
tid = self.tokenizer.convert_tokens_to_ids(tok)
# convert_tokens_to_ids returns None or 0/-1 if unknown depending on tokenizer
if tid is None or tid < 0:
return None
return tid
def format_input(self, text: str) -> str:
# If not a chat-formatted string, wrap as a single user message via chat template
if "<|im_start|>" not in text:
msg = {"role": "user", "content": normalize_text(text)}
text = self.tokenizer.apply_chat_template(
[msg],
tokenize=False,
add_generation_prompt=True
)
return text
def make_pred_mask(self, input_ids: np.ndarray) -> np.ndarray:
"""Return boolean mask: True where we should compute EoT prob (user tokens only)."""
if self.START_TOKEN_ID is None or not self.USER_TOKEN_IDS:
# Fallback: score all non-special tokens if start/user not available
return np.array([tok not in self.SPECIAL_TOKENS for tok in input_ids], dtype=bool)
user_mask = [False] * len(input_ids)
is_user_role = False
for i in range(len(input_ids)):
tok = input_ids[i]
if (self.START_TOKEN_ID is not None) and (tok == self.START_TOKEN_ID) and i + 1 < len(input_ids):
is_user_role = input_ids[i + 1] in self.USER_TOKEN_IDS
user_mask[i] = False
continue
user_mask[i] = is_user_role and (tok not in self.SPECIAL_TOKENS)
return np.array(user_mask, dtype=bool)
@torch.no_grad()
def predict_eou(self, text: str) -> pd.DataFrame:
text = self.format_input(text)
with torch.amp.autocast(self.model.device.type):
inputs = self.tokenizer.encode(
text,
add_special_tokens=False,
return_tensors="pt"
).to(self.model.device)
outputs = self.model(inputs)
# probs over vocab for each position; then take the probability of EOU token
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=-1)
if self.EOU_TOKEN_ID is None:
# If the model/tokenizer doesn't have <|im_end|>, use newline as a proxy (last resort)
fallback_id = self.NEWLINE_TOKEN_ID if self.NEWLINE_TOKEN_ID is not None else 0
eou_probs = probs[..., fallback_id]
else:
eou_probs = probs[..., self.EOU_TOKEN_ID]
eou_probs = eou_probs.squeeze(0).float().cpu().numpy()
input_ids = inputs.squeeze(0).int().cpu().numpy()
mask = self.make_pred_mask(input_ids)
# set masked positions to NaN (not scored)
eou_probs_masked = eou_probs.copy()
eou_probs_masked[~mask] = np.nan
tokens = [self.tokenizer.decode(i) for i in input_ids]
return pd.DataFrame({"token": tokens, "pred": eou_probs_masked})
def make_styled_df(self, df: pd.DataFrame, cmap="coolwarm") -> str:
EPS = 1e-12
thresh = self.thresh
_df = df.copy()
_df = _df[~_df.token.isin(self.CONTROL_TOKS)]
_df.token = _df.token.replace({"\n": "⏎", " ": "␠"})
_df["log_odds"] = (
_df.pred.fillna(thresh)
.add(EPS)
.apply(log_odds).sub(log_odds(thresh))
.mask(_df.pred.isna())
)
_df["Prob(EoT) as %"] = _df.pred.mul(100).fillna(0).astype(int)
vmin, vmax = _df.log_odds.min(), _df.log_odds.max()
vmax_abs = max(abs(vmin), abs(vmax)) * 1.5 if pd.notna(vmin) and pd.notna(vmax) else 1.0
fmt = (
_df.drop(columns=["pred"])
.style
.bar(
subset=["log_odds"],
align="zero",
vmin=-vmax_abs,
vmax=vmax_abs,
cmap=cmap,
height=70,
width=100,
)
.text_gradient(subset=["log_odds"], cmap=cmap, vmin=-vmax_abs, vmax=vmax_abs)
.format(na_rep="", precision=1, subset=["log_odds"])
.format("{:3d}", subset=["Prob(EoT) as %"])
.hide(axis="index")
)
return fmt.to_html()
def generate_highlighted_text(self, text: str):
"""Returns: (highlighted_list, styled_html) for Gradio"""
eps = 1e-12
threshold = self.thresh
if not text:
return [], "<div>No input.</div>"
df = self.predict_eou(text)
df.token = df.token.replace({"user": "\nUSER:", "assistant": "\nAGENT:"})
df = df[~df.token.isin(self.CONTROL_TOKS)]
df["score"] = (
df.pred.fillna(threshold)
.add(eps)
.apply(log_odds).sub(log_odds(threshold))
.mask(df.pred.isna() | df.pred.round(2).eq(0))
)
max_abs_score = df["score"].abs().max()
if pd.notna(max_abs_score) and max_abs_score > 0:
df.score = df.score / (max_abs_score * 1.5)
styled_df = self.make_styled_df(df[["token", "pred"]])
return list(zip(df.token, df.score)), styled_df
# ===== Cached loaders so switching models in the UI is fast =====
@lru_cache(maxsize=4)
def get_runner(model_id: str, revision: str | None):
return ModelRunner(model_id, revision)
# ===== Gradio App =====
import spaces
@spaces.GPU
def compare_models(
text: str,
model_a_id: str,
model_a_rev: str,
model_b_id: str,
model_b_rev: str,
):
runner_a = get_runner(model_a_id, model_a_rev if model_a_rev else None)
runner_b = get_runner(model_b_id, model_b_rev if model_b_rev else None)
ht_a, html_a = runner_a.generate_highlighted_text(text)
ht_b, html_b = runner_b.generate_highlighted_text(text)
# Optional: prepend small headers indicating model names in the HTML blocks
html_a = f"<h4 style='margin:0 0 8px 0'>{model_a_id}@{model_a_rev or 'default'}</h4>" + html_a
html_b = f"<h4 style='margin:0 0 8px 0'>{model_b_id}@{model_b_rev or 'default'}</h4>" + html_b
return ht_a, html_a, ht_b, html_b
EXAMPLE_CONVO = """<|im_start|>assistant
what is your phone number<|im_end|>
<|im_start|>user
five five five four one zero zero four two three<|im_end|>"""
with gr.Blocks(theme="soft", title="Turn Detector Debugger — Side by Side") as demo:
gr.Markdown(
"""# Turn Detector Debugger — Side by Side
Visualize predicted turn endings from **two models**.
Red ⇒ agent should reply • Blue ⇒ agent should wait"""
)
with gr.Row():
text_in = gr.Textbox(
label="Input Text",
info="Input text should follow the following chat template. Transcripts should be normalized to be lowercase and without punctuation.",
value=EXAMPLE_CONVO,
lines=4,
)
with gr.Row():
with gr.Column():
gr.Markdown("### Model A")
model_a_id = gr.Textbox(value=DEFAULT_MODEL_A_ID, label="Model ID")
model_a_rev = gr.Textbox(value=DEFAULT_MODEL_A_REV, label="Revision (optional)")
with gr.Column():
gr.Markdown("### Model B")
model_b_id = gr.Textbox(value=DEFAULT_MODEL_B_ID, label="Model ID")
model_b_rev = gr.Textbox(value=DEFAULT_MODEL_B_REV, label="Revision (optional)")
run_btn = gr.Button("Run Comparison", variant="primary")
with gr.Row():
with gr.Column():
out_ht_a = gr.HighlightedText(
label="EoT Predictions (Model A)",
color_map="coolwarm",
scale=1.5,
)
out_html_a = gr.HTML(label="Raw scores (Model A)")
with gr.Column():
out_ht_b = gr.HighlightedText(
label="EoT Predictions (Model B)",
color_map="coolwarm",
scale=1.5,
)
out_html_b = gr.HTML(label="Raw scores (Model B)")
run_btn.click(
fn=compare_models,
inputs=[text_in, model_a_id, model_a_rev, model_b_id, model_b_rev],
outputs=[out_ht_a, out_html_a, out_ht_b, out_html_b]
)
demo.launch(share=True)