Rainbowdesign commited on
Commit
6f2bef2
·
verified ·
1 Parent(s): e774293

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -79
app.py CHANGED
@@ -1,5 +1,7 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient, HfApi
 
 
3
  import json
4
  import os
5
 
@@ -150,12 +152,14 @@ def add_model_box(
150
 
151
 
152
  # -----------------------------
153
- # Helper: check model access
154
  # -----------------------------
155
  def check_model_access(model_id, hf_token):
156
  """
157
  Try to get model info; return (ok: bool, message: str).
158
  This helps distinguish auth/gating vs other issues.
 
 
159
  """
160
  try:
161
  api = HfApi(token=hf_token.token if hf_token else None)
@@ -169,7 +173,46 @@ def check_model_access(model_id, hf_token):
169
 
170
 
171
  # -----------------------------
172
- # Chat function (robust)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  # -----------------------------
174
  def respond(
175
  message,
@@ -197,19 +240,13 @@ def respond(
197
  meta, _, _, _ = meta_tuple
198
  model_id = meta["id"]
199
 
200
- debug(f"Chat using model: {model_id}")
201
 
202
- # Check token presence
203
- if hf_token is None or hf_token.token is None:
204
- debug("No HF token available from Login.")
205
- yield "No Hugging Face token detected. Please click Login in the sidebar and try again."
206
- return
207
-
208
- # Check access to the model
209
  ok, msg = check_model_access(model_id, hf_token)
210
  if not ok:
211
  yield (
212
- f"Could not access model `{model_id}`.\n\n"
213
  f"This is usually because:\n"
214
  f"- The repo is private or gated and this token has no access\n"
215
  f"- Or the token is invalid/expired\n\n"
@@ -218,68 +255,63 @@ def respond(
218
  )
219
  return
220
 
221
- # Build messages
222
- messages = [{"role": "system", "content": system_message}]
223
- messages.extend(history or [])
224
- messages.append({"role": "user", "content": message})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
  try:
227
- client = InferenceClient(token=hf_token.token, model=model_id)
228
-
229
- # Try chat_completion first
230
- response = ""
231
- try:
232
- for msg_obj in client.chat_completion(
233
- messages,
234
- max_tokens=max_tokens,
235
- stream=True,
236
- temperature=temperature,
237
- top_p=top_p,
238
- ):
239
- # Defensive handling of streaming structure
240
- choice_list = getattr(msg_obj, "choices", [])
241
- if not choice_list:
242
- continue
243
- delta = getattr(choice_list[0], "delta", None)
244
- if delta is None:
245
- continue
246
- chunk = getattr(delta, "content", "") or ""
247
- response += chunk
248
- yield response
249
-
250
- # If we streamed something, we're done
251
- if response:
252
- return
253
-
254
- except Exception:
255
- # Fall back to text_generation if chat_completion fails or model isn't chat-style
256
- import traceback
257
- tb = traceback.format_exc()
258
- debug(f"chat_completion failed, falling back to text_generation:\n{tb}")
259
-
260
- prompt = (
261
- system_message
262
- + "\n\n"
263
- + "\n".join([f"User: {h['content']}" if h["role"] == "user" else f"Assistant: {h['content']}" for h in (history or [])])
264
- + f"\nUser: {message}\nAssistant:"
265
- )
266
 
267
- text = client.text_generation(
268
- prompt,
269
- max_new_tokens=max_tokens,
270
- temperature=temperature,
271
- top_p=top_p,
272
- stream=False,
273
- )
274
- yield text
 
275
 
276
  except Exception:
277
  import traceback
278
  tb = traceback.format_exc()
279
- debug(f"ERROR in respond:\n{tb}")
280
  yield (
281
- "An unexpected error occurred while talking to the model.\n"
282
- "Please check the Debug Log for more details."
 
283
  )
284
 
285
 
@@ -339,21 +371,21 @@ def build_model_tree(
339
 
340
  # Model accordion
341
  with gr.Accordion(f"{emoji} {model_name}", open=False):
342
- info_text = (
343
- f"**Model ID:** `{meta['id']}` \n"
344
- f"**Description:** {meta['description']} \n"
345
- f"[Model card]({meta['link']})"
346
- )
347
- gr.Markdown(info_text)
348
 
349
- use_btn = gr.Button("Use this model", size="sm")
350
 
351
- # Wire button -> use_model
352
- use_btn.click(
353
- use_model,
354
- inputs=[gr.State(full_key), active_model_state],
355
- outputs=[active_model_state, current_model_label],
356
- )
357
 
358
 
359
  # -----------------------------
@@ -417,6 +449,7 @@ with gr.Blocks() as demo:
417
  )
418
 
419
  # Debug Log box (separate accordion)
 
420
  debug_log = gr.Textbox(
421
  label="System Debug Output",
422
  value="",
 
1
  import gradio as gr
2
+ from huggingface_hub import HfApi
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import torch
5
  import json
6
  import os
7
 
 
152
 
153
 
154
  # -----------------------------
155
+ # Helper: check model access (repo visibility)
156
  # -----------------------------
157
  def check_model_access(model_id, hf_token):
158
  """
159
  Try to get model info; return (ok: bool, message: str).
160
  This helps distinguish auth/gating vs other issues.
161
+ For local loading, this is not strictly required, but we keep
162
+ it to give clearer messages for private/gated models.
163
  """
164
  try:
165
  api = HfApi(token=hf_token.token if hf_token else None)
 
173
 
174
 
175
  # -----------------------------
176
+ # Local model cache
177
+ # -----------------------------
178
+ from transformers import AutoTokenizer, AutoModelForCausalLM
179
+ import torch
180
+
181
+ LOCAL_MODEL_CACHE = {}
182
+
183
+ def load_local_model(model_id):
184
+ """
185
+ Load a model + tokenizer locally and cache them.
186
+ This makes the Space behave like a dedicated model Space:
187
+ models are executed inside the container, not via Inference API.
188
+ """
189
+ if model_id in LOCAL_MODEL_CACHE:
190
+ debug(f"Using cached model: {model_id}")
191
+ return LOCAL_MODEL_CACHE[model_id]
192
+
193
+ debug(f"Loading model locally: {model_id}")
194
+ try:
195
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
196
+ except Exception as e:
197
+ debug(f"ERROR loading tokenizer for {model_id}: {e}")
198
+ raise
199
+
200
+ try:
201
+ model = AutoModelForCausalLM.from_pretrained(
202
+ model_id,
203
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
204
+ device_map="auto"
205
+ )
206
+ except Exception as e:
207
+ debug(f"ERROR loading model weights for {model_id}: {e}")
208
+ raise
209
+
210
+ LOCAL_MODEL_CACHE[model_id] = (tokenizer, model)
211
+ return tokenizer, model
212
+
213
+
214
+ # -----------------------------
215
+ # Chat function (local models)
216
  # -----------------------------
217
  def respond(
218
  message,
 
240
  meta, _, _, _ = meta_tuple
241
  model_id = meta["id"]
242
 
243
+ debug(f"Chat using local model: {model_id}")
244
 
245
+ # Optional: check repo access (for private/gated models)
 
 
 
 
 
 
246
  ok, msg = check_model_access(model_id, hf_token)
247
  if not ok:
248
  yield (
249
+ f"Could not access model `{model_id}` on Hugging Face.\n\n"
250
  f"This is usually because:\n"
251
  f"- The repo is private or gated and this token has no access\n"
252
  f"- Or the token is invalid/expired\n\n"
 
255
  )
256
  return
257
 
258
+ # Load model locally
259
+ try:
260
+ tokenizer, model = load_local_model(model_id)
261
+ except Exception:
262
+ import traceback
263
+ tb = traceback.format_exc()
264
+ debug(f"ERROR in load_local_model for {model_id}:\n{tb}")
265
+ yield (
266
+ f"Failed to load model `{model_id}` locally inside the Space.\n"
267
+ f"Check the Debug Log for details (likely out of memory or missing files)."
268
+ )
269
+ return
270
+
271
+ # Build chat-style prompt from history + current message
272
+ prompt = system_message.strip() + "\n\n"
273
+ for turn in history or []:
274
+ role = turn.get("role", "user")
275
+ content = turn.get("content", "")
276
+ if role == "user":
277
+ prompt += f"User: {content}\n"
278
+ else:
279
+ prompt += f"Assistant: {content}\n"
280
+ prompt += f"User: {message}\nAssistant:"
281
+
282
+ debug(f"Prompt length (chars): {len(prompt)}")
283
 
284
  try:
285
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
286
+
287
+ # Generate text locally
288
+ output_ids = model.generate(
289
+ **inputs,
290
+ max_new_tokens=int(max_tokens),
291
+ do_sample=True,
292
+ temperature=float(temperature),
293
+ top_p=float(top_p),
294
+ pad_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id is not None else None,
295
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
 
297
+ output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
298
+
299
+ # Extract only the assistant's final answer
300
+ if "Assistant:" in output_text:
301
+ answer = output_text.split("Assistant:")[-1].strip()
302
+ else:
303
+ answer = output_text.strip()
304
+
305
+ yield answer
306
 
307
  except Exception:
308
  import traceback
309
  tb = traceback.format_exc()
310
+ debug(f"ERROR during local generation for {model_id}:\n{tb}")
311
  yield (
312
+ "An error occurred during local text generation.\n"
313
+ "This is often due to running out of memory for large models.\n"
314
+ "Try a smaller model, fewer max tokens, or check the Debug Log."
315
  )
316
 
317
 
 
371
 
372
  # Model accordion
373
  with gr.Accordion(f"{emoji} {model_name}", open=False):
374
+ info_text = (
375
+ f"**Model ID:** `{meta['id']}` \n"
376
+ f"**Description:** {meta['description']} \n"
377
+ f"[Model card]({meta['link']})"
378
+ )
379
+ gr.Markdown(info_text)
380
 
381
+ use_btn = gr.Button("Use this model", size="sm")
382
 
383
+ # Wire button -> use_model
384
+ use_btn.click(
385
+ use_model,
386
+ inputs=[gr.State(full_key), active_model_state],
387
+ outputs=[active_model_state, current_model_label],
388
+ )
389
 
390
 
391
  # -----------------------------
 
449
  )
450
 
451
  # Debug Log box (separate accordion)
452
+ with gr.Accordion("Debug Log", open=False):
453
  debug_log = gr.Textbox(
454
  label="System Debug Output",
455
  value="",