207 lines
6.5 KiB
Python
207 lines
6.5 KiB
Python
"""Generic benchmark runner for model-boss inference API.
|
|
|
|
Submits inference requests to the model-boss HTTP API at localhost:8210 and
|
|
collects structured results suitable for accuracy/latency analysis.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import base64
|
|
import io
|
|
import logging
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime, timezone
|
|
from typing import Any
|
|
|
|
import httpx
|
|
from PIL import Image
|
|
|
|
MODEL_BOSS_URL = "http://localhost:8210"
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class SampleResult:
|
|
sample_id: int | str
|
|
predictions: dict[str, Any] # gate_name -> bool (vlm_celeba) | answer string | exec verdict
|
|
ground_truth: dict[str, Any]
|
|
correct: int
|
|
total: int
|
|
inference_time_s: float
|
|
error: str | None = None
|
|
# Optional bag for suite-specific per-sample data (raw response, judge rubric, etc.)
|
|
extra: dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
|
@dataclass
|
|
class BenchmarkResult:
|
|
model_id: str
|
|
suite_name: str
|
|
timestamp: str
|
|
samples: int
|
|
overall_accuracy: float
|
|
# Per-category accuracy (gates for vlm_celeba, subjects for MMLU, datasets for code).
|
|
# Kept under the legacy name so old vision.md tables still parse.
|
|
per_gate_accuracy: dict[str, float]
|
|
avg_inference_time_s: float
|
|
sample_results: list[SampleResult] = field(default_factory=list)
|
|
# Suite-specific scalar metrics (e.g. {"gsm8k_acc": 0.62, "mmlu_pro_acc": 0.41}).
|
|
metrics: dict[str, float] = field(default_factory=dict)
|
|
|
|
|
|
def _extract_message_text(data: dict[str, Any]) -> str:
|
|
"""Return user-visible text from a chat completion response.
|
|
|
|
Some llama-server builds (Qwen 3.6 thinking variants in particular) split
|
|
responses into `reasoning_content` (chain-of-thought) and `content` (final
|
|
answer). When the model gets truncated mid-thought, `content` is empty
|
|
while `reasoning_content` holds everything. We concatenate when both are
|
|
present so a regex parser scanning the whole string still finds the final
|
|
answer marker, and fall back to `reasoning_content` when `content` is empty.
|
|
"""
|
|
msg = data["choices"][0]["message"]
|
|
content = msg.get("content") or ""
|
|
reasoning = msg.get("reasoning_content") or ""
|
|
if content and reasoning:
|
|
return f"{reasoning}\n{content}"
|
|
return content or reasoning
|
|
|
|
|
|
async def submit_vlm_request(
|
|
model_id: str,
|
|
image: Image.Image,
|
|
prompt: str,
|
|
*,
|
|
max_tokens: int = 400,
|
|
temperature: float = 0.0,
|
|
) -> str:
|
|
"""Submit a VLM chat request through model-boss and return the text response."""
|
|
buf = io.BytesIO()
|
|
image.save(buf, format="PNG")
|
|
img_b64 = base64.b64encode(buf.getvalue()).decode()
|
|
|
|
body = {
|
|
"model": model_id,
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_b64}"}},
|
|
{"type": "text", "text": prompt},
|
|
],
|
|
}
|
|
],
|
|
"max_tokens": max_tokens,
|
|
"temperature": temperature,
|
|
}
|
|
|
|
async with httpx.AsyncClient(timeout=300.0) as client:
|
|
resp = await client.post(f"{MODEL_BOSS_URL}/v1/chat/completions", json=body)
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
return _extract_message_text(data)
|
|
|
|
|
|
async def submit_chat_request(
|
|
model_id: str,
|
|
prompt: str,
|
|
*,
|
|
max_tokens: int = 4096,
|
|
temperature: float = 0.0,
|
|
system: str | None = None,
|
|
timeout: float = 600.0,
|
|
) -> str:
|
|
"""Submit a text-only chat request through model-boss and return the text response.
|
|
|
|
Used by reasoning, code, and judge calls. Temperature defaults to 0 for
|
|
reproducibility. max_tokens defaults to 4096 to leave thinking models
|
|
enough budget to emit chain-of-thought AND a final answer.
|
|
"""
|
|
messages: list[dict[str, Any]] = []
|
|
if system:
|
|
messages.append({"role": "system", "content": system})
|
|
messages.append({"role": "user", "content": prompt})
|
|
|
|
body: dict[str, Any] = {
|
|
"model": model_id,
|
|
"messages": messages,
|
|
"max_tokens": max_tokens,
|
|
"temperature": temperature,
|
|
}
|
|
|
|
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
resp = await client.post(f"{MODEL_BOSS_URL}/v1/chat/completions", json=body)
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
return _extract_message_text(data)
|
|
|
|
|
|
def score_predictions(
|
|
predictions: dict[str, Any],
|
|
ground_truth: dict[str, Any],
|
|
) -> tuple[int, int]:
|
|
"""Return (correct, total) for a single sample.
|
|
|
|
Compares predictions[k] == ground_truth[k] for every k in ground_truth.
|
|
Missing keys count as wrong.
|
|
"""
|
|
total = len(ground_truth)
|
|
correct = sum(
|
|
predictions.get(gate) == label for gate, label in ground_truth.items()
|
|
)
|
|
return correct, total
|
|
|
|
|
|
def aggregate_results(
|
|
sample_results: list[SampleResult],
|
|
all_gates: list[str],
|
|
model_id: str,
|
|
suite_name: str,
|
|
*,
|
|
metrics: dict[str, float] | None = None,
|
|
) -> BenchmarkResult:
|
|
"""Compute aggregate metrics from a list of SampleResults."""
|
|
valid = [s for s in sample_results if s.error is None]
|
|
n = len(valid)
|
|
|
|
if n == 0:
|
|
return BenchmarkResult(
|
|
model_id=model_id,
|
|
suite_name=suite_name,
|
|
timestamp=datetime.now(timezone.utc).isoformat(),
|
|
samples=len(sample_results),
|
|
overall_accuracy=0.0,
|
|
per_gate_accuracy={g: 0.0 for g in all_gates},
|
|
avg_inference_time_s=0.0,
|
|
sample_results=sample_results,
|
|
metrics=metrics or {},
|
|
)
|
|
|
|
total_correct = sum(s.correct for s in valid)
|
|
total_attrs = sum(s.total for s in valid)
|
|
overall_accuracy = total_correct / total_attrs if total_attrs > 0 else 0.0
|
|
avg_time = sum(s.inference_time_s for s in valid) / n
|
|
|
|
per_gate: dict[str, float] = {}
|
|
for gate in all_gates:
|
|
gate_correct = sum(
|
|
1
|
|
for s in valid
|
|
if gate in s.ground_truth and s.predictions.get(gate) == s.ground_truth[gate]
|
|
)
|
|
gate_total = sum(1 for s in valid if gate in s.ground_truth)
|
|
per_gate[gate] = gate_correct / gate_total if gate_total > 0 else 0.0
|
|
|
|
return BenchmarkResult(
|
|
model_id=model_id,
|
|
suite_name=suite_name,
|
|
timestamp=datetime.now(timezone.utc).isoformat(),
|
|
samples=len(sample_results),
|
|
overall_accuracy=overall_accuracy,
|
|
per_gate_accuracy=per_gate,
|
|
avg_inference_time_s=avg_time,
|
|
sample_results=sample_results,
|
|
metrics=metrics or {},
|
|
)
|