- test_types.py: 26 tests for dataclasses and from_dict - test_auto.py: 28 tests for format/category mappings - test_onnx_loader.py: 16 tests for ONNX loader - test_whisper_loader.py: 24 tests for Whisper loader - pyproject.toml: pytest config and dev dependencies 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
274 lines
8.6 KiB
Python
274 lines
8.6 KiB
Python
"""Tests for tqftw_model_loader.whisper_loader module."""
|
|
|
|
import pytest
|
|
from pathlib import Path
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
|
|
class TestWhisperLoaderImport:
|
|
"""Test that WhisperLoader can be imported."""
|
|
|
|
def test_import_whisper_loader(self):
|
|
"""WhisperLoader should be importable."""
|
|
from tqftw_model_loader.whisper_loader import WhisperLoader
|
|
assert WhisperLoader is not None
|
|
|
|
|
|
class TestWhisperSizes:
|
|
"""Tests for WHISPER_SIZES constant."""
|
|
|
|
def test_contains_standard_sizes(self):
|
|
from tqftw_model_loader.whisper_loader import WHISPER_SIZES
|
|
|
|
assert "tiny" in WHISPER_SIZES
|
|
assert "tiny.en" in WHISPER_SIZES
|
|
assert "base" in WHISPER_SIZES
|
|
assert "small" in WHISPER_SIZES
|
|
assert "medium" in WHISPER_SIZES
|
|
assert "large" in WHISPER_SIZES
|
|
assert "large-v1" in WHISPER_SIZES
|
|
assert "large-v2" in WHISPER_SIZES
|
|
assert "large-v3" in WHISPER_SIZES
|
|
assert "turbo" in WHISPER_SIZES
|
|
|
|
|
|
class TestGetWhisperDevice:
|
|
"""Tests for _get_whisper_device helper."""
|
|
|
|
def test_auto_with_cuda_available(self):
|
|
from tqftw_model_loader.whisper_loader import _get_whisper_device
|
|
|
|
with patch("torch.cuda.is_available", return_value=True):
|
|
with patch.dict("sys.modules", {"torch": MagicMock()}):
|
|
# Mock torch import
|
|
mock_torch = MagicMock()
|
|
mock_torch.cuda.is_available.return_value = True
|
|
with patch.dict("sys.modules", {"torch": mock_torch}):
|
|
result = _get_whisper_device("auto")
|
|
assert result in ["cuda", "cpu"]
|
|
|
|
def test_none_defaults_to_auto(self):
|
|
from tqftw_model_loader.whisper_loader import _get_whisper_device
|
|
|
|
result = _get_whisper_device(None)
|
|
# Without CUDA, should fall back to CPU
|
|
assert result in ["cuda", "cpu"]
|
|
|
|
def test_cuda_device(self):
|
|
from tqftw_model_loader.whisper_loader import _get_whisper_device
|
|
|
|
result = _get_whisper_device("cuda")
|
|
assert result == "cuda"
|
|
|
|
def test_cuda_with_index(self):
|
|
from tqftw_model_loader.whisper_loader import _get_whisper_device
|
|
|
|
result = _get_whisper_device("cuda:0")
|
|
assert result == "cuda"
|
|
|
|
def test_mps_falls_back_to_cpu(self):
|
|
from tqftw_model_loader.whisper_loader import _get_whisper_device
|
|
|
|
result = _get_whisper_device("mps")
|
|
assert result == "cpu"
|
|
|
|
def test_cpu_device(self):
|
|
from tqftw_model_loader.whisper_loader import _get_whisper_device
|
|
|
|
result = _get_whisper_device("cpu")
|
|
assert result == "cpu"
|
|
|
|
|
|
class TestGetComputeType:
|
|
"""Tests for _get_compute_type helper."""
|
|
|
|
def test_explicit_compute_type(self):
|
|
from tqftw_model_loader.whisper_loader import _get_compute_type
|
|
|
|
assert _get_compute_type("int8", "cpu") == "int8"
|
|
assert _get_compute_type("float16", "cuda") == "float16"
|
|
assert _get_compute_type("float32", "cpu") == "float32"
|
|
|
|
def test_auto_for_cuda(self):
|
|
from tqftw_model_loader.whisper_loader import _get_compute_type
|
|
|
|
assert _get_compute_type("auto", "cuda") == "float16"
|
|
|
|
def test_auto_for_cpu(self):
|
|
from tqftw_model_loader.whisper_loader import _get_compute_type
|
|
|
|
assert _get_compute_type("auto", "cpu") == "int8"
|
|
|
|
|
|
class TestWhisperLoaderInit:
|
|
"""Tests for WhisperLoader initialization."""
|
|
|
|
def test_init_defaults(self):
|
|
from tqftw_model_loader.whisper_loader import WhisperLoader
|
|
|
|
loader = WhisperLoader()
|
|
assert loader._device == "cpu"
|
|
assert loader._compute_type == "int8"
|
|
assert loader._model is None
|
|
|
|
|
|
class TestWhisperLoaderProperties:
|
|
"""Tests for WhisperLoader properties."""
|
|
|
|
def test_device_property(self):
|
|
from tqftw_model_loader.whisper_loader import WhisperLoader
|
|
|
|
loader = WhisperLoader()
|
|
loader._device = "cuda"
|
|
assert loader.device == "cuda"
|
|
|
|
def test_compute_type_property(self):
|
|
from tqftw_model_loader.whisper_loader import WhisperLoader
|
|
|
|
loader = WhisperLoader()
|
|
loader._compute_type = "float16"
|
|
assert loader.compute_type == "float16"
|
|
|
|
|
|
class TestWhisperLoaderResolveModelPath:
|
|
"""Tests for _resolve_model_path method."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_returns_size_string_for_standard_sizes(self):
|
|
from tqftw_model_loader.whisper_loader import WhisperLoader
|
|
|
|
loader = WhisperLoader()
|
|
|
|
result = await loader._resolve_model_path("large-v3")
|
|
assert result == "large-v3"
|
|
|
|
result = await loader._resolve_model_path("tiny")
|
|
assert result == "tiny"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_returns_path_for_existing_file(self):
|
|
from tqftw_model_loader.whisper_loader import WhisperLoader
|
|
import tempfile
|
|
|
|
loader = WhisperLoader()
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
model_path = Path(tmpdir) / "whisper-model"
|
|
model_path.mkdir()
|
|
|
|
result = await loader._resolve_model_path(str(model_path))
|
|
assert result == str(model_path)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_returns_model_id_for_huggingface(self):
|
|
from tqftw_model_loader.whisper_loader import WhisperLoader
|
|
|
|
loader = WhisperLoader()
|
|
|
|
# Non-existent path that looks like HF ID should be returned as-is
|
|
result = await loader._resolve_model_path("deepdml/faster-whisper-large-v3-turbo-ct2")
|
|
assert "deepdml" in result
|
|
|
|
|
|
class TestWhisperLoaderTranscribe:
|
|
"""Tests for transcribe method."""
|
|
|
|
def test_transcribe_raises_when_no_model(self):
|
|
from tqftw_model_loader.whisper_loader import WhisperLoader
|
|
|
|
loader = WhisperLoader()
|
|
with pytest.raises(RuntimeError, match="No model loaded"):
|
|
loader.transcribe("audio.wav")
|
|
|
|
def test_transcribe_calls_model(self):
|
|
from tqftw_model_loader.whisper_loader import WhisperLoader
|
|
|
|
loader = WhisperLoader()
|
|
loader._model = MagicMock()
|
|
loader._model.transcribe.return_value = (["segment"], {"info": "data"})
|
|
|
|
segments, info = loader.transcribe("audio.wav", language="en")
|
|
|
|
loader._model.transcribe.assert_called_once()
|
|
call_kwargs = loader._model.transcribe.call_args[1]
|
|
assert call_kwargs["language"] == "en"
|
|
|
|
|
|
class TestWhisperLoaderTranscribeStream:
|
|
"""Tests for transcribe_stream method."""
|
|
|
|
def test_transcribe_stream_yields_segments(self):
|
|
from tqftw_model_loader.whisper_loader import WhisperLoader
|
|
|
|
loader = WhisperLoader()
|
|
loader._model = MagicMock()
|
|
|
|
mock_segments = [MagicMock(text="Hello"), MagicMock(text="World")]
|
|
loader._model.transcribe.return_value = (iter(mock_segments), {})
|
|
|
|
result = list(loader.transcribe_stream("audio.wav"))
|
|
assert len(result) == 2
|
|
|
|
|
|
class TestWhisperLoaderDetectLanguage:
|
|
"""Tests for detect_language method."""
|
|
|
|
def test_detect_language_raises_when_no_model(self):
|
|
from tqftw_model_loader.whisper_loader import WhisperLoader
|
|
|
|
loader = WhisperLoader()
|
|
with pytest.raises(RuntimeError, match="No model loaded"):
|
|
loader.detect_language("audio.wav")
|
|
|
|
def test_detect_language_calls_model(self):
|
|
from tqftw_model_loader.whisper_loader import WhisperLoader
|
|
|
|
loader = WhisperLoader()
|
|
loader._model = MagicMock()
|
|
loader._model.detect_language.return_value = ("en", 0.95)
|
|
|
|
result = loader.detect_language("audio.wav")
|
|
|
|
loader._model.detect_language.assert_called_once()
|
|
assert result == ("en", 0.95)
|
|
|
|
|
|
class TestWhisperLoaderCall:
|
|
"""Tests for __call__ method."""
|
|
|
|
def test_call_returns_segment_dicts(self):
|
|
from tqftw_model_loader.whisper_loader import WhisperLoader
|
|
|
|
loader = WhisperLoader()
|
|
loader._model = MagicMock()
|
|
|
|
mock_segment = MagicMock()
|
|
mock_segment.start = 0.0
|
|
mock_segment.end = 2.5
|
|
mock_segment.text = "Hello world"
|
|
|
|
loader._model.transcribe.return_value = (iter([mock_segment]), {})
|
|
|
|
result = loader("audio.wav")
|
|
|
|
assert len(result) == 1
|
|
assert result[0]["start"] == 0.0
|
|
assert result[0]["end"] == 2.5
|
|
assert result[0]["text"] == "Hello world"
|
|
|
|
|
|
class TestWhisperLoaderUnload:
|
|
"""Tests for unload method."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_unload_clears_model(self):
|
|
from tqftw_model_loader.whisper_loader import WhisperLoader
|
|
|
|
loader = WhisperLoader()
|
|
loader._model = MagicMock()
|
|
loader._model_info = MagicMock()
|
|
|
|
await loader.unload()
|
|
|
|
assert loader._model is None
|
|
assert loader._model_info is None
|