Spaces:
Sleeping
Sleeping
| """ | |
| Memory monitoring and management utilities for production deployment. | |
| """ | |
| import gc | |
| import logging | |
| import os | |
| import threading | |
| import time | |
| import tracemalloc | |
| from functools import wraps | |
| from typing import Any, Callable, Dict, Optional, Tuple, TypeVar, cast | |
| logger = logging.getLogger(__name__) | |
| # Environment flag to enable deeper / more frequent memory diagnostics | |
| MEMORY_DEBUG = os.getenv("MEMORY_DEBUG", "0") not in (None, "0", "false", "False") | |
| ENABLE_TRACEMALLOC = os.getenv("ENABLE_TRACEMALLOC", "0") not in ( | |
| None, | |
| "0", | |
| "false", | |
| "False", | |
| ) | |
| # Memory milestone thresholds (MB) which trigger enhanced logging once per run | |
| MEMORY_THRESHOLDS = [300, 400, 450, 500] | |
| _crossed_thresholds: "set[int]" = set() # type: ignore[type-arg] | |
| _tracemalloc_started = False | |
| _periodic_thread_started = False | |
| _periodic_thread: Optional[threading.Thread] = None | |
| def get_memory_usage() -> float: | |
| """ | |
| Get current memory usage in MB. | |
| Falls back to basic approach if psutil is not available. | |
| """ | |
| try: | |
| import psutil | |
| return psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 | |
| except ImportError: | |
| # Fallback: use tracemalloc if available | |
| try: | |
| current, peak = tracemalloc.get_traced_memory() | |
| return current / 1024 / 1024 | |
| except Exception: | |
| return 0.0 | |
| def log_memory_usage(context: str = "") -> float: | |
| """Log current memory usage with context and return the memory value.""" | |
| memory_mb = get_memory_usage() | |
| if context: | |
| logger.info(f"Memory usage ({context}): {memory_mb:.1f}MB") | |
| else: | |
| logger.info(f"Memory usage: {memory_mb:.1f}MB") | |
| return memory_mb | |
| def _collect_detailed_stats() -> Dict[str, Any]: | |
| """Collect additional (lightweight) diagnostics; guarded by MEMORY_DEBUG.""" | |
| stats: Dict[str, Any] = {} | |
| try: | |
| import psutil # type: ignore | |
| p = psutil.Process(os.getpid()) | |
| with p.oneshot(): | |
| mem = p.memory_info() | |
| stats["rss_mb"] = mem.rss / 1024 / 1024 | |
| stats["vms_mb"] = mem.vms / 1024 / 1024 | |
| stats["num_threads"] = p.num_threads() | |
| stats["open_files"] = len(p.open_files()) if hasattr(p, "open_files") else None | |
| except Exception: | |
| pass | |
| # tracemalloc snapshot (only if already tracing to avoid overhead) | |
| if tracemalloc.is_tracing(): | |
| try: | |
| current, peak = tracemalloc.get_traced_memory() | |
| stats["tracemalloc_current_mb"] = current / 1024 / 1024 | |
| stats["tracemalloc_peak_mb"] = peak / 1024 / 1024 | |
| except Exception: | |
| pass | |
| # GC counts are cheap | |
| try: | |
| stats["gc_counts"] = gc.get_count() | |
| except Exception: | |
| pass | |
| return stats | |
| def log_memory_checkpoint(context: str, force: bool = False): | |
| """Log a richer memory diagnostic line if MEMORY_DEBUG is enabled or force=True. | |
| Args: | |
| context: Label for where in code we are capturing this | |
| force: Override MEMORY_DEBUG gate | |
| """ | |
| if not (MEMORY_DEBUG or force): | |
| return | |
| base = get_memory_usage() | |
| stats = _collect_detailed_stats() | |
| logger.info( | |
| "[MEMORY CHECKPOINT] %s | rss=%.1fMB details=%s", | |
| context, | |
| base, | |
| stats, | |
| ) | |
| # Automatic milestone snapshot logging | |
| _maybe_log_milestone(base, context) | |
| # If tracemalloc enabled and memory above 380MB (pre-crit), log top allocations | |
| if ENABLE_TRACEMALLOC and base > 380: | |
| log_top_tracemalloc(f"high_mem_{context}") | |
| def start_tracemalloc(nframes: int = 25): | |
| """Start tracemalloc if enabled via environment flag.""" | |
| global _tracemalloc_started | |
| if ENABLE_TRACEMALLOC and not _tracemalloc_started: | |
| try: | |
| tracemalloc.start(nframes) | |
| _tracemalloc_started = True | |
| logger.info("tracemalloc started (nframes=%d)", nframes) | |
| except Exception as e: # pragma: no cover | |
| logger.warning(f"Failed to start tracemalloc: {e}") | |
| def log_top_tracemalloc(label: str, limit: int = 10): | |
| """Log top memory allocation traces if tracemalloc is running.""" | |
| if not tracemalloc.is_tracing(): | |
| return | |
| try: | |
| snapshot = tracemalloc.take_snapshot() | |
| top_stats = snapshot.statistics("lineno") | |
| logger.info("[TRACEMALLOC] Top %d allocations (%s)", limit, label) | |
| for stat in top_stats[:limit]: | |
| logger.info("[TRACEMALLOC] %s", stat) | |
| except Exception as e: # pragma: no cover | |
| logger.debug(f"Failed logging tracemalloc stats: {e}") | |
| def memory_summary(include_tracemalloc: bool = True) -> Dict[str, Any]: | |
| """Return a dictionary summary of current memory diagnostics.""" | |
| summary: Dict[str, Any] = {} | |
| summary["rss_mb"] = get_memory_usage() | |
| # Include which milestones crossed | |
| summary["milestones_crossed"] = sorted(list(_crossed_thresholds)) | |
| stats = _collect_detailed_stats() | |
| summary.update(stats) | |
| if include_tracemalloc and tracemalloc.is_tracing(): | |
| try: | |
| current, peak = tracemalloc.get_traced_memory() | |
| summary["tracemalloc_current_mb"] = current / 1024 / 1024 | |
| summary["tracemalloc_peak_mb"] = peak / 1024 / 1024 | |
| except Exception: | |
| pass | |
| return summary | |
| def start_periodic_memory_logger(interval_seconds: int = 60): | |
| """Start a background thread that logs memory every interval_seconds.""" | |
| global _periodic_thread_started, _periodic_thread | |
| if _periodic_thread_started: | |
| return | |
| def _runner(): | |
| logger.info( | |
| ("Periodic memory logger started (interval=%ds, " "debug=%s, tracemalloc=%s)"), | |
| interval_seconds, | |
| MEMORY_DEBUG, | |
| tracemalloc.is_tracing(), | |
| ) | |
| while True: | |
| try: | |
| log_memory_checkpoint("periodic", force=True) | |
| except Exception: # pragma: no cover | |
| logger.debug("Periodic memory logger iteration failed", exc_info=True) | |
| time.sleep(interval_seconds) | |
| _periodic_thread = threading.Thread(target=_runner, name="PeriodicMemoryLogger", daemon=True) | |
| _periodic_thread.start() | |
| _periodic_thread_started = True | |
| logger.info("Periodic memory logger thread started") | |
| R = TypeVar("R") | |
| def memory_monitor(func: Callable[..., R]) -> Callable[..., R]: | |
| """Decorator to monitor memory usage of functions.""" | |
| def wrapper(*args: Tuple[Any, ...], **kwargs: Any): # type: ignore[override] | |
| memory_before = get_memory_usage() | |
| result = func(*args, **kwargs) | |
| memory_after = get_memory_usage() | |
| memory_diff = memory_after - memory_before | |
| logger.info( | |
| f"Memory change in {func.__name__}: " | |
| f"{memory_before:.1f}MB -> {memory_after:.1f}MB " | |
| f"(+{memory_diff:.1f}MB)" | |
| ) | |
| return result | |
| return cast(Callable[..., R], wrapper) | |
| def force_garbage_collection(): | |
| """Force garbage collection and log memory freed.""" | |
| memory_before = get_memory_usage() | |
| # Force garbage collection | |
| collected = gc.collect() | |
| memory_after = get_memory_usage() | |
| memory_freed = memory_before - memory_after | |
| logger.info(f"Garbage collection: freed {memory_freed:.1f}MB, " f"collected {collected} objects") | |
| def check_memory_threshold(threshold_mb: float = 400) -> bool: | |
| """ | |
| Check if memory usage exceeds threshold. | |
| Args: | |
| threshold_mb: Memory threshold in MB (default 400MB for 512MB limit) | |
| Returns: | |
| True if memory usage is above threshold | |
| """ | |
| current_memory = get_memory_usage() | |
| if current_memory > threshold_mb: | |
| logger.warning(f"Memory usage {current_memory:.1f}MB exceeds threshold {threshold_mb}MB") | |
| return True | |
| return False | |
| def clean_memory(context: str = ""): | |
| """ | |
| Clean memory and force garbage collection with context logging. | |
| Args: | |
| context: Description of when/why cleanup is happening | |
| """ | |
| memory_before = get_memory_usage() | |
| # Force garbage collection | |
| collected = gc.collect() | |
| memory_after = get_memory_usage() | |
| memory_freed = memory_before - memory_after | |
| if context: | |
| logger.info( | |
| f"Memory cleanup ({context}): " | |
| f"{memory_before:.1f}MB -> {memory_after:.1f}MB " | |
| f"(freed {memory_freed:.1f}MB, collected {collected} objects)" | |
| ) | |
| else: | |
| logger.info(f"Memory cleanup: freed {memory_freed:.1f}MB, collected {collected} objects") | |
| def optimize_memory(): | |
| """ | |
| Perform memory optimization operations. | |
| Called when memory usage gets high. | |
| """ | |
| logger.info("Performing memory optimization...") | |
| # Force garbage collection | |
| force_garbage_collection() | |
| # Clear any model caches if they exist | |
| try: | |
| from src.embedding.embedding_service import EmbeddingService | |
| if hasattr(EmbeddingService, "_model_cache"): | |
| cache_attr = getattr(EmbeddingService, "_model_cache") | |
| # type: ignore[attr-defined] | |
| try: | |
| cache_size = len(cache_attr) | |
| # Keep at least one model cached | |
| if cache_size > 1: | |
| keys = list(cache_attr.keys()) | |
| for key in keys[:-1]: | |
| del cache_attr[key] | |
| logger.info( | |
| "Cleared %d cached models, kept 1", | |
| cache_size - 1, | |
| ) | |
| except Exception as e: # pragma: no cover | |
| logger.debug("Failed clearing model cache: %s", e) | |
| except Exception as e: | |
| logger.debug("Could not clear model cache: %s", e) | |
| class MemoryManager: | |
| """Context manager for memory-intensive operations.""" | |
| def __init__(self, operation_name: str = "operation", threshold_mb: float = 400): | |
| self.operation_name = operation_name | |
| self.threshold_mb = threshold_mb | |
| self.start_memory: Optional[float] = None | |
| def __enter__(self): | |
| self.start_memory = get_memory_usage() | |
| logger.info(f"Starting {self.operation_name} (Memory: {self.start_memory:.1f}MB)") | |
| # Check if we're already near the threshold | |
| if self.start_memory > self.threshold_mb: | |
| logger.warning("Starting operation with high memory usage") | |
| optimize_memory() | |
| return self | |
| def __exit__( | |
| self, | |
| exc_type: Optional[type], | |
| exc_val: Optional[BaseException], | |
| exc_tb: Optional[Any], | |
| ) -> None: | |
| end_memory = get_memory_usage() | |
| memory_diff = end_memory - (self.start_memory or 0) | |
| logger.info( | |
| f"Completed {self.operation_name} " | |
| f"(Memory: {self.start_memory:.1f}MB -> {end_memory:.1f}MB, " | |
| f"Change: {memory_diff:+.1f}MB)" | |
| ) | |
| # If memory usage increased significantly, trigger cleanup | |
| if memory_diff > 50: # More than 50MB increase | |
| logger.info("Large memory increase detected, running cleanup") | |
| force_garbage_collection() | |
| # Capture a post-cleanup checkpoint if deep debugging enabled | |
| log_memory_checkpoint(f"post_cleanup_{self.operation_name}") | |
| # ---------- Milestone & force-clean helpers ---------- # | |
| def _maybe_log_milestone(current_mb: float, context: str): | |
| """Internal: log when crossing defined memory thresholds.""" | |
| for threshold in MEMORY_THRESHOLDS: | |
| if current_mb >= threshold and threshold not in _crossed_thresholds: | |
| _crossed_thresholds.add(threshold) | |
| logger.warning( | |
| "[MEMORY MILESTONE] %.1fMB crossed threshold %dMB " "(context=%s)", | |
| current_mb, | |
| threshold, | |
| context, | |
| ) | |
| # Provide immediate snapshot & optionally top allocations | |
| details = memory_summary(include_tracemalloc=True) | |
| logger.info("[MEMORY SNAPSHOT @%dMB] summary=%s", threshold, details) | |
| if ENABLE_TRACEMALLOC and tracemalloc.is_tracing(): | |
| log_top_tracemalloc(f"milestone_{threshold}MB") | |
| def force_clean_and_report(label: str = "manual") -> Dict[str, Any]: | |
| """Force GC + optimization and return post-clean summary.""" | |
| logger.info("Force clean invoked (%s)", label) | |
| force_garbage_collection() | |
| optimize_memory() | |
| summary = memory_summary(include_tracemalloc=True) | |
| logger.info("Post-clean memory summary (%s): %s", label, summary) | |
| return summary | |