""" Model loader registry. Allows registering and retrieving model loaders by name. """ from typing import Type, Dict, Optional, Callable, Any import logging from .base import BaseModelLoader logger = logging.getLogger(__name__) # Global registry of loaders _loaders: Dict[str, Type[BaseModelLoader]] = {} # Aliases for common names _aliases: Dict[str, str] = { "huggingface": "hf", "transformers": "hf", "hf-transformers": "hf", "stable-diffusion": "diffusers", "sdxl": "diffusers", "sd": "diffusers", "llama": "gguf", "llama-cpp": "gguf", "llamacpp": "gguf", } def register_loader( name: str, aliases: Optional[list[str]] = None, ) -> Callable[[Type[BaseModelLoader]], Type[BaseModelLoader]]: """ Decorator to register a model loader class. Args: name: Primary name for the loader aliases: Optional list of alternative names Example: >>> @register_loader("my-format", aliases=["mf"]) ... class MyFormatLoader(BaseModelLoader): ... ... >>> loader = get_loader("my-format") # or get_loader("mf") """ def decorator(cls: Type[BaseModelLoader]) -> Type[BaseModelLoader]: if name in _loaders: logger.warning(f"Overwriting existing loader '{name}'") _loaders[name] = cls # Register aliases for alias in aliases or []: _aliases[alias] = name logger.debug(f"Registered loader '{name}' ({cls.__name__})") return cls return decorator def get_loader(name: str, **init_kwargs: Any) -> BaseModelLoader: """ Get a loader instance by name. Args: name: Loader name or alias **init_kwargs: Arguments to pass to loader constructor Returns: Instantiated loader Raises: KeyError: If loader not found """ # Resolve alias resolved_name = _aliases.get(name, name) if resolved_name not in _loaders: available = list(_loaders.keys()) raise KeyError( f"Unknown loader '{name}'. Available: {available}" ) loader_cls = _loaders[resolved_name] return loader_cls(**init_kwargs) def get_loader_class(name: str) -> Type[BaseModelLoader]: """ Get a loader class by name (without instantiating). Args: name: Loader name or alias Returns: Loader class Raises: KeyError: If loader not found """ resolved_name = _aliases.get(name, name) if resolved_name not in _loaders: raise KeyError(f"Unknown loader '{name}'") return _loaders[resolved_name] def list_loaders() -> list[str]: """Get list of registered loader names.""" return list(_loaders.keys()) def list_aliases() -> Dict[str, str]: """Get dict of alias -> loader name mappings.""" return dict(_aliases) def is_loader_registered(name: str) -> bool: """Check if a loader is registered.""" resolved = _aliases.get(name, name) return resolved in _loaders def unregister_loader(name: str) -> None: """ Unregister a loader. Also removes any aliases pointing to it. Args: name: Loader name to unregister """ if name in _loaders: del _loaders[name] # Remove aliases aliases_to_remove = [k for k, v in _aliases.items() if v == name] for alias in aliases_to_remove: del _aliases[alias] def clear_registry() -> None: """Clear all registered loaders and aliases.""" _loaders.clear() _aliases.clear() # Auto-discover and register built-in loaders when module is imported def _register_builtin_loaders() -> None: """Register the built-in loaders.""" # This will be called after all loader modules are defined # The loaders register themselves via the @register_loader decorator pass