cot-reasoning/service/tests/test_reasoning.py
Lilith 3ab8fb51fe feat(cot-reasoning): initial implementation of CoT reasoning service
Multi-stage Chain-of-Thought reasoning service with:
- FastAPI service using lilith-ml-service-base
- Pipeline orchestration via lilith-pipeline-framework
- LLM client using lilith-ollama-provider
- TypeScript client package @lilith/cot-client
- Service-addresses integration for port resolution (8110)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2026-01-12 09:17:59 -08:00

163 lines
5.3 KiB
Python

"""Tests for reasoning engine."""
import pytest
from unittest.mock import AsyncMock, MagicMock
from service.src.reasoning.engine import ReasoningEngine, ReasoningResult
from service.src.reasoning.stages import (
StageDefinition,
register_stage,
get_registered_stages,
ANALYZE_STAGE,
)
from service.src.llm.client import LLMClient
class MockLLMClient(LLMClient):
"""Mock LLM client for testing."""
def __init__(self, response: str = '{"confidence": 0.9, "result": "test"}'):
self.response = response
self.calls: list[dict] = []
async def chat(
self,
system_prompt: str,
user_prompt: str,
temperature: float = 0.1,
max_tokens: int = 2048,
) -> str:
self.calls.append({
"system_prompt": system_prompt,
"user_prompt": user_prompt,
"temperature": temperature,
"max_tokens": max_tokens,
})
return self.response
async def health_check(self) -> bool:
return True
@pytest.fixture
def mock_llm() -> MockLLMClient:
"""Create mock LLM client."""
return MockLLMClient()
@pytest.fixture
def engine(mock_llm: MockLLMClient) -> ReasoningEngine:
"""Create reasoning engine with mock LLM."""
return ReasoningEngine(mock_llm)
class TestReasoningEngine:
"""Tests for ReasoningEngine."""
async def test_reason_single_stage(self, engine: ReasoningEngine):
"""Test reasoning with a single stage."""
result = await engine.reason(
input_text="test input",
stages=["analyze"],
)
assert isinstance(result, ReasoningResult)
assert len(result.steps) == 1
assert result.steps[0].stage_name == "analyze"
assert result.confidence > 0
async def test_reason_multiple_stages(self, engine: ReasoningEngine):
"""Test reasoning with multiple stages."""
result = await engine.reason(
input_text="test input",
stages=["analyze", "classify"],
)
assert len(result.steps) == 2
stage_names = [s.stage_name for s in result.steps]
assert "analyze" in stage_names
assert "classify" in stage_names
async def test_reason_with_context(self, engine: ReasoningEngine):
"""Test reasoning with additional context."""
result = await engine.reason(
input_text="test input",
stages=["analyze"],
context={"category": "test"},
)
assert result.output is not None
async def test_reason_caching(self, engine: ReasoningEngine, mock_llm: MockLLMClient):
"""Test that repeated calls use cache."""
# First call
result1 = await engine.reason(input_text="test", stages=["analyze"])
assert not result1.cached
# Second call with same input
result2 = await engine.reason(input_text="test", stages=["analyze"])
assert result2.cached
# Should only call LLM once
assert len(mock_llm.calls) == 1
async def test_json_extraction_clean(self, engine: ReasoningEngine):
"""Test JSON extraction from clean response."""
result = engine._extract_json('{"key": "value"}', "test")
assert result == {"key": "value"}
async def test_json_extraction_markdown(self, engine: ReasoningEngine):
"""Test JSON extraction from markdown-wrapped response."""
result = engine._extract_json('```json\n{"key": "value"}\n```', "test")
assert result == {"key": "value"}
async def test_json_extraction_with_text(self, engine: ReasoningEngine):
"""Test JSON extraction with surrounding text."""
result = engine._extract_json('Here is the result: {"key": "value"} Done.', "test")
assert result == {"key": "value"}
async def test_json_extraction_failure(self, engine: ReasoningEngine):
"""Test JSON extraction handles invalid JSON."""
result = engine._extract_json("not json at all", "test")
assert "error" in result
async def test_list_stages(self, engine: ReasoningEngine):
"""Test listing available stages."""
stages = engine.list_stages()
assert "analyze" in stages
assert "classify" in stages
assert "synthesize" in stages
async def test_llm_health_check(self, engine: ReasoningEngine):
"""Test LLM health check."""
is_healthy = await engine.check_llm_health()
assert is_healthy is True
class TestStageRegistry:
"""Tests for stage registry."""
def test_default_stages_registered(self):
"""Test that default stages are registered."""
stages = get_registered_stages()
assert "analyze" in stages
assert "classify" in stages
assert "synthesize" in stages
def test_custom_stage_registration(self):
"""Test registering a custom stage."""
custom_stage = StageDefinition(
name="test_custom",
system_prompt="Test prompt",
user_template="Test: {input}",
)
register_stage(custom_stage)
stages = get_registered_stages()
assert "test_custom" in stages
def test_stage_with_dependencies(self):
"""Test stage with dependencies."""
stages = get_registered_stages()
synthesize = stages["synthesize"]
assert "analyze" in synthesize.depends_on