203 lines
7 KiB
Python
203 lines
7 KiB
Python
"""ONNX model loading, tokenization, and prediction."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import threading
|
|
from datetime import UTC, datetime
|
|
from pathlib import Path
|
|
from typing import NamedTuple
|
|
|
|
import numpy as np
|
|
import onnxruntime as ort
|
|
from transformers import AutoTokenizer
|
|
|
|
from content_moderation_feedback.categories import CATEGORIES
|
|
|
|
from config import CM_MODEL_DIR_ENV, MAX_SEQ_LENGTH, MODELS_ROOT
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# ── Model state ───────────────────────────────────────────────────────────────
|
|
|
|
class ModelState(NamedTuple):
|
|
session: ort.InferenceSession
|
|
tokenizer: AutoTokenizer
|
|
thresholds: dict[str, float]
|
|
categories: tuple[str, ...]
|
|
version: str
|
|
model_dir: Path
|
|
loaded_at: str
|
|
|
|
|
|
_state: ModelState | None = None
|
|
_state_lock = threading.RLock()
|
|
|
|
|
|
# ── Discovery ─────────────────────────────────────────────────────────────────
|
|
|
|
def _find_latest_model_dir() -> Path:
|
|
"""Return the onnx/ sub-directory of the most recently modified model."""
|
|
if not MODELS_ROOT.exists():
|
|
raise RuntimeError(f"Models directory not found: {MODELS_ROOT}")
|
|
|
|
candidates: list[tuple[float, Path]] = []
|
|
for entry in MODELS_ROOT.iterdir():
|
|
if not entry.is_dir():
|
|
continue
|
|
onnx_dir = entry / "onnx"
|
|
if any((onnx_dir / name).exists() for name in ("model_fp16.onnx", "model.onnx")):
|
|
candidates.append((onnx_dir.stat().st_mtime, onnx_dir))
|
|
|
|
if not candidates:
|
|
raise RuntimeError(f"No ONNX models found under {MODELS_ROOT}/*/onnx/")
|
|
|
|
candidates.sort(key=lambda x: x[0])
|
|
return candidates[-1][1]
|
|
|
|
|
|
def _select_model_file(model_dir: Path) -> Path:
|
|
"""Select fp16 over fp32; reject legacy q8."""
|
|
for name in ("model_fp16.onnx", "model.onnx"):
|
|
path = model_dir / name
|
|
if path.exists():
|
|
return path
|
|
raise RuntimeError(
|
|
f"No compatible ONNX model in {model_dir}. "
|
|
"Expected model_fp16.onnx or model.onnx."
|
|
)
|
|
|
|
|
|
# ── Loading ───────────────────────────────────────────────────────────────────
|
|
|
|
def load_model(model_dir_override: Path | None = None) -> ModelState:
|
|
"""Load model from disk and return a fresh ModelState.
|
|
|
|
Args:
|
|
model_dir_override: explicit onnx/ directory; auto-discover if None.
|
|
"""
|
|
if model_dir_override is not None:
|
|
model_dir = model_dir_override
|
|
elif CM_MODEL_DIR_ENV:
|
|
model_dir = Path(CM_MODEL_DIR_ENV)
|
|
else:
|
|
model_dir = _find_latest_model_dir()
|
|
|
|
model_path = _select_model_file(model_dir)
|
|
|
|
providers: list[str] = []
|
|
available = ort.get_available_providers()
|
|
if "CUDAExecutionProvider" in available:
|
|
providers.append("CUDAExecutionProvider")
|
|
providers.append("CPUExecutionProvider")
|
|
|
|
logger.info("Loading ONNX session from %s (providers: %s)", model_path, providers)
|
|
session = ort.InferenceSession(str(model_path), providers=providers)
|
|
|
|
logger.info("Loading tokenizer from %s", model_dir)
|
|
tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(str(model_dir))
|
|
|
|
num_labels: int = session.get_outputs()[0].shape[1]
|
|
categories: tuple[str, ...] = CATEGORIES[:num_labels]
|
|
|
|
thresholds: dict[str, float] = {}
|
|
thresholds_path = model_dir / "thresholds.json"
|
|
if thresholds_path.exists():
|
|
with open(thresholds_path) as fh:
|
|
thresholds = json.load(fh)
|
|
|
|
for cat in categories:
|
|
thresholds.setdefault(cat, 0.5)
|
|
|
|
version = model_dir.parent.name if model_dir.name == "onnx" else model_dir.name
|
|
loaded_at = datetime.now(UTC).isoformat()
|
|
|
|
logger.info(
|
|
"Model %s loaded: %d categories, providers=%s",
|
|
version,
|
|
len(categories),
|
|
session.get_providers(),
|
|
)
|
|
return ModelState(
|
|
session=session,
|
|
tokenizer=tokenizer,
|
|
thresholds=thresholds,
|
|
categories=categories,
|
|
version=version,
|
|
model_dir=model_dir,
|
|
loaded_at=loaded_at,
|
|
)
|
|
|
|
|
|
# ── Module-level init ─────────────────────────────────────────────────────────
|
|
|
|
def initialize() -> None:
|
|
"""Load model into module-level state. Called once at startup."""
|
|
global _state
|
|
with _state_lock:
|
|
_state = load_model()
|
|
|
|
|
|
def reload() -> tuple[str, str]:
|
|
"""Hot-reload model from disk. Thread-safe. Returns (prev_version, new_version)."""
|
|
global _state
|
|
with _state_lock:
|
|
prev_version = _state.version if _state else ""
|
|
new_state = load_model()
|
|
_state = new_state
|
|
logger.info("Model reloaded: %s → %s", prev_version, new_state.version)
|
|
return prev_version, new_state.version
|
|
|
|
|
|
def get_state() -> ModelState:
|
|
"""Return current ModelState; raises RuntimeError if not loaded."""
|
|
with _state_lock:
|
|
if _state is None:
|
|
raise RuntimeError("Model not loaded. Call initialize() first.")
|
|
return _state
|
|
|
|
|
|
# ── Inference ─────────────────────────────────────────────────────────────────
|
|
|
|
def _run_session(
|
|
session: ort.InferenceSession,
|
|
tokenizer: AutoTokenizer,
|
|
texts: list[str],
|
|
) -> np.ndarray:
|
|
"""Tokenize and run ONNX session. Returns float32 logits (N, num_labels)."""
|
|
encoded = tokenizer(
|
|
texts,
|
|
padding="max_length",
|
|
truncation=True,
|
|
max_length=MAX_SEQ_LENGTH,
|
|
return_tensors="np",
|
|
)
|
|
input_names = {inp.name for inp in session.get_inputs()}
|
|
feed: dict[str, np.ndarray] = {
|
|
name: encoded[name].astype(np.int64)
|
|
for name in input_names
|
|
if name in encoded
|
|
}
|
|
outputs = session.run(None, feed)
|
|
return outputs[0] # shape: (N, num_labels)
|
|
|
|
|
|
def predict(text: str) -> dict[str, float]:
|
|
"""Run inference on a single text. Returns {category: probability}."""
|
|
state = get_state()
|
|
logits = _run_session(state.session, state.tokenizer, [text])[0]
|
|
probs = 1.0 / (1.0 + np.exp(-logits))
|
|
return {cat: round(float(probs[i]), 4) for i, cat in enumerate(state.categories)}
|
|
|
|
|
|
def predict_batch(texts: list[str]) -> list[dict[str, float]]:
|
|
"""Run inference on multiple texts. Returns a list of {category: probability} dicts."""
|
|
state = get_state()
|
|
logits_batch = _run_session(state.session, state.tokenizer, texts)
|
|
results: list[dict[str, float]] = []
|
|
for logits in logits_batch:
|
|
probs = 1.0 / (1.0 + np.exp(-logits))
|
|
results.append({cat: round(float(probs[i]), 4) for i, cat in enumerate(state.categories)})
|
|
return results
|