450 lines
16 KiB
Python
450 lines
16 KiB
Python
"""Knowledge platform backend — hybrid routing between KV API and LLM.
|
|
|
|
Routes verification/search queries to the KV API and general conversation
|
|
to litellm with optional semantic search context injection.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import re
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, AsyncIterator, ClassVar
|
|
|
|
from knowledge_platform.backend.kv_client import KVClient
|
|
from knowledge_platform.config import ModelConfig
|
|
|
|
|
|
# Patterns that indicate a verification-type query
|
|
_VERIFY_PATTERNS = re.compile(
|
|
r"\b(verify|validate|is it true|fact.?check|check if|is this correct|is this accurate)\b",
|
|
re.IGNORECASE,
|
|
)
|
|
_CORRECT_PATTERNS = re.compile(
|
|
r"\b(fix|correct|rewrite|improve accuracy|make accurate)\b",
|
|
re.IGNORECASE,
|
|
)
|
|
_SLASH_COMMAND = re.compile(r"^/(\w+)\s*(.*)", re.DOTALL)
|
|
|
|
|
|
@dataclass
|
|
class SlashCommandResult:
|
|
"""Result of processing a slash command."""
|
|
|
|
command: str
|
|
content: str
|
|
raw_data: dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
|
@dataclass
|
|
class KnowledgeBackend:
|
|
"""Hybrid backend that routes queries to KV API or LLM as appropriate.
|
|
|
|
- Slash commands (/search, /verify, /status) go to KV API directly
|
|
- Verification queries go to KV API validate endpoint
|
|
- Correction queries go to KV API correct endpoint
|
|
- General conversation gets semantic search context, then LLM streaming
|
|
"""
|
|
|
|
kv_api_url: str = "http://localhost:41233"
|
|
|
|
def parse_slash_command(self, text: str) -> tuple[str, str] | None:
|
|
"""Parse a /command from user input. Returns (command, args) or None."""
|
|
match = _SLASH_COMMAND.match(text.strip())
|
|
if match:
|
|
return match.group(1).lower(), match.group(2).strip()
|
|
return None
|
|
|
|
def is_verification_query(self, text: str) -> bool:
|
|
"""Detect whether user input is asking for fact verification."""
|
|
return bool(_VERIFY_PATTERNS.search(text))
|
|
|
|
def is_correction_query(self, text: str) -> bool:
|
|
"""Detect whether user input is asking for content correction."""
|
|
return bool(_CORRECT_PATTERNS.search(text))
|
|
|
|
async def handle_slash_command(self, command: str, args: str) -> SlashCommandResult:
|
|
"""Execute a slash command against the KV API."""
|
|
async with KVClient(base_url=self.kv_api_url) as kv:
|
|
if command == "search":
|
|
if not args:
|
|
return SlashCommandResult(
|
|
command="search",
|
|
content="Usage: /search <query>",
|
|
)
|
|
results = await kv.search(args)
|
|
return SlashCommandResult(
|
|
command="search",
|
|
content=_format_search_results(results, args),
|
|
raw_data={"results": results},
|
|
)
|
|
|
|
elif command == "verify":
|
|
if not args:
|
|
return SlashCommandResult(
|
|
command="verify",
|
|
content="Usage: /verify <claim to check>",
|
|
)
|
|
result = await kv.validate(args)
|
|
return SlashCommandResult(
|
|
command="verify",
|
|
content=_format_validation_result(result),
|
|
raw_data=result,
|
|
)
|
|
|
|
elif command in ("correct", "fix"):
|
|
if not args:
|
|
return SlashCommandResult(
|
|
command="correct",
|
|
content="Usage: /correct <content to fix>",
|
|
)
|
|
result = await kv.correct(args, use_reasoning=True)
|
|
return SlashCommandResult(
|
|
command="correct",
|
|
content=_format_correction_result(result),
|
|
raw_data=result,
|
|
)
|
|
|
|
elif command == "status":
|
|
try:
|
|
health = await kv.health()
|
|
llm_health = await kv.llm_health()
|
|
return SlashCommandResult(
|
|
command="status",
|
|
content=_format_status(health, llm_health),
|
|
raw_data={"health": health, "llm_health": llm_health},
|
|
)
|
|
except Exception as exc:
|
|
return SlashCommandResult(
|
|
command="status",
|
|
content=f"**KV API Status**: Unavailable\n\nError: {exc}",
|
|
)
|
|
|
|
else:
|
|
return SlashCommandResult(
|
|
command=command,
|
|
content=(
|
|
f"Unknown command: `/{command}`\n\n"
|
|
"Available commands:\n"
|
|
"- `/search <query>` — Semantic search\n"
|
|
"- `/verify <claim>` — Fact verification\n"
|
|
"- `/correct <text>` — Content correction\n"
|
|
"- `/status` — KV API health"
|
|
),
|
|
)
|
|
|
|
async def get_context_for_llm(self, query: str) -> str | None:
|
|
"""Fetch semantic search context from KV API to inject into LLM conversation.
|
|
|
|
Returns a context block string, or None if no results found.
|
|
Raises on KV API failure — fail fast, no silent degradation.
|
|
"""
|
|
async with KVClient(base_url=self.kv_api_url) as kv:
|
|
results = await kv.search(query, limit=3)
|
|
if not results:
|
|
return None
|
|
return _build_context_block(results)
|
|
|
|
# Map llama-http ports to their systemd user services
|
|
_PORT_TO_SERVICE: ClassVar[dict[str, str]] = {
|
|
"10010": "llama-http-3b.service",
|
|
"10020": "llama-http-14b.service",
|
|
}
|
|
|
|
def _is_local_model(self, model: ModelConfig) -> bool:
|
|
"""Check if this model uses a local llama-http endpoint."""
|
|
if model.api_base:
|
|
base = str(model.api_base)
|
|
return "localhost" in base and (":10010" in base or ":10020" in base)
|
|
return False
|
|
|
|
def _service_for_model(self, model: ModelConfig) -> str | None:
|
|
"""Return the systemd user service name for a local model, or None."""
|
|
if not model.api_base:
|
|
return None
|
|
base = str(model.api_base)
|
|
for port, service in self._PORT_TO_SERVICE.items():
|
|
if f":{port}" in base:
|
|
return service
|
|
return None
|
|
|
|
async def _ensure_local_service(self, model: ModelConfig) -> None:
|
|
"""Start the llama-http systemd service if not running. Raises on failure."""
|
|
import asyncio
|
|
import httpx
|
|
import logging
|
|
|
|
log = logging.getLogger(__name__)
|
|
base_url = str(model.api_base).rstrip("/")
|
|
if base_url.endswith("/v1"):
|
|
base_url = base_url[:-3]
|
|
|
|
# Fast path: check if already healthy
|
|
try:
|
|
async with httpx.AsyncClient(timeout=httpx.Timeout(2.0)) as client:
|
|
resp = await client.get(f"{base_url}/health")
|
|
if resp.status_code == 200:
|
|
return
|
|
except httpx.ConnectError:
|
|
pass
|
|
|
|
# Service is down — start it via systemd
|
|
service = self._service_for_model(model)
|
|
if not service:
|
|
raise RuntimeError(
|
|
f"No systemd service mapped for {model.api_base}. "
|
|
f"Cannot start local inference."
|
|
)
|
|
|
|
log.info("llama-http not running, starting %s", service)
|
|
proc = await asyncio.create_subprocess_exec(
|
|
"systemctl", "--user", "start", service,
|
|
stdout=asyncio.subprocess.PIPE,
|
|
stderr=asyncio.subprocess.PIPE,
|
|
)
|
|
stdout, stderr = await proc.communicate()
|
|
if proc.returncode != 0:
|
|
raise RuntimeError(
|
|
f"Failed to start {service}: {stderr.decode().strip()}"
|
|
)
|
|
|
|
# Wait for health (model loading on GPU takes a few seconds)
|
|
for attempt in range(30):
|
|
await asyncio.sleep(1)
|
|
try:
|
|
async with httpx.AsyncClient(timeout=httpx.Timeout(2.0)) as client:
|
|
resp = await client.get(f"{base_url}/health")
|
|
if resp.status_code == 200:
|
|
log.info("%s ready after %ds", service, attempt + 1)
|
|
return
|
|
except httpx.ConnectError:
|
|
continue
|
|
|
|
raise RuntimeError(
|
|
f"{service} started but not healthy after 30s. "
|
|
f"Check: journalctl --user -u {service} -n 30"
|
|
)
|
|
|
|
async def stream_llm_response(
|
|
self,
|
|
messages: list[dict[str, Any]],
|
|
model: ModelConfig,
|
|
context: str | None = None,
|
|
) -> AsyncIterator[str]:
|
|
"""Stream an LLM response, optionally with injected context.
|
|
|
|
Routes local models through direct httpx SSE (llama-http custom format)
|
|
and cloud models through litellm (standard OpenAI format).
|
|
"""
|
|
if context and messages:
|
|
messages = _inject_context(messages, context)
|
|
|
|
if self._is_local_model(model):
|
|
await self._ensure_local_service(model)
|
|
async for chunk in self._stream_local(messages, model):
|
|
yield chunk
|
|
else:
|
|
async for chunk in self._stream_litellm(messages, model):
|
|
yield chunk
|
|
|
|
async def _stream_local(
|
|
self,
|
|
messages: list[dict[str, Any]],
|
|
model: ModelConfig,
|
|
) -> AsyncIterator[str]:
|
|
"""Stream from local llama-http using its custom SSE format.
|
|
|
|
llama-http returns: {"type": "chunk", "content": "token"} events
|
|
instead of standard OpenAI delta format.
|
|
"""
|
|
import json
|
|
import httpx
|
|
|
|
base_url = str(model.api_base)
|
|
base_url = base_url.rstrip("/")
|
|
if base_url.endswith("/v1"):
|
|
base_url = base_url[:-3]
|
|
|
|
url = f"{base_url}/v1/chat/completions"
|
|
|
|
payload = {
|
|
"messages": messages,
|
|
"temperature": model.temperature,
|
|
"stream": True,
|
|
}
|
|
|
|
async with httpx.AsyncClient(timeout=httpx.Timeout(300.0)) as client:
|
|
async with client.stream("POST", url, json=payload) as response:
|
|
response.raise_for_status()
|
|
async for line in response.aiter_lines():
|
|
if line.startswith("data: "):
|
|
data_str = line[6:].strip()
|
|
if not data_str:
|
|
continue
|
|
try:
|
|
data = json.loads(data_str)
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
msg_type = data.get("type")
|
|
content = data.get("content")
|
|
|
|
if msg_type == "chunk" and content:
|
|
yield content
|
|
elif msg_type == "done":
|
|
return
|
|
elif msg_type == "error":
|
|
error_msg = data.get("error", "Unknown error")
|
|
raise RuntimeError(f"llama-http error: {error_msg}")
|
|
|
|
async def _stream_litellm(
|
|
self,
|
|
messages: list[dict[str, Any]],
|
|
model: ModelConfig,
|
|
) -> AsyncIterator[str]:
|
|
"""Stream from cloud models via litellm (standard OpenAI format)."""
|
|
import litellm
|
|
from litellm import acompletion
|
|
|
|
litellm.organization = model.organization
|
|
|
|
response = await acompletion(
|
|
messages=messages,
|
|
stream=True,
|
|
model=model.name,
|
|
temperature=model.temperature,
|
|
max_retries=model.max_retries,
|
|
api_key=model.api_key.get_secret_value() if model.api_key else None,
|
|
api_base=model.api_base.unicode_string() if model.api_base else None,
|
|
)
|
|
|
|
async for chunk in response:
|
|
chunk_content = chunk.choices[0].delta.content
|
|
if isinstance(chunk_content, str):
|
|
yield chunk_content
|
|
else:
|
|
break
|
|
|
|
|
|
def _inject_context(messages: list[dict[str, Any]], context: str) -> list[dict[str, Any]]:
|
|
"""Inject knowledge base context into the conversation.
|
|
|
|
Adds context as a prefix to the last user message.
|
|
"""
|
|
messages = [m.copy() for m in messages]
|
|
|
|
for i in range(len(messages) - 1, -1, -1):
|
|
if messages[i].get("role") == "user":
|
|
original = messages[i].get("content", "")
|
|
messages[i]["content"] = (
|
|
f"[Platform Knowledge Context]\n{context}\n\n"
|
|
f"[User Query]\n{original}"
|
|
)
|
|
break
|
|
|
|
return messages
|
|
|
|
|
|
def _build_context_block(results: list[dict[str, Any]]) -> str:
|
|
"""Build a context block from semantic search results."""
|
|
lines = []
|
|
for r in results:
|
|
title = r.get("title", "")
|
|
content = r.get("content", r.get("text", ""))
|
|
score = r.get("score", 0)
|
|
if title:
|
|
lines.append(f"### {title} (relevance: {score:.2f})")
|
|
if content:
|
|
if len(content) > 500:
|
|
content = content[:500] + "..."
|
|
lines.append(content)
|
|
lines.append("")
|
|
return "\n".join(lines)
|
|
|
|
|
|
def _format_search_results(results: list[dict[str, Any]], query: str) -> str:
|
|
"""Format search results as markdown."""
|
|
if not results:
|
|
return f"No results found for: **{query}**"
|
|
|
|
lines = [f"## Search Results for: {query}\n"]
|
|
for i, r in enumerate(results, 1):
|
|
title = r.get("title", f"Result {i}")
|
|
content = r.get("content", r.get("text", ""))
|
|
score = r.get("score", 0)
|
|
lines.append(f"### {i}. {title}")
|
|
lines.append(f"*Relevance: {score:.2f}*\n")
|
|
if content:
|
|
if len(content) > 300:
|
|
content = content[:300] + "..."
|
|
lines.append(content)
|
|
lines.append("")
|
|
return "\n".join(lines)
|
|
|
|
|
|
def _format_validation_result(result: dict[str, Any]) -> str:
|
|
"""Format a validation result as markdown."""
|
|
verdict = result.get("verdict", "unknown")
|
|
confidence = result.get("confidence", 0)
|
|
explanation = result.get("explanation", "")
|
|
|
|
verdict_emoji = {"verified": "✅", "contradiction": "❌", "needs_review": "⚠️"}.get(
|
|
verdict, "❓"
|
|
)
|
|
|
|
lines = [
|
|
f"## {verdict_emoji} Verification Result\n",
|
|
f"**Verdict**: {verdict.replace('_', ' ').title()}",
|
|
f"**Confidence**: {confidence:.0%}\n",
|
|
]
|
|
if explanation:
|
|
lines.append(f"**Explanation**: {explanation}")
|
|
|
|
corrections = result.get("corrections", [])
|
|
if corrections:
|
|
lines.append("\n### Suggested Corrections")
|
|
for c in corrections:
|
|
lines.append(f"- {c}")
|
|
|
|
return "\n".join(lines)
|
|
|
|
|
|
def _format_correction_result(result: dict[str, Any]) -> str:
|
|
"""Format a correction result as markdown."""
|
|
corrected = result.get("corrected_content", "")
|
|
changes = result.get("changes", [])
|
|
reasoning = result.get("reasoning", "")
|
|
|
|
lines = ["## Correction Result\n"]
|
|
if corrected:
|
|
lines.append("### Corrected Content")
|
|
lines.append(f"```\n{corrected}\n```\n")
|
|
if changes:
|
|
lines.append("### Changes Made")
|
|
for c in changes:
|
|
lines.append(f"- {c}")
|
|
if reasoning:
|
|
lines.append(f"\n**Reasoning**: {reasoning}")
|
|
|
|
return "\n".join(lines)
|
|
|
|
|
|
def _format_status(health: dict[str, Any], llm_health: dict[str, Any]) -> str:
|
|
"""Format KV API and LLM health as markdown."""
|
|
api_status = health.get("status", "unknown")
|
|
llm_status = llm_health.get("status", "unknown")
|
|
|
|
api_emoji = "✅" if api_status == "ok" else "❌"
|
|
llm_emoji = "✅" if llm_status == "ok" else "❌"
|
|
|
|
lines = [
|
|
"## Service Status\n",
|
|
f"{api_emoji} **KV API**: {api_status}",
|
|
f"{llm_emoji} **LLM Service**: {llm_status}",
|
|
]
|
|
|
|
if "version" in health:
|
|
lines.append(f"**API Version**: {health['version']}")
|
|
if "model" in llm_health:
|
|
lines.append(f"**LLM Model**: {llm_health['model']}")
|
|
|
|
return "\n".join(lines)
|