"""Tests for tqftw_model_loader.auto module.""" import pytest from pathlib import Path from unittest.mock import patch, MagicMock from tqftw_model_loader.auto import ( FORMAT_TO_LOADER, EXTENSION_TO_LOADER, CATEGORY_TO_LOADER, detect_format_from_path, get_loader_for_format, get_loader_for_category, ) class TestFormatMappings: """Tests for format-to-loader mappings.""" def test_gguf_maps_to_gguf_loader(self): assert FORMAT_TO_LOADER["gguf"] == "gguf" def test_safetensors_maps_to_diffusers(self): assert FORMAT_TO_LOADER["safetensors"] == "diffusers" def test_onnx_maps_to_onnx_loader(self): assert FORMAT_TO_LOADER["onnx"] == "onnx" def test_pytorch_maps_to_hf(self): assert FORMAT_TO_LOADER["pytorch"] == "hf" def test_hf_snapshot_maps_to_hf(self): assert FORMAT_TO_LOADER["hf-snapshot"] == "hf" def test_diffusion_maps_to_diffusers(self): assert FORMAT_TO_LOADER["diffusion"] == "diffusers" class TestExtensionMappings: """Tests for extension-to-loader mappings.""" def test_gguf_extension(self): assert EXTENSION_TO_LOADER[".gguf"] == "gguf" def test_onnx_extension(self): assert EXTENSION_TO_LOADER[".onnx"] == "onnx" def test_safetensors_extension(self): assert EXTENSION_TO_LOADER[".safetensors"] == "diffusers" def test_pytorch_extensions(self): assert EXTENSION_TO_LOADER[".pt"] == "hf" assert EXTENSION_TO_LOADER[".pth"] == "hf" assert EXTENSION_TO_LOADER[".bin"] == "hf" class TestCategoryMappings: """Tests for category-to-loader mappings.""" def test_llm_maps_to_gguf(self): assert CATEGORY_TO_LOADER["llm"] == "gguf" def test_embedding_maps_to_gguf(self): assert CATEGORY_TO_LOADER["embedding"] == "gguf" def test_voice_maps_to_whisper(self): assert CATEGORY_TO_LOADER["voice"] == "whisper" def test_diffusion_maps_to_diffusers(self): assert CATEGORY_TO_LOADER["diffusion"] == "diffusers" def test_multimodal_maps_to_hf(self): assert CATEGORY_TO_LOADER["multimodal"] == "hf" def test_tools_maps_to_onnx(self): assert CATEGORY_TO_LOADER["tools"] == "onnx" class TestDetectFormatFromPath: """Tests for detect_format_from_path function.""" def test_detects_gguf(self): assert detect_format_from_path("model.gguf") == "gguf" assert detect_format_from_path("/path/to/model-q4.GGUF") == "gguf" def test_detects_onnx(self): assert detect_format_from_path("model.onnx") == "onnx" def test_detects_safetensors(self): assert detect_format_from_path("model.safetensors") == "safetensors" def test_detects_pytorch(self): assert detect_format_from_path("model.pt") == "pytorch" assert detect_format_from_path("model.pth") == "pytorch" assert detect_format_from_path("model.bin") == "pytorch" def test_returns_none_for_unknown(self): assert detect_format_from_path("model.txt") is None assert detect_format_from_path("model.json") is None def test_returns_none_for_no_extension(self): assert detect_format_from_path("model-directory") is None @patch("tqftw_model_loader.auto.Path") def test_detects_diffusion_directory(self, mock_path_cls): """Test detecting diffusion pipeline from model_index.json.""" mock_path = MagicMock() mock_path.suffix = "" mock_path.is_dir.return_value = True mock_path.__truediv__ = lambda self, x: MagicMock( exists=lambda: x == "model_index.json" ) mock_path_cls.return_value = mock_path result = detect_format_from_path("stable-diffusion-xl") assert result == "diffusion" @patch("tqftw_model_loader.auto.Path") def test_detects_sharded_safetensors_directory(self, mock_path_cls): """Test detecting sharded safetensors from index file.""" mock_path = MagicMock() mock_path.suffix = "" mock_path.is_dir.return_value = True def mock_truediv(self_arg, x): mock_subpath = MagicMock() mock_subpath.exists.return_value = (x == "model.safetensors.index.json") return mock_subpath mock_path.__truediv__ = mock_truediv mock_path_cls.return_value = mock_path result = detect_format_from_path("llama-sharded") assert result == "safetensors" class TestGetLoaderForFormat: """Tests for get_loader_for_format function.""" @patch("tqftw_model_loader.auto.get_loader") def test_returns_loader_for_valid_format(self, mock_get_loader): mock_loader = MagicMock() mock_get_loader.return_value = mock_loader result = get_loader_for_format("gguf") mock_get_loader.assert_called_once_with("gguf") assert result == mock_loader def test_raises_for_unknown_format(self): with pytest.raises(KeyError, match="No loader registered for format"): get_loader_for_format("unknown_format") class TestGetLoaderForCategory: """Tests for get_loader_for_category function.""" @patch("tqftw_model_loader.auto.get_loader") def test_returns_loader_for_valid_category(self, mock_get_loader): mock_loader = MagicMock() mock_get_loader.return_value = mock_loader result = get_loader_for_category("llm") mock_get_loader.assert_called_once_with("gguf") assert result == mock_loader def test_raises_for_unknown_category(self): with pytest.raises(KeyError, match="No loader registered for category"): get_loader_for_category("unknown_category")