Spaces:
Build error
Build error
| # coding=utf-8 | |
| # Implements API for ChatGLM2-6B in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat) | |
| # Usage: python openai_api.py | |
| # Visit http://localhost:8000/docs for documents. | |
| import time | |
| import torch | |
| import uvicorn | |
| from pydantic import BaseModel, Field | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from contextlib import asynccontextmanager | |
| from starlette.responses import StreamingResponse | |
| from typing import Any, Dict, List, Literal, Optional, Union | |
| from transformers import AutoTokenizer, AutoModel | |
| async def lifespan(app: FastAPI): # collects GPU memory | |
| yield | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.ipc_collect() | |
| app = FastAPI(lifespan=lifespan) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| class ModelCard(BaseModel): | |
| id: str | |
| object: str = "model" | |
| created: int = Field(default_factory=lambda: int(time.time())) | |
| owned_by: str = "owner" | |
| root: Optional[str] = None | |
| parent: Optional[str] = None | |
| permission: Optional[list] = None | |
| class ModelList(BaseModel): | |
| object: str = "list" | |
| data: List[ModelCard] = [] | |
| class ChatMessage(BaseModel): | |
| role: Literal["user", "assistant", "system"] | |
| content: str | |
| class DeltaMessage(BaseModel): | |
| role: Optional[Literal["user", "assistant", "system"]] = None | |
| content: Optional[str] = None | |
| class ChatCompletionRequest(BaseModel): | |
| model: str | |
| messages: List[ChatMessage] | |
| temperature: Optional[float] = None | |
| top_p: Optional[float] = None | |
| max_length: Optional[int] = None | |
| stream: Optional[bool] = False | |
| class ChatCompletionResponseChoice(BaseModel): | |
| index: int | |
| message: ChatMessage | |
| finish_reason: Literal["stop", "length"] | |
| class ChatCompletionResponseStreamChoice(BaseModel): | |
| index: int | |
| delta: DeltaMessage | |
| finish_reason: Optional[Literal["stop", "length"]] | |
| class ChatCompletionResponse(BaseModel): | |
| model: str | |
| object: Literal["chat.completion", "chat.completion.chunk"] | |
| choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]] | |
| created: Optional[int] = Field(default_factory=lambda: int(time.time())) | |
| async def list_models(): | |
| global model_args | |
| model_card = ModelCard(id="gpt-3.5-turbo") | |
| return ModelList(data=[model_card]) | |
| async def create_chat_completion(request: ChatCompletionRequest): | |
| global model, tokenizer | |
| if request.messages[-1].role != "user": | |
| raise HTTPException(status_code=400, detail="Invalid request") | |
| query = request.messages[-1].content | |
| prev_messages = request.messages[:-1] | |
| if len(prev_messages) > 0 and prev_messages[0].role == "system": | |
| query = prev_messages.pop(0).content + query | |
| history = [] | |
| if len(prev_messages) % 2 == 0: | |
| for i in range(0, len(prev_messages), 2): | |
| if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant": | |
| history.append([prev_messages[i].content, prev_messages[i+1].content]) | |
| if request.stream: | |
| generate = predict(query, history, request.model) | |
| return StreamingResponse(generate, media_type="text/event-stream") | |
| response, _ = model.chat(tokenizer, query, history=history) | |
| choice_data = ChatCompletionResponseChoice( | |
| index=0, | |
| message=ChatMessage(role="assistant", content=response), | |
| finish_reason="stop" | |
| ) | |
| return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion") | |
| async def predict(query: str, history: List[List[str]], model_id: str): | |
| global model, tokenizer | |
| choice_data = ChatCompletionResponseStreamChoice( | |
| index=0, | |
| delta=DeltaMessage(role="assistant"), | |
| finish_reason=None | |
| ) | |
| chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") | |
| yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False)) | |
| current_length = 0 | |
| for new_response, _ in model.stream_chat(tokenizer, query, history): | |
| if len(new_response) == current_length: | |
| continue | |
| new_text = new_response[current_length:] | |
| current_length = len(new_response) | |
| choice_data = ChatCompletionResponseStreamChoice( | |
| index=0, | |
| delta=DeltaMessage(content=new_text), | |
| finish_reason=None | |
| ) | |
| chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") | |
| yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False)) | |
| choice_data = ChatCompletionResponseStreamChoice( | |
| index=0, | |
| delta=DeltaMessage(), | |
| finish_reason="stop" | |
| ) | |
| chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") | |
| yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False)) | |
| if __name__ == "__main__": | |
| tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True) | |
| model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).cuda() | |
| # 多显卡支持,使用下面两行代替上面一行,将num_gpus改为你实际的显卡数量 | |
| # from utils import load_model_on_gpus | |
| # model = load_model_on_gpus("THUDM/chatglm2-6b", num_gpus=2) | |
| model.eval() | |
| uvicorn.run(app, host='0.0.0.0', port=8000, workers=1) | |