File size: 11,027 Bytes
8d2d3a6
 
 
 
 
ab7c317
 
 
 
 
 
 
 
8d2d3a6
 
 
ab7c317
10609fa
ebe0b0a
ab7c317
 
8d2d3a6
 
ab7c317
 
 
 
 
 
 
 
 
8d2d3a6
 
ab7c317
 
20b69e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d2d3a6
 
 
 
 
ab7c317
8d2d3a6
 
 
 
 
 
 
 
fd17305
8d2d3a6
 
 
 
 
 
 
 
 
 
 
 
 
ab7c317
8d2d3a6
 
 
 
 
 
 
 
 
ab7c317
8d2d3a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab7c317
 
 
8d2d3a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ed11a0
8d2d3a6
20b69e8
 
8d2d3a6
 
 
 
 
 
 
 
 
ab7c317
8d2d3a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ed11a0
8d2d3a6
 
20b69e8
8d2d3a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab7c317
0ed11a0
8d2d3a6
ab7c317
 
8d2d3a6
 
 
 
ab7c317
 
8d2d3a6
 
532c84d
 
 
8d2d3a6
 
 
20b69e8
8d2d3a6
20b69e8
8d2d3a6
 
 
 
20b69e8
 
ab7c317
8d2d3a6
 
 
ab7c317
8d2d3a6
ab7c317
 
8d2d3a6
ab7c317
 
10609fa
ab7c317
8d2d3a6
 
 
 
 
 
 
 
 
 
41cee88
8d2d3a6
 
 
ab7c317
8d2d3a6
 
 
 
 
20b69e8
8d2d3a6
 
 
 
20b69e8
8d2d3a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20b69e8
8d2d3a6
 
ab7c317
c32bc3a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
import os
import re
import unicodedata
from functools import lru_cache

import numpy as np
import pandas as pd
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

import gradio as gr


# ===== Defaults (you can change from the UI too) =====
DEFAULT_MODEL_A_ID = "livekit/turn-detector"
DEFAULT_MODEL_A_REV = "v0.3.0-intl"

DEFAULT_MODEL_B_ID = "livekit/turn-detector"
DEFAULT_MODEL_B_REV = "v0.4.1-intl"   # adjust if there's a specific revision


# ===== Utilities =====
def normalize_text(text: str) -> str:
    text = unicodedata.normalize("NFKC", text.lower())
    text = ''.join(
        ch for ch in text
        if not (unicodedata.category(ch).startswith('P') and ch not in ["'", "-"])
    )
    text = re.sub(r'\s+', ' ', text).strip()
    return text


def log_odds(p, eps=0.0):
    return np.log(p / (1 - p + eps))


def get_threshold(rev_id):
    import requests    

    DEFAULT_THRESH = 0.0049
    URL = f"https://huggingface.co/livekit/turn-detector/resolve/{rev_id}/languages.json"

    try:
        config = requests.get(URL).json().get("en")
    except Exception as e:
        print(f"Error loading languages.json: \n{e}")
        config = {}
    
    return config.get("threshold", DEFAULT_THRESH)



