File size: 3,427 Bytes
ff74138
040c903
c8ad3e4
287f932
511a6d9
 
ff74138
 
287f932
6412a86
040c903
6412a86
cf777dc
 
 
 
 
040c903
cf777dc
040c903
c8ad3e4
ff74138
 
 
 
c8ad3e4
ff74138
c8ad3e4
040c903
 
 
 
 
ff74138
 
040c903
 
6412a86
ff74138
6412a86
c8ad3e4
ff74138
 
 
 
c8ad3e4
287f932
 
 
 
 
 
 
ff74138
 
 
 
 
 
c8ad3e4
 
 
 
 
 
ff74138
 
 
c8ad3e4
 
ff74138
511a6d9
c8ad3e4
 
 
 
ff74138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c8ad3e4
 
ff74138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from fastapi import FastAPI, Query, HTTPException
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
import os
import torch

# -----------------------
# Set cache dirs (avoid Docker errors)
# -----------------------
# os.environ["HF_HOME"] = "/tmp"
# os.environ["TRANSFORMERS_CACHE"] = "/tmp"
# os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp/torch_inductor_cache"
# os.makedirs("/tmp/torch_inductor_cache", exist_ok=True)
os.environ["TORCH_HOME"] = "/tmp/torch_home"
os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp/torch_inductor_cache"
os.makedirs("/tmp/torch_home", exist_ok=True)
os.makedirs("/tmp/torch_inductor_cache", exist_ok=True)

# -----------------------
# Model Setup
# -----------------------
model_id = "LLM360/K2-Think"

print("Loading tokenizer and model...")
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir="/tmp")

bnb_config = BitsAndBytesConfig(
    load_in_8bit=True  # 8-bit quantization
)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map="auto",
    cache_dir="/tmp"
)
print("Model loaded!")

# -----------------------
# FastAPI Setup
# -----------------------
app = FastAPI(title="K2-Think QA API", description="Serving K2-Think Hugging Face model with FastAPI")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

app.mount("/static", StaticFiles(directory="static"), name="static")

# -----------------------
# Request Schema
# -----------------------
class QueryRequest(BaseModel):
    question: str
    max_new_tokens: int = 50
    temperature: float = 0.7
    top_p: float = 0.9

# -----------------------
# Endpoints
# -----------------------
@app.get("/")
def home():
    return {"message": "Welcome to K2-Think QA API 🚀"}

@app.get("/health")
def health():
    return {"status": "ok"}

@app.get("/ask")
def ask(question: str = Query(...), max_new_tokens: int = Query(50)):
    try:
        inputs = tokenizer(question, return_tensors="pt")
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            pad_token_id=tokenizer.eos_token_id,
            return_dict_in_generate=True
        )
        answer = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
        return {"question": question, "answer": answer}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/predict")
def predict(request: QueryRequest):
    try:
        inputs = tokenizer(request.question, return_tensors="pt")
        outputs = model.generate(
            **inputs,
            max_new_tokens=request.max_new_tokens,
            do_sample=True,
            temperature=request.temperature,
            top_p=request.top_p,
            pad_token_id=tokenizer.eos_token_id,
            return_dict_in_generate=True
        )
        answer = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
        return {"question": request.question, "answer": answer}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))