Rainbowdesign commited on
Commit
ac097cd
·
verified ·
1 Parent(s): 68a234e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +339 -37
app.py CHANGED
@@ -1,80 +1,382 @@
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
- import json, os
 
4
 
5
- MODEL_FILE = "models.json"
 
6
 
7
- def load_models():
8
- if os.path.exists(MODEL_FILE):
9
- return json.load(open(MODEL_FILE))
10
- return {"Default (GPT‑OSS‑20B)": "openai/gpt-oss-20b"}
11
 
12
- def save_models(models):
13
- json.dump(models, open(MODEL_FILE, "w"))
 
 
 
 
14
 
15
- models = load_models()
16
 
 
 
 
 
 
 
 
 
17
 
18
- def add_model(link):
19
- name = link.split("/")[-1]
20
- models[name] = link
21
- save_models(models)
22
- return gr.Dropdown.update(choices=list(models.keys()), value=name)
23
 
 
 
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def respond(
26
  message,
27
- history: list[dict[str, str]],
28
  system_message,
29
  max_tokens,
30
  temperature,
31
  top_p,
32
- model_choice,
33
- hf_token: gr.OAuthToken,
34
  ):
35
- model_id = models[model_choice]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  client = InferenceClient(token=hf_token.token, model=model_id)
37
 
38
  messages = [{"role": "system", "content": system_message}]
39
- messages.extend(history)
40
  messages.append({"role": "user", "content": message})
41
 
42
  response = ""
43
-
44
- for message in client.chat_completion(
45
  messages,
46
  max_tokens=max_tokens,
47
  stream=True,
48
  temperature=temperature,
49
  top_p=top_p,
50
  ):
51
- choices = message.choices
52
- token = ""
53
- if len(choices) and choices[0].delta.content:
54
- token = choices[0].delta.content
55
-
56
- response += token
57
  yield response
58
 
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  with gr.Blocks() as demo:
 
 
 
 
 
61
  with gr.Sidebar():
62
  gr.LoginButton()
63
- gr.Markdown("### Add a new model")
64
- new_model = gr.Textbox(label="Model repo (e.g. meta-llama/Llama-3-8B)")
65
- model_dropdown = gr.Dropdown(list(models.keys()), label="Choose model")
66
- add_button = gr.Button("Add model")
67
- add_button.click(add_model, new_model, model_dropdown)
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  chatbot = gr.ChatInterface(
70
  respond,
 
71
  type="messages",
72
  additional_inputs=[
73
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
74
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
75
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
76
- gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"),
77
- model_dropdown,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  ],
79
  )
80
 
 
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
+ import json
4
+ import os
5
 
6
+ DEFAULT_FILE = "default_models.json"
7
+ USER_FILE = "models.json"
8
 
 
 
 
 
9
 
10
+ # -----------------------------
11
+ # Model data loading / saving
12
+ # -----------------------------
13
+ def load_default_models():
14
+ with open(DEFAULT_FILE, "r", encoding="utf-8") as f:
15
+ return json.load(f)
16
 
 
17
 
18
+ def load_user_models():
19
+ if os.path.exists(USER_FILE):
20
+ with open(USER_FILE, "r", encoding="utf-8") as f:
21
+ try:
22
+ return json.load(f)
23
+ except json.JSONDecodeError:
24
+ return {}
25
+ return {}
26
 
 
 
 
 
 
27
 
28
+ def save_user_models(data):
29
+ with open(USER_FILE, "w", encoding="utf-8") as f:
30
+ json.dump(data, f, indent=2, ensure_ascii=False)
31
 
