msse-ai-engineering / src /utils /memory_utils.py
sethmcknight
Refactor test cases for improved readability and consistency
159faf0
"""
Memory monitoring and management utilities for production deployment.
"""
import gc
import logging
import os
import threading
import time
import tracemalloc
from functools import wraps
from typing import Any, Callable, Dict, Optional, Tuple, TypeVar, cast
logger = logging.getLogger(__name__)
# Environment flag to enable deeper / more frequent memory diagnostics
MEMORY_DEBUG = os.getenv("MEMORY_DEBUG", "0") not in (None, "0", "false", "False")
ENABLE_TRACEMALLOC = os.getenv("ENABLE_TRACEMALLOC", "0") not in (
None,
"0",
"false",
"False",
)
# Memory milestone thresholds (MB) which trigger enhanced logging once per run
MEMORY_THRESHOLDS = [300, 400, 450, 500]
_crossed_thresholds: "set[int]" = set() # type: ignore[type-arg]
_tracemalloc_started = False
_periodic_thread_started = False
_periodic_thread: Optional[threading.Thread] = None
def get_memory_usage() -> float:
"""
Get current memory usage in MB.
Falls back to basic approach if psutil is not available.
"""
try:
import psutil
return psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024
except ImportError:
# Fallback: use tracemalloc if available
try:
current, peak = tracemalloc.get_traced_memory()
return current / 1024 / 1024
except Exception:
return 0.0
def log_memory_usage(context: str = "") -> float:
"""Log current memory usage with context and return the memory value."""
memory_mb = get_memory_usage()
if context:
logger.info(f"Memory usage ({context}): {memory_mb:.1f}MB")
else:
logger.info(f"Memory usage: {memory_mb:.1f}MB")
return memory_mb
def _collect_detailed_stats() -> Dict[str, Any]:
"""Collect additional (lightweight) diagnostics; guarded by MEMORY_DEBUG."""
stats: Dict[str, Any] = {}
try:
import psutil # type: ignore
p = psutil.Process(os.getpid())
with p.oneshot():
mem = p.memory_info()
stats["rss_mb"] = mem.rss / 1024 / 1024
stats["vms_mb"] = mem.vms / 1024 / 1024
stats["num_threads"] = p.num_threads()
stats["open_files"] = len(p.open_files()) if hasattr(p, "open_files") else None
except Exception:
pass
# tracemalloc snapshot (only if already tracing to avoid overhead)
if tracemalloc.is_tracing():
try:
current, peak = tracemalloc.get_traced_memory()
stats["tracemalloc_current_mb"] = current / 1024 / 1024
stats["tracemalloc_peak_mb"] = peak / 1024 / 1024
except Exception:
pass
# GC counts are cheap
try:
stats["gc_counts"] = gc.get_count()
except Exception:
pass
return stats
def log_memory_checkpoint(context: str, force: bool = False):
"""Log a richer memory diagnostic line if MEMORY_DEBUG is enabled or force=True.
Args:
context: Label for where in code we are capturing this
force: Override MEMORY_DEBUG gate
"""
if not (MEMORY_DEBUG or force):
return
base = get_memory_usage()
stats = _collect_detailed_stats()
logger.info(
"[MEMORY CHECKPOINT] %s | rss=%.1fMB details=%s",
context,
base,
stats,
)
# Automatic milestone snapshot logging
_maybe_log_milestone(base, context)
# If tracemalloc enabled and memory above 380MB (pre-crit), log top allocations
if ENABLE_TRACEMALLOC and base > 380:
log_top_tracemalloc(f"high_mem_{context}")
def start_tracemalloc(nframes: int = 25):
"""Start tracemalloc if enabled via environment flag."""
global _tracemalloc_started
if ENABLE_TRACEMALLOC and not _tracemalloc_started:
try:
tracemalloc.start(nframes)
_tracemalloc_started = True
logger.info("tracemalloc started (nframes=%d)", nframes)
except Exception as e: # pragma: no cover
logger.warning(f"Failed to start tracemalloc: {e}")
def log_top_tracemalloc(label: str, limit: int = 10):
"""Log top memory allocation traces if tracemalloc is running."""
if not tracemalloc.is_tracing():
return
try:
snapshot = tracemalloc.take_snapshot()
top_stats = snapshot.statistics("lineno")
logger.info("[TRACEMALLOC] Top %d allocations (%s)", limit, label)
for stat in top_stats[:limit]:
logger.info("[TRACEMALLOC] %s", stat)
except Exception as e: # pragma: no cover
logger.debug(f"Failed logging tracemalloc stats: {e}")
def memory_summary(include_tracemalloc: bool = True) -> Dict[str, Any]:
"""Return a dictionary summary of current memory diagnostics."""
summary: Dict[str, Any] = {}
summary["rss_mb"] = get_memory_usage()
# Include which milestones crossed
summary["milestones_crossed"] = sorted(list(_crossed_thresholds))
stats = _collect_detailed_stats()
summary.update(stats)
if include_tracemalloc and tracemalloc.is_tracing():
try:
current, peak = tracemalloc.get_traced_memory()
summary["tracemalloc_current_mb"] = current / 1024 / 1024
summary["tracemalloc_peak_mb"] = peak / 1024 / 1024
except Exception:
pass
return summary
def start_periodic_memory_logger(interval_seconds: int = 60):
"""Start a background thread that logs memory every interval_seconds."""
global _periodic_thread_started, _periodic_thread
if _periodic_thread_started:
return
def _runner():
logger.info(
("Periodic memory logger started (interval=%ds, " "debug=%s, tracemalloc=%s)"),
interval_seconds,
MEMORY_DEBUG,
tracemalloc.is_tracing(),
)
while True:
try:
log_memory_checkpoint("periodic", force=True)
except Exception: # pragma: no cover
logger.debug("Periodic memory logger iteration failed", exc_info=True)
time.sleep(interval_seconds)
_periodic_thread = threading.Thread(target=_runner, name="PeriodicMemoryLogger", daemon=True)
_periodic_thread.start()
_periodic_thread_started = True
logger.info("Periodic memory logger thread started")
R = TypeVar("R")
def memory_monitor(func: Callable[..., R]) -> Callable[..., R]:
"""Decorator to monitor memory usage of functions."""
@wraps(func)
def wrapper(*args: Tuple[Any, ...], **kwargs: Any): # type: ignore[override]
memory_before = get_memory_usage()
result = func(*args, **kwargs)
memory_after = get_memory_usage()
memory_diff = memory_after - memory_before
logger.info(
f"Memory change in {func.__name__}: "
f"{memory_before:.1f}MB -> {memory_after:.1f}MB "
f"(+{memory_diff:.1f}MB)"
)
return result
return cast(Callable[..., R], wrapper)
def force_garbage_collection():
"""Force garbage collection and log memory freed."""
memory_before = get_memory_usage()
# Force garbage collection
collected = gc.collect()
memory_after = get_memory_usage()
memory_freed = memory_before - memory_after
logger.info(f"Garbage collection: freed {memory_freed:.1f}MB, " f"collected {collected} objects")
def check_memory_threshold(threshold_mb: float = 400) -> bool:
"""
Check if memory usage exceeds threshold.
Args:
threshold_mb: Memory threshold in MB (default 400MB for 512MB limit)
Returns:
True if memory usage is above threshold
"""
current_memory = get_memory_usage()
if current_memory > threshold_mb:
logger.warning(f"Memory usage {current_memory:.1f}MB exceeds threshold {threshold_mb}MB")
return True
return False
def clean_memory(context: str = ""):
"""
Clean memory and force garbage collection with context logging.
Args:
context: Description of when/why cleanup is happening
"""
memory_before = get_memory_usage()
# Force garbage collection
collected = gc.collect()
memory_after = get_memory_usage()
memory_freed = memory_before - memory_after
if context:
logger.info(
f"Memory cleanup ({context}): "
f"{memory_before:.1f}MB -> {memory_after:.1f}MB "
f"(freed {memory_freed:.1f}MB, collected {collected} objects)"
)
else:
logger.info(f"Memory cleanup: freed {memory_freed:.1f}MB, collected {collected} objects")
def optimize_memory():
"""
Perform memory optimization operations.
Called when memory usage gets high.
"""
logger.info("Performing memory optimization...")
# Force garbage collection
force_garbage_collection()
# Clear any model caches if they exist
try:
from src.embedding.embedding_service import EmbeddingService
if hasattr(EmbeddingService, "_model_cache"):
cache_attr = getattr(EmbeddingService, "_model_cache")
# type: ignore[attr-defined]
try:
cache_size = len(cache_attr)
# Keep at least one model cached
if cache_size > 1:
keys = list(cache_attr.keys())
for key in keys[:-1]:
del cache_attr[key]
logger.info(
"Cleared %d cached models, kept 1",
cache_size - 1,
)
except Exception as e: # pragma: no cover
logger.debug("Failed clearing model cache: %s", e)
except Exception as e:
logger.debug("Could not clear model cache: %s", e)
class MemoryManager:
"""Context manager for memory-intensive operations."""
def __init__(self, operation_name: str = "operation", threshold_mb: float = 400):
self.operation_name = operation_name
self.threshold_mb = threshold_mb
self.start_memory: Optional[float] = None
def __enter__(self):
self.start_memory = get_memory_usage()
logger.info(f"Starting {self.operation_name} (Memory: {self.start_memory:.1f}MB)")
# Check if we're already near the threshold
if self.start_memory > self.threshold_mb:
logger.warning("Starting operation with high memory usage")
optimize_memory()
return self
def __exit__(
self,
exc_type: Optional[type],
exc_val: Optional[BaseException],
exc_tb: Optional[Any],
) -> None:
end_memory = get_memory_usage()
memory_diff = end_memory - (self.start_memory or 0)
logger.info(
f"Completed {self.operation_name} "
f"(Memory: {self.start_memory:.1f}MB -> {end_memory:.1f}MB, "
f"Change: {memory_diff:+.1f}MB)"
)
# If memory usage increased significantly, trigger cleanup
if memory_diff > 50: # More than 50MB increase
logger.info("Large memory increase detected, running cleanup")
force_garbage_collection()
# Capture a post-cleanup checkpoint if deep debugging enabled
log_memory_checkpoint(f"post_cleanup_{self.operation_name}")
# ---------- Milestone & force-clean helpers ---------- #
def _maybe_log_milestone(current_mb: float, context: str):
"""Internal: log when crossing defined memory thresholds."""
for threshold in MEMORY_THRESHOLDS:
if current_mb >= threshold and threshold not in _crossed_thresholds:
_crossed_thresholds.add(threshold)
logger.warning(
"[MEMORY MILESTONE] %.1fMB crossed threshold %dMB " "(context=%s)",
current_mb,
threshold,
context,
)
# Provide immediate snapshot & optionally top allocations
details = memory_summary(include_tracemalloc=True)
logger.info("[MEMORY SNAPSHOT @%dMB] summary=%s", threshold, details)
if ENABLE_TRACEMALLOC and tracemalloc.is_tracing():
log_top_tracemalloc(f"milestone_{threshold}MB")
def force_clean_and_report(label: str = "manual") -> Dict[str, Any]:
"""Force GC + optimization and return post-clean summary."""
logger.info("Force clean invoked (%s)", label)
force_garbage_collection()
optimize_memory()
summary = memory_summary(include_tracemalloc=True)
logger.info("Post-clean memory summary (%s): %s", label, summary)
return summary