File size: 5,448 Bytes
35c9e22 2c51b55 ac3979b 775d1cf 2c51b55 401afad 2c51b55 401afad 35c9e22 ac3979b 401afad ac3979b 401afad 77384a1 161ef48 2c51b55 ac3979b 2c51b55 401afad fd97af7 35c9e22 ac3979b 401afad ac3979b 401afad ac3979b 401afad ac3979b 840691e fd97af7 ac3979b fd97af7 8c48eed 401afad 35c9e22 4d4640c 161ef48 4d4640c 161ef48 4d4640c 161ef48 4d4640c 161ef48 ac3979b 4ed1f6e 840691e a02142f 4d4640c 840691e 4d4640c 161ef48 4d4640c 161ef48 4d4640c 4ed1f6e 840691e 2c51b55 401afad 2c51b55 401afad ac3979b 401afad 8c48eed 401afad 7e4024e 401afad 7e4024e 35c9e22 ac3979b 421c62b 401afad 470c095 35c9e22 401afad ac3979b 401afad 35c9e22 2c51b55 401afad 2c51b55 401afad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
# ============================================================
# 1️⃣ Load model and tokenizer
# ============================================================
MODEL_ID = "gokaygokay/prompt-enhancer-gemma-3-270m-it"
# Use GPU if available
device = 0 if torch.cuda.is_available() else -1
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device=device, # 0 for GPU, -1 for CPU
)
# ============================================================
# 2️⃣ Define the generation function (chat-template style)
# ============================================================
def enhance_prompt(user_prompt, temperature, max_tokens, chat_history):
chat_history = chat_history or []
# Build messages using proper roles
messages = [
{"role": "system", "content": "Enhance and expand the following prompt with more details and context:"},
{"role": "user", "content": user_prompt}
]
# Use tokenizer chat template to build the input
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# Generate output
output = pipe(
prompt,
max_new_tokens=int(max_tokens),
temperature=float(temperature),
do_sample=True,
)[0]["generated_text"].strip()
print(output)
print(output[0])
# Append conversation to history
chat_history.append({"role": "user", "content": user_prompt})
chat_history.append({"role": "assistant", "content": output})
return chat_history
import re
def extract_later_part(user_prompt, generated_text):
"""
Cleans the model output and extracts only the enhanced (later) portion.
Removes prompt echoes and system tags like <end_of_turn>, <start_of_turn>, etc.
"""
# Step 1: Clean up internal tags
cleaned = re.sub(r"<.*?>", "", generated_text) # Remove <end_of_turn>, <start_of_turn>, etc.
cleaned = cleaned.strip()
# Step 2: Normalize spaces
cleaned = re.sub(r"\s+", " ", cleaned)
# Step 3: Try removing the original prompt if repeated
user_prompt_clean = user_prompt.strip().lower()
cleaned_lower = cleaned.lower()
if cleaned_lower.startswith(user_prompt_clean):
cleaned = cleaned[len(user_prompt):].strip(",. ").strip()
return cleaned
# ===================== Prompt Enhancer Function =====================
def enhance_prompt1(user_prompt, temperature, max_tokens, chat_history):
messages = [
{"role": "system", "content": "Enhance and expand the following prompt with more details and context:"},
{"role": "user", "content": user_prompt}
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
output = pipe(prompt, max_new_tokens=256)
raw_output = output[0]['generated_text']
print("=== RAW MODEL OUTPUT ===")
print(raw_output)
# Extract the cleaned, later portion
later_part = extract_later_part(user_prompt, raw_output)
print("=== EXTRACTED CLEANED OUTPUT ===")
print(later_part)
# Append to chat history for Gradio
chat_history = chat_history or []
chat_history.append({"role": "user", "content": user_prompt})
chat_history.append({"role": "assistant", "content": later_part})
return chat_history
# ============================================================
# 3️⃣ Gradio UI
# ============================================================
with gr.Blocks(title="Prompt Enhancer – Gemma 3 270M", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# ✨ Prompt Enhancer (Gemma 3 270M)
Enter a short prompt, and the model will **expand it with details and creative context**
using the Gemma chat-template interface.
"""
)
with gr.Row():
chatbot = gr.Chatbot(height=400, label="Enhanced Prompts", type="messages")
with gr.Column(scale=1):
user_prompt = gr.Textbox(
placeholder="Enter a short prompt...",
label="Your Prompt",
lines=3,
)
temperature = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Temperature")
max_tokens = gr.Slider(32, 256, value=128, step=16, label="Max Tokens")
send_btn = gr.Button("🚀 dev dont click", variant="primary")
clear_btn = gr.Button("🧹 Clear Chat")
add_btn = gr.Button("🚀 Enchance Prompt", variant="primary")
# Bind UI actions
#send_btn.click(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
user_prompt.submit(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
clear_btn.click(lambda: [], None, chatbot)
add_btn.click(enhance_prompt1, [user_prompt, temperature, max_tokens, chatbot], chatbot)
gr.Markdown(
"""
---
💡 **Tips:**
- Works best with short, descriptive prompts (e.g., "a cat sitting on a chair")
- Increase *Temperature* for more creative output.
"""
)
# ============================================================
# 4️⃣ Launch
# ============================================================
if __name__ == "__main__":
demo.launch(show_error=True)
|