316 lines
8.7 KiB
Python
316 lines
8.7 KiB
Python
|
|
"""
|
||
|
|
Device management for ML model loading.
|
||
|
|
|
||
|
|
Handles GPU/CPU detection, allocation, and memory management.
|
||
|
|
"""
|
||
|
|
|
||
|
|
from dataclasses import dataclass
|
||
|
|
from typing import Optional, List, Literal
|
||
|
|
import logging
|
||
|
|
import os
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
DeviceType = Literal["cuda", "mps", "cpu"]
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class DeviceInfo:
|
||
|
|
"""Information about a compute device."""
|
||
|
|
|
||
|
|
name: str
|
||
|
|
type: DeviceType
|
||
|
|
index: int
|
||
|
|
total_memory_mb: float
|
||
|
|
free_memory_mb: float
|
||
|
|
is_available: bool
|
||
|
|
|
||
|
|
|
||
|
|
class DeviceManager:
|
||
|
|
"""
|
||
|
|
Centralized device management for ML models.
|
||
|
|
|
||
|
|
Handles device detection, selection, and memory tracking across
|
||
|
|
CUDA, MPS (Apple Silicon), and CPU backends.
|
||
|
|
|
||
|
|
Example:
|
||
|
|
>>> dm = DeviceManager()
|
||
|
|
>>> device = dm.get_best_device()
|
||
|
|
>>> print(device) # "cuda:0" or "mps" or "cpu"
|
||
|
|
|
||
|
|
>>> # Allocate with preference
|
||
|
|
>>> device = dm.allocate_device(preference="cuda:1")
|
||
|
|
"""
|
||
|
|
|
||
|
|
_instance: Optional["DeviceManager"] = None
|
||
|
|
|
||
|
|
def __new__(cls) -> "DeviceManager":
|
||
|
|
"""Singleton pattern for device manager."""
|
||
|
|
if cls._instance is None:
|
||
|
|
cls._instance = super().__new__(cls)
|
||
|
|
cls._instance._initialized = False
|
||
|
|
return cls._instance
|
||
|
|
|
||
|
|
def __init__(self) -> None:
|
||
|
|
if self._initialized:
|
||
|
|
return
|
||
|
|
self._initialized = True
|
||
|
|
self._allocated_devices: dict[str, int] = {} # device -> allocation count
|
||
|
|
self._refresh_devices()
|
||
|
|
|
||
|
|
def _refresh_devices(self) -> None:
|
||
|
|
"""Refresh device information."""
|
||
|
|
self._cuda_available = False
|
||
|
|
self._mps_available = False
|
||
|
|
self._cuda_devices: List[DeviceInfo] = []
|
||
|
|
|
||
|
|
# Check CUDA
|
||
|
|
try:
|
||
|
|
import torch
|
||
|
|
|
||
|
|
if torch.cuda.is_available():
|
||
|
|
self._cuda_available = True
|
||
|
|
for i in range(torch.cuda.device_count()):
|
||
|
|
props = torch.cuda.get_device_properties(i)
|
||
|
|
free_mem, total_mem = torch.cuda.mem_get_info(i)
|
||
|
|
self._cuda_devices.append(
|
||
|
|
DeviceInfo(
|
||
|
|
name=props.name,
|
||
|
|
type="cuda",
|
||
|
|
index=i,
|
||
|
|
total_memory_mb=total_mem / 1024 / 1024,
|
||
|
|
free_memory_mb=free_mem / 1024 / 1024,
|
||
|
|
is_available=True,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
# Check MPS (Apple Silicon)
|
||
|
|
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||
|
|
self._mps_available = True
|
||
|
|
|
||
|
|
except ImportError:
|
||
|
|
logger.debug("PyTorch not available, GPU detection disabled")
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def get_best_device() -> str:
|
||
|
|
"""
|
||
|
|
Get the best available device.
|
||
|
|
|
||
|
|
Priority: CUDA > MPS > CPU
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Device string (e.g., "cuda:0", "mps", "cpu")
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
import torch
|
||
|
|
|
||
|
|
if torch.cuda.is_available():
|
||
|
|
return "cuda:0"
|
||
|
|
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||
|
|
return "mps"
|
||
|
|
except ImportError:
|
||
|
|
pass
|
||
|
|
|
||
|
|
return "cpu"
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def get_device_count() -> int:
|
||
|
|
"""
|
||
|
|
Get the number of available GPU devices.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Number of GPUs (0 if none or PyTorch not available)
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
import torch
|
||
|
|
|
||
|
|
if torch.cuda.is_available():
|
||
|
|
return torch.cuda.device_count()
|
||
|
|
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||
|
|
return 1
|
||
|
|
except ImportError:
|
||
|
|
pass
|
||
|
|
|
||
|
|
return 0
|
||
|
|
|
||
|
|
def allocate_device(self, preference: Optional[str] = None) -> str:
|
||
|
|
"""
|
||
|
|
Allocate a device with optional preference.
|
||
|
|
|
||
|
|
If preference is not available, falls back to best available.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
preference: Preferred device (e.g., "cuda:0", "cuda:1", "mps", "cpu")
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Allocated device string
|
||
|
|
"""
|
||
|
|
if preference:
|
||
|
|
if self._is_device_available(preference):
|
||
|
|
self._allocated_devices[preference] = (
|
||
|
|
self._allocated_devices.get(preference, 0) + 1
|
||
|
|
)
|
||
|
|
return preference
|
||
|
|
else:
|
||
|
|
logger.warning(
|
||
|
|
f"Preferred device '{preference}' not available, using best available"
|
||
|
|
)
|
||
|
|
|
||
|
|
device = self.get_best_device()
|
||
|
|
self._allocated_devices[device] = self._allocated_devices.get(device, 0) + 1
|
||
|
|
return device
|
||
|
|
|
||
|
|
def release_device(self, device: str) -> None:
|
||
|
|
"""
|
||
|
|
Release a previously allocated device.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
device: Device to release
|
||
|
|
"""
|
||
|
|
if device in self._allocated_devices:
|
||
|
|
self._allocated_devices[device] -= 1
|
||
|
|
if self._allocated_devices[device] <= 0:
|
||
|
|
del self._allocated_devices[device]
|
||
|
|
|
||
|
|
def _is_device_available(self, device: str) -> bool:
|
||
|
|
"""Check if a specific device is available."""
|
||
|
|
if device == "cpu":
|
||
|
|
return True
|
||
|
|
|
||
|
|
try:
|
||
|
|
import torch
|
||
|
|
|
||
|
|
if device.startswith("cuda"):
|
||
|
|
if not torch.cuda.is_available():
|
||
|
|
return False
|
||
|
|
if ":" in device:
|
||
|
|
idx = int(device.split(":")[1])
|
||
|
|
return idx < torch.cuda.device_count()
|
||
|
|
return True
|
||
|
|
|
||
|
|
if device == "mps":
|
||
|
|
return (
|
||
|
|
hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
||
|
|
)
|
||
|
|
|
||
|
|
except ImportError:
|
||
|
|
pass
|
||
|
|
|
||
|
|
return False
|
||
|
|
|
||
|
|
def get_cuda_devices(self) -> List[DeviceInfo]:
|
||
|
|
"""Get list of available CUDA devices."""
|
||
|
|
self._refresh_devices()
|
||
|
|
return self._cuda_devices
|
||
|
|
|
||
|
|
def get_device_with_most_memory(self) -> str:
|
||
|
|
"""
|
||
|
|
Get the device with the most free memory.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Device string
|
||
|
|
"""
|
||
|
|
self._refresh_devices()
|
||
|
|
|
||
|
|
if not self._cuda_devices:
|
||
|
|
if self._mps_available:
|
||
|
|
return "mps"
|
||
|
|
return "cpu"
|
||
|
|
|
||
|
|
best = max(self._cuda_devices, key=lambda d: d.free_memory_mb)
|
||
|
|
return f"cuda:{best.index}"
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def get_memory_info(device: str) -> tuple[float, float]:
|
||
|
|
"""
|
||
|
|
Get memory info for a device.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
device: Device string
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Tuple of (free_mb, total_mb)
|
||
|
|
"""
|
||
|
|
if device == "cpu":
|
||
|
|
import psutil
|
||
|
|
|
||
|
|
mem = psutil.virtual_memory()
|
||
|
|
return mem.available / 1024 / 1024, mem.total / 1024 / 1024
|
||
|
|
|
||
|
|
try:
|
||
|
|
import torch
|
||
|
|
|
||
|
|
if device.startswith("cuda"):
|
||
|
|
idx = int(device.split(":")[1]) if ":" in device else 0
|
||
|
|
free, total = torch.cuda.mem_get_info(idx)
|
||
|
|
return free / 1024 / 1024, total / 1024 / 1024
|
||
|
|
|
||
|
|
except (ImportError, RuntimeError):
|
||
|
|
pass
|
||
|
|
|
||
|
|
return 0.0, 0.0
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def clear_cache(device: Optional[str] = None) -> None:
|
||
|
|
"""
|
||
|
|
Clear GPU cache to free memory.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
device: Specific device to clear, or None for all
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
import torch
|
||
|
|
|
||
|
|
if torch.cuda.is_available():
|
||
|
|
if device and device.startswith("cuda"):
|
||
|
|
idx = int(device.split(":")[1]) if ":" in device else 0
|
||
|
|
with torch.cuda.device(idx):
|
||
|
|
torch.cuda.empty_cache()
|
||
|
|
else:
|
||
|
|
torch.cuda.empty_cache()
|
||
|
|
|
||
|
|
if hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"):
|
||
|
|
torch.mps.empty_cache()
|
||
|
|
|
||
|
|
except ImportError:
|
||
|
|
pass
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def set_default_device(device: str) -> None:
|
||
|
|
"""
|
||
|
|
Set the default device for tensor operations.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
device: Device to set as default
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
import torch
|
||
|
|
|
||
|
|
if device.startswith("cuda"):
|
||
|
|
idx = int(device.split(":")[1]) if ":" in device else 0
|
||
|
|
torch.cuda.set_device(idx)
|
||
|
|
|
||
|
|
except ImportError:
|
||
|
|
pass
|
||
|
|
|
||
|
|
|
||
|
|
# Convenience functions
|
||
|
|
def get_best_device() -> str:
|
||
|
|
"""Get the best available device."""
|
||
|
|
return DeviceManager.get_best_device()
|
||
|
|
|
||
|
|
|
||
|
|
def get_device_count() -> int:
|
||
|
|
"""Get the number of available GPU devices."""
|
||
|
|
return DeviceManager.get_device_count()
|
||
|
|
|
||
|
|
|
||
|
|
def allocate_device(preference: Optional[str] = None) -> str:
|
||
|
|
"""Allocate a device with optional preference."""
|
||
|
|
return DeviceManager().allocate_device(preference)
|
||
|
|
|
||
|
|
|
||
|
|
def clear_gpu_cache() -> None:
|
||
|
|
"""Clear GPU memory cache."""
|
||
|
|
DeviceManager.clear_cache()
|