File size: 16,038 Bytes
089d923
6f2bef2
 
 
ac097cd
 
296d19e
ac097cd
 
296d19e
ac097cd
 
 
 
 
 
296d19e
 
ac097cd
 
 
 
 
 
 
 
296d19e
089d923
ac097cd
 
 
089d923
ac097cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef4f851
 
 
802b19a
ef4f851
 
 
 
 
 
 
 
 
 
 
 
 
ac097cd
2eec318
ac097cd
 
 
 
 
 
 
 
 
 
ef4f851
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
802b19a
ef4f851
 
 
ac097cd
ef4f851
ac097cd
 
 
 
6f2bef2
802b19a
 
 
 
 
6f2bef2
 
802b19a
 
 
 
 
 
 
 
 
 
 
 
 
6f2bef2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac097cd
089d923
 
ac097cd
089d923
 
 
 
ac097cd
 
089d923
802b19a
ac097cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f2bef2
ef4f851
6f2bef2
802b19a
 
 
6f2bef2
802b19a
 
 
 
 
 
 
 
6f2bef2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
089d923
ef4f851
6f2bef2
 
 
 
 
 
 
 
 
 
 
802b19a
6f2bef2
 
 
 
 
 
 
 
 
802b19a
 
ef4f851
 
6f2bef2
802b19a
6f2bef2
 
 
802b19a
ef4f851
 
 
 
 
a393253
ef4f851
a393253
 
 
ef4f851
 
 
 
 
 
 
 
a393253
ef4f851
 
 
 
 
 
a393253
ef4f851
802b19a
ef4f851
 
 
802b19a
089d923
 
ac097cd
 
 
 
 
 
2eec318
ac097cd
 
 
 
 
 
3f48061
ac097cd
 
 
 
 
 
 
 
2eec318
 
6f2bef2
 
 
 
 
 
3f48061
6f2bef2
ac097cd
6f2bef2
 
 
 
 
 
ac097cd
 
 
 
 
089d923
ac097cd
 
 
 
 
089d923
 
 
ac097cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef4f851
6f2bef2
ef4f851
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac097cd
 
 
 
 
 
 
 
 
 
 
 
296d19e
 
76aa0b3
296d19e
 
ac097cd
 
 
 
 
 
76aa0b3
ac097cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296d19e
 
089d923
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
import gradio as gr
from huggingface_hub import HfApi
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import json
import os

DEFAULT_FILE = "default_models.json"
USER_FILE = "models.json"

# -----------------------------
# Model data loading / saving
# -----------------------------
def load_default_models():
    with open(DEFAULT_FILE, "r", encoding="utf-8") as f:
        return json.load(f)


def load_user_models():
    if os.path.exists(USER_FILE):
        with open(USER_FILE, "r", encoding="utf-8") as f:
            try:
                return json.load(f)
            except json.JSONDecodeError:
                return {}
    return {}


def save_user_models(data):
    with open(USER_FILE, "w", encoding="utf-8") as f:
        json.dump(data, f, indent=2, ensure_ascii=False)


def merge_models():
    """
    Merge default + user models into one tree:
      Category -> Family -> Model -> meta
    User models can introduce new categories/families.
    """
    base = load_default_models()
    user = load_user_models()

    for category, families in user.items():
        if category not in base:
            base[category] = {}
        for family, models in families.items():
            if family not in base[category]:
                base[category][family] = {}
            for model_name, meta in models.items():
                base[category][family][model_name] = meta

    return base


# -----------------------------
# Utility: flatten and lookup
# -----------------------------
def flatten_models(model_tree):
    """
    Returns a dict:
      full_key -> (meta, category, family, model_name)
    where full_key = "Category / Family / Model"
    """
    flat = {}
    for category, families in model_tree.items():
        for family, models in families.items():
            for model_name, meta in models.items():
                full_key = f"{category} / {family} / {model_name}"
                flat[full_key] = (meta, category, family, model_name)
    return flat