# ===== Per-model runner (keeps tokenizer/model and token ids) =====
class ModelRunner:
    def __init__(self, model_id: str, revision: str | None = None, dtype=torch.bfloat16):
        self.model_id = model_id
        self.revision = revision

        self.tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_id,
            revision=revision,
            torch_dtype=dtype,
            device_map="auto",
        )
        self.model.eval()
        self.thresh = get_threshold(revision)

        # Pull commonly used tokens, falling back gracefully if not present
        self.START_TOKEN_ID = self._tok_id("<|im_start|>")
        self.EOU_TOKEN_ID   = self._tok_id("<|im_end|>")
        self.NEWLINE_TOKEN_ID = self._tok_id("\n")

        # Common role tokens; include both legacy and chat-template variants
        self.USER_TOKEN_IDS = tuple(
            tid for tid in [
                self._tok_id("user"),
                self._tok_id("<|user|>")
            ] if tid is not None
        )

        # Tokens we do not want to score on (specials / scaffolding)
        self.SPECIAL_TOKENS = set(
            tid for tid in [
                self.NEWLINE_TOKEN_ID,
                self.START_TOKEN_ID,
                self._tok_id("user"),
                self._tok_id("assistant"),
            ] if tid is not None
        )

        # For filtering in display
        self.CONTROL_TOKS = set([
            "<|im_start|>", "<|im_end|>", "user", "assistant", "\n"
        ])

    def _tok_id(self, tok: str) -> int | None:
        tid = self.tokenizer.convert_tokens_to_ids(tok)
        # convert_tokens_to_ids returns None or 0/-1 if unknown depending on tokenizer
        if tid is None or tid < 0:
            return None
        return tid

    def format_input(self, text: str) -> str:
        # If not a chat-formatted string, wrap as a single user message via chat template
        if "<|im_start|>" not in text:
            msg = {"role": "user", "content": normalize_text(text)}
            text = self.tokenizer.apply_chat_template(
                [msg],
                tokenize=False,
                add_generation_prompt=True
            )
        return text

    def make_pred_mask(self, input_ids: np.ndarray) -> np.ndarray:
        """Return boolean mask: True where we should compute EoT prob (user tokens only)."""
        if self.START_TOKEN_ID is None or not self.USER_TOKEN_IDS:
            # Fallback: score all non-special tokens if start/user not available
            return np.array([tok not in self.SPECIAL_TOKENS for tok in input_ids], dtype=bool)

        user_mask = [False] * len(input_ids)
        is_user_role = False
        for i in range(len(input_ids)):
            tok = input_ids[i]
            if (self.START_TOKEN_ID is not None) and (tok == self.START_TOKEN_ID) and i + 1 < len(input_ids):
                is_user_role = input_ids[i + 1] in self.USER_TOKEN_IDS
                user_mask[i] = False
                continue
            user_mask[i] = is_user_role and (tok not in self.SPECIAL_TOKENS)
        return np.array(user_mask, dtype=bool)

    @torch.no_grad()
    def predict_eou(self, text: str) -> pd.DataFrame:
        text = self.format_input(text)

        with torch.amp.autocast(self.model.device.type):
            inputs = self.tokenizer.encode(
                text,
                add_special_tokens=False,
                return_tensors="pt"
            ).to(self.model.device)

            outputs = self.model(inputs)

        # probs over vocab for each position; then take the probability of EOU token
        logits = outputs.logits
        probs = torch.nn.functional.softmax(logits, dim=-1)
        if self.EOU_TOKEN_ID is None:
            # If the model/tokenizer doesn't have <|im_end|>, use newline as a proxy (last resort)
            fallback_id = self.NEWLINE_TOKEN_ID if self.NEWLINE_TOKEN_ID is not None else 0
            eou_probs = probs[..., fallback_id]
        else:
            eou_probs = probs[..., self.EOU_TOKEN_ID]

        eou_probs = eou_probs.squeeze(0).float().cpu().numpy()

        input_ids = inputs.squeeze(0).int().cpu().numpy()
        mask = self.make_pred_mask(input_ids)

        # set masked positions to NaN (not scored)
        eou_probs_masked = eou_probs.copy()
        eou_probs_masked[~mask] = np.nan

        tokens = [self.tokenizer.decode(i) for i in input_ids]
        return pd.DataFrame({"token": tokens, "pred": eou_probs_masked})

    def make_styled_df(self, df: pd.DataFrame, cmap="coolwarm") -> str:
        EPS = 1e-12
        thresh = self.thresh
        
        _df = df.copy()
        _df = _df[~_df.token.isin(self.CONTROL_TOKS)]
        _df.token = _df.token.replace({"\n": "⏎", " ": "␠"})

        _df["log_odds"] = (
            _df.pred.fillna(thresh)
            .add(EPS)
            .apply(log_odds).sub(log_odds(thresh))
            .mask(_df.pred.isna())
        )
        _df["Prob(EoT) as %"] = _df.pred.mul(100).fillna(0).astype(int)
        vmin, vmax = _df.log_odds.min(), _df.log_odds.max()
        vmax_abs = max(abs(vmin), abs(vmax)) * 1.5 if pd.notna(vmin) and pd.notna(vmax) else 1.0

        fmt = (
            _df.drop(columns=["pred"])
            .style
            .bar(
                subset=["log_odds"],
                align="zero",
                vmin=-vmax_abs,
                vmax=vmax_abs,
                cmap=cmap,
                height=70,
                width=100,
            )
            .text_gradient(subset=["log_odds"], cmap=cmap, vmin=-vmax_abs, vmax=vmax_abs)
            .format(na_rep="", precision=1, subset=["log_odds"])
            .format("{:3d}", subset=["Prob(EoT) as %"])
            .hide(axis="index")
        )
        return fmt.to_html()

    def generate_highlighted_text(self, text: str):
        """Returns: (highlighted_list, styled_html) for Gradio"""
        eps = 1e-12
        threshold = self.thresh
        if not text:
            return [], "<div>No input.</div>"

        df = self.predict_eou(text)
        df.token = df.token.replace({"user": "\nUSER:", "assistant": "\nAGENT:"})
        df = df[~df.token.isin(self.CONTROL_TOKS)]

        df["score"] = (
            df.pred.fillna(threshold)
            .add(eps)
            .apply(log_odds).sub(log_odds(threshold))
            .mask(df.pred.isna() | df.pred.round(2).eq(0))
        )
        max_abs_score = df["score"].abs().max()
        if pd.notna(max_abs_score) and max_abs_score > 0:
            df.score = df.score / (max_abs_score * 1.5)

        styled_df = self.make_styled_df(df[["token", "pred"]])
        return list(zip(df.token, df.score)), styled_df


