lilith-ml-intent-classifier/tests/test_classifier.py
2026-01-21 12:48:54 -08:00

521 lines
18 KiB
Python

"""
Tests for IntentClassifier.
Tests cover the core classification functionality, caching behavior,
urgency keyword detection, and error handling.
"""
import json
from typing import Any
import pytest
from ml_intent_classifier import (
ClassifierConfig,
ConfidenceScores,
EmotionalTone,
Intent,
IntentClassifier,
PrimaryIntent,
ResponseStyle,
TopicCategory,
build_classification_prompt,
get_system_prompt,
)
class MockModelClient:
"""Mock LLM client for testing."""
def __init__(self, response: str | None = None):
"""Initialize with optional fixed response."""
self._response = response
self._calls: list[list[dict[str, str]]] = []
async def generate(
self,
messages: list[dict[str, str]],
**kwargs: Any,
) -> str:
"""Return mock response and record the call."""
self._calls.append(messages)
if self._response is not None:
return self._response
# Default response for testing
return json.dumps(
{
"primary_intent": "question",
"urgency": 0.3,
"emotional_tone": "neutral",
"topic": "logistical",
"suggested_response_style": "casual",
"confidence": 0.85,
"confidence_scores": {
"primary_intent": 0.9,
"urgency": 0.8,
"emotional_tone": 0.85,
"topic": 0.75,
"response_style": 0.8,
},
}
)
@property
def call_count(self) -> int:
"""Return number of calls made."""
return len(self._calls)
@property
def last_call(self) -> list[dict[str, str]] | None:
"""Return the last call made."""
return self._calls[-1] if self._calls else None
class TestIntent:
"""Tests for the Intent dataclass."""
def test_intent_creation(self) -> None:
"""Test basic Intent creation."""
intent = Intent(
primary=PrimaryIntent.QUESTION,
urgency=0.5,
emotional_tone=EmotionalTone.NEUTRAL,
topic=TopicCategory.LOGISTICAL,
suggested_response_style=ResponseStyle.CASUAL,
)
assert intent.primary == PrimaryIntent.QUESTION
assert intent.urgency == 0.5
assert intent.emotional_tone == EmotionalTone.NEUTRAL
assert intent.topic == TopicCategory.LOGISTICAL
assert intent.suggested_response_style == ResponseStyle.CASUAL
def test_intent_urgency_validation(self) -> None:
"""Test that urgency must be between 0.0 and 1.0."""
with pytest.raises(ValueError, match="urgency must be between"):
Intent(
primary=PrimaryIntent.QUESTION,
urgency=1.5,
emotional_tone=EmotionalTone.NEUTRAL,
topic=TopicCategory.CASUAL,
suggested_response_style=ResponseStyle.CASUAL,
)
with pytest.raises(ValueError, match="urgency must be between"):
Intent(
primary=PrimaryIntent.QUESTION,
urgency=-0.1,
emotional_tone=EmotionalTone.NEUTRAL,
topic=TopicCategory.CASUAL,
suggested_response_style=ResponseStyle.CASUAL,
)
def test_intent_is_urgent_property(self) -> None:
"""Test the is_urgent property."""
urgent = Intent(
primary=PrimaryIntent.REQUEST,
urgency=0.8,
emotional_tone=EmotionalTone.NEUTRAL,
topic=TopicCategory.PROFESSIONAL,
suggested_response_style=ResponseStyle.FORMAL,
)
assert urgent.is_urgent is True
not_urgent = Intent(
primary=PrimaryIntent.GREETING,
urgency=0.3,
emotional_tone=EmotionalTone.POSITIVE,
topic=TopicCategory.CASUAL,
suggested_response_style=ResponseStyle.CASUAL,
)
assert not_urgent.is_urgent is False
def test_intent_emotional_properties(self) -> None:
"""Test emotional tone properties."""
positive = Intent(
primary=PrimaryIntent.REACTION,
urgency=0.2,
emotional_tone=EmotionalTone.POSITIVE,
topic=TopicCategory.PERSONAL,
suggested_response_style=ResponseStyle.EMPATHETIC,
)
assert positive.is_positive is True
assert positive.is_negative is False
negative = Intent(
primary=PrimaryIntent.STATEMENT,
urgency=0.6,
emotional_tone=EmotionalTone.NEGATIVE,
topic=TopicCategory.PROFESSIONAL,
suggested_response_style=ResponseStyle.EMPATHETIC,
)
assert negative.is_positive is False
assert negative.is_negative is True
def test_intent_needs_action_property(self) -> None:
"""Test the needs_action property."""
# Request always needs action
request = Intent(
primary=PrimaryIntent.REQUEST,
urgency=0.3,
emotional_tone=EmotionalTone.NEUTRAL,
topic=TopicCategory.LOGISTICAL,
suggested_response_style=ResponseStyle.CASUAL,
)
assert request.needs_action is True
# Urgent question needs action
urgent_question = Intent(
primary=PrimaryIntent.QUESTION,
urgency=0.8,
emotional_tone=EmotionalTone.NEUTRAL,
topic=TopicCategory.PROFESSIONAL,
suggested_response_style=ResponseStyle.FORMAL,
)
assert urgent_question.needs_action is True
# Non-urgent question does not need action
casual_question = Intent(
primary=PrimaryIntent.QUESTION,
urgency=0.3,
emotional_tone=EmotionalTone.NEUTRAL,
topic=TopicCategory.CASUAL,
suggested_response_style=ResponseStyle.CASUAL,
)
assert casual_question.needs_action is False
class TestConfidenceScores:
"""Tests for the ConfidenceScores dataclass."""
def test_confidence_scores_creation(self) -> None:
"""Test basic ConfidenceScores creation."""
scores = ConfidenceScores(
primary_intent=0.9,
urgency=0.8,
emotional_tone=0.85,
topic=0.75,
response_style=0.8,
overall=0.85,
)
assert scores.primary_intent == 0.9
assert scores.overall == 0.85
def test_confidence_scores_validation(self) -> None:
"""Test that scores must be between 0.0 and 1.0."""
with pytest.raises(ValueError, match="primary_intent must be between"):
ConfidenceScores(primary_intent=1.5)
with pytest.raises(ValueError, match="overall must be between"):
ConfidenceScores(overall=-0.1)
def test_confidence_scores_defaults(self) -> None:
"""Test default values."""
scores = ConfidenceScores()
assert scores.primary_intent == 0.0
assert scores.overall == 0.0
class TestIntentClassifier:
"""Tests for the IntentClassifier class."""
@pytest.fixture
def mock_client(self) -> MockModelClient:
"""Create a mock model client."""
return MockModelClient()
@pytest.fixture
def classifier(self, mock_client: MockModelClient) -> IntentClassifier:
"""Create a classifier with mock client."""
return IntentClassifier(model_client=mock_client)
@pytest.mark.asyncio
async def test_basic_classification(self, classifier: IntentClassifier) -> None:
"""Test basic message classification."""
intent = await classifier.classify("Hey, are you free Saturday?")
assert intent.primary == PrimaryIntent.QUESTION
assert intent.urgency == pytest.approx(0.3, abs=0.1)
assert intent.emotional_tone == EmotionalTone.NEUTRAL
assert intent.topic == TopicCategory.LOGISTICAL
assert intent.suggested_response_style == ResponseStyle.CASUAL
assert intent.raw_message == "Hey, are you free Saturday?"
@pytest.mark.asyncio
async def test_empty_message_returns_default(
self,
classifier: IntentClassifier,
) -> None:
"""Test that empty messages return default intent."""
intent = await classifier.classify("")
assert intent.primary == PrimaryIntent.STATEMENT
assert intent.confidence.overall == 0.0
intent = await classifier.classify(" ")
assert intent.primary == PrimaryIntent.STATEMENT
@pytest.mark.asyncio
async def test_urgency_keyword_high_boost(self) -> None:
"""Test that high urgency keywords boost urgency score."""
response = json.dumps(
{
"primary_intent": "request",
"urgency": 0.5,
"emotional_tone": "neutral",
"topic": "professional",
"suggested_response_style": "formal",
"confidence": 0.9,
}
)
client = MockModelClient(response)
classifier = IntentClassifier(model_client=client)
intent = await classifier.classify("URGENT: Need this ASAP!")
# Base 0.5 + keyword boost (default 0.3)
assert intent.urgency >= 0.7
@pytest.mark.asyncio
async def test_urgency_keyword_low_reduction(self) -> None:
"""Test that low urgency keywords reduce urgency score."""
response = json.dumps(
{
"primary_intent": "request",
"urgency": 0.5,
"emotional_tone": "neutral",
"topic": "casual",
"suggested_response_style": "casual",
"confidence": 0.9,
}
)
client = MockModelClient(response)
classifier = IntentClassifier(model_client=client)
intent = await classifier.classify("Whenever you get a chance, no rush")
# Base 0.5 - keyword reduction
assert intent.urgency < 0.5
@pytest.mark.asyncio
async def test_caching_enabled(self, mock_client: MockModelClient) -> None:
"""Test that caching prevents duplicate LLM calls."""
config = ClassifierConfig(cache_enabled=True)
classifier = IntentClassifier(model_client=mock_client, config=config)
message = "Test caching message"
# First call should hit LLM
await classifier.classify(message)
assert mock_client.call_count == 1
# Second call should use cache
await classifier.classify(message)
assert mock_client.call_count == 1
# Different message should hit LLM
await classifier.classify("Different message")
assert mock_client.call_count == 2
@pytest.mark.asyncio
async def test_caching_disabled(self, mock_client: MockModelClient) -> None:
"""Test that disabling cache calls LLM every time."""
config = ClassifierConfig(cache_enabled=False)
classifier = IntentClassifier(model_client=mock_client, config=config)
message = "Test no caching"
await classifier.classify(message)
assert mock_client.call_count == 1
await classifier.classify(message)
assert mock_client.call_count == 2
@pytest.mark.asyncio
async def test_cache_clear(self, mock_client: MockModelClient) -> None:
"""Test cache clearing."""
classifier = IntentClassifier(model_client=mock_client)
await classifier.classify("Test message")
assert classifier.cache_size == 1
classifier.clear_cache()
assert classifier.cache_size == 0
# Should call LLM again after cache clear
await classifier.classify("Test message")
assert mock_client.call_count == 2
@pytest.mark.asyncio
async def test_fallback_on_error(self) -> None:
"""Test that fallback returns default on LLM error."""
client = MockModelClient("not valid json at all")
config = ClassifierConfig(fallback_on_error=True)
classifier = IntentClassifier(model_client=client, config=config)
intent = await classifier.classify("Test message")
assert intent.primary == PrimaryIntent.STATEMENT
assert intent.confidence.overall == 0.0
@pytest.mark.asyncio
async def test_no_fallback_raises_error(self) -> None:
"""Test that error is raised when fallback disabled."""
client = MockModelClient("invalid response")
config = ClassifierConfig(fallback_on_error=False)
classifier = IntentClassifier(model_client=client, config=config)
with pytest.raises(ValueError, match="Classification failed"):
await classifier.classify("Test message")
@pytest.mark.asyncio
async def test_classify_batch(self, classifier: IntentClassifier) -> None:
"""Test batch classification."""
messages = [
"Hello there!",
"What time is the meeting?",
"Thanks for the help!",
]
intents = await classifier.classify_batch(messages)
assert len(intents) == 3
for intent in intents:
assert isinstance(intent, Intent)
@pytest.mark.asyncio
async def test_handles_json_in_text(self) -> None:
"""Test that JSON is extracted from surrounding text."""
response = """Here is the classification:
{"primary_intent": "statement", "urgency": 0.2, "emotional_tone": "positive",
"topic": "casual", "suggested_response_style": "brief", "confidence": 0.9}
That's my analysis."""
client = MockModelClient(response)
classifier = IntentClassifier(model_client=client)
intent = await classifier.classify("I'm happy!")
assert intent.primary == PrimaryIntent.STATEMENT
assert intent.emotional_tone == EmotionalTone.POSITIVE
@pytest.mark.asyncio
async def test_handles_unknown_enum_values(self) -> None:
"""Test that unknown enum values fall back to defaults."""
response = json.dumps(
{
"primary_intent": "unknown_intent",
"urgency": 0.5,
"emotional_tone": "unknown_tone",
"topic": "unknown_topic",
"suggested_response_style": "unknown_style",
"confidence": 0.7,
}
)
client = MockModelClient(response)
classifier = IntentClassifier(model_client=client)
intent = await classifier.classify("Test message")
# Should use defaults
assert intent.primary == PrimaryIntent.STATEMENT
assert intent.emotional_tone == EmotionalTone.NEUTRAL
assert intent.topic == TopicCategory.CASUAL
assert intent.suggested_response_style == ResponseStyle.CASUAL
class TestPrompts:
"""Tests for prompt building functions."""
def test_build_classification_prompt(self) -> None:
"""Test prompt building."""
prompt = build_classification_prompt("Hello world")
assert "Hello world" in prompt
assert "primary_intent" in prompt
assert "urgency" in prompt
def test_build_classification_prompt_escapes_braces(self) -> None:
"""Test that braces in message are escaped."""
prompt = build_classification_prompt("Use {variable} here")
# Should not raise KeyError
assert "{variable}" in prompt or "{{variable}}" in prompt
def test_get_system_prompt(self) -> None:
"""Test system prompt retrieval."""
prompt = get_system_prompt()
assert "classifier" in prompt.lower()
assert "primary_intent" in prompt
assert "urgency" in prompt
class TestClassifierConfig:
"""Tests for ClassifierConfig."""
def test_default_config(self) -> None:
"""Test default configuration values."""
config = ClassifierConfig()
assert config.cache_enabled is True
assert config.cache_max_size == 1000
assert config.urgency_keyword_boost == 0.3
assert config.fallback_on_error is True
def test_custom_config(self) -> None:
"""Test custom configuration."""
config = ClassifierConfig(
cache_enabled=False,
cache_max_size=500,
urgency_keyword_boost=0.2,
fallback_on_error=False,
)
assert config.cache_enabled is False
assert config.cache_max_size == 500
def test_config_validation(self) -> None:
"""Test configuration validation."""
with pytest.raises(ValueError):
ClassifierConfig(cache_max_size=5) # Too small
with pytest.raises(ValueError):
ClassifierConfig(urgency_keyword_boost=0.6) # Too large
class TestEnums:
"""Tests for enum values and iteration."""
def test_primary_intent_values(self) -> None:
"""Test all primary intent values exist."""
values = [e.value for e in PrimaryIntent]
assert "question" in values
assert "statement" in values
assert "request" in values
assert "greeting" in values
assert "farewell" in values
assert "reaction" in values
assert "acknowledgment" in values
def test_emotional_tone_values(self) -> None:
"""Test all emotional tone values exist."""
values = [e.value for e in EmotionalTone]
assert "positive" in values
assert "negative" in values
assert "neutral" in values
assert "mixed" in values
def test_topic_category_values(self) -> None:
"""Test all topic category values exist."""
values = [e.value for e in TopicCategory]
assert "personal" in values
assert "professional" in values
assert "casual" in values
assert "logistical" in values
def test_response_style_values(self) -> None:
"""Test all response style values exist."""
values = [e.value for e in ResponseStyle]
assert "formal" in values
assert "casual" in values
assert "empathetic" in values
assert "brief" in values