Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| from random import sample | |
| from detoxify import Detoxify | |
| from datasets import load_dataset | |
| from huggingface_hub import HfApi, ModelFilter, ModelSearchArguments | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM | |
| from transformers import BloomTokenizerFast, BloomForCausalLM | |
| HF_AUTH_TOKEN = os.environ.get("hf_token" or True) | |
| DATASET = "allenai/real-toxicity-prompts" | |
| CHECKPOINTS = { | |
| "DistilGPT2 by HuggingFace π€": "distilgpt2", | |
| "GPT-Neo 125M by EleutherAI π€": "EleutherAI/gpt-neo-125M", | |
| "BLOOM 560M by BigScience πΈ": "bigscience/bloom-560m", | |
| "Custom Model": None, | |
| } | |
| MODEL_CLASSES = { | |
| "DistilGPT2 by HuggingFace π€": (GPT2LMHeadModel, GPT2Tokenizer), | |
| "GPT-Neo 125M by EleutherAI π€": (GPTNeoForCausalLM, GPT2Tokenizer), | |
| "BLOOM 560M by BigScience πΈ": (BloomForCausalLM, BloomTokenizerFast), | |
| "Custom Model": (AutoModelForCausalLM, AutoTokenizer), | |
| } | |
| CHOICES = sorted(list(CHECKPOINTS.keys())[:3]) | |
| def load_model(model_name, custom_model_path, token): | |
| try: | |
| model_class, tokenizer_class = MODEL_CLASSES[model_name] | |
| model_path = CHECKPOINTS[model_name] | |
| except KeyError: | |
| model_class, tokenizer_class = MODEL_CLASSES["Custom Model"] | |
| model_path = custom_model_path or model_name | |
| model = model_class.from_pretrained(model_path, use_auth_token=token) | |
| tokenizer = tokenizer_class.from_pretrained(model_path, use_auth_token=token) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model.config.pad_token_id = model.config.eos_token_id | |
| model.eval() | |
| return model, tokenizer | |
| MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop | |
| def set_seed(seed, n_gpu): | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| if n_gpu > 0: | |
| torch.cuda.manual_seed_all(seed) | |
| def adjust_length_to_model(length, max_sequence_length): | |
| if length < 0 and max_sequence_length > 0: | |
| length = max_sequence_length | |
| elif 0 < max_sequence_length < length: | |
| length = max_sequence_length # No generation bigger than model size | |
| elif length < 0: | |
| length = MAX_LENGTH # avoid infinite loop | |
| return length | |
| def generate( | |
| model_name, | |
| token, | |
| custom_model_path, | |
| input_sentence, | |
| length=75, | |
| temperature=0.7, | |
| top_k=50, | |
| top_p=0.95, | |
| seed=42, | |
| no_cuda=False, | |
| num_return_sequences=1, | |
| stop_token=".", | |
| ): | |
| # load device | |
| # if not no_cuda: | |
| device = torch.device( | |
| "cuda" if torch.cuda.is_available() and not no_cuda else "cpu" | |
| ) | |
| n_gpu = 0 if no_cuda else torch.cuda.device_count() | |
| # Set seed | |
| set_seed(seed, n_gpu) | |
| # Load model | |
| model, tokenizer = load_model(model_name, custom_model_path, token) | |
| model.to(device) | |
| # length = adjust_length_to_model(length, max_sequence_length=model.config.max_position_embeddings) | |
| # Tokenize input | |
| encoded_prompt = tokenizer.encode( | |
| input_sentence, add_special_tokens=False, return_tensors="pt" | |
| ) | |
| encoded_prompt = encoded_prompt.to(device) | |
| input_ids = encoded_prompt | |
| # Generate output | |
| output_sequences = model.generate( | |
| input_ids=input_ids, | |
| max_length=length + len(encoded_prompt[0]), | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p, | |
| do_sample=True, | |
| num_return_sequences=num_return_sequences, | |
| ) | |
| generated_sequences = list() | |
| for generated_sequence_idx, generated_sequence in enumerate(output_sequences): | |
| generated_sequence = generated_sequence.tolist() | |
| text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True) | |
| # remove prompt | |
| text = text[ | |
| len( | |
| tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True) | |
| ) : | |
| ] | |
| # remove all text after last occurence of stop_token | |
| text = text[: text.rfind(stop_token) + 1] | |
| generated_sequences.append(text) | |
| return generated_sequences[0] | |
| def show_mode(mode): | |
| if mode == "Single Model": | |
| return (gr.update(visible=True), gr.update(visible=False)) | |
| if mode == "Multi-Model": | |
| return (gr.update(visible=False), gr.update(visible=True)) | |
| def prepare_dataset(dataset): | |
| dataset = load_dataset(dataset, split="train") | |
| return dataset | |
| def load_prompts(dataset): | |
| prompts = [dataset[i]["prompt"]["text"] for i in range(len(dataset))] | |
| return prompts | |
| def random_sample(prompt_list): | |
| random_sample = sample(prompt_list, 10) | |
| return random_sample | |
| def show_dataset(dataset): | |
| raw_data = prepare_dataset(dataset) | |
| prompts = load_prompts(raw_data) | |
| return ( | |
| gr.update( | |
| choices=random_sample(prompts), | |
| label="You can find below a random subset from the RealToxicityPrompts dataset", | |
| visible=True, | |
| ), | |
| gr.update(visible=True), | |
| prompts, | |
| ) | |
| def update_dropdown(prompts): | |
| return gr.update(choices=random_sample(prompts)) | |
| def show_search_bar(value): | |
| if value == "Custom Model": | |
| return (value, gr.update(visible=True)) | |
| else: | |
| return (value, gr.update(visible=False)) | |
| def search_model(model_name, token): | |
| api = HfApi() | |
| model_args = ModelSearchArguments() | |
| filt = ModelFilter( | |
| task=model_args.pipeline_tag.TextGeneration, library=model_args.library.PyTorch | |
| ) | |
| results = api.list_models(filter=filt, search=model_name, use_auth_token=token) | |
| model_list = [model.modelId for model in results] | |
| return gr.update( | |
| visible=True, | |
| choices=model_list, | |
| label="Choose the model", | |
| ) | |
| def show_api_key_textbox(checkbox): | |
| if checkbox: | |
| return gr.update(visible=True) | |
| else: | |
| return gr.update(visible=False) | |
| def forward_model_choice(model_choice_path): | |
| return (model_choice_path, model_choice_path) | |
| def auto_complete(input, generated): | |
| output = input + " " + generated | |
| output_spans = [{"entity": "OUTPUT", "start": len(input), "end": len(output)}] | |
| completed_prompt = {"text": output, "entities": output_spans} | |
| return completed_prompt | |
| def process_user_input( | |
| model, token, custom_model_path, input, length, temperature, top_p, top_k | |
| ): | |
| warning = "Please enter a valid prompt." | |
| if input == None: | |
| generated = warning | |
| else: | |
| generated = generate( | |
| model_name=model, | |
| token=token, | |
| custom_model_path=custom_model_path, | |
| input_sentence=input, | |
| length=length, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| ) | |
| generated = generated.replace("\n", " ") | |
| generated_with_spans = auto_complete(input=input, generated=generated) | |
| return ( | |
| gr.update(value=generated_with_spans), | |
| gr.update(visible=True), | |
| gr.update(visible=True), | |
| input, | |
| generated, | |
| ) | |
| def pass_to_textbox(input): | |
| return gr.update(value=input) | |
| def run_detoxify(text): | |
| results = Detoxify("original").predict(text) | |
| json_ready_results = {cat: float(score) for (cat, score) in results.items()} | |
| return json_ready_results | |
| def compute_toxi_output(output_text): | |
| scores = run_detoxify(output_text) | |
| return (gr.update(value=scores, visible=True), gr.update(visible=True)) | |
| def compute_change(input, output): | |
| change_percent = round(((float(output) - input) / input) * 100, 2) | |
| return change_percent | |
| def compare_toxi_scores(input_text, output_scores): | |
| input_scores = run_detoxify(input_text) | |
| json_ready_results = {cat: float(score) for (cat, score) in input_scores.items()} | |
| compare_scores = { | |
| cat: compute_change(json_ready_results[cat], output_scores[cat]) | |
| for cat in json_ready_results | |
| for cat in output_scores | |
| } | |
| return ( | |
| gr.update(value=json_ready_results, visible=True), | |
| gr.update(value=compare_scores, visible=True), | |
| ) | |
| def show_flag_choices(): | |
| return gr.update(visible=True) | |
| def update_flag(flag_value): | |
| return ( | |
| flag_value, | |
| gr.update(visible=True), | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| ) | |
| def upload_flag(*args): | |
| flags = list(args) | |
| flags[1] = bytes(flags[1], "utf-8") | |
| flagging_callback.flag(flags) | |
| return gr.update(visible=True) | |
| def forward_model_choice_multi(model_choice_path): | |
| CHOICES.append(model_choice_path) | |
| return gr.update(choices=CHOICES) | |
| def process_user_input_multi(models, input, token, length, temperature, top_p, top_k): | |
| warning = "Please enter a valid prompt." | |
| if input == None: | |
| generated = warning | |
| else: | |
| generated_dict = { | |
| model: generate( | |
| model_name=model, | |
| token=token, | |
| custom_model_path=None, | |
| input_sentence=input, | |
| length=length, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| ) | |
| for model in sorted(models) | |
| } | |
| generated_with_spans_dict = { | |
| model: auto_complete(input, generated) | |
| for model, generated in generated_dict.items() | |
| } | |
| update_outputs = [ | |
| gr.HighlightedText.update(value=output, label=model) | |
| for model, output in generated_with_spans_dict.items() | |
| ] | |
| update_hide = [ | |
| gr.HighlightedText.update(visible=False) for i in range(10 - len(models)) | |
| ] | |
| return update_outputs + update_hide | |
| def show_choices_multi(models): | |
| update_show = [gr.HighlightedText.update(visible=True) for model in sorted(models)] | |
| update_hide = [ | |
| gr.HighlightedText.update(visible=False, value=None, label=None) | |
| for i in range(10 - len(models)) | |
| ] | |
| return update_show + update_hide | |
| def show_params(checkbox): | |
| if checkbox == True: | |
| return gr.update(visible=True) | |
| else: | |
| return gr.update(visible=False) | |
| CSS = """ | |
| #inside_group { | |
| padding-top: 0.6em; | |
| padding-bottom: 0.6em; | |
| } | |
| #pw textarea { | |
| -webkit-text-security: disc; | |
| } | |
| """ | |
| with gr.Blocks(css=CSS) as demo: | |
| dataset = gr.Variable(value=DATASET) | |
| prompts_var = gr.Variable(value=None) | |
| input_var = gr.Variable(label="Input Prompt", value=None) | |
| output_var = gr.Variable(label="Output", value=None) | |
| model_choice = gr.Variable(label="Model", value=None) | |
| custom_model_path = gr.Variable(value=None) | |
| flag_choice = gr.Variable(label="Flag", value=None) | |
| flagging_callback = gr.HuggingFaceDatasetSaver( | |
| hf_token=HF_AUTH_TOKEN, | |
| dataset_name="fsdlredteam/flagged_3", | |
| private=True, | |
| ) | |
| gr.Markdown("<p align='center'><img src='https://i.imgur.com/ZxbbLUQ.png>'/></p>") | |
| gr.Markdown("<h1 align='center'>BuggingSpace</h1>") | |
| gr.Markdown( | |
| "<h2 align='center'>FSDL 2022 Red-Teaming Open-Source Models Project</h2>" | |
| ) | |
| gr.Markdown( | |
| "### Pick a text generation model below, write a prompt and explore the output" | |
| ) | |
| gr.Markdown("### Or compare the output of multiple models at the same time") | |
| choose_mode = gr.Radio( | |
| choices=["Single Model", "Multi-Model"], | |
| value="Single Model", | |
| interactive=True, | |
| visible=True, | |
| show_label=False, | |
| ) | |
| with gr.Group() as single_model: | |
| gr.Markdown( | |
| "You can upload any model from the Hugging Face hub -even private ones, \ | |
| provided you use your private key! " | |
| "Write your prompt or alternatively use one from the \ | |
| [RealToxicityPrompts](https://allenai.org/data/real-toxicity-prompts) dataset." | |
| ) | |
| gr.Markdown( | |
| "Use it to audit the model for potential failure modes, \ | |
| analyse its output with the Detoxify suite and contribute by reporting any problematic result." | |
| ) | |
| gr.Markdown( | |
| "Beware ! Generation can take up to a few minutes with very large models." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): # input & prompts dataset exploration | |
| gr.Markdown("### 1. Select a prompt", elem_id="inside_group") | |
| input_text = gr.Textbox( | |
| label="Write your prompt below.", | |
| interactive=True, | |
| lines=4, | |
| elem_id="inside_group", | |
| ) | |
| gr.Markdown("β or β", elem_id="inside_group") | |
| inspo_button = gr.Button( | |
| "Click here if you need some inspiration", elem_id="inside_group" | |
| ) | |
| prompts_drop = gr.Dropdown(visible=False, elem_id="inside_group") | |
| randomize_button = gr.Button( | |
| "Show another subset", visible=False, elem_id="inside_group" | |
| ) | |
| show_params_checkbox_single = gr.Checkbox( | |
| label="Set custom params", interactive=True, value=False | |
| ) | |
| with gr.Box(visible=False) as params_box_single: | |
| length_single = gr.Slider( | |
| label="Output length", | |
| visible=True, | |
| interactive=True, | |
| minimum=50, | |
| maximum=200, | |
| value=75, | |
| ) | |
| top_k_single = gr.Slider( | |
| label="top_k", | |
| visible=True, | |
| interactive=True, | |
| minimum=1, | |
| maximum=100, | |
| value=50, | |
| ) | |
| top_p_single = gr.Slider( | |
| label="top_p", | |
| visible=True, | |
| interactive=True, | |
| minimum=0.1, | |
| maximum=1, | |
| value=0.95, | |
| ) | |
| temperature_single = gr.Slider( | |
| label="temperature", | |
| visible=True, | |
| interactive=True, | |
| minimum=0.1, | |
| maximum=1, | |
| value=0.7, | |
| ) | |
| with gr.Column(scale=1): # Model choice & output | |
| gr.Markdown("### 2. Evaluate output") | |
| model_radio = gr.Radio( | |
| choices=list(CHECKPOINTS.keys()), | |
| label="Model", | |
| interactive=True, | |
| elem_id="inside_group", | |
| ) | |
| search_bar = gr.Textbox( | |
| label="Search model", | |
| interactive=True, | |
| visible=False, | |
| elem_id="inside_group", | |
| ) | |
| model_drop = gr.Dropdown(visible=False) | |
| private_checkbox = gr.Checkbox( | |
| visible=True, label="Private Model ?", elem_id="inside_group" | |
| ) | |
| api_key_textbox = gr.Textbox( | |
| label="Enter your AUTH TOKEN below", | |
| value=None, | |
| interactive=True, | |
| visible=False, | |
| elem_id="pw", | |
| ) | |
| generate_button = gr.Button( | |
| "Submit your prompt", elem_id="inside_group" | |
| ) | |
| output_spans = gr.HighlightedText(visible=True, label="Generated text") | |
| flag_button = gr.Button( | |
| "Report output here", visible=False, elem_id="inside_group" | |
| ) | |
| with gr.Row(): # Flagging | |
| with gr.Column(scale=1): | |
| flag_radio = gr.Radio( | |
| choices=[ | |
| "Toxic", | |
| "Offensive", | |
| "Repetitive", | |
| "Incorrect", | |
| "Other", | |
| ], | |
| label="What's wrong with the output ?", | |
| interactive=True, | |
| visible=False, | |
| elem_id="inside_group", | |
| ) | |
| user_comment = gr.Textbox( | |
| label="(Optional) Briefly describe the issue", | |
| visible=False, | |
| interactive=True, | |
| elem_id="inside_group", | |
| ) | |
| confirm_flag_button = gr.Button( | |
| "Confirm report", visible=False, elem_id="inside_group" | |
| ) | |
| with gr.Row(): # Flagging success | |
| success_message = gr.Markdown( | |
| "Your report has been successfully registered. Thank you!", | |
| visible=False, | |
| elem_id="inside_group", | |
| ) | |
| with gr.Row(): # Toxicity buttons | |
| toxi_button = gr.Button( | |
| "Run a toxicity analysis of the model's output", | |
| visible=False, | |
| elem_id="inside_group", | |
| ) | |
| toxi_button_compare = gr.Button( | |
| "Compare toxicity on input and output", | |
| visible=False, | |
| elem_id="inside_group", | |
| ) | |
| with gr.Row(): # Toxicity scores | |
| toxi_scores_input = gr.JSON( | |
| label="Detoxify classification of your input", | |
| visible=False, | |
| elem_id="inside_group", | |
| ) | |
| toxi_scores_output = gr.JSON( | |
| label="Detoxify classification of the model's output", | |
| visible=False, | |
| elem_id="inside_group", | |
| ) | |
| toxi_scores_compare = gr.JSON( | |
| label="Percentage change between Input and Output", | |
| visible=False, | |
| elem_id="inside_group", | |
| ) | |
| with gr.Group(visible=False) as multi_model: | |
| model_list = list() | |
| gr.Markdown( | |
| "#### Run the same input on multiple models and compare the outputs" | |
| ) | |
| gr.Markdown( | |
| "You can upload any model from the Hugging Face hub -even private ones, provided you use your private key!" | |
| ) | |
| gr.Markdown( | |
| "Use this feature to compare the same model at different checkpoints" | |
| ) | |
| gr.Markdown("Or to benchmark your model against another one as a reference.") | |
| gr.Markdown( | |
| "Beware ! Generation can take up to a few minutes with very large models." | |
| ) | |
| with gr.Row(elem_id="inside_group"): | |
| with gr.Column(): | |
| models_multi = gr.CheckboxGroup( | |
| choices=CHOICES, | |
| label="Models", | |
| interactive=True, | |
| elem_id="inside_group", | |
| value=None, | |
| ) | |
| with gr.Column(): | |
| generate_button_multi = gr.Button( | |
| "Submit your prompt", elem_id="inside_group" | |
| ) | |
| show_params_checkbox_multi = gr.Checkbox( | |
| label="Set custom params", interactive=True, value=False | |
| ) | |
| with gr.Box(visible=False) as params_box_multi: | |
| length_multi = gr.Slider( | |
| label="Output length", | |
| visible=True, | |
| interactive=True, | |
| minimum=50, | |
| maximum=200, | |
| value=75, | |
| ) | |
| top_k_multi = gr.Slider( | |
| label="top_k", | |
| visible=True, | |
| interactive=True, | |
| minimum=1, | |
| maximum=100, | |
| value=50, | |
| ) | |
| top_p_multi = gr.Slider( | |
| label="top_p", | |
| visible=True, | |
| interactive=True, | |
| minimum=0.1, | |
| maximum=1, | |
| value=0.95, | |
| ) | |
| temperature_multi = gr.Slider( | |
| label="temperature", | |
| visible=True, | |
| interactive=True, | |
| minimum=0.1, | |
| maximum=1, | |
| value=0.7, | |
| ) | |
| with gr.Row(elem_id="inside_group"): | |
| with gr.Column(elem_id="inside_group", scale=1): | |
| input_text_multi = gr.Textbox( | |
| label="Write your prompt below.", | |
| interactive=True, | |
| lines=4, | |
| elem_id="inside_group", | |
| ) | |
| with gr.Column(elem_id="inside_group", scale=1): | |
| search_bar_multi = gr.Textbox( | |
| label="Search another model", | |
| interactive=True, | |
| visible=True, | |
| elem_id="inside_group", | |
| ) | |
| model_drop_multi = gr.Dropdown(visible=False, elem_id="inside_group") | |
| private_checkbox_multi = gr.Checkbox( | |
| visible=True, label="Private Model ?" | |
| ) | |
| api_key_textbox_multi = gr.Textbox( | |
| label="Enter your AUTH TOKEN below", | |
| value=None, | |
| interactive=True, | |
| visible=False, | |
| elem_id="pw", | |
| ) | |
| with gr.Row() as outputs_row: | |
| for i in range(10): | |
| output_spans_multi = gr.HighlightedText( | |
| visible=False, elem_id="inside_group" | |
| ) | |
| model_list.append(output_spans_multi) | |
| with gr.Row(): | |
| gr.Markdown( | |
| "App made during the [FSDL course](https://fullstackdeeplearning.com) \ | |
| by Team53: Jean-Antoine, Sajenthan, Sashank, Kemp, Srihari, Astitwa" | |
| ) | |
| # Single Model | |
| choose_mode.change( | |
| fn=show_mode, inputs=choose_mode, outputs=[single_model, multi_model] | |
| ) | |
| inspo_button.click( | |
| fn=show_dataset, | |
| inputs=dataset, | |
| outputs=[prompts_drop, randomize_button, prompts_var], | |
| ) | |
| prompts_drop.change(fn=pass_to_textbox, inputs=prompts_drop, outputs=input_text) | |
| randomize_button.click( | |
| fn=update_dropdown, inputs=prompts_var, outputs=prompts_drop | |
| ), | |
| model_radio.change( | |
| fn=show_search_bar, inputs=model_radio, outputs=[model_choice, search_bar] | |
| ) | |
| search_bar.submit( | |
| fn=search_model, | |
| inputs=[search_bar, api_key_textbox], | |
| outputs=model_drop, | |
| show_progress=True, | |
| ) | |
| private_checkbox.change( | |
| fn=show_api_key_textbox, inputs=private_checkbox, outputs=api_key_textbox | |
| ) | |
| model_drop.change( | |
| fn=forward_model_choice, | |
| inputs=model_drop, | |
| outputs=[model_choice, custom_model_path], | |
| ) | |
| generate_button.click( | |
| fn=process_user_input, | |
| inputs=[ | |
| model_choice, | |
| api_key_textbox, | |
| custom_model_path, | |
| input_text, | |
| length_single, | |
| temperature_single, | |
| top_p_single, | |
| top_k_single, | |
| ], | |
| outputs=[output_spans, toxi_button, flag_button, input_var, output_var], | |
| show_progress=True, | |
| ) | |
| toxi_button.click( | |
| fn=compute_toxi_output, | |
| inputs=output_var, | |
| outputs=[toxi_scores_output, toxi_button_compare], | |
| show_progress=True, | |
| ) | |
| toxi_button_compare.click( | |
| fn=compare_toxi_scores, | |
| inputs=[input_text, toxi_scores_output], | |
| outputs=[toxi_scores_input, toxi_scores_compare], | |
| show_progress=True, | |
| ) | |
| flag_button.click(fn=show_flag_choices, inputs=None, outputs=flag_radio) | |
| flag_radio.change( | |
| fn=update_flag, | |
| inputs=flag_radio, | |
| outputs=[flag_choice, confirm_flag_button, user_comment, flag_button], | |
| ) | |
| flagging_callback.setup( | |
| [input_var, output_var, model_choice, user_comment, flag_choice], | |
| "flagged_data_points", | |
| ) | |
| confirm_flag_button.click( | |
| fn=upload_flag, | |
| inputs=[input_var, output_var, model_choice, user_comment, flag_choice], | |
| outputs=success_message, | |
| ) | |
| show_params_checkbox_single.change( | |
| fn=show_params, inputs=show_params_checkbox_single, outputs=params_box_single | |
| ) | |
| # Model comparison | |
| search_bar_multi.submit( | |
| fn=search_model, | |
| inputs=[search_bar_multi, api_key_textbox_multi], | |
| outputs=model_drop_multi, | |
| show_progress=True, | |
| ) | |
| show_params_checkbox_multi.change( | |
| fn=show_params, inputs=show_params_checkbox_multi, outputs=params_box_multi | |
| ) | |
| private_checkbox_multi.change( | |
| fn=show_api_key_textbox, | |
| inputs=private_checkbox_multi, | |
| outputs=api_key_textbox_multi, | |
| ) | |
| model_drop_multi.change( | |
| fn=forward_model_choice_multi, inputs=model_drop_multi, outputs=[models_multi] | |
| ) | |
| models_multi.change(fn=show_choices_multi, inputs=models_multi, outputs=model_list) | |
| generate_button_multi.click( | |
| fn=process_user_input_multi, | |
| inputs=[ | |
| models_multi, | |
| input_text_multi, | |
| api_key_textbox_multi, | |
| length_multi, | |
| temperature_multi, | |
| top_p_multi, | |
| top_k_multi, | |
| ], | |
| outputs=model_list, | |
| show_progress=True, | |
| ) | |
| if __name__ == "__main__": | |
| # demo.queue(concurrency_count=3) | |
| demo.launch(debug=True) | |