| # # To Run: | |
| # # python -m dsp.modules.hf_server --port 4242 --model "google/flan-t5-base" | |
| # # To Query: | |
| # # curl -d '{"prompt":".."}' -X POST "http://0.0.0.0:4242" -H 'Content-Type: application/json' | |
| # # Or use the HF client. TODO: Add support for kwargs to the server. | |
| # from functools import lru_cache | |
| # import argparse | |
| # import time | |
| # import random | |
| # import os | |
| # import sys | |
| # import uvicorn | |
| # import warnings | |
| # from fastapi import FastAPI | |
| # from pydantic import BaseModel | |
| # from argparse import ArgumentParser | |
| # from starlette.middleware.cors import CORSMiddleware | |
| # from dsp.modules.hf import HFModel | |
| # class Query(BaseModel): | |
| # prompt: str | |
| # kwargs: dict = {} | |
| # warnings.filterwarnings("ignore") | |
| # app = FastAPI() | |
| # app.add_middleware( | |
| # CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"] | |
| # ) | |
| # parser = argparse.ArgumentParser("Server for Hugging Face models") | |
| # parser.add_argument("--port", type=int, required=True, help="Server port") | |
| # parser.add_argument("--model", type=str, required=True, help="Hugging Face model") | |
| # args = parser.parse_args() | |
| # # TODO: Convert this to a log message | |
| # print(f"#> Loading the language model {args.model}") | |
| # lm = HFModel(args.model) | |
| # @lru_cache(maxsize=None) | |
| # def generate(prompt, **kwargs): | |
| # global lm | |
| # generateStart = time.time() | |
| # # TODO: Convert this to a log message | |
| # print(f'#> kwargs: "{kwargs}" (type={type(kwargs)})') | |
| # response = lm._generate(prompt, **kwargs) | |
| # # TODO: Convert this to a log message | |
| # print(f'#> Response: "{response}"') | |
| # latency = (time.time() - generateStart) * 1000.0 | |
| # response["latency"] = latency | |
| # print(f'#> Latency:', '{:.3f}'.format(latency / 1000.0), 'seconds') | |
| # return response | |
| # @app.post("/") | |
| # async def generate_post(query: Query): | |
| # return generate(query.prompt, **query.kwargs) | |
| # if __name__ == "__main__": | |
| # uvicorn.run( | |
| # app, | |
| # host="0.0.0.0", | |
| # port=args.port, | |
| # reload=False, | |
| # log_level="info", | |
| # ) # can make reload=True later | |