audio / app.py
shanusherly's picture
Update app.py
f6fc212 verified
# app.py
import os
import time
import requests
import gradio as gr
import google.generativeai as genai
from google.api_core.exceptions import ResourceExhausted
# -----------------------
# Config / Secrets
# -----------------------
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
HF_API_TOKEN = os.environ.get("HF_API_TOKEN") # required for TTS
HF_TTS_MODEL = os.environ.get("HF_TTS_MODEL", "microsoft/speecht5_tts")
AUDIO_TMP_DIR = "/tmp"
if not GEMINI_API_KEY:
raise RuntimeError("Missing GEMINI_API_KEY in environment. Add it to HF Space Secrets.")
if not HF_API_TOKEN:
print("Warning: HF_API_TOKEN not set. Audio will be unavailable until set in Space Secrets.")
# Configure Gemini SDK
genai.configure(api_key=GEMINI_API_KEY)
gemini_model = genai.GenerativeModel("gemini-2.5-flash")
# -----------------------
# In-memory chat memory
# -----------------------
class SimpleMemory:
def __init__(self, max_messages=40):
self.max_messages = max_messages
self.history = [] # list of (role, text) where role in {"user","bot"}
def add(self, role, text):
self.history.append((role, text))
if len(self.history) > self.max_messages:
self.history = self.history[-self.max_messages:]
def as_prompt_text(self):
lines = []
for role, txt in self.history:
if role == "user":
lines.append(f"User: {txt}")
else:
lines.append(f"Chatbot: {txt}")
return "\n".join(lines)
memory = SimpleMemory(max_messages=40)
# -----------------------
# Prompt template
# -----------------------
PROMPT_TEMPLATE = """You are a helpful assistant.
{chat_history}
User: {user_message}
Chatbot:"""
# -----------------------
# Robust Gemini generator (tries multiple formats)
# Returns (text, error)
# -----------------------
def generate_text_with_gemini(user_message):
chat_history_text = memory.as_prompt_text()
full_prompt = PROMPT_TEMPLATE.format(chat_history=chat_history_text, user_message=user_message)
# 1) raw prompt
try:
resp = gemini_model.generate_content(full_prompt)
text = getattr(resp, "text", None) or str(resp)
return text, None
except ResourceExhausted as e:
print("Gemini quota exhausted (raw):", e)
return None, "Gemini quota exceeded. Please try again later."
except Exception as e1:
print("generate_content(raw) failed, trying messages:", repr(e1))
# 2) messages with plain content
try:
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": full_prompt}
]
resp = gemini_model.generate_content(messages=messages)
text = getattr(resp, "text", None) or str(resp)
return text, None
except ResourceExhausted as e:
print("Gemini quota exhausted (messages):", e)
return None, "Gemini quota exceeded. Please try again later."
except Exception as e2:
print("generate_content(messages) failed, trying typed content:", repr(e2))
# 3) messages with typed content
try:
messages2 = [
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
{"role": "user", "content": [{"type": "text", "text": full_prompt}]}
]
resp = gemini_model.generate_content(messages=messages2)
text = getattr(resp, "text", None) or str(resp)
return text, None
except ResourceExhausted as e:
print("Gemini quota exhausted (messages2):", e)
return None, "Gemini quota exceeded. Please try again later."
except Exception as efinal:
print("Gemini all attempts failed:", repr(efinal))
return None, f"Gemini error: {repr(efinal)}"
# -----------------------
# Hugging Face Router-aware TTS
# Tries legacy api-inference endpoint, then router.huggingface.co
# Returns (path, error)
# -----------------------
def generate_audio_hf_inference(text):
if not HF_API_TOKEN:
return "", "HF_API_TOKEN not configured for TTS."
model = HF_TTS_MODEL # e.g. "microsoft/speecht5_tts"
router_url = f"/static-proxy?url=https%3A%2F%2Frouter.huggingface.co%2Fmodels%2F%3Cspan class="hljs-subst">{model}"
legacy_url = f"/static-proxy?url=https%3A%2F%2Fapi-inference.huggingface.co%2Fmodels%2F%3Cspan class="hljs-subst">{model}"
headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
payload = {"inputs": text}
def _save_bytes(content, content_type_hint=""):
ct = content_type_hint or ""
ext = ".mp3" if "mpeg" in ct or "audio/mpeg" in ct else ".wav"
filename = f"audio_{int(time.time()*1000)}_{abs(hash(text))%100000}{ext}"
path = os.path.join(AUDIO_TMP_DIR, filename)
with open(path, "wb") as f:
f.write(content)
return path
last_err = None
for url in [legacy_url, router_url]:
try:
h = headers.copy()
h["Accept"] = "audio/mpeg, audio/wav, */*"
resp = requests.post(url, headers=h, json=payload, timeout=60)
except Exception as e:
last_err = f"HuggingFace request to {url} failed: {e}"
print(last_err)
continue
if resp.status_code == 410:
last_err = f"HuggingFace returned 410 for {url}: {resp.text}"
print(last_err)
continue
if resp.status_code == 200:
try:
content_type = resp.headers.get("content-type", "")
path = _save_bytes(resp.content, content_type)
print(f"HuggingFace TTS: audio saved to {path} using URL {url} (content-type={content_type})")
return path, ""
except Exception as e:
last_err = f"Failed to save HF audio from {url}: {e}"
print(last_err)
continue
else:
try:
body = resp.json()
except Exception:
body = resp.text
last_err = f"HuggingFace TTS error {resp.status_code} from {url}: {body}"
print(last_err)
if resp.status_code in (401, 403):
# auth problem — break early
break
continue
return "", last_err or "Unknown HuggingFace error"
# -----------------------
# Convert memory -> messages list for Gradio
# -----------------------
def convert_memory_to_messages(history):
messages = []
for role, msg in history:
role_out = "assistant" if role == "bot" else "user"
messages.append({"role": role_out, "content": msg})
return messages
# -----------------------
# Combined chat workflow
# Returns (messages_list, audio_path, error)
# -----------------------
def process_user_message(user_message):
text, gen_err = generate_text_with_gemini(user_message)
if gen_err:
memory.add("user", user_message)
fallback = "Sorry — the assistant is temporarily unavailable: " + gen_err
memory.add("bot", fallback)
return convert_memory_to_messages(memory.history), "", gen_err
memory.add("user", user_message)
memory.add("bot", text)
audio_path, audio_err = generate_audio_hf_inference(text)
if audio_err:
print("Audio generation error (HF):", audio_err)
return convert_memory_to_messages(memory.history), audio_path or "", audio_err or ""
# -----------------------
# Gradio UI (Blocks) with debug UI
# -----------------------
with gr.Blocks() as demo:
gr.Markdown("## 🤖 Gemini + Hugging Face TTS Chatbot\n\nAudio generated via Hugging Face Inference (router).")
chatbot = gr.Chatbot()
with gr.Row():
txt = gr.Textbox(show_label=False, placeholder="Type your message and press Enter")
send_btn = gr.Button("Send")
audio_player = gr.Audio(label="Last reply audio (if available)", visible=False)
debug_box = gr.Textbox(label="Last debug message (audio path or error)", interactive=False, visible=False)
def submit_message(message):
messages, audio_path, err = process_user_message(message)
if audio_path:
debug_msg = f"Audio saved: {audio_path}"
return messages, gr.update(value=audio_path, visible=True), gr.update(value=debug_msg, visible=True)
elif err:
return messages, gr.update(value=None, visible=False), gr.update(value=err, visible=True)
else:
return messages, gr.update(value=None, visible=False), gr.update(value="No audio generated", visible=True)
send_btn.click(fn=submit_message, inputs=[txt], outputs=[chatbot, audio_player, debug_box])
txt.submit(fn=submit_message, inputs=[txt], outputs=[chatbot, audio_player, debug_box])
# Launch
if __name__ == "__main__":
demo.launch(debug=True)