""" Device management for ML model loading. Handles GPU/CPU detection, allocation, and memory management. """ from dataclasses import dataclass from typing import Optional, List, Literal import logging import os logger = logging.getLogger(__name__) DeviceType = Literal["cuda", "mps", "cpu"] @dataclass class DeviceInfo: """Information about a compute device.""" name: str type: DeviceType index: int total_memory_mb: float free_memory_mb: float is_available: bool class DeviceManager: """ Centralized device management for ML models. Handles device detection, selection, and memory tracking across CUDA, MPS (Apple Silicon), and CPU backends. Example: >>> dm = DeviceManager() >>> device = dm.get_best_device() >>> print(device) # "cuda:0" or "mps" or "cpu" >>> # Allocate with preference >>> device = dm.allocate_device(preference="cuda:1") """ _instance: Optional["DeviceManager"] = None def __new__(cls) -> "DeviceManager": """Singleton pattern for device manager.""" if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._initialized = False return cls._instance def __init__(self) -> None: if self._initialized: return self._initialized = True self._allocated_devices: dict[str, int] = {} # device -> allocation count self._refresh_devices() def _refresh_devices(self) -> None: """Refresh device information.""" self._cuda_available = False self._mps_available = False self._cuda_devices: List[DeviceInfo] = [] # Check CUDA try: import torch if torch.cuda.is_available(): self._cuda_available = True for i in range(torch.cuda.device_count()): props = torch.cuda.get_device_properties(i) free_mem, total_mem = torch.cuda.mem_get_info(i) self._cuda_devices.append( DeviceInfo( name=props.name, type="cuda", index=i, total_memory_mb=total_mem / 1024 / 1024, free_memory_mb=free_mem / 1024 / 1024, is_available=True, ) ) # Check MPS (Apple Silicon) if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): self._mps_available = True except ImportError: logger.debug("PyTorch not available, GPU detection disabled") @staticmethod def get_best_device() -> str: """ Get the best available device. Priority: CUDA > MPS > CPU Returns: Device string (e.g., "cuda:0", "mps", "cpu") """ try: import torch if torch.cuda.is_available(): return "cuda:0" if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): return "mps" except ImportError: pass return "cpu" @staticmethod def get_device_count() -> int: """ Get the number of available GPU devices. Returns: Number of GPUs (0 if none or PyTorch not available) """ try: import torch if torch.cuda.is_available(): return torch.cuda.device_count() if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): return 1 except ImportError: pass return 0 def allocate_device(self, preference: Optional[str] = None) -> str: """ Allocate a device with optional preference. If preference is not available, falls back to best available. Args: preference: Preferred device (e.g., "cuda:0", "cuda:1", "mps", "cpu") Returns: Allocated device string """ if preference: if self._is_device_available(preference): self._allocated_devices[preference] = ( self._allocated_devices.get(preference, 0) + 1 ) return preference else: logger.warning( f"Preferred device '{preference}' not available, using best available" ) device = self.get_best_device() self._allocated_devices[device] = self._allocated_devices.get(device, 0) + 1 return device def release_device(self, device: str) -> None: """ Release a previously allocated device. Args: device: Device to release """ if device in self._allocated_devices: self._allocated_devices[device] -= 1 if self._allocated_devices[device] <= 0: del self._allocated_devices[device] def _is_device_available(self, device: str) -> bool: """Check if a specific device is available.""" if device == "cpu": return True try: import torch if device.startswith("cuda"): if not torch.cuda.is_available(): return False if ":" in device: idx = int(device.split(":")[1]) return idx < torch.cuda.device_count() return True if device == "mps": return ( hasattr(torch.backends, "mps") and torch.backends.mps.is_available() ) except ImportError: pass return False def get_cuda_devices(self) -> List[DeviceInfo]: """Get list of available CUDA devices.""" self._refresh_devices() return self._cuda_devices def get_device_with_most_memory(self) -> str: """ Get the device with the most free memory. Returns: Device string """ self._refresh_devices() if not self._cuda_devices: if self._mps_available: return "mps" return "cpu" best = max(self._cuda_devices, key=lambda d: d.free_memory_mb) return f"cuda:{best.index}" @staticmethod def get_memory_info(device: str) -> tuple[float, float]: """ Get memory info for a device. Args: device: Device string Returns: Tuple of (free_mb, total_mb) """ if device == "cpu": import psutil mem = psutil.virtual_memory() return mem.available / 1024 / 1024, mem.total / 1024 / 1024 try: import torch if device.startswith("cuda"): idx = int(device.split(":")[1]) if ":" in device else 0 free, total = torch.cuda.mem_get_info(idx) return free / 1024 / 1024, total / 1024 / 1024 except (ImportError, RuntimeError): pass return 0.0, 0.0 @staticmethod def clear_cache(device: Optional[str] = None) -> None: """ Clear GPU cache to free memory. Args: device: Specific device to clear, or None for all """ try: import torch if torch.cuda.is_available(): if device and device.startswith("cuda"): idx = int(device.split(":")[1]) if ":" in device else 0 with torch.cuda.device(idx): torch.cuda.empty_cache() else: torch.cuda.empty_cache() if hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"): torch.mps.empty_cache() except ImportError: pass @staticmethod def set_default_device(device: str) -> None: """ Set the default device for tensor operations. Args: device: Device to set as default """ try: import torch if device.startswith("cuda"): idx = int(device.split(":")[1]) if ":" in device else 0 torch.cuda.set_device(idx) except ImportError: pass # Convenience functions def get_best_device() -> str: """Get the best available device.""" return DeviceManager.get_best_device() def get_device_count() -> int: """Get the number of available GPU devices.""" return DeviceManager.get_device_count() def allocate_device(preference: Optional[str] = None) -> str: """Allocate a device with optional preference.""" return DeviceManager().allocate_device(preference) def clear_gpu_cache() -> None: """Clear GPU memory cache.""" DeviceManager.clear_cache()