|
|
import asyncio |
|
|
import random |
|
|
import gradio as gr |
|
|
from datasets import load_dataset |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
|
|
|
|
print("===== Application Startup =====") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Loading model...") |
|
|
model_name = "gpt2" |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
model = AutoModelForCausalLM.from_pretrained(model_name) |
|
|
generator = pipeline("text-generation", model=model, tokenizer=tokenizer) |
|
|
print("Model loaded successfully.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Fetching dataset...") |
|
|
dataset = load_dataset("lvwerra/stack-exchange-paired", split="train[:200]") |
|
|
|
|
|
print(f"Total prompts available: {len(dataset)}") |
|
|
|
|
|
|
|
|
initial_prompts = dataset[:20] |
|
|
remaining_prompts = dataset[20:] |
|
|
|
|
|
|
|
|
prompts = [] |
|
|
for item in initial_prompts: |
|
|
prompts.append(item["question"]) |
|
|
|
|
|
print(f"Loaded {len(prompts)} initial prompts for fast startup.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def load_remaining_prompts(): |
|
|
print("Background: Loading remaining prompts...") |
|
|
await asyncio.sleep(2) |
|
|
for item in remaining_prompts: |
|
|
prompts.append(item["question"]) |
|
|
print(f"Background: Finished loading. Total prompts now = {len(prompts)}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def chat_with_model(user_input): |
|
|
"""Respond to user with a random dataset prompt + model output.""" |
|
|
if not prompts: |
|
|
return "Prompts not ready yet. Please wait..." |
|
|
prompt = random.choice(prompts) |
|
|
response = generator(f"{prompt}\n\nUser: {user_input}\nAI:", |
|
|
max_length=100, |
|
|
num_return_sequences=1, |
|
|
do_sample=True)[0]["generated_text"] |
|
|
return response |
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=chat_with_model, |
|
|
inputs=gr.Textbox(lines=2, placeholder="Ask me something..."), |
|
|
outputs="text", |
|
|
title="Fast Prompt Loader Chatbot", |
|
|
description="Loads 20 prompts fast, then background loads 200+ prompts" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
loop = asyncio.get_event_loop() |
|
|
loop.create_task(load_remaining_prompts()) |
|
|
demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|
|