File size: 4,461 Bytes
ce929d0 be09bfa ce929d0 be09bfa ce929d0 be09bfa ce929d0 248fe25 be09bfa ce929d0 be09bfa ce929d0 30bd2c9 ff12e01 30bd2c9 ff12e01 f901d63 ff12e01 be09bfa d071e42 9274c9f be09bfa ce929d0 d071e42 be09bfa d071e42 be09bfa f901d63 ce929d0 be09bfa ce929d0 be09bfa ce929d0 ff12e01 ce929d0 be09bfa ce929d0 ff12e01 ce929d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import gradio as gr
from transformers import AutoTokenizer
from flashpack.integrations.transformers import FlashPackTransformersModelMixin
from transformers import AutoModelForCausalLM, pipeline as hf_pipeline
# ============================================================
# 1️⃣ Define FlashPack-enabled model class
# ============================================================
class FlashPackGemmaModel(AutoModelForCausalLM, FlashPackTransformersModelMixin):
"""Gemma 3 model wrapped with FlashPackTransformersModelMixin"""
pass
# ============================================================
# 2️⃣ Load tokenizer
# ============================================================
MODEL_ID = "gokaygokay/prompt-enhancer-gemma-3-270m-it"
FLASHPACK_REPO = "rahul7star/FlashPack"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
# ============================================================
# 3️⃣ Load or create FlashPack model
# ============================================================
try:
print("📂 Loading model from FlashPack repository...")
model = FlashPackGemmaModel.from_pretrained_flashpack(FLASHPACK_REPO)
except FileNotFoundError:
print("⚠️ FlashPack model not found. Loading from HF Hub and uploading FlashPack...")
model = FlashPackGemmaModel.from_pretrained(MODEL_ID)
model.save_pretrained_flashpack(FLASHPACK_REPO, push_to_hub=True)
print(f"✅ FlashPack model uploaded to Hugging Face Hub: {FLASHPACK_REPO}")
# ============================================================
# 4️⃣ Build text-generation pipeline
# ============================================================
pipe = hf_pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device_map="auto"
)
# ============================================================
# 5️⃣ Define prompt enhancement function
# ============================================================
def enhance_prompt(user_prompt, temperature, max_tokens, chat_history):
chat_history = chat_history or []
messages = [
{"role": "system", "content": "Enhance and expand the following prompt with more details and context:"},
{"role": "user", "content": user_prompt},
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
outputs = pipe(
prompt,
max_new_tokens=int(max_tokens),
temperature=float(temperature),
do_sample=True
)
enhanced = outputs[0]["generated_text"].strip()
chat_history.append({"role": "user", "content": user_prompt})
chat_history.append({"role": "assistant", "content": enhanced})
return chat_history
# ============================================================
# 6️⃣ Gradio UI
# ============================================================
with gr.Blocks(title="Prompt Enhancer – Gemma 3 270M", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# ✨ Prompt Enhancer (Gemma 3 270M)
Enter a short prompt, and the model will **expand it with details and creative context**
using the Gemma chat-template interface.
"""
)
with gr.Row():
chatbot = gr.Chatbot(height=400, label="Enhanced Prompts", type="messages")
with gr.Column(scale=1):
user_prompt = gr.Textbox(
placeholder="Enter a short prompt...",
label="Your Prompt",
lines=3,
)
temperature = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Temperature")
max_tokens = gr.Slider(32, 256, value=128, step=16, label="Max Tokens")
send_btn = gr.Button("🚀 Enhance Prompt", variant="primary")
clear_btn = gr.Button("🧹 Clear Chat")
# Bind UI actions
send_btn.click(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
user_prompt.submit(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
clear_btn.click(lambda: [], None, chatbot)
gr.Markdown(
"""
---
💡 **Tips:**
- Works best with short, descriptive prompts (e.g., "a cat sitting on a chair")
- Increase *Temperature* for more creative output.
"""
)
# ============================================================
# 7️⃣ Launch
# ============================================================
if __name__ == "__main__":
demo.launch(show_error=True)
|