345 lines
10 KiB
Python
345 lines
10 KiB
Python
|
|
"""
|
||
|
|
faster-whisper model loader.
|
||
|
|
|
||
|
|
Loads CTranslate2-optimized Whisper models for efficient speech recognition.
|
||
|
|
"""
|
||
|
|
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import Optional, Any, Dict, List, Iterator, Tuple, Union
|
||
|
|
import time
|
||
|
|
import logging
|
||
|
|
|
||
|
|
from .base import BaseModelLoader, ModelInfo, ModelLoadError, ModelNotFoundError
|
||
|
|
from .registry import register_loader
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
# Type alias
|
||
|
|
WhisperModel = Any
|
||
|
|
|
||
|
|
# Whisper model sizes
|
||
|
|
WHISPER_SIZES = [
|
||
|
|
"tiny", "tiny.en",
|
||
|
|
"base", "base.en",
|
||
|
|
"small", "small.en",
|
||
|
|
"medium", "medium.en",
|
||
|
|
"large", "large-v1", "large-v2", "large-v3",
|
||
|
|
"turbo",
|
||
|
|
]
|
||
|
|
|
||
|
|
|
||
|
|
def _get_whisper_device(device: Optional[str] = None) -> str:
|
||
|
|
"""Get faster-whisper compatible device string."""
|
||
|
|
if device is None or device == "auto":
|
||
|
|
# Check for CUDA availability
|
||
|
|
try:
|
||
|
|
import torch
|
||
|
|
if torch.cuda.is_available():
|
||
|
|
return "cuda"
|
||
|
|
except ImportError:
|
||
|
|
pass
|
||
|
|
return "cpu"
|
||
|
|
|
||
|
|
if device.startswith("cuda"):
|
||
|
|
return "cuda"
|
||
|
|
elif device == "mps":
|
||
|
|
logger.warning("faster-whisper doesn't support MPS, falling back to CPU")
|
||
|
|
return "cpu"
|
||
|
|
else:
|
||
|
|
return "cpu"
|
||
|
|
|
||
|
|
|
||
|
|
def _get_compute_type(compute_type: str, device: str) -> str:
|
||
|
|
"""Determine compute type based on device."""
|
||
|
|
if compute_type != "auto":
|
||
|
|
return compute_type
|
||
|
|
|
||
|
|
if device == "cuda":
|
||
|
|
return "float16"
|
||
|
|
else:
|
||
|
|
return "int8"
|
||
|
|
|
||
|
|
|
||
|
|
@register_loader("whisper", aliases=["faster-whisper", "ctranslate2", "fw"])
|
||
|
|
class WhisperLoader(BaseModelLoader[WhisperModel]):
|
||
|
|
"""
|
||
|
|
faster-whisper model loader.
|
||
|
|
|
||
|
|
Loads Whisper models optimized with CTranslate2 for fast inference.
|
||
|
|
Supports model sizes, HuggingFace IDs, and local CTranslate2 directories.
|
||
|
|
|
||
|
|
Example:
|
||
|
|
>>> loader = WhisperLoader()
|
||
|
|
>>> model = await loader.load("large-v3", device="cuda")
|
||
|
|
|
||
|
|
>>> # Transcribe audio
|
||
|
|
>>> segments, info = loader.transcribe("audio.wav")
|
||
|
|
>>> for segment in segments:
|
||
|
|
... print(f"[{segment.start:.2f} -> {segment.end:.2f}] {segment.text}")
|
||
|
|
|
||
|
|
>>> # Or with streaming
|
||
|
|
>>> for segment in loader.transcribe_stream("audio.wav"):
|
||
|
|
... print(segment.text, end="", flush=True)
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self) -> None:
|
||
|
|
super().__init__()
|
||
|
|
self._device: str = "cpu"
|
||
|
|
self._compute_type: str = "int8"
|
||
|
|
|
||
|
|
@property
|
||
|
|
def device(self) -> str:
|
||
|
|
"""Get the device the model is loaded on."""
|
||
|
|
return self._device
|
||
|
|
|
||
|
|
@property
|
||
|
|
def compute_type(self) -> str:
|
||
|
|
"""Get the compute type being used."""
|
||
|
|
return self._compute_type
|
||
|
|
|
||
|
|
async def load(
|
||
|
|
self,
|
||
|
|
model_id: str,
|
||
|
|
*,
|
||
|
|
device: Optional[str] = None,
|
||
|
|
compute_type: str = "auto",
|
||
|
|
cpu_threads: int = 0,
|
||
|
|
num_workers: int = 1,
|
||
|
|
download_root: Optional[str] = None,
|
||
|
|
local_files_only: bool = False,
|
||
|
|
**kwargs: Any,
|
||
|
|
) -> WhisperModel:
|
||
|
|
"""
|
||
|
|
Load a faster-whisper model.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
model_id: Model size ("tiny", "base", "small", "medium", "large-v3"),
|
||
|
|
HuggingFace ID ("deepdml/faster-whisper-large-v3-turbo-ct2"),
|
||
|
|
or local path to CT2 directory
|
||
|
|
device: Device ("cuda", "cpu", "auto")
|
||
|
|
compute_type: "auto", "int8", "int8_float16", "float16", "float32"
|
||
|
|
cpu_threads: Number of threads for CPU inference (0 = auto)
|
||
|
|
num_workers: Number of workers for transcription
|
||
|
|
download_root: Custom download directory
|
||
|
|
local_files_only: Only use local files, no downloads
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
WhisperModel instance
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
from faster_whisper import WhisperModel
|
||
|
|
except ImportError as e:
|
||
|
|
raise ModelLoadError(
|
||
|
|
model_id,
|
||
|
|
"faster-whisper not installed. Install with: pip install faster-whisper",
|
||
|
|
cause=e,
|
||
|
|
)
|
||
|
|
|
||
|
|
self._loading = True
|
||
|
|
start_time = time.time()
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Determine device
|
||
|
|
self._device = _get_whisper_device(device)
|
||
|
|
self._compute_type = _get_compute_type(compute_type, self._device)
|
||
|
|
|
||
|
|
# Resolve model path
|
||
|
|
model_path_or_size = await self._resolve_model_path(model_id)
|
||
|
|
|
||
|
|
logger.info(
|
||
|
|
f"Loading Whisper model '{model_path_or_size}' "
|
||
|
|
f"on {self._device} with {self._compute_type}"
|
||
|
|
)
|
||
|
|
|
||
|
|
# Create model
|
||
|
|
model = WhisperModel(
|
||
|
|
model_path_or_size,
|
||
|
|
device=self._device,
|
||
|
|
compute_type=self._compute_type,
|
||
|
|
cpu_threads=cpu_threads if cpu_threads > 0 else 0,
|
||
|
|
num_workers=num_workers,
|
||
|
|
download_root=download_root,
|
||
|
|
local_files_only=local_files_only,
|
||
|
|
)
|
||
|
|
|
||
|
|
self._model = model
|
||
|
|
self._model_info = ModelInfo(
|
||
|
|
model_id=model_id,
|
||
|
|
path=Path(model_path_or_size) if not model_id in WHISPER_SIZES else None,
|
||
|
|
device=self._device,
|
||
|
|
dtype=self._compute_type,
|
||
|
|
load_time_seconds=time.time() - start_time,
|
||
|
|
metadata={
|
||
|
|
"model_size": model_id if model_id in WHISPER_SIZES else "custom",
|
||
|
|
"compute_type": self._compute_type,
|
||
|
|
"num_workers": num_workers,
|
||
|
|
},
|
||
|
|
)
|
||
|
|
|
||
|
|
logger.info(
|
||
|
|
f"Loaded Whisper model '{model_id}' in {self._model_info.load_time_seconds:.2f}s"
|
||
|
|
)
|
||
|
|
|
||
|
|
return model
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
if isinstance(e, ModelLoadError):
|
||
|
|
raise
|
||
|
|
raise ModelLoadError(model_id, str(e), cause=e)
|
||
|
|
finally:
|
||
|
|
self._loading = False
|
||
|
|
|
||
|
|
async def unload(self) -> None:
|
||
|
|
"""Unload the Whisper model."""
|
||
|
|
if self._model is not None:
|
||
|
|
del self._model
|
||
|
|
self._model = None
|
||
|
|
self._model_info = None
|
||
|
|
logger.debug("Unloaded Whisper model")
|
||
|
|
|
||
|
|
async def _resolve_model_path(self, model_id: str) -> str:
|
||
|
|
"""Resolve model ID to a path or size string."""
|
||
|
|
# Check if it's a standard size
|
||
|
|
if model_id in WHISPER_SIZES:
|
||
|
|
return model_id
|
||
|
|
|
||
|
|
# Check if it's already a path
|
||
|
|
path = Path(model_id)
|
||
|
|
if path.exists():
|
||
|
|
return str(path)
|
||
|
|
|
||
|
|
# Try to resolve via model loader
|
||
|
|
try:
|
||
|
|
from .loader import ensure_model
|
||
|
|
|
||
|
|
resolved = ensure_model(model_id)
|
||
|
|
if resolved:
|
||
|
|
return resolved
|
||
|
|
except Exception as e:
|
||
|
|
logger.debug(f"Failed to resolve via manifest: {e}")
|
||
|
|
|
||
|
|
# Return as-is (might be a HuggingFace ID)
|
||
|
|
return model_id
|
||
|
|
|
||
|
|
def transcribe(
|
||
|
|
self,
|
||
|
|
audio: Union[str, Path, Any],
|
||
|
|
*,
|
||
|
|
language: Optional[str] = None,
|
||
|
|
task: str = "transcribe",
|
||
|
|
beam_size: int = 5,
|
||
|
|
best_of: int = 5,
|
||
|
|
patience: float = 1.0,
|
||
|
|
length_penalty: float = 1.0,
|
||
|
|
temperature: Union[float, List[float]] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
|
||
|
|
initial_prompt: Optional[str] = None,
|
||
|
|
word_timestamps: bool = False,
|
||
|
|
vad_filter: bool = False,
|
||
|
|
vad_parameters: Optional[Dict[str, Any]] = None,
|
||
|
|
**kwargs: Any,
|
||
|
|
) -> Tuple[Iterator[Any], Any]:
|
||
|
|
"""
|
||
|
|
Transcribe audio file.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
audio: Path to audio file or numpy array
|
||
|
|
language: Language code (e.g., "en") or None for auto-detection
|
||
|
|
task: "transcribe" or "translate"
|
||
|
|
beam_size: Beam size for decoding
|
||
|
|
best_of: Number of candidates when sampling
|
||
|
|
patience: Patience for beam search
|
||
|
|
length_penalty: Length penalty
|
||
|
|
temperature: Temperature(s) for sampling
|
||
|
|
initial_prompt: Optional prompt to condition the model
|
||
|
|
word_timestamps: Whether to output word-level timestamps
|
||
|
|
vad_filter: Whether to use VAD to filter out silence
|
||
|
|
vad_parameters: Parameters for VAD filter
|
||
|
|
**kwargs: Additional arguments
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Tuple of (segments iterator, transcription info)
|
||
|
|
"""
|
||
|
|
if self._model is None:
|
||
|
|
raise RuntimeError("No model loaded. Call load() first.")
|
||
|
|
|
||
|
|
return self._model.transcribe(
|
||
|
|
audio if isinstance(audio, str) else str(audio),
|
||
|
|
language=language,
|
||
|
|
task=task,
|
||
|
|
beam_size=beam_size,
|
||
|
|
best_of=best_of,
|
||
|
|
patience=patience,
|
||
|
|
length_penalty=length_penalty,
|
||
|
|
temperature=temperature,
|
||
|
|
initial_prompt=initial_prompt,
|
||
|
|
word_timestamps=word_timestamps,
|
||
|
|
vad_filter=vad_filter,
|
||
|
|
vad_parameters=vad_parameters,
|
||
|
|
**kwargs,
|
||
|
|
)
|
||
|
|
|
||
|
|
def transcribe_stream(
|
||
|
|
self,
|
||
|
|
audio: Union[str, Path],
|
||
|
|
**kwargs: Any,
|
||
|
|
) -> Iterator[Any]:
|
||
|
|
"""
|
||
|
|
Stream transcription segments.
|
||
|
|
|
||
|
|
Convenience method that yields segments one at a time.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
audio: Path to audio file
|
||
|
|
**kwargs: Arguments passed to transcribe()
|
||
|
|
|
||
|
|
Yields:
|
||
|
|
Transcription segments
|
||
|
|
"""
|
||
|
|
segments, _ = self.transcribe(audio, **kwargs)
|
||
|
|
yield from segments
|
||
|
|
|
||
|
|
def detect_language(
|
||
|
|
self,
|
||
|
|
audio: Union[str, Path, Any],
|
||
|
|
) -> Tuple[str, float]:
|
||
|
|
"""
|
||
|
|
Detect the language of audio.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
audio: Path to audio file or numpy array
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Tuple of (language code, probability)
|
||
|
|
"""
|
||
|
|
if self._model is None:
|
||
|
|
raise RuntimeError("No model loaded. Call load() first.")
|
||
|
|
|
||
|
|
return self._model.detect_language(
|
||
|
|
audio if isinstance(audio, str) else str(audio)
|
||
|
|
)
|
||
|
|
|
||
|
|
def __call__(
|
||
|
|
self,
|
||
|
|
audio: Union[str, Path],
|
||
|
|
**kwargs: Any,
|
||
|
|
) -> List[Dict[str, Any]]:
|
||
|
|
"""
|
||
|
|
Shorthand for transcribe that returns a list of segment dicts.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
audio: Path to audio file
|
||
|
|
**kwargs: Arguments passed to transcribe()
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
List of segment dictionaries
|
||
|
|
"""
|
||
|
|
segments, _ = self.transcribe(audio, **kwargs)
|
||
|
|
return [
|
||
|
|
{
|
||
|
|
"start": seg.start,
|
||
|
|
"end": seg.end,
|
||
|
|
"text": seg.text,
|
||
|
|
}
|
||
|
|
for seg in segments
|
||
|
|
]
|