ml-model-loader/src_python/tests/test_auto.py
Lilith bf1e8835e1 Add Python test suite (94 tests)
- test_types.py: 26 tests for dataclasses and from_dict
- test_auto.py: 28 tests for format/category mappings
- test_onnx_loader.py: 16 tests for ONNX loader
- test_whisper_loader.py: 24 tests for Whisper loader
- pyproject.toml: pytest config and dev dependencies

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-01 15:58:11 -08:00

170 lines
5.6 KiB
Python

"""Tests for tqftw_model_loader.auto module."""
import pytest
from pathlib import Path
from unittest.mock import patch, MagicMock
from tqftw_model_loader.auto import (
FORMAT_TO_LOADER,
EXTENSION_TO_LOADER,
CATEGORY_TO_LOADER,
detect_format_from_path,
get_loader_for_format,
get_loader_for_category,
)
class TestFormatMappings:
"""Tests for format-to-loader mappings."""
def test_gguf_maps_to_gguf_loader(self):
assert FORMAT_TO_LOADER["gguf"] == "gguf"
def test_safetensors_maps_to_diffusers(self):
assert FORMAT_TO_LOADER["safetensors"] == "diffusers"
def test_onnx_maps_to_onnx_loader(self):
assert FORMAT_TO_LOADER["onnx"] == "onnx"
def test_pytorch_maps_to_hf(self):
assert FORMAT_TO_LOADER["pytorch"] == "hf"
def test_hf_snapshot_maps_to_hf(self):
assert FORMAT_TO_LOADER["hf-snapshot"] == "hf"
def test_diffusion_maps_to_diffusers(self):
assert FORMAT_TO_LOADER["diffusion"] == "diffusers"
class TestExtensionMappings:
"""Tests for extension-to-loader mappings."""
def test_gguf_extension(self):
assert EXTENSION_TO_LOADER[".gguf"] == "gguf"
def test_onnx_extension(self):
assert EXTENSION_TO_LOADER[".onnx"] == "onnx"
def test_safetensors_extension(self):
assert EXTENSION_TO_LOADER[".safetensors"] == "diffusers"
def test_pytorch_extensions(self):
assert EXTENSION_TO_LOADER[".pt"] == "hf"
assert EXTENSION_TO_LOADER[".pth"] == "hf"
assert EXTENSION_TO_LOADER[".bin"] == "hf"
class TestCategoryMappings:
"""Tests for category-to-loader mappings."""
def test_llm_maps_to_gguf(self):
assert CATEGORY_TO_LOADER["llm"] == "gguf"
def test_embedding_maps_to_gguf(self):
assert CATEGORY_TO_LOADER["embedding"] == "gguf"
def test_voice_maps_to_whisper(self):
assert CATEGORY_TO_LOADER["voice"] == "whisper"
def test_diffusion_maps_to_diffusers(self):
assert CATEGORY_TO_LOADER["diffusion"] == "diffusers"
def test_multimodal_maps_to_hf(self):
assert CATEGORY_TO_LOADER["multimodal"] == "hf"
def test_tools_maps_to_onnx(self):
assert CATEGORY_TO_LOADER["tools"] == "onnx"
class TestDetectFormatFromPath:
"""Tests for detect_format_from_path function."""
def test_detects_gguf(self):
assert detect_format_from_path("model.gguf") == "gguf"
assert detect_format_from_path("/path/to/model-q4.GGUF") == "gguf"
def test_detects_onnx(self):
assert detect_format_from_path("model.onnx") == "onnx"
def test_detects_safetensors(self):
assert detect_format_from_path("model.safetensors") == "safetensors"
def test_detects_pytorch(self):
assert detect_format_from_path("model.pt") == "pytorch"
assert detect_format_from_path("model.pth") == "pytorch"
assert detect_format_from_path("model.bin") == "pytorch"
def test_returns_none_for_unknown(self):
assert detect_format_from_path("model.txt") is None
assert detect_format_from_path("model.json") is None
def test_returns_none_for_no_extension(self):
assert detect_format_from_path("model-directory") is None
@patch("tqftw_model_loader.auto.Path")
def test_detects_diffusion_directory(self, mock_path_cls):
"""Test detecting diffusion pipeline from model_index.json."""
mock_path = MagicMock()
mock_path.suffix = ""
mock_path.is_dir.return_value = True
mock_path.__truediv__ = lambda self, x: MagicMock(
exists=lambda: x == "model_index.json"
)
mock_path_cls.return_value = mock_path
result = detect_format_from_path("stable-diffusion-xl")
assert result == "diffusion"
@patch("tqftw_model_loader.auto.Path")
def test_detects_sharded_safetensors_directory(self, mock_path_cls):
"""Test detecting sharded safetensors from index file."""
mock_path = MagicMock()
mock_path.suffix = ""
mock_path.is_dir.return_value = True
def mock_truediv(self_arg, x):
mock_subpath = MagicMock()
mock_subpath.exists.return_value = (x == "model.safetensors.index.json")
return mock_subpath
mock_path.__truediv__ = mock_truediv
mock_path_cls.return_value = mock_path
result = detect_format_from_path("llama-sharded")
assert result == "safetensors"
class TestGetLoaderForFormat:
"""Tests for get_loader_for_format function."""
@patch("tqftw_model_loader.auto.get_loader")
def test_returns_loader_for_valid_format(self, mock_get_loader):
mock_loader = MagicMock()
mock_get_loader.return_value = mock_loader
result = get_loader_for_format("gguf")
mock_get_loader.assert_called_once_with("gguf")
assert result == mock_loader
def test_raises_for_unknown_format(self):
with pytest.raises(KeyError, match="No loader registered for format"):
get_loader_for_format("unknown_format")
class TestGetLoaderForCategory:
"""Tests for get_loader_for_category function."""
@patch("tqftw_model_loader.auto.get_loader")
def test_returns_loader_for_valid_category(self, mock_get_loader):
mock_loader = MagicMock()
mock_get_loader.return_value = mock_loader
result = get_loader_for_category("llm")
mock_get_loader.assert_called_once_with("gguf")
assert result == mock_loader
def test_raises_for_unknown_category(self):
with pytest.raises(KeyError, match="No loader registered for category"):
get_loader_for_category("unknown_category")