""" Base model loader interface. All framework-specific loaders inherit from BaseModelLoader. """ from abc import ABC, abstractmethod from dataclasses import dataclass, field from pathlib import Path from typing import TypeVar, Generic, Optional, Any, Dict import logging logger = logging.getLogger(__name__) T = TypeVar("T") # Model type @dataclass class ModelInfo: """Information about a loaded model.""" model_id: str path: Optional[Path] = None device: str = "cpu" dtype: Optional[str] = None memory_used_mb: float = 0.0 load_time_seconds: float = 0.0 metadata: Dict[str, Any] = field(default_factory=dict) class BaseModelLoader(ABC, Generic[T]): """ Abstract base class for all model loaders. Provides a consistent interface for loading, unloading, and managing ML models across different frameworks (HuggingFace, Diffusers, GGUF, etc.). Type Parameters: T: The type of the loaded model (e.g., Pipeline, StableDiffusionPipeline) Example: >>> class MyLoader(BaseModelLoader[MyModelType]): ... async def load(self, model_id: str, **kwargs) -> MyModelType: ... ... """ def __init__(self) -> None: self._model: Optional[T] = None self._model_info: Optional[ModelInfo] = None self._loading: bool = False @property def model(self) -> Optional[T]: """Get the currently loaded model, or None if not loaded.""" return self._model @property def model_info(self) -> Optional[ModelInfo]: """Get info about the currently loaded model.""" return self._model_info @abstractmethod async def load(self, model_id: str, **kwargs: Any) -> T: """ Load a model and return the ready-to-use instance. Args: model_id: Model identifier (manifest ID, HuggingFace ID, or path) **kwargs: Framework-specific options (device, dtype, etc.) Returns: The loaded model instance Raises: ModelLoadError: If the model cannot be loaded """ pass def load_sync(self, model_id: str, **kwargs: Any) -> T: """ Synchronous version of load(). Default implementation runs the async load in an event loop. Override for better sync performance if needed. """ import asyncio try: loop = asyncio.get_running_loop() except RuntimeError: loop = None if loop and loop.is_running(): # We're in an async context, can't use run_until_complete import concurrent.futures with concurrent.futures.ThreadPoolExecutor() as pool: future = pool.submit(asyncio.run, self.load(model_id, **kwargs)) return future.result() else: return asyncio.run(self.load(model_id, **kwargs)) @abstractmethod async def unload(self) -> None: """ Unload the current model and free resources. Should release GPU memory, close file handles, etc. Safe to call even if no model is loaded. """ pass def unload_sync(self) -> None: """Synchronous version of unload().""" import asyncio try: loop = asyncio.get_running_loop() except RuntimeError: loop = None if loop and loop.is_running(): import concurrent.futures with concurrent.futures.ThreadPoolExecutor() as pool: future = pool.submit(asyncio.run, self.unload()) future.result() else: asyncio.run(self.unload()) def is_loaded(self) -> bool: """Check if a model is currently loaded.""" return self._model is not None def is_loading(self) -> bool: """Check if a model is currently being loaded.""" return self._loading def get_model_id(self) -> Optional[str]: """Get the ID of the currently loaded model.""" return self._model_info.model_id if self._model_info else None def get_device(self) -> Optional[str]: """Get the device the model is loaded on.""" return self._model_info.device if self._model_info else None async def reload(self, **kwargs: Any) -> T: """ Reload the current model with new options. Args: **kwargs: New options to apply Returns: The reloaded model Raises: ValueError: If no model is currently loaded """ if not self._model_info: raise ValueError("No model loaded to reload") model_id = self._model_info.model_id await self.unload() return await self.load(model_id, **kwargs) def __enter__(self) -> "BaseModelLoader[T]": """Context manager entry.""" return self def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: """Context manager exit - ensures model is unloaded.""" self.unload_sync() async def __aenter__(self) -> "BaseModelLoader[T]": """Async context manager entry.""" return self async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: """Async context manager exit.""" await self.unload() class ModelLoadError(Exception): """Raised when a model fails to load.""" def __init__(self, model_id: str, reason: str, cause: Optional[Exception] = None): self.model_id = model_id self.reason = reason self.cause = cause super().__init__(f"Failed to load model '{model_id}': {reason}") class ModelNotFoundError(ModelLoadError): """Raised when a model cannot be found.""" def __init__(self, model_id: str, searched_paths: Optional[list] = None): paths_str = ", ".join(str(p) for p in (searched_paths or [])) reason = f"Model not found" + (f" (searched: {paths_str})" if paths_str else "") super().__init__(model_id, reason) self.searched_paths = searched_paths or [] class DeviceNotAvailableError(ModelLoadError): """Raised when the requested device is not available.""" def __init__(self, model_id: str, device: str): super().__init__(model_id, f"Device '{device}' is not available") self.device = device