Spaces:
Paused
Paused
| import deepsparse | |
| import gradio as gr | |
| from typing import Tuple, List | |
| deepsparse.cpu.print_hardware_capability() | |
| MODEL_ID = "zoo:llama2-7b-gsm8k_llama2_pretrain-pruned60_quantized" | |
| DESCRIPTION = f""" | |
| # Llama 2 Sparse Finetuned on GSM8k with DeepSparse | |
|  | |
| Model ID: {MODEL_ID} | |
| 🚀 **Experience the power of LLM mathematical reasoning** through [our Llama 2 sparse finetuned](https://arxiv.org/abs/2310.06927) on the [GSM8K dataset](https://huggingface.co/datasets/gsm8k). | |
| GSM8K, short for Grade School Math 8K, is a collection of 8.5K high-quality linguistically diverse grade school math word problems, designed to challenge question-answering systems with multi-step reasoning. | |
| Observe the model's performance in deciphering complex math questions and offering detailed step-by-step solutions. | |
| ## Accelerated Inferenced on CPUs | |
| The Llama 2 model runs purely on CPU courtesy of [sparse software execution by DeepSparse](https://github.com/neuralmagic/deepsparse/tree/main/research/mpt). | |
| DeepSparse provides accelerated inference by taking advantage of the model's weight sparsity to deliver tokens fast! | |
| %3C%2Fspan%3E%3C!-- HTML_TAG_END --> | |
| """ | |
| MAX_MAX_NEW_TOKENS = 1024 | |
| DEFAULT_MAX_NEW_TOKENS = 200 | |
| # Setup the engine | |
| pipe = deepsparse.TextGeneration(model=MODEL_ID, sequence_length=MAX_MAX_NEW_TOKENS, num_cores=8) | |
| def clear_and_save_textbox(message: str) -> Tuple[str, str]: | |
| return "", message | |
| def display_input( | |
| message: str, history: List[Tuple[str, str]] | |
| ) -> List[Tuple[str, str]]: | |
| history.append((message, "")) | |
| return history | |
| def delete_prev_fn(history: List[Tuple[str, str]]) -> Tuple[List[Tuple[str, str]], str]: | |
| try: | |
| message, _ = history.pop() | |
| except IndexError: | |
| message = "" | |
| return history, message or "" | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Column(): | |
| gr.Markdown("""### Sparse Finetuned Llama Demo""") | |
| with gr.Group(): | |
| chatbot = gr.Chatbot(label="Chatbot") | |
| with gr.Row(): | |
| textbox = gr.Textbox( | |
| container=False, | |
| placeholder="Type a message...", | |
| scale=10, | |
| ) | |
| submit_button = gr.Button( | |
| "Submit", variant="primary", scale=1, min_width=0 | |
| ) | |
| with gr.Row(): | |
| retry_button = gr.Button("🔄 Retry", variant="secondary") | |
| undo_button = gr.Button("↩️ Undo", variant="secondary") | |
| clear_button = gr.Button("🗑️ Clear", variant="secondary") | |
| saved_input = gr.State() | |
| gr.Examples( | |
| examples=[ | |
| "James decides to run 3 sprints 3 times a week. He runs 60 meters each sprint. How many total meters does he run a week?", | |
| "Claire makes a 3 egg omelet every morning for breakfast. How many dozens of eggs will she eat in 4 weeks?", | |
| "Gretchen has 110 coins. There are 30 more gold coins than silver coins. How many gold coins does Gretchen have?", | |
| ], | |
| inputs=[textbox], | |
| ) | |
| max_new_tokens = gr.Slider( | |
| label="Max new tokens", | |
| value=DEFAULT_MAX_NEW_TOKENS, | |
| minimum=0, | |
| maximum=MAX_MAX_NEW_TOKENS, | |
| step=1, | |
| interactive=True, | |
| info="The maximum numbers of new tokens", | |
| ) | |
| temperature = gr.Slider( | |
| label="Temperature", | |
| value=0.3, | |
| minimum=0.05, | |
| maximum=1.0, | |
| step=0.05, | |
| interactive=True, | |
| info="Higher values produce more diverse outputs", | |
| ) | |
| # Generation inference | |
| def generate( | |
| message, | |
| history, | |
| max_new_tokens: int, | |
| temperature: float, | |
| ): | |
| generation_config = { | |
| "max_new_tokens": max_new_tokens, | |
| "temperature": temperature, | |
| } | |
| inference = pipe(sequences=message, streaming=True, **generation_config) | |
| history[-1][1] += message | |
| for token in inference: | |
| history[-1][1] += token.generations[0].text | |
| yield history | |
| # history[-1][1] += inference.generations[0].text | |
| print(pipe.timer_manager) | |
| # return history | |
| textbox.submit( | |
| fn=clear_and_save_textbox, | |
| inputs=textbox, | |
| outputs=[textbox, saved_input], | |
| api_name=False, | |
| queue=False, | |
| ).then( | |
| fn=display_input, | |
| inputs=[saved_input, chatbot], | |
| outputs=chatbot, | |
| api_name=False, | |
| queue=False, | |
| ).success( | |
| generate, | |
| inputs=[ | |
| saved_input, | |
| chatbot, | |
| max_new_tokens, | |
| temperature, | |
| ], | |
| outputs=[chatbot], | |
| api_name=False, | |
| ) | |
| submit_button.click( | |
| fn=clear_and_save_textbox, | |
| inputs=textbox, | |
| outputs=[textbox, saved_input], | |
| api_name=False, | |
| queue=False, | |
| ).then( | |
| fn=display_input, | |
| inputs=[saved_input, chatbot], | |
| outputs=chatbot, | |
| api_name=False, | |
| queue=False, | |
| ).success( | |
| generate, | |
| inputs=[ | |
| saved_input, | |
| chatbot, | |
| max_new_tokens, | |
| temperature, | |
| ], | |
| outputs=[chatbot], | |
| api_name=False, | |
| ) | |
| retry_button.click( | |
| fn=delete_prev_fn, | |
| inputs=chatbot, | |
| outputs=[chatbot, saved_input], | |
| api_name=False, | |
| queue=False, | |
| ).then( | |
| fn=display_input, | |
| inputs=[saved_input, chatbot], | |
| outputs=chatbot, | |
| api_name=False, | |
| queue=False, | |
| ).then( | |
| generate, | |
| inputs=[ | |
| saved_input, | |
| chatbot, | |
| max_new_tokens, | |
| temperature, | |
| ], | |
| outputs=[chatbot], | |
| api_name=False, | |
| ) | |
| undo_button.click( | |
| fn=delete_prev_fn, | |
| inputs=chatbot, | |
| outputs=[chatbot, saved_input], | |
| api_name=False, | |
| queue=False, | |
| ).then( | |
| fn=lambda x: x, | |
| inputs=[saved_input], | |
| outputs=textbox, | |
| api_name=False, | |
| queue=False, | |
| ) | |
| clear_button.click( | |
| fn=lambda: ([], ""), | |
| outputs=[chatbot, saved_input], | |
| queue=False, | |
| api_name=False, | |
| ) | |
| demo.queue().launch() | |