# -----------------------------
# Debug logging
# -----------------------------
DEBUG_MESSAGES = []  # global buffer

def debug(msg):
    """Append a debug message to the global log."""
    DEBUG_MESSAGES.append(str(msg))
    if len(DEBUG_MESSAGES) > 300:
        DEBUG_MESSAGES.pop(0)
    return "\n".join(DEBUG_MESSAGES)


def get_debug_log():
    return "\n".join(DEBUG_MESSAGES)


# -----------------------------
# Add a new model (from the box)
# -----------------------------
def add_model_box(
    category,
    family,
    model_name,
    model_id,
    description,
    link,
    emoji
):
    try:
        if not model_id:
            debug("Add model failed: no model_id provided")
            return gr.Markdown.update(
                value="Please provide a Model ID like `user/model`."
            )

        if not category:
            category = "Custom"
        if not family:
            family = "User-Added"
        if not model_name:
            model_name = model_id.split("/")[-1]
        if not description:
            description = "User-added model."
        if not link:
            link = f"https://huggingface.co/{model_id}"
        if not emoji:
            emoji = "✨"

        user_models = load_user_models()

        if category not in user_models:
            user_models[category] = {}
        if family not in user_models[category]:
            user_models[category][family] = {}

        user_models[category][family][model_name] = {
            "id": model_id,
            "description": description,
            "link": link,
            "emoji": emoji
        }

        save_user_models(user_models)

        msg = (
            f"Added model under `{category} / {family}`: "
            f"{emoji} **{model_name}** (`{model_id}`)\n\n"
            f"It will appear in the model tree after reloading the Space."
        )
        debug(f"Model added: {category} / {family} / {model_name} ({model_id})")
        return gr.Markdown.update(value=msg)
    except Exception:
        import traceback
        tb = traceback.format_exc()
        debug(f"ERROR in add_model_box:\n{tb}")
        return gr.Markdown.update(
            value="An error occurred while adding the model. Check Debug Log."
        )


# -----------------------------
# Helper: check model access (repo visibility)
# -----------------------------
def check_model_access(model_id, hf_token):
    """
    Try to get model info; return (ok: bool, message: str).
    This helps distinguish auth/gating vs other issues.
    For local loading, this is not strictly required, but we keep
    it to give clearer messages for private/gated models.
    """
    try:
        api = HfApi(token=hf_token.token if hf_token else None)
        _ = api.model_info(model_id)
        return True, ""
    except Exception as e:
        import traceback
        tb = traceback.format_exc()
        debug(f"ERROR in check_model_access for {model_id}:\n{tb}")
        return False, str(e)


# -----------------------------
# Local model cache
# -----------------------------
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

LOCAL_MODEL_CACHE = {}

def load_local_model(model_id):
    """
    Load a model + tokenizer locally and cache them.
    This makes the Space behave like a dedicated model Space:
    models are executed inside the container, not via Inference API.
    """
    if model_id in LOCAL_MODEL_CACHE:
        debug(f"Using cached model: {model_id}")
        return LOCAL_MODEL_CACHE[model_id]

    debug(f"Loading model locally: {model_id}")
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_id)
    except Exception as e:
        debug(f"ERROR loading tokenizer for {model_id}: {e}")
        raise

    try:
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            device_map="auto"
        )
    except Exception as e:
        debug(f"ERROR loading model weights for {model_id}: {e}")
        raise

    LOCAL_MODEL_CACHE[model_id] = (tokenizer, model)
    return tokenizer, model


