ml-model-loader/python/lilith_model_loader/device.py
2025-12-28 04:32:35 -08:00

315 lines
8.7 KiB
Python

"""
Device management for ML model loading.
Handles GPU/CPU detection, allocation, and memory management.
"""
from dataclasses import dataclass
from typing import Optional, List, Literal
import logging
import os
logger = logging.getLogger(__name__)
DeviceType = Literal["cuda", "mps", "cpu"]
@dataclass
class DeviceInfo:
"""Information about a compute device."""
name: str
type: DeviceType
index: int
total_memory_mb: float
free_memory_mb: float
is_available: bool
class DeviceManager:
"""
Centralized device management for ML models.
Handles device detection, selection, and memory tracking across
CUDA, MPS (Apple Silicon), and CPU backends.
Example:
>>> dm = DeviceManager()
>>> device = dm.get_best_device()
>>> print(device) # "cuda:0" or "mps" or "cpu"
>>> # Allocate with preference
>>> device = dm.allocate_device(preference="cuda:1")
"""
_instance: Optional["DeviceManager"] = None
def __new__(cls) -> "DeviceManager":
"""Singleton pattern for device manager."""
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self) -> None:
if self._initialized:
return
self._initialized = True
self._allocated_devices: dict[str, int] = {} # device -> allocation count
self._refresh_devices()
def _refresh_devices(self) -> None:
"""Refresh device information."""
self._cuda_available = False
self._mps_available = False
self._cuda_devices: List[DeviceInfo] = []
# Check CUDA
try:
import torch
if torch.cuda.is_available():
self._cuda_available = True
for i in range(torch.cuda.device_count()):
props = torch.cuda.get_device_properties(i)
free_mem, total_mem = torch.cuda.mem_get_info(i)
self._cuda_devices.append(
DeviceInfo(
name=props.name,
type="cuda",
index=i,
total_memory_mb=total_mem / 1024 / 1024,
free_memory_mb=free_mem / 1024 / 1024,
is_available=True,
)
)
# Check MPS (Apple Silicon)
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
self._mps_available = True
except ImportError:
logger.debug("PyTorch not available, GPU detection disabled")
@staticmethod
def get_best_device() -> str:
"""
Get the best available device.
Priority: CUDA > MPS > CPU
Returns:
Device string (e.g., "cuda:0", "mps", "cpu")
"""
try:
import torch
if torch.cuda.is_available():
return "cuda:0"
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return "mps"
except ImportError:
pass
return "cpu"
@staticmethod
def get_device_count() -> int:
"""
Get the number of available GPU devices.
Returns:
Number of GPUs (0 if none or PyTorch not available)
"""
try:
import torch
if torch.cuda.is_available():
return torch.cuda.device_count()
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return 1
except ImportError:
pass
return 0
def allocate_device(self, preference: Optional[str] = None) -> str:
"""
Allocate a device with optional preference.
If preference is not available, falls back to best available.
Args:
preference: Preferred device (e.g., "cuda:0", "cuda:1", "mps", "cpu")
Returns:
Allocated device string
"""
if preference:
if self._is_device_available(preference):
self._allocated_devices[preference] = (
self._allocated_devices.get(preference, 0) + 1
)
return preference
else:
logger.warning(
f"Preferred device '{preference}' not available, using best available"
)
device = self.get_best_device()
self._allocated_devices[device] = self._allocated_devices.get(device, 0) + 1
return device
def release_device(self, device: str) -> None:
"""
Release a previously allocated device.
Args:
device: Device to release
"""
if device in self._allocated_devices:
self._allocated_devices[device] -= 1
if self._allocated_devices[device] <= 0:
del self._allocated_devices[device]
def _is_device_available(self, device: str) -> bool:
"""Check if a specific device is available."""
if device == "cpu":
return True
try:
import torch
if device.startswith("cuda"):
if not torch.cuda.is_available():
return False
if ":" in device:
idx = int(device.split(":")[1])
return idx < torch.cuda.device_count()
return True
if device == "mps":
return (
hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
)
except ImportError:
pass
return False
def get_cuda_devices(self) -> List[DeviceInfo]:
"""Get list of available CUDA devices."""
self._refresh_devices()
return self._cuda_devices
def get_device_with_most_memory(self) -> str:
"""
Get the device with the most free memory.
Returns:
Device string
"""
self._refresh_devices()
if not self._cuda_devices:
if self._mps_available:
return "mps"
return "cpu"
best = max(self._cuda_devices, key=lambda d: d.free_memory_mb)
return f"cuda:{best.index}"
@staticmethod
def get_memory_info(device: str) -> tuple[float, float]:
"""
Get memory info for a device.
Args:
device: Device string
Returns:
Tuple of (free_mb, total_mb)
"""
if device == "cpu":
import psutil
mem = psutil.virtual_memory()
return mem.available / 1024 / 1024, mem.total / 1024 / 1024
try:
import torch
if device.startswith("cuda"):
idx = int(device.split(":")[1]) if ":" in device else 0
free, total = torch.cuda.mem_get_info(idx)
return free / 1024 / 1024, total / 1024 / 1024
except (ImportError, RuntimeError):
pass
return 0.0, 0.0
@staticmethod
def clear_cache(device: Optional[str] = None) -> None:
"""
Clear GPU cache to free memory.
Args:
device: Specific device to clear, or None for all
"""
try:
import torch
if torch.cuda.is_available():
if device and device.startswith("cuda"):
idx = int(device.split(":")[1]) if ":" in device else 0
with torch.cuda.device(idx):
torch.cuda.empty_cache()
else:
torch.cuda.empty_cache()
if hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"):
torch.mps.empty_cache()
except ImportError:
pass
@staticmethod
def set_default_device(device: str) -> None:
"""
Set the default device for tensor operations.
Args:
device: Device to set as default
"""
try:
import torch
if device.startswith("cuda"):
idx = int(device.split(":")[1]) if ":" in device else 0
torch.cuda.set_device(idx)
except ImportError:
pass
# Convenience functions
def get_best_device() -> str:
"""Get the best available device."""
return DeviceManager.get_best_device()
def get_device_count() -> int:
"""Get the number of available GPU devices."""
return DeviceManager.get_device_count()
def allocate_device(preference: Optional[str] = None) -> str:
"""Allocate a device with optional preference."""
return DeviceManager().allocate_device(preference)
def clear_gpu_cache() -> None:
"""Clear GPU memory cache."""
DeviceManager.clear_cache()