import gradio as gr from huggingface_hub import HfApi from transformers import AutoTokenizer, AutoModelForCausalLM import torch import json import os DEFAULT_FILE = "default_models.json" USER_FILE = "models.json" # ----------------------------- # Model data loading / saving # ----------------------------- def load_default_models(): with open(DEFAULT_FILE, "r", encoding="utf-8") as f: return json.load(f) def load_user_models(): if os.path.exists(USER_FILE): with open(USER_FILE, "r", encoding="utf-8") as f: try: return json.load(f) except json.JSONDecodeError: return {} return {} def save_user_models(data): with open(USER_FILE, "w", encoding="utf-8") as f: json.dump(data, f, indent=2, ensure_ascii=False) def merge_models(): """ Merge default + user models into one tree: Category -> Family -> Model -> meta User models can introduce new categories/families. """ base = load_default_models() user = load_user_models() for category, families in user.items(): if category not in base: base[category] = {} for family, models in families.items(): if family not in base[category]: base[category][family] = {} for model_name, meta in models.items(): base[category][family][model_name] = meta return base # ----------------------------- # Utility: flatten and lookup # ----------------------------- def flatten_models(model_tree): """ Returns a dict: full_key -> (meta, category, family, model_name) where full_key = "Category / Family / Model" """ flat = {} for category, families in model_tree.items(): for family, models in families.items(): for model_name, meta in models.items(): full_key = f"{category} / {family} / {model_name}" flat[full_key] = (meta, category, family, model_name) return flat # ----------------------------- # Debug logging # ----------------------------- DEBUG_MESSAGES = [] # global buffer def debug(msg): """Append a debug message to the global log.""" DEBUG_MESSAGES.append(str(msg)) if len(DEBUG_MESSAGES) > 300: DEBUG_MESSAGES.pop(0) return "\n".join(DEBUG_MESSAGES) def get_debug_log(): return "\n".join(DEBUG_MESSAGES) # ----------------------------- # Add a new model (from the box) # ----------------------------- def add_model_box( category, family, model_name, model_id, description, link, emoji ): try: if not model_id: debug("Add model failed: no model_id provided") return gr.Markdown.update( value="Please provide a Model ID like `user/model`." ) if not category: category = "Custom" if not family: family = "User-Added" if not model_name: model_name = model_id.split("/")[-1] if not description: description = "User-added model." if not link: link = f"https://huggingface.co/{model_id}" if not emoji: emoji = "✨" user_models = load_user_models() if category not in user_models: user_models[category] = {} if family not in user_models[category]: user_models[category][family] = {} user_models[category][family][model_name] = { "id": model_id, "description": description, "link": link, "emoji": emoji } save_user_models(user_models) msg = ( f"Added model under `{category} / {family}`: " f"{emoji} **{model_name}** (`{model_id}`)\n\n" f"It will appear in the model tree after reloading the Space." ) debug(f"Model added: {category} / {family} / {model_name} ({model_id})") return gr.Markdown.update(value=msg) except Exception: import traceback tb = traceback.format_exc() debug(f"ERROR in add_model_box:\n{tb}") return gr.Markdown.update( value="An error occurred while adding the model. Check Debug Log." ) # ----------------------------- # Helper: check model access (repo visibility) # ----------------------------- def check_model_access(model_id, hf_token): """ Try to get model info; return (ok: bool, message: str). This helps distinguish auth/gating vs other issues. For local loading, this is not strictly required, but we keep it to give clearer messages for private/gated models. """ try: api = HfApi(token=hf_token.token if hf_token else None) _ = api.model_info(model_id) return True, "" except Exception as e: import traceback tb = traceback.format_exc() debug(f"ERROR in check_model_access for {model_id}:\n{tb}") return False, str(e) # ----------------------------- # Local model cache # ----------------------------- from transformers import AutoTokenizer, AutoModelForCausalLM import torch LOCAL_MODEL_CACHE = {} def load_local_model(model_id): """ Load a model + tokenizer locally and cache them. This makes the Space behave like a dedicated model Space: models are executed inside the container, not via Inference API. """ if model_id in LOCAL_MODEL_CACHE: debug(f"Using cached model: {model_id}") return LOCAL_MODEL_CACHE[model_id] debug(f"Loading model locally: {model_id}") try: tokenizer = AutoTokenizer.from_pretrained(model_id) except Exception as e: debug(f"ERROR loading tokenizer for {model_id}: {e}") raise try: model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" ) except Exception as e: debug(f"ERROR loading model weights for {model_id}: {e}") raise LOCAL_MODEL_CACHE[model_id] = (tokenizer, model) return tokenizer, model # ----------------------------- # Chat function (local models) # ----------------------------- def respond( message, history, system_message, max_tokens, temperature, top_p, active_model_key, hf_token: gr.OAuthToken ): # No model chosen if active_model_key is None: yield "No model selected. Please choose a model in the sidebar and click 'Use this model'." return models = merge_models() flat = flatten_models(models) meta_tuple = flat.get(active_model_key) if meta_tuple is None: yield "Selected model not found. Please choose a model again." return meta, _, _, _ = meta_tuple model_id = meta["id"] debug(f"Chat using local model: {model_id}") # Optional: check repo access (for private/gated models) ok, msg = check_model_access(model_id, hf_token) if not ok: yield ( f"Could not access model `{model_id}` on Hugging Face.\n\n" f"This is usually because:\n" f"- The repo is private or gated and this token has no access\n" f"- Or the token is invalid/expired\n\n" f"Raw error:\n{msg}\n\n" f"Check Debug Log for more details." ) return # Load model locally try: tokenizer, model = load_local_model(model_id) except Exception: import traceback tb = traceback.format_exc() debug(f"ERROR in load_local_model for {model_id}:\n{tb}") yield ( f"Failed to load model `{model_id}` locally inside the Space.\n" f"Check the Debug Log for details (likely out of memory or missing files)." ) return # Build chat-style prompt from history + current message prompt = system_message.strip() + "\n\n" for turn in history or []: role = turn.get("role", "user") content = turn.get("content", "") if role == "user": prompt += f"User: {content}\n" else: prompt += f"Assistant: {content}\n" prompt += f"User: {message}\nAssistant:" debug(f"Prompt length (chars): {len(prompt)}") try: inputs = tokenizer(prompt, return_tensors="pt").to(model.device) # Generate text locally output_ids = model.generate( **inputs, max_new_tokens=int(max_tokens), do_sample=True, temperature=float(temperature), top_p=float(top_p), pad_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id is not None else None, ) output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) # Extract only the assistant's final answer if "Assistant:" in output_text: answer = output_text.split("Assistant:")[-1].strip() else: answer = output_text.strip() yield answer except Exception: import traceback tb = traceback.format_exc() debug(f"ERROR during local generation for {model_id}:\n{tb}") yield ( "An error occurred during local text generation.\n" "This is often due to running out of memory for large models.\n" "Try a smaller model, fewer max tokens, or check the Debug Log." ) # ----------------------------- # Use model helper # ----------------------------- def use_model(fk, old_fk): """ fk: full key "Category / Family / Model" (from gr.State(full_key)) old_fk: previous active model key (from active_model_state) Returns: (new_active_key, current_model_label_text) """ try: models_local = merge_models() flat_local = flatten_models(models_local) meta_loc_tuple = flat_local.get(fk) if not meta_loc_tuple: debug(f"use_model: key not found: {fk}") return old_fk, "**Current model:** _none selected_" meta_loc, _, _, mname = meta_loc_tuple emoji_local = meta_loc.get("emoji", "✨") label_text = f"**Current model:** {emoji_local} {mname}" debug(f"use_model: selected {fk}") return fk, label_text except Exception: import traceback tb = traceback.format_exc() debug(f"ERROR in use_model:\n{tb}") return old_fk, "**Current model:** _error occurred (see Debug Log)_" # ----------------------------- # Build the sidebar tree # ----------------------------- def build_model_tree( models, active_model_state, current_model_label ): """ models: merged models dict (Category -> Family -> Model -> meta) active_model_state: gr.State storing current active full key current_model_label: gr.Markdown for 'Current model: ...' """ for category, families in models.items(): with gr.Accordion(category, open=False): for family, model_dict in families.items(): with gr.Accordion(family, open=False): for model_name, meta in model_dict.items(): emoji = meta.get("emoji", "✨") full_key = f"{category} / {family} / {model_name}" # Model accordion with gr.Accordion(f"{emoji} {model_name}", open=False): info_text = ( f"**Model ID:** `{meta['id']}` \n" f"**Description:** {meta['description']} \n" f"[Model card]({meta['link']})" ) gr.Markdown(info_text) use_btn = gr.Button("Use this model", size="sm") # Wire button -> use_model use_btn.click( use_model, inputs=[gr.State(full_key), active_model_state], outputs=[active_model_state, current_model_label], ) # ----------------------------- # Build the UI # ----------------------------- with gr.Blocks() as demo: models_tree = merge_models() # Holds full key: "Category / Family / Model" active_model_key = gr.State(value=None) with gr.Sidebar(): gr.LoginButton() # Collapsible "Add New Model" box with gr.Accordion("Add New Model", open=False): category_input = gr.Textbox( label="Category (e.g. Exotic or new category)", placeholder="Exotic" ) family_input = gr.Textbox( label="Family (e.g. RWKV)", placeholder="RWKV" ) model_name_input = gr.Textbox( label="Model Name (e.g. RWKV-World-7B)", placeholder="RWKV-World-7B" ) model_id_input = gr.Textbox( label="Model ID (e.g. BlinkDL/rwkv-7-world)", placeholder="BlinkDL/rwkv-7-world" ) description_input = gr.Textbox( label="Description (optional)", lines=2 ) link_input = gr.Textbox( label="Link (optional, will default to https://huggingface.co/ModelID if empty)", lines=1 ) emoji_input = gr.Textbox( label="Emoji (optional, e.g. 🌍)", lines=1 ) add_button = gr.Button("Add Model") add_status = gr.Markdown("") add_button.click( add_model_box, inputs=[ category_input, family_input, model_name_input, model_id_input, description_input, link_input, emoji_input, ], outputs=add_status, ) # Debug Log box (separate accordion) with gr.Accordion("Debug Log", open=False): debug_log = gr.Textbox( label="System Debug Output", value="", lines=15, max_lines=200, interactive=False, show_copy_button=True, ) # Button to refresh debug log refresh_debug = gr.Button("Refresh Debug Log", size="sm") refresh_debug.click( get_debug_log, inputs=None, outputs=debug_log ) # Current model label under the debug box current_model_label = gr.Markdown("**Current model:** _none selected_") gr.Markdown("### Models") # Build nested accordions for models build_model_tree( models_tree, active_model_state=active_model_key, current_model_label=current_model_label, ) # Main chat interface chatbot = gr.ChatInterface( respond, title=current_model_label, type="messages", additional_inputs=[ gr.Textbox( value="You are a friendly chatbot.", label="System message" ), gr.Slider( minimum=1, maximum=100000, value=512, step=1, label="Max new tokens" ), gr.Slider( minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature" ), gr.Slider( minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p" ), active_model_key, # passes current active model key into respond() ], ) if __name__ == "__main__": demo.launch()