ml-model-loader/src_python/tqftw_model_loader/base.py
Lilith aa01d0f388 chore: rename package @lilith/model-loader -> @lilith/ml-model-loader
Package renamed to follow naming convention:
@lilith/{namespace}-{parent}-{child}

Generated by rename-packages.sh
2025-12-31 01:32:00 -08:00

209 lines
6.2 KiB
Python

"""
Base model loader interface.
All framework-specific loaders inherit from BaseModelLoader.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from pathlib import Path
from typing import TypeVar, Generic, Optional, Any, Dict
import logging
logger = logging.getLogger(__name__)
T = TypeVar("T") # Model type
@dataclass
class ModelInfo:
"""Information about a loaded model."""
model_id: str
path: Optional[Path] = None
device: str = "cpu"
dtype: Optional[str] = None
memory_used_mb: float = 0.0
load_time_seconds: float = 0.0
metadata: Dict[str, Any] = field(default_factory=dict)
class BaseModelLoader(ABC, Generic[T]):
"""
Abstract base class for all model loaders.
Provides a consistent interface for loading, unloading, and managing
ML models across different frameworks (HuggingFace, Diffusers, GGUF, etc.).
Type Parameters:
T: The type of the loaded model (e.g., Pipeline, StableDiffusionPipeline)
Example:
>>> class MyLoader(BaseModelLoader[MyModelType]):
... async def load(self, model_id: str, **kwargs) -> MyModelType:
... ...
"""
def __init__(self) -> None:
self._model: Optional[T] = None
self._model_info: Optional[ModelInfo] = None
self._loading: bool = False
@property
def model(self) -> Optional[T]:
"""Get the currently loaded model, or None if not loaded."""
return self._model
@property
def model_info(self) -> Optional[ModelInfo]:
"""Get info about the currently loaded model."""
return self._model_info
@abstractmethod
async def load(self, model_id: str, **kwargs: Any) -> T:
"""
Load a model and return the ready-to-use instance.
Args:
model_id: Model identifier (manifest ID, HuggingFace ID, or path)
**kwargs: Framework-specific options (device, dtype, etc.)
Returns:
The loaded model instance
Raises:
ModelLoadError: If the model cannot be loaded
"""
pass
def load_sync(self, model_id: str, **kwargs: Any) -> T:
"""
Synchronous version of load().
Default implementation runs the async load in an event loop.
Override for better sync performance if needed.
"""
import asyncio
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop and loop.is_running():
# We're in an async context, can't use run_until_complete
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as pool:
future = pool.submit(asyncio.run, self.load(model_id, **kwargs))
return future.result()
else:
return asyncio.run(self.load(model_id, **kwargs))
@abstractmethod
async def unload(self) -> None:
"""
Unload the current model and free resources.
Should release GPU memory, close file handles, etc.
Safe to call even if no model is loaded.
"""
pass
def unload_sync(self) -> None:
"""Synchronous version of unload()."""
import asyncio
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop and loop.is_running():
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as pool:
future = pool.submit(asyncio.run, self.unload())
future.result()
else:
asyncio.run(self.unload())
def is_loaded(self) -> bool:
"""Check if a model is currently loaded."""
return self._model is not None
def is_loading(self) -> bool:
"""Check if a model is currently being loaded."""
return self._loading
def get_model_id(self) -> Optional[str]:
"""Get the ID of the currently loaded model."""
return self._model_info.model_id if self._model_info else None
def get_device(self) -> Optional[str]:
"""Get the device the model is loaded on."""
return self._model_info.device if self._model_info else None
async def reload(self, **kwargs: Any) -> T:
"""
Reload the current model with new options.
Args:
**kwargs: New options to apply
Returns:
The reloaded model
Raises:
ValueError: If no model is currently loaded
"""
if not self._model_info:
raise ValueError("No model loaded to reload")
model_id = self._model_info.model_id
await self.unload()
return await self.load(model_id, **kwargs)
def __enter__(self) -> "BaseModelLoader[T]":
"""Context manager entry."""
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
"""Context manager exit - ensures model is unloaded."""
self.unload_sync()
async def __aenter__(self) -> "BaseModelLoader[T]":
"""Async context manager entry."""
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
"""Async context manager exit."""
await self.unload()
class ModelLoadError(Exception):
"""Raised when a model fails to load."""
def __init__(self, model_id: str, reason: str, cause: Optional[Exception] = None):
self.model_id = model_id
self.reason = reason
self.cause = cause
super().__init__(f"Failed to load model '{model_id}': {reason}")
class ModelNotFoundError(ModelLoadError):
"""Raised when a model cannot be found."""
def __init__(self, model_id: str, searched_paths: Optional[list] = None):
paths_str = ", ".join(str(p) for p in (searched_paths or []))
reason = f"Model not found" + (f" (searched: {paths_str})" if paths_str else "")
super().__init__(model_id, reason)
self.searched_paths = searched_paths or []
class DeviceNotAvailableError(ModelLoadError):
"""Raised when the requested device is not available."""
def __init__(self, model_id: str, device: str):
super().__init__(model_id, f"Device '{device}' is not available")
self.device = device