32
+
33
+ def merge_models():
34
+ """
35
+ Merge default + user models into one tree:
36
+ Category -> Family -> Model -> meta
37
+ User models can introduce new categories/families.
38
+ """
39
+ base = load_default_models()
40
+ user = load_user_models()
41
+
42
+ for category, families in user.items():
43
+ if category not in base:
44
+ base[category] = {}
45
+ for family, models in families.items():
46
+ if family not in base[category]:
47
+ base[category][family] = {}
48
+ for model_name, meta in models.items():
49
+ base[category][family][model_name] = meta
50
+
51
+ return base
52
+
53
+
54
+ # -----------------------------
55
+ # Utility: flatten and lookup
56
+ # -----------------------------
57
+ def flatten_models(model_tree):
58
+ """
59
+ Returns a dict:
60
+ full_key -> (meta, category, family, model_name)
61
+ where full_key = "Category / Family / Model"
62
+ """
63
+ flat = {}
64
+ for category, families in model_tree.items():
65
+ for family, models in families.items():
66
+ for model_name, meta in models.items():
67
+ full_key = f"{category} / {family} / {model_name}"
68
+ flat[full_key] = (meta, category, family, model_name)
69
+ return flat
70
+
71
+
72
+ # -----------------------------
73
+ # Add a new model
74
+ # -----------------------------
75
+ def add_model_box(
76
+ category,
77
+ family,
78
+ model_name,
79
+ model_id,
80
+ description,
81
+ link,
82
+ emoji
83
+ ):
84
+ # Basic validation
85
+ if not model_id:
86
+ return gr.Markdown.update(
87
+ value="Please provide a Model ID like `user/model`."
88
+ )
89
+
90
+ # Fallbacks
91
+ if not category:
92
+ category = "Custom"
93
+ if not family:
94
+ family = "User-Added"
95
+ if not model_name:
96
+ model_name = model_id.split("/")[-1]
97
+ if not description:
98
+ description = "User-added model."
99
+ if not link:
100
+ link = f"https://huggingface.co/{model_id}"
101
+ if not emoji:
102
+ emoji = "✨"
103
+
104
+ user_models = load_user_models()
105
+
106
+ # Ensure nested structure exists
107
+ if category not in user_models:
108
+ user_models[category] = {}
109
+ if family not in user_models[category]:
110
+ user_models[category][family] = {}
111
+
112
+ user_models[category][family][model_name] = {
113
+ "id": model_id,
114
+ "description": description,
115
+ "link": link,
116
+ "emoji": emoji
117
+ }
118
+
119
+ save_user_models(user_models)
120
+
121
+ msg = (
122
+ f"Added model under `{category} / {family}`: "
123
+ f"{emoji} **{model_name}** (`{model_id}`)\n\n"
124
+ f"It will appear in the model tree after reloading the Space."
125
+ )
126
+ return gr.Markdown.update(value=msg)
127
+
128
+
129
+ # -----------------------------
130
+ # Chat function
131
+ # -----------------------------
132
  def respond(
133
  message,
134
+ history,
135
  system_message,
136
  max_tokens,
137
  temperature,
138
  top_p,
139
+ active_model_key,
140
+ hf_token: gr.OAuthToken
141
  ):
142
+ if active_model_key is None:
143
+ yield "No model selected. Please choose a model in the sidebar and click 'Use this model'."
144
+ return
145
+
146
+ models = merge_models()
147
+ flat = flatten_models(models)
148
+
149
+ meta_tuple = flat.get(active_model_key)
150
+ if meta_tuple is None:
151
+ yield "Selected model not found. Please choose a model again."
152
+ return
153
+
154
+ meta, _, _, _ = meta_tuple
155
+ model_id = meta["id"]
156
+
157
  client = InferenceClient(token=hf_token.token, model=model_id)
158
 
159
  messages = [{"role": "system", "content": system_message}]
160
+ messages.extend(history or [])
161
  messages.append({"role": "user", "content": message})
162
 
163
  response = ""
164
+ for msg in client.chat_completion(
 
165
  messages,
166
  max_tokens=max_tokens,
167
  stream=True,
168
  temperature=temperature,
169
  top_p=top_p,
170
  ):
171
+ delta = msg.choices[0].delta.content or ""
172
+ response += delta
 
 
 
 
173
  yield response
174
 
175
 
176
+ # -----------------------------
177
+ # Build the sidebar tree
178
+ # -----------------------------
179
+ def build_model_tree(
180
+ models,
181
+ active_model_state,
182
+ current_model_label,
183
+ info_markdowns,
184
+ use_buttons
185
+ ):
186
+ """
187
+ models: merged models dict (Category -> Family -> Model -> meta)
188
+ active_model_state: gr.State storing current active full key
189
+ current_model_label: gr.Markdown for 'Current model: ...'
190
+ info_markdowns: dict full_key -> gr.Markdown (Model Info)
191
+ use_buttons: dict full_key -> gr.Button (Use this model)
192
+ """
193
+
194
+ flat = flatten_models(models)
195
+
196
+ # We’ll build the UI, and inside each loop create small closures
197
+ for category, families in models.items():
198
+ with gr.Accordion(category, open=False):
199
+ for family, model_dict in families.items():
200
+ with gr.Accordion(family, open=False):
201
+ for model_name, meta in model_dict.items():
202
+ emoji = meta.get("emoji", "✨")
203
+ full_key = f"{category} / {family} / {model_name}"
204
+
205
+ # Model button row
206
+ model_button = gr.Button(f"{emoji} {model_name}", size="sm")
207
+
208
+ # Model info area
209
+ info_md = gr.Markdown(visible=False)
210
+ info_markdowns[full_key] = info_md
211
+
212
+ # Use button
213
+ use_btn = gr.Button("Use this model", visible=False, size="sm")
214
+ use_buttons[full_key] = use_btn
215
+
216
+ # ---- Click model button → show this info, hide others
217
+ def show_info(clicked_key, fk=full_key):
218
+ models_local = merge_models()
219
+ flat_local = flatten_models(models_local)
220
+
221
+ updates = {}
222
+ for key, (meta_loc, _, _, _) in flat_local.items():
223
+ md_out = info_markdowns.get(key)
224
+ btn_out = use_buttons.get(key)
225
+ if md_out is None or btn_out is None:
226
+ continue
227
+
228
+ if key == fk:
229
+ text = (
230
+ f"**Model ID:** `{meta_loc['id']}` \n"
231
+ f"**Description:** {meta_loc['description']} \n"
232
+ f"[Model card]({meta_loc['link']})"
233
+ )
234
+ updates[md_out] = gr.Markdown.update(value=text, visible=True)
235
+ updates[btn_out] = gr.Button.update(visible=True)
236
+ else:
237
+ updates[md_out] = gr.Markdown.update(visible=False)
238
+ updates[btn_out] = gr.Button.update(visible=False)
239
+ return updates
240
+
241
+ model_button.click(
242
+ show_info,
243
+ inputs=active_model_state,
244
+ outputs=list(info_markdowns.values()) + list(use_buttons.values()),
245
+ )
246
+
247
+ # ---- Click "Use this model" → set active model + label
248
+ def use_model(fk=full_key):
249
+ models_local = merge_models()
250
+ flat_local = flatten_models(models_local)
251
+ meta_loc_tuple = flat_local.get(fk)
252
+
253
+ if not meta_loc_tuple:
254
+ return fk, gr.Markdown.update(
255
+ value="**Current model:** _none selected_"
256
+ )
257
+
258
+ meta_loc, _, _, mname = meta_loc_tuple
259
+ emoji_local = meta_loc.get("emoji", "✨")
260
+ label_text = f"**Current model:** {emoji_local} {mname}"
261
+ return fk, gr.Markdown.update(value=label_text)
262
+
263
+ use_btn.click(
264
+ use_model,
265
+ inputs=None,
266
+ outputs=[active_model_state, current_model_label],
267
+ )
268
+
269
+
270
+ # -----------------------------
271
+ # Build the UI
272
+ # -----------------------------
273
  with gr.Blocks() as demo:
