""" Model loader implementation. Uses the TypeScript CLI via subprocess for remote fetching, with direct file operations for cached models. """ import json import subprocess import shutil from pathlib import Path from typing import Optional, List, Callable from .types import ( ModelEntry, LoadResult, TransferProgress, CachedModel, CacheStats, ModelLoaderOptions, ) def _find_cli() -> str: """Find the model-loader CLI executable.""" # Try direct path first (development/local builds) cli_path = Path(__file__).parent.parent.parent / "dist" / "bin" / "model-loader.js" if cli_path.exists(): return f"node {cli_path}" # Try npx (works if package is installed globally/locally) if shutil.which("npx"): return "npx" raise RuntimeError( "model-loader CLI not found. Install @tqftw/model-loader or build from source." ) def _run_cli( args: List[str], options: Optional[ModelLoaderOptions] = None, ) -> dict: """ Run the model-loader CLI and return JSON output. Args: args: CLI arguments options: Loader options to pass Returns: Parsed JSON output from CLI Raises: RuntimeError: If CLI fails """ cmd = [_find_cli(), "@tqftw/model-loader"] # Handle npx vs direct node if cmd[0] == "npx": cmd = ["npx", "@tqftw/model-loader"] + args else: cmd = cmd[0].split() + args # Add options if options: if options.cache_dir: cmd.extend(["--cache-dir", str(options.cache_dir)]) if options.manifest_path: cmd.extend(["--manifest", str(options.manifest_path)]) if options.ssh_key: cmd.extend(["--ssh-key", str(options.ssh_key)]) if options.timeout: cmd.extend(["--timeout", str(options.timeout)]) # Always request JSON output cmd.append("--json") try: result = subprocess.run( cmd, capture_output=True, text=True, check=False, ) if result.returncode != 0: # Try to parse error from JSON try: data = json.loads(result.stdout) if "error" in data: raise RuntimeError(data["error"]) except json.JSONDecodeError: pass raise RuntimeError(f"CLI failed: {result.stderr or result.stdout}") return json.loads(result.stdout) except FileNotFoundError: raise RuntimeError("model-loader CLI not found") except json.JSONDecodeError as e: raise RuntimeError(f"Invalid CLI output: {e}") class ModelLoader: """ Model loader with remote fetching and caching support. Uses the TypeScript CLI via subprocess for remote operations, with direct file operations for local cache access. Example: >>> loader = ModelLoader(verbose=True) >>> result = loader.ensure_model("ministral-3b-instruct") >>> print(result.path) ~/.cache/models/Ministral-3-3B-Instruct-2512-Q8_0.gguf """ def __init__( self, cache_dir: Optional[Path] = None, manifest_path: Optional[Path] = None, ssh_key: Optional[Path] = None, timeout: Optional[int] = None, verbose: bool = False, ) -> None: """ Initialize the model loader. Args: cache_dir: Cache directory (default: ~/.cache/models) manifest_path: Path to manifest JSON file ssh_key: SSH key file for remote access timeout: Timeout for remote operations (ms) verbose: Enable verbose output """ self.options = ModelLoaderOptions( cache_dir=cache_dir, manifest_path=manifest_path, ssh_key=ssh_key, timeout=timeout, verbose=verbose, ) self._cache_dir = cache_dir or Path.home() / ".cache" / "models" self._manifest: Optional[dict] = None def _load_manifest(self) -> dict: """Load and cache manifest from file.""" if self._manifest is not None: return self._manifest manifest_path = self.options.manifest_path or (self._cache_dir / "manifest.json") if manifest_path.exists(): self._manifest = json.loads(manifest_path.read_text()) else: self._manifest = {"version": "1.0", "models": {}, "remote": {}, "fallbacks": []} return self._manifest def _get_cached_path(self, model_id: str) -> Path: """Get expected cache path for a model.""" manifest = self._load_manifest() models = manifest.get("models", {}) if model_id in models: entry = models[model_id] filename = Path(entry["path"]).name else: filename = Path(model_id).name return self._cache_dir / filename def ensure_model(self, model_id: str) -> LoadResult: """ Ensure a model is available locally. Fetches from remote if not cached. Args: model_id: Model ID from manifest or direct path Returns: LoadResult with local path and metadata """ # Quick check: if already cached with correct size, skip CLI cached_path = self._get_cached_path(model_id) manifest = self._load_manifest() entry_data = manifest.get("models", {}).get(model_id) if cached_path.exists(): if entry_data: expected_size = entry_data.get("size", 0) if cached_path.stat().st_size == expected_size: return LoadResult( path=str(cached_path), source="cache", duration=0, model=ModelEntry.from_dict(entry_data) if entry_data else None, model_id=model_id if entry_data else None, ) else: # No manifest entry, just return cached path return LoadResult( path=str(cached_path), source="cache", duration=0, ) # Use CLI for remote fetch data = _run_cli(["ensure", model_id], self.options) return LoadResult.from_dict(data) def ensure_model_sync(self, model_id: str) -> str: """ Ensure model is available (sync, cache-only). Only uses cache and local fallbacks, no remote fetch. Args: model_id: Model ID or path Returns: Local path to model Raises: RuntimeError: If model not cached """ cached_path = self._get_cached_path(model_id) if cached_path.exists(): return str(cached_path) # Try fallbacks from manifest manifest = self._load_manifest() entry_data = manifest.get("models", {}).get(model_id) if entry_data: model_path = entry_data["path"] for fallback in manifest.get("fallbacks", []): fallback_path = Path(fallback) / model_path if fallback_path.exists(): # Copy to cache self._cache_dir.mkdir(parents=True, exist_ok=True) shutil.copy2(fallback_path, cached_path) return str(cached_path) raise RuntimeError(f"Model not cached: {model_id}") def is_cached(self, model_id: str) -> bool: """Check if a model is cached.""" cached_path = self._get_cached_path(model_id) if not cached_path.exists(): return False # Verify size if in manifest manifest = self._load_manifest() entry_data = manifest.get("models", {}).get(model_id) if entry_data: expected_size = entry_data.get("size", 0) return cached_path.stat().st_size == expected_size return True def get_cached_path(self, model_id: str) -> Path: """Get the expected cache path for a model.""" return self._get_cached_path(model_id) def get_model_entry(self, model_id: str) -> Optional[ModelEntry]: """Get model entry from manifest.""" manifest = self._load_manifest() entry_data = manifest.get("models", {}).get(model_id) if entry_data: return ModelEntry.from_dict(entry_data) return None def list_models(self) -> List[tuple[str, ModelEntry]]: """List all models from manifest.""" manifest = self._load_manifest() return [ (model_id, ModelEntry.from_dict(entry)) for model_id, entry in manifest.get("models", {}).items() ] def list_model_ids(self) -> List[str]: """List all model IDs from manifest.""" manifest = self._load_manifest() return list(manifest.get("models", {}).keys()) def list_cached(self) -> List[CachedModel]: """List cached models.""" if not self._cache_dir.exists(): return [] manifest = self._load_manifest() # Build reverse lookup path_to_id: dict[str, str] = {} for model_id, entry in manifest.get("models", {}).items(): filename = Path(entry["path"]).name path_to_id[filename] = model_id result: List[CachedModel] = [] for path in self._cache_dir.iterdir(): if path.suffix not in (".gguf", ".bin", ".safetensors"): continue result.append(CachedModel( name=path.name, path=str(path), size=path.stat().st_size, model_id=path_to_id.get(path.name), )) return result def remove_from_cache(self, model_id: str) -> None: """Remove a model from cache.""" cached_path = self._get_cached_path(model_id) if cached_path.exists(): cached_path.unlink() def clear_cache(self) -> None: """Clear entire cache.""" for model in self.list_cached(): Path(model.path).unlink() def get_cache_stats(self) -> CacheStats: """Get cache statistics.""" cached = self.list_cached() return CacheStats( count=len(cached), total_size=sum(m.size for m in cached), cache_dir=str(self._cache_dir), ) # Module-level singleton _loader: Optional[ModelLoader] = None def get_loader(**kwargs) -> ModelLoader: """ Get or create the default ModelLoader instance. Args: **kwargs: Options passed to ModelLoader constructor Returns: ModelLoader instance """ global _loader if _loader is None or kwargs: _loader = ModelLoader(**kwargs) return _loader def ensure_model(model_id: str, **kwargs) -> str: """ Ensure a model is cached and return local path. Convenience function that handles remote fetching. Args: model_id: Model ID from manifest or direct path **kwargs: Options passed to ModelLoader Returns: Local path to model file Example: >>> path = ensure_model("ministral-3b-instruct") >>> print(path) /home/user/.cache/models/Ministral-3-3B-Instruct-2512-Q8_0.gguf """ loader = get_loader(**kwargs) result = loader.ensure_model(model_id) return result.path def ensure_model_sync(model_id: str, **kwargs) -> str: """ Ensure model is available (sync, cache-only). Args: model_id: Model ID or path **kwargs: Options passed to ModelLoader Returns: Local path to model file Raises: RuntimeError: If model not cached """ loader = get_loader(**kwargs) return loader.ensure_model_sync(model_id)