# -----------------------------
# Chat function (local models)
# -----------------------------
def respond(
    message,
    history,
    system_message,
    max_tokens,
    temperature,
    top_p,
    active_model_key,
    hf_token: gr.OAuthToken
):
    # No model chosen
    if active_model_key is None:
        yield "No model selected. Please choose a model in the sidebar and click 'Use this model'."
        return

    models = merge_models()
    flat = flatten_models(models)

    meta_tuple = flat.get(active_model_key)
    if meta_tuple is None:
        yield "Selected model not found. Please choose a model again."
        return

    meta, _, _, _ = meta_tuple
    model_id = meta["id"]

    debug(f"Chat using local model: {model_id}")

    # Optional: check repo access (for private/gated models)
    ok, msg = check_model_access(model_id, hf_token)
    if not ok:
        yield (
            f"Could not access model `{model_id}` on Hugging Face.\n\n"
            f"This is usually because:\n"
            f"- The repo is private or gated and this token has no access\n"
            f"- Or the token is invalid/expired\n\n"
            f"Raw error:\n{msg}\n\n"
            f"Check Debug Log for more details."
        )
        return

    # Load model locally
    try:
        tokenizer, model = load_local_model(model_id)
    except Exception:
        import traceback
        tb = traceback.format_exc()
        debug(f"ERROR in load_local_model for {model_id}:\n{tb}")
        yield (
            f"Failed to load model `{model_id}` locally inside the Space.\n"
            f"Check the Debug Log for details (likely out of memory or missing files)."
        )
        return

    # Build chat-style prompt from history + current message
    prompt = system_message.strip() + "\n\n"
    for turn in history or []:
        role = turn.get("role", "user")
        content = turn.get("content", "")
        if role == "user":
            prompt += f"User: {content}\n"
        else:
            prompt += f"Assistant: {content}\n"
    prompt += f"User: {message}\nAssistant:"

    debug(f"Prompt length (chars): {len(prompt)}")

    try:
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

        # Generate text locally
        output_ids = model.generate(
            **inputs,
            max_new_tokens=int(max_tokens),
            do_sample=True,
            temperature=float(temperature),
            top_p=float(top_p),
            pad_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id is not None else None,
        )

        output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)

        # Extract only the assistant's final answer
        if "Assistant:" in output_text:
            answer = output_text.split("Assistant:")[-1].strip()
        else:
            answer = output_text.strip()

        yield answer

    except Exception:
        import traceback
        tb = traceback.format_exc()
        debug(f"ERROR during local generation for {model_id}:\n{tb}")
        yield (
            "An error occurred during local text generation.\n"
            "This is often due to running out of memory for large models.\n"
            "Try a smaller model, fewer max tokens, or check the Debug Log."
        )


# -----------------------------
# Use model helper
# -----------------------------
def use_model(fk, old_fk):
    """
    fk: full key "Category / Family / Model" (from gr.State(full_key))
    old_fk: previous active model key (from active_model_state)
    Returns: (new_active_key, current_model_label_text)
    """
    try:
        models_local = merge_models()
        flat_local = flatten_models(models_local)
        meta_loc_tuple = flat_local.get(fk)

        if not meta_loc_tuple:
            debug(f"use_model: key not found: {fk}")
            return old_fk, "**Current model:** _none selected_"

        meta_loc, _, _, mname = meta_loc_tuple
        emoji_local = meta_loc.get("emoji", "✨")
        label_text = f"**Current model:** {emoji_local} {mname}"

        debug(f"use_model: selected {fk}")
        return fk, label_text

    except Exception:
        import traceback
        tb = traceback.format_exc()
        debug(f"ERROR in use_model:\n{tb}")
        return old_fk, "**Current model:** _error occurred (see Debug Log)_" 


