Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse, JSONResponse | |
| from pydantic import BaseModel | |
| import json | |
| import logging | |
| import logging.config | |
| import os | |
| from core.config import API_HOST, API_PORT, CORS_SETTINGS, LOG_CONFIG | |
| from core.exceptions import APIError, handle_api_error | |
| from core.text_generation import text_generator | |
| # Configure logging | |
| logging.config.dictConfig(LOG_CONFIG) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI(title="AI Text Generation API", | |
| description="API for text generation using multiple AI providers", | |
| version="1.0.0") | |
| # Enable CORS with specific headers for SSE | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Update this in production | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| expose_headers=["Content-Type", "Cache-Control"] | |
| ) | |
| # API configuration and setup | |
| class PromptRequest(BaseModel): | |
| model: str | |
| prompt: str | |
| async def read_root(): | |
| """API root endpoint.""" | |
| return {"status": "ok", "message": "API is running"} | |
| async def get_models(): | |
| """Get list of all available models.""" | |
| try: | |
| # Return models as a JSON array | |
| return JSONResponse(content=text_generator.get_available_models()) | |
| except APIError as e: | |
| error_response = handle_api_error(e) | |
| raise HTTPException( | |
| status_code=error_response["status_code"], | |
| detail=error_response["detail"] | |
| ) | |
| except Exception as e: | |
| logger.error(f"Unexpected error in get_models: {str(e)}") | |
| raise HTTPException(status_code=500, detail="Internal server error") | |
| async def generate_stream(model: str, prompt: str): | |
| """Stream generator for text generation.""" | |
| try: | |
| async for chunk in text_generator.generate_stream(model, prompt): | |
| # Add extra newline to ensure proper event separation | |
| yield f"data: {json.dumps({'content': chunk})}\n\n" | |
| except APIError as e: | |
| error_response = handle_api_error(e) | |
| yield f"data: {json.dumps({'error': error_response['detail']})}\n\n" | |
| except Exception as e: | |
| logger.error(f"Unexpected error in generate_stream: {str(e)}") | |
| yield f"data: {json.dumps({'error': 'Internal server error'})}\n\n" | |
| finally: | |
| yield "data: [DONE]\n\n" | |
| async def generate_response(request: Request): | |
| """Generate response using selected model (supports both GET and POST).""" | |
| try: | |
| # Handle both GET and POST methods | |
| if request.method == "GET": | |
| params = dict(request.query_params) | |
| model = params.get("model") | |
| prompt = params.get("prompt") | |
| else: | |
| body = await request.json() | |
| model = body.get("model") | |
| prompt = body.get("prompt") | |
| if not model or not prompt: | |
| raise HTTPException(status_code=400, detail="Missing model or prompt parameter") | |
| logger.info(f"Received {request.method} request for model: {model}") | |
| headers = { | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "X-Accel-Buffering": "no" # Disable buffering for nginx | |
| } | |
| return StreamingResponse( | |
| generate_stream(model, prompt), | |
| media_type="text/event-stream", | |
| headers=headers | |
| ) | |
| except APIError as e: | |
| error_response = handle_api_error(e) | |
| raise HTTPException( | |
| status_code=error_response["status_code"], | |
| detail=error_response["detail"] | |
| ) | |
| except Exception as e: | |
| logger.error(f"Unexpected error in generate_response: {str(e)}") | |
| raise HTTPException(status_code=500, detail="Internal server error") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host=API_HOST, port=API_PORT) |