ml-model-loader/src_python/build/lib/tqftw_model_loader/whisper_loader.py
Lilith 8f4a35ba79 chore: add publishConfig to prevent public npm publishing
All @lilith/* packages should publish to forge.nasty.sh only.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-03 00:42:23 -08:00

344 lines
10 KiB
Python

"""
faster-whisper model loader.
Loads CTranslate2-optimized Whisper models for efficient speech recognition.
"""
from pathlib import Path
from typing import Optional, Any, Dict, List, Iterator, Tuple, Union
import time
import logging
from .base import BaseModelLoader, ModelInfo, ModelLoadError, ModelNotFoundError
from .registry import register_loader
logger = logging.getLogger(__name__)
# Type alias
WhisperModel = Any
# Whisper model sizes
WHISPER_SIZES = [
"tiny", "tiny.en",
"base", "base.en",
"small", "small.en",
"medium", "medium.en",
"large", "large-v1", "large-v2", "large-v3",
"turbo",
]
def _get_whisper_device(device: Optional[str] = None) -> str:
"""Get faster-whisper compatible device string."""
if device is None or device == "auto":
# Check for CUDA availability
try:
import torch
if torch.cuda.is_available():
return "cuda"
except ImportError:
pass
return "cpu"
if device.startswith("cuda"):
return "cuda"
elif device == "mps":
logger.warning("faster-whisper doesn't support MPS, falling back to CPU")
return "cpu"
else:
return "cpu"
def _get_compute_type(compute_type: str, device: str) -> str:
"""Determine compute type based on device."""
if compute_type != "auto":
return compute_type
if device == "cuda":
return "float16"
else:
return "int8"
@register_loader("whisper", aliases=["faster-whisper", "ctranslate2", "fw"])
class WhisperLoader(BaseModelLoader[WhisperModel]):
"""
faster-whisper model loader.
Loads Whisper models optimized with CTranslate2 for fast inference.
Supports model sizes, HuggingFace IDs, and local CTranslate2 directories.
Example:
>>> loader = WhisperLoader()
>>> model = await loader.load("large-v3", device="cuda")
>>> # Transcribe audio
>>> segments, info = loader.transcribe("audio.wav")
>>> for segment in segments:
... print(f"[{segment.start:.2f} -> {segment.end:.2f}] {segment.text}")
>>> # Or with streaming
>>> for segment in loader.transcribe_stream("audio.wav"):
... print(segment.text, end="", flush=True)
"""
def __init__(self) -> None:
super().__init__()
self._device: str = "cpu"
self._compute_type: str = "int8"
@property
def device(self) -> str:
"""Get the device the model is loaded on."""
return self._device
@property
def compute_type(self) -> str:
"""Get the compute type being used."""
return self._compute_type
async def load(
self,
model_id: str,
*,
device: Optional[str] = None,
compute_type: str = "auto",
cpu_threads: int = 0,
num_workers: int = 1,
download_root: Optional[str] = None,
local_files_only: bool = False,
**kwargs: Any,
) -> WhisperModel:
"""
Load a faster-whisper model.
Args:
model_id: Model size ("tiny", "base", "small", "medium", "large-v3"),
HuggingFace ID ("deepdml/faster-whisper-large-v3-turbo-ct2"),
or local path to CT2 directory
device: Device ("cuda", "cpu", "auto")
compute_type: "auto", "int8", "int8_float16", "float16", "float32"
cpu_threads: Number of threads for CPU inference (0 = auto)
num_workers: Number of workers for transcription
download_root: Custom download directory
local_files_only: Only use local files, no downloads
Returns:
WhisperModel instance
"""
try:
from faster_whisper import WhisperModel
except ImportError as e:
raise ModelLoadError(
model_id,
"faster-whisper not installed. Install with: pip install faster-whisper",
cause=e,
)
self._loading = True
start_time = time.time()
try:
# Determine device
self._device = _get_whisper_device(device)
self._compute_type = _get_compute_type(compute_type, self._device)
# Resolve model path
model_path_or_size = await self._resolve_model_path(model_id)
logger.info(
f"Loading Whisper model '{model_path_or_size}' "
f"on {self._device} with {self._compute_type}"
)
# Create model
model = WhisperModel(
model_path_or_size,
device=self._device,
compute_type=self._compute_type,
cpu_threads=cpu_threads if cpu_threads > 0 else 0,
num_workers=num_workers,
download_root=download_root,
local_files_only=local_files_only,
)
self._model = model
self._model_info = ModelInfo(
model_id=model_id,
path=Path(model_path_or_size) if not model_id in WHISPER_SIZES else None,
device=self._device,
dtype=self._compute_type,
load_time_seconds=time.time() - start_time,
metadata={
"model_size": model_id if model_id in WHISPER_SIZES else "custom",
"compute_type": self._compute_type,
"num_workers": num_workers,
},
)
logger.info(
f"Loaded Whisper model '{model_id}' in {self._model_info.load_time_seconds:.2f}s"
)
return model
except Exception as e:
if isinstance(e, ModelLoadError):
raise
raise ModelLoadError(model_id, str(e), cause=e)
finally:
self._loading = False
async def unload(self) -> None:
"""Unload the Whisper model."""
if self._model is not None:
del self._model
self._model = None
self._model_info = None
logger.debug("Unloaded Whisper model")
async def _resolve_model_path(self, model_id: str) -> str:
"""Resolve model ID to a path or size string."""
# Check if it's a standard size
if model_id in WHISPER_SIZES:
return model_id
# Check if it's already a path
path = Path(model_id)
if path.exists():
return str(path)
# Try to resolve via model loader
try:
from .loader import ensure_model
resolved = ensure_model(model_id)
if resolved:
return resolved
except Exception as e:
logger.debug(f"Failed to resolve via manifest: {e}")
# Return as-is (might be a HuggingFace ID)
return model_id
def transcribe(
self,
audio: Union[str, Path, Any],
*,
language: Optional[str] = None,
task: str = "transcribe",
beam_size: int = 5,
best_of: int = 5,
patience: float = 1.0,
length_penalty: float = 1.0,
temperature: Union[float, List[float]] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
initial_prompt: Optional[str] = None,
word_timestamps: bool = False,
vad_filter: bool = False,
vad_parameters: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Tuple[Iterator[Any], Any]:
"""
Transcribe audio file.
Args:
audio: Path to audio file or numpy array
language: Language code (e.g., "en") or None for auto-detection
task: "transcribe" or "translate"
beam_size: Beam size for decoding
best_of: Number of candidates when sampling
patience: Patience for beam search
length_penalty: Length penalty
temperature: Temperature(s) for sampling
initial_prompt: Optional prompt to condition the model
word_timestamps: Whether to output word-level timestamps
vad_filter: Whether to use VAD to filter out silence
vad_parameters: Parameters for VAD filter
**kwargs: Additional arguments
Returns:
Tuple of (segments iterator, transcription info)
"""
if self._model is None:
raise RuntimeError("No model loaded. Call load() first.")
return self._model.transcribe(
audio if isinstance(audio, str) else str(audio),
language=language,
task=task,
beam_size=beam_size,
best_of=best_of,
patience=patience,
length_penalty=length_penalty,
temperature=temperature,
initial_prompt=initial_prompt,
word_timestamps=word_timestamps,
vad_filter=vad_filter,
vad_parameters=vad_parameters,
**kwargs,
)
def transcribe_stream(
self,
audio: Union[str, Path],
**kwargs: Any,
) -> Iterator[Any]:
"""
Stream transcription segments.
Convenience method that yields segments one at a time.
Args:
audio: Path to audio file
**kwargs: Arguments passed to transcribe()
Yields:
Transcription segments
"""
segments, _ = self.transcribe(audio, **kwargs)
yield from segments
def detect_language(
self,
audio: Union[str, Path, Any],
) -> Tuple[str, float]:
"""
Detect the language of audio.
Args:
audio: Path to audio file or numpy array
Returns:
Tuple of (language code, probability)
"""
if self._model is None:
raise RuntimeError("No model loaded. Call load() first.")
return self._model.detect_language(
audio if isinstance(audio, str) else str(audio)
)
def __call__(
self,
audio: Union[str, Path],
**kwargs: Any,
) -> List[Dict[str, Any]]:
"""
Shorthand for transcribe that returns a list of segment dicts.
Args:
audio: Path to audio file
**kwargs: Arguments passed to transcribe()
Returns:
List of segment dictionaries
"""
segments, _ = self.transcribe(audio, **kwargs)
return [
{
"start": seg.start,
"end": seg.end,
"text": seg.text,
}
for seg in segments
]