File size: 4,461 Bytes
ce929d0
be09bfa
ce929d0
be09bfa
ce929d0
 
be09bfa
ce929d0
248fe25
be09bfa
ce929d0
 
 
be09bfa
ce929d0
30bd2c9
ff12e01
30bd2c9
ff12e01
 
 
 
 
f901d63
 
 
ff12e01
be09bfa
d071e42
9274c9f
be09bfa
ce929d0
d071e42
be09bfa
d071e42
be09bfa
f901d63
 
 
 
 
ce929d0
 
be09bfa
ce929d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be09bfa
ce929d0
 
 
 
 
 
 
 
ff12e01
ce929d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be09bfa
ce929d0
 
 
 
 
 
 
 
 
 
 
 
 
 
ff12e01
ce929d0
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import gradio as gr
from transformers import AutoTokenizer
from flashpack.integrations.transformers import FlashPackTransformersModelMixin
from transformers import AutoModelForCausalLM, pipeline as hf_pipeline

# ============================================================
# 1️⃣ Define FlashPack-enabled model class
# ============================================================
class FlashPackGemmaModel(AutoModelForCausalLM, FlashPackTransformersModelMixin):
    """Gemma 3 model wrapped with FlashPackTransformersModelMixin"""
    pass

# ============================================================
# 2️⃣ Load tokenizer
# ============================================================
MODEL_ID = "gokaygokay/prompt-enhancer-gemma-3-270m-it"
FLASHPACK_REPO = "rahul7star/FlashPack"

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# ============================================================
# 3️⃣ Load or create FlashPack model
# ============================================================
try:
    print("📂 Loading model from FlashPack repository...")
    model = FlashPackGemmaModel.from_pretrained_flashpack(FLASHPACK_REPO)
except FileNotFoundError:
    print("⚠️ FlashPack model not found. Loading from HF Hub and uploading FlashPack...")
    model = FlashPackGemmaModel.from_pretrained(MODEL_ID)
    model.save_pretrained_flashpack(FLASHPACK_REPO, push_to_hub=True)
    print(f"✅ FlashPack model uploaded to Hugging Face Hub: {FLASHPACK_REPO}")

# ============================================================
# 4️⃣ Build text-generation pipeline
# ============================================================
pipe = hf_pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    device_map="auto"
)

# ============================================================
# 5️⃣ Define prompt enhancement function
# ============================================================
def enhance_prompt(user_prompt, temperature, max_tokens, chat_history):
    chat_history = chat_history or []

    messages = [
        {"role": "system", "content": "Enhance and expand the following prompt with more details and context:"},
        {"role": "user", "content": user_prompt},
    ]

    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    outputs = pipe(
        prompt,
        max_new_tokens=int(max_tokens),
        temperature=float(temperature),
        do_sample=True
    )
    enhanced = outputs[0]["generated_text"].strip()

    chat_history.append({"role": "user", "content": user_prompt})
    chat_history.append({"role": "assistant", "content": enhanced})
    return chat_history

# ============================================================
# 6️⃣ Gradio UI
# ============================================================
with gr.Blocks(title="Prompt Enhancer – Gemma 3 270M", theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        # ✨ Prompt Enhancer (Gemma 3 270M)
        Enter a short prompt, and the model will **expand it with details and creative context**  
        using the Gemma chat-template interface.
        """
    )

    with gr.Row():
        chatbot = gr.Chatbot(height=400, label="Enhanced Prompts", type="messages")
        with gr.Column(scale=1):
            user_prompt = gr.Textbox(
                placeholder="Enter a short prompt...",
                label="Your Prompt",
                lines=3,
            )
            temperature = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Temperature")
            max_tokens = gr.Slider(32, 256, value=128, step=16, label="Max Tokens")
            send_btn = gr.Button("🚀 Enhance Prompt", variant="primary")
            clear_btn = gr.Button("🧹 Clear Chat")

    # Bind UI actions
    send_btn.click(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
    user_prompt.submit(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
    clear_btn.click(lambda: [], None, chatbot)

    gr.Markdown(
        """
        ---
        💡 **Tips:**
        - Works best with short, descriptive prompts (e.g., "a cat sitting on a chair")
        - Increase *Temperature* for more creative output.
        """
    )

# ============================================================
# 7️⃣ Launch
# ============================================================
if __name__ == "__main__":
    demo.launch(show_error=True)