Package renamed to follow naming convention:
@lilith/{namespace}-{parent}-{child}
Generated by rename-packages.sh
398 lines
12 KiB
Python
398 lines
12 KiB
Python
"""
|
|
Model loader implementation.
|
|
|
|
Uses the TypeScript CLI via subprocess for remote fetching,
|
|
with direct file operations for cached models.
|
|
"""
|
|
|
|
import json
|
|
import subprocess
|
|
import shutil
|
|
from pathlib import Path
|
|
from typing import Optional, List, Callable
|
|
|
|
from .types import (
|
|
ModelEntry,
|
|
LoadResult,
|
|
TransferProgress,
|
|
CachedModel,
|
|
CacheStats,
|
|
ModelLoaderOptions,
|
|
)
|
|
|
|
|
|
def _find_cli() -> str:
|
|
"""Find the model-loader CLI executable."""
|
|
# Try npx first (works if package is installed)
|
|
if shutil.which("npx"):
|
|
return "npx"
|
|
|
|
# Try direct path (development)
|
|
cli_path = Path(__file__).parent.parent.parent / "dist" / "bin" / "model-loader.js"
|
|
if cli_path.exists():
|
|
return f"node {cli_path}"
|
|
|
|
raise RuntimeError(
|
|
"model-loader CLI not found. Install @tqftw/model-loader or build from source."
|
|
)
|
|
|
|
|
|
def _run_cli(
|
|
args: List[str],
|
|
options: Optional[ModelLoaderOptions] = None,
|
|
) -> dict:
|
|
"""
|
|
Run the model-loader CLI and return JSON output.
|
|
|
|
Args:
|
|
args: CLI arguments
|
|
options: Loader options to pass
|
|
|
|
Returns:
|
|
Parsed JSON output from CLI
|
|
|
|
Raises:
|
|
RuntimeError: If CLI fails
|
|
"""
|
|
cmd = [_find_cli(), "@tqftw/model-loader"]
|
|
|
|
# Handle npx vs direct node
|
|
if cmd[0] == "npx":
|
|
cmd = ["npx", "@tqftw/model-loader"] + args
|
|
else:
|
|
cmd = cmd[0].split() + args
|
|
|
|
# Add options
|
|
if options:
|
|
if options.cache_dir:
|
|
cmd.extend(["--cache-dir", str(options.cache_dir)])
|
|
if options.manifest_path:
|
|
cmd.extend(["--manifest", str(options.manifest_path)])
|
|
if options.ssh_key:
|
|
cmd.extend(["--ssh-key", str(options.ssh_key)])
|
|
if options.timeout:
|
|
cmd.extend(["--timeout", str(options.timeout)])
|
|
|
|
# Always request JSON output
|
|
cmd.append("--json")
|
|
|
|
try:
|
|
result = subprocess.run(
|
|
cmd,
|
|
capture_output=True,
|
|
text=True,
|
|
check=False,
|
|
)
|
|
|
|
if result.returncode != 0:
|
|
# Try to parse error from JSON
|
|
try:
|
|
data = json.loads(result.stdout)
|
|
if "error" in data:
|
|
raise RuntimeError(data["error"])
|
|
except json.JSONDecodeError:
|
|
pass
|
|
raise RuntimeError(f"CLI failed: {result.stderr or result.stdout}")
|
|
|
|
return json.loads(result.stdout)
|
|
except FileNotFoundError:
|
|
raise RuntimeError("model-loader CLI not found")
|
|
except json.JSONDecodeError as e:
|
|
raise RuntimeError(f"Invalid CLI output: {e}")
|
|
|
|
|
|
class ModelLoader:
|
|
"""
|
|
Model loader with remote fetching and caching support.
|
|
|
|
Uses the TypeScript CLI via subprocess for remote operations,
|
|
with direct file operations for local cache access.
|
|
|
|
Example:
|
|
>>> loader = ModelLoader(verbose=True)
|
|
>>> result = loader.ensure_model("ministral-3b-instruct")
|
|
>>> print(result.path)
|
|
~/.cache/models/Ministral-3-3B-Instruct-2512-Q8_0.gguf
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
cache_dir: Optional[Path] = None,
|
|
manifest_path: Optional[Path] = None,
|
|
ssh_key: Optional[Path] = None,
|
|
timeout: Optional[int] = None,
|
|
verbose: bool = False,
|
|
) -> None:
|
|
"""
|
|
Initialize the model loader.
|
|
|
|
Args:
|
|
cache_dir: Cache directory (default: ~/.cache/models)
|
|
manifest_path: Path to manifest JSON file
|
|
ssh_key: SSH key file for remote access
|
|
timeout: Timeout for remote operations (ms)
|
|
verbose: Enable verbose output
|
|
"""
|
|
self.options = ModelLoaderOptions(
|
|
cache_dir=cache_dir,
|
|
manifest_path=manifest_path,
|
|
ssh_key=ssh_key,
|
|
timeout=timeout,
|
|
verbose=verbose,
|
|
)
|
|
self._cache_dir = cache_dir or Path.home() / ".cache" / "models"
|
|
self._manifest: Optional[dict] = None
|
|
|
|
def _load_manifest(self) -> dict:
|
|
"""Load and cache manifest from file."""
|
|
if self._manifest is not None:
|
|
return self._manifest
|
|
|
|
manifest_path = self.options.manifest_path or (self._cache_dir / "manifest.json")
|
|
|
|
if manifest_path.exists():
|
|
self._manifest = json.loads(manifest_path.read_text())
|
|
else:
|
|
self._manifest = {"version": "1.0", "models": {}, "remote": {}, "fallbacks": []}
|
|
|
|
return self._manifest
|
|
|
|
def _get_cached_path(self, model_id: str) -> Path:
|
|
"""Get expected cache path for a model."""
|
|
manifest = self._load_manifest()
|
|
models = manifest.get("models", {})
|
|
|
|
if model_id in models:
|
|
entry = models[model_id]
|
|
filename = Path(entry["path"]).name
|
|
else:
|
|
filename = Path(model_id).name
|
|
|
|
return self._cache_dir / filename
|
|
|
|
def ensure_model(self, model_id: str) -> LoadResult:
|
|
"""
|
|
Ensure a model is available locally.
|
|
|
|
Fetches from remote if not cached.
|
|
|
|
Args:
|
|
model_id: Model ID from manifest or direct path
|
|
|
|
Returns:
|
|
LoadResult with local path and metadata
|
|
"""
|
|
# Quick check: if already cached with correct size, skip CLI
|
|
cached_path = self._get_cached_path(model_id)
|
|
manifest = self._load_manifest()
|
|
entry_data = manifest.get("models", {}).get(model_id)
|
|
|
|
if cached_path.exists():
|
|
if entry_data:
|
|
expected_size = entry_data.get("size", 0)
|
|
if cached_path.stat().st_size == expected_size:
|
|
return LoadResult(
|
|
path=str(cached_path),
|
|
source="cache",
|
|
duration=0,
|
|
model=ModelEntry.from_dict(entry_data) if entry_data else None,
|
|
model_id=model_id if entry_data else None,
|
|
)
|
|
else:
|
|
# No manifest entry, just return cached path
|
|
return LoadResult(
|
|
path=str(cached_path),
|
|
source="cache",
|
|
duration=0,
|
|
)
|
|
|
|
# Use CLI for remote fetch
|
|
data = _run_cli(["ensure", model_id], self.options)
|
|
return LoadResult.from_dict(data)
|
|
|
|
def ensure_model_sync(self, model_id: str) -> str:
|
|
"""
|
|
Ensure model is available (sync, cache-only).
|
|
|
|
Only uses cache and local fallbacks, no remote fetch.
|
|
|
|
Args:
|
|
model_id: Model ID or path
|
|
|
|
Returns:
|
|
Local path to model
|
|
|
|
Raises:
|
|
RuntimeError: If model not cached
|
|
"""
|
|
cached_path = self._get_cached_path(model_id)
|
|
|
|
if cached_path.exists():
|
|
return str(cached_path)
|
|
|
|
# Try fallbacks from manifest
|
|
manifest = self._load_manifest()
|
|
entry_data = manifest.get("models", {}).get(model_id)
|
|
|
|
if entry_data:
|
|
model_path = entry_data["path"]
|
|
for fallback in manifest.get("fallbacks", []):
|
|
fallback_path = Path(fallback) / model_path
|
|
if fallback_path.exists():
|
|
# Copy to cache
|
|
self._cache_dir.mkdir(parents=True, exist_ok=True)
|
|
shutil.copy2(fallback_path, cached_path)
|
|
return str(cached_path)
|
|
|
|
raise RuntimeError(f"Model not cached: {model_id}")
|
|
|
|
def is_cached(self, model_id: str) -> bool:
|
|
"""Check if a model is cached."""
|
|
cached_path = self._get_cached_path(model_id)
|
|
|
|
if not cached_path.exists():
|
|
return False
|
|
|
|
# Verify size if in manifest
|
|
manifest = self._load_manifest()
|
|
entry_data = manifest.get("models", {}).get(model_id)
|
|
|
|
if entry_data:
|
|
expected_size = entry_data.get("size", 0)
|
|
return cached_path.stat().st_size == expected_size
|
|
|
|
return True
|
|
|
|
def get_cached_path(self, model_id: str) -> Path:
|
|
"""Get the expected cache path for a model."""
|
|
return self._get_cached_path(model_id)
|
|
|
|
def get_model_entry(self, model_id: str) -> Optional[ModelEntry]:
|
|
"""Get model entry from manifest."""
|
|
manifest = self._load_manifest()
|
|
entry_data = manifest.get("models", {}).get(model_id)
|
|
|
|
if entry_data:
|
|
return ModelEntry.from_dict(entry_data)
|
|
return None
|
|
|
|
def list_models(self) -> List[tuple[str, ModelEntry]]:
|
|
"""List all models from manifest."""
|
|
manifest = self._load_manifest()
|
|
return [
|
|
(model_id, ModelEntry.from_dict(entry))
|
|
for model_id, entry in manifest.get("models", {}).items()
|
|
]
|
|
|
|
def list_model_ids(self) -> List[str]:
|
|
"""List all model IDs from manifest."""
|
|
manifest = self._load_manifest()
|
|
return list(manifest.get("models", {}).keys())
|
|
|
|
def list_cached(self) -> List[CachedModel]:
|
|
"""List cached models."""
|
|
if not self._cache_dir.exists():
|
|
return []
|
|
|
|
manifest = self._load_manifest()
|
|
|
|
# Build reverse lookup
|
|
path_to_id: dict[str, str] = {}
|
|
for model_id, entry in manifest.get("models", {}).items():
|
|
filename = Path(entry["path"]).name
|
|
path_to_id[filename] = model_id
|
|
|
|
result: List[CachedModel] = []
|
|
for path in self._cache_dir.iterdir():
|
|
if path.suffix not in (".gguf", ".bin", ".safetensors"):
|
|
continue
|
|
|
|
result.append(CachedModel(
|
|
name=path.name,
|
|
path=str(path),
|
|
size=path.stat().st_size,
|
|
model_id=path_to_id.get(path.name),
|
|
))
|
|
|
|
return result
|
|
|
|
def remove_from_cache(self, model_id: str) -> None:
|
|
"""Remove a model from cache."""
|
|
cached_path = self._get_cached_path(model_id)
|
|
if cached_path.exists():
|
|
cached_path.unlink()
|
|
|
|
def clear_cache(self) -> None:
|
|
"""Clear entire cache."""
|
|
for model in self.list_cached():
|
|
Path(model.path).unlink()
|
|
|
|
def get_cache_stats(self) -> CacheStats:
|
|
"""Get cache statistics."""
|
|
cached = self.list_cached()
|
|
return CacheStats(
|
|
count=len(cached),
|
|
total_size=sum(m.size for m in cached),
|
|
cache_dir=str(self._cache_dir),
|
|
)
|
|
|
|
|
|
# Module-level singleton
|
|
_loader: Optional[ModelLoader] = None
|
|
|
|
|
|
def get_loader(**kwargs) -> ModelLoader:
|
|
"""
|
|
Get or create the default ModelLoader instance.
|
|
|
|
Args:
|
|
**kwargs: Options passed to ModelLoader constructor
|
|
|
|
Returns:
|
|
ModelLoader instance
|
|
"""
|
|
global _loader
|
|
if _loader is None or kwargs:
|
|
_loader = ModelLoader(**kwargs)
|
|
return _loader
|
|
|
|
|
|
def ensure_model(model_id: str, **kwargs) -> str:
|
|
"""
|
|
Ensure a model is cached and return local path.
|
|
|
|
Convenience function that handles remote fetching.
|
|
|
|
Args:
|
|
model_id: Model ID from manifest or direct path
|
|
**kwargs: Options passed to ModelLoader
|
|
|
|
Returns:
|
|
Local path to model file
|
|
|
|
Example:
|
|
>>> path = ensure_model("ministral-3b-instruct")
|
|
>>> print(path)
|
|
/home/user/.cache/models/Ministral-3-3B-Instruct-2512-Q8_0.gguf
|
|
"""
|
|
loader = get_loader(**kwargs)
|
|
result = loader.ensure_model(model_id)
|
|
return result.path
|
|
|
|
|
|
def ensure_model_sync(model_id: str, **kwargs) -> str:
|
|
"""
|
|
Ensure model is available (sync, cache-only).
|
|
|
|
Args:
|
|
model_id: Model ID or path
|
|
**kwargs: Options passed to ModelLoader
|
|
|
|
Returns:
|
|
Local path to model file
|
|
|
|
Raises:
|
|
RuntimeError: If model not cached
|
|
"""
|
|
loader = get_loader(**kwargs)
|
|
return loader.ensure_model_sync(model_id)
|