ml-knowledge-platform/knowledge_platform/backend/knowledge_backend.py
Lilith 240b4328f1 chore(config): 🔧 Update 40 configuration files across project
Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
2026-02-16 01:39:57 -08:00

450 lines
16 KiB
Python

"""Knowledge platform backend — hybrid routing between KV API and LLM.
Routes verification/search queries to the KV API and general conversation
to litellm with optional semantic search context injection.
"""
from __future__ import annotations
import re
from dataclasses import dataclass, field
from typing import Any, AsyncIterator, ClassVar
from knowledge_platform.backend.kv_client import KVClient
from knowledge_platform.config import ModelConfig
# Patterns that indicate a verification-type query
_VERIFY_PATTERNS = re.compile(
r"\b(verify|validate|is it true|fact.?check|check if|is this correct|is this accurate)\b",
re.IGNORECASE,
)
_CORRECT_PATTERNS = re.compile(
r"\b(fix|correct|rewrite|improve accuracy|make accurate)\b",
re.IGNORECASE,
)
_SLASH_COMMAND = re.compile(r"^/(\w+)\s*(.*)", re.DOTALL)
@dataclass
class SlashCommandResult:
"""Result of processing a slash command."""
command: str
content: str
raw_data: dict[str, Any] = field(default_factory=dict)
@dataclass
class KnowledgeBackend:
"""Hybrid backend that routes queries to KV API or LLM as appropriate.
- Slash commands (/search, /verify, /status) go to KV API directly
- Verification queries go to KV API validate endpoint
- Correction queries go to KV API correct endpoint
- General conversation gets semantic search context, then LLM streaming
"""
kv_api_url: str = "http://localhost:41233"
def parse_slash_command(self, text: str) -> tuple[str, str] | None:
"""Parse a /command from user input. Returns (command, args) or None."""
match = _SLASH_COMMAND.match(text.strip())
if match:
return match.group(1).lower(), match.group(2).strip()
return None
def is_verification_query(self, text: str) -> bool:
"""Detect whether user input is asking for fact verification."""
return bool(_VERIFY_PATTERNS.search(text))
def is_correction_query(self, text: str) -> bool:
"""Detect whether user input is asking for content correction."""
return bool(_CORRECT_PATTERNS.search(text))
async def handle_slash_command(self, command: str, args: str) -> SlashCommandResult:
"""Execute a slash command against the KV API."""
async with KVClient(base_url=self.kv_api_url) as kv:
if command == "search":
if not args:
return SlashCommandResult(
command="search",
content="Usage: /search <query>",
)
results = await kv.search(args)
return SlashCommandResult(
command="search",
content=_format_search_results(results, args),
raw_data={"results": results},
)
elif command == "verify":
if not args:
return SlashCommandResult(
command="verify",
content="Usage: /verify <claim to check>",
)
result = await kv.validate(args)
return SlashCommandResult(
command="verify",
content=_format_validation_result(result),
raw_data=result,
)
elif command in ("correct", "fix"):
if not args:
return SlashCommandResult(
command="correct",
content="Usage: /correct <content to fix>",
)
result = await kv.correct(args, use_reasoning=True)
return SlashCommandResult(
command="correct",
content=_format_correction_result(result),
raw_data=result,
)
elif command == "status":
try:
health = await kv.health()
llm_health = await kv.llm_health()
return SlashCommandResult(
command="status",
content=_format_status(health, llm_health),
raw_data={"health": health, "llm_health": llm_health},
)
except Exception as exc:
return SlashCommandResult(
command="status",
content=f"**KV API Status**: Unavailable\n\nError: {exc}",
)
else:
return SlashCommandResult(
command=command,
content=(
f"Unknown command: `/{command}`\n\n"
"Available commands:\n"
"- `/search <query>` — Semantic search\n"
"- `/verify <claim>` — Fact verification\n"
"- `/correct <text>` — Content correction\n"
"- `/status` — KV API health"
),
)
async def get_context_for_llm(self, query: str) -> str | None:
"""Fetch semantic search context from KV API to inject into LLM conversation.
Returns a context block string, or None if no results found.
Raises on KV API failure — fail fast, no silent degradation.
"""
async with KVClient(base_url=self.kv_api_url) as kv:
results = await kv.search(query, limit=3)
if not results:
return None
return _build_context_block(results)
# Map llama-http ports to their systemd user services
_PORT_TO_SERVICE: ClassVar[dict[str, str]] = {
"10010": "llama-http-3b.service",
"10020": "llama-http-14b.service",
}
def _is_local_model(self, model: ModelConfig) -> bool:
"""Check if this model uses a local llama-http endpoint."""
if model.api_base:
base = str(model.api_base)
return "localhost" in base and (":10010" in base or ":10020" in base)
return False
def _service_for_model(self, model: ModelConfig) -> str | None:
"""Return the systemd user service name for a local model, or None."""
if not model.api_base:
return None
base = str(model.api_base)
for port, service in self._PORT_TO_SERVICE.items():
if f":{port}" in base:
return service
return None
async def _ensure_local_service(self, model: ModelConfig) -> None:
"""Start the llama-http systemd service if not running. Raises on failure."""
import asyncio
import httpx
import logging
log = logging.getLogger(__name__)
base_url = str(model.api_base).rstrip("/")
if base_url.endswith("/v1"):
base_url = base_url[:-3]
# Fast path: check if already healthy
try:
async with httpx.AsyncClient(timeout=httpx.Timeout(2.0)) as client:
resp = await client.get(f"{base_url}/health")
if resp.status_code == 200:
return
except httpx.ConnectError:
pass
# Service is down — start it via systemd
service = self._service_for_model(model)
if not service:
raise RuntimeError(
f"No systemd service mapped for {model.api_base}. "
f"Cannot start local inference."
)
log.info("llama-http not running, starting %s", service)
proc = await asyncio.create_subprocess_exec(
"systemctl", "--user", "start", service,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await proc.communicate()
if proc.returncode != 0:
raise RuntimeError(
f"Failed to start {service}: {stderr.decode().strip()}"
)
# Wait for health (model loading on GPU takes a few seconds)
for attempt in range(30):
await asyncio.sleep(1)
try:
async with httpx.AsyncClient(timeout=httpx.Timeout(2.0)) as client:
resp = await client.get(f"{base_url}/health")
if resp.status_code == 200:
log.info("%s ready after %ds", service, attempt + 1)
return
except httpx.ConnectError:
continue
raise RuntimeError(
f"{service} started but not healthy after 30s. "
f"Check: journalctl --user -u {service} -n 30"
)
async def stream_llm_response(
self,
messages: list[dict[str, Any]],
model: ModelConfig,
context: str | None = None,
) -> AsyncIterator[str]:
"""Stream an LLM response, optionally with injected context.
Routes local models through direct httpx SSE (llama-http custom format)
and cloud models through litellm (standard OpenAI format).
"""
if context and messages:
messages = _inject_context(messages, context)
if self._is_local_model(model):
await self._ensure_local_service(model)
async for chunk in self._stream_local(messages, model):
yield chunk
else:
async for chunk in self._stream_litellm(messages, model):
yield chunk
async def _stream_local(
self,
messages: list[dict[str, Any]],
model: ModelConfig,
) -> AsyncIterator[str]:
"""Stream from local llama-http using its custom SSE format.
llama-http returns: {"type": "chunk", "content": "token"} events
instead of standard OpenAI delta format.
"""
import json
import httpx
base_url = str(model.api_base)
base_url = base_url.rstrip("/")
if base_url.endswith("/v1"):
base_url = base_url[:-3]
url = f"{base_url}/v1/chat/completions"
payload = {
"messages": messages,
"temperature": model.temperature,
"stream": True,
}
async with httpx.AsyncClient(timeout=httpx.Timeout(300.0)) as client:
async with client.stream("POST", url, json=payload) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line.startswith("data: "):
data_str = line[6:].strip()
if not data_str:
continue
try:
data = json.loads(data_str)
except json.JSONDecodeError:
continue
msg_type = data.get("type")
content = data.get("content")
if msg_type == "chunk" and content:
yield content
elif msg_type == "done":
return
elif msg_type == "error":
error_msg = data.get("error", "Unknown error")
raise RuntimeError(f"llama-http error: {error_msg}")
async def _stream_litellm(
self,
messages: list[dict[str, Any]],
model: ModelConfig,
) -> AsyncIterator[str]:
"""Stream from cloud models via litellm (standard OpenAI format)."""
import litellm
from litellm import acompletion
litellm.organization = model.organization
response = await acompletion(
messages=messages,
stream=True,
model=model.name,
temperature=model.temperature,
max_retries=model.max_retries,
api_key=model.api_key.get_secret_value() if model.api_key else None,
api_base=model.api_base.unicode_string() if model.api_base else None,
)
async for chunk in response:
chunk_content = chunk.choices[0].delta.content
if isinstance(chunk_content, str):
yield chunk_content
else:
break
def _inject_context(messages: list[dict[str, Any]], context: str) -> list[dict[str, Any]]:
"""Inject knowledge base context into the conversation.
Adds context as a prefix to the last user message.
"""
messages = [m.copy() for m in messages]
for i in range(len(messages) - 1, -1, -1):
if messages[i].get("role") == "user":
original = messages[i].get("content", "")
messages[i]["content"] = (
f"[Platform Knowledge Context]\n{context}\n\n"
f"[User Query]\n{original}"
)
break
return messages
def _build_context_block(results: list[dict[str, Any]]) -> str:
"""Build a context block from semantic search results."""
lines = []
for r in results:
title = r.get("title", "")
content = r.get("content", r.get("text", ""))
score = r.get("score", 0)
if title:
lines.append(f"### {title} (relevance: {score:.2f})")
if content:
if len(content) > 500:
content = content[:500] + "..."
lines.append(content)
lines.append("")
return "\n".join(lines)
def _format_search_results(results: list[dict[str, Any]], query: str) -> str:
"""Format search results as markdown."""
if not results:
return f"No results found for: **{query}**"
lines = [f"## Search Results for: {query}\n"]
for i, r in enumerate(results, 1):
title = r.get("title", f"Result {i}")
content = r.get("content", r.get("text", ""))
score = r.get("score", 0)
lines.append(f"### {i}. {title}")
lines.append(f"*Relevance: {score:.2f}*\n")
if content:
if len(content) > 300:
content = content[:300] + "..."
lines.append(content)
lines.append("")
return "\n".join(lines)
def _format_validation_result(result: dict[str, Any]) -> str:
"""Format a validation result as markdown."""
verdict = result.get("verdict", "unknown")
confidence = result.get("confidence", 0)
explanation = result.get("explanation", "")
verdict_emoji = {"verified": "", "contradiction": "", "needs_review": "⚠️"}.get(
verdict, ""
)
lines = [
f"## {verdict_emoji} Verification Result\n",
f"**Verdict**: {verdict.replace('_', ' ').title()}",
f"**Confidence**: {confidence:.0%}\n",
]
if explanation:
lines.append(f"**Explanation**: {explanation}")
corrections = result.get("corrections", [])
if corrections:
lines.append("\n### Suggested Corrections")
for c in corrections:
lines.append(f"- {c}")
return "\n".join(lines)
def _format_correction_result(result: dict[str, Any]) -> str:
"""Format a correction result as markdown."""
corrected = result.get("corrected_content", "")
changes = result.get("changes", [])
reasoning = result.get("reasoning", "")
lines = ["## Correction Result\n"]
if corrected:
lines.append("### Corrected Content")
lines.append(f"```\n{corrected}\n```\n")
if changes:
lines.append("### Changes Made")
for c in changes:
lines.append(f"- {c}")
if reasoning:
lines.append(f"\n**Reasoning**: {reasoning}")
return "\n".join(lines)
def _format_status(health: dict[str, Any], llm_health: dict[str, Any]) -> str:
"""Format KV API and LLM health as markdown."""
api_status = health.get("status", "unknown")
llm_status = llm_health.get("status", "unknown")
api_emoji = "" if api_status == "ok" else ""
llm_emoji = "" if llm_status == "ok" else ""
lines = [
"## Service Status\n",
f"{api_emoji} **KV API**: {api_status}",
f"{llm_emoji} **LLM Service**: {llm_status}",
]
if "version" in health:
lines.append(f"**API Version**: {health['version']}")
if "model" in llm_health:
lines.append(f"**LLM Model**: {llm_health['model']}")
return "\n".join(lines)