""" ONNX Runtime model loader. Loads ONNX models with GPU acceleration via CUDA/TensorRT execution providers. """ from pathlib import Path from typing import Optional, Any, Dict, List import time import logging from .base import BaseModelLoader, ModelInfo, ModelLoadError, ModelNotFoundError from .registry import register_loader logger = logging.getLogger(__name__) # Type alias for ONNX session ONNXSession = Any def _get_onnx_providers(device: Optional[str] = None) -> List[str]: """Get ONNX execution providers based on device preference.""" if device is None or device == "auto": # Try CUDA first, fall back to CPU return [ "TensorrtExecutionProvider", "CUDAExecutionProvider", "CPUExecutionProvider", ] elif device.startswith("cuda"): return [ "TensorrtExecutionProvider", "CUDAExecutionProvider", "CPUExecutionProvider", ] elif device == "tensorrt": return [ "TensorrtExecutionProvider", "CUDAExecutionProvider", "CPUExecutionProvider", ] else: return ["CPUExecutionProvider"] @register_loader("onnx", aliases=["onnxruntime", "ort"]) class ONNXLoader(BaseModelLoader[ONNXSession]): """ ONNX Runtime model loader. Supports both single .onnx files and directory-based models (containing model.onnx + external data). Example: >>> loader = ONNXLoader() >>> session = await loader.load("silero-vad") >>> outputs = session.run(None, {"input": audio_data}) # Or with specific options >>> session = await loader.load( ... "model.onnx", ... device="cuda", ... providers=["CUDAExecutionProvider"], ... ) """ def __init__(self) -> None: super().__init__() self._input_names: List[str] = [] self._output_names: List[str] = [] self._session_options: Any = None @property def input_names(self) -> List[str]: """Get input tensor names.""" return self._input_names @property def output_names(self) -> List[str]: """Get output tensor names.""" return self._output_names async def load( self, model_id: str, *, device: Optional[str] = None, providers: Optional[List[str]] = None, provider_options: Optional[List[Dict[str, Any]]] = None, graph_optimization_level: str = "all", intra_op_num_threads: int = 0, inter_op_num_threads: int = 0, **kwargs: Any, ) -> ONNXSession: """ Load an ONNX model. Args: model_id: Model ID from manifest or path to .onnx file/directory device: Device preference ("cuda", "cpu", "tensorrt", "auto") providers: Explicit execution providers list provider_options: Provider-specific options graph_optimization_level: "disabled", "basic", "extended", "all" intra_op_num_threads: Threads within ops (0 = auto) inter_op_num_threads: Threads between ops (0 = auto) **kwargs: Additional session options Returns: ONNX InferenceSession """ try: import onnxruntime as ort except ImportError as e: raise ModelLoadError( model_id, "onnxruntime not installed. Install with: pip install onnxruntime-gpu", cause=e, ) self._loading = True start_time = time.time() try: # Resolve model path model_path = await self._resolve_model_path(model_id) if model_path is None: raise ModelNotFoundError(model_id) # Find the actual .onnx file onnx_file = self._find_onnx_file(model_path) # Set up session options sess_options = ort.SessionOptions() # Graph optimization opt_levels = { "disabled": ort.GraphOptimizationLevel.ORT_DISABLE_ALL, "basic": ort.GraphOptimizationLevel.ORT_ENABLE_BASIC, "extended": ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED, "all": ort.GraphOptimizationLevel.ORT_ENABLE_ALL, } sess_options.graph_optimization_level = opt_levels.get( graph_optimization_level, ort.GraphOptimizationLevel.ORT_ENABLE_ALL ) # Thread configuration if intra_op_num_threads > 0: sess_options.intra_op_num_threads = intra_op_num_threads if inter_op_num_threads > 0: sess_options.inter_op_num_threads = inter_op_num_threads self._session_options = sess_options # Determine providers if providers is None: providers = _get_onnx_providers(device) # Filter to available providers available = ort.get_available_providers() providers = [p for p in providers if p in available] if not providers: providers = ["CPUExecutionProvider"] logger.info(f"Loading ONNX model with providers: {providers}") # Create session session = ort.InferenceSession( str(onnx_file), sess_options=sess_options, providers=providers, provider_options=provider_options, ) # Store input/output names self._input_names = [inp.name for inp in session.get_inputs()] self._output_names = [out.name for out in session.get_outputs()] self._model = session self._model_info = ModelInfo( model_id=model_id, path=onnx_file, device=providers[0] if providers else "cpu", load_time_seconds=time.time() - start_time, metadata={ "providers": providers, "input_names": self._input_names, "output_names": self._output_names, "graph_optimization": graph_optimization_level, }, ) logger.info( f"Loaded ONNX model '{model_id}' in {self._model_info.load_time_seconds:.2f}s" ) return session except Exception as e: if isinstance(e, ModelLoadError): raise raise ModelLoadError(model_id, str(e), cause=e) finally: self._loading = False async def unload(self) -> None: """Unload the ONNX session.""" if self._model is not None: # ONNX Runtime sessions don't have explicit cleanup # Just dereference to allow garbage collection self._model = None self._model_info = None self._input_names = [] self._output_names = [] logger.debug("Unloaded ONNX model") async def _resolve_model_path(self, model_id: str) -> Optional[Path]: """Resolve model ID to a local path.""" # Check if it's already a path path = Path(model_id) if path.exists(): return path # Try to resolve via model loader try: from .loader import ensure_model resolved = ensure_model(model_id) if resolved: return Path(resolved) except Exception as e: logger.debug(f"Failed to resolve via manifest: {e}") return None def _find_onnx_file(self, path: Path) -> Path: """Find the .onnx file in a path (file or directory).""" if path.is_file() and path.suffix.lower() == ".onnx": return path if path.is_dir(): # Look for model.onnx first model_onnx = path / "model.onnx" if model_onnx.exists(): return model_onnx # Look for any .onnx file onnx_files = list(path.glob("*.onnx")) if onnx_files: return onnx_files[0] raise ModelNotFoundError( str(path), searched_paths=[path, path / "model.onnx"] ) def run( self, inputs: Dict[str, Any], output_names: Optional[List[str]] = None, ) -> List[Any]: """ Run inference on the loaded model. Args: inputs: Dict mapping input names to numpy arrays output_names: Optional list of output names (default: all) Returns: List of output numpy arrays """ if self._model is None: raise RuntimeError("No model loaded. Call load() first.") return self._model.run(output_names, inputs) def __call__(self, inputs: Dict[str, Any]) -> List[Any]: """Shorthand for run().""" return self.run(inputs)