Add support for all ML model types beyond GGUF: - New discovery module for auto-scanning model directories - Detect formats: GGUF, safetensors, ONNX, PyTorch, diffusion pipelines - CLI commands: discover, scan, sync for manifest management - Manifest v2.0 with format field, directory support, file lists Python loaders (v2.0.0): - ONNXLoader with CUDA/TensorRT execution providers - WhisperLoader for faster-whisper with transcribe/stream - get_auto_loader() for automatic backend selection Breaking: Manifest schema upgraded to v2.0 (auto-migrates v1.x on load) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
278 lines
8.8 KiB
Python
278 lines
8.8 KiB
Python
"""
|
|
ONNX Runtime model loader.
|
|
|
|
Loads ONNX models with GPU acceleration via CUDA/TensorRT execution providers.
|
|
"""
|
|
|
|
from pathlib import Path
|
|
from typing import Optional, Any, Dict, List
|
|
import time
|
|
import logging
|
|
|
|
from .base import BaseModelLoader, ModelInfo, ModelLoadError, ModelNotFoundError
|
|
from .registry import register_loader
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Type alias for ONNX session
|
|
ONNXSession = Any
|
|
|
|
|
|
def _get_onnx_providers(device: Optional[str] = None) -> List[str]:
|
|
"""Get ONNX execution providers based on device preference."""
|
|
if device is None or device == "auto":
|
|
# Try CUDA first, fall back to CPU
|
|
return [
|
|
"TensorrtExecutionProvider",
|
|
"CUDAExecutionProvider",
|
|
"CPUExecutionProvider",
|
|
]
|
|
elif device.startswith("cuda"):
|
|
return [
|
|
"TensorrtExecutionProvider",
|
|
"CUDAExecutionProvider",
|
|
"CPUExecutionProvider",
|
|
]
|
|
elif device == "tensorrt":
|
|
return [
|
|
"TensorrtExecutionProvider",
|
|
"CUDAExecutionProvider",
|
|
"CPUExecutionProvider",
|
|
]
|
|
else:
|
|
return ["CPUExecutionProvider"]
|
|
|
|
|
|
@register_loader("onnx", aliases=["onnxruntime", "ort"])
|
|
class ONNXLoader(BaseModelLoader[ONNXSession]):
|
|
"""
|
|
ONNX Runtime model loader.
|
|
|
|
Supports both single .onnx files and directory-based models
|
|
(containing model.onnx + external data).
|
|
|
|
Example:
|
|
>>> loader = ONNXLoader()
|
|
>>> session = await loader.load("silero-vad")
|
|
>>> outputs = session.run(None, {"input": audio_data})
|
|
|
|
# Or with specific options
|
|
>>> session = await loader.load(
|
|
... "model.onnx",
|
|
... device="cuda",
|
|
... providers=["CUDAExecutionProvider"],
|
|
... )
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self._input_names: List[str] = []
|
|
self._output_names: List[str] = []
|
|
self._session_options: Any = None
|
|
|
|
@property
|
|
def input_names(self) -> List[str]:
|
|
"""Get input tensor names."""
|
|
return self._input_names
|
|
|
|
@property
|
|
def output_names(self) -> List[str]:
|
|
"""Get output tensor names."""
|
|
return self._output_names
|
|
|
|
async def load(
|
|
self,
|
|
model_id: str,
|
|
*,
|
|
device: Optional[str] = None,
|
|
providers: Optional[List[str]] = None,
|
|
provider_options: Optional[List[Dict[str, Any]]] = None,
|
|
graph_optimization_level: str = "all",
|
|
intra_op_num_threads: int = 0,
|
|
inter_op_num_threads: int = 0,
|
|
**kwargs: Any,
|
|
) -> ONNXSession:
|
|
"""
|
|
Load an ONNX model.
|
|
|
|
Args:
|
|
model_id: Model ID from manifest or path to .onnx file/directory
|
|
device: Device preference ("cuda", "cpu", "tensorrt", "auto")
|
|
providers: Explicit execution providers list
|
|
provider_options: Provider-specific options
|
|
graph_optimization_level: "disabled", "basic", "extended", "all"
|
|
intra_op_num_threads: Threads within ops (0 = auto)
|
|
inter_op_num_threads: Threads between ops (0 = auto)
|
|
**kwargs: Additional session options
|
|
|
|
Returns:
|
|
ONNX InferenceSession
|
|
"""
|
|
try:
|
|
import onnxruntime as ort
|
|
except ImportError as e:
|
|
raise ModelLoadError(
|
|
model_id,
|
|
"onnxruntime not installed. Install with: pip install onnxruntime-gpu",
|
|
cause=e,
|
|
)
|
|
|
|
self._loading = True
|
|
start_time = time.time()
|
|
|
|
try:
|
|
# Resolve model path
|
|
model_path = await self._resolve_model_path(model_id)
|
|
if model_path is None:
|
|
raise ModelNotFoundError(model_id)
|
|
|
|
# Find the actual .onnx file
|
|
onnx_file = self._find_onnx_file(model_path)
|
|
|
|
# Set up session options
|
|
sess_options = ort.SessionOptions()
|
|
|
|
# Graph optimization
|
|
opt_levels = {
|
|
"disabled": ort.GraphOptimizationLevel.ORT_DISABLE_ALL,
|
|
"basic": ort.GraphOptimizationLevel.ORT_ENABLE_BASIC,
|
|
"extended": ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED,
|
|
"all": ort.GraphOptimizationLevel.ORT_ENABLE_ALL,
|
|
}
|
|
sess_options.graph_optimization_level = opt_levels.get(
|
|
graph_optimization_level, ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
)
|
|
|
|
# Thread configuration
|
|
if intra_op_num_threads > 0:
|
|
sess_options.intra_op_num_threads = intra_op_num_threads
|
|
if inter_op_num_threads > 0:
|
|
sess_options.inter_op_num_threads = inter_op_num_threads
|
|
|
|
self._session_options = sess_options
|
|
|
|
# Determine providers
|
|
if providers is None:
|
|
providers = _get_onnx_providers(device)
|
|
|
|
# Filter to available providers
|
|
available = ort.get_available_providers()
|
|
providers = [p for p in providers if p in available]
|
|
|
|
if not providers:
|
|
providers = ["CPUExecutionProvider"]
|
|
|
|
logger.info(f"Loading ONNX model with providers: {providers}")
|
|
|
|
# Create session
|
|
session = ort.InferenceSession(
|
|
str(onnx_file),
|
|
sess_options=sess_options,
|
|
providers=providers,
|
|
provider_options=provider_options,
|
|
)
|
|
|
|
# Store input/output names
|
|
self._input_names = [inp.name for inp in session.get_inputs()]
|
|
self._output_names = [out.name for out in session.get_outputs()]
|
|
|
|
self._model = session
|
|
self._model_info = ModelInfo(
|
|
model_id=model_id,
|
|
path=onnx_file,
|
|
device=providers[0] if providers else "cpu",
|
|
load_time_seconds=time.time() - start_time,
|
|
metadata={
|
|
"providers": providers,
|
|
"input_names": self._input_names,
|
|
"output_names": self._output_names,
|
|
"graph_optimization": graph_optimization_level,
|
|
},
|
|
)
|
|
|
|
logger.info(
|
|
f"Loaded ONNX model '{model_id}' in {self._model_info.load_time_seconds:.2f}s"
|
|
)
|
|
|
|
return session
|
|
|
|
except Exception as e:
|
|
if isinstance(e, ModelLoadError):
|
|
raise
|
|
raise ModelLoadError(model_id, str(e), cause=e)
|
|
finally:
|
|
self._loading = False
|
|
|
|
async def unload(self) -> None:
|
|
"""Unload the ONNX session."""
|
|
if self._model is not None:
|
|
# ONNX Runtime sessions don't have explicit cleanup
|
|
# Just dereference to allow garbage collection
|
|
self._model = None
|
|
self._model_info = None
|
|
self._input_names = []
|
|
self._output_names = []
|
|
logger.debug("Unloaded ONNX model")
|
|
|
|
async def _resolve_model_path(self, model_id: str) -> Optional[Path]:
|
|
"""Resolve model ID to a local path."""
|
|
# Check if it's already a path
|
|
path = Path(model_id)
|
|
if path.exists():
|
|
return path
|
|
|
|
# Try to resolve via model loader
|
|
try:
|
|
from .loader import ensure_model
|
|
|
|
resolved = ensure_model(model_id)
|
|
if resolved:
|
|
return Path(resolved)
|
|
except Exception as e:
|
|
logger.debug(f"Failed to resolve via manifest: {e}")
|
|
|
|
return None
|
|
|
|
def _find_onnx_file(self, path: Path) -> Path:
|
|
"""Find the .onnx file in a path (file or directory)."""
|
|
if path.is_file() and path.suffix.lower() == ".onnx":
|
|
return path
|
|
|
|
if path.is_dir():
|
|
# Look for model.onnx first
|
|
model_onnx = path / "model.onnx"
|
|
if model_onnx.exists():
|
|
return model_onnx
|
|
|
|
# Look for any .onnx file
|
|
onnx_files = list(path.glob("*.onnx"))
|
|
if onnx_files:
|
|
return onnx_files[0]
|
|
|
|
raise ModelNotFoundError(
|
|
str(path), searched_paths=[path, path / "model.onnx"]
|
|
)
|
|
|
|
def run(
|
|
self,
|
|
inputs: Dict[str, Any],
|
|
output_names: Optional[List[str]] = None,
|
|
) -> List[Any]:
|
|
"""
|
|
Run inference on the loaded model.
|
|
|
|
Args:
|
|
inputs: Dict mapping input names to numpy arrays
|
|
output_names: Optional list of output names (default: all)
|
|
|
|
Returns:
|
|
List of output numpy arrays
|
|
"""
|
|
if self._model is None:
|
|
raise RuntimeError("No model loaded. Call load() first.")
|
|
|
|
return self._model.run(output_names, inputs)
|
|
|
|
def __call__(self, inputs: Dict[str, Any]) -> List[Any]:
|
|
"""Shorthand for run()."""
|
|
return self.run(inputs)
|