theomonnom commited on
Commit
8d2d3a6
·
verified ·
1 Parent(s): 7c74905

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +252 -150
app.py CHANGED
@@ -1,47 +1,26 @@
1
- import random
 
 
 
 
2
  import numpy as np
3
  import pandas as pd
4
  import torch
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
 
7
- import unicodedata
8
- import re
9
  import gradio as gr
10
- from pprint import pprint
11
-
12
 
13
 
14
- MODEL_ID = "livekit/turn-detector"
15
- REVISION_ID = "v0.3.0-intl"
 
16
 
17
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, revision=REVISION_ID)
18
- model = AutoModelForCausalLM.from_pretrained(
19
- MODEL_ID,
20
- revision=REVISION_ID,
21
- torch_dtype=torch.bfloat16,
22
- device_map="auto",
23
- )
24
- model.eval()
25
 
26
 
27
- EN_THRESHOLD = 0.0049
28
- START_TOKEN_ID = tokenizer.convert_tokens_to_ids('<|im_start|>')
29
- EOU_TOKEN_ID = tokenizer.convert_tokens_to_ids("<|im_end|>")
30
- NEWLINE_TOKEN_ID = tokenizer.convert_tokens_to_ids('\n')
31
- USER_TOKEN_IDS = (
32
- tokenizer.convert_tokens_to_ids('user'),
33
- tokenizer.convert_tokens_to_ids('<|user|>')
34
- )
35
- SPECIAL_TOKENS = set([
36
- NEWLINE_TOKEN_ID,
37
- START_TOKEN_ID,
38
- tokenizer.convert_tokens_to_ids('user'),
39
- tokenizer.convert_tokens_to_ids('assistant'),
40
- ])
41
- CONTROL_TOKS = ['<|im_start|>', '<|im_end|>', 'user', 'assistant', '\n']
42
-
43
-
44
- def normalize_text(text):
45
  text = unicodedata.normalize("NFKC", text.lower())
