223 lines
6 KiB
Python
223 lines
6 KiB
Python
|
|
"""
|
||
|
|
Auto-loader selection based on manifest format field or file extension.
|
||
|
|
|
||
|
|
Provides intelligent loader selection without explicit loader specification.
|
||
|
|
"""
|
||
|
|
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import Optional, Any, TYPE_CHECKING
|
||
|
|
import logging
|
||
|
|
|
||
|
|
from .registry import get_loader, is_loader_registered
|
||
|
|
from .types import ModelFormat
|
||
|
|
|
||
|
|
if TYPE_CHECKING:
|
||
|
|
from .base import BaseModelLoader
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
# Format to loader mapping
|
||
|
|
FORMAT_TO_LOADER: dict[str, str] = {
|
||
|
|
"gguf": "gguf",
|
||
|
|
"safetensors": "diffusers", # Default to diffusers for safetensors
|
||
|
|
"onnx": "onnx",
|
||
|
|
"pytorch": "hf",
|
||
|
|
"hf-snapshot": "hf",
|
||
|
|
"diffusion": "diffusers",
|
||
|
|
}
|
||
|
|
|
||
|
|
# Extension-based fallback
|
||
|
|
EXTENSION_TO_LOADER: dict[str, str] = {
|
||
|
|
".gguf": "gguf",
|
||
|
|
".onnx": "onnx",
|
||
|
|
".safetensors": "diffusers",
|
||
|
|
".pt": "hf",
|
||
|
|
".pth": "hf",
|
||
|
|
".bin": "hf",
|
||
|
|
}
|
||
|
|
|
||
|
|
# Category-based defaults
|
||
|
|
CATEGORY_TO_LOADER: dict[str, str] = {
|
||
|
|
"llm": "gguf",
|
||
|
|
"embedding": "gguf",
|
||
|
|
"voice": "whisper",
|
||
|
|
"diffusion": "diffusers",
|
||
|
|
"multimodal": "hf",
|
||
|
|
"tools": "onnx",
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
def get_auto_loader(
|
||
|
|
model_id: str,
|
||
|
|
manifest_path: Optional[str] = None,
|
||
|
|
**init_kwargs: Any,
|
||
|
|
) -> "BaseModelLoader":
|
||
|
|
"""
|
||
|
|
Automatically select and return appropriate loader for a model.
|
||
|
|
|
||
|
|
Selection priority:
|
||
|
|
1. Manifest `format` field (if model in manifest)
|
||
|
|
2. File extension (for direct paths)
|
||
|
|
3. Manifest `category` field (fallback)
|
||
|
|
4. Default to "hf" loader
|
||
|
|
|
||
|
|
Args:
|
||
|
|
model_id: Model ID from manifest or direct path
|
||
|
|
manifest_path: Optional custom manifest path
|
||
|
|
**init_kwargs: Arguments passed to loader constructor
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Instantiated loader appropriate for the model
|
||
|
|
|
||
|
|
Example:
|
||
|
|
>>> loader = get_auto_loader("whisper-large-v3-turbo")
|
||
|
|
>>> model = await loader.load("whisper-large-v3-turbo")
|
||
|
|
|
||
|
|
>>> # Works with paths too
|
||
|
|
>>> loader = get_auto_loader("/path/to/model.onnx")
|
||
|
|
>>> model = await loader.load("/path/to/model.onnx")
|
||
|
|
"""
|
||
|
|
loader_name: Optional[str] = None
|
||
|
|
|
||
|
|
# Try to get manifest entry
|
||
|
|
try:
|
||
|
|
from .loader import ModelLoader
|
||
|
|
|
||
|
|
path_loader = ModelLoader(manifest_path=manifest_path)
|
||
|
|
entry = path_loader.get_model_entry(model_id)
|
||
|
|
|
||
|
|
if entry:
|
||
|
|
# Priority 1: Explicit format field
|
||
|
|
if entry.format:
|
||
|
|
loader_name = FORMAT_TO_LOADER.get(entry.format)
|
||
|
|
if loader_name:
|
||
|
|
logger.debug(
|
||
|
|
f"Selected loader '{loader_name}' from format '{entry.format}'"
|
||
|
|
)
|
||
|
|
|
||
|
|
# Priority 3: Category fallback
|
||
|
|
if not loader_name and entry.category:
|
||
|
|
loader_name = CATEGORY_TO_LOADER.get(entry.category)
|
||
|
|
if loader_name:
|
||
|
|
logger.debug(
|
||
|
|
f"Selected loader '{loader_name}' from category '{entry.category}'"
|
||
|
|
)
|
||
|
|
except Exception as e:
|
||
|
|
logger.debug(f"Failed to get manifest entry: {e}")
|
||
|
|
|
||
|
|
# Priority 2: Extension-based (for direct paths)
|
||
|
|
if not loader_name:
|
||
|
|
path = Path(model_id)
|
||
|
|
if path.suffix:
|
||
|
|
loader_name = EXTENSION_TO_LOADER.get(path.suffix.lower())
|
||
|
|
if loader_name:
|
||
|
|
logger.debug(
|
||
|
|
f"Selected loader '{loader_name}' from extension '{path.suffix}'"
|
||
|
|
)
|
||
|
|
|
||
|
|
# Default fallback
|
||
|
|
if not loader_name:
|
||
|
|
loader_name = "hf"
|
||
|
|
logger.debug(f"Using default loader 'hf'")
|
||
|
|
|
||
|
|
# Verify loader exists
|
||
|
|
if not is_loader_registered(loader_name):
|
||
|
|
logger.warning(
|
||
|
|
f"Loader '{loader_name}' not registered, falling back to 'hf'"
|
||
|
|
)
|
||
|
|
loader_name = "hf"
|
||
|
|
|
||
|
|
return get_loader(loader_name, **init_kwargs)
|
||
|
|
|
||
|
|
|
||
|
|
def detect_format_from_path(path: str) -> Optional[ModelFormat]:
|
||
|
|
"""
|
||
|
|
Detect model format from file path.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
path: Path to model file or directory
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Detected ModelFormat or None
|
||
|
|
"""
|
||
|
|
p = Path(path)
|
||
|
|
|
||
|
|
# File extension detection
|
||
|
|
ext = p.suffix.lower()
|
||
|
|
if ext == ".gguf":
|
||
|
|
return "gguf"
|
||
|
|
elif ext == ".onnx":
|
||
|
|
return "onnx"
|
||
|
|
elif ext == ".safetensors":
|
||
|
|
return "safetensors"
|
||
|
|
elif ext in (".pt", ".pth", ".bin"):
|
||
|
|
return "pytorch"
|
||
|
|
|
||
|
|
# Directory detection
|
||
|
|
if p.is_dir():
|
||
|
|
# Check for diffusion pipeline
|
||
|
|
if (p / "model_index.json").exists():
|
||
|
|
return "diffusion"
|
||
|
|
# Check for sharded safetensors
|
||
|
|
if (p / "model.safetensors.index.json").exists():
|
||
|
|
return "safetensors"
|
||
|
|
# Check for HuggingFace cache
|
||
|
|
if p.name.startswith("models--"):
|
||
|
|
return "hf-snapshot"
|
||
|
|
# Check for ONNX
|
||
|
|
if list(p.glob("*.onnx")):
|
||
|
|
return "onnx"
|
||
|
|
# Default to safetensors for directories with config.json
|
||
|
|
if (p / "config.json").exists():
|
||
|
|
return "safetensors"
|
||
|
|
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
def get_loader_for_format(
|
||
|
|
format: ModelFormat,
|
||
|
|
**init_kwargs: Any,
|
||
|
|
) -> "BaseModelLoader":
|
||
|
|
"""
|
||
|
|
Get a loader instance for a specific format.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
format: Model format
|
||
|
|
**init_kwargs: Arguments passed to loader constructor
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Instantiated loader
|
||
|
|
|
||
|
|
Raises:
|
||
|
|
KeyError: If no loader registered for format
|
||
|
|
"""
|
||
|
|
loader_name = FORMAT_TO_LOADER.get(format)
|
||
|
|
if not loader_name:
|
||
|
|
raise KeyError(f"No loader registered for format '{format}'")
|
||
|
|
|
||
|
|
return get_loader(loader_name, **init_kwargs)
|
||
|
|
|
||
|
|
|
||
|
|
def get_loader_for_category(
|
||
|
|
category: str,
|
||
|
|
**init_kwargs: Any,
|
||
|
|
) -> "BaseModelLoader":
|
||
|
|
"""
|
||
|
|
Get a loader instance for a model category.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
category: Model category (llm, embedding, voice, diffusion, multimodal)
|
||
|
|
**init_kwargs: Arguments passed to loader constructor
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Instantiated loader
|
||
|
|
|
||
|
|
Raises:
|
||
|
|
KeyError: If no loader registered for category
|
||
|
|
"""
|
||
|
|
loader_name = CATEGORY_TO_LOADER.get(category)
|
||
|
|
if not loader_name:
|
||
|
|
raise KeyError(f"No loader registered for category '{category}'")
|
||
|
|
|
||
|
|
return get_loader(loader_name, **init_kwargs)
|