210 lines
6.2 KiB
Python
210 lines
6.2 KiB
Python
|
|
"""
|
||
|
|
Base model loader interface.
|
||
|
|
|
||
|
|
All framework-specific loaders inherit from BaseModelLoader.
|
||
|
|
"""
|
||
|
|
|
||
|
|
from abc import ABC, abstractmethod
|
||
|
|
from dataclasses import dataclass, field
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import TypeVar, Generic, Optional, Any, Dict
|
||
|
|
import logging
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
T = TypeVar("T") # Model type
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class ModelInfo:
|
||
|
|
"""Information about a loaded model."""
|
||
|
|
|
||
|
|
model_id: str
|
||
|
|
path: Optional[Path] = None
|
||
|
|
device: str = "cpu"
|
||
|
|
dtype: Optional[str] = None
|
||
|
|
memory_used_mb: float = 0.0
|
||
|
|
load_time_seconds: float = 0.0
|
||
|
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||
|
|
|
||
|
|
|
||
|
|
class BaseModelLoader(ABC, Generic[T]):
|
||
|
|
"""
|
||
|
|
Abstract base class for all model loaders.
|
||
|
|
|
||
|
|
Provides a consistent interface for loading, unloading, and managing
|
||
|
|
ML models across different frameworks (HuggingFace, Diffusers, GGUF, etc.).
|
||
|
|
|
||
|
|
Type Parameters:
|
||
|
|
T: The type of the loaded model (e.g., Pipeline, StableDiffusionPipeline)
|
||
|
|
|
||
|
|
Example:
|
||
|
|
>>> class MyLoader(BaseModelLoader[MyModelType]):
|
||
|
|
... async def load(self, model_id: str, **kwargs) -> MyModelType:
|
||
|
|
... ...
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self) -> None:
|
||
|
|
self._model: Optional[T] = None
|
||
|
|
self._model_info: Optional[ModelInfo] = None
|
||
|
|
self._loading: bool = False
|
||
|
|
|
||
|
|
@property
|
||
|
|
def model(self) -> Optional[T]:
|
||
|
|
"""Get the currently loaded model, or None if not loaded."""
|
||
|
|
return self._model
|
||
|
|
|
||
|
|
@property
|
||
|
|
def model_info(self) -> Optional[ModelInfo]:
|
||
|
|
"""Get info about the currently loaded model."""
|
||
|
|
return self._model_info
|
||
|
|
|
||
|
|
@abstractmethod
|
||
|
|
async def load(self, model_id: str, **kwargs: Any) -> T:
|
||
|
|
"""
|
||
|
|
Load a model and return the ready-to-use instance.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
model_id: Model identifier (manifest ID, HuggingFace ID, or path)
|
||
|
|
**kwargs: Framework-specific options (device, dtype, etc.)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
The loaded model instance
|
||
|
|
|
||
|
|
Raises:
|
||
|
|
ModelLoadError: If the model cannot be loaded
|
||
|
|
"""
|
||
|
|
pass
|
||
|
|
|
||
|
|
def load_sync(self, model_id: str, **kwargs: Any) -> T:
|
||
|
|
"""
|
||
|
|
Synchronous version of load().
|
||
|
|
|
||
|
|
Default implementation runs the async load in an event loop.
|
||
|
|
Override for better sync performance if needed.
|
||
|
|
"""
|
||
|
|
import asyncio
|
||
|
|
|
||
|
|
try:
|
||
|
|
loop = asyncio.get_running_loop()
|
||
|
|
except RuntimeError:
|
||
|
|
loop = None
|
||
|
|
|
||
|
|
if loop and loop.is_running():
|
||
|
|
# We're in an async context, can't use run_until_complete
|
||
|
|
import concurrent.futures
|
||
|
|
|
||
|
|
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||
|
|
future = pool.submit(asyncio.run, self.load(model_id, **kwargs))
|
||
|
|
return future.result()
|
||
|
|
else:
|
||
|
|
return asyncio.run(self.load(model_id, **kwargs))
|
||
|
|
|
||
|
|
@abstractmethod
|
||
|
|
async def unload(self) -> None:
|
||
|
|
"""
|
||
|
|
Unload the current model and free resources.
|
||
|
|
|
||
|
|
Should release GPU memory, close file handles, etc.
|
||
|
|
Safe to call even if no model is loaded.
|
||
|
|
"""
|
||
|
|
pass
|
||
|
|
|
||
|
|
def unload_sync(self) -> None:
|
||
|
|
"""Synchronous version of unload()."""
|
||
|
|
import asyncio
|
||
|
|
|
||
|
|
try:
|
||
|
|
loop = asyncio.get_running_loop()
|
||
|
|
except RuntimeError:
|
||
|
|
loop = None
|
||
|
|
|
||
|
|
if loop and loop.is_running():
|
||
|
|
import concurrent.futures
|
||
|
|
|
||
|
|
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||
|
|
future = pool.submit(asyncio.run, self.unload())
|
||
|
|
future.result()
|
||
|
|
else:
|
||
|
|
asyncio.run(self.unload())
|
||
|
|
|
||
|
|
def is_loaded(self) -> bool:
|
||
|
|
"""Check if a model is currently loaded."""
|
||
|
|
return self._model is not None
|
||
|
|
|
||
|
|
def is_loading(self) -> bool:
|
||
|
|
"""Check if a model is currently being loaded."""
|
||
|
|
return self._loading
|
||
|
|
|
||
|
|
def get_model_id(self) -> Optional[str]:
|
||
|
|
"""Get the ID of the currently loaded model."""
|
||
|
|
return self._model_info.model_id if self._model_info else None
|
||
|
|
|
||
|
|
def get_device(self) -> Optional[str]:
|
||
|
|
"""Get the device the model is loaded on."""
|
||
|
|
return self._model_info.device if self._model_info else None
|
||
|
|
|
||
|
|
async def reload(self, **kwargs: Any) -> T:
|
||
|
|
"""
|
||
|
|
Reload the current model with new options.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
**kwargs: New options to apply
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
The reloaded model
|
||
|
|
|
||
|
|
Raises:
|
||
|
|
ValueError: If no model is currently loaded
|
||
|
|
"""
|
||
|
|
if not self._model_info:
|
||
|
|
raise ValueError("No model loaded to reload")
|
||
|
|
|
||
|
|
model_id = self._model_info.model_id
|
||
|
|
await self.unload()
|
||
|
|
return await self.load(model_id, **kwargs)
|
||
|
|
|
||
|
|
def __enter__(self) -> "BaseModelLoader[T]":
|
||
|
|
"""Context manager entry."""
|
||
|
|
return self
|
||
|
|
|
||
|
|
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||
|
|
"""Context manager exit - ensures model is unloaded."""
|
||
|
|
self.unload_sync()
|
||
|
|
|
||
|
|
async def __aenter__(self) -> "BaseModelLoader[T]":
|
||
|
|
"""Async context manager entry."""
|
||
|
|
return self
|
||
|
|
|
||
|
|
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||
|
|
"""Async context manager exit."""
|
||
|
|
await self.unload()
|
||
|
|
|
||
|
|
|
||
|
|
class ModelLoadError(Exception):
|
||
|
|
"""Raised when a model fails to load."""
|
||
|
|
|
||
|
|
def __init__(self, model_id: str, reason: str, cause: Optional[Exception] = None):
|
||
|
|
self.model_id = model_id
|
||
|
|
self.reason = reason
|
||
|
|
self.cause = cause
|
||
|
|
super().__init__(f"Failed to load model '{model_id}': {reason}")
|
||
|
|
|
||
|
|
|
||
|
|
class ModelNotFoundError(ModelLoadError):
|
||
|
|
"""Raised when a model cannot be found."""
|
||
|
|
|
||
|
|
def __init__(self, model_id: str, searched_paths: Optional[list] = None):
|
||
|
|
paths_str = ", ".join(str(p) for p in (searched_paths or []))
|
||
|
|
reason = f"Model not found" + (f" (searched: {paths_str})" if paths_str else "")
|
||
|
|
super().__init__(model_id, reason)
|
||
|
|
self.searched_paths = searched_paths or []
|
||
|
|
|
||
|
|
|
||
|
|
class DeviceNotAvailableError(ModelLoadError):
|
||
|
|
"""Raised when the requested device is not available."""
|
||
|
|
|
||
|
|
def __init__(self, model_id: str, device: str):
|
||
|
|
super().__init__(model_id, f"Device '{device}' is not available")
|
||
|
|
self.device = device
|