""" Memory monitoring and management utilities for production deployment. """ import gc import logging import os import threading import time import tracemalloc from functools import wraps from typing import Any, Callable, Dict, Optional, Tuple, TypeVar, cast logger = logging.getLogger(__name__) # Environment flag to enable deeper / more frequent memory diagnostics MEMORY_DEBUG = os.getenv("MEMORY_DEBUG", "0") not in (None, "0", "false", "False") ENABLE_TRACEMALLOC = os.getenv("ENABLE_TRACEMALLOC", "0") not in ( None, "0", "false", "False", ) # Memory milestone thresholds (MB) which trigger enhanced logging once per run MEMORY_THRESHOLDS = [300, 400, 450, 500] _crossed_thresholds: "set[int]" = set() # type: ignore[type-arg] _tracemalloc_started = False _periodic_thread_started = False _periodic_thread: Optional[threading.Thread] = None def get_memory_usage() -> float: """ Get current memory usage in MB. Falls back to basic approach if psutil is not available. """ try: import psutil return psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 except ImportError: # Fallback: use tracemalloc if available try: current, peak = tracemalloc.get_traced_memory() return current / 1024 / 1024 except Exception: return 0.0 def log_memory_usage(context: str = "") -> float: """Log current memory usage with context and return the memory value.""" memory_mb = get_memory_usage() if context: logger.info(f"Memory usage ({context}): {memory_mb:.1f}MB") else: logger.info(f"Memory usage: {memory_mb:.1f}MB") return memory_mb def _collect_detailed_stats() -> Dict[str, Any]: """Collect additional (lightweight) diagnostics; guarded by MEMORY_DEBUG.""" stats: Dict[str, Any] = {} try: import psutil # type: ignore p = psutil.Process(os.getpid()) with p.oneshot(): mem = p.memory_info() stats["rss_mb"] = mem.rss / 1024 / 1024 stats["vms_mb"] = mem.vms / 1024 / 1024 stats["num_threads"] = p.num_threads() stats["open_files"] = len(p.open_files()) if hasattr(p, "open_files") else None except Exception: pass # tracemalloc snapshot (only if already tracing to avoid overhead) if tracemalloc.is_tracing(): try: current, peak = tracemalloc.get_traced_memory() stats["tracemalloc_current_mb"] = current / 1024 / 1024 stats["tracemalloc_peak_mb"] = peak / 1024 / 1024 except Exception: pass # GC counts are cheap try: stats["gc_counts"] = gc.get_count() except Exception: pass return stats def log_memory_checkpoint(context: str, force: bool = False): """Log a richer memory diagnostic line if MEMORY_DEBUG is enabled or force=True. Args: context: Label for where in code we are capturing this force: Override MEMORY_DEBUG gate """ if not (MEMORY_DEBUG or force): return base = get_memory_usage() stats = _collect_detailed_stats() logger.info( "[MEMORY CHECKPOINT] %s | rss=%.1fMB details=%s", context, base, stats, ) # Automatic milestone snapshot logging _maybe_log_milestone(base, context) # If tracemalloc enabled and memory above 380MB (pre-crit), log top allocations if ENABLE_TRACEMALLOC and base > 380: log_top_tracemalloc(f"high_mem_{context}") def start_tracemalloc(nframes: int = 25): """Start tracemalloc if enabled via environment flag.""" global _tracemalloc_started if ENABLE_TRACEMALLOC and not _tracemalloc_started: try: tracemalloc.start(nframes) _tracemalloc_started = True logger.info("tracemalloc started (nframes=%d)", nframes) except Exception as e: # pragma: no cover logger.warning(f"Failed to start tracemalloc: {e}") def log_top_tracemalloc(label: str, limit: int = 10): """Log top memory allocation traces if tracemalloc is running.""" if not tracemalloc.is_tracing(): return try: snapshot = tracemalloc.take_snapshot() top_stats = snapshot.statistics("lineno") logger.info("[TRACEMALLOC] Top %d allocations (%s)", limit, label) for stat in top_stats[:limit]: logger.info("[TRACEMALLOC] %s", stat) except Exception as e: # pragma: no cover logger.debug(f"Failed logging tracemalloc stats: {e}") def memory_summary(include_tracemalloc: bool = True) -> Dict[str, Any]: """Return a dictionary summary of current memory diagnostics.""" summary: Dict[str, Any] = {} summary["rss_mb"] = get_memory_usage() # Include which milestones crossed summary["milestones_crossed"] = sorted(list(_crossed_thresholds)) stats = _collect_detailed_stats() summary.update(stats) if include_tracemalloc and tracemalloc.is_tracing(): try: current, peak = tracemalloc.get_traced_memory() summary["tracemalloc_current_mb"] = current / 1024 / 1024 summary["tracemalloc_peak_mb"] = peak / 1024 / 1024 except Exception: pass return summary def start_periodic_memory_logger(interval_seconds: int = 60): """Start a background thread that logs memory every interval_seconds.""" global _periodic_thread_started, _periodic_thread if _periodic_thread_started: return def _runner(): logger.info( ("Periodic memory logger started (interval=%ds, " "debug=%s, tracemalloc=%s)"), interval_seconds, MEMORY_DEBUG, tracemalloc.is_tracing(), ) while True: try: log_memory_checkpoint("periodic", force=True) except Exception: # pragma: no cover logger.debug("Periodic memory logger iteration failed", exc_info=True) time.sleep(interval_seconds) _periodic_thread = threading.Thread(target=_runner, name="PeriodicMemoryLogger", daemon=True) _periodic_thread.start() _periodic_thread_started = True logger.info("Periodic memory logger thread started") R = TypeVar("R") def memory_monitor(func: Callable[..., R]) -> Callable[..., R]: """Decorator to monitor memory usage of functions.""" @wraps(func) def wrapper(*args: Tuple[Any, ...], **kwargs: Any): # type: ignore[override] memory_before = get_memory_usage() result = func(*args, **kwargs) memory_after = get_memory_usage() memory_diff = memory_after - memory_before logger.info( f"Memory change in {func.__name__}: " f"{memory_before:.1f}MB -> {memory_after:.1f}MB " f"(+{memory_diff:.1f}MB)" ) return result return cast(Callable[..., R], wrapper) def force_garbage_collection(): """Force garbage collection and log memory freed.""" memory_before = get_memory_usage() # Force garbage collection collected = gc.collect() memory_after = get_memory_usage() memory_freed = memory_before - memory_after logger.info(f"Garbage collection: freed {memory_freed:.1f}MB, " f"collected {collected} objects") def check_memory_threshold(threshold_mb: float = 400) -> bool: """ Check if memory usage exceeds threshold. Args: threshold_mb: Memory threshold in MB (default 400MB for 512MB limit) Returns: True if memory usage is above threshold """ current_memory = get_memory_usage() if current_memory > threshold_mb: logger.warning(f"Memory usage {current_memory:.1f}MB exceeds threshold {threshold_mb}MB") return True return False def clean_memory(context: str = ""): """ Clean memory and force garbage collection with context logging. Args: context: Description of when/why cleanup is happening """ memory_before = get_memory_usage() # Force garbage collection collected = gc.collect() memory_after = get_memory_usage() memory_freed = memory_before - memory_after if context: logger.info( f"Memory cleanup ({context}): " f"{memory_before:.1f}MB -> {memory_after:.1f}MB " f"(freed {memory_freed:.1f}MB, collected {collected} objects)" ) else: logger.info(f"Memory cleanup: freed {memory_freed:.1f}MB, collected {collected} objects") def optimize_memory(): """ Perform memory optimization operations. Called when memory usage gets high. """ logger.info("Performing memory optimization...") # Force garbage collection force_garbage_collection() # Clear any model caches if they exist try: from src.embedding.embedding_service import EmbeddingService if hasattr(EmbeddingService, "_model_cache"): cache_attr = getattr(EmbeddingService, "_model_cache") # type: ignore[attr-defined] try: cache_size = len(cache_attr) # Keep at least one model cached if cache_size > 1: keys = list(cache_attr.keys()) for key in keys[:-1]: del cache_attr[key] logger.info( "Cleared %d cached models, kept 1", cache_size - 1, ) except Exception as e: # pragma: no cover logger.debug("Failed clearing model cache: %s", e) except Exception as e: logger.debug("Could not clear model cache: %s", e) class MemoryManager: """Context manager for memory-intensive operations.""" def __init__(self, operation_name: str = "operation", threshold_mb: float = 400): self.operation_name = operation_name self.threshold_mb = threshold_mb self.start_memory: Optional[float] = None def __enter__(self): self.start_memory = get_memory_usage() logger.info(f"Starting {self.operation_name} (Memory: {self.start_memory:.1f}MB)") # Check if we're already near the threshold if self.start_memory > self.threshold_mb: logger.warning("Starting operation with high memory usage") optimize_memory() return self def __exit__( self, exc_type: Optional[type], exc_val: Optional[BaseException], exc_tb: Optional[Any], ) -> None: end_memory = get_memory_usage() memory_diff = end_memory - (self.start_memory or 0) logger.info( f"Completed {self.operation_name} " f"(Memory: {self.start_memory:.1f}MB -> {end_memory:.1f}MB, " f"Change: {memory_diff:+.1f}MB)" ) # If memory usage increased significantly, trigger cleanup if memory_diff > 50: # More than 50MB increase logger.info("Large memory increase detected, running cleanup") force_garbage_collection() # Capture a post-cleanup checkpoint if deep debugging enabled log_memory_checkpoint(f"post_cleanup_{self.operation_name}") # ---------- Milestone & force-clean helpers ---------- # def _maybe_log_milestone(current_mb: float, context: str): """Internal: log when crossing defined memory thresholds.""" for threshold in MEMORY_THRESHOLDS: if current_mb >= threshold and threshold not in _crossed_thresholds: _crossed_thresholds.add(threshold) logger.warning( "[MEMORY MILESTONE] %.1fMB crossed threshold %dMB " "(context=%s)", current_mb, threshold, context, ) # Provide immediate snapshot & optionally top allocations details = memory_summary(include_tracemalloc=True) logger.info("[MEMORY SNAPSHOT @%dMB] summary=%s", threshold, details) if ENABLE_TRACEMALLOC and tracemalloc.is_tracing(): log_top_tracemalloc(f"milestone_{threshold}MB") def force_clean_and_report(label: str = "manual") -> Dict[str, Any]: """Force GC + optimization and return post-clean summary.""" logger.info("Force clean invoked (%s)", label) force_garbage_collection() optimize_memory() summary = memory_summary(include_tracemalloc=True) logger.info("Post-clean memory summary (%s): %s", label, summary) return summary