Rainbowdesign's picture
Update app.py
76aa0b3 verified
import gradio as gr
from huggingface_hub import HfApi
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import json
import os
DEFAULT_FILE = "default_models.json"
USER_FILE = "models.json"
# -----------------------------
# Model data loading / saving
# -----------------------------
def load_default_models():
with open(DEFAULT_FILE, "r", encoding="utf-8") as f:
return json.load(f)
def load_user_models():
if os.path.exists(USER_FILE):
with open(USER_FILE, "r", encoding="utf-8") as f:
try:
return json.load(f)
except json.JSONDecodeError:
return {}
return {}
def save_user_models(data):
with open(USER_FILE, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2, ensure_ascii=False)
def merge_models():
"""
Merge default + user models into one tree:
Category -> Family -> Model -> meta
User models can introduce new categories/families.
"""
base = load_default_models()
user = load_user_models()
for category, families in user.items():
if category not in base:
base[category] = {}
for family, models in families.items():
if family not in base[category]:
base[category][family] = {}
for model_name, meta in models.items():
base[category][family][model_name] = meta
return base
# -----------------------------
# Utility: flatten and lookup
# -----------------------------
def flatten_models(model_tree):
"""
Returns a dict:
full_key -> (meta, category, family, model_name)
where full_key = "Category / Family / Model"
"""
flat = {}
for category, families in model_tree.items():
for family, models in families.items():
for model_name, meta in models.items():
full_key = f"{category} / {family} / {model_name}"
flat[full_key] = (meta, category, family, model_name)
return flat
# -----------------------------
# Debug logging
# -----------------------------
DEBUG_MESSAGES = [] # global buffer
def debug(msg):
"""Append a debug message to the global log."""
DEBUG_MESSAGES.append(str(msg))
if len(DEBUG_MESSAGES) > 300:
DEBUG_MESSAGES.pop(0)
return "\n".join(DEBUG_MESSAGES)
def get_debug_log():
return "\n".join(DEBUG_MESSAGES)
# -----------------------------
# Add a new model (from the box)
# -----------------------------
def add_model_box(
category,
family,
model_name,
model_id,
description,
link,
emoji
):
try:
if not model_id:
debug("Add model failed: no model_id provided")
return gr.Markdown.update(
value="Please provide a Model ID like `user/model`."
)
if not category:
category = "Custom"
if not family:
family = "User-Added"
if not model_name:
model_name = model_id.split("/")[-1]
if not description:
description = "User-added model."
if not link:
link = f"https://huggingface.co/{model_id}"
if not emoji:
emoji = "✨"
user_models = load_user_models()
if category not in user_models:
user_models[category] = {}
if family not in user_models[category]:
user_models[category][family] = {}
user_models[category][family][model_name] = {
"id": model_id,
"description": description,
"link": link,
"emoji": emoji
}
save_user_models(user_models)
msg = (
f"Added model under `{category} / {family}`: "
f"{emoji} **{model_name}** (`{model_id}`)\n\n"
f"It will appear in the model tree after reloading the Space."
)
debug(f"Model added: {category} / {family} / {model_name} ({model_id})")
return gr.Markdown.update(value=msg)
except Exception:
import traceback
tb = traceback.format_exc()
debug(f"ERROR in add_model_box:\n{tb}")
return gr.Markdown.update(
value="An error occurred while adding the model. Check Debug Log."
)
# -----------------------------
# Helper: check model access (repo visibility)
# -----------------------------
def check_model_access(model_id, hf_token):
"""
Try to get model info; return (ok: bool, message: str).
This helps distinguish auth/gating vs other issues.
For local loading, this is not strictly required, but we keep
it to give clearer messages for private/gated models.
"""
try:
api = HfApi(token=hf_token.token if hf_token else None)
_ = api.model_info(model_id)
return True, ""
except Exception as e:
import traceback
tb = traceback.format_exc()
debug(f"ERROR in check_model_access for {model_id}:\n{tb}")
return False, str(e)
# -----------------------------
# Local model cache
# -----------------------------
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
LOCAL_MODEL_CACHE = {}
def load_local_model(model_id):
"""
Load a model + tokenizer locally and cache them.
This makes the Space behave like a dedicated model Space:
models are executed inside the container, not via Inference API.
"""
if model_id in LOCAL_MODEL_CACHE:
debug(f"Using cached model: {model_id}")
return LOCAL_MODEL_CACHE[model_id]
debug(f"Loading model locally: {model_id}")
try:
tokenizer = AutoTokenizer.from_pretrained(model_id)
except Exception as e:
debug(f"ERROR loading tokenizer for {model_id}: {e}")
raise
try:
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto"
)
except Exception as e:
debug(f"ERROR loading model weights for {model_id}: {e}")
raise
LOCAL_MODEL_CACHE[model_id] = (tokenizer, model)
return tokenizer, model
# -----------------------------
# Chat function (local models)
# -----------------------------
def respond(
message,
history,
system_message,
max_tokens,
temperature,
top_p,
active_model_key,
hf_token: gr.OAuthToken
):
# No model chosen
if active_model_key is None:
yield "No model selected. Please choose a model in the sidebar and click 'Use this model'."
return
models = merge_models()
flat = flatten_models(models)
meta_tuple = flat.get(active_model_key)
if meta_tuple is None:
yield "Selected model not found. Please choose a model again."
return
meta, _, _, _ = meta_tuple
model_id = meta["id"]
debug(f"Chat using local model: {model_id}")
# Optional: check repo access (for private/gated models)
ok, msg = check_model_access(model_id, hf_token)
if not ok:
yield (
f"Could not access model `{model_id}` on Hugging Face.\n\n"
f"This is usually because:\n"
f"- The repo is private or gated and this token has no access\n"
f"- Or the token is invalid/expired\n\n"
f"Raw error:\n{msg}\n\n"
f"Check Debug Log for more details."
)
return
# Load model locally
try:
tokenizer, model = load_local_model(model_id)
except Exception:
import traceback
tb = traceback.format_exc()
debug(f"ERROR in load_local_model for {model_id}:\n{tb}")
yield (
f"Failed to load model `{model_id}` locally inside the Space.\n"
f"Check the Debug Log for details (likely out of memory or missing files)."
)
return
# Build chat-style prompt from history + current message
prompt = system_message.strip() + "\n\n"
for turn in history or []:
role = turn.get("role", "user")
content = turn.get("content", "")
if role == "user":
prompt += f"User: {content}\n"
else:
prompt += f"Assistant: {content}\n"
prompt += f"User: {message}\nAssistant:"
debug(f"Prompt length (chars): {len(prompt)}")
try:
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Generate text locally
output_ids = model.generate(
**inputs,
max_new_tokens=int(max_tokens),
do_sample=True,
temperature=float(temperature),
top_p=float(top_p),
pad_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id is not None else None,
)
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
# Extract only the assistant's final answer
if "Assistant:" in output_text:
answer = output_text.split("Assistant:")[-1].strip()
else:
answer = output_text.strip()
yield answer
except Exception:
import traceback
tb = traceback.format_exc()
debug(f"ERROR during local generation for {model_id}:\n{tb}")
yield (
"An error occurred during local text generation.\n"
"This is often due to running out of memory for large models.\n"
"Try a smaller model, fewer max tokens, or check the Debug Log."
)
# -----------------------------
# Use model helper
# -----------------------------
def use_model(fk, old_fk):
"""
fk: full key "Category / Family / Model" (from gr.State(full_key))
old_fk: previous active model key (from active_model_state)
Returns: (new_active_key, current_model_label_text)
"""
try:
models_local = merge_models()
flat_local = flatten_models(models_local)
meta_loc_tuple = flat_local.get(fk)
if not meta_loc_tuple:
debug(f"use_model: key not found: {fk}")
return old_fk, "**Current model:** _none selected_"
meta_loc, _, _, mname = meta_loc_tuple
emoji_local = meta_loc.get("emoji", "✨")
label_text = f"**Current model:** {emoji_local} {mname}"
debug(f"use_model: selected {fk}")
return fk, label_text
except Exception:
import traceback
tb = traceback.format_exc()
debug(f"ERROR in use_model:\n{tb}")
return old_fk, "**Current model:** _error occurred (see Debug Log)_"
# -----------------------------
# Build the sidebar tree
# -----------------------------
def build_model_tree(
models,
active_model_state,
current_model_label
):
"""
models: merged models dict (Category -> Family -> Model -> meta)
active_model_state: gr.State storing current active full key
current_model_label: gr.Markdown for 'Current model: ...'
"""
for category, families in models.items():
with gr.Accordion(category, open=False):
for family, model_dict in families.items():
with gr.Accordion(family, open=False):
for model_name, meta in model_dict.items():
emoji = meta.get("emoji", "✨")
full_key = f"{category} / {family} / {model_name}"
# Model accordion
with gr.Accordion(f"{emoji} {model_name}", open=False):
info_text = (
f"**Model ID:** `{meta['id']}` \n"
f"**Description:** {meta['description']} \n"
f"[Model card]({meta['link']})"
)
gr.Markdown(info_text)
use_btn = gr.Button("Use this model", size="sm")
# Wire button -> use_model
use_btn.click(
use_model,
inputs=[gr.State(full_key), active_model_state],
outputs=[active_model_state, current_model_label],
)
# -----------------------------
# Build the UI
# -----------------------------
with gr.Blocks() as demo:
models_tree = merge_models()
# Holds full key: "Category / Family / Model"
active_model_key = gr.State(value=None)
with gr.Sidebar():
gr.LoginButton()
# Collapsible "Add New Model" box
with gr.Accordion("Add New Model", open=False):
category_input = gr.Textbox(
label="Category (e.g. Exotic or new category)",
placeholder="Exotic"
)
family_input = gr.Textbox(
label="Family (e.g. RWKV)",
placeholder="RWKV"
)
model_name_input = gr.Textbox(
label="Model Name (e.g. RWKV-World-7B)",
placeholder="RWKV-World-7B"
)
model_id_input = gr.Textbox(
label="Model ID (e.g. BlinkDL/rwkv-7-world)",
placeholder="BlinkDL/rwkv-7-world"
)
description_input = gr.Textbox(
label="Description (optional)",
lines=2
)
link_input = gr.Textbox(
label="Link (optional, will default to https://huggingface.co/ModelID if empty)",
lines=1
)
emoji_input = gr.Textbox(
label="Emoji (optional, e.g. 🌍)",
lines=1
)
add_button = gr.Button("Add Model")
add_status = gr.Markdown("")
add_button.click(
add_model_box,
inputs=[
category_input,
family_input,
model_name_input,
model_id_input,
description_input,
link_input,
emoji_input,
],
outputs=add_status,
)
# Debug Log box (separate accordion)
with gr.Accordion("Debug Log", open=False):
debug_log = gr.Textbox(
label="System Debug Output",
value="",
lines=15,
max_lines=200,
interactive=False,
show_copy_button=True,
)
# Button to refresh debug log
refresh_debug = gr.Button("Refresh Debug Log", size="sm")
refresh_debug.click(
get_debug_log,
inputs=None,
outputs=debug_log
)
# Current model label under the debug box
current_model_label = gr.Markdown("**Current model:** _none selected_")
gr.Markdown("### Models")
# Build nested accordions for models
build_model_tree(
models_tree,
active_model_state=active_model_key,
current_model_label=current_model_label,
)
# Main chat interface
chatbot = gr.ChatInterface(
respond,
title=current_model_label,
type="messages",
additional_inputs=[
gr.Textbox(
value="You are a friendly chatbot.",
label="System message"
),
gr.Slider(
minimum=1,
maximum=100000,
value=512,
step=1,
label="Max new tokens"
),
gr.Slider(
minimum=0.1,
maximum=4.0,
value=0.7,
step=0.1,
label="Temperature"
),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p"
),
active_model_key, # passes current active model key into respond()
],
)
if __name__ == "__main__":
demo.launch()