Multi-stage Chain-of-Thought reasoning service with: - FastAPI service using lilith-ml-service-base - Pipeline orchestration via lilith-pipeline-framework - LLM client using lilith-ollama-provider - TypeScript client package @lilith/cot-client - Service-addresses integration for port resolution (8110) Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
163 lines
5.3 KiB
Python
163 lines
5.3 KiB
Python
"""Tests for reasoning engine."""
|
|
|
|
import pytest
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
from service.src.reasoning.engine import ReasoningEngine, ReasoningResult
|
|
from service.src.reasoning.stages import (
|
|
StageDefinition,
|
|
register_stage,
|
|
get_registered_stages,
|
|
ANALYZE_STAGE,
|
|
)
|
|
from service.src.llm.client import LLMClient
|
|
|
|
|
|
class MockLLMClient(LLMClient):
|
|
"""Mock LLM client for testing."""
|
|
|
|
def __init__(self, response: str = '{"confidence": 0.9, "result": "test"}'):
|
|
self.response = response
|
|
self.calls: list[dict] = []
|
|
|
|
async def chat(
|
|
self,
|
|
system_prompt: str,
|
|
user_prompt: str,
|
|
temperature: float = 0.1,
|
|
max_tokens: int = 2048,
|
|
) -> str:
|
|
self.calls.append({
|
|
"system_prompt": system_prompt,
|
|
"user_prompt": user_prompt,
|
|
"temperature": temperature,
|
|
"max_tokens": max_tokens,
|
|
})
|
|
return self.response
|
|
|
|
async def health_check(self) -> bool:
|
|
return True
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_llm() -> MockLLMClient:
|
|
"""Create mock LLM client."""
|
|
return MockLLMClient()
|
|
|
|
|
|
@pytest.fixture
|
|
def engine(mock_llm: MockLLMClient) -> ReasoningEngine:
|
|
"""Create reasoning engine with mock LLM."""
|
|
return ReasoningEngine(mock_llm)
|
|
|
|
|
|
class TestReasoningEngine:
|
|
"""Tests for ReasoningEngine."""
|
|
|
|
async def test_reason_single_stage(self, engine: ReasoningEngine):
|
|
"""Test reasoning with a single stage."""
|
|
result = await engine.reason(
|
|
input_text="test input",
|
|
stages=["analyze"],
|
|
)
|
|
|
|
assert isinstance(result, ReasoningResult)
|
|
assert len(result.steps) == 1
|
|
assert result.steps[0].stage_name == "analyze"
|
|
assert result.confidence > 0
|
|
|
|
async def test_reason_multiple_stages(self, engine: ReasoningEngine):
|
|
"""Test reasoning with multiple stages."""
|
|
result = await engine.reason(
|
|
input_text="test input",
|
|
stages=["analyze", "classify"],
|
|
)
|
|
|
|
assert len(result.steps) == 2
|
|
stage_names = [s.stage_name for s in result.steps]
|
|
assert "analyze" in stage_names
|
|
assert "classify" in stage_names
|
|
|
|
async def test_reason_with_context(self, engine: ReasoningEngine):
|
|
"""Test reasoning with additional context."""
|
|
result = await engine.reason(
|
|
input_text="test input",
|
|
stages=["analyze"],
|
|
context={"category": "test"},
|
|
)
|
|
|
|
assert result.output is not None
|
|
|
|
async def test_reason_caching(self, engine: ReasoningEngine, mock_llm: MockLLMClient):
|
|
"""Test that repeated calls use cache."""
|
|
# First call
|
|
result1 = await engine.reason(input_text="test", stages=["analyze"])
|
|
assert not result1.cached
|
|
|
|
# Second call with same input
|
|
result2 = await engine.reason(input_text="test", stages=["analyze"])
|
|
assert result2.cached
|
|
|
|
# Should only call LLM once
|
|
assert len(mock_llm.calls) == 1
|
|
|
|
async def test_json_extraction_clean(self, engine: ReasoningEngine):
|
|
"""Test JSON extraction from clean response."""
|
|
result = engine._extract_json('{"key": "value"}', "test")
|
|
assert result == {"key": "value"}
|
|
|
|
async def test_json_extraction_markdown(self, engine: ReasoningEngine):
|
|
"""Test JSON extraction from markdown-wrapped response."""
|
|
result = engine._extract_json('```json\n{"key": "value"}\n```', "test")
|
|
assert result == {"key": "value"}
|
|
|
|
async def test_json_extraction_with_text(self, engine: ReasoningEngine):
|
|
"""Test JSON extraction with surrounding text."""
|
|
result = engine._extract_json('Here is the result: {"key": "value"} Done.', "test")
|
|
assert result == {"key": "value"}
|
|
|
|
async def test_json_extraction_failure(self, engine: ReasoningEngine):
|
|
"""Test JSON extraction handles invalid JSON."""
|
|
result = engine._extract_json("not json at all", "test")
|
|
assert "error" in result
|
|
|
|
async def test_list_stages(self, engine: ReasoningEngine):
|
|
"""Test listing available stages."""
|
|
stages = engine.list_stages()
|
|
assert "analyze" in stages
|
|
assert "classify" in stages
|
|
assert "synthesize" in stages
|
|
|
|
async def test_llm_health_check(self, engine: ReasoningEngine):
|
|
"""Test LLM health check."""
|
|
is_healthy = await engine.check_llm_health()
|
|
assert is_healthy is True
|
|
|
|
|
|
class TestStageRegistry:
|
|
"""Tests for stage registry."""
|
|
|
|
def test_default_stages_registered(self):
|
|
"""Test that default stages are registered."""
|
|
stages = get_registered_stages()
|
|
assert "analyze" in stages
|
|
assert "classify" in stages
|
|
assert "synthesize" in stages
|
|
|
|
def test_custom_stage_registration(self):
|
|
"""Test registering a custom stage."""
|
|
custom_stage = StageDefinition(
|
|
name="test_custom",
|
|
system_prompt="Test prompt",
|
|
user_template="Test: {input}",
|
|
)
|
|
register_stage(custom_stage)
|
|
|
|
stages = get_registered_stages()
|
|
assert "test_custom" in stages
|
|
|
|
def test_stage_with_dependencies(self):
|
|
"""Test stage with dependencies."""
|
|
stages = get_registered_stages()
|
|
synthesize = stages["synthesize"]
|
|
assert "analyze" in synthesize.depends_on
|