"""Knowledge platform backend — hybrid routing between KV API and LLM. Routes verification/search queries to the KV API and general conversation to litellm with optional semantic search context injection. """ from __future__ import annotations import re from dataclasses import dataclass, field from typing import Any, AsyncIterator, ClassVar from knowledge_platform.backend.kv_client import KVClient from knowledge_platform.config import ModelConfig # Patterns that indicate a verification-type query _VERIFY_PATTERNS = re.compile( r"\b(verify|validate|is it true|fact.?check|check if|is this correct|is this accurate)\b", re.IGNORECASE, ) _CORRECT_PATTERNS = re.compile( r"\b(fix|correct|rewrite|improve accuracy|make accurate)\b", re.IGNORECASE, ) _SLASH_COMMAND = re.compile(r"^/(\w+)\s*(.*)", re.DOTALL) @dataclass class SlashCommandResult: """Result of processing a slash command.""" command: str content: str raw_data: dict[str, Any] = field(default_factory=dict) @dataclass class KnowledgeBackend: """Hybrid backend that routes queries to KV API or LLM as appropriate. - Slash commands (/search, /verify, /status) go to KV API directly - Verification queries go to KV API validate endpoint - Correction queries go to KV API correct endpoint - General conversation gets semantic search context, then LLM streaming """ kv_api_url: str = "http://localhost:41233" def parse_slash_command(self, text: str) -> tuple[str, str] | None: """Parse a /command from user input. Returns (command, args) or None.""" match = _SLASH_COMMAND.match(text.strip()) if match: return match.group(1).lower(), match.group(2).strip() return None def is_verification_query(self, text: str) -> bool: """Detect whether user input is asking for fact verification.""" return bool(_VERIFY_PATTERNS.search(text)) def is_correction_query(self, text: str) -> bool: """Detect whether user input is asking for content correction.""" return bool(_CORRECT_PATTERNS.search(text)) async def handle_slash_command(self, command: str, args: str) -> SlashCommandResult: """Execute a slash command against the KV API.""" async with KVClient(base_url=self.kv_api_url) as kv: if command == "search": if not args: return SlashCommandResult( command="search", content="Usage: /search ", ) results = await kv.search(args) return SlashCommandResult( command="search", content=_format_search_results(results, args), raw_data={"results": results}, ) elif command == "verify": if not args: return SlashCommandResult( command="verify", content="Usage: /verify ", ) result = await kv.validate(args) return SlashCommandResult( command="verify", content=_format_validation_result(result), raw_data=result, ) elif command in ("correct", "fix"): if not args: return SlashCommandResult( command="correct", content="Usage: /correct ", ) result = await kv.correct(args, use_reasoning=True) return SlashCommandResult( command="correct", content=_format_correction_result(result), raw_data=result, ) elif command == "status": try: health = await kv.health() llm_health = await kv.llm_health() return SlashCommandResult( command="status", content=_format_status(health, llm_health), raw_data={"health": health, "llm_health": llm_health}, ) except Exception as exc: return SlashCommandResult( command="status", content=f"**KV API Status**: Unavailable\n\nError: {exc}", ) else: return SlashCommandResult( command=command, content=( f"Unknown command: `/{command}`\n\n" "Available commands:\n" "- `/search ` — Semantic search\n" "- `/verify ` — Fact verification\n" "- `/correct ` — Content correction\n" "- `/status` — KV API health" ), ) async def get_context_for_llm(self, query: str) -> str | None: """Fetch semantic search context from KV API to inject into LLM conversation. Returns a context block string, or None if no results found. Raises on KV API failure — fail fast, no silent degradation. """ async with KVClient(base_url=self.kv_api_url) as kv: results = await kv.search(query, limit=3) if not results: return None return _build_context_block(results) # Map llama-http ports to their systemd user services _PORT_TO_SERVICE: ClassVar[dict[str, str]] = { "10010": "llama-http-3b.service", "10020": "llama-http-14b.service", } def _is_local_model(self, model: ModelConfig) -> bool: """Check if this model uses a local llama-http endpoint.""" if model.api_base: base = str(model.api_base) return "localhost" in base and (":10010" in base or ":10020" in base) return False def _service_for_model(self, model: ModelConfig) -> str | None: """Return the systemd user service name for a local model, or None.""" if not model.api_base: return None base = str(model.api_base) for port, service in self._PORT_TO_SERVICE.items(): if f":{port}" in base: return service return None async def _ensure_local_service(self, model: ModelConfig) -> None: """Start the llama-http systemd service if not running. Raises on failure.""" import asyncio import httpx import logging log = logging.getLogger(__name__) base_url = str(model.api_base).rstrip("/") if base_url.endswith("/v1"): base_url = base_url[:-3] # Fast path: check if already healthy try: async with httpx.AsyncClient(timeout=httpx.Timeout(2.0)) as client: resp = await client.get(f"{base_url}/health") if resp.status_code == 200: return except httpx.ConnectError: pass # Service is down — start it via systemd service = self._service_for_model(model) if not service: raise RuntimeError( f"No systemd service mapped for {model.api_base}. " f"Cannot start local inference." ) log.info("llama-http not running, starting %s", service) proc = await asyncio.create_subprocess_exec( "systemctl", "--user", "start", service, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) stdout, stderr = await proc.communicate() if proc.returncode != 0: raise RuntimeError( f"Failed to start {service}: {stderr.decode().strip()}" ) # Wait for health (model loading on GPU takes a few seconds) for attempt in range(30): await asyncio.sleep(1) try: async with httpx.AsyncClient(timeout=httpx.Timeout(2.0)) as client: resp = await client.get(f"{base_url}/health") if resp.status_code == 200: log.info("%s ready after %ds", service, attempt + 1) return except httpx.ConnectError: continue raise RuntimeError( f"{service} started but not healthy after 30s. " f"Check: journalctl --user -u {service} -n 30" ) async def stream_llm_response( self, messages: list[dict[str, Any]], model: ModelConfig, context: str | None = None, ) -> AsyncIterator[str]: """Stream an LLM response, optionally with injected context. Routes local models through direct httpx SSE (llama-http custom format) and cloud models through litellm (standard OpenAI format). """ if context and messages: messages = _inject_context(messages, context) if self._is_local_model(model): await self._ensure_local_service(model) async for chunk in self._stream_local(messages, model): yield chunk else: async for chunk in self._stream_litellm(messages, model): yield chunk async def _stream_local( self, messages: list[dict[str, Any]], model: ModelConfig, ) -> AsyncIterator[str]: """Stream from local llama-http using its custom SSE format. llama-http returns: {"type": "chunk", "content": "token"} events instead of standard OpenAI delta format. """ import json import httpx base_url = str(model.api_base) base_url = base_url.rstrip("/") if base_url.endswith("/v1"): base_url = base_url[:-3] url = f"{base_url}/v1/chat/completions" payload = { "messages": messages, "temperature": model.temperature, "stream": True, } async with httpx.AsyncClient(timeout=httpx.Timeout(300.0)) as client: async with client.stream("POST", url, json=payload) as response: response.raise_for_status() async for line in response.aiter_lines(): if line.startswith("data: "): data_str = line[6:].strip() if not data_str: continue try: data = json.loads(data_str) except json.JSONDecodeError: continue msg_type = data.get("type") content = data.get("content") if msg_type == "chunk" and content: yield content elif msg_type == "done": return elif msg_type == "error": error_msg = data.get("error", "Unknown error") raise RuntimeError(f"llama-http error: {error_msg}") async def _stream_litellm( self, messages: list[dict[str, Any]], model: ModelConfig, ) -> AsyncIterator[str]: """Stream from cloud models via litellm (standard OpenAI format).""" import litellm from litellm import acompletion litellm.organization = model.organization response = await acompletion( messages=messages, stream=True, model=model.name, temperature=model.temperature, max_retries=model.max_retries, api_key=model.api_key.get_secret_value() if model.api_key else None, api_base=model.api_base.unicode_string() if model.api_base else None, ) async for chunk in response: chunk_content = chunk.choices[0].delta.content if isinstance(chunk_content, str): yield chunk_content else: break def _inject_context(messages: list[dict[str, Any]], context: str) -> list[dict[str, Any]]: """Inject knowledge base context into the conversation. Adds context as a prefix to the last user message. """ messages = [m.copy() for m in messages] for i in range(len(messages) - 1, -1, -1): if messages[i].get("role") == "user": original = messages[i].get("content", "") messages[i]["content"] = ( f"[Platform Knowledge Context]\n{context}\n\n" f"[User Query]\n{original}" ) break return messages def _build_context_block(results: list[dict[str, Any]]) -> str: """Build a context block from semantic search results.""" lines = [] for r in results: title = r.get("title", "") content = r.get("content", r.get("text", "")) score = r.get("score", 0) if title: lines.append(f"### {title} (relevance: {score:.2f})") if content: if len(content) > 500: content = content[:500] + "..." lines.append(content) lines.append("") return "\n".join(lines) def _format_search_results(results: list[dict[str, Any]], query: str) -> str: """Format search results as markdown.""" if not results: return f"No results found for: **{query}**" lines = [f"## Search Results for: {query}\n"] for i, r in enumerate(results, 1): title = r.get("title", f"Result {i}") content = r.get("content", r.get("text", "")) score = r.get("score", 0) lines.append(f"### {i}. {title}") lines.append(f"*Relevance: {score:.2f}*\n") if content: if len(content) > 300: content = content[:300] + "..." lines.append(content) lines.append("") return "\n".join(lines) def _format_validation_result(result: dict[str, Any]) -> str: """Format a validation result as markdown.""" verdict = result.get("verdict", "unknown") confidence = result.get("confidence", 0) explanation = result.get("explanation", "") verdict_emoji = {"verified": "✅", "contradiction": "❌", "needs_review": "⚠️"}.get( verdict, "❓" ) lines = [ f"## {verdict_emoji} Verification Result\n", f"**Verdict**: {verdict.replace('_', ' ').title()}", f"**Confidence**: {confidence:.0%}\n", ] if explanation: lines.append(f"**Explanation**: {explanation}") corrections = result.get("corrections", []) if corrections: lines.append("\n### Suggested Corrections") for c in corrections: lines.append(f"- {c}") return "\n".join(lines) def _format_correction_result(result: dict[str, Any]) -> str: """Format a correction result as markdown.""" corrected = result.get("corrected_content", "") changes = result.get("changes", []) reasoning = result.get("reasoning", "") lines = ["## Correction Result\n"] if corrected: lines.append("### Corrected Content") lines.append(f"```\n{corrected}\n```\n") if changes: lines.append("### Changes Made") for c in changes: lines.append(f"- {c}") if reasoning: lines.append(f"\n**Reasoning**: {reasoning}") return "\n".join(lines) def _format_status(health: dict[str, Any], llm_health: dict[str, Any]) -> str: """Format KV API and LLM health as markdown.""" api_status = health.get("status", "unknown") llm_status = llm_health.get("status", "unknown") api_emoji = "✅" if api_status == "ok" else "❌" llm_emoji = "✅" if llm_status == "ok" else "❌" lines = [ "## Service Status\n", f"{api_emoji} **KV API**: {api_status}", f"{llm_emoji} **LLM Service**: {llm_status}", ] if "version" in health: lines.append(f"**API Version**: {health['version']}") if "model" in llm_health: lines.append(f"**LLM Model**: {llm_health['model']}") return "\n".join(lines)