|
|
import gradio as gr |
|
|
from huggingface_hub import HfApi |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
import torch |
|
|
import json |
|
|
import os |
|
|
|
|
|
DEFAULT_FILE = "default_models.json" |
|
|
USER_FILE = "models.json" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_default_models(): |
|
|
with open(DEFAULT_FILE, "r", encoding="utf-8") as f: |
|
|
return json.load(f) |
|
|
|
|
|
|
|
|
def load_user_models(): |
|
|
if os.path.exists(USER_FILE): |
|
|
with open(USER_FILE, "r", encoding="utf-8") as f: |
|
|
try: |
|
|
return json.load(f) |
|
|
except json.JSONDecodeError: |
|
|
return {} |
|
|
return {} |
|
|
|
|
|
|
|
|
def save_user_models(data): |
|
|
with open(USER_FILE, "w", encoding="utf-8") as f: |
|
|
json.dump(data, f, indent=2, ensure_ascii=False) |
|
|
|
|
|
|
|
|
def merge_models(): |
|
|
""" |
|
|
Merge default + user models into one tree: |
|
|
Category -> Family -> Model -> meta |
|
|
User models can introduce new categories/families. |
|
|
""" |
|
|
base = load_default_models() |
|
|
user = load_user_models() |
|
|
|
|
|
for category, families in user.items(): |
|
|
if category not in base: |
|
|
base[category] = {} |
|
|
for family, models in families.items(): |
|
|
if family not in base[category]: |
|
|
base[category][family] = {} |
|
|
for model_name, meta in models.items(): |
|
|
base[category][family][model_name] = meta |
|
|
|
|
|
return base |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def flatten_models(model_tree): |
|
|
""" |
|
|
Returns a dict: |
|
|
full_key -> (meta, category, family, model_name) |
|
|
where full_key = "Category / Family / Model" |
|
|
""" |
|
|
flat = {} |
|
|
for category, families in model_tree.items(): |
|
|
for family, models in families.items(): |
|
|
for model_name, meta in models.items(): |
|
|
full_key = f"{category} / {family} / {model_name}" |
|
|
flat[full_key] = (meta, category, family, model_name) |
|
|
return flat |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEBUG_MESSAGES = [] |
|
|
|
|
|
def debug(msg): |
|
|
"""Append a debug message to the global log.""" |
|
|
DEBUG_MESSAGES.append(str(msg)) |
|
|
if len(DEBUG_MESSAGES) > 300: |
|
|
DEBUG_MESSAGES.pop(0) |
|
|
return "\n".join(DEBUG_MESSAGES) |
|
|
|
|
|
|
|
|
def get_debug_log(): |
|
|
return "\n".join(DEBUG_MESSAGES) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def add_model_box( |
|
|
category, |
|
|
family, |
|
|
model_name, |
|
|
model_id, |
|
|
description, |
|
|
link, |
|
|
emoji |
|
|
): |
|
|
try: |
|
|
if not model_id: |
|
|
debug("Add model failed: no model_id provided") |
|
|
return gr.Markdown.update( |
|
|
value="Please provide a Model ID like `user/model`." |
|
|
) |
|
|
|
|
|
if not category: |
|
|
category = "Custom" |
|
|
if not family: |
|
|
family = "User-Added" |
|
|
if not model_name: |
|
|
model_name = model_id.split("/")[-1] |
|
|
if not description: |
|
|
description = "User-added model." |
|
|
if not link: |
|
|
link = f"https://huggingface.co/{model_id}" |
|
|
if not emoji: |
|
|
emoji = "✨" |
|
|
|
|
|
user_models = load_user_models() |
|
|
|
|
|
if category not in user_models: |
|
|
user_models[category] = {} |
|
|
if family not in user_models[category]: |
|
|
user_models[category][family] = {} |
|
|
|
|
|
user_models[category][family][model_name] = { |
|
|
"id": model_id, |
|
|
"description": description, |
|
|
"link": link, |
|
|
"emoji": emoji |
|
|
} |
|
|
|
|
|
save_user_models(user_models) |
|
|
|
|
|
msg = ( |
|
|
f"Added model under `{category} / {family}`: " |
|
|
f"{emoji} **{model_name}** (`{model_id}`)\n\n" |
|
|
f"It will appear in the model tree after reloading the Space." |
|
|
) |
|
|
debug(f"Model added: {category} / {family} / {model_name} ({model_id})") |
|
|
return gr.Markdown.update(value=msg) |
|
|
except Exception: |
|
|
import traceback |
|
|
tb = traceback.format_exc() |
|
|
debug(f"ERROR in add_model_box:\n{tb}") |
|
|
return gr.Markdown.update( |
|
|
value="An error occurred while adding the model. Check Debug Log." |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_model_access(model_id, hf_token): |
|
|
""" |
|
|
Try to get model info; return (ok: bool, message: str). |
|
|
This helps distinguish auth/gating vs other issues. |
|
|
For local loading, this is not strictly required, but we keep |
|
|
it to give clearer messages for private/gated models. |
|
|
""" |
|
|
try: |
|
|
api = HfApi(token=hf_token.token if hf_token else None) |
|
|
_ = api.model_info(model_id) |
|
|
return True, "" |
|
|
except Exception as e: |
|
|
import traceback |
|
|
tb = traceback.format_exc() |
|
|
debug(f"ERROR in check_model_access for {model_id}:\n{tb}") |
|
|
return False, str(e) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
import torch |
|
|
|
|
|
LOCAL_MODEL_CACHE = {} |
|
|
|
|
|
def load_local_model(model_id): |
|
|
""" |
|
|
Load a model + tokenizer locally and cache them. |
|
|
This makes the Space behave like a dedicated model Space: |
|
|
models are executed inside the container, not via Inference API. |
|
|
""" |
|
|
if model_id in LOCAL_MODEL_CACHE: |
|
|
debug(f"Using cached model: {model_id}") |
|
|
return LOCAL_MODEL_CACHE[model_id] |
|
|
|
|
|
debug(f"Loading model locally: {model_id}") |
|
|
try: |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
except Exception as e: |
|
|
debug(f"ERROR loading tokenizer for {model_id}: {e}") |
|
|
raise |
|
|
|
|
|
try: |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_id, |
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
|
device_map="auto" |
|
|
) |
|
|
except Exception as e: |
|
|
debug(f"ERROR loading model weights for {model_id}: {e}") |
|
|
raise |
|
|
|
|
|
LOCAL_MODEL_CACHE[model_id] = (tokenizer, model) |
|
|
return tokenizer, model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def respond( |
|
|
message, |
|
|
history, |
|
|
system_message, |
|
|
max_tokens, |
|
|
temperature, |
|
|
top_p, |
|
|
active_model_key, |
|
|
hf_token: gr.OAuthToken |
|
|
): |
|
|
|
|
|
if active_model_key is None: |
|
|
yield "No model selected. Please choose a model in the sidebar and click 'Use this model'." |
|
|
return |
|
|
|
|
|
models = merge_models() |
|
|
flat = flatten_models(models) |
|
|
|
|
|
meta_tuple = flat.get(active_model_key) |
|
|
if meta_tuple is None: |
|
|
yield "Selected model not found. Please choose a model again." |
|
|
return |
|
|
|
|
|
meta, _, _, _ = meta_tuple |
|
|
model_id = meta["id"] |
|
|
|
|
|
debug(f"Chat using local model: {model_id}") |
|
|
|
|
|
|
|
|
ok, msg = check_model_access(model_id, hf_token) |
|
|
if not ok: |
|
|
yield ( |
|
|
f"Could not access model `{model_id}` on Hugging Face.\n\n" |
|
|
f"This is usually because:\n" |
|
|
f"- The repo is private or gated and this token has no access\n" |
|
|
f"- Or the token is invalid/expired\n\n" |
|
|
f"Raw error:\n{msg}\n\n" |
|
|
f"Check Debug Log for more details." |
|
|
) |
|
|
return |
|
|
|
|
|
|
|
|
try: |
|
|
tokenizer, model = load_local_model(model_id) |
|
|
except Exception: |
|
|
import traceback |
|
|
tb = traceback.format_exc() |
|
|
debug(f"ERROR in load_local_model for {model_id}:\n{tb}") |
|
|
yield ( |
|
|
f"Failed to load model `{model_id}` locally inside the Space.\n" |
|
|
f"Check the Debug Log for details (likely out of memory or missing files)." |
|
|
) |
|
|
return |
|
|
|
|
|
|
|
|
prompt = system_message.strip() + "\n\n" |
|
|
for turn in history or []: |
|
|
role = turn.get("role", "user") |
|
|
content = turn.get("content", "") |
|
|
if role == "user": |
|
|
prompt += f"User: {content}\n" |
|
|
else: |
|
|
prompt += f"Assistant: {content}\n" |
|
|
prompt += f"User: {message}\nAssistant:" |
|
|
|
|
|
debug(f"Prompt length (chars): {len(prompt)}") |
|
|
|
|
|
try: |
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
|
|
|
|
|
|
output_ids = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=int(max_tokens), |
|
|
do_sample=True, |
|
|
temperature=float(temperature), |
|
|
top_p=float(top_p), |
|
|
pad_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id is not None else None, |
|
|
) |
|
|
|
|
|
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
if "Assistant:" in output_text: |
|
|
answer = output_text.split("Assistant:")[-1].strip() |
|
|
else: |
|
|
answer = output_text.strip() |
|
|
|
|
|
yield answer |
|
|
|
|
|
except Exception: |
|
|
import traceback |
|
|
tb = traceback.format_exc() |
|
|
debug(f"ERROR during local generation for {model_id}:\n{tb}") |
|
|
yield ( |
|
|
"An error occurred during local text generation.\n" |
|
|
"This is often due to running out of memory for large models.\n" |
|
|
"Try a smaller model, fewer max tokens, or check the Debug Log." |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def use_model(fk, old_fk): |
|
|
""" |
|
|
fk: full key "Category / Family / Model" (from gr.State(full_key)) |
|
|
old_fk: previous active model key (from active_model_state) |
|
|
Returns: (new_active_key, current_model_label_text) |
|
|
""" |
|
|
try: |
|
|
models_local = merge_models() |
|
|
flat_local = flatten_models(models_local) |
|
|
meta_loc_tuple = flat_local.get(fk) |
|
|
|
|
|
if not meta_loc_tuple: |
|
|
debug(f"use_model: key not found: {fk}") |
|
|
return old_fk, "**Current model:** _none selected_" |
|
|
|
|
|
meta_loc, _, _, mname = meta_loc_tuple |
|
|
emoji_local = meta_loc.get("emoji", "✨") |
|
|
label_text = f"**Current model:** {emoji_local} {mname}" |
|
|
|
|
|
debug(f"use_model: selected {fk}") |
|
|
return fk, label_text |
|
|
|
|
|
except Exception: |
|
|
import traceback |
|
|
tb = traceback.format_exc() |
|
|
debug(f"ERROR in use_model:\n{tb}") |
|
|
return old_fk, "**Current model:** _error occurred (see Debug Log)_" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_model_tree( |
|
|
models, |
|
|
active_model_state, |
|
|
current_model_label |
|
|
): |
|
|
""" |
|
|
models: merged models dict (Category -> Family -> Model -> meta) |
|
|
active_model_state: gr.State storing current active full key |
|
|
current_model_label: gr.Markdown for 'Current model: ...' |
|
|
""" |
|
|
|
|
|
for category, families in models.items(): |
|
|
with gr.Accordion(category, open=False): |
|
|
for family, model_dict in families.items(): |
|
|
with gr.Accordion(family, open=False): |
|
|
for model_name, meta in model_dict.items(): |
|
|
emoji = meta.get("emoji", "✨") |
|
|
full_key = f"{category} / {family} / {model_name}" |
|
|
|
|
|
|
|
|
with gr.Accordion(f"{emoji} {model_name}", open=False): |
|
|
info_text = ( |
|
|
f"**Model ID:** `{meta['id']}` \n" |
|
|
f"**Description:** {meta['description']} \n" |
|
|
f"[Model card]({meta['link']})" |
|
|
) |
|
|
gr.Markdown(info_text) |
|
|
|
|
|
use_btn = gr.Button("Use this model", size="sm") |
|
|
|
|
|
|
|
|
use_btn.click( |
|
|
use_model, |
|
|
inputs=[gr.State(full_key), active_model_state], |
|
|
outputs=[active_model_state, current_model_label], |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
models_tree = merge_models() |
|
|
|
|
|
|
|
|
active_model_key = gr.State(value=None) |
|
|
|
|
|
with gr.Sidebar(): |
|
|
gr.LoginButton() |
|
|
|
|
|
|
|
|
with gr.Accordion("Add New Model", open=False): |
|
|
category_input = gr.Textbox( |
|
|
label="Category (e.g. Exotic or new category)", |
|
|
placeholder="Exotic" |
|
|
) |
|
|
family_input = gr.Textbox( |
|
|
label="Family (e.g. RWKV)", |
|
|
placeholder="RWKV" |
|
|
) |
|
|
model_name_input = gr.Textbox( |
|
|
label="Model Name (e.g. RWKV-World-7B)", |
|
|
placeholder="RWKV-World-7B" |
|
|
) |
|
|
model_id_input = gr.Textbox( |
|
|
label="Model ID (e.g. BlinkDL/rwkv-7-world)", |
|
|
placeholder="BlinkDL/rwkv-7-world" |
|
|
) |
|
|
description_input = gr.Textbox( |
|
|
label="Description (optional)", |
|
|
lines=2 |
|
|
) |
|
|
link_input = gr.Textbox( |
|
|
label="Link (optional, will default to https://huggingface.co/ModelID if empty)", |
|
|
lines=1 |
|
|
) |
|
|
emoji_input = gr.Textbox( |
|
|
label="Emoji (optional, e.g. 🌍)", |
|
|
lines=1 |
|
|
) |
|
|
|
|
|
add_button = gr.Button("Add Model") |
|
|
add_status = gr.Markdown("") |
|
|
|
|
|
add_button.click( |
|
|
add_model_box, |
|
|
inputs=[ |
|
|
category_input, |
|
|
family_input, |
|
|
model_name_input, |
|
|
model_id_input, |
|
|
description_input, |
|
|
link_input, |
|
|
emoji_input, |
|
|
], |
|
|
outputs=add_status, |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Accordion("Debug Log", open=False): |
|
|
debug_log = gr.Textbox( |
|
|
label="System Debug Output", |
|
|
value="", |
|
|
lines=15, |
|
|
max_lines=200, |
|
|
interactive=False, |
|
|
show_copy_button=True, |
|
|
) |
|
|
|
|
|
|
|
|
refresh_debug = gr.Button("Refresh Debug Log", size="sm") |
|
|
|
|
|
refresh_debug.click( |
|
|
get_debug_log, |
|
|
inputs=None, |
|
|
outputs=debug_log |
|
|
) |
|
|
|
|
|
|
|
|
current_model_label = gr.Markdown("**Current model:** _none selected_") |
|
|
|
|
|
gr.Markdown("### Models") |
|
|
|
|
|
|
|
|
build_model_tree( |
|
|
models_tree, |
|
|
active_model_state=active_model_key, |
|
|
current_model_label=current_model_label, |
|
|
) |
|
|
|
|
|
|
|
|
chatbot = gr.ChatInterface( |
|
|
respond, |
|
|
title=current_model_label, |
|
|
type="messages", |
|
|
additional_inputs=[ |
|
|
gr.Textbox( |
|
|
value="You are a friendly chatbot.", |
|
|
label="System message" |
|
|
), |
|
|
gr.Slider( |
|
|
minimum=1, |
|
|
maximum=100000, |
|
|
value=512, |
|
|
step=1, |
|
|
label="Max new tokens" |
|
|
), |
|
|
gr.Slider( |
|
|
minimum=0.1, |
|
|
maximum=4.0, |
|
|
value=0.7, |
|
|
step=0.1, |
|
|
label="Temperature" |
|
|
), |
|
|
gr.Slider( |
|
|
minimum=0.1, |
|
|
maximum=1.0, |
|
|
value=0.95, |
|
|
step=0.05, |
|
|
label="Top-p" |
|
|
), |
|
|
active_model_key, |
|
|
], |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|