Rainbowdesign commited on
Commit
296d19e
·
verified ·
1 Parent(s): 03f3ac3

Update app.py

Browse files

Add model selector

Files changed (1) hide show
  1. app.py +39 -27
app.py CHANGED
@@ -1,5 +1,25 @@
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
 
5
  def respond(
@@ -9,17 +29,14 @@ def respond(
9
  max_tokens,
10
  temperature,
11
  top_p,
 
12
  hf_token: gr.OAuthToken,
13
  ):
14
- """
15
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
16
- """
17
- client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
18
 
19
  messages = [{"role": "system", "content": system_message}]
20
-
21
  messages.extend(history)
22
-
23
  messages.append({"role": "user", "content": message})
24
 
25
  response = ""
@@ -40,31 +57,26 @@ def respond(
40
  yield response
41
 
42
 
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- chatbot = gr.ChatInterface(
47
- respond,
48
- type="messages",
49
- additional_inputs=[
50
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
51
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
52
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
53
- gr.Slider(
54
- minimum=0.1,
55
- maximum=1.0,
56
- value=0.95,
57
- step=0.05,
58
- label="Top-p (nucleus sampling)",
59
- ),
60
- ],
61
- )
62
-
63
  with gr.Blocks() as demo:
64
  with gr.Sidebar():
65
  gr.LoginButton()
66
- chatbot.render()
 
 
 
 
67
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  if __name__ == "__main__":
70
  demo.launch()
 
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
+ import json, os
4
+
5
+ MODEL_FILE = "models.json"
6
+
7
+ def load_models():
8
+ if os.path.exists(MODEL_FILE):
9
+ return json.load(open(MODEL_FILE))
10
+ return {"Default (GPT‑OSS‑20B)": "openai/gpt-oss-20b"}
11
+
12
+ def save_models(models):
13
+ json.dump(models, open(MODEL_FILE, "w"))
14
+
15
+ models = load_models()
16
+
17
+
18
+ def add_model(link):
19
+ name = link.split("/")[-1]
20
+ models[name] = link
21
+ save_models(models)
22
+ return gr.Dropdown.update(choices=list(models.keys()), value=name)
23
 
24
 
25
  def respond(
 
29
  max_tokens,
30
  temperature,
31
  top_p,
32
+ model_choice,
33
  hf_token: gr.OAuthToken,
34
  ):
35
+ model_id = models[model_choice]
36
+ client = InferenceClient(token=hf_token.token, model=model_id)
 
 
37
 
38
  messages = [{"role": "system", "content": system_message}]
 
39
  messages.extend(history)
 
40
  messages.append({"role": "user", "content": message})
41
 
42
  response = ""
 
57
  yield response
58
 
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  with gr.Blocks() as demo:
61
  with gr.Sidebar():
62
  gr.LoginButton()
63
+ gr.Markdown("### Add a new model")
64
+ new_model = gr.Textbox(label="Model repo (e.g. meta-llama/Llama-3-8B)")
65
+ model_dropdown = gr.Dropdown(list(models.keys()), label="Choose model")
66
+ add_button = gr.Button("Add model")
67
+ add_button.click(add_model, new_model, model_dropdown)
68
 
69
+ chatbot = gr.ChatInterface(
70
+ respond,
71
+ type="messages",
72
+ additional_inputs=[
73
+ gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
74
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
75
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
76
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"),
77
+ model_dropdown,
78
+ ],
79
+ )
80
 
81
  if __name__ == "__main__":
82
  demo.launch()