46
  text = ''.join(
47
  ch for ch in text
@@ -51,145 +30,268 @@ def normalize_text(text):
51
  return text
52
 
53
 
54
- def format_input(text):
55
- if '<|im_start|>' not in text:
56
- # assume single user turn
57
- text = {'role': 'user', 'content': normalize_text(text)}
58
- text = tokenizer.apply_chat_template(
59
- [text],
60
- tokenize=False,
61
- add_generation_prompt=True
62
- )
63
- return text
64
-
65
 
66
- def log_odds(p, eps=0):
67
- return np.log(p /(1 - p + eps))
68
 
 
 
 
 
 
69
 
70
- def make_pred_mask(input_ids):
71
- user_mask = [False] * len(input_ids)
72
- is_user_role = False
73
- for i in range(len(input_ids)):
74
- tok = input_ids[i]
75
- if tok == START_TOKEN_ID and i + 1 < len(input_ids):
76
- is_user_role = input_ids[i + 1] in USER_TOKEN_IDS
77
- user_mask[i] = False
78
- continue
79
- user_mask[i] = is_user_role and (tok not in SPECIAL_TOKENS)
80
- return user_mask
 
 
 
 
 
 
 
 
 
 
81
 
 
 
 
 
 
 
 
 
 
82
 
83
- def predict_eou(text):
84
- text = format_input(text)
85
- with torch.no_grad():
86
- with torch.amp.autocast(model.device.type):
87
- inputs = tokenizer.encode(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  text,
89
  add_special_tokens=False,
90
  return_tensors="pt"
91
- ).to(model.device)
92
- outputs = model(inputs)
93
- probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
94
- probs = probs.cpu().float().numpy()[:, :, EOU_TOKEN_ID].flatten()
95
-
96
- input_ids = inputs.cpu().int().flatten().numpy()
97
- mask = np.array(make_pred_mask(input_ids))
98
- probs[~mask] = np.nan
99
-
100
- tokens = [tokenizer.decode(id) for id in input_ids]
101
- res = {'token':tokens,'pred':probs}
102
- return pd.DataFrame(res)
103
-
104
-
105
- def make_styled_df(df, thresh=EN_THRESHOLD, cmap="coolwarm"):
106
- EPS = 1e-12
107
- df = df.copy()
108
- df = df[~df.token.isin(CONTROL_TOKS)]
109
- df.token = df.token.replace({"\n": "⏎"," ": "␠",})
110
-
111
- df['log_odds'] = (
112
- df.pred.fillna(thresh)
113
- .add(EPS)
114
- .apply(log_odds).sub(log_odds(thresh))
115
- .mask(df.pred.isna())
116
- )
117
- df['Prob(EoT) as %'] = df.pred.mul(100).fillna(0).astype(int)
118
- vmin, vmax = df.log_odds.min(), df.log_odds.max()
119
- vmax_abs = max(abs(vmin), abs(vmax)) * 1.5
120
-
121
- fmt = (
122
- df.drop(columns=['pred'])
123
- .style
124
- .bar(
125
- subset=['log_odds'],
126
- align="zero",
127
- vmin=-vmax_abs,
128
- vmax=vmax_abs,
129
- cmap=cmap,
130
- height=70,
131
- width=100,
132
  )
133
- .text_gradient(subset=['log_odds'], cmap=cmap, vmin=-vmax_abs, vmax=vmax_abs)
134
- .format(na_rep='', precision=1, subset=['log_odds'])
135
- .format("{:3d}", subset=['Prob(EoT) as %'])
136
- .hide(axis='index')
137
- )
138
- return fmt.to_html()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
 
 
140
 
141
- def generate_highlighted_text(text, threshold=EN_THRESHOLD):
142
- eps = 1e-12
143
- if not text:
144
- return []
145
 
146
- df = predict_eou(text)
147
- df.token = df.token.replace({"user": "\nUSER:", "assistant": "\nAGENT:"})
148
- df = df[~df.token.isin(CONTROL_TOKS)]
 
149
 
150
- df['score'] = (
151
- df.pred.fillna(threshold)
152
- .add(eps)
153
- .apply(log_odds).sub(log_odds(threshold))
154
- .mask(df.pred.isna() | df.pred.round(2) == 0)
155
- )
156
- max_abs_score = df['score'].abs().max() * 1.5
157
 
158
- if max_abs_score > 0:
159
- df.score = df.score / max_abs_score
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
- styled_df = make_styled_df(df[['token', 'pred']])
162
- return list(zip(df.token, df.score)), styled_df
 
163
 
 
164
 
165
 
166
- convo_text = """<|im_start|>assistant
167
  what is your phone number<|im_end|>
168
  <|im_start|>user
169
  555 410 0423<|im_end|>"""
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
- demo = gr.Interface(
173
- fn=generate_highlighted_text,
174
- theme="soft",
175
- inputs=gr.Textbox(
176
- label="Input Text",
177
- info="If <|im_start|> is present it will treat input as formatted convo. if not it will format it as convo with 1 user message.",
178
- # value="can you help me order some pizza",
179
- value=convo_text,
180
- lines=2
181
- ),
182
- outputs=[
183
- gr.HighlightedText(
184
- label="EoT Predictions",
185
- color_map="coolwarm",
186
- scale=1.5,
187
- ),
188
- gr.HTML(label="Raw scores",)
189
- ],
190
- title="Turn Detector Debugger",
191
- description="Visualize predicted turn endings. The coloring is based on log-odds, centered on the threshold.\n Red means agent should reply; Blue means agent should wait",
192
- allow_flagging="never"
193
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
  demo.launch()
 
1
+ import os
2
+ import re
3
+ import unicodedata
4
+ from functools import lru_cache
5
+
6
  import numpy as np
7
  import pandas as pd
8
  import torch
9
  from transformers import AutoModelForCausalLM, AutoTokenizer
10
 
 
 
11
  import gradio as gr
 
 
12
 
13
 
14
+ # ===== Defaults (you can change from the UI too) =====
15
+ DEFAULT_MODEL_A_ID = "livekit/turn-detector"
16
+ DEFAULT_MODEL_A_REV = "v0.3.0-intl"
17
 
18
+ DEFAULT_MODEL_B_ID = "livekit/eou-experiment"
19
+ DEFAULT_MODEL_B_REV = "main" # adjust if there's a specific revision
 
 
 
 
 
 
20
 
21
 
22
+ # ===== Utilities =====
23
+ def normalize_text(text: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  text = unicodedata.normalize("NFKC", text.lower())
25
  text = ''.join(
26
  ch for ch in text
 
30
  return text
31
 
32
 
33
+ def log_odds(p, eps=0.0):
34
+ return np.log(p / (1 - p + eps))
 
 
 
 
 
 
 
 
 
35
 
 
 
36
 
37
+ # ===== Per-model runner (keeps tokenizer/model and token ids) =====
38
+ class ModelRunner:
39
+ def __init__(self, model_id: str, revision: str | None = None, dtype=torch.bfloat16):
40
+ self.model_id = model_id
41
+ self.revision = revision
42
 
43
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)
44
+ self.model = AutoModelForCausalLM.from_pretrained(
45
+ model_id,
46
+ revision=revision,
47
+ torch_dtype=dtype,
48
+ device_map="auto",
49
+ )
50
+ self.model.eval()
51
+
52
+ # Pull commonly used tokens, falling back gracefully if not present
53
+ self.START_TOKEN_ID = self._tok_id("<|im_start|>")
54
+ self.EOU_TOKEN_ID = self._tok_id("<|im_end|>")
55
+ self.NEWLINE_TOKEN_ID = self._tok_id("\n")
56
+
57
+ # Common role tokens; include both legacy and chat-template variants
58
+ self.USER_TOKEN_IDS = tuple(
59
+ tid for tid in [
60
+ self._tok_id("user"),
61
+ self._tok_id("<|user|>")
62
+ ] if tid is not None
63
+ )
64
 
65
+ # Tokens we do not want to score on (specials / scaffolding)
66
+ self.SPECIAL_TOKENS = set(
67
+ tid for tid in [
68
+ self.NEWLINE_TOKEN_ID,
69
+ self.START_TOKEN_ID,
70
+ self._tok_id("user"),
71
+ self._tok_id("assistant"),
72
+ ] if tid is not None
73
+ )
74
 
75
+ # For filtering in display
76
+ self.CONTROL_TOKS = set([
77
+ "<|im_start|>", "<|im_end|>", "user", "assistant", "\n"
78
+ ])
79
+
80
+ def _tok_id(self, tok: str) -> int | None:
81
+ tid = self.tokenizer.convert_tokens_to_ids(tok)
82
+ # convert_tokens_to_ids returns None or 0/-1 if unknown depending on tokenizer
83
+ if tid is None or tid < 0:
84
+ return None
85
+ return tid
86
+
87
+ def format_input(self, text: str) -> str:
88
+ # If not a chat-formatted string, wrap as a single user message via chat template
89
+ if "<|im_start|>" not in text:
90
+ msg = {"role": "user", "content": normalize_text(text)}
91
+ text = self.tokenizer.apply_chat_template(
92
+ [msg],
93
+ tokenize=False,
94
+ add_generation_prompt=True
95
+ )
96
+ return text
97
+
98
+ def make_pred_mask(self, input_ids: np.ndarray) -> np.ndarray:
99
+ """Return boolean mask: True where we should compute EoT prob (user tokens only)."""
100
+ if self.START_TOKEN_ID is None or not self.USER_TOKEN_IDS:
101
+ # Fallback: score all non-special tokens if start/user not available
102
+ return np.array([tok not in self.SPECIAL_TOKENS for tok in input_ids], dtype=bool)
103
+
104
+ user_mask = [False] * len(input_ids)
105
+ is_user_role = False
106
+ for i in range(len(input_ids)):
107
+ tok = input_ids[i]
108
+ if (self.START_TOKEN_ID is not None) and (tok == self.START_TOKEN_ID) and i + 1 < len(input_ids):
109
+ is_user_role = input_ids[i + 1] in self.USER_TOKEN_IDS
110
+ user_mask[i] = False
111
+ continue
112
+ user_mask[i] = is_user_role and (tok not in self.SPECIAL_TOKENS)
113
+ return np.array(user_mask, dtype=bool)
114
+
115
+ @torch.no_grad()
116
+ def predict_eou(self, text: str) -> pd.DataFrame:
117
+ text = self.format_input(text)
118
+
119
+ with torch.amp.autocast(self.model.device.type):
120
+ inputs = self.tokenizer.encode(
121
  text,
122
  add_special_tokens=False,
123
  return_tensors="pt"
124
+ ).to(self.model.device)
125
+
126
+ outputs = self.model(inputs)
127
+
128
+ # probs over vocab for each position; then take the probability of EOU token
129
+ logits = outputs.logits
130
+ probs = torch.nn.functional.softmax(logits, dim=-1)
131
+ if self.EOU_TOKEN_ID is None:
132
+ # If the model/tokenizer doesn't have <|im_end|>, use newline as a proxy (last resort)
133
+ fallback_id = self.NEWLINE_TOKEN_ID if self.NEWLINE_TOKEN_ID is not None else 0
134
+ eou_probs = probs[..., fallback_id]
135
+ else:
136
+ eou_probs = probs[..., self.EOU_TOKEN_ID]
137
+
138
+ eou_probs = eou_probs.squeeze(0).float().cpu().numpy()
139
+
140
+ input_ids = inputs.squeeze(0).int().cpu().numpy()
141
+ mask = self.make_pred_mask(input_ids)
142
+
143
+ # set masked positions to NaN (not scored)
144
+ eou_probs_masked = eou_probs.copy()
145
+ eou_probs_masked[~mask] = np.nan
146
+
147
+ tokens = [self.tokenizer.decode(i) for i in input_ids]
148
+ return pd.DataFrame({"token": tokens, "pred": eou_probs_masked})
149
+
150
+ def make_styled_df(self, df: pd.DataFrame, thresh: float, cmap="coolwarm") -> str:
151
+ EPS = 1e-12
152
+ _df = df.copy()
153
+ _df = _df[~_df.token.isin(self.CONTROL_TOKS)]
154
+ _df.token = _df.token.replace({"\n": "⏎", " ": "␠"})
155
+
156
+ _df["log_odds"] = (
157
+ _df.pred.fillna(thresh)
158
+ .add(EPS)
159
+ .apply(log_odds).sub(log_odds(thresh))
160
+ .mask(_df.pred.isna())
 
 
 
 
161
  )
162
+ _df["Prob(EoT) as %"] = _df.pred.mul(100).fillna(0).astype(int)
163
+ vmin, vmax = _df.log_odds.min(), _df.log_odds.max()
164
+ vmax_abs = max(abs(vmin), abs(vmax)) * 1.5 if pd.notna(vmin) and pd.notna(vmax) else 1.0
165
+
166
+ fmt = (
167
+ _df.drop(columns=["pred"])
168
+ .style
169
+ .bar(
170
+ subset=["log_odds"],
171
+ align="zero",
172
+ vmin=-vmax_abs,
173
+ vmax=vmax_abs,
174
+ cmap=cmap,
175
+ height=70,
176
+ width=100,
177
+ )
178
+ .text_gradient(subset=["log_odds"], cmap=cmap, vmin=-vmax_abs, vmax=vmax_abs)
179
+ .format(na_rep="", precision=1, subset=["log_odds"])
180
+ .format("{:3d}", subset=["Prob(EoT) as %"])
181
+ .hide(axis="index")
182
+ )
183
+ return fmt.to_html()
184
+
185
+ def generate_highlighted_text(self, text: str, threshold: float):
186
+ """Returns: (highlighted_list, styled_html) for Gradio"""
187
+ eps = 1e-12
188
+ if not text:
189
+ return [], "<div>No input.</div>"
190
+
191
+ df = self.predict_eou(text)
192
+ df.token = df.token.replace({"user": "\nUSER:", "assistant": "\nAGENT:"})
193
+ df = df[~df.token.isin(self.CONTROL_TOKS)]
194
+
195
+ df["score"] = (
196
+ df.pred.fillna(threshold)
197
+ .add(eps)
198
+ .apply(log_odds).sub(log_odds(threshold))
199
+ .mask(df.pred.isna() | df.pred.round(2).eq(0))
200
+ )
201
+ max_abs_score = df["score"].abs().max()
202
+ if pd.notna(max_abs_score) and max_abs_score > 0:
203
+ df.score = df.score / (max_abs_score * 1.5)
204
 
205
+ styled_df = self.make_styled_df(df[["token", "pred"]], thresh=threshold)
206
+ return list(zip(df.token, df.score)), styled_df
207
 
 
 
 
 
208
 
209
+ # ===== Cached loaders so switching models in the UI is fast =====
210
+ @lru_cache(maxsize=4)
211
+ def get_runner(model_id: str, revision: str | None):
212
+ return ModelRunner(model_id, revision)
213
 
 
 
 
 
 
 
 
214
 
215
+ # ===== Gradio App =====
216
+ EN_THRESHOLD = 0.0049
217
+
218
+ def compare_models(
219
+ text: str,
220
+ model_a_id: str,
221
+ model_a_rev: str,
222
+ thresh_a: float,
223
+ model_b_id: str,
224
+ model_b_rev: str,
225
+ thresh_b: float,
226
+ ):
227
+ runner_a = get_runner(model_a_id, model_a_rev if model_a_rev else None)
228
+ runner_b = get_runner(model_b_id, model_b_rev if model_b_rev else None)
229
+
230
+ ht_a, html_a = runner_a.generate_highlighted_text(text, threshold=thresh_a)
231
+ ht_b, html_b = runner_b.generate_highlighted_text(text, threshold=thresh_b)
232
 
233
+ # Optional: prepend small headers indicating model names in the HTML blocks
234
+ html_a = f"<h4 style='margin:0 0 8px 0'>{model_a_id}@{model_a_rev or 'default'}</h4>" + html_a
235
+ html_b = f"<h4 style='margin:0 0 8px 0'>{model_b_id}@{model_b_rev or 'default'}</h4>" + html_b
236
 
237
+ return ht_a, html_a, ht_b, html_b
238
 
239
 
240
+ EXAMPLE_CONVO = """<|im_start|>assistant
241
  what is your phone number<|im_end|>
242
  <|im_start|>user
243
  555 410 0423<|im_end|>"""
244
 
245
+ with gr.Blocks(theme="soft", title="Turn Detector Debugger — Side by Side") as demo:
246
+ gr.Markdown(
247
+ """# Turn Detector Debugger — Side by Side
248
+ Visualize predicted turn endings from **two models**.
249
+ Red ⇒ agent should reply • Blue ⇒ agent should wait"""
250
+ )
251
+
252
+ with gr.Row():
253
+ text_in = gr.Textbox(
254
+ label="Input Text",
255
+ info="If <|im_start|> is present, input is treated as chat-formatted; otherwise it's wrapped as a single user turn.",
256
+ value=EXAMPLE_CONVO,
257
+ lines=4,
258
+ )
259
 
260
+ with gr.Row():
261
+ with gr.Column():
262
+ gr.Markdown("### Model A")
263
+ model_a_id = gr.Textbox(value=DEFAULT_MODEL_A_ID, label="Model ID")
264
+ model_a_rev = gr.Textbox(value=DEFAULT_MODEL_A_REV, label="Revision (optional)")
265
+ thresh_a = gr.Slider(0.0001, 0.05, value=EN_THRESHOLD, step=0.0001, label="Threshold")
266
+
267
+ with gr.Column():
268
+ gr.Markdown("### Model B")
269
+ model_b_id = gr.Textbox(value=DEFAULT_MODEL_B_ID, label="Model ID")
270
+ model_b_rev = gr.Textbox(value=DEFAULT_MODEL_B_REV, label="Revision (optional)")
271
+ thresh_b = gr.Slider(0.0001, 0.05, value=EN_THRESHOLD, step=0.0001, label="Threshold")
272
+
273
+ run_btn = gr.Button("Run Comparison", variant="primary")
274
+
275
+ with gr.Row():
276
+ with gr.Column():
277
+ out_ht_a = gr.HighlightedText(
278
+ label="EoT Predictions (Model A)",
279
+ color_map="coolwarm",
280
+ scale=1.5,
281
+ )
282
+ out_html_a = gr.HTML(label="Raw scores (Model A)")
283
+ with gr.Column():
284
+ out_ht_b = gr.HighlightedText(
285
+ label="EoT Predictions (Model B)",
286
+ color_map="coolwarm",
287
+ scale=1.5,
288
+ )
289
+ out_html_b = gr.HTML(label="Raw scores (Model B)")
290
+
291
+ run_btn.click(
292
+ fn=compare_models,
293
+ inputs=[text_in, model_a_id, model_a_rev, thresh_a, model_b_id, model_b_rev, thresh_b],
294
+ outputs=[out_ht_a, out_html_a, out_ht_b, out_html_b]
295
+ )
296
 
297
  demo.launch()