jeradf commited on
Commit
20b69e8
·
verified ·
1 Parent(s): 9bc7a88

load version thresholds

Browse files
Files changed (1) hide show
  1. app.py +27 -11
app.py CHANGED
@@ -34,6 +34,22 @@ def log_odds(p, eps=0.0):
34
  return np.log(p / (1 - p + eps))
35
 
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  # ===== Per-model runner (keeps tokenizer/model and token ids) =====
38
  class ModelRunner:
39
  def __init__(self, model_id: str, revision: str | None = None, dtype=torch.bfloat16):
@@ -48,6 +64,7 @@ class ModelRunner:
48
  device_map="auto",
49
  )
50
  self.model.eval()
 
51
 
52
  # Pull commonly used tokens, falling back gracefully if not present
53
  self.START_TOKEN_ID = self._tok_id("<|im_start|>")
@@ -149,6 +166,8 @@ class ModelRunner:
149
 
150
  def make_styled_df(self, df: pd.DataFrame, thresh: float, cmap="coolwarm") -> str:
151
  EPS = 1e-12
 
 
152
  _df = df.copy()
153
  _df = _df[~_df.token.isin(self.CONTROL_TOKS)]
154
  _df.token = _df.token.replace({"\n": "⏎", " ": "␠"})
@@ -185,6 +204,7 @@ class ModelRunner:
185
  def generate_highlighted_text(self, text: str, threshold: float):
186
  """Returns: (highlighted_list, styled_html) for Gradio"""
187
  eps = 1e-12
 
188
  if not text:
189
  return [], "<div>No input.</div>"
190
 
@@ -213,7 +233,6 @@ def get_runner(model_id: str, revision: str | None):
213
 
214
 
215
  # ===== Gradio App =====
216
- EN_THRESHOLD = 0.0049
217
 
218
  import spaces
219
 
@@ -221,17 +240,15 @@ import spaces
221
  def compare_models(
222
  text: str,
223
  model_a_id: str,
224
- model_a_rev: str,
225
- thresh_a: float,
226
  model_b_id: str,
227
- model_b_rev: str,
228
- thresh_b: float,
229
  ):
230
  runner_a = get_runner(model_a_id, model_a_rev if model_a_rev else None)
231
  runner_b = get_runner(model_b_id, model_b_rev if model_b_rev else None)
232
 
233
- ht_a, html_a = runner_a.generate_highlighted_text(text, threshold=thresh_a)
234
- ht_b, html_b = runner_b.generate_highlighted_text(text, threshold=thresh_b)
235
 
236
  # Optional: prepend small headers indicating model names in the HTML blocks
237
  html_a = f"<h4 style='margin:0 0 8px 0'>{model_a_id}@{model_a_rev or 'default'}</h4>" + html_a
@@ -265,13 +282,12 @@ Red ⇒ agent should reply • Blue ⇒ agent should wait"""
265
  gr.Markdown("### Model A")
266
  model_a_id = gr.Textbox(value=DEFAULT_MODEL_A_ID, label="Model ID")
267
  model_a_rev = gr.Textbox(value=DEFAULT_MODEL_A_REV, label="Revision (optional)")
268
- thresh_a = gr.Slider(0.0001, 0.05, value=EN_THRESHOLD, step=0.0001, label="Threshold")
269
 
270
  with gr.Column():
271
  gr.Markdown("### Model B")
272
  model_b_id = gr.Textbox(value=DEFAULT_MODEL_B_ID, label="Model ID")
273
- model_b_rev = gr.Textbox(value=DEFAULT_MODEL_B_REV, label="Revision (optional)")
274
- thresh_b = gr.Slider(0.0001, 0.05, value=EN_THRESHOLD, step=0.0001, label="Threshold")
275
 
276
  run_btn = gr.Button("Run Comparison", variant="primary")
277
 
@@ -293,7 +309,7 @@ Red ⇒ agent should reply • Blue ⇒ agent should wait"""
293
 
