521 lines
18 KiB
Python
521 lines
18 KiB
Python
"""
|
|
Tests for IntentClassifier.
|
|
|
|
Tests cover the core classification functionality, caching behavior,
|
|
urgency keyword detection, and error handling.
|
|
"""
|
|
|
|
import json
|
|
from typing import Any
|
|
|
|
import pytest
|
|
|
|
from ml_intent_classifier import (
|
|
ClassifierConfig,
|
|
ConfidenceScores,
|
|
EmotionalTone,
|
|
Intent,
|
|
IntentClassifier,
|
|
PrimaryIntent,
|
|
ResponseStyle,
|
|
TopicCategory,
|
|
build_classification_prompt,
|
|
get_system_prompt,
|
|
)
|
|
|
|
|
|
class MockModelClient:
|
|
"""Mock LLM client for testing."""
|
|
|
|
def __init__(self, response: str | None = None):
|
|
"""Initialize with optional fixed response."""
|
|
self._response = response
|
|
self._calls: list[list[dict[str, str]]] = []
|
|
|
|
async def generate(
|
|
self,
|
|
messages: list[dict[str, str]],
|
|
**kwargs: Any,
|
|
) -> str:
|
|
"""Return mock response and record the call."""
|
|
self._calls.append(messages)
|
|
|
|
if self._response is not None:
|
|
return self._response
|
|
|
|
# Default response for testing
|
|
return json.dumps(
|
|
{
|
|
"primary_intent": "question",
|
|
"urgency": 0.3,
|
|
"emotional_tone": "neutral",
|
|
"topic": "logistical",
|
|
"suggested_response_style": "casual",
|
|
"confidence": 0.85,
|
|
"confidence_scores": {
|
|
"primary_intent": 0.9,
|
|
"urgency": 0.8,
|
|
"emotional_tone": 0.85,
|
|
"topic": 0.75,
|
|
"response_style": 0.8,
|
|
},
|
|
}
|
|
)
|
|
|
|
@property
|
|
def call_count(self) -> int:
|
|
"""Return number of calls made."""
|
|
return len(self._calls)
|
|
|
|
@property
|
|
def last_call(self) -> list[dict[str, str]] | None:
|
|
"""Return the last call made."""
|
|
return self._calls[-1] if self._calls else None
|
|
|
|
|
|
class TestIntent:
|
|
"""Tests for the Intent dataclass."""
|
|
|
|
def test_intent_creation(self) -> None:
|
|
"""Test basic Intent creation."""
|
|
intent = Intent(
|
|
primary=PrimaryIntent.QUESTION,
|
|
urgency=0.5,
|
|
emotional_tone=EmotionalTone.NEUTRAL,
|
|
topic=TopicCategory.LOGISTICAL,
|
|
suggested_response_style=ResponseStyle.CASUAL,
|
|
)
|
|
|
|
assert intent.primary == PrimaryIntent.QUESTION
|
|
assert intent.urgency == 0.5
|
|
assert intent.emotional_tone == EmotionalTone.NEUTRAL
|
|
assert intent.topic == TopicCategory.LOGISTICAL
|
|
assert intent.suggested_response_style == ResponseStyle.CASUAL
|
|
|
|
def test_intent_urgency_validation(self) -> None:
|
|
"""Test that urgency must be between 0.0 and 1.0."""
|
|
with pytest.raises(ValueError, match="urgency must be between"):
|
|
Intent(
|
|
primary=PrimaryIntent.QUESTION,
|
|
urgency=1.5,
|
|
emotional_tone=EmotionalTone.NEUTRAL,
|
|
topic=TopicCategory.CASUAL,
|
|
suggested_response_style=ResponseStyle.CASUAL,
|
|
)
|
|
|
|
with pytest.raises(ValueError, match="urgency must be between"):
|
|
Intent(
|
|
primary=PrimaryIntent.QUESTION,
|
|
urgency=-0.1,
|
|
emotional_tone=EmotionalTone.NEUTRAL,
|
|
topic=TopicCategory.CASUAL,
|
|
suggested_response_style=ResponseStyle.CASUAL,
|
|
)
|
|
|
|
def test_intent_is_urgent_property(self) -> None:
|
|
"""Test the is_urgent property."""
|
|
urgent = Intent(
|
|
primary=PrimaryIntent.REQUEST,
|
|
urgency=0.8,
|
|
emotional_tone=EmotionalTone.NEUTRAL,
|
|
topic=TopicCategory.PROFESSIONAL,
|
|
suggested_response_style=ResponseStyle.FORMAL,
|
|
)
|
|
assert urgent.is_urgent is True
|
|
|
|
not_urgent = Intent(
|
|
primary=PrimaryIntent.GREETING,
|
|
urgency=0.3,
|
|
emotional_tone=EmotionalTone.POSITIVE,
|
|
topic=TopicCategory.CASUAL,
|
|
suggested_response_style=ResponseStyle.CASUAL,
|
|
)
|
|
assert not_urgent.is_urgent is False
|
|
|
|
def test_intent_emotional_properties(self) -> None:
|
|
"""Test emotional tone properties."""
|
|
positive = Intent(
|
|
primary=PrimaryIntent.REACTION,
|
|
urgency=0.2,
|
|
emotional_tone=EmotionalTone.POSITIVE,
|
|
topic=TopicCategory.PERSONAL,
|
|
suggested_response_style=ResponseStyle.EMPATHETIC,
|
|
)
|
|
assert positive.is_positive is True
|
|
assert positive.is_negative is False
|
|
|
|
negative = Intent(
|
|
primary=PrimaryIntent.STATEMENT,
|
|
urgency=0.6,
|
|
emotional_tone=EmotionalTone.NEGATIVE,
|
|
topic=TopicCategory.PROFESSIONAL,
|
|
suggested_response_style=ResponseStyle.EMPATHETIC,
|
|
)
|
|
assert negative.is_positive is False
|
|
assert negative.is_negative is True
|
|
|
|
def test_intent_needs_action_property(self) -> None:
|
|
"""Test the needs_action property."""
|
|
# Request always needs action
|
|
request = Intent(
|
|
primary=PrimaryIntent.REQUEST,
|
|
urgency=0.3,
|
|
emotional_tone=EmotionalTone.NEUTRAL,
|
|
topic=TopicCategory.LOGISTICAL,
|
|
suggested_response_style=ResponseStyle.CASUAL,
|
|
)
|
|
assert request.needs_action is True
|
|
|
|
# Urgent question needs action
|
|
urgent_question = Intent(
|
|
primary=PrimaryIntent.QUESTION,
|
|
urgency=0.8,
|
|
emotional_tone=EmotionalTone.NEUTRAL,
|
|
topic=TopicCategory.PROFESSIONAL,
|
|
suggested_response_style=ResponseStyle.FORMAL,
|
|
)
|
|
assert urgent_question.needs_action is True
|
|
|
|
# Non-urgent question does not need action
|
|
casual_question = Intent(
|
|
primary=PrimaryIntent.QUESTION,
|
|
urgency=0.3,
|
|
emotional_tone=EmotionalTone.NEUTRAL,
|
|
topic=TopicCategory.CASUAL,
|
|
suggested_response_style=ResponseStyle.CASUAL,
|
|
)
|
|
assert casual_question.needs_action is False
|
|
|
|
|
|
class TestConfidenceScores:
|
|
"""Tests for the ConfidenceScores dataclass."""
|
|
|
|
def test_confidence_scores_creation(self) -> None:
|
|
"""Test basic ConfidenceScores creation."""
|
|
scores = ConfidenceScores(
|
|
primary_intent=0.9,
|
|
urgency=0.8,
|
|
emotional_tone=0.85,
|
|
topic=0.75,
|
|
response_style=0.8,
|
|
overall=0.85,
|
|
)
|
|
|
|
assert scores.primary_intent == 0.9
|
|
assert scores.overall == 0.85
|
|
|
|
def test_confidence_scores_validation(self) -> None:
|
|
"""Test that scores must be between 0.0 and 1.0."""
|
|
with pytest.raises(ValueError, match="primary_intent must be between"):
|
|
ConfidenceScores(primary_intent=1.5)
|
|
|
|
with pytest.raises(ValueError, match="overall must be between"):
|
|
ConfidenceScores(overall=-0.1)
|
|
|
|
def test_confidence_scores_defaults(self) -> None:
|
|
"""Test default values."""
|
|
scores = ConfidenceScores()
|
|
assert scores.primary_intent == 0.0
|
|
assert scores.overall == 0.0
|
|
|
|
|
|
class TestIntentClassifier:
|
|
"""Tests for the IntentClassifier class."""
|
|
|
|
@pytest.fixture
|
|
def mock_client(self) -> MockModelClient:
|
|
"""Create a mock model client."""
|
|
return MockModelClient()
|
|
|
|
@pytest.fixture
|
|
def classifier(self, mock_client: MockModelClient) -> IntentClassifier:
|
|
"""Create a classifier with mock client."""
|
|
return IntentClassifier(model_client=mock_client)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_basic_classification(self, classifier: IntentClassifier) -> None:
|
|
"""Test basic message classification."""
|
|
intent = await classifier.classify("Hey, are you free Saturday?")
|
|
|
|
assert intent.primary == PrimaryIntent.QUESTION
|
|
assert intent.urgency == pytest.approx(0.3, abs=0.1)
|
|
assert intent.emotional_tone == EmotionalTone.NEUTRAL
|
|
assert intent.topic == TopicCategory.LOGISTICAL
|
|
assert intent.suggested_response_style == ResponseStyle.CASUAL
|
|
assert intent.raw_message == "Hey, are you free Saturday?"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_empty_message_returns_default(
|
|
self,
|
|
classifier: IntentClassifier,
|
|
) -> None:
|
|
"""Test that empty messages return default intent."""
|
|
intent = await classifier.classify("")
|
|
assert intent.primary == PrimaryIntent.STATEMENT
|
|
assert intent.confidence.overall == 0.0
|
|
|
|
intent = await classifier.classify(" ")
|
|
assert intent.primary == PrimaryIntent.STATEMENT
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_urgency_keyword_high_boost(self) -> None:
|
|
"""Test that high urgency keywords boost urgency score."""
|
|
response = json.dumps(
|
|
{
|
|
"primary_intent": "request",
|
|
"urgency": 0.5,
|
|
"emotional_tone": "neutral",
|
|
"topic": "professional",
|
|
"suggested_response_style": "formal",
|
|
"confidence": 0.9,
|
|
}
|
|
)
|
|
client = MockModelClient(response)
|
|
classifier = IntentClassifier(model_client=client)
|
|
|
|
intent = await classifier.classify("URGENT: Need this ASAP!")
|
|
|
|
# Base 0.5 + keyword boost (default 0.3)
|
|
assert intent.urgency >= 0.7
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_urgency_keyword_low_reduction(self) -> None:
|
|
"""Test that low urgency keywords reduce urgency score."""
|
|
response = json.dumps(
|
|
{
|
|
"primary_intent": "request",
|
|
"urgency": 0.5,
|
|
"emotional_tone": "neutral",
|
|
"topic": "casual",
|
|
"suggested_response_style": "casual",
|
|
"confidence": 0.9,
|
|
}
|
|
)
|
|
client = MockModelClient(response)
|
|
classifier = IntentClassifier(model_client=client)
|
|
|
|
intent = await classifier.classify("Whenever you get a chance, no rush")
|
|
|
|
# Base 0.5 - keyword reduction
|
|
assert intent.urgency < 0.5
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_caching_enabled(self, mock_client: MockModelClient) -> None:
|
|
"""Test that caching prevents duplicate LLM calls."""
|
|
config = ClassifierConfig(cache_enabled=True)
|
|
classifier = IntentClassifier(model_client=mock_client, config=config)
|
|
|
|
message = "Test caching message"
|
|
|
|
# First call should hit LLM
|
|
await classifier.classify(message)
|
|
assert mock_client.call_count == 1
|
|
|
|
# Second call should use cache
|
|
await classifier.classify(message)
|
|
assert mock_client.call_count == 1
|
|
|
|
# Different message should hit LLM
|
|
await classifier.classify("Different message")
|
|
assert mock_client.call_count == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_caching_disabled(self, mock_client: MockModelClient) -> None:
|
|
"""Test that disabling cache calls LLM every time."""
|
|
config = ClassifierConfig(cache_enabled=False)
|
|
classifier = IntentClassifier(model_client=mock_client, config=config)
|
|
|
|
message = "Test no caching"
|
|
|
|
await classifier.classify(message)
|
|
assert mock_client.call_count == 1
|
|
|
|
await classifier.classify(message)
|
|
assert mock_client.call_count == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cache_clear(self, mock_client: MockModelClient) -> None:
|
|
"""Test cache clearing."""
|
|
classifier = IntentClassifier(model_client=mock_client)
|
|
|
|
await classifier.classify("Test message")
|
|
assert classifier.cache_size == 1
|
|
|
|
classifier.clear_cache()
|
|
assert classifier.cache_size == 0
|
|
|
|
# Should call LLM again after cache clear
|
|
await classifier.classify("Test message")
|
|
assert mock_client.call_count == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fallback_on_error(self) -> None:
|
|
"""Test that fallback returns default on LLM error."""
|
|
client = MockModelClient("not valid json at all")
|
|
config = ClassifierConfig(fallback_on_error=True)
|
|
classifier = IntentClassifier(model_client=client, config=config)
|
|
|
|
intent = await classifier.classify("Test message")
|
|
|
|
assert intent.primary == PrimaryIntent.STATEMENT
|
|
assert intent.confidence.overall == 0.0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_no_fallback_raises_error(self) -> None:
|
|
"""Test that error is raised when fallback disabled."""
|
|
client = MockModelClient("invalid response")
|
|
config = ClassifierConfig(fallback_on_error=False)
|
|
classifier = IntentClassifier(model_client=client, config=config)
|
|
|
|
with pytest.raises(ValueError, match="Classification failed"):
|
|
await classifier.classify("Test message")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_classify_batch(self, classifier: IntentClassifier) -> None:
|
|
"""Test batch classification."""
|
|
messages = [
|
|
"Hello there!",
|
|
"What time is the meeting?",
|
|
"Thanks for the help!",
|
|
]
|
|
|
|
intents = await classifier.classify_batch(messages)
|
|
|
|
assert len(intents) == 3
|
|
for intent in intents:
|
|
assert isinstance(intent, Intent)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handles_json_in_text(self) -> None:
|
|
"""Test that JSON is extracted from surrounding text."""
|
|
response = """Here is the classification:
|
|
{"primary_intent": "statement", "urgency": 0.2, "emotional_tone": "positive",
|
|
"topic": "casual", "suggested_response_style": "brief", "confidence": 0.9}
|
|
That's my analysis."""
|
|
|
|
client = MockModelClient(response)
|
|
classifier = IntentClassifier(model_client=client)
|
|
|
|
intent = await classifier.classify("I'm happy!")
|
|
|
|
assert intent.primary == PrimaryIntent.STATEMENT
|
|
assert intent.emotional_tone == EmotionalTone.POSITIVE
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handles_unknown_enum_values(self) -> None:
|
|
"""Test that unknown enum values fall back to defaults."""
|
|
response = json.dumps(
|
|
{
|
|
"primary_intent": "unknown_intent",
|
|
"urgency": 0.5,
|
|
"emotional_tone": "unknown_tone",
|
|
"topic": "unknown_topic",
|
|
"suggested_response_style": "unknown_style",
|
|
"confidence": 0.7,
|
|
}
|
|
)
|
|
|
|
client = MockModelClient(response)
|
|
classifier = IntentClassifier(model_client=client)
|
|
|
|
intent = await classifier.classify("Test message")
|
|
|
|
# Should use defaults
|
|
assert intent.primary == PrimaryIntent.STATEMENT
|
|
assert intent.emotional_tone == EmotionalTone.NEUTRAL
|
|
assert intent.topic == TopicCategory.CASUAL
|
|
assert intent.suggested_response_style == ResponseStyle.CASUAL
|
|
|
|
|
|
class TestPrompts:
|
|
"""Tests for prompt building functions."""
|
|
|
|
def test_build_classification_prompt(self) -> None:
|
|
"""Test prompt building."""
|
|
prompt = build_classification_prompt("Hello world")
|
|
assert "Hello world" in prompt
|
|
assert "primary_intent" in prompt
|
|
assert "urgency" in prompt
|
|
|
|
def test_build_classification_prompt_escapes_braces(self) -> None:
|
|
"""Test that braces in message are escaped."""
|
|
prompt = build_classification_prompt("Use {variable} here")
|
|
# Should not raise KeyError
|
|
assert "{variable}" in prompt or "{{variable}}" in prompt
|
|
|
|
def test_get_system_prompt(self) -> None:
|
|
"""Test system prompt retrieval."""
|
|
prompt = get_system_prompt()
|
|
assert "classifier" in prompt.lower()
|
|
assert "primary_intent" in prompt
|
|
assert "urgency" in prompt
|
|
|
|
|
|
class TestClassifierConfig:
|
|
"""Tests for ClassifierConfig."""
|
|
|
|
def test_default_config(self) -> None:
|
|
"""Test default configuration values."""
|
|
config = ClassifierConfig()
|
|
assert config.cache_enabled is True
|
|
assert config.cache_max_size == 1000
|
|
assert config.urgency_keyword_boost == 0.3
|
|
assert config.fallback_on_error is True
|
|
|
|
def test_custom_config(self) -> None:
|
|
"""Test custom configuration."""
|
|
config = ClassifierConfig(
|
|
cache_enabled=False,
|
|
cache_max_size=500,
|
|
urgency_keyword_boost=0.2,
|
|
fallback_on_error=False,
|
|
)
|
|
assert config.cache_enabled is False
|
|
assert config.cache_max_size == 500
|
|
|
|
def test_config_validation(self) -> None:
|
|
"""Test configuration validation."""
|
|
with pytest.raises(ValueError):
|
|
ClassifierConfig(cache_max_size=5) # Too small
|
|
|
|
with pytest.raises(ValueError):
|
|
ClassifierConfig(urgency_keyword_boost=0.6) # Too large
|
|
|
|
|
|
class TestEnums:
|
|
"""Tests for enum values and iteration."""
|
|
|
|
def test_primary_intent_values(self) -> None:
|
|
"""Test all primary intent values exist."""
|
|
values = [e.value for e in PrimaryIntent]
|
|
assert "question" in values
|
|
assert "statement" in values
|
|
assert "request" in values
|
|
assert "greeting" in values
|
|
assert "farewell" in values
|
|
assert "reaction" in values
|
|
assert "acknowledgment" in values
|
|
|
|
def test_emotional_tone_values(self) -> None:
|
|
"""Test all emotional tone values exist."""
|
|
values = [e.value for e in EmotionalTone]
|
|
assert "positive" in values
|
|
assert "negative" in values
|
|
assert "neutral" in values
|
|
assert "mixed" in values
|
|
|
|
def test_topic_category_values(self) -> None:
|
|
"""Test all topic category values exist."""
|
|
values = [e.value for e in TopicCategory]
|
|
assert "personal" in values
|
|
assert "professional" in values
|
|
assert "casual" in values
|
|
assert "logistical" in values
|
|
|
|
def test_response_style_values(self) -> None:
|
|
"""Test all response style values exist."""
|
|
values = [e.value for e in ResponseStyle]
|
|
assert "formal" in values
|
|
assert "casual" in values
|
|
assert "empathetic" in values
|
|
assert "brief" in values
|