|
|
|
|
|
"""
|
|
|
AI & ML API Router
|
|
|
==================
|
|
|
API endpoints for AI predictions, backtesting, and ML training
|
|
|
"""
|
|
|
|
|
|
from fastapi import APIRouter, HTTPException, Depends, Body, Query, Path
|
|
|
from fastapi.responses import JSONResponse
|
|
|
from typing import Optional, List, Dict, Any
|
|
|
from pydantic import BaseModel, Field
|
|
|
from datetime import datetime
|
|
|
from sqlalchemy.orm import Session
|
|
|
import logging
|
|
|
|
|
|
from backend.services.backtesting_service import BacktestingService
|
|
|
from backend.services.ml_training_service import MLTrainingService
|
|
|
from database.db_manager import db_manager
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
router = APIRouter(
|
|
|
prefix="/api/ai",
|
|
|
tags=["AI & ML"]
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BacktestRequest(BaseModel):
|
|
|
"""Request model for starting a backtest."""
|
|
|
strategy: str = Field(..., description="Strategy name (e.g., 'simple_moving_average', 'rsi_strategy', 'macd_strategy')")
|
|
|
symbol: str = Field(..., description="Trading pair (e.g., 'BTC/USDT')")
|
|
|
start_date: datetime = Field(..., description="Backtest start date")
|
|
|
end_date: datetime = Field(..., description="Backtest end date")
|
|
|
initial_capital: float = Field(..., gt=0, description="Starting capital for backtest")
|
|
|
|
|
|
|
|
|
class TrainingRequest(BaseModel):
|
|
|
"""Request model for starting ML training."""
|
|
|
model_name: str = Field(..., description="Name of the model to train")
|
|
|
training_data_start: datetime = Field(..., description="Start date for training data")
|
|
|
training_data_end: datetime = Field(..., description="End date for training data")
|
|
|
batch_size: int = Field(32, gt=0, description="Training batch size")
|
|
|
learning_rate: Optional[float] = Field(None, gt=0, description="Learning rate")
|
|
|
config: Optional[Dict[str, Any]] = Field(None, description="Additional training configuration")
|
|
|
|
|
|
|
|
|
class TrainingStepRequest(BaseModel):
|
|
|
"""Request model for executing a training step."""
|
|
|
step_number: int = Field(..., ge=1, description="Step number")
|
|
|
loss: Optional[float] = Field(None, description="Training loss")
|
|
|
accuracy: Optional[float] = Field(None, ge=0, le=1, description="Training accuracy")
|
|
|
learning_rate: Optional[float] = Field(None, gt=0, description="Current learning rate")
|
|
|
metrics: Optional[Dict[str, Any]] = Field(None, description="Additional metrics")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_db() -> Session:
|
|
|
"""Get database session."""
|
|
|
db = db_manager.SessionLocal()
|
|
|
try:
|
|
|
yield db
|
|
|
finally:
|
|
|
db.close()
|
|
|
|
|
|
|
|
|
def get_backtesting_service(db: Session = Depends(get_db)) -> BacktestingService:
|
|
|
"""Get backtesting service instance."""
|
|
|
return BacktestingService(db)
|
|
|
|
|
|
|
|
|
def get_ml_training_service(db: Session = Depends(get_db)) -> MLTrainingService:
|
|
|
"""Get ML training service instance."""
|
|
|
return MLTrainingService(db)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/backtest")
|
|
|
async def start_backtest(
|
|
|
backtest_request: BacktestRequest,
|
|
|
service: BacktestingService = Depends(get_backtesting_service)
|
|
|
) -> JSONResponse:
|
|
|
"""
|
|
|
Start a backtest for a specific strategy.
|
|
|
|
|
|
Runs a backtest simulation using historical data and returns comprehensive
|
|
|
performance metrics including total return, Sharpe ratio, max drawdown, and win rate.
|
|
|
|
|
|
Args:
|
|
|
backtest_request: Backtest configuration
|
|
|
service: Backtesting service instance
|
|
|
|
|
|
Returns:
|
|
|
JSON response with backtest results
|
|
|
"""
|
|
|
try:
|
|
|
|
|
|
if backtest_request.end_date <= backtest_request.start_date:
|
|
|
raise ValueError("end_date must be after start_date")
|
|
|
|
|
|
|
|
|
results = service.start_backtest(
|
|
|
strategy=backtest_request.strategy,
|
|
|
symbol=backtest_request.symbol,
|
|
|
start_date=backtest_request.start_date,
|
|
|
end_date=backtest_request.end_date,
|
|
|
initial_capital=backtest_request.initial_capital
|
|
|
)
|
|
|
|
|
|
return JSONResponse(
|
|
|
status_code=200,
|
|
|
content={
|
|
|
"success": True,
|
|
|
"message": "Backtest completed successfully",
|
|
|
"data": results
|
|
|
}
|
|
|
)
|
|
|
|
|
|
except ValueError as e:
|
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error running backtest: {e}", exc_info=True)
|
|
|
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
|
|
|
|
|
|
|
|
@router.post("/train")
|
|
|
async def start_training(
|
|
|
training_request: TrainingRequest,
|
|
|
service: MLTrainingService = Depends(get_ml_training_service)
|
|
|
) -> JSONResponse:
|
|
|
"""
|
|
|
Start training a model.
|
|
|
|
|
|
Initiates the model training process with specified configuration.
|
|
|
|
|
|
Args:
|
|
|
training_request: Training configuration
|
|
|
service: ML training service instance
|
|
|
|
|
|
Returns:
|
|
|
JSON response with training job details
|
|
|
"""
|
|
|
try:
|
|
|
job = service.start_training(
|
|
|
model_name=training_request.model_name,
|
|
|
training_data_start=training_request.training_data_start,
|
|
|
training_data_end=training_request.training_data_end,
|
|
|
batch_size=training_request.batch_size,
|
|
|
learning_rate=training_request.learning_rate,
|
|
|
config=training_request.config
|
|
|
)
|
|
|
|
|
|
return JSONResponse(
|
|
|
status_code=201,
|
|
|
content={
|
|
|
"success": True,
|
|
|
"message": "Training job created successfully",
|
|
|
"data": job
|
|
|
}
|
|
|
)
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error starting training: {e}", exc_info=True)
|
|
|
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
|
|
|
|
|
|
|
|
@router.post("/train-step")
|
|
|
async def execute_training_step(
|
|
|
job_id: str = Query(..., description="Training job ID"),
|
|
|
step_request: TrainingStepRequest = Body(...),
|
|
|
service: MLTrainingService = Depends(get_ml_training_service)
|
|
|
) -> JSONResponse:
|
|
|
"""
|
|
|
Execute a training step.
|
|
|
|
|
|
Records a single training step with metrics.
|
|
|
|
|
|
Args:
|
|
|
job_id: Training job ID
|
|
|
step_request: Training step data
|
|
|
service: ML training service instance
|
|
|
|
|
|
Returns:
|
|
|
JSON response with step details
|
|
|
"""
|
|
|
try:
|
|
|
step = service.execute_training_step(
|
|
|
job_id=job_id,
|
|
|
step_number=step_request.step_number,
|
|
|
loss=step_request.loss,
|
|
|
accuracy=step_request.accuracy,
|
|
|
learning_rate=step_request.learning_rate,
|
|
|
metrics=step_request.metrics
|
|
|
)
|
|
|
|
|
|
return JSONResponse(
|
|
|
status_code=200,
|
|
|
content={
|
|
|
"success": True,
|
|
|
"message": "Training step executed successfully",
|
|
|
"data": step
|
|
|
}
|
|
|
)
|
|
|
|
|
|
except ValueError as e:
|
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error executing training step: {e}", exc_info=True)
|
|
|
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
|
|
|
|
|
|
|
|
@router.get("/train/status")
|
|
|
async def get_training_status(
|
|
|
job_id: str = Query(..., description="Training job ID"),
|
|
|
service: MLTrainingService = Depends(get_ml_training_service)
|
|
|
) -> JSONResponse:
|
|
|
"""
|
|
|
Get the current training status.
|
|
|
|
|
|
Retrieves the current status and metrics for a training job.
|
|
|
|
|
|
Args:
|
|
|
job_id: Training job ID
|
|
|
service: ML training service instance
|
|
|
|
|
|
Returns:
|
|
|
JSON response with training status
|
|
|
"""
|
|
|
try:
|
|
|
status = service.get_training_status(job_id)
|
|
|
|
|
|
return JSONResponse(
|
|
|
status_code=200,
|
|
|
content={
|
|
|
"success": True,
|
|
|
"data": status
|
|
|
}
|
|
|
)
|
|
|
|
|
|
except ValueError as e:
|
|
|
raise HTTPException(status_code=404, detail=str(e))
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error getting training status: {e}", exc_info=True)
|
|
|
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
|
|
|
|
|
|
|
|
@router.get("/train/history")
|
|
|
async def get_training_history(
|
|
|
model_name: Optional[str] = Query(None, description="Filter by model name"),
|
|
|
limit: int = Query(100, ge=1, le=1000, description="Maximum number of jobs to return"),
|
|
|
service: MLTrainingService = Depends(get_ml_training_service)
|
|
|
) -> JSONResponse:
|
|
|
"""
|
|
|
Get training history.
|
|
|
|
|
|
Retrieves the training history for all models or a specific model.
|
|
|
|
|
|
Args:
|
|
|
model_name: Optional model name filter
|
|
|
limit: Maximum number of jobs to return
|
|
|
service: ML training service instance
|
|
|
|
|
|
Returns:
|
|
|
JSON response with training history
|
|
|
"""
|
|
|
try:
|
|
|
history = service.get_training_history(
|
|
|
model_name=model_name,
|
|
|
limit=limit
|
|
|
)
|
|
|
|
|
|
return JSONResponse(
|
|
|
status_code=200,
|
|
|
content={
|
|
|
"success": True,
|
|
|
"count": len(history),
|
|
|
"data": history
|
|
|
}
|
|
|
)
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error retrieving training history: {e}", exc_info=True)
|
|
|
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
|
|
|
|
|
|