""" faster-whisper model loader. Loads CTranslate2-optimized Whisper models for efficient speech recognition. """ from pathlib import Path from typing import Optional, Any, Dict, List, Iterator, Tuple, Union import time import logging from .base import BaseModelLoader, ModelInfo, ModelLoadError, ModelNotFoundError from .registry import register_loader logger = logging.getLogger(__name__) # Type alias WhisperModel = Any # Whisper model sizes WHISPER_SIZES = [ "tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large", "large-v1", "large-v2", "large-v3", "turbo", ] def _get_whisper_device(device: Optional[str] = None) -> str: """Get faster-whisper compatible device string.""" if device is None or device == "auto": # Check for CUDA availability try: import torch if torch.cuda.is_available(): return "cuda" except ImportError: pass return "cpu" if device.startswith("cuda"): return "cuda" elif device == "mps": logger.warning("faster-whisper doesn't support MPS, falling back to CPU") return "cpu" else: return "cpu" def _get_compute_type(compute_type: str, device: str) -> str: """Determine compute type based on device.""" if compute_type != "auto": return compute_type if device == "cuda": return "float16" else: return "int8" @register_loader("whisper", aliases=["faster-whisper", "ctranslate2", "fw"]) class WhisperLoader(BaseModelLoader[WhisperModel]): """ faster-whisper model loader. Loads Whisper models optimized with CTranslate2 for fast inference. Supports model sizes, HuggingFace IDs, and local CTranslate2 directories. Example: >>> loader = WhisperLoader() >>> model = await loader.load("large-v3", device="cuda") >>> # Transcribe audio >>> segments, info = loader.transcribe("audio.wav") >>> for segment in segments: ... print(f"[{segment.start:.2f} -> {segment.end:.2f}] {segment.text}") >>> # Or with streaming >>> for segment in loader.transcribe_stream("audio.wav"): ... print(segment.text, end="", flush=True) """ def __init__(self) -> None: super().__init__() self._device: str = "cpu" self._compute_type: str = "int8" @property def device(self) -> str: """Get the device the model is loaded on.""" return self._device @property def compute_type(self) -> str: """Get the compute type being used.""" return self._compute_type async def load( self, model_id: str, *, device: Optional[str] = None, compute_type: str = "auto", cpu_threads: int = 0, num_workers: int = 1, download_root: Optional[str] = None, local_files_only: bool = False, **kwargs: Any, ) -> WhisperModel: """ Load a faster-whisper model. Args: model_id: Model size ("tiny", "base", "small", "medium", "large-v3"), HuggingFace ID ("deepdml/faster-whisper-large-v3-turbo-ct2"), or local path to CT2 directory device: Device ("cuda", "cpu", "auto") compute_type: "auto", "int8", "int8_float16", "float16", "float32" cpu_threads: Number of threads for CPU inference (0 = auto) num_workers: Number of workers for transcription download_root: Custom download directory local_files_only: Only use local files, no downloads Returns: WhisperModel instance """ try: from faster_whisper import WhisperModel except ImportError as e: raise ModelLoadError( model_id, "faster-whisper not installed. Install with: pip install faster-whisper", cause=e, ) self._loading = True start_time = time.time() try: # Determine device self._device = _get_whisper_device(device) self._compute_type = _get_compute_type(compute_type, self._device) # Resolve model path model_path_or_size = await self._resolve_model_path(model_id) logger.info( f"Loading Whisper model '{model_path_or_size}' " f"on {self._device} with {self._compute_type}" ) # Create model model = WhisperModel( model_path_or_size, device=self._device, compute_type=self._compute_type, cpu_threads=cpu_threads if cpu_threads > 0 else 0, num_workers=num_workers, download_root=download_root, local_files_only=local_files_only, ) self._model = model self._model_info = ModelInfo( model_id=model_id, path=Path(model_path_or_size) if not model_id in WHISPER_SIZES else None, device=self._device, dtype=self._compute_type, load_time_seconds=time.time() - start_time, metadata={ "model_size": model_id if model_id in WHISPER_SIZES else "custom", "compute_type": self._compute_type, "num_workers": num_workers, }, ) logger.info( f"Loaded Whisper model '{model_id}' in {self._model_info.load_time_seconds:.2f}s" ) return model except Exception as e: if isinstance(e, ModelLoadError): raise raise ModelLoadError(model_id, str(e), cause=e) finally: self._loading = False async def unload(self) -> None: """Unload the Whisper model.""" if self._model is not None: del self._model self._model = None self._model_info = None logger.debug("Unloaded Whisper model") async def _resolve_model_path(self, model_id: str) -> str: """Resolve model ID to a path or size string.""" # Check if it's a standard size if model_id in WHISPER_SIZES: return model_id # Check if it's already a path path = Path(model_id) if path.exists(): return str(path) # Try to resolve via model loader try: from .loader import ensure_model resolved = ensure_model(model_id) if resolved: return resolved except Exception as e: logger.debug(f"Failed to resolve via manifest: {e}") # Return as-is (might be a HuggingFace ID) return model_id def transcribe( self, audio: Union[str, Path, Any], *, language: Optional[str] = None, task: str = "transcribe", beam_size: int = 5, best_of: int = 5, patience: float = 1.0, length_penalty: float = 1.0, temperature: Union[float, List[float]] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0], initial_prompt: Optional[str] = None, word_timestamps: bool = False, vad_filter: bool = False, vad_parameters: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Tuple[Iterator[Any], Any]: """ Transcribe audio file. Args: audio: Path to audio file or numpy array language: Language code (e.g., "en") or None for auto-detection task: "transcribe" or "translate" beam_size: Beam size for decoding best_of: Number of candidates when sampling patience: Patience for beam search length_penalty: Length penalty temperature: Temperature(s) for sampling initial_prompt: Optional prompt to condition the model word_timestamps: Whether to output word-level timestamps vad_filter: Whether to use VAD to filter out silence vad_parameters: Parameters for VAD filter **kwargs: Additional arguments Returns: Tuple of (segments iterator, transcription info) """ if self._model is None: raise RuntimeError("No model loaded. Call load() first.") return self._model.transcribe( audio if isinstance(audio, str) else str(audio), language=language, task=task, beam_size=beam_size, best_of=best_of, patience=patience, length_penalty=length_penalty, temperature=temperature, initial_prompt=initial_prompt, word_timestamps=word_timestamps, vad_filter=vad_filter, vad_parameters=vad_parameters, **kwargs, ) def transcribe_stream( self, audio: Union[str, Path], **kwargs: Any, ) -> Iterator[Any]: """ Stream transcription segments. Convenience method that yields segments one at a time. Args: audio: Path to audio file **kwargs: Arguments passed to transcribe() Yields: Transcription segments """ segments, _ = self.transcribe(audio, **kwargs) yield from segments def detect_language( self, audio: Union[str, Path, Any], ) -> Tuple[str, float]: """ Detect the language of audio. Args: audio: Path to audio file or numpy array Returns: Tuple of (language code, probability) """ if self._model is None: raise RuntimeError("No model loaded. Call load() first.") return self._model.detect_language( audio if isinstance(audio, str) else str(audio) ) def __call__( self, audio: Union[str, Path], **kwargs: Any, ) -> List[Dict[str, Any]]: """ Shorthand for transcribe that returns a list of segment dicts. Args: audio: Path to audio file **kwargs: Arguments passed to transcribe() Returns: List of segment dictionaries """ segments, _ = self.transcribe(audio, **kwargs) return [ { "start": seg.start, "end": seg.end, "text": seg.text, } for seg in segments ]