294
  run_btn.click(
295
  fn=compare_models,
296
- inputs=[text_in, model_a_id, model_a_rev, thresh_a, model_b_id, model_b_rev, thresh_b],
297
  outputs=[out_ht_a, out_html_a, out_ht_b, out_html_b]
298
  )
299
 
 
34
  return np.log(p / (1 - p + eps))
35
 
36
 
37
+ def get_threshold(rev_id):
38
+ import requests
39
+
40
+ DEFAULT_THRESH = 0.0049
41
+ URL = f"https://huggingface.co/livekit/turn-detector/resolve/{rev_id}/languages.json"
42
+
43
+ try:
44
+ config = requests.get(URL).json().get("en")
45
+ except Exception as e:
46
+ print(f"Error loading languages.json: \n{e}")
47
+ config = {}
48
+
49
+ return config.get("threshold", DEFAULT_THRESH)
50
+
51
+
52
+
53
  # ===== Per-model runner (keeps tokenizer/model and token ids) =====
54
  class ModelRunner:
55
  def __init__(self, model_id: str, revision: str | None = None, dtype=torch.bfloat16):
 
64
  device_map="auto",
65
  )
66
  self.model.eval()
67
+ self.thresh = get_thresh(revision)
68
 
69
  # Pull commonly used tokens, falling back gracefully if not present
70
  self.START_TOKEN_ID = self._tok_id("<|im_start|>")
 
166
 
167
  def make_styled_df(self, df: pd.DataFrame, thresh: float, cmap="coolwarm") -> str:
168
  EPS = 1e-12
169
+ thresh = self.thresh
170
+
171
  _df = df.copy()
172
  _df = _df[~_df.token.isin(self.CONTROL_TOKS)]
173
  _df.token = _df.token.replace({"\n": "⏎", " ": "␠"})
 
204
  def generate_highlighted_text(self, text: str, threshold: float):
205
  """Returns: (highlighted_list, styled_html) for Gradio"""
206
  eps = 1e-12
207
+ threshold = self.thresh
208
  if not text:
209
  return [], "<div>No input.</div>"
210
 
 
233
 
234
 
235
  # ===== Gradio App =====
 
236
 
237
  import spaces
238
 
 
240
  def compare_models(
241
  text: str,
242
  model_a_id: str,
243
+ model_a_rev: str,
 
244
  model_b_id: str,
245
+ model_b_rev: str,
 
246
  ):
247
  runner_a = get_runner(model_a_id, model_a_rev if model_a_rev else None)
248
  runner_b = get_runner(model_b_id, model_b_rev if model_b_rev else None)
249
 
250
+ ht_a, html_a = runner_a.generate_highlighted_text(text)
251
+ ht_b, html_b = runner_b.generate_highlighted_text(text)
252
 
253
  # Optional: prepend small headers indicating model names in the HTML blocks
254
  html_a = f"<h4 style='margin:0 0 8px 0'>{model_a_id}@{model_a_rev or 'default'}</h4>" + html_a
 
282
  gr.Markdown("### Model A")
283
  model_a_id = gr.Textbox(value=DEFAULT_MODEL_A_ID, label="Model ID")
284
  model_a_rev = gr.Textbox(value=DEFAULT_MODEL_A_REV, label="Revision (optional)")
285
+
286
 
287
  with gr.Column():
288
  gr.Markdown("### Model B")
289
  model_b_id = gr.Textbox(value=DEFAULT_MODEL_B_ID, label="Model ID")
290
+ model_b_rev = gr.Textbox(value=DEFAULT_MODEL_B_REV, label="Revision (optional)")
 
291
 
292
  run_btn = gr.Button("Run Comparison", variant="primary")
293
 
 
309
 
310
  run_btn.click(
311
  fn=compare_models,
312
+ inputs=[text_in, model_a_id, model_a_rev, model_b_id, model_b_rev],
313
  outputs=[out_ht_a, out_html_a, out_ht_b, out_html_b]
314
  )
315