171 lines
5.6 KiB
Python
171 lines
5.6 KiB
Python
|
|
"""Tests for tqftw_model_loader.auto module."""
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
from pathlib import Path
|
||
|
|
from unittest.mock import patch, MagicMock
|
||
|
|
|
||
|
|
from tqftw_model_loader.auto import (
|
||
|
|
FORMAT_TO_LOADER,
|
||
|
|
EXTENSION_TO_LOADER,
|
||
|
|
CATEGORY_TO_LOADER,
|
||
|
|
detect_format_from_path,
|
||
|
|
get_loader_for_format,
|
||
|
|
get_loader_for_category,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class TestFormatMappings:
|
||
|
|
"""Tests for format-to-loader mappings."""
|
||
|
|
|
||
|
|
def test_gguf_maps_to_gguf_loader(self):
|
||
|
|
assert FORMAT_TO_LOADER["gguf"] == "gguf"
|
||
|
|
|
||
|
|
def test_safetensors_maps_to_diffusers(self):
|
||
|
|
assert FORMAT_TO_LOADER["safetensors"] == "diffusers"
|
||
|
|
|
||
|
|
def test_onnx_maps_to_onnx_loader(self):
|
||
|
|
assert FORMAT_TO_LOADER["onnx"] == "onnx"
|
||
|
|
|
||
|
|
def test_pytorch_maps_to_hf(self):
|
||
|
|
assert FORMAT_TO_LOADER["pytorch"] == "hf"
|
||
|
|
|
||
|
|
def test_hf_snapshot_maps_to_hf(self):
|
||
|
|
assert FORMAT_TO_LOADER["hf-snapshot"] == "hf"
|
||
|
|
|
||
|
|
def test_diffusion_maps_to_diffusers(self):
|
||
|
|
assert FORMAT_TO_LOADER["diffusion"] == "diffusers"
|
||
|
|
|
||
|
|
|
||
|
|
class TestExtensionMappings:
|
||
|
|
"""Tests for extension-to-loader mappings."""
|
||
|
|
|
||
|
|
def test_gguf_extension(self):
|
||
|
|
assert EXTENSION_TO_LOADER[".gguf"] == "gguf"
|
||
|
|
|
||
|
|
def test_onnx_extension(self):
|
||
|
|
assert EXTENSION_TO_LOADER[".onnx"] == "onnx"
|
||
|
|
|
||
|
|
def test_safetensors_extension(self):
|
||
|
|
assert EXTENSION_TO_LOADER[".safetensors"] == "diffusers"
|
||
|
|
|
||
|
|
def test_pytorch_extensions(self):
|
||
|
|
assert EXTENSION_TO_LOADER[".pt"] == "hf"
|
||
|
|
assert EXTENSION_TO_LOADER[".pth"] == "hf"
|
||
|
|
assert EXTENSION_TO_LOADER[".bin"] == "hf"
|
||
|
|
|
||
|
|
|
||
|
|
class TestCategoryMappings:
|
||
|
|
"""Tests for category-to-loader mappings."""
|
||
|
|
|
||
|
|
def test_llm_maps_to_gguf(self):
|
||
|
|
assert CATEGORY_TO_LOADER["llm"] == "gguf"
|
||
|
|
|
||
|
|
def test_embedding_maps_to_gguf(self):
|
||
|
|
assert CATEGORY_TO_LOADER["embedding"] == "gguf"
|
||
|
|
|
||
|
|
def test_voice_maps_to_whisper(self):
|
||
|
|
assert CATEGORY_TO_LOADER["voice"] == "whisper"
|
||
|
|
|
||
|
|
def test_diffusion_maps_to_diffusers(self):
|
||
|
|
assert CATEGORY_TO_LOADER["diffusion"] == "diffusers"
|
||
|
|
|
||
|
|
def test_multimodal_maps_to_hf(self):
|
||
|
|
assert CATEGORY_TO_LOADER["multimodal"] == "hf"
|
||
|
|
|
||
|
|
def test_tools_maps_to_onnx(self):
|
||
|
|
assert CATEGORY_TO_LOADER["tools"] == "onnx"
|
||
|
|
|
||
|
|
|
||
|
|
class TestDetectFormatFromPath:
|
||
|
|
"""Tests for detect_format_from_path function."""
|
||
|
|
|
||
|
|
def test_detects_gguf(self):
|
||
|
|
assert detect_format_from_path("model.gguf") == "gguf"
|
||
|
|
assert detect_format_from_path("/path/to/model-q4.GGUF") == "gguf"
|
||
|
|
|
||
|
|
def test_detects_onnx(self):
|
||
|
|
assert detect_format_from_path("model.onnx") == "onnx"
|
||
|
|
|
||
|
|
def test_detects_safetensors(self):
|
||
|
|
assert detect_format_from_path("model.safetensors") == "safetensors"
|
||
|
|
|
||
|
|
def test_detects_pytorch(self):
|
||
|
|
assert detect_format_from_path("model.pt") == "pytorch"
|
||
|
|
assert detect_format_from_path("model.pth") == "pytorch"
|
||
|
|
assert detect_format_from_path("model.bin") == "pytorch"
|
||
|
|
|
||
|
|
def test_returns_none_for_unknown(self):
|
||
|
|
assert detect_format_from_path("model.txt") is None
|
||
|
|
assert detect_format_from_path("model.json") is None
|
||
|
|
|
||
|
|
def test_returns_none_for_no_extension(self):
|
||
|
|
assert detect_format_from_path("model-directory") is None
|
||
|
|
|
||
|
|
@patch("tqftw_model_loader.auto.Path")
|
||
|
|
def test_detects_diffusion_directory(self, mock_path_cls):
|
||
|
|
"""Test detecting diffusion pipeline from model_index.json."""
|
||
|
|
mock_path = MagicMock()
|
||
|
|
mock_path.suffix = ""
|
||
|
|
mock_path.is_dir.return_value = True
|
||
|
|
mock_path.__truediv__ = lambda self, x: MagicMock(
|
||
|
|
exists=lambda: x == "model_index.json"
|
||
|
|
)
|
||
|
|
mock_path_cls.return_value = mock_path
|
||
|
|
|
||
|
|
result = detect_format_from_path("stable-diffusion-xl")
|
||
|
|
assert result == "diffusion"
|
||
|
|
|
||
|
|
@patch("tqftw_model_loader.auto.Path")
|
||
|
|
def test_detects_sharded_safetensors_directory(self, mock_path_cls):
|
||
|
|
"""Test detecting sharded safetensors from index file."""
|
||
|
|
mock_path = MagicMock()
|
||
|
|
mock_path.suffix = ""
|
||
|
|
mock_path.is_dir.return_value = True
|
||
|
|
|
||
|
|
def mock_truediv(self_arg, x):
|
||
|
|
mock_subpath = MagicMock()
|
||
|
|
mock_subpath.exists.return_value = (x == "model.safetensors.index.json")
|
||
|
|
return mock_subpath
|
||
|
|
|
||
|
|
mock_path.__truediv__ = mock_truediv
|
||
|
|
mock_path_cls.return_value = mock_path
|
||
|
|
|
||
|
|
result = detect_format_from_path("llama-sharded")
|
||
|
|
assert result == "safetensors"
|
||
|
|
|
||
|
|
|
||
|
|
class TestGetLoaderForFormat:
|
||
|
|
"""Tests for get_loader_for_format function."""
|
||
|
|
|
||
|
|
@patch("tqftw_model_loader.auto.get_loader")
|
||
|
|
def test_returns_loader_for_valid_format(self, mock_get_loader):
|
||
|
|
mock_loader = MagicMock()
|
||
|
|
mock_get_loader.return_value = mock_loader
|
||
|
|
|
||
|
|
result = get_loader_for_format("gguf")
|
||
|
|
|
||
|
|
mock_get_loader.assert_called_once_with("gguf")
|
||
|
|
assert result == mock_loader
|
||
|
|
|
||
|
|
def test_raises_for_unknown_format(self):
|
||
|
|
with pytest.raises(KeyError, match="No loader registered for format"):
|
||
|
|
get_loader_for_format("unknown_format")
|
||
|
|
|
||
|
|
|
||
|
|
class TestGetLoaderForCategory:
|
||
|
|
"""Tests for get_loader_for_category function."""
|
||
|
|
|
||
|
|
@patch("tqftw_model_loader.auto.get_loader")
|
||
|
|
def test_returns_loader_for_valid_category(self, mock_get_loader):
|
||
|
|
mock_loader = MagicMock()
|
||
|
|
mock_get_loader.return_value = mock_loader
|
||
|
|
|
||
|
|
result = get_loader_for_category("llm")
|
||
|
|
|
||
|
|
mock_get_loader.assert_called_once_with("gguf")
|
||
|
|
assert result == mock_loader
|
||
|
|
|
||
|
|
def test_raises_for_unknown_category(self):
|
||
|
|
with pytest.raises(KeyError, match="No loader registered for category"):
|
||
|
|
get_loader_for_category("unknown_category")
|