ml-model-loader/src_python/tqftw_model_loader/auto.py

223 lines
6 KiB
Python
Raw Permalink Normal View History

"""
Auto-loader selection based on manifest format field or file extension.
Provides intelligent loader selection without explicit loader specification.
"""
from pathlib import Path
from typing import Optional, Any, TYPE_CHECKING
import logging
from .registry import get_loader, is_loader_registered
from .types import ModelFormat
if TYPE_CHECKING:
from .base import BaseModelLoader
logger = logging.getLogger(__name__)
# Format to loader mapping
FORMAT_TO_LOADER: dict[str, str] = {
"gguf": "gguf",
"safetensors": "diffusers", # Default to diffusers for safetensors
"onnx": "onnx",
"pytorch": "hf",
"hf-snapshot": "hf",
"diffusion": "diffusers",
}
# Extension-based fallback
EXTENSION_TO_LOADER: dict[str, str] = {
".gguf": "gguf",
".onnx": "onnx",
".safetensors": "diffusers",
".pt": "hf",
".pth": "hf",
".bin": "hf",
}
# Category-based defaults
CATEGORY_TO_LOADER: dict[str, str] = {
"llm": "gguf",
"embedding": "gguf",
"voice": "whisper",
"diffusion": "diffusers",
"multimodal": "hf",
"tools": "onnx",
}
def get_auto_loader(
model_id: str,
manifest_path: Optional[str] = None,
**init_kwargs: Any,
) -> "BaseModelLoader":
"""
Automatically select and return appropriate loader for a model.
Selection priority:
1. Manifest `format` field (if model in manifest)
2. File extension (for direct paths)
3. Manifest `category` field (fallback)
4. Default to "hf" loader
Args:
model_id: Model ID from manifest or direct path
manifest_path: Optional custom manifest path
**init_kwargs: Arguments passed to loader constructor
Returns:
Instantiated loader appropriate for the model
Example:
>>> loader = get_auto_loader("whisper-large-v3-turbo")
>>> model = await loader.load("whisper-large-v3-turbo")
>>> # Works with paths too
>>> loader = get_auto_loader("/path/to/model.onnx")
>>> model = await loader.load("/path/to/model.onnx")
"""
loader_name: Optional[str] = None
# Try to get manifest entry
try:
from .loader import ModelLoader
path_loader = ModelLoader(manifest_path=manifest_path)
entry = path_loader.get_model_entry(model_id)
if entry:
# Priority 1: Explicit format field
if entry.format:
loader_name = FORMAT_TO_LOADER.get(entry.format)
if loader_name:
logger.debug(
f"Selected loader '{loader_name}' from format '{entry.format}'"
)
# Priority 3: Category fallback
if not loader_name and entry.category:
loader_name = CATEGORY_TO_LOADER.get(entry.category)
if loader_name:
logger.debug(
f"Selected loader '{loader_name}' from category '{entry.category}'"
)
except Exception as e:
logger.debug(f"Failed to get manifest entry: {e}")
# Priority 2: Extension-based (for direct paths)
if not loader_name:
path = Path(model_id)
if path.suffix:
loader_name = EXTENSION_TO_LOADER.get(path.suffix.lower())
if loader_name:
logger.debug(
f"Selected loader '{loader_name}' from extension '{path.suffix}'"
)
# Default fallback
if not loader_name:
loader_name = "hf"
logger.debug(f"Using default loader 'hf'")
# Verify loader exists
if not is_loader_registered(loader_name):
logger.warning(
f"Loader '{loader_name}' not registered, falling back to 'hf'"
)
loader_name = "hf"
return get_loader(loader_name, **init_kwargs)
def detect_format_from_path(path: str) -> Optional[ModelFormat]:
"""
Detect model format from file path.
Args:
path: Path to model file or directory
Returns:
Detected ModelFormat or None
"""
p = Path(path)
# File extension detection
ext = p.suffix.lower()
if ext == ".gguf":
return "gguf"
elif ext == ".onnx":
return "onnx"
elif ext == ".safetensors":
return "safetensors"
elif ext in (".pt", ".pth", ".bin"):
return "pytorch"
# Directory detection
if p.is_dir():
# Check for diffusion pipeline
if (p / "model_index.json").exists():
return "diffusion"
# Check for sharded safetensors
if (p / "model.safetensors.index.json").exists():
return "safetensors"
# Check for HuggingFace cache
if p.name.startswith("models--"):
return "hf-snapshot"
# Check for ONNX
if list(p.glob("*.onnx")):
return "onnx"
# Default to safetensors for directories with config.json
if (p / "config.json").exists():
return "safetensors"
return None
def get_loader_for_format(
format: ModelFormat,
**init_kwargs: Any,
) -> "BaseModelLoader":
"""
Get a loader instance for a specific format.
Args:
format: Model format
**init_kwargs: Arguments passed to loader constructor
Returns:
Instantiated loader
Raises:
KeyError: If no loader registered for format
"""
loader_name = FORMAT_TO_LOADER.get(format)
if not loader_name:
raise KeyError(f"No loader registered for format '{format}'")
return get_loader(loader_name, **init_kwargs)
def get_loader_for_category(
category: str,
**init_kwargs: Any,
) -> "BaseModelLoader":
"""
Get a loader instance for a model category.
Args:
category: Model category (llm, embedding, voice, diffusion, multimodal)
**init_kwargs: Arguments passed to loader constructor
Returns:
Instantiated loader
Raises:
KeyError: If no loader registered for category
"""
loader_name = CATEGORY_TO_LOADER.get(category)
if not loader_name:
raise KeyError(f"No loader registered for category '{category}'")
return get_loader(loader_name, **init_kwargs)