""" Auto-loader selection based on manifest format field or file extension. Provides intelligent loader selection without explicit loader specification. """ from pathlib import Path from typing import Optional, Any, TYPE_CHECKING import logging from .registry import get_loader, is_loader_registered from .types import ModelFormat if TYPE_CHECKING: from .base import BaseModelLoader logger = logging.getLogger(__name__) # Format to loader mapping FORMAT_TO_LOADER: dict[str, str] = { "gguf": "gguf", "safetensors": "diffusers", # Default to diffusers for safetensors "onnx": "onnx", "pytorch": "hf", "hf-snapshot": "hf", "diffusion": "diffusers", } # Extension-based fallback EXTENSION_TO_LOADER: dict[str, str] = { ".gguf": "gguf", ".onnx": "onnx", ".safetensors": "diffusers", ".pt": "hf", ".pth": "hf", ".bin": "hf", } # Category-based defaults CATEGORY_TO_LOADER: dict[str, str] = { "llm": "gguf", "embedding": "gguf", "voice": "whisper", "diffusion": "diffusers", "multimodal": "hf", "tools": "onnx", } def get_auto_loader( model_id: str, manifest_path: Optional[str] = None, **init_kwargs: Any, ) -> "BaseModelLoader": """ Automatically select and return appropriate loader for a model. Selection priority: 1. Manifest `format` field (if model in manifest) 2. File extension (for direct paths) 3. Manifest `category` field (fallback) 4. Default to "hf" loader Args: model_id: Model ID from manifest or direct path manifest_path: Optional custom manifest path **init_kwargs: Arguments passed to loader constructor Returns: Instantiated loader appropriate for the model Example: >>> loader = get_auto_loader("whisper-large-v3-turbo") >>> model = await loader.load("whisper-large-v3-turbo") >>> # Works with paths too >>> loader = get_auto_loader("/path/to/model.onnx") >>> model = await loader.load("/path/to/model.onnx") """ loader_name: Optional[str] = None # Try to get manifest entry try: from .loader import ModelLoader path_loader = ModelLoader(manifest_path=manifest_path) entry = path_loader.get_model_entry(model_id) if entry: # Priority 1: Explicit format field if entry.format: loader_name = FORMAT_TO_LOADER.get(entry.format) if loader_name: logger.debug( f"Selected loader '{loader_name}' from format '{entry.format}'" ) # Priority 3: Category fallback if not loader_name and entry.category: loader_name = CATEGORY_TO_LOADER.get(entry.category) if loader_name: logger.debug( f"Selected loader '{loader_name}' from category '{entry.category}'" ) except Exception as e: logger.debug(f"Failed to get manifest entry: {e}") # Priority 2: Extension-based (for direct paths) if not loader_name: path = Path(model_id) if path.suffix: loader_name = EXTENSION_TO_LOADER.get(path.suffix.lower()) if loader_name: logger.debug( f"Selected loader '{loader_name}' from extension '{path.suffix}'" ) # Default fallback if not loader_name: loader_name = "hf" logger.debug(f"Using default loader 'hf'") # Verify loader exists if not is_loader_registered(loader_name): logger.warning( f"Loader '{loader_name}' not registered, falling back to 'hf'" ) loader_name = "hf" return get_loader(loader_name, **init_kwargs) def detect_format_from_path(path: str) -> Optional[ModelFormat]: """ Detect model format from file path. Args: path: Path to model file or directory Returns: Detected ModelFormat or None """ p = Path(path) # File extension detection ext = p.suffix.lower() if ext == ".gguf": return "gguf" elif ext == ".onnx": return "onnx" elif ext == ".safetensors": return "safetensors" elif ext in (".pt", ".pth", ".bin"): return "pytorch" # Directory detection if p.is_dir(): # Check for diffusion pipeline if (p / "model_index.json").exists(): return "diffusion" # Check for sharded safetensors if (p / "model.safetensors.index.json").exists(): return "safetensors" # Check for HuggingFace cache if p.name.startswith("models--"): return "hf-snapshot" # Check for ONNX if list(p.glob("*.onnx")): return "onnx" # Default to safetensors for directories with config.json if (p / "config.json").exists(): return "safetensors" return None def get_loader_for_format( format: ModelFormat, **init_kwargs: Any, ) -> "BaseModelLoader": """ Get a loader instance for a specific format. Args: format: Model format **init_kwargs: Arguments passed to loader constructor Returns: Instantiated loader Raises: KeyError: If no loader registered for format """ loader_name = FORMAT_TO_LOADER.get(format) if not loader_name: raise KeyError(f"No loader registered for format '{format}'") return get_loader(loader_name, **init_kwargs) def get_loader_for_category( category: str, **init_kwargs: Any, ) -> "BaseModelLoader": """ Get a loader instance for a model category. Args: category: Model category (llm, embedding, voice, diffusion, multimodal) **init_kwargs: Arguments passed to loader constructor Returns: Instantiated loader Raises: KeyError: If no loader registered for category """ loader_name = CATEGORY_TO_LOADER.get(category) if not loader_name: raise KeyError(f"No loader registered for category '{category}'") return get_loader(loader_name, **init_kwargs)