# -----------------------------
# Build the sidebar tree
# -----------------------------
def build_model_tree(
    models,
    active_model_state,
    current_model_label
):
    """
    models: merged models dict (Category -> Family -> Model -> meta)
    active_model_state: gr.State storing current active full key
    current_model_label: gr.Markdown for 'Current model: ...'
    """

    for category, families in models.items():
        with gr.Accordion(category, open=False):
            for family, model_dict in families.items():
                with gr.Accordion(family, open=False):
                    for model_name, meta in model_dict.items():
                        emoji = meta.get("emoji", "✨")
                        full_key = f"{category} / {family} / {model_name}"

                        # Model accordion
                        with gr.Accordion(f"{emoji} {model_name}", open=False):
                            info_text = (
                                f"**Model ID:** `{meta['id']}`  \n"
                                f"**Description:** {meta['description']}  \n"
                                f"[Model card]({meta['link']})"
                            )
                            gr.Markdown(info_text)

                            use_btn = gr.Button("Use this model", size="sm")

                            # Wire button -> use_model
                            use_btn.click(
                                use_model,
                                inputs=[gr.State(full_key), active_model_state],
                                outputs=[active_model_state, current_model_label],
                            )


# -----------------------------
# Build the UI
# -----------------------------
with gr.Blocks() as demo:
    models_tree = merge_models()

    # Holds full key: "Category / Family / Model"
    active_model_key = gr.State(value=None)

    with gr.Sidebar():
        gr.LoginButton()

        # Collapsible "Add New Model" box
        with gr.Accordion("Add New Model", open=False):
            category_input = gr.Textbox(
                label="Category (e.g. Exotic or new category)",
                placeholder="Exotic"
            )
            family_input = gr.Textbox(
                label="Family (e.g. RWKV)",
                placeholder="RWKV"
            )
            model_name_input = gr.Textbox(
                label="Model Name (e.g. RWKV-World-7B)",
                placeholder="RWKV-World-7B"
            )
            model_id_input = gr.Textbox(
                label="Model ID (e.g. BlinkDL/rwkv-7-world)",
                placeholder="BlinkDL/rwkv-7-world"
            )
            description_input = gr.Textbox(
                label="Description (optional)",
                lines=2
            )
            link_input = gr.Textbox(
                label="Link (optional, will default to https://huggingface.co/ModelID if empty)",
                lines=1
            )
            emoji_input = gr.Textbox(
                label="Emoji (optional, e.g. 🌍)",
                lines=1
            )

            add_button = gr.Button("Add Model")
            add_status = gr.Markdown("")

            add_button.click(
                add_model_box,
                inputs=[
                    category_input,
                    family_input,
                    model_name_input,
                    model_id_input,
                    description_input,
                    link_input,
                    emoji_input,
                ],
                outputs=add_status,
            )

        # Debug Log box (separate accordion)
        with gr.Accordion("Debug Log", open=False):
            debug_log = gr.Textbox(
                label="System Debug Output",
                value="",
                lines=15,
                max_lines=200,
                interactive=False,
                show_copy_button=True,
            )

        # Button to refresh debug log
        refresh_debug = gr.Button("Refresh Debug Log", size="sm")

        refresh_debug.click(
            get_debug_log,
            inputs=None,
            outputs=debug_log
        )

        # Current model label under the debug box
        current_model_label = gr.Markdown("**Current model:** _none selected_")

        gr.Markdown("### Models")

        # Build nested accordions for models
        build_model_tree(
            models_tree,
            active_model_state=active_model_key,
            current_model_label=current_model_label,
        )

    # Main chat interface
    chatbot = gr.ChatInterface(
        respond,
        title=current_model_label,
        type="messages",
        additional_inputs=[
            gr.Textbox(
                value="You are a friendly chatbot.",
                label="System message"
            ),
            gr.Slider(
                minimum=1,
                maximum=100000,
                value=512,
                step=1,
                label="Max new tokens"
            ),
            gr.Slider(
                minimum=0.1,
                maximum=4.0,
                value=0.7,
                step=0.1,
                label="Temperature"
            ),
            gr.Slider(
                minimum=0.1,
                maximum=1.0,
                value=0.95,
                step=0.05,
                label="Top-p"
            ),
            active_model_key,  # passes current active model key into respond()
        ],
    )

if __name__ == "__main__":
    demo.launch()