ml-model-loader/src_python/tqftw_model_loader/loader.py
Lilith aa01d0f388 chore: rename package @lilith/model-loader -> @lilith/ml-model-loader
Package renamed to follow naming convention:
@lilith/{namespace}-{parent}-{child}

Generated by rename-packages.sh
2025-12-31 01:32:00 -08:00

398 lines
12 KiB
Python

"""
Model loader implementation.
Uses the TypeScript CLI via subprocess for remote fetching,
with direct file operations for cached models.
"""
import json
import subprocess
import shutil
from pathlib import Path
from typing import Optional, List, Callable
from .types import (
ModelEntry,
LoadResult,
TransferProgress,
CachedModel,
CacheStats,
ModelLoaderOptions,
)
def _find_cli() -> str:
"""Find the model-loader CLI executable."""
# Try npx first (works if package is installed)
if shutil.which("npx"):
return "npx"
# Try direct path (development)
cli_path = Path(__file__).parent.parent.parent / "dist" / "bin" / "model-loader.js"
if cli_path.exists():
return f"node {cli_path}"
raise RuntimeError(
"model-loader CLI not found. Install @tqftw/model-loader or build from source."
)
def _run_cli(
args: List[str],
options: Optional[ModelLoaderOptions] = None,
) -> dict:
"""
Run the model-loader CLI and return JSON output.
Args:
args: CLI arguments
options: Loader options to pass
Returns:
Parsed JSON output from CLI
Raises:
RuntimeError: If CLI fails
"""
cmd = [_find_cli(), "@tqftw/model-loader"]
# Handle npx vs direct node
if cmd[0] == "npx":
cmd = ["npx", "@tqftw/model-loader"] + args
else:
cmd = cmd[0].split() + args
# Add options
if options:
if options.cache_dir:
cmd.extend(["--cache-dir", str(options.cache_dir)])
if options.manifest_path:
cmd.extend(["--manifest", str(options.manifest_path)])
if options.ssh_key:
cmd.extend(["--ssh-key", str(options.ssh_key)])
if options.timeout:
cmd.extend(["--timeout", str(options.timeout)])
# Always request JSON output
cmd.append("--json")
try:
result = subprocess.run(
cmd,
capture_output=True,
text=True,
check=False,
)
if result.returncode != 0:
# Try to parse error from JSON
try:
data = json.loads(result.stdout)
if "error" in data:
raise RuntimeError(data["error"])
except json.JSONDecodeError:
pass
raise RuntimeError(f"CLI failed: {result.stderr or result.stdout}")
return json.loads(result.stdout)
except FileNotFoundError:
raise RuntimeError("model-loader CLI not found")
except json.JSONDecodeError as e:
raise RuntimeError(f"Invalid CLI output: {e}")
class ModelLoader:
"""
Model loader with remote fetching and caching support.
Uses the TypeScript CLI via subprocess for remote operations,
with direct file operations for local cache access.
Example:
>>> loader = ModelLoader(verbose=True)
>>> result = loader.ensure_model("ministral-3b-instruct")
>>> print(result.path)
~/.cache/models/Ministral-3-3B-Instruct-2512-Q8_0.gguf
"""
def __init__(
self,
cache_dir: Optional[Path] = None,
manifest_path: Optional[Path] = None,
ssh_key: Optional[Path] = None,
timeout: Optional[int] = None,
verbose: bool = False,
) -> None:
"""
Initialize the model loader.
Args:
cache_dir: Cache directory (default: ~/.cache/models)
manifest_path: Path to manifest JSON file
ssh_key: SSH key file for remote access
timeout: Timeout for remote operations (ms)
verbose: Enable verbose output
"""
self.options = ModelLoaderOptions(
cache_dir=cache_dir,
manifest_path=manifest_path,
ssh_key=ssh_key,
timeout=timeout,
verbose=verbose,
)
self._cache_dir = cache_dir or Path.home() / ".cache" / "models"
self._manifest: Optional[dict] = None
def _load_manifest(self) -> dict:
"""Load and cache manifest from file."""
if self._manifest is not None:
return self._manifest
manifest_path = self.options.manifest_path or (self._cache_dir / "manifest.json")
if manifest_path.exists():
self._manifest = json.loads(manifest_path.read_text())
else:
self._manifest = {"version": "1.0", "models": {}, "remote": {}, "fallbacks": []}
return self._manifest
def _get_cached_path(self, model_id: str) -> Path:
"""Get expected cache path for a model."""
manifest = self._load_manifest()
models = manifest.get("models", {})
if model_id in models:
entry = models[model_id]
filename = Path(entry["path"]).name
else:
filename = Path(model_id).name
return self._cache_dir / filename
def ensure_model(self, model_id: str) -> LoadResult:
"""
Ensure a model is available locally.
Fetches from remote if not cached.
Args:
model_id: Model ID from manifest or direct path
Returns:
LoadResult with local path and metadata
"""
# Quick check: if already cached with correct size, skip CLI
cached_path = self._get_cached_path(model_id)
manifest = self._load_manifest()
entry_data = manifest.get("models", {}).get(model_id)
if cached_path.exists():
if entry_data:
expected_size = entry_data.get("size", 0)
if cached_path.stat().st_size == expected_size:
return LoadResult(
path=str(cached_path),
source="cache",
duration=0,
model=ModelEntry.from_dict(entry_data) if entry_data else None,
model_id=model_id if entry_data else None,
)
else:
# No manifest entry, just return cached path
return LoadResult(
path=str(cached_path),
source="cache",
duration=0,
)
# Use CLI for remote fetch
data = _run_cli(["ensure", model_id], self.options)
return LoadResult.from_dict(data)
def ensure_model_sync(self, model_id: str) -> str:
"""
Ensure model is available (sync, cache-only).
Only uses cache and local fallbacks, no remote fetch.
Args:
model_id: Model ID or path
Returns:
Local path to model
Raises:
RuntimeError: If model not cached
"""
cached_path = self._get_cached_path(model_id)
if cached_path.exists():
return str(cached_path)
# Try fallbacks from manifest
manifest = self._load_manifest()
entry_data = manifest.get("models", {}).get(model_id)
if entry_data:
model_path = entry_data["path"]
for fallback in manifest.get("fallbacks", []):
fallback_path = Path(fallback) / model_path
if fallback_path.exists():
# Copy to cache
self._cache_dir.mkdir(parents=True, exist_ok=True)
shutil.copy2(fallback_path, cached_path)
return str(cached_path)
raise RuntimeError(f"Model not cached: {model_id}")
def is_cached(self, model_id: str) -> bool:
"""Check if a model is cached."""
cached_path = self._get_cached_path(model_id)
if not cached_path.exists():
return False
# Verify size if in manifest
manifest = self._load_manifest()
entry_data = manifest.get("models", {}).get(model_id)
if entry_data:
expected_size = entry_data.get("size", 0)
return cached_path.stat().st_size == expected_size
return True
def get_cached_path(self, model_id: str) -> Path:
"""Get the expected cache path for a model."""
return self._get_cached_path(model_id)
def get_model_entry(self, model_id: str) -> Optional[ModelEntry]:
"""Get model entry from manifest."""
manifest = self._load_manifest()
entry_data = manifest.get("models", {}).get(model_id)
if entry_data:
return ModelEntry.from_dict(entry_data)
return None
def list_models(self) -> List[tuple[str, ModelEntry]]:
"""List all models from manifest."""
manifest = self._load_manifest()
return [
(model_id, ModelEntry.from_dict(entry))
for model_id, entry in manifest.get("models", {}).items()
]
def list_model_ids(self) -> List[str]:
"""List all model IDs from manifest."""
manifest = self._load_manifest()
return list(manifest.get("models", {}).keys())
def list_cached(self) -> List[CachedModel]:
"""List cached models."""
if not self._cache_dir.exists():
return []
manifest = self._load_manifest()
# Build reverse lookup
path_to_id: dict[str, str] = {}
for model_id, entry in manifest.get("models", {}).items():
filename = Path(entry["path"]).name
path_to_id[filename] = model_id
result: List[CachedModel] = []
for path in self._cache_dir.iterdir():
if path.suffix not in (".gguf", ".bin", ".safetensors"):
continue
result.append(CachedModel(
name=path.name,
path=str(path),
size=path.stat().st_size,
model_id=path_to_id.get(path.name),
))
return result
def remove_from_cache(self, model_id: str) -> None:
"""Remove a model from cache."""
cached_path = self._get_cached_path(model_id)
if cached_path.exists():
cached_path.unlink()
def clear_cache(self) -> None:
"""Clear entire cache."""
for model in self.list_cached():
Path(model.path).unlink()
def get_cache_stats(self) -> CacheStats:
"""Get cache statistics."""
cached = self.list_cached()
return CacheStats(
count=len(cached),
total_size=sum(m.size for m in cached),
cache_dir=str(self._cache_dir),
)
# Module-level singleton
_loader: Optional[ModelLoader] = None
def get_loader(**kwargs) -> ModelLoader:
"""
Get or create the default ModelLoader instance.
Args:
**kwargs: Options passed to ModelLoader constructor
Returns:
ModelLoader instance
"""
global _loader
if _loader is None or kwargs:
_loader = ModelLoader(**kwargs)
return _loader
def ensure_model(model_id: str, **kwargs) -> str:
"""
Ensure a model is cached and return local path.
Convenience function that handles remote fetching.
Args:
model_id: Model ID from manifest or direct path
**kwargs: Options passed to ModelLoader
Returns:
Local path to model file
Example:
>>> path = ensure_model("ministral-3b-instruct")
>>> print(path)
/home/user/.cache/models/Ministral-3-3B-Instruct-2512-Q8_0.gguf
"""
loader = get_loader(**kwargs)
result = loader.ensure_model(model_id)
return result.path
def ensure_model_sync(model_id: str, **kwargs) -> str:
"""
Ensure model is available (sync, cache-only).
Args:
model_id: Model ID or path
**kwargs: Options passed to ModelLoader
Returns:
Local path to model file
Raises:
RuntimeError: If model not cached
"""
loader = get_loader(**kwargs)
return loader.ensure_model_sync(model_id)