Spaces:
Running
Running
| import asyncio | |
| import json | |
| import time | |
| import uuid | |
| import base64 | |
| import os | |
| import tempfile | |
| import threading | |
| import dotenv | |
| import logging | |
| from typing import List, Dict, Any, Optional, Union | |
| from fastapi import FastAPI, HTTPException, Depends, Request, BackgroundTasks, File, UploadFile, Form | |
| from fastapi.responses import StreamingResponse, JSONResponse, FileResponse, HTMLResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import BaseModel, Field | |
| import uvicorn | |
| import re | |
| from .key_manager import KeyManager | |
| from .sora_integration import SoraClient | |
| from .config import Config | |
| from .utils import localize_image_urls # 导入新增的图片本地化功能 | |
| # 日志系统配置 | |
| class LogConfig: | |
| LEVEL = os.getenv("LOG_LEVEL", "WARNING").upper() | |
| FORMAT = "%(asctime)s [%(levelname)s] %(message)s" | |
| # 初始化日志 | |
| logging.basicConfig( | |
| level=getattr(logging, LogConfig.LEVEL), | |
| format=LogConfig.FORMAT, | |
| datefmt="%Y-%m-%d %H:%M:%S" | |
| ) | |
| logger = logging.getLogger("sora-api") | |
| # 打印日志级别信息 | |
| logger.info(f"日志级别设置为: {LogConfig.LEVEL}") | |
| logger.info(f"要调整日志级别,请设置环境变量 LOG_LEVEL=DEBUG|INFO|WARNING|ERROR") | |
| # 创建FastAPI应用 | |
| app = FastAPI(title="OpenAI Compatible Sora API") | |
| # 添加CORS支持 | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # 确保静态文件目录存在 | |
| os.makedirs(os.path.join(Config.STATIC_DIR, "admin"), exist_ok=True) | |
| os.makedirs(os.path.join(Config.STATIC_DIR, "admin/js"), exist_ok=True) | |
| os.makedirs(os.path.join(Config.STATIC_DIR, "admin/css"), exist_ok=True) | |
| os.makedirs(os.path.join(Config.STATIC_DIR, "images"), exist_ok=True) # 确保图片目录存在 | |
| # 打印配置信息 | |
| Config.print_config() | |
| # 挂载静态文件目录 | |
| app.mount("/static", StaticFiles(directory=Config.STATIC_DIR), name="static") | |
| # 初始化Key管理器 | |
| key_manager = KeyManager(storage_file=Config.KEYS_STORAGE_FILE) | |
| # 初始化时保存管理员密钥 | |
| Config.save_admin_key() | |
| # 创建Sora客户端池 | |
| sora_clients = {} | |
| # 存储生成结果的全局字典 | |
| generation_results = {} | |
| # 请求模型 | |
| class ContentItem(BaseModel): | |
| type: str | |
| text: Optional[str] = None | |
| image_url: Optional[Dict[str, str]] = None | |
| class ChatMessage(BaseModel): | |
| role: str | |
| content: Union[str, List[ContentItem]] | |
| class ChatCompletionRequest(BaseModel): | |
| model: str | |
| messages: List[ChatMessage] | |
| temperature: Optional[float] = 1.0 | |
| top_p: Optional[float] = 1.0 | |
| n: Optional[int] = 1 | |
| stream: Optional[bool] = False | |
| max_tokens: Optional[int] = None | |
| presence_penalty: Optional[float] = 0 | |
| frequency_penalty: Optional[float] = 0 | |
| # API密钥管理模型 | |
| class ApiKeyCreate(BaseModel): | |
| name: str = Field(..., description="密钥名称") | |
| key_value: str = Field(..., description="Bearer Token值") | |
| weight: int = Field(default=1, ge=1, le=10, description="权重值") | |
| rate_limit: int = Field(default=60, description="每分钟最大请求数") | |
| is_enabled: bool = Field(default=True, description="是否启用") | |
| notes: Optional[str] = Field(default=None, description="备注信息") | |
| class ApiKeyUpdate(BaseModel): | |
| name: Optional[str] = None | |
| key_value: Optional[str] = None | |
| weight: Optional[int] = None | |
| rate_limit: Optional[int] = None | |
| is_enabled: Optional[bool] = None | |
| notes: Optional[str] = None | |
| # 获取Sora客户端 | |
| def get_sora_client(auth_token: str) -> SoraClient: | |
| if auth_token not in sora_clients: | |
| proxy_host = Config.PROXY_HOST if Config.PROXY_HOST and Config.PROXY_HOST.strip() else None | |
| proxy_port = Config.PROXY_PORT if Config.PROXY_PORT and Config.PROXY_PORT.strip() else None | |
| sora_clients[auth_token] = SoraClient( | |
| proxy_host=proxy_host, | |
| proxy_port=proxy_port, | |
| auth_token=auth_token | |
| ) | |
| return sora_clients[auth_token] | |
| # 验证API key | |
| async def verify_api_key(request: Request): | |
| auth_header = request.headers.get("Authorization") | |
| if not auth_header or not auth_header.startswith("Bearer "): | |
| raise HTTPException(status_code=401, detail="缺少或无效的API key") | |
| api_key = auth_header.replace("Bearer ", "") | |
| # 在实际应用中,这里应该验证key的有效性 | |
| # 这里简化处理,假设所有key都有效 | |
| return api_key | |
| # 验证管理员权限 | |
| async def verify_admin(request: Request): | |
| auth_header = request.headers.get("Authorization") | |
| if not auth_header or not auth_header.startswith("Bearer "): | |
| raise HTTPException(status_code=401, detail="未授权") | |
| admin_key = auth_header.replace("Bearer ", "") | |
| # 这里应该检查是否为管理员密钥 | |
| # 简化处理,假设admin_key是预设的管理员密钥 | |
| if admin_key != Config.ADMIN_KEY: | |
| raise HTTPException(status_code=403, detail="没有管理员权限") | |
| return admin_key | |
| # 将处理中状态消息格式化为think代码块 | |
| def format_think_block(message): | |
| """将消息放入```think代码块中""" | |
| return f"```think\n{message}\n```" | |
| # 后台任务处理函数 - 文本生成图像 | |
| async def process_image_generation( | |
| request_id: str, | |
| sora_client: SoraClient, | |
| prompt: str, | |
| num_images: int = 1, | |
| width: int = 720, | |
| height: int = 480 | |
| ): | |
| try: | |
| # 更新状态为生成中 | |
| generation_results[request_id] = { | |
| "status": "processing", | |
| "message": format_think_block("正在生成图像中,请耐心等待..."), | |
| "timestamp": int(time.time()) | |
| } | |
| # 生成图像 | |
| logger.info(f"[{request_id}] 开始生成图像, 提示词: {prompt}") | |
| image_urls = await sora_client.generate_image( | |
| prompt=prompt, | |
| num_images=num_images, | |
| width=width, | |
| height=height | |
| ) | |
| # 验证生成结果 | |
| if isinstance(image_urls, str): | |
| logger.warning(f"[{request_id}] 图像生成失败或返回了错误信息: {image_urls}") | |
| generation_results[request_id] = { | |
| "status": "failed", | |
| "error": image_urls, | |
| "message": format_think_block(f"图像生成失败: {image_urls}"), | |
| "timestamp": int(time.time()) | |
| } | |
| return | |
| if not image_urls: | |
| logger.warning(f"[{request_id}] 图像生成返回了空列表") | |
| generation_results[request_id] = { | |
| "status": "failed", | |
| "error": "图像生成返回了空结果", | |
| "message": format_think_block("图像生成失败: 服务器返回了空结果"), | |
| "timestamp": int(time.time()) | |
| } | |
| return | |
| logger.info(f"[{request_id}] 成功生成 {len(image_urls)} 张图片") | |
| if logger.isEnabledFor(logging.DEBUG): | |
| for i, url in enumerate(image_urls): | |
| logger.debug(f"[{request_id}] 图片 {i+1}: {url}") | |
| # 本地化图片URL | |
| if Config.IMAGE_LOCALIZATION: | |
| logger.info(f"[{request_id}] 准备进行图片本地化处理") | |
| logger.debug(f"[{request_id}] 图片本地化配置: 启用={Config.IMAGE_LOCALIZATION}, 保存目录={Config.IMAGE_SAVE_DIR}") | |
| try: | |
| localized_urls = await localize_image_urls(image_urls) | |
| logger.info(f"[{request_id}] 图片本地化处理完成") | |
| # 检查本地化结果 | |
| if not localized_urls: | |
| logger.warning(f"[{request_id}] 本地化处理返回了空列表,将使用原始URL") | |
| localized_urls = image_urls | |
| # 检查是否所有URL都被正确本地化 | |
| local_count = sum(1 for url in localized_urls if url.startswith("/static/") or "/static/" in url) | |
| logger.info(f"[{request_id}] 本地化结果: 总计 {len(localized_urls)} 张图片,成功本地化 {local_count} 张") | |
| if local_count == 0: | |
| logger.warning(f"[{request_id}] 警告:没有一个URL被成功本地化,将使用原始URL") | |
| localized_urls = image_urls | |
| # 打印结果对比 | |
| if logger.isEnabledFor(logging.DEBUG): | |
| for i, (orig, local) in enumerate(zip(image_urls, localized_urls)): | |
| logger.debug(f"[{request_id}] 图片 {i+1} 本地化结果: {orig} -> {local}") | |
| image_urls = localized_urls | |
| except Exception as e: | |
| logger.error(f"[{request_id}] 图片本地化过程中发生错误: {str(e)}") | |
| if logger.isEnabledFor(logging.DEBUG): | |
| import traceback | |
| logger.debug(traceback.format_exc()) | |
| logger.info(f"[{request_id}] 由于错误,将使用原始URL") | |
| else: | |
| logger.info(f"[{request_id}] 图片本地化功能未启用,使用原始URL") | |
| # 存储结果 | |
| generation_results[request_id] = { | |
| "status": "completed", | |
| "image_urls": image_urls, | |
| "timestamp": int(time.time()) | |
| } | |
| # 30分钟后自动清理结果 | |
| threading.Timer(1800, lambda: generation_results.pop(request_id, None)).start() | |
| except Exception as e: | |
| error_message = f"图像生成失败: {str(e)}" | |
| generation_results[request_id] = { | |
| "status": "failed", | |
| "error": error_message, | |
| "message": format_think_block(error_message), | |
| "timestamp": int(time.time()) | |
| } | |
| logger.error(f"图像生成失败 (ID: {request_id}): {str(e)}") | |
| if logger.isEnabledFor(logging.DEBUG): | |
| import traceback | |
| logger.debug(traceback.format_exc()) | |
| # 后台任务处理函数 - 带图片的remix | |
| async def process_image_remix( | |
| request_id: str, | |
| sora_client: SoraClient, | |
| prompt: str, | |
| image_data: str, | |
| num_images: int = 1 | |
| ): | |
| try: | |
| # 更新状态为处理中 | |
| generation_results[request_id] = { | |
| "status": "processing", | |
| "message": format_think_block("正在处理上传的图片..."), | |
| "timestamp": int(time.time()) | |
| } | |
| # 保存base64图片到临时文件 | |
| temp_dir = tempfile.mkdtemp() | |
| temp_image_path = os.path.join(temp_dir, f"upload_{uuid.uuid4()}.png") | |
| try: | |
| # 解码并保存图片 | |
| with open(temp_image_path, "wb") as f: | |
| f.write(base64.b64decode(image_data)) | |
| # 更新状态为上传中 | |
| generation_results[request_id] = { | |
| "status": "processing", | |
| "message": format_think_block("正在上传图片到Sora服务..."), | |
| "timestamp": int(time.time()) | |
| } | |
| # 上传图片 | |
| upload_result = await sora_client.upload_image(temp_image_path) | |
| media_id = upload_result['id'] | |
| # 更新状态为生成中 | |
| generation_results[request_id] = { | |
| "status": "processing", | |
| "message": format_think_block("正在基于图片生成新图像..."), | |
| "timestamp": int(time.time()) | |
| } | |
| # 执行remix生成 | |
| logger.info(f"[{request_id}] 开始生成Remix图像, 提示词: {prompt}") | |
| image_urls = await sora_client.generate_image_remix( | |
| prompt=prompt, | |
| media_id=media_id, | |
| num_images=num_images | |
| ) | |
| # 本地化图片URL | |
| if Config.IMAGE_LOCALIZATION: | |
| logger.info(f"[{request_id}] 准备进行图片本地化处理") | |
| localized_urls = await localize_image_urls(image_urls) | |
| image_urls = localized_urls | |
| logger.info(f"[{request_id}] Remix图片本地化处理完成") | |
| # 存储结果 | |
| generation_results[request_id] = { | |
| "status": "completed", | |
| "image_urls": image_urls, | |
| "timestamp": int(time.time()) | |
| } | |
| # 30分钟后自动清理结果 | |
| threading.Timer(1800, lambda: generation_results.pop(request_id, None)).start() | |
| finally: | |
| # 清理临时文件 | |
| if os.path.exists(temp_image_path): | |
| os.remove(temp_image_path) | |
| if os.path.exists(temp_dir): | |
| os.rmdir(temp_dir) | |
| except Exception as e: | |
| error_message = f"图像Remix失败: {str(e)}" | |
| generation_results[request_id] = { | |
| "status": "failed", | |
| "error": error_message, | |
| "message": format_think_block(error_message), | |
| "timestamp": int(time.time()) | |
| } | |
| logger.error(f"图像Remix失败 (ID: {request_id}): {str(e)}") | |
| # 添加一个新端点用于检查生成状态 | |
| @app.get("/v1/generation/{request_id}") | |
| async def check_generation_status(request_id: str, api_key: str = Depends(verify_api_key)): | |
| """ | |
| 检查图像生成任务的状态 | |
| """ | |
| # 获取一个可用的key并记录开始时间 | |
| sora_auth_token = key_manager.get_key() | |
| if not sora_auth_token: | |
| raise HTTPException(status_code=429, detail="所有API key都已达到速率限制") | |
| start_time = time.time() | |
| success = False | |
| try: | |
| if request_id not in generation_results: | |
| raise HTTPException(status_code=404, detail=f"找不到生成任务: {request_id}") | |
| result = generation_results[request_id] | |
| if result["status"] == "completed": | |
| image_urls = result["image_urls"] | |
| # 构建OpenAI兼容的响应 | |
| response = { | |
| "id": request_id, | |
| "object": "chat.completion", | |
| "created": result["timestamp"], | |
| "model": "sora-1.0", | |
| "choices": [ | |
| { | |
| "index": i, | |
| "message": { | |
| "role": "assistant", | |
| "content": f"" | |
| }, | |
| "finish_reason": "stop" | |
| } | |
| for i, url in enumerate(image_urls) | |
| ], | |
| "usage": { | |
| "prompt_tokens": 0, # 简化的令牌计算 | |
| "completion_tokens": 20, | |
| "total_tokens": 20 | |
| } | |
| } | |
| success = True | |
| return JSONResponse(content=response) | |
| elif result["status"] == "failed": | |
| if "message" in result: | |
| # 返回带有格式化错误消息的响应 | |
| response = { | |
| "id": request_id, | |
| "object": "chat.completion", | |
| "created": result["timestamp"], | |
| "model": "sora-1.0", | |
| "choices": [ | |
| { | |
| "index": 0, | |
| "message": { | |
| "role": "assistant", | |
| "content": result["message"] | |
| }, | |
| "finish_reason": "error" | |
| } | |
| ], | |
| "usage": { | |
| "prompt_tokens": 0, | |
| "completion_tokens": 10, | |
| "total_tokens": 10 | |
| } | |
| } | |
| success = False | |
| return JSONResponse(content=response) | |
| else: | |
| # 向后兼容,使用老的方式 | |
| raise HTTPException(status_code=500, detail=f"生成失败: {result['error']}") | |
| else: # 处理中 | |
| message = result.get("message", "```think\n正在生成图像,请稍候...\n```") | |
| response = { | |
| "id": request_id, | |
| "object": "chat.completion", | |
| "created": result["timestamp"], | |
| "model": "sora-1.0", | |
| "choices": [ | |
| { | |
| "index": 0, | |
| "message": { | |
| "role": "assistant", | |
| "content": message | |
| }, | |
| "finish_reason": "processing" | |
| } | |
| ], | |
| "usage": { | |
| "prompt_tokens": 0, | |
| "completion_tokens": 10, | |
| "total_tokens": 10 | |
| } | |
| } | |
| success = True | |
| return JSONResponse(content=response) | |
| except Exception as e: | |
| success = False | |
| raise HTTPException(status_code=500, detail=f"检查任务状态失败: {str(e)}") | |
| finally: | |
| # 记录请求结果 | |
| response_time = time.time() - start_time | |
| key_manager.record_request_result(sora_auth_token, success, response_time) | |
| # 聊天完成端点 | |
| @app.post("/v1/chat/completions") | |
| async def chat_completions( | |
| request: ChatCompletionRequest, | |
| api_key: str = Depends(verify_api_key), | |
| background_tasks: BackgroundTasks = None | |
| ): | |
| # 获取一个可用的key | |
| sora_auth_token = key_manager.get_key() | |
| if not sora_auth_token: | |
| raise HTTPException(status_code=429, detail="所有API key都已达到速率限制") | |
| # 获取Sora客户端 | |
| sora_client = get_sora_client(sora_auth_token) | |
| # 分析最后一条用户消息以提取内容 | |
| user_messages = [m for m in request.messages if m.role == "user"] | |
| if not user_messages: | |
| raise HTTPException(status_code=400, detail="至少需要一条用户消息") | |
| last_user_message = user_messages[-1] | |
| prompt = "" | |
| image_data = None | |
| # 提取提示词和图片数据 | |
| if isinstance(last_user_message.content, str): | |
| # 简单的字符串内容 | |
| prompt = last_user_message.content | |
| # 检查是否包含内嵌的base64图片 | |
| pattern = r'data:image\/[^;]+;base64,([^"]+)' | |
| match = re.search(pattern, prompt) | |
| if match: | |
| image_data = match.group(1) | |
| # 从提示词中删除base64数据,以保持提示词的可读性 | |
| prompt = re.sub(pattern, "[已上传图片]", prompt) | |
| else: | |
| # 多模态内容,提取文本和图片 | |
| content_items = last_user_message.content | |
| text_parts = [] | |
| for item in content_items: | |
| if item.type == "text" and item.text: | |
| text_parts.append(item.text) | |
| elif item.type == "image_url" and item.image_url: | |
| # 如果有图片URL包含base64数据 | |
| url = item.image_url.get("url", "") | |
| if url.startswith("data:image/"): | |
| pattern = r'data:image\/[^;]+;base64,([^"]+)' | |
| match = re.search(pattern, url) | |
| if match: | |
| image_data = match.group(1) | |
| text_parts.append("[已上传图片]") | |
| prompt = " ".join(text_parts) | |
| # 记录开始时间 | |
| start_time = time.time() | |
| success = False | |
| # 处理图片生成 | |
| try: | |
| # 检查是否为流式响应 | |
| if request.stream: | |
| # 流式响应特殊处理文本+图片的情况 | |
| if image_data: | |
| response = StreamingResponse( | |
| generate_streaming_remix_response(sora_client, prompt, image_data, request.n), | |
| media_type="text/event-stream" | |
| ) | |
| else: | |
| response = StreamingResponse( | |
| generate_streaming_response(sora_client, prompt, request.n), | |
| media_type="text/event-stream" | |
| ) | |
| success = True | |
| # 记录请求结果(流式响应立即记录) | |
| response_time = time.time() - start_time | |
| key_manager.record_request_result(sora_auth_token, success, response_time) | |
| return response | |
| else: | |
| # 对于非流式响应,返回一个即时响应,表示任务已接收 | |
| # 创建一个唯一ID | |
| request_id = f"chatcmpl-{uuid.uuid4().hex}" | |
| # 在结果字典中创建初始状态 | |
| processing_message = "正在准备生成任务,请稍候..." | |
| generation_results[request_id] = { | |
| "status": "processing", | |
| "message": format_think_block(processing_message), | |
| "timestamp": int(time.time()) | |
| } | |
| # 添加后台任务 | |
| if image_data: | |
| background_tasks.add_task( | |
| process_image_remix, | |
| request_id, | |
| sora_client, | |
| prompt, | |
| image_data, | |
| request.n | |
| ) | |
| else: | |
| background_tasks.add_task( | |
| process_image_generation, | |
| request_id, | |
| sora_client, | |
| prompt, | |
| request.n, | |
| 720, # width | |
| 480 # height | |
| ) | |
| # 立即返回一个"正在处理中"的响应 | |
| response = { | |
| "id": request_id, | |
| "object": "chat.completion", | |
| "created": int(time.time()), | |
| "model": "sora-1.0", | |
| "choices": [ | |
| { | |
| "index": 0, | |
| "message": { | |
| "role": "assistant", | |
| "content": format_think_block(processing_message) | |
| }, | |
| "finish_reason": "processing" | |
| } | |
| ], | |
| "usage": { | |
| "prompt_tokens": len(prompt) // 4, | |
| "completion_tokens": 10, | |
| "total_tokens": len(prompt) // 4 + 10 | |
| } | |
| } | |
| success = True | |
| # 记录请求结果(非流式响应立即记录) | |
| response_time = time.time() - start_time | |
| key_manager.record_request_result(sora_auth_token, success, response_time) | |
| return JSONResponse(content=response) | |
| except Exception as e: | |
| success = False | |
| # 记录请求结果(异常情况也记录) | |
| response_time = time.time() - start_time | |
| key_manager.record_request_result(sora_auth_token, success, response_time) | |
| raise HTTPException(status_code=500, detail=f"图像生成失败: {str(e)}") | |
| # 流式响应生成器 - 普通文本到图像 | |
| async def generate_streaming_response( | |
| sora_client: SoraClient, | |
| prompt: str, | |
| n_images: int = 1 | |
| ): | |
| request_id = f"chatcmpl-{uuid.uuid4().hex}" | |
| # 发送开始事件 | |
| yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'role': 'assistant'}, 'finish_reason': None}]})}\n\n" | |
| # 发送处理中的消息(放在代码块中) | |
| start_msg = "```think\n正在生成图像,请稍候...\n" | |
| yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': start_msg}, 'finish_reason': None}]})}\n\n" | |
| # 创建一个后台任务来生成图像 | |
| logger.info(f"[流式响应 {request_id}] 开始生成图像, 提示词: {prompt}") | |
| generation_task = asyncio.create_task(sora_client.generate_image( | |
| prompt=prompt, | |
| num_images=n_images, | |
| width=720, | |
| height=480 | |
| )) | |
| # 每5秒发送一条"仍在生成中"的消息,防止连接超时 | |
| progress_messages = [ | |
| "正在处理您的请求...", | |
| "仍在生成图像中,请继续等待...", | |
| "Sora正在创作您的图像作品...", | |
| "图像生成需要一点时间,感谢您的耐心等待...", | |
| "我们正在努力为您创作高质量图像..." | |
| ] | |
| i = 0 | |
| while not generation_task.done(): | |
| # 每5秒发送一次进度消息 | |
| await asyncio.sleep(5) | |
| progress_msg = progress_messages[i % len(progress_messages)] | |
| i += 1 | |
| content = "\n" + progress_msg + "\n" | |
| yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': content}, 'finish_reason': None}]})}\n\n" | |
| try: | |
| # 获取生成结果 | |
| image_urls = await generation_task | |
| logger.info(f"[流式响应 {request_id}] 图像生成完成,获取到 {len(image_urls) if isinstance(image_urls, list) else '非列表'} 个URL") | |
| # 本地化图片URL | |
| if Config.IMAGE_LOCALIZATION and isinstance(image_urls, list) and image_urls: | |
| logger.info(f"[流式响应 {request_id}] 准备进行图片本地化处理") | |
| try: | |
| localized_urls = await localize_image_urls(image_urls) | |
| logger.info(f"[流式响应 {request_id}] 图片本地化处理完成") | |
| # 检查本地化结果 | |
| if not localized_urls: | |
| logger.warning(f"[流式响应 {request_id}] 本地化处理返回了空列表,将使用原始URL") | |
| localized_urls = image_urls | |
| # 检查是否所有URL都被正确本地化 | |
| local_count = sum(1 for url in localized_urls if url.startswith("/static/") or "/static/" in url) | |
| if local_count == 0: | |
| logger.warning(f"[流式响应 {request_id}] 警告:没有一个URL被成功本地化,将使用原始URL") | |
| localized_urls = image_urls | |
| else: | |
| logger.info(f"[流式响应 {request_id}] 成功本地化 {local_count}/{len(localized_urls)} 张图片") | |
| # 打印本地化对比结果 | |
| if logger.isEnabledFor(logging.DEBUG): | |
| for i, (orig, local) in enumerate(zip(image_urls, localized_urls)): | |
| logger.debug(f"[流式响应 {request_id}] 图片 {i+1}: {orig} -> {local}") | |
| image_urls = localized_urls | |
| except Exception as e: | |
| logger.error(f"[流式响应 {request_id}] 图片本地化过程中发生错误: {str(e)}") | |
| if logger.isEnabledFor(logging.DEBUG): | |
| import traceback | |
| logger.debug(traceback.format_exc()) | |
| logger.info(f"[流式响应 {request_id}] 由于错误,将使用原始URL") | |
| elif not Config.IMAGE_LOCALIZATION: | |
| logger.info(f"[流式响应 {request_id}] 图片本地化功能未启用") | |
| elif not isinstance(image_urls, list) or not image_urls: | |
| logger.warning(f"[流式响应 {request_id}] 无法进行本地化: 图像结果不是有效的URL列表") | |
| # 结束代码块 | |
| content_str = "\n```\n\n" | |
| yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': content_str}, 'finish_reason': None}]})}\n\n" | |
| # 添加生成的图片URLs | |
| for i, url in enumerate(image_urls): | |
| if i > 0: | |
| content_str = "\n\n" | |
| yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': content_str}, 'finish_reason': None}]})}\n\n" | |
| image_markdown = f"" | |
| yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': image_markdown}, 'finish_reason': None}]})}\n\n" | |
| # 发送完成事件 | |
| yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {}, 'finish_reason': 'stop'}]})}\n\n" | |
| # 发送结束标志 | |
| yield "data: [DONE]\n\n" | |
| except Exception as e: | |
| error_msg = f"图像生成失败: {str(e)}" | |
| logger.error(f"[流式响应 {request_id}] 错误: {error_msg}") | |
| if logger.isEnabledFor(logging.DEBUG): | |
| import traceback | |
| logger.debug(traceback.format_exc()) | |
| error_content = f"\n{error_msg}\n```" | |
| yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': error_content}, 'finish_reason': 'error'}]})}\n\n" | |
| yield "data: [DONE]\n\n" | |
| # 流式响应生成器 - 带图片的remix | |
| async def generate_streaming_remix_response( | |
| sora_client: SoraClient, | |
| prompt: str, | |
| image_data: str, | |
| n_images: int = 1 | |
| ): | |
| request_id = f"chatcmpl-{uuid.uuid4().hex}" | |
| # 发送开始事件 | |
| yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'role': 'assistant'}, 'finish_reason': None}]})}\n\n" | |
| try: | |
| # 保存base64图片到临时文件 | |
| temp_dir = tempfile.mkdtemp() | |
| temp_image_path = os.path.join(temp_dir, f"upload_{uuid.uuid4()}.png") | |
| try: | |
| # 解码并保存图片 | |
| with open(temp_image_path, "wb") as f: | |
| f.write(base64.b64decode(image_data)) | |
| # 上传图片 | |
| upload_msg = "```think\n上传图片中...\n" | |
| yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': upload_msg}, 'finish_reason': None}]})}\n\n" | |
| logger.info(f"[流式响应Remix {request_id}] 上传图片中") | |
| upload_result = await sora_client.upload_image(temp_image_path) | |
| media_id = upload_result['id'] | |
| # 发送生成中消息 | |
| generate_msg = "\n基于图片生成新图像中...\n" | |
| yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': generate_msg}, 'finish_reason': None}]})}\n\n" | |
| # 创建一个后台任务来生成图像 | |
| logger.info(f"[流式响应Remix {request_id}] 开始生成图像,提示词: {prompt}") | |
| generation_task = asyncio.create_task(sora_client.generate_image_remix( | |
| prompt=prompt, | |
| media_id=media_id, | |
| num_images=n_images | |
| )) | |
| # 每5秒发送一条"仍在生成中"的消息,防止连接超时 | |
| progress_messages = [ | |
| "正在处理您的请求...", | |
| "仍在生成图像中,请继续等待...", | |
| "Sora正在基于您的图片创作新作品...", | |
| "图像生成需要一点时间,感谢您的耐心等待...", | |
| "正在努力融合您的风格和提示词,打造专属图像..." | |
| ] | |
| i = 0 | |
| while not generation_task.done(): | |
| # 每5秒发送一次进度消息 | |
| await asyncio.sleep(5) | |
| progress_msg = progress_messages[i % len(progress_messages)] | |
| i += 1 | |
| content = "\n" + progress_msg + "\n" | |
| yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': content}, 'finish_reason': None}]})}\n\n" | |
| # 获取生成结果 | |
| image_urls = await generation_task | |
| logger.info(f"[流式响应Remix {request_id}] 图像生成完成") | |
| # 本地化图片URL | |
| if Config.IMAGE_LOCALIZATION: | |
| logger.info(f"[流式响应Remix {request_id}] 进行图片本地化处理") | |
| localized_urls = await localize_image_urls(image_urls) | |
| image_urls = localized_urls | |
| logger.info(f"[流式响应Remix {request_id}] 图片本地化处理完成") | |
| else: | |
| logger.info(f"[流式响应Remix {request_id}] 图片本地化功能未启用") | |
| # 结束代码块 | |
| content_str = "\n```\n\n" | |
| yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': content_str}, 'finish_reason': None}]})}\n\n" | |
| # 发送图片URL作为Markdown | |
| for i, url in enumerate(image_urls): | |
| if i > 0: | |
| newline_str = "\n\n" | |
| yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': newline_str}, 'finish_reason': None}]})}\n\n" | |
| image_markdown = f"" | |
| yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': image_markdown}, 'finish_reason': None}]})}\n\n" | |
| # 发送完成事件 | |
| yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {}, 'finish_reason': 'stop'}]})}\n\n" | |
| # 发送结束标志 | |
| yield "data: [DONE]\n\n" | |
| finally: | |
| # 清理临时文件 | |
| if os.path.exists(temp_image_path): | |
| os.remove(temp_image_path) | |
| if os.path.exists(temp_dir): | |
| os.rmdir(temp_dir) | |
| except Exception as e: | |
| error_msg = f"图像Remix失败: {str(e)}" | |
| logger.error(f"[流式响应Remix {request_id}] 错误: {error_msg}") | |
| if logger.isEnabledFor(logging.DEBUG): | |
| import traceback | |
| logger.debug(traceback.format_exc()) | |
| error_content = f"\n{error_msg}\n```" | |
| yield f"data: {json.dumps({'id': request_id, 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': 'sora-1.0', 'choices': [{'index': 0, 'delta': {'content': error_content}, 'finish_reason': 'error'}]})}\n\n" | |
| # 结束流 | |
| yield "data: [DONE]\n\n" | |
| # API密钥管理端点 | |
| @app.get("/api/keys") | |
| async def get_all_keys(admin_key: str = Depends(verify_admin)): | |
| """获取所有API密钥""" | |
| return key_manager.get_all_keys() | |
| @app.get("/api/keys/{key_id}") | |
| async def get_key(key_id: str, admin_key: str = Depends(verify_admin)): | |
| """获取单个API密钥详情""" | |
| key = key_manager.get_key_by_id(key_id) | |
| if not key: | |
| raise HTTPException(status_code=404, detail="密钥不存在") | |
| return key | |
| @app.post("/api/keys") | |
| async def create_key(key_data: ApiKeyCreate, admin_key: str = Depends(verify_admin)): | |
| """创建新API密钥""" | |
| try: | |
| # 确保密钥值包含 Bearer 前缀 | |
| key_value = key_data.key_value | |
| if not key_value.startswith("Bearer "): | |
| key_value = f"Bearer {key_value}" | |
| new_key = key_manager.add_key( | |
| key_value, | |
| name=key_data.name, | |
| weight=key_data.weight, | |
| rate_limit=key_data.rate_limit, | |
| is_enabled=key_data.is_enabled, | |
| notes=key_data.notes | |
| ) | |
| # 通过Config永久保存所有密钥 | |
| Config.save_api_keys(key_manager.keys) | |
| return new_key | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| @app.put("/api/keys/{key_id}") | |
| async def update_key(key_id: str, key_data: ApiKeyUpdate, admin_key: str = Depends(verify_admin)): | |
| """更新API密钥信息""" | |
| try: | |
| # 如果提供了新的密钥值,确保包含Bearer前缀 | |
| key_value = key_data.key_value | |
| if key_value and not key_value.startswith("Bearer "): | |
| key_value = f"Bearer {key_value}" | |
| key_data.key_value = key_value | |
| updated_key = key_manager.update_key( | |
| key_id, | |
| key_value=key_data.key_value, | |
| name=key_data.name, | |
| weight=key_data.weight, | |
| rate_limit=key_data.rate_limit, | |
| is_enabled=key_data.is_enabled, | |
| notes=key_data.notes | |
| ) | |
| if not updated_key: | |
| raise HTTPException(status_code=404, detail="密钥不存在") | |
| # 通过Config永久保存所有密钥 | |
| Config.save_api_keys(key_manager.keys) | |
| return updated_key | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| @app.delete("/api/keys/{key_id}") | |
| async def delete_key(key_id: str, admin_key: str = Depends(verify_admin)): | |
| """删除API密钥""" | |
| success = key_manager.delete_key(key_id) | |
| if not success: | |
| raise HTTPException(status_code=404, detail="密钥不存在") | |
| # 通过Config永久保存所有密钥 | |
| Config.save_api_keys(key_manager.keys) | |
| return {"status": "success", "message": "密钥已删除"} | |
| @app.get("/api/stats") | |
| async def get_usage_stats(admin_key: str = Depends(verify_admin)): | |
| """获取API使用统计""" | |
| return key_manager.get_usage_stats() | |
| @app.post("/api/keys/test") | |
| async def test_key(key_data: ApiKeyCreate, admin_key: str = Depends(verify_admin)): | |
| """测试API密钥是否有效""" | |
| try: | |
| # 确保密钥值包含 Bearer 前缀 | |
| key_value = key_data.key_value | |
| if not key_value.startswith("Bearer "): | |
| key_value = f"Bearer {key_value}" | |
| # 获取代理配置 | |
| proxy_host = Config.PROXY_HOST if Config.PROXY_HOST and Config.PROXY_HOST.strip() else None | |
| proxy_port = Config.PROXY_PORT if Config.PROXY_PORT and Config.PROXY_PORT.strip() else None | |
| # 创建临时客户端测试连接 | |
| test_client = SoraClient( | |
| proxy_host=proxy_host, | |
| proxy_port=proxy_port, | |
| auth_token=key_value | |
| ) | |
| # 执行简单API调用测试连接 | |
| test_result = await test_client.test_connection() | |
| return {"status": "success", "message": "API密钥测试成功", "details": test_result} | |
| except Exception as e: | |
| return {"status": "error", "message": f"API密钥测试失败: {str(e)}"} | |
| @app.post("/api/keys/batch") | |
| async def batch_operation(operation: Dict[str, Any], admin_key: str = Depends(verify_admin)): | |
| """批量操作API密钥""" | |
| action = operation.get("action") | |
| key_ids = operation.get("key_ids", []) | |
| if not action or not key_ids: | |
| raise HTTPException(status_code=400, detail="无效的请求参数") | |
| # 确保key_ids是一个列表 | |
| if isinstance(key_ids, str): | |
| key_ids = [key_ids] | |
| results = {} | |
| if action == "enable": | |
| for key_id in key_ids: | |
| success = key_manager.update_key(key_id, is_enabled=True) | |
| results[key_id] = "success" if success else "failed" | |
| elif action == "disable": | |
| for key_id in key_ids: | |
| success = key_manager.update_key(key_id, is_enabled=False) | |
| results[key_id] = "success" if success else "failed" | |
| elif action == "delete": | |
| for key_id in key_ids: | |
| success = key_manager.delete_key(key_id) | |
| results[key_id] = "success" if success else "failed" | |
| else: | |
| raise HTTPException(status_code=400, detail="不支持的操作类型") | |
| # 通过Config永久保存所有密钥 | |
| Config.save_api_keys(key_manager.keys) | |
| return {"status": "success", "results": results} | |
| # 健康检查端点 | |
| @app.get("/health") | |
| async def health_check(): | |
| return {"status": "ok", "timestamp": time.time()} | |
| # 管理界面路由 | |
| @app.get("/admin") | |
| async def admin_panel(): | |
| return FileResponse(os.path.join(Config.STATIC_DIR, "admin/index.html")) | |
| # 管理员密钥API | |
| @app.get("/admin/key") | |
| async def admin_key(): | |
| return {"admin_key": Config.ADMIN_KEY} | |
| # 挂载静态文件 | |
| app.mount("/admin", StaticFiles(directory=os.path.join(Config.STATIC_DIR, "admin"), html=True), name="admin") | |
| # 配置管理模型 | |
| class ConfigUpdate(BaseModel): | |
| IMAGE_LOCALIZATION: Optional[bool] = None | |
| IMAGE_SAVE_DIR: Optional[str] = None | |
| LOG_LEVEL: Optional[str] = None | |
| # 配置管理页面 | |
| @app.get("/admin/config") | |
| async def config_panel(): | |
| return FileResponse("src/static/admin/config.html") | |
| # 获取当前配置 | |
| @app.get("/api/config") | |
| async def get_config(admin_key: str = Depends(verify_admin)): | |
| """获取当前系统配置""" | |
| return { | |
| "IMAGE_LOCALIZATION": Config.IMAGE_LOCALIZATION, | |
| "IMAGE_SAVE_DIR": Config.IMAGE_SAVE_DIR, | |
| "LOG_LEVEL": LogConfig.LEVEL | |
| } | |
| # 更新配置 | |
| @app.post("/api/config") | |
| async def update_config(config_data: ConfigUpdate, admin_key: str = Depends(verify_admin)): | |
| """更新系统配置""" | |
| changes = {} | |
| if config_data.IMAGE_LOCALIZATION is not None: | |
| old_value = Config.IMAGE_LOCALIZATION | |
| Config.IMAGE_LOCALIZATION = config_data.IMAGE_LOCALIZATION | |
| os.environ["IMAGE_LOCALIZATION"] = str(config_data.IMAGE_LOCALIZATION) | |
| changes["IMAGE_LOCALIZATION"] = { | |
| "old": old_value, | |
| "new": Config.IMAGE_LOCALIZATION | |
| } | |
| if config_data.IMAGE_SAVE_DIR is not None: | |
| old_value = Config.IMAGE_SAVE_DIR | |
| Config.IMAGE_SAVE_DIR = config_data.IMAGE_SAVE_DIR | |
| os.environ["IMAGE_SAVE_DIR"] = config_data.IMAGE_SAVE_DIR | |
| # 确保目录存在 | |
| os.makedirs(Config.IMAGE_SAVE_DIR, exist_ok=True) | |
| changes["IMAGE_SAVE_DIR"] = { | |
| "old": old_value, | |
| "new": Config.IMAGE_SAVE_DIR | |
| } | |
| if config_data.LOG_LEVEL is not None: | |
| old_value = LogConfig.LEVEL | |
| level = config_data.LOG_LEVEL.upper() | |
| # 验证日志级别是否有效 | |
| valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] | |
| if level not in valid_levels: | |
| raise HTTPException(status_code=400, detail=f"无效的日志级别,有效值:{', '.join(valid_levels)}") | |
| # 更新日志级别 | |
| LogConfig.LEVEL = level | |
| logging.getLogger("sora-api").setLevel(getattr(logging, level)) | |
| os.environ["LOG_LEVEL"] = level | |
| changes["LOG_LEVEL"] = { | |
| "old": old_value, | |
| "new": level | |
| } | |
| logger.info(f"日志级别已更改为: {level}") | |
| # 保存到.env文件以持久化配置 | |
| try: | |
| dotenv_file = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), ".env") | |
| # 读取现有.env文件 | |
| env_vars = {} | |
| if os.path.exists(dotenv_file): | |
| with open(dotenv_file, "r") as f: | |
| for line in f: | |
| if line.strip() and not line.startswith("#"): | |
| key, value = line.strip().split("=", 1) | |
| env_vars[key] = value | |
| # 更新值 | |
| if config_data.IMAGE_LOCALIZATION is not None: | |
| env_vars["IMAGE_LOCALIZATION"] = str(config_data.IMAGE_LOCALIZATION) | |
| if config_data.IMAGE_SAVE_DIR is not None: | |
| env_vars["IMAGE_SAVE_DIR"] = config_data.IMAGE_SAVE_DIR | |
| if config_data.LOG_LEVEL is not None: | |
| env_vars["LOG_LEVEL"] = config_data.LOG_LEVEL.upper() | |
| # 写回文件 | |
| with open(dotenv_file, "w") as f: | |
| for key, value in env_vars.items(): | |
| f.write(f"{key}={value}\n") | |
| except Exception as e: | |
| logger.error(f"保存配置到.env文件失败: {str(e)}") | |
| return { | |
| "success": True, | |
| "message": "配置已更新", | |
| "changes": changes | |
| } | |
| # 日志级别控制 | |
| class LogLevelUpdate(BaseModel): | |
| level: str = Field(..., description="日志级别") | |
| @app.post("/api/logs/level") | |
| async def update_log_level(data: LogLevelUpdate, admin_key: str = Depends(verify_admin)): | |
| """更新系统日志级别""" | |
| level = data.level.upper() | |
| # 验证日志级别是否有效 | |
| valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] | |
| if level not in valid_levels: | |
| raise HTTPException(status_code=400, detail=f"无效的日志级别,有效值:{', '.join(valid_levels)}") | |
| # 更新日志级别 | |
| old_level = LogConfig.LEVEL | |
| LogConfig.LEVEL = level | |
| logging.getLogger("sora-api").setLevel(getattr(logging, level)) | |
| os.environ["LOG_LEVEL"] = level | |
| # 记录变更 | |
| logger.info(f"日志级别已更改为: {level}") | |
| # 更新.env文件 | |
| try: | |
| dotenv_file = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), ".env") | |
| # 读取现有.env文件 | |
| env_vars = {} | |
| if os.path.exists(dotenv_file): | |
| with open(dotenv_file, "r") as f: | |
| for line in f: | |
| if line.strip() and not line.startswith("#"): | |
| key, value = line.strip().split("=", 1) | |
| env_vars[key] = value | |
| # 更新日志级别 | |
| env_vars["LOG_LEVEL"] = level | |
| # 写回文件 | |
| with open(dotenv_file, "w") as f: | |
| for key, value in env_vars.items(): | |
| f.write(f"{key}={value}\n") | |
| except Exception as e: | |
| logger.warning(f"保存日志级别到.env文件失败: {str(e)}") | |
| return { | |
| "success": True, | |
| "message": f"日志级别已更改: {old_level} -> {level}" | |
| } |