ml-model-loader/src_python/tqftw_model_loader/registry.py

160 lines
3.8 KiB
Python
Raw Permalink Normal View History

2025-12-28 04:32:35 -08:00
"""
Model loader registry.
Allows registering and retrieving model loaders by name.
"""
from typing import Type, Dict, Optional, Callable, Any
import logging
from .base import BaseModelLoader
logger = logging.getLogger(__name__)
# Global registry of loaders
_loaders: Dict[str, Type[BaseModelLoader]] = {}
# Aliases for common names
_aliases: Dict[str, str] = {
"huggingface": "hf",
"transformers": "hf",
"hf-transformers": "hf",
"stable-diffusion": "diffusers",
"sdxl": "diffusers",
"sd": "diffusers",
"llama": "gguf",
"llama-cpp": "gguf",
"llamacpp": "gguf",
}
def register_loader(
name: str,
aliases: Optional[list[str]] = None,
) -> Callable[[Type[BaseModelLoader]], Type[BaseModelLoader]]:
"""
Decorator to register a model loader class.
Args:
name: Primary name for the loader
aliases: Optional list of alternative names
Example:
>>> @register_loader("my-format", aliases=["mf"])
... class MyFormatLoader(BaseModelLoader):
... ...
>>> loader = get_loader("my-format") # or get_loader("mf")
"""
def decorator(cls: Type[BaseModelLoader]) -> Type[BaseModelLoader]:
if name in _loaders:
logger.warning(f"Overwriting existing loader '{name}'")
_loaders[name] = cls
# Register aliases
for alias in aliases or []:
_aliases[alias] = name
logger.debug(f"Registered loader '{name}' ({cls.__name__})")
return cls
return decorator
def get_loader(name: str, **init_kwargs: Any) -> BaseModelLoader:
"""
Get a loader instance by name.
Args:
name: Loader name or alias
**init_kwargs: Arguments to pass to loader constructor
Returns:
Instantiated loader
Raises:
KeyError: If loader not found
"""
# Resolve alias
resolved_name = _aliases.get(name, name)
if resolved_name not in _loaders:
available = list(_loaders.keys())
raise KeyError(
f"Unknown loader '{name}'. Available: {available}"
)
loader_cls = _loaders[resolved_name]
return loader_cls(**init_kwargs)
def get_loader_class(name: str) -> Type[BaseModelLoader]:
"""
Get a loader class by name (without instantiating).
Args:
name: Loader name or alias
Returns:
Loader class
Raises:
KeyError: If loader not found
"""
resolved_name = _aliases.get(name, name)
if resolved_name not in _loaders:
raise KeyError(f"Unknown loader '{name}'")
return _loaders[resolved_name]
def list_loaders() -> list[str]:
"""Get list of registered loader names."""
return list(_loaders.keys())
def list_aliases() -> Dict[str, str]:
"""Get dict of alias -> loader name mappings."""
return dict(_aliases)
def is_loader_registered(name: str) -> bool:
"""Check if a loader is registered."""
resolved = _aliases.get(name, name)
return resolved in _loaders
def unregister_loader(name: str) -> None:
"""
Unregister a loader.
Also removes any aliases pointing to it.
Args:
name: Loader name to unregister
"""
if name in _loaders:
del _loaders[name]
# Remove aliases
aliases_to_remove = [k for k, v in _aliases.items() if v == name]
for alias in aliases_to_remove:
del _aliases[alias]
def clear_registry() -> None:
"""Clear all registered loaders and aliases."""
_loaders.clear()
_aliases.clear()
# Auto-discover and register built-in loaders when module is imported
def _register_builtin_loaders() -> None:
"""Register the built-in loaders."""
# This will be called after all loader modules are defined
# The loaders register themselves via the @register_loader decorator
pass