"""Tests for ML exception classes.""" import pytest from lilith_ml_exceptions.errors import ( InferenceError, MLBaseError, ModelLoadError, ResourceError, ValidationError, ) from lilith_ml_exceptions.http_mappers import ( exception_to_http_status, exception_to_response, ) class TestMLBaseError: """Tests for MLBaseError base class.""" def test_basic_creation(self): """Test creating a basic exception.""" exc = MLBaseError("Something went wrong") assert str(exc) == "Something went wrong" assert exc.message == "Something went wrong" assert exc.details == {} assert exc.cause is None def test_with_details(self): """Test exception with details.""" exc = MLBaseError("Error", details={"key": "value"}) assert exc.details == {"key": "value"} def test_with_cause(self): """Test exception with underlying cause.""" cause = ValueError("Original error") exc = MLBaseError("Wrapped error", cause=cause) assert exc.cause is cause def test_to_dict(self): """Test converting exception to dictionary.""" exc = MLBaseError("Error message", details={"foo": "bar"}) result = exc.to_dict() assert result["error_code"] == "ML_ERROR" assert result["message"] == "Error message" assert result["details"] == {"foo": "bar"} class TestModelLoadError: """Tests for ModelLoadError.""" def test_error_code(self): """Test that error code is correct.""" exc = ModelLoadError("Model not found") assert exc.error_code == "MODEL_LOAD_ERROR" def test_with_model_path(self): """Test exception with model path.""" exc = ModelLoadError("Not found", model_path="/models/bert.pt") assert exc.model_path == "/models/bert.pt" assert exc.details["model_path"] == "/models/bert.pt" def test_with_model_name(self): """Test exception with model name.""" exc = ModelLoadError("Failed", model_name="bert-base") assert exc.model_name == "bert-base" assert exc.details["model_name"] == "bert-base" def test_http_status(self): """Test HTTP status code mapping.""" exc = ModelLoadError("Error") assert exception_to_http_status(exc) == 503 class TestInferenceError: """Tests for InferenceError.""" def test_error_code(self): """Test that error code is correct.""" exc = InferenceError("Inference failed") assert exc.error_code == "INFERENCE_ERROR" def test_with_input_shape(self): """Test exception with input shape.""" exc = InferenceError("OOM", input_shape=(1, 3, 224, 224)) assert exc.input_shape == (1, 3, 224, 224) assert exc.details["input_shape"] == (1, 3, 224, 224) def test_with_timeout(self): """Test exception with timeout.""" exc = InferenceError("Timeout", timeout_seconds=30.0) assert exc.timeout_seconds == 30.0 assert exc.details["timeout_seconds"] == 30.0 def test_http_status(self): """Test HTTP status code mapping.""" exc = InferenceError("Error") assert exception_to_http_status(exc) == 500 class TestValidationError: """Tests for ValidationError.""" def test_error_code(self): """Test that error code is correct.""" exc = ValidationError("Invalid input") assert exc.error_code == "VALIDATION_ERROR" def test_with_field(self): """Test exception with field name.""" exc = ValidationError("Invalid", field="image") assert exc.field == "image" assert exc.details["field"] == "image" def test_with_expected_type(self): """Test exception with expected type.""" exc = ValidationError("Wrong type", expected_type="numpy.ndarray") assert exc.expected_type == "numpy.ndarray" def test_with_constraints(self): """Test exception with constraints.""" constraints = {"min_size": 224, "max_size": 1024} exc = ValidationError("Out of range", constraints=constraints) assert exc.constraints == constraints def test_http_status(self): """Test HTTP status code mapping.""" exc = ValidationError("Error") assert exception_to_http_status(exc) == 400 class TestResourceError: """Tests for ResourceError.""" def test_error_code(self): """Test that error code is correct.""" exc = ResourceError("GPU not available") assert exc.error_code == "RESOURCE_ERROR" def test_with_resource_info(self): """Test exception with resource information.""" exc = ResourceError( "Insufficient VRAM", resource_type="VRAM", required_amount="8GB", available_amount="4GB", ) assert exc.resource_type == "VRAM" assert exc.required_amount == "8GB" assert exc.available_amount == "4GB" def test_http_status(self): """Test HTTP status code mapping.""" exc = ResourceError("Error") assert exception_to_http_status(exc) == 503 class TestExceptionToResponse: """Tests for exception_to_response function.""" def test_ml_exception_response(self): """Test response for ML exceptions.""" exc = ModelLoadError("Not found", model_name="bert") response = exception_to_response(exc) assert response["status"] == "error" assert response["error"]["error_code"] == "MODEL_LOAD_ERROR" assert response["error"]["message"] == "Not found" assert response["error"]["details"]["model_name"] == "bert" def test_generic_exception_response(self): """Test response for non-ML exceptions.""" exc = ValueError("Something went wrong") response = exception_to_response(exc) assert response["status"] == "error" assert response["error"]["error_code"] == "INTERNAL_ERROR" assert response["error"]["message"] == "Something went wrong" class TestUnknownExceptionStatus: """Tests for unknown exception handling.""" def test_unknown_exception_returns_500(self): """Test that unknown exceptions map to 500.""" exc = RuntimeError("Unknown error") assert exception_to_http_status(exc) == 500