Spaces:
Runtime error
Runtime error
| # from fastapi import FastAPI, Request, Form, UploadFile, File | |
| # from fastapi.templating import Jinja2Templates | |
| # from fastapi.responses import HTMLResponse, RedirectResponse | |
| # from fastapi.staticfiles import StaticFiles | |
| # from dotenv import load_dotenv | |
| # import os, io | |
| # from PIL import Image | |
| # import markdown | |
| # import google.generativeai as genai | |
| # # Load environment variable | |
| # load_dotenv() | |
| # API_KEY = os.getenv("GOOGLE_API_KEY") | |
| # genai.configure(api_key=API_KEY) | |
| # app = FastAPI() | |
| # templates = Jinja2Templates(directory="templates") | |
| # app.mount("/static", StaticFiles(directory="static"), name="static") | |
| # model = genai.GenerativeModel('gemini-2.0-flash') | |
| # # Create a global chat session | |
| # chat = None | |
| # chat_history = [] | |
| # @app.get("/", response_class=HTMLResponse) | |
| # async def root(request: Request): | |
| # return templates.TemplateResponse("index.html", { | |
| # "request": request, | |
| # "chat_history": chat_history, | |
| # }) | |
| # @app.post("/", response_class=HTMLResponse) | |
| # async def handle_input( | |
| # request: Request, | |
| # user_input: str = Form(...), | |
| # image: UploadFile = File(None) | |
| # ): | |
| # global chat, chat_history | |
| # # Initialize chat session if needed | |
| # if chat is None: | |
| # chat = model.start_chat(history=[]) | |
| # parts = [] | |
| # if user_input: | |
| # parts.append(user_input) | |
| # # For display in the UI | |
| # user_message = user_input | |
| # if image and image.content_type.startswith("image/"): | |
| # data = await image.read() | |
| # try: | |
| # img = Image.open(io.BytesIO(data)) | |
| # parts.append(img) | |
| # user_message += " [Image uploaded]" # Indicate image in chat history | |
| # except Exception as e: | |
| # chat_history.append({ | |
| # "role": "model", | |
| # "content": markdown.markdown(f"**Error loading image:** {e}") | |
| # }) | |
| # return RedirectResponse("/", status_code=303) | |
| # # Store user message for display | |
| # chat_history.append({"role": "user", "content": user_message}) | |
| # try: | |
| # # Send message to Gemini model | |
| # resp = chat.send_message(parts) | |
| # # Add model response to history | |
| # raw = resp.text | |
| # chat_history.append({"role": "model", "content": raw}) | |
| # except Exception as e: | |
| # err = f"**Error:** {e}" | |
| # chat_history.append({ | |
| # "role": "model", | |
| # "content": markdown.markdown(err) | |
| # }) | |
| # # Post-Redirect-Get | |
| # return RedirectResponse("/", status_code=303) | |
| # # Clear chat history and start fresh | |
| # @app.post("/new") | |
| # async def new_chat(): | |
| # global chat, chat_history | |
| # chat = None | |
| # chat_history.clear() | |
| # return RedirectResponse("/", status_code=303) | |
| import os | |
| import io | |
| import streamlit as st | |
| from dotenv import load_dotenv | |
| from PIL import Image | |
| import google.generativeai as genai | |
| from langgraph.graph import StateGraph, END | |
| from typing import TypedDict, List, Union | |
| # --------------------------- | |
| # Load API Key | |
| # --------------------------- | |
| load_dotenv() | |
| API_KEY = os.getenv("GOOGLE_API_KEY") | |
| genai.configure(api_key=API_KEY) | |
| model = genai.GenerativeModel("gemini-2.0-flash") | |
| # --------------------------- | |
| # State Definition | |
| # --------------------------- | |
| class ChatState(TypedDict): | |
| user_input: str | |
| image: Union[Image.Image, None] | |
| raw_response: str | |
| final_response: str | |
| chat_history: List[dict] | |
| # --------------------------- | |
| # LangGraph Nodes | |
| # --------------------------- | |
| def input_node(state: ChatState) -> ChatState: | |
| return state | |
| def processing_node(state: ChatState) -> ChatState: | |
| parts = [state["user_input"]] | |
| if state["image"]: | |
| parts.append(state["image"]) | |
| try: | |
| chat = model.start_chat(history=[]) | |
| resp = chat.send_message(parts) | |
| state["raw_response"] = resp.text | |
| except Exception as e: | |
| state["raw_response"] = f"Error: {e}" | |
| return state | |
| def checking_node(state: ChatState) -> ChatState: | |
| raw = state["raw_response"] | |
| # Remove unnecessary lines from Gemini responses | |
| if raw.startswith("Sure!") or "The image shows" in raw: | |
| lines = raw.split("\n") | |
| filtered = [ | |
| line for line in lines | |
| if not line.startswith("Sure!") and "The image shows" not in line | |
| ] | |
| final = "\n".join(filtered).strip() | |
| state["final_response"] = final | |
| else: | |
| state["final_response"] = raw | |
| # Save to session chat history | |
| st.session_state.chat_history.append({"role": "user", "content": state["user_input"]}) | |
| st.session_state.chat_history.append({"role": "model", "content": state["final_response"]}) | |
| return state | |
| # --------------------------- | |
| # Build the LangGraph | |
| # --------------------------- | |
| builder = StateGraph(ChatState) | |
| builder.add_node("input", input_node) | |
| builder.add_node("processing", processing_node) | |
| builder.add_node("checking", checking_node) | |
| builder.set_entry_point("input") | |
| builder.add_edge("input", "processing") | |
| builder.add_edge("processing", "checking") | |
| builder.add_edge("checking", END) | |
| graph = builder.compile() | |
| # --------------------------- | |
| # Streamlit UI Setup | |
| # --------------------------- | |
| st.set_page_config(page_title="Math Chatbot", layout="centered") | |
| st.title("Math Chatbot") | |
| # Initialize session state | |
| if "chat_history" not in st.session_state: | |
| st.session_state.chat_history = [] | |
| # Display chat history | |
| for msg in st.session_state.chat_history: | |
| with st.chat_message(msg["role"]): | |
| st.markdown(msg["content"]) | |
| # --------------------------- | |
| # Sidebar | |
| # --------------------------- | |
| with st.sidebar: | |
| st.header("Options") | |
| if st.button("New Chat"): | |
| st.session_state.chat_history = [] | |
| st.rerun() | |
| # --------------------------- | |
| # Chat Input Form | |
| # --------------------------- | |
| with st.form("chat_form", clear_on_submit=True): | |
| user_input = st.text_input("Your message:", placeholder="Ask your math problem here") | |
| uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) | |
| submitted = st.form_submit_button("Send") | |
| if submitted: | |
| # Load image safely | |
| image = None | |
| if uploaded_file: | |
| try: | |
| image = Image.open(io.BytesIO(uploaded_file.read())) | |
| except Exception as e: | |
| st.error(f"Error loading image: {e}") | |
| st.stop() | |
| # Prepare state | |
| input_state = { | |
| "user_input": user_input, | |
| "image": image, | |
| "raw_response": "", | |
| "final_response": "", | |
| "chat_history": st.session_state.chat_history, | |
| } | |
| # Run LangGraph | |
| output = graph.invoke(input_state) | |
| # Show model response | |
| with st.chat_message("model"): | |
| st.markdown(output["final_response"]) | |