ml-model-loader/src_python/tqftw_model_loader/onnx_loader.py
Lilith bbbccd685b feat: universal ML model support with auto-discovery
Add support for all ML model types beyond GGUF:
- New discovery module for auto-scanning model directories
- Detect formats: GGUF, safetensors, ONNX, PyTorch, diffusion pipelines
- CLI commands: discover, scan, sync for manifest management
- Manifest v2.0 with format field, directory support, file lists

Python loaders (v2.0.0):
- ONNXLoader with CUDA/TensorRT execution providers
- WhisperLoader for faster-whisper with transcribe/stream
- get_auto_loader() for automatic backend selection

Breaking: Manifest schema upgraded to v2.0 (auto-migrates v1.x on load)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-01 15:21:52 -08:00

278 lines
8.8 KiB
Python

"""
ONNX Runtime model loader.
Loads ONNX models with GPU acceleration via CUDA/TensorRT execution providers.
"""
from pathlib import Path
from typing import Optional, Any, Dict, List
import time
import logging
from .base import BaseModelLoader, ModelInfo, ModelLoadError, ModelNotFoundError
from .registry import register_loader
logger = logging.getLogger(__name__)
# Type alias for ONNX session
ONNXSession = Any
def _get_onnx_providers(device: Optional[str] = None) -> List[str]:
"""Get ONNX execution providers based on device preference."""
if device is None or device == "auto":
# Try CUDA first, fall back to CPU
return [
"TensorrtExecutionProvider",
"CUDAExecutionProvider",
"CPUExecutionProvider",
]
elif device.startswith("cuda"):
return [
"TensorrtExecutionProvider",
"CUDAExecutionProvider",
"CPUExecutionProvider",
]
elif device == "tensorrt":
return [
"TensorrtExecutionProvider",
"CUDAExecutionProvider",
"CPUExecutionProvider",
]
else:
return ["CPUExecutionProvider"]
@register_loader("onnx", aliases=["onnxruntime", "ort"])
class ONNXLoader(BaseModelLoader[ONNXSession]):
"""
ONNX Runtime model loader.
Supports both single .onnx files and directory-based models
(containing model.onnx + external data).
Example:
>>> loader = ONNXLoader()
>>> session = await loader.load("silero-vad")
>>> outputs = session.run(None, {"input": audio_data})
# Or with specific options
>>> session = await loader.load(
... "model.onnx",
... device="cuda",
... providers=["CUDAExecutionProvider"],
... )
"""
def __init__(self) -> None:
super().__init__()
self._input_names: List[str] = []
self._output_names: List[str] = []
self._session_options: Any = None
@property
def input_names(self) -> List[str]:
"""Get input tensor names."""
return self._input_names
@property
def output_names(self) -> List[str]:
"""Get output tensor names."""
return self._output_names
async def load(
self,
model_id: str,
*,
device: Optional[str] = None,
providers: Optional[List[str]] = None,
provider_options: Optional[List[Dict[str, Any]]] = None,
graph_optimization_level: str = "all",
intra_op_num_threads: int = 0,
inter_op_num_threads: int = 0,
**kwargs: Any,
) -> ONNXSession:
"""
Load an ONNX model.
Args:
model_id: Model ID from manifest or path to .onnx file/directory
device: Device preference ("cuda", "cpu", "tensorrt", "auto")
providers: Explicit execution providers list
provider_options: Provider-specific options
graph_optimization_level: "disabled", "basic", "extended", "all"
intra_op_num_threads: Threads within ops (0 = auto)
inter_op_num_threads: Threads between ops (0 = auto)
**kwargs: Additional session options
Returns:
ONNX InferenceSession
"""
try:
import onnxruntime as ort
except ImportError as e:
raise ModelLoadError(
model_id,
"onnxruntime not installed. Install with: pip install onnxruntime-gpu",
cause=e,
)
self._loading = True
start_time = time.time()
try:
# Resolve model path
model_path = await self._resolve_model_path(model_id)
if model_path is None:
raise ModelNotFoundError(model_id)
# Find the actual .onnx file
onnx_file = self._find_onnx_file(model_path)
# Set up session options
sess_options = ort.SessionOptions()
# Graph optimization
opt_levels = {
"disabled": ort.GraphOptimizationLevel.ORT_DISABLE_ALL,
"basic": ort.GraphOptimizationLevel.ORT_ENABLE_BASIC,
"extended": ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED,
"all": ort.GraphOptimizationLevel.ORT_ENABLE_ALL,
}
sess_options.graph_optimization_level = opt_levels.get(
graph_optimization_level, ort.GraphOptimizationLevel.ORT_ENABLE_ALL
)
# Thread configuration
if intra_op_num_threads > 0:
sess_options.intra_op_num_threads = intra_op_num_threads
if inter_op_num_threads > 0:
sess_options.inter_op_num_threads = inter_op_num_threads
self._session_options = sess_options
# Determine providers
if providers is None:
providers = _get_onnx_providers(device)
# Filter to available providers
available = ort.get_available_providers()
providers = [p for p in providers if p in available]
if not providers:
providers = ["CPUExecutionProvider"]
logger.info(f"Loading ONNX model with providers: {providers}")
# Create session
session = ort.InferenceSession(
str(onnx_file),
sess_options=sess_options,
providers=providers,
provider_options=provider_options,
)
# Store input/output names
self._input_names = [inp.name for inp in session.get_inputs()]
self._output_names = [out.name for out in session.get_outputs()]
self._model = session
self._model_info = ModelInfo(
model_id=model_id,
path=onnx_file,
device=providers[0] if providers else "cpu",
load_time_seconds=time.time() - start_time,
metadata={
"providers": providers,
"input_names": self._input_names,
"output_names": self._output_names,
"graph_optimization": graph_optimization_level,
},
)
logger.info(
f"Loaded ONNX model '{model_id}' in {self._model_info.load_time_seconds:.2f}s"
)
return session
except Exception as e:
if isinstance(e, ModelLoadError):
raise
raise ModelLoadError(model_id, str(e), cause=e)
finally:
self._loading = False
async def unload(self) -> None:
"""Unload the ONNX session."""
if self._model is not None:
# ONNX Runtime sessions don't have explicit cleanup
# Just dereference to allow garbage collection
self._model = None
self._model_info = None
self._input_names = []
self._output_names = []
logger.debug("Unloaded ONNX model")
async def _resolve_model_path(self, model_id: str) -> Optional[Path]:
"""Resolve model ID to a local path."""
# Check if it's already a path
path = Path(model_id)
if path.exists():
return path
# Try to resolve via model loader
try:
from .loader import ensure_model
resolved = ensure_model(model_id)
if resolved:
return Path(resolved)
except Exception as e:
logger.debug(f"Failed to resolve via manifest: {e}")
return None
def _find_onnx_file(self, path: Path) -> Path:
"""Find the .onnx file in a path (file or directory)."""
if path.is_file() and path.suffix.lower() == ".onnx":
return path
if path.is_dir():
# Look for model.onnx first
model_onnx = path / "model.onnx"
if model_onnx.exists():
return model_onnx
# Look for any .onnx file
onnx_files = list(path.glob("*.onnx"))
if onnx_files:
return onnx_files[0]
raise ModelNotFoundError(
str(path), searched_paths=[path, path / "model.onnx"]
)
def run(
self,
inputs: Dict[str, Any],
output_names: Optional[List[str]] = None,
) -> List[Any]:
"""
Run inference on the loaded model.
Args:
inputs: Dict mapping input names to numpy arrays
output_names: Optional list of output names (default: all)
Returns:
List of output numpy arrays
"""
if self._model is None:
raise RuntimeError("No model loaded. Call load() first.")
return self._model.run(output_names, inputs)
def __call__(self, inputs: Dict[str, Any]) -> List[Any]:
"""Shorthand for run()."""
return self.run(inputs)