Spaces:
Running
Running
| import os | |
| import logging | |
| import dotenv | |
| from typing import Dict, Any, List | |
| from fastapi import APIRouter, Depends, HTTPException | |
| from fastapi.responses import FileResponse | |
| from ..models.schemas import ApiKeyCreate, ApiKeyUpdate, ConfigUpdate, LogLevelUpdate, BatchOperation, BatchImportOperation | |
| from ..api.dependencies import verify_admin, verify_admin_jwt | |
| from ..config import Config | |
| from ..key_manager import key_manager | |
| from ..sora_integration import SoraClient | |
| # 设置日志 | |
| logger = logging.getLogger("sora-api.admin") | |
| # 日志系统配置 | |
| class LogConfig: | |
| LEVEL = os.getenv("LOG_LEVEL", "WARNING").upper() | |
| FORMAT = "%(asctime)s [%(levelname)s] %(message)s" | |
| # 创建路由 | |
| router = APIRouter(prefix="/api") | |
| # 密钥管理API | |
| async def get_all_keys(admin_token = Depends(verify_admin_jwt)): | |
| """获取所有API密钥""" | |
| return key_manager.get_all_keys() | |
| async def get_key(key_id: str, admin_token = Depends(verify_admin_jwt)): | |
| """获取单个API密钥详情""" | |
| key = key_manager.get_key_by_id(key_id) | |
| if not key: | |
| raise HTTPException(status_code=404, detail="密钥不存在") | |
| return key | |
| async def create_key(key_data: ApiKeyCreate, admin_token = Depends(verify_admin_jwt)): | |
| """创建新API密钥""" | |
| try: | |
| # 确保密钥值包含 Bearer 前缀 | |
| key_value = key_data.key_value | |
| if not key_value.startswith("Bearer "): | |
| key_value = f"Bearer {key_value}" | |
| new_key = key_manager.add_key( | |
| key_value, | |
| name=key_data.name, | |
| weight=key_data.weight, | |
| rate_limit=key_data.rate_limit, | |
| is_enabled=key_data.is_enabled, | |
| notes=key_data.notes | |
| ) | |
| # 通过Config永久保存所有密钥 | |
| Config.save_api_keys(key_manager.keys) | |
| return new_key | |
| except Exception as e: | |
| logger.error(f"创建密钥失败: {str(e)}", exc_info=True) | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| async def update_key(key_id: str, key_data: ApiKeyUpdate, admin_token = Depends(verify_admin_jwt)): | |
| """更新API密钥信息""" | |
| try: | |
| # 如果提供了新的密钥值,确保包含Bearer前缀 | |
| key_value = key_data.key_value | |
| if key_value and not key_value.startswith("Bearer "): | |
| key_value = f"Bearer {key_value}" | |
| key_data.key_value = key_value | |
| updated_key = key_manager.update_key( | |
| key_id, | |
| key_value=key_data.key_value, | |
| name=key_data.name, | |
| weight=key_data.weight, | |
| rate_limit=key_data.rate_limit, | |
| is_enabled=key_data.is_enabled, | |
| notes=key_data.notes | |
| ) | |
| if not updated_key: | |
| raise HTTPException(status_code=404, detail="密钥不存在") | |
| # 通过Config永久保存所有密钥 | |
| Config.save_api_keys(key_manager.keys) | |
| return updated_key | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"更新密钥失败: {str(e)}", exc_info=True) | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| async def delete_key(key_id: str, admin_token = Depends(verify_admin_jwt)): | |
| """删除API密钥""" | |
| success = key_manager.delete_key(key_id) | |
| if not success: | |
| raise HTTPException(status_code=404, detail="密钥不存在") | |
| # 通过Config永久保存所有密钥 | |
| Config.save_api_keys(key_manager.keys) | |
| return {"status": "success", "message": "密钥已删除"} | |
| async def get_usage_stats(admin_token = Depends(verify_admin_jwt)): | |
| """获取API使用统计""" | |
| stats = key_manager.get_usage_stats() | |
| # 处理daily_usage数据,确保前端能够正确显示 | |
| daily_usage = {} | |
| keys_usage = {} | |
| # 从past_7_days数据转换为daily_usage格式 | |
| for date, counts in stats.get("past_7_days", {}).items(): | |
| daily_usage[date] = counts.get("successful", 0) + counts.get("failed", 0) | |
| # 获取每个密钥的使用情况 | |
| for key in key_manager.keys: | |
| key_id = key.get("id") | |
| key_name = key.get("name") or f"密钥_{key_id[:6]}" | |
| # 获取该密钥的使用统计 | |
| if key_id in key_manager.usage_stats: | |
| key_stats = key_manager.usage_stats[key_id] | |
| total_requests = key_stats.get("total_requests", 0) | |
| if total_requests > 0: | |
| keys_usage[key_name] = total_requests | |
| # 添加到返回数据中 | |
| stats["daily_usage"] = daily_usage | |
| stats["keys_usage"] = keys_usage | |
| return stats | |
| async def test_key(key_data: ApiKeyCreate, admin_token = Depends(verify_admin_jwt)): | |
| """测试API密钥是否有效""" | |
| try: | |
| # 获取密钥值 | |
| key_value = key_data.key_value.strip() | |
| # 确保密钥格式正确 | |
| if not key_value.startswith("Bearer "): | |
| key_value = f"Bearer {key_value}" | |
| # 获取代理配置 | |
| proxy_host = Config.PROXY_HOST if Config.PROXY_HOST and Config.PROXY_HOST.strip() else None | |
| proxy_port = Config.PROXY_PORT if Config.PROXY_PORT and Config.PROXY_PORT.strip() else None | |
| proxy_user = Config.PROXY_USER if Config.PROXY_USER and Config.PROXY_USER.strip() else None | |
| proxy_pass = Config.PROXY_PASS if Config.PROXY_PASS and Config.PROXY_PASS.strip() else None | |
| test_client = SoraClient( | |
| proxy_host=proxy_host, | |
| proxy_port=proxy_port, | |
| proxy_user=proxy_user, | |
| proxy_pass=proxy_pass, | |
| auth_token=key_value | |
| ) | |
| # 执行简单API调用测试连接 | |
| test_result = await test_client.test_connection() | |
| logger.info(f"API密钥测试结果: {test_result}") | |
| # 检查底层测试结果的状态 | |
| if test_result.get("status") == "success": | |
| # API连接测试成功 | |
| return { | |
| "status": "success", | |
| "message": "API密钥测试成功", | |
| "details": test_result, | |
| "success": True | |
| } | |
| else: | |
| # API连接测试失败 | |
| return { | |
| "status": "error", | |
| "message": f"API密钥测试失败: {test_result.get('message', '连接失败')}", | |
| "details": test_result, | |
| "success": False | |
| } | |
| except Exception as e: | |
| logger.error(f"测试密钥失败: {str(e)}", exc_info=True) | |
| return { | |
| "status": "error", | |
| "message": f"API密钥测试失败: {str(e)}", | |
| "success": False | |
| } | |
| async def batch_operation(operation: Dict[str, Any], admin_token = Depends(verify_admin_jwt)): | |
| """批量操作API密钥""" | |
| try: | |
| action = operation.get("action") | |
| logger.info(f"接收到批量操作请求: {action}") | |
| if not action: | |
| logger.warning("批量操作缺少action参数") | |
| raise HTTPException(status_code=400, detail="缺少必要参数: action") | |
| logger.info(f"批量操作类型: {action}") | |
| if action == "import": | |
| # 批量导入API密钥 | |
| keys_data = operation.get("keys", []) | |
| if not keys_data: | |
| logger.warning("批量导入缺少keys数据") | |
| raise HTTPException(status_code=400, detail="未提供密钥数据") | |
| logger.info(f"准备导入 {len(keys_data)} 个密钥") | |
| # 对每个密钥处理Bearer前缀 | |
| for key_data in keys_data: | |
| if isinstance(key_data, dict): | |
| key_value = key_data.get("key", "").strip() | |
| if key_value and not key_value.startswith("Bearer "): | |
| key_data["key"] = f"Bearer {key_value}" | |
| # 执行批量导入 | |
| try: | |
| result = key_manager.batch_import_keys(keys_data) | |
| logger.info(f"导入结果: 成功={result['imported']}, 跳过={result['skipped']}") | |
| # 通过Config永久保存所有密钥 | |
| Config.save_api_keys(key_manager.keys) | |
| return { | |
| "success": True, | |
| "message": f"成功导入 {result['imported']} 个密钥,跳过 {result['skipped']} 个重复密钥", | |
| "imported": result["imported"], | |
| "skipped": result["skipped"] | |
| } | |
| except Exception as e: | |
| logger.error(f"批量导入密钥错误: {str(e)}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=f"导入密钥失败: {str(e)}") | |
| elif action not in ["enable", "disable", "delete"]: | |
| logger.warning(f"不支持的批量操作: {action}") | |
| raise HTTPException(status_code=400, detail=f"不支持的操作: {action}") | |
| # 对于非导入操作,需要提供key_ids | |
| key_ids = operation.get("key_ids", []) | |
| if not key_ids: | |
| logger.warning(f"{action}操作缺少key_ids参数") | |
| raise HTTPException(status_code=400, detail="缺少必要参数: key_ids") | |
| # 确保key_ids是一个列表 | |
| if isinstance(key_ids, str): | |
| key_ids = [key_ids] | |
| logger.info(f"批量{action}操作 {len(key_ids)} 个密钥") | |
| if action == "enable": | |
| # 批量启用 | |
| success_count = 0 | |
| for key_id in key_ids: | |
| updated = key_manager.update_key(key_id, is_enabled=True) | |
| if updated: | |
| success_count += 1 | |
| # 通过Config永久保存所有密钥 | |
| Config.save_api_keys(key_manager.keys) | |
| logger.info(f"成功启用 {success_count} 个密钥") | |
| return { | |
| "success": True, | |
| "message": f"已成功启用 {success_count} 个密钥", | |
| "affected": success_count | |
| } | |
| elif action == "disable": | |
| # 批量禁用 | |
| success_count = 0 | |
| for key_id in key_ids: | |
| updated = key_manager.update_key(key_id, is_enabled=False) | |
| if updated: | |
| success_count += 1 | |
| # 通过Config永久保存所有密钥 | |
| Config.save_api_keys(key_manager.keys) | |
| logger.info(f"成功禁用 {success_count} 个密钥") | |
| return { | |
| "success": True, | |
| "message": f"已成功禁用 {success_count} 个密钥", | |
| "affected": success_count | |
| } | |
| elif action == "delete": | |
| # 批量删除 | |
| success_count = 0 | |
| for key_id in key_ids: | |
| if key_manager.delete_key(key_id): | |
| success_count += 1 | |
| # 通过Config永久保存所有密钥 | |
| Config.save_api_keys(key_manager.keys) | |
| logger.info(f"成功删除 {success_count} 个密钥") | |
| return { | |
| "success": True, | |
| "message": f"已成功删除 {success_count} 个密钥", | |
| "affected": success_count | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"批量操作失败: {str(e)}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # 配置管理API | |
| async def get_config(admin_token = Depends(verify_admin_jwt)): | |
| """获取当前系统配置""" | |
| return { | |
| "HOST": Config.HOST, | |
| "PORT": Config.PORT, | |
| "BASE_URL": Config.BASE_URL, | |
| "PROXY_HOST": Config.PROXY_HOST, | |
| "PROXY_PORT": Config.PROXY_PORT, | |
| "PROXY_USER": Config.PROXY_USER, | |
| "PROXY_PASS": "******" if Config.PROXY_PASS else "", | |
| "IMAGE_LOCALIZATION": Config.IMAGE_LOCALIZATION, | |
| "IMAGE_SAVE_DIR": Config.IMAGE_SAVE_DIR, | |
| "API_AUTH_TOKEN": bool(Config.API_AUTH_TOKEN) # 只返回是否设置,不返回实际值 | |
| } | |
| async def update_config(config_data: ConfigUpdate, admin_token = Depends(verify_admin_jwt)): | |
| """更新系统配置""" | |
| try: | |
| changes = [] | |
| # 更新代理设置 | |
| if config_data.PROXY_HOST is not None: | |
| Config.PROXY_HOST = config_data.PROXY_HOST | |
| changes.append("PROXY_HOST") | |
| # 更新环境变量 | |
| os.environ["PROXY_HOST"] = config_data.PROXY_HOST | |
| if config_data.PROXY_PORT is not None: | |
| Config.PROXY_PORT = config_data.PROXY_PORT | |
| changes.append("PROXY_PORT") | |
| # 更新环境变量 | |
| os.environ["PROXY_PORT"] = config_data.PROXY_PORT | |
| # 更新代理认证设置 | |
| if config_data.PROXY_USER is not None: | |
| Config.PROXY_USER = config_data.PROXY_USER | |
| changes.append("PROXY_USER") | |
| # 更新环境变量 | |
| os.environ["PROXY_USER"] = config_data.PROXY_USER | |
| if config_data.PROXY_PASS is not None: | |
| Config.PROXY_PASS = config_data.PROXY_PASS | |
| changes.append("PROXY_PASS") | |
| # 更新环境变量 | |
| os.environ["PROXY_PASS"] = config_data.PROXY_PASS | |
| # 更新基础URL设置 | |
| if config_data.BASE_URL is not None: | |
| Config.BASE_URL = config_data.BASE_URL | |
| changes.append("BASE_URL") | |
| # 更新环境变量 | |
| os.environ["BASE_URL"] = config_data.BASE_URL | |
| # 更新图片本地化设置 | |
| if config_data.IMAGE_LOCALIZATION is not None: | |
| Config.IMAGE_LOCALIZATION = config_data.IMAGE_LOCALIZATION | |
| changes.append("IMAGE_LOCALIZATION") | |
| # 更新环境变量 | |
| os.environ["IMAGE_LOCALIZATION"] = str(config_data.IMAGE_LOCALIZATION) | |
| if config_data.IMAGE_SAVE_DIR is not None: | |
| Config.IMAGE_SAVE_DIR = config_data.IMAGE_SAVE_DIR | |
| changes.append("IMAGE_SAVE_DIR") | |
| # 更新环境变量 | |
| os.environ["IMAGE_SAVE_DIR"] = config_data.IMAGE_SAVE_DIR | |
| # 确保目录存在 | |
| os.makedirs(Config.IMAGE_SAVE_DIR, exist_ok=True) | |
| # 保存到.env文件 | |
| if config_data.save_to_env and changes: | |
| env_file = os.path.join(Config.BASE_DIR, '.env') | |
| env_data = {} | |
| # 先读取现有的.env文件 | |
| if os.path.exists(env_file): | |
| env_data = dotenv.dotenv_values(env_file) | |
| # 更新环境变量 | |
| for field in changes: | |
| value = getattr(Config, field) | |
| env_data[field] = str(value) | |
| # 写入.env文件 | |
| with open(env_file, 'w') as f: | |
| for key, value in env_data.items(): | |
| f.write(f"{key}={value}\n") | |
| logger.info(f"已将配置保存到.env文件: {changes}") | |
| return { | |
| "status": "success", | |
| "message": f"配置已更新: {', '.join(changes) if changes else '无变更'}" | |
| } | |
| except Exception as e: | |
| logger.error(f"更新配置失败: {str(e)}", exc_info=True) | |
| raise HTTPException(status_code=400, detail=f"更新配置失败: {str(e)}") | |
| async def update_log_level(data: LogLevelUpdate, admin_token = Depends(verify_admin_jwt)): | |
| """更新日志级别""" | |
| try: | |
| # 验证日志级别 | |
| level = data.level.upper() | |
| valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] | |
| if level not in valid_levels: | |
| raise HTTPException(status_code=400, detail=f"无效的日志级别: {level}") | |
| # 更新根日志记录器的级别 | |
| root_logger = logging.getLogger() | |
| root_logger.setLevel(getattr(logging, level)) | |
| # 同时更新sora-api模块的日志级别 | |
| sora_logger = logging.getLogger("sora-api") | |
| sora_logger.setLevel(getattr(logging, level)) | |
| # 记录日志级别变更 | |
| logger.info(f"日志级别已更新为: {level}") | |
| # 如果需要,保存到环境变量 | |
| if data.save_to_env: | |
| env_file = os.path.join(Config.BASE_DIR, '.env') | |
| env_data = {} | |
| # 先读取现有的.env文件 | |
| if os.path.exists(env_file): | |
| env_data = dotenv.dotenv_values(env_file) | |
| # 更新LOG_LEVEL环境变量 | |
| env_data["LOG_LEVEL"] = level | |
| # 写入.env文件 | |
| with open(env_file, 'w') as f: | |
| for key, value in env_data.items(): | |
| f.write(f"{key}={value}\n") | |
| # 记录配置保存 | |
| logger.info(f"已将日志级别保存到.env文件: LOG_LEVEL={level}") | |
| return {"status": "success", "message": f"日志级别已更新为: {level}"} | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"更新日志级别失败: {str(e)}", exc_info=True) | |
| raise HTTPException(status_code=400, detail=f"更新日志级别失败: {str(e)}") | |
| # 管理员密钥API - 已移至app.py中 |