Doubt-Solver / app.py
ak0601's picture
Update app.py
dbd41ce verified
# from fastapi import FastAPI, Request, Form, UploadFile, File
# from fastapi.templating import Jinja2Templates
# from fastapi.responses import HTMLResponse, RedirectResponse
# from fastapi.staticfiles import StaticFiles
# from dotenv import load_dotenv
# import os, io
# from PIL import Image
# import markdown
# import google.generativeai as genai
# # Load environment variable
# load_dotenv()
# API_KEY = os.getenv("GOOGLE_API_KEY")
# genai.configure(api_key=API_KEY)
# app = FastAPI()
# templates = Jinja2Templates(directory="templates")
# app.mount("/static", StaticFiles(directory="static"), name="static")
# model = genai.GenerativeModel('gemini-2.0-flash')
# # Create a global chat session
# chat = None
# chat_history = []
# @app.get("/", response_class=HTMLResponse)
# async def root(request: Request):
# return templates.TemplateResponse("index.html", {
# "request": request,
# "chat_history": chat_history,
# })
# @app.post("/", response_class=HTMLResponse)
# async def handle_input(
# request: Request,
# user_input: str = Form(...),
# image: UploadFile = File(None)
# ):
# global chat, chat_history
# # Initialize chat session if needed
# if chat is None:
# chat = model.start_chat(history=[])
# parts = []
# if user_input:
# parts.append(user_input)
# # For display in the UI
# user_message = user_input
# if image and image.content_type.startswith("image/"):
# data = await image.read()
# try:
# img = Image.open(io.BytesIO(data))
# parts.append(img)
# user_message += " [Image uploaded]" # Indicate image in chat history
# except Exception as e:
# chat_history.append({
# "role": "model",
# "content": markdown.markdown(f"**Error loading image:** {e}")
# })
# return RedirectResponse("/", status_code=303)
# # Store user message for display
# chat_history.append({"role": "user", "content": user_message})
# try:
# # Send message to Gemini model
# resp = chat.send_message(parts)
# # Add model response to history
# raw = resp.text
# chat_history.append({"role": "model", "content": raw})
# except Exception as e:
# err = f"**Error:** {e}"
# chat_history.append({
# "role": "model",
# "content": markdown.markdown(err)
# })
# # Post-Redirect-Get
# return RedirectResponse("/", status_code=303)
# # Clear chat history and start fresh
# @app.post("/new")
# async def new_chat():
# global chat, chat_history
# chat = None
# chat_history.clear()
# return RedirectResponse("/", status_code=303)
import os
import io
import streamlit as st
from dotenv import load_dotenv
from PIL import Image
import google.generativeai as genai
from langgraph.graph import StateGraph, END
from typing import TypedDict, List, Union
# ---------------------------
# Load API Key
# ---------------------------
load_dotenv()
API_KEY = os.getenv("GOOGLE_API_KEY")
genai.configure(api_key=API_KEY)
model = genai.GenerativeModel("gemini-2.0-flash")
# ---------------------------
# State Definition
# ---------------------------
class ChatState(TypedDict):
user_input: str
image: Union[Image.Image, None]
raw_response: str
final_response: str
chat_history: List[dict]
# ---------------------------
# LangGraph Nodes
# ---------------------------
def input_node(state: ChatState) -> ChatState:
return state
def processing_node(state: ChatState) -> ChatState:
parts = [state["user_input"]]
if state["image"]:
parts.append(state["image"])
try:
chat = model.start_chat(history=[])
resp = chat.send_message(parts)
state["raw_response"] = resp.text
except Exception as e:
state["raw_response"] = f"Error: {e}"
return state
def checking_node(state: ChatState) -> ChatState:
raw = state["raw_response"]
# Remove unnecessary lines from Gemini responses
if raw.startswith("Sure!") or "The image shows" in raw:
lines = raw.split("\n")
filtered = [
line for line in lines
if not line.startswith("Sure!") and "The image shows" not in line
]
final = "\n".join(filtered).strip()
state["final_response"] = final
else:
state["final_response"] = raw
# Save to session chat history
st.session_state.chat_history.append({"role": "user", "content": state["user_input"]})
st.session_state.chat_history.append({"role": "model", "content": state["final_response"]})
return state
# ---------------------------
# Build the LangGraph
# ---------------------------
builder = StateGraph(ChatState)
builder.add_node("input", input_node)
builder.add_node("processing", processing_node)
builder.add_node("checking", checking_node)
builder.set_entry_point("input")
builder.add_edge("input", "processing")
builder.add_edge("processing", "checking")
builder.add_edge("checking", END)
graph = builder.compile()
# ---------------------------
# Streamlit UI Setup
# ---------------------------
st.set_page_config(page_title="Math Chatbot", layout="centered")
st.title("Math Chatbot")
# Initialize session state
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
# Display chat history
for msg in st.session_state.chat_history:
with st.chat_message(msg["role"]):
st.markdown(msg["content"])
# ---------------------------
# Sidebar
# ---------------------------
with st.sidebar:
st.header("Options")
if st.button("New Chat"):
st.session_state.chat_history = []
st.rerun()
# ---------------------------
# Chat Input Form
# ---------------------------
with st.form("chat_form", clear_on_submit=True):
user_input = st.text_input("Your message:", placeholder="Ask your math problem here")
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
submitted = st.form_submit_button("Send")
if submitted:
# Load image safely
image = None
if uploaded_file:
try:
image = Image.open(io.BytesIO(uploaded_file.read()))
except Exception as e:
st.error(f"Error loading image: {e}")
st.stop()
# Prepare state
input_state = {
"user_input": user_input,
"image": image,
"raw_response": "",
"final_response": "",
"chat_history": st.session_state.chat_history,
}
# Run LangGraph
output = graph.invoke(input_state)
# Show model response
with st.chat_message("model"):
st.markdown(output["final_response"])