274
+ models_tree = merge_models()
275
+
276
+ # Holds full key: "Category / Family / Model"
277
+ active_model_key = gr.State(value=None)
278
+
279
  with gr.Sidebar():
280
  gr.LoginButton()
 
 
 
 
 
281
 
282
+ # Collapsible "Add New Model" box
283
+ with gr.Accordion("Add New Model", open=False):
284
+ category_input = gr.Textbox(
285
+ label="Category (e.g. Exotic or new category)",
286
+ placeholder="Exotic"
287
+ )
288
+ family_input = gr.Textbox(
289
+ label="Family (e.g. RWKV)",
290
+ placeholder="RWKV"
291
+ )
292
+ model_name_input = gr.Textbox(
293
+ label="Model Name (e.g. RWKV-World-7B)",
294
+ placeholder="RWKV-World-7B"
295
+ )
296
+ model_id_input = gr.Textbox(
297
+ label="Model ID (e.g. BlinkDL/rwkv-7-world)",
298
+ placeholder="BlinkDL/rwkv-7-world"
299
+ )
300
+ description_input = gr.Textbox(
301
+ label="Description (optional)",
302
+ lines=2
303
+ )
304
+ link_input = gr.Textbox(
305
+ label="Link (optional, will default to https://huggingface.co/ModelID if empty)",
306
+ lines=1
307
+ )
308
+ emoji_input = gr.Textbox(
309
+ label="Emoji (optional, e.g. 🌍)",
310
+ lines=1
311
+ )
312
+
313
+ add_button = gr.Button("Add Model")
314
+ add_status = gr.Markdown("")
315
+
316
+ add_button.click(
317
+ add_model_box,
318
+ inputs=[
319
+ category_input,
320
+ family_input,
321
+ model_name_input,
322
+ model_id_input,
323
+ description_input,
324
+ link_input,
325
+ emoji_input,
326
+ ],
327
+ outputs=add_status,
328
+ )
329
+
330
+ # Current model label under the Add box
331
+ current_model_label = gr.Markdown("**Current model:** _none selected_")
332
+
333
+ # Structures to hold info areas
334
+ info_markdowns = {}
335
+ use_buttons = {}
336
+
337
+ gr.Markdown("### Models")
338
+
339
+ # Build nested accordions for models
340
+ build_model_tree(
341
+ models_tree,
342
+ active_model_state=active_model_key,
343
+ current_model_label=current_model_label,
344
+ info_markdowns=info_markdowns,
345
+ use_buttons=use_buttons,
346
+ )
347
+
348
+ # Main chat interface
349
  chatbot = gr.ChatInterface(
350
  respond,
351
+ title="chatbot",
352
  type="messages",
353
  additional_inputs=[
354
+ gr.Textbox(
355
+ value="You are a friendly chatbot.",
356
+ label="System message"
357
+ ),
358
+ gr.Slider(
359
+ minimum=1,
360
+ maximum=2048,
361
+ value=512,
362
+ step=1,
363
+ label="Max new tokens"
364
+ ),
365
+ gr.Slider(
366
+ minimum=0.1,
367
+ maximum=4.0,
368
+ value=0.7,
369
+ step=0.1,
370
+ label="Temperature"
371
+ ),
372
+ gr.Slider(
373
+ minimum=0.1,
374
+ maximum=1.0,
375
+ value=0.95,
376
+ step=0.05,
377
+ label="Top-p"
378
+ ),
379
+ active_model_key, # passes current active model key into respond()
380
  ],
381
  )
382