Spaces:
Running
on
Zero
Running
on
Zero
| import random | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import unicodedata | |
| import re | |
| import gradio as gr | |
| from pprint import pprint | |
| MODEL_ID = "livekit/turn-detector" | |
| REVISION_ID = "v0.3.0-intl" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, revision=REVISION_ID) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| revision=REVISION_ID, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| ) | |
| model.eval() | |
| EN_THRESHOLD = 0.0049 | |
| START_TOKEN_ID = tokenizer.convert_tokens_to_ids('<|im_start|>') | |
| EOU_TOKEN_ID = tokenizer.convert_tokens_to_ids("<|im_end|>") | |
| NEWLINE_TOKEN_ID = tokenizer.convert_tokens_to_ids('\n') | |
| USER_TOKEN_IDS = ( | |
| tokenizer.convert_tokens_to_ids('user'), | |
| tokenizer.convert_tokens_to_ids('<|user|>') | |
| ) | |
| SPECIAL_TOKENS = set([ | |
| NEWLINE_TOKEN_ID, | |
| START_TOKEN_ID, | |
| tokenizer.convert_tokens_to_ids('user'), | |
| tokenizer.convert_tokens_to_ids('assistant'), | |
| ]) | |
| CONTROL_TOKS = ['<|im_start|>', '<|im_end|>', 'user', 'assistant', '\n'] | |
| def normalize_text(text): | |
| text = unicodedata.normalize("NFKC", text.lower()) | |
| text = ''.join( | |
| ch for ch in text | |
| if not (unicodedata.category(ch).startswith('P') and ch not in ["'", "-"]) | |
| ) | |
| text = re.sub(r'\s+', ' ', text).strip() | |
| return text | |
| def format_input(text): | |
| if '<|im_start|>' not in text: | |
| # assume single user turn | |
| text = {'role': 'user', 'content': normalize_text(text)} | |
| text = tokenizer.apply_chat_template( | |
| [text], | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| return text | |
| def log_odds(p, eps=0): | |
| return np.log(p /(1 - p + eps)) | |
| def make_pred_mask(input_ids): | |
| user_mask = [False] * len(input_ids) | |
| is_user_role = False | |
| for i in range(len(input_ids)-1): | |
| if input_ids[i] == START_TOKEN_ID: | |
| is_user_role = input_ids[i+1] in USER_TOKEN_IDS | |
| if is_user_role and (input_ids[i] not in SPECIAL_TOKENS): | |
| user_mask[i] = True | |
| else: | |
| user_mask[i] = False | |
| return user_mask | |
| def predict_eou(text): | |
| text = format_input(text) | |
| with torch.no_grad(): | |
| with torch.amp.autocast(model.device.type): | |
| inputs = tokenizer.encode( | |
| text, | |
| add_special_tokens=False, | |
| return_tensors="pt" | |
| ).to(model.device) | |
| outputs = model(inputs) | |
| probs = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| probs = probs.cpu().float().numpy()[:, :, EOU_TOKEN_ID].flatten() | |
| input_ids = inputs.cpu().int().flatten().numpy() | |
| mask = np.array(make_pred_mask(input_ids)) | |
| probs[~mask] = np.nan | |
| tokens = [tokenizer.decode(id) for id in input_ids] | |
| res = {'token':tokens,'pred':probs} | |
| return pd.DataFrame(res) | |
| def make_styled_df(df, thresh=EN_THRESHOLD, cmap="coolwarm"): | |
| EPS = 1e-12 | |
| df = df.copy() | |
| df = df[~df.token.isin(CONTROL_TOKS)] | |
| df.token = df.token.replace({"\n": "⏎"," ": "␠",}) | |
| df['log_odds'] = ( | |
| df.pred.fillna(thresh) | |
| .add(EPS) | |
| .apply(log_odds).sub(log_odds(thresh)) | |
| .mask(df.pred.isna()) | |
| ) | |
| df['Prob(EoT) as %'] = df.pred.mul(100).fillna(0).astype(int) | |
| vmin, vmax = df.log_odds.min(), df.log_odds.max() | |
| vmax_abs = max(abs(vmin), abs(vmax)) * 1.5 | |
| fmt = ( | |
| df.drop(columns=['pred']) | |
| .style | |
| .bar( | |
| subset=['log_odds'], | |
| align="zero", | |
| vmin=-vmax_abs, | |
| vmax=vmax_abs, | |
| cmap=cmap, | |
| height=70, | |
| width=100, | |
| ) | |
| .text_gradient(subset=['log_odds'], cmap=cmap, vmin=-vmax_abs, vmax=vmax_abs) | |
| .format(na_rep='', precision=1, subset=['log_odds']) | |
| .format("{:3d}", subset=['Prob(EoT) as %']) | |
| .hide(axis='index') | |
| ) | |
| return fmt.to_html() | |
| def generate_highlighted_text(text, threshold=EN_THRESHOLD): | |
| eps = 1e-12 | |
| if not text: | |
| return [] | |
| df = predict_eou(text) | |
| df.token = df.token.replace({"user": "\nUSER:", "assistant": "\nAGENT:"}) | |
| df = df[~df.token.isin(CONTROL_TOKS)] | |
| df['score'] = ( | |
| df.pred.fillna(threshold) | |
| .add(eps) | |
| .apply(log_odds).sub(log_odds(threshold)) | |
| .mask(df.pred.isna() | df.pred.round(2) == 0) | |
| ) | |
| max_abs_score = df['score'].abs().max() * 1.5 | |
| if max_abs_score > 0: | |
| df.score = df.score / max_abs_score | |
| styled_df = make_styled_df(df[['token', 'pred']]) | |
| return list(zip(df.token, df.score)), styled_df | |
| convo_text = """<|im_start|>assistant | |
| what is your phone number<|im_end|> | |
| <|im_start|>user | |
| 555 410 0423<|im_end|>""" | |
| demo = gr.Interface( | |
| fn=generate_highlighted_text, | |
| theme="soft", | |
| inputs=gr.Textbox( | |
| label="Input Text", | |
| # value="can you help me order some pizza", | |
| value=convo_text, | |
| lines=2 | |
| ), | |
| outputs=[ | |
| gr.HighlightedText( | |
| label="EoT Predictions", | |
| color_map="coolwarm", | |
| scale=1.5, | |
| ), | |
| gr.HTML(label="Raw scores",) | |
| ], | |
| title="Turn Detector Debugger", | |
| description="Visualize predicted turn endings. The coloring is based on log-odds, centered on the threshold.\n Red means agent should reply; Blue means agent should wait", | |
| allow_flagging="never" | |
| ) | |
| demo.launch() | |