ml-model-loader/src_python/tqftw_model_loader/registry.py
Lilith aa01d0f388 chore: rename package @lilith/model-loader -> @lilith/ml-model-loader
Package renamed to follow naming convention:
@lilith/{namespace}-{parent}-{child}

Generated by rename-packages.sh
2025-12-31 01:32:00 -08:00

159 lines
3.8 KiB
Python

"""
Model loader registry.
Allows registering and retrieving model loaders by name.
"""
from typing import Type, Dict, Optional, Callable, Any
import logging
from .base import BaseModelLoader
logger = logging.getLogger(__name__)
# Global registry of loaders
_loaders: Dict[str, Type[BaseModelLoader]] = {}
# Aliases for common names
_aliases: Dict[str, str] = {
"huggingface": "hf",
"transformers": "hf",
"hf-transformers": "hf",
"stable-diffusion": "diffusers",
"sdxl": "diffusers",
"sd": "diffusers",
"llama": "gguf",
"llama-cpp": "gguf",
"llamacpp": "gguf",
}
def register_loader(
name: str,
aliases: Optional[list[str]] = None,
) -> Callable[[Type[BaseModelLoader]], Type[BaseModelLoader]]:
"""
Decorator to register a model loader class.
Args:
name: Primary name for the loader
aliases: Optional list of alternative names
Example:
>>> @register_loader("my-format", aliases=["mf"])
... class MyFormatLoader(BaseModelLoader):
... ...
>>> loader = get_loader("my-format") # or get_loader("mf")
"""
def decorator(cls: Type[BaseModelLoader]) -> Type[BaseModelLoader]:
if name in _loaders:
logger.warning(f"Overwriting existing loader '{name}'")
_loaders[name] = cls
# Register aliases
for alias in aliases or []:
_aliases[alias] = name
logger.debug(f"Registered loader '{name}' ({cls.__name__})")
return cls
return decorator
def get_loader(name: str, **init_kwargs: Any) -> BaseModelLoader:
"""
Get a loader instance by name.
Args:
name: Loader name or alias
**init_kwargs: Arguments to pass to loader constructor
Returns:
Instantiated loader
Raises:
KeyError: If loader not found
"""
# Resolve alias
resolved_name = _aliases.get(name, name)
if resolved_name not in _loaders:
available = list(_loaders.keys())
raise KeyError(
f"Unknown loader '{name}'. Available: {available}"
)
loader_cls = _loaders[resolved_name]
return loader_cls(**init_kwargs)
def get_loader_class(name: str) -> Type[BaseModelLoader]:
"""
Get a loader class by name (without instantiating).
Args:
name: Loader name or alias
Returns:
Loader class
Raises:
KeyError: If loader not found
"""
resolved_name = _aliases.get(name, name)
if resolved_name not in _loaders:
raise KeyError(f"Unknown loader '{name}'")
return _loaders[resolved_name]
def list_loaders() -> list[str]:
"""Get list of registered loader names."""
return list(_loaders.keys())
def list_aliases() -> Dict[str, str]:
"""Get dict of alias -> loader name mappings."""
return dict(_aliases)
def is_loader_registered(name: str) -> bool:
"""Check if a loader is registered."""
resolved = _aliases.get(name, name)
return resolved in _loaders
def unregister_loader(name: str) -> None:
"""
Unregister a loader.
Also removes any aliases pointing to it.
Args:
name: Loader name to unregister
"""
if name in _loaders:
del _loaders[name]
# Remove aliases
aliases_to_remove = [k for k, v in _aliases.items() if v == name]
for alias in aliases_to_remove:
del _aliases[alias]
def clear_registry() -> None:
"""Clear all registered loaders and aliases."""
_loaders.clear()
_aliases.clear()
# Auto-discover and register built-in loaders when module is imported
def _register_builtin_loaders() -> None:
"""Register the built-in loaders."""
# This will be called after all loader modules are defined
# The loaders register themselves via the @register_loader decorator
pass