rabiyulfahim's picture
Update app.py
84f3de7 verified
from fastapi import FastAPI,Query
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
# ✅ Force Hugging Face cache to /tmp (writable in Spaces)
os.environ["HF_HOME"] = "/tmp"
os.environ["TRANSFORMERS_CACHE"] = "/tmp"
model_id = "rabiyulfahim/qa_python_gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir="/tmp")
model = AutoModelForCausalLM.from_pretrained(model_id, cache_dir="/tmp")
app = FastAPI(title="QA GPT2 API UI", description="Serving HuggingFace model with FastAPI")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Request schema
class QueryRequest(BaseModel):
question: str
max_new_tokens: int = 50
temperature: float = 0.7
top_p: float = 0.9
@app.get("/")
def home():
return {"message": "Welcome to QA GPT2 API 🚀"}
@app.get("/ask")
def ask(question: str, max_new_tokens: int = 50):
inputs = tokenizer(question, return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"question": question, "answer": answer}
# Mount static folder
app.mount("/static", StaticFiles(directory="static"), name="static")
@app.get("/ui", response_class=HTMLResponse)
def serve_ui():
html_path = os.path.join("static", "index.html")
with open(html_path, "r", encoding="utf-8") as f:
return HTMLResponse(f.read())
# Health check endpoint
@app.get("/health")
def health():
return {"status": "ok"}
# Inference endpoint
@app.post("/predict")
def predict(request: QueryRequest):
inputs = tokenizer(request.question, return_tensors="pt")
outputs = model.generate(
**inputs,
max_new_tokens=request.max_new_tokens,
do_sample=True,
temperature=0.7,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id,
return_dict_in_generate=True
)
answer = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
return {
"question": request.question,
"answer": answer
}
@app.get("/answers")
def predict(question: str = Query(..., description="The question to ask"), max_new_tokens: int = Query(50, description="Max new tokens to generate")):
# Tokenize the input question
inputs = tokenizer(question, return_tensors="pt")
# Generate output from model
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=0.7,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id,
return_dict_in_generate=True
)
# Decode output
answer = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
return {
"question": question,
"answer": answer
}