# ===== Cached loaders so switching models in the UI is fast =====
@lru_cache(maxsize=4)
def get_runner(model_id: str, revision: str | None):
    return ModelRunner(model_id, revision)


# ===== Gradio App =====

import spaces

@spaces.GPU 
def compare_models(
    text: str,
    model_a_id: str,
    model_a_rev: str,    
    model_b_id: str,
    model_b_rev: str,    
):
    runner_a = get_runner(model_a_id, model_a_rev if model_a_rev else None)
    runner_b = get_runner(model_b_id, model_b_rev if model_b_rev else None)

    ht_a, html_a = runner_a.generate_highlighted_text(text)
    ht_b, html_b = runner_b.generate_highlighted_text(text)

    # Optional: prepend small headers indicating model names in the HTML blocks
    html_a = f"<h4 style='margin:0 0 8px 0'>{model_a_id}@{model_a_rev or 'default'}</h4>" + html_a
    html_b = f"<h4 style='margin:0 0 8px 0'>{model_b_id}@{model_b_rev or 'default'}</h4>" + html_b

    return ht_a, html_a, ht_b, html_b


EXAMPLE_CONVO = """<|im_start|>assistant
what is your phone number<|im_end|>
<|im_start|>user
five five five four one zero zero four two three<|im_end|>"""

with gr.Blocks(theme="soft", title="Turn Detector Debugger — Side by Side") as demo:
    gr.Markdown(
        """# Turn Detector Debugger — Side by Side  
Visualize predicted turn endings from **two models**.  
Red ⇒ agent should reply • Blue ⇒ agent should wait"""
    )

    with gr.Row():
        text_in = gr.Textbox(
            label="Input Text",
            info="Input text should follow the following chat template. Transcripts should be normalized to be lowercase and without punctuation.",
            value=EXAMPLE_CONVO,
            lines=4,
        )

    with gr.Row():
        with gr.Column():
            gr.Markdown("### Model A")
            model_a_id = gr.Textbox(value=DEFAULT_MODEL_A_ID, label="Model ID")
            model_a_rev = gr.Textbox(value=DEFAULT_MODEL_A_REV, label="Revision (optional)")
            

        with gr.Column():
            gr.Markdown("### Model B")
            model_b_id = gr.Textbox(value=DEFAULT_MODEL_B_ID, label="Model ID")
            model_b_rev = gr.Textbox(value=DEFAULT_MODEL_B_REV, label="Revision (optional)")            

    run_btn = gr.Button("Run Comparison", variant="primary")

    with gr.Row():
        with gr.Column():
            out_ht_a = gr.HighlightedText(
                label="EoT Predictions (Model A)",
                color_map="coolwarm",
                scale=1.5,
            )
            out_html_a = gr.HTML(label="Raw scores (Model A)")
        with gr.Column():
            out_ht_b = gr.HighlightedText(
                label="EoT Predictions (Model B)",
                color_map="coolwarm",
                scale=1.5,
            )
            out_html_b = gr.HTML(label="Raw scores (Model B)")

    run_btn.click(
        fn=compare_models,
        inputs=[text_in, model_a_id, model_a_rev, model_b_id, model_b_rev],
        outputs=[out_ht_a, out_html_a, out_ht_b, out_html_b]
    )

demo.launch(share=True)