Spaces:
Runtime error
Runtime error
File size: 3,427 Bytes
ff74138 040c903 c8ad3e4 287f932 511a6d9 ff74138 287f932 6412a86 040c903 6412a86 cf777dc 040c903 cf777dc 040c903 c8ad3e4 ff74138 c8ad3e4 ff74138 c8ad3e4 040c903 ff74138 040c903 6412a86 ff74138 6412a86 c8ad3e4 ff74138 c8ad3e4 287f932 ff74138 c8ad3e4 ff74138 c8ad3e4 ff74138 511a6d9 c8ad3e4 ff74138 c8ad3e4 ff74138 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
from fastapi import FastAPI, Query, HTTPException
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
import os
import torch
# -----------------------
# Set cache dirs (avoid Docker errors)
# -----------------------
# os.environ["HF_HOME"] = "/tmp"
# os.environ["TRANSFORMERS_CACHE"] = "/tmp"
# os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp/torch_inductor_cache"
# os.makedirs("/tmp/torch_inductor_cache", exist_ok=True)
os.environ["TORCH_HOME"] = "/tmp/torch_home"
os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp/torch_inductor_cache"
os.makedirs("/tmp/torch_home", exist_ok=True)
os.makedirs("/tmp/torch_inductor_cache", exist_ok=True)
# -----------------------
# Model Setup
# -----------------------
model_id = "LLM360/K2-Think"
print("Loading tokenizer and model...")
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir="/tmp")
bnb_config = BitsAndBytesConfig(
load_in_8bit=True # 8-bit quantization
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=bnb_config,
device_map="auto",
cache_dir="/tmp"
)
print("Model loaded!")
# -----------------------
# FastAPI Setup
# -----------------------
app = FastAPI(title="K2-Think QA API", description="Serving K2-Think Hugging Face model with FastAPI")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.mount("/static", StaticFiles(directory="static"), name="static")
# -----------------------
# Request Schema
# -----------------------
class QueryRequest(BaseModel):
question: str
max_new_tokens: int = 50
temperature: float = 0.7
top_p: float = 0.9
# -----------------------
# Endpoints
# -----------------------
@app.get("/")
def home():
return {"message": "Welcome to K2-Think QA API 🚀"}
@app.get("/health")
def health():
return {"status": "ok"}
@app.get("/ask")
def ask(question: str = Query(...), max_new_tokens: int = Query(50)):
try:
inputs = tokenizer(question, return_tensors="pt")
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=0.7,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id,
return_dict_in_generate=True
)
answer = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
return {"question": question, "answer": answer}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/predict")
def predict(request: QueryRequest):
try:
inputs = tokenizer(request.question, return_tensors="pt")
outputs = model.generate(
**inputs,
max_new_tokens=request.max_new_tokens,
do_sample=True,
temperature=request.temperature,
top_p=request.top_p,
pad_token_id=tokenizer.eos_token_id,
return_dict_in_generate=True
)
answer = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
return {"question": request.question, "answer": answer}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
|