File size: 5,448 Bytes
35c9e22
2c51b55
ac3979b
775d1cf
2c51b55
401afad
2c51b55
401afad
35c9e22
ac3979b
401afad
 
 
 
 
 
 
 
 
ac3979b
401afad
77384a1
161ef48
2c51b55
ac3979b
2c51b55
401afad
fd97af7
35c9e22
ac3979b
 
 
 
 
401afad
ac3979b
 
 
 
401afad
ac3979b
401afad
 
 
ac3979b
840691e
 
fd97af7
ac3979b
fd97af7
 
8c48eed
401afad
35c9e22
4d4640c
 
161ef48
 
4d4640c
 
161ef48
4d4640c
 
 
 
 
 
 
 
161ef48
4d4640c
 
 
 
 
 
161ef48
ac3979b
4ed1f6e
840691e
a02142f
 
 
 
4d4640c
840691e
 
4d4640c
 
 
 
161ef48
4d4640c
 
 
 
161ef48
4d4640c
 
 
 
4ed1f6e
 
840691e
2c51b55
401afad
2c51b55
401afad
 
 
 
ac3979b
 
401afad
8c48eed
 
401afad
 
 
 
 
 
 
 
 
 
7e4024e
401afad
7e4024e
35c9e22
ac3979b
421c62b
401afad
 
470c095
35c9e22
401afad
 
 
ac3979b
 
 
401afad
 
35c9e22
2c51b55
401afad
2c51b55
401afad
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

# ============================================================
# 1️⃣ Load model and tokenizer
# ============================================================
MODEL_ID = "gokaygokay/prompt-enhancer-gemma-3-270m-it"

# Use GPU if available
device = 0 if torch.cuda.is_available() else -1

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID)

pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    device=device,  # 0 for GPU, -1 for CPU
)


# ============================================================
# 2️⃣ Define the generation function (chat-template style)
# ============================================================
def enhance_prompt(user_prompt, temperature, max_tokens, chat_history):
    chat_history = chat_history or []

    # Build messages using proper roles
    messages = [
        {"role": "system", "content": "Enhance and expand the following prompt with more details and context:"},
        {"role": "user", "content": user_prompt}
    ]

    # Use tokenizer chat template to build the input
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    # Generate output
    output = pipe(
        prompt,
        max_new_tokens=int(max_tokens),
        temperature=float(temperature),
        do_sample=True,
    )[0]["generated_text"].strip()
    print(output)
    print(output[0])

    # Append conversation to history
    chat_history.append({"role": "user", "content": user_prompt})
    chat_history.append({"role": "assistant", "content": output})

    return chat_history

import re

def extract_later_part(user_prompt, generated_text):
    """
    Cleans the model output and extracts only the enhanced (later) portion.
    Removes prompt echoes and system tags like <end_of_turn>, <start_of_turn>, etc.
    """
    # Step 1: Clean up internal tags
    cleaned = re.sub(r"<.*?>", "", generated_text)  # Remove <end_of_turn>, <start_of_turn>, etc.
    cleaned = cleaned.strip()

    # Step 2: Normalize spaces
    cleaned = re.sub(r"\s+", " ", cleaned)

    # Step 3: Try removing the original prompt if repeated
    user_prompt_clean = user_prompt.strip().lower()
    cleaned_lower = cleaned.lower()

    if cleaned_lower.startswith(user_prompt_clean):
        cleaned = cleaned[len(user_prompt):].strip(",. ").strip()

    return cleaned


# ===================== Prompt Enhancer Function =====================
def enhance_prompt1(user_prompt, temperature, max_tokens, chat_history):
    messages = [
        {"role": "system", "content": "Enhance and expand the following prompt with more details and context:"},
        {"role": "user", "content": user_prompt}
    ]

    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    output = pipe(prompt, max_new_tokens=256)
    raw_output = output[0]['generated_text']

    print("=== RAW MODEL OUTPUT ===")
    print(raw_output)

    # Extract the cleaned, later portion
    later_part = extract_later_part(user_prompt, raw_output)
    print("=== EXTRACTED CLEANED OUTPUT ===")
    print(later_part)

    # Append to chat history for Gradio
    chat_history = chat_history or []
    chat_history.append({"role": "user", "content": user_prompt})
    chat_history.append({"role": "assistant", "content": later_part})

    return chat_history

# ============================================================
# 3️⃣ Gradio UI
# ============================================================
with gr.Blocks(title="Prompt Enhancer – Gemma 3 270M", theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        # ✨ Prompt Enhancer (Gemma 3 270M)
        Enter a short prompt, and the model will **expand it with details and creative context**  
        using the Gemma chat-template interface.
        """
    )

    with gr.Row():
        chatbot = gr.Chatbot(height=400, label="Enhanced Prompts", type="messages")
        with gr.Column(scale=1):
            user_prompt = gr.Textbox(
                placeholder="Enter a short prompt...",
                label="Your Prompt",
                lines=3,
            )
            temperature = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Temperature")
            max_tokens = gr.Slider(32, 256, value=128, step=16, label="Max Tokens")
            send_btn = gr.Button("🚀 dev dont click", variant="primary")
            clear_btn = gr.Button("🧹 Clear Chat")
            add_btn = gr.Button("🚀 Enchance  Prompt", variant="primary")

    # Bind UI actions
    #send_btn.click(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
    user_prompt.submit(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
    clear_btn.click(lambda: [], None, chatbot)
    add_btn.click(enhance_prompt1, [user_prompt, temperature, max_tokens, chatbot], chatbot)

    gr.Markdown(
        """
        ---
        💡 **Tips:**
        - Works best with short, descriptive prompts (e.g., "a cat sitting on a chair")
        - Increase *Temperature* for more creative output.
        """
    )

# ============================================================
# 4️⃣ Launch
# ============================================================
if __name__ == "__main__":
    demo.launch(show_error=True)