File size: 3,046 Bytes
3edcf27
c8ad3e4
 
 
 
287f932
511a6d9
 
287f932
c8ad3e4
 
 
 
 
 
 
 
 
 
 
 
84f3de7
c8ad3e4
287f932
 
 
 
 
 
 
c8ad3e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
511a6d9
 
 
 
 
 
 
 
 
c8ad3e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b87b24a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c8ad3e4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from fastapi import FastAPI,Query
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles


# ✅ Force Hugging Face cache to /tmp (writable in Spaces)
os.environ["HF_HOME"] = "/tmp"
os.environ["TRANSFORMERS_CACHE"] = "/tmp"


model_id = "rabiyulfahim/qa_python_gpt2"

tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir="/tmp")
model = AutoModelForCausalLM.from_pretrained(model_id, cache_dir="/tmp")


app = FastAPI(title="QA GPT2 API UI", description="Serving HuggingFace model with FastAPI")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)
# Request schema
class QueryRequest(BaseModel):
    question: str
    max_new_tokens: int = 50
    temperature: float = 0.7
    top_p: float = 0.9


@app.get("/")
def home():
    return {"message": "Welcome to QA GPT2 API 🚀"}

@app.get("/ask")
def ask(question: str, max_new_tokens: int = 50):
    inputs = tokenizer(question, return_tensors="pt")
    outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
    answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return {"question": question, "answer": answer}


# Mount static folder
app.mount("/static", StaticFiles(directory="static"), name="static")

@app.get("/ui", response_class=HTMLResponse)
def serve_ui():
    html_path = os.path.join("static", "index.html")
    with open(html_path, "r", encoding="utf-8") as f:
        return HTMLResponse(f.read())


# Health check endpoint
@app.get("/health")
def health():
    return {"status": "ok"}

# Inference endpoint
@app.post("/predict")
def predict(request: QueryRequest):
    inputs = tokenizer(request.question, return_tensors="pt")
    outputs = model.generate(
        **inputs,
        max_new_tokens=request.max_new_tokens,
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
        pad_token_id=tokenizer.eos_token_id,
        return_dict_in_generate=True
    )

    answer = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
    return {
        "question": request.question,
        "answer": answer
    }




@app.get("/answers")
def predict(question: str = Query(..., description="The question to ask"), max_new_tokens: int = Query(50, description="Max new tokens to generate")):
    # Tokenize the input question
    inputs = tokenizer(question, return_tensors="pt")

    # Generate output from model
    outputs = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
        pad_token_id=tokenizer.eos_token_id,
        return_dict_in_generate=True
    )

    # Decode output
    answer = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)

    return {
        "question": question,
        "answer": answer
    }