model-boss/tools/benchmark/runner.py
2026-05-11 00:20:11 -07:00

207 lines
6.5 KiB
Python

"""Generic benchmark runner for model-boss inference API.
Submits inference requests to the model-boss HTTP API at localhost:8210 and
collects structured results suitable for accuracy/latency analysis.
"""
from __future__ import annotations
import asyncio
import base64
import io
import logging
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any
import httpx
from PIL import Image
MODEL_BOSS_URL = "http://localhost:8210"
logger = logging.getLogger(__name__)
@dataclass
class SampleResult:
sample_id: int | str
predictions: dict[str, Any] # gate_name -> bool (vlm_celeba) | answer string | exec verdict
ground_truth: dict[str, Any]
correct: int
total: int
inference_time_s: float
error: str | None = None
# Optional bag for suite-specific per-sample data (raw response, judge rubric, etc.)
extra: dict[str, Any] = field(default_factory=dict)
@dataclass
class BenchmarkResult:
model_id: str
suite_name: str
timestamp: str
samples: int
overall_accuracy: float
# Per-category accuracy (gates for vlm_celeba, subjects for MMLU, datasets for code).
# Kept under the legacy name so old vision.md tables still parse.
per_gate_accuracy: dict[str, float]
avg_inference_time_s: float
sample_results: list[SampleResult] = field(default_factory=list)
# Suite-specific scalar metrics (e.g. {"gsm8k_acc": 0.62, "mmlu_pro_acc": 0.41}).
metrics: dict[str, float] = field(default_factory=dict)
def _extract_message_text(data: dict[str, Any]) -> str:
"""Return user-visible text from a chat completion response.
Some llama-server builds (Qwen 3.6 thinking variants in particular) split
responses into `reasoning_content` (chain-of-thought) and `content` (final
answer). When the model gets truncated mid-thought, `content` is empty
while `reasoning_content` holds everything. We concatenate when both are
present so a regex parser scanning the whole string still finds the final
answer marker, and fall back to `reasoning_content` when `content` is empty.
"""
msg = data["choices"][0]["message"]
content = msg.get("content") or ""
reasoning = msg.get("reasoning_content") or ""
if content and reasoning:
return f"{reasoning}\n{content}"
return content or reasoning
async def submit_vlm_request(
model_id: str,
image: Image.Image,
prompt: str,
*,
max_tokens: int = 400,
temperature: float = 0.0,
) -> str:
"""Submit a VLM chat request through model-boss and return the text response."""
buf = io.BytesIO()
image.save(buf, format="PNG")
img_b64 = base64.b64encode(buf.getvalue()).decode()
body = {
"model": model_id,
"messages": [
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_b64}"}},
{"type": "text", "text": prompt},
],
}
],
"max_tokens": max_tokens,
"temperature": temperature,
}
async with httpx.AsyncClient(timeout=300.0) as client:
resp = await client.post(f"{MODEL_BOSS_URL}/v1/chat/completions", json=body)
resp.raise_for_status()
data = resp.json()
return _extract_message_text(data)
async def submit_chat_request(
model_id: str,
prompt: str,
*,
max_tokens: int = 4096,
temperature: float = 0.0,
system: str | None = None,
timeout: float = 600.0,
) -> str:
"""Submit a text-only chat request through model-boss and return the text response.
Used by reasoning, code, and judge calls. Temperature defaults to 0 for
reproducibility. max_tokens defaults to 4096 to leave thinking models
enough budget to emit chain-of-thought AND a final answer.
"""
messages: list[dict[str, Any]] = []
if system:
messages.append({"role": "system", "content": system})
messages.append({"role": "user", "content": prompt})
body: dict[str, Any] = {
"model": model_id,
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
}
async with httpx.AsyncClient(timeout=timeout) as client:
resp = await client.post(f"{MODEL_BOSS_URL}/v1/chat/completions", json=body)
resp.raise_for_status()
data = resp.json()
return _extract_message_text(data)
def score_predictions(
predictions: dict[str, Any],
ground_truth: dict[str, Any],
) -> tuple[int, int]:
"""Return (correct, total) for a single sample.
Compares predictions[k] == ground_truth[k] for every k in ground_truth.
Missing keys count as wrong.
"""
total = len(ground_truth)
correct = sum(
predictions.get(gate) == label for gate, label in ground_truth.items()
)
return correct, total
def aggregate_results(
sample_results: list[SampleResult],
all_gates: list[str],
model_id: str,
suite_name: str,
*,
metrics: dict[str, float] | None = None,
) -> BenchmarkResult:
"""Compute aggregate metrics from a list of SampleResults."""
valid = [s for s in sample_results if s.error is None]
n = len(valid)
if n == 0:
return BenchmarkResult(
model_id=model_id,
suite_name=suite_name,
timestamp=datetime.now(timezone.utc).isoformat(),
samples=len(sample_results),
overall_accuracy=0.0,
per_gate_accuracy={g: 0.0 for g in all_gates},
avg_inference_time_s=0.0,
sample_results=sample_results,
metrics=metrics or {},
)
total_correct = sum(s.correct for s in valid)
total_attrs = sum(s.total for s in valid)
overall_accuracy = total_correct / total_attrs if total_attrs > 0 else 0.0
avg_time = sum(s.inference_time_s for s in valid) / n
per_gate: dict[str, float] = {}
for gate in all_gates:
gate_correct = sum(
1
for s in valid
if gate in s.ground_truth and s.predictions.get(gate) == s.ground_truth[gate]
)
gate_total = sum(1 for s in valid if gate in s.ground_truth)
per_gate[gate] = gate_correct / gate_total if gate_total > 0 else 0.0
return BenchmarkResult(
model_id=model_id,
suite_name=suite_name,
timestamp=datetime.now(timezone.utc).isoformat(),
samples=len(sample_results),
overall_accuracy=overall_accuracy,
per_gate_accuracy=per_gate,
avg_inference_time_s=avg_time,
sample_results=sample_results,
metrics=metrics or {},
)