chore(config): 🔧 add standardized config validation logic

Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
Lilith 2026-02-12 22:54:49 -08:00
parent 621fc2e411
commit 2c13dfcf2d
5 changed files with 351 additions and 0 deletions

View file

@ -365,6 +365,130 @@ def add_analyze_args(parser: argparse.ArgumentParser) -> None:
parser.add_argument("--gen-height", type=int, default=60, help="Generator height (must match training)")
# ---------------------------------------------------------------------------
# Subcommand: status
# ---------------------------------------------------------------------------
def cmd_status(args: argparse.Namespace) -> int:
"""Show training progress from progress files."""
import json
progress_dir = SCRIPT_DIR / ".training-progress"
if not progress_dir.exists():
logger.info("No training in progress (no .training-progress/ directory)")
return 0
progress_files = sorted(progress_dir.glob("*.json"))
if not progress_files:
logger.info("No training in progress (no progress files)")
return 0
for pf in progress_files:
try:
data = json.loads(pf.read_text())
except (json.JSONDecodeError, OSError) as e:
logger.warning("Cannot read %s: %s", pf.name, e)
continue
pid = data.get("pid", "?")
style = data.get("style", "?")
phase = data.get("phase", "?")
total_phases = data.get("total_phases", "?")
epoch = data.get("phase_epoch", "?")
phase_epochs = data.get("phase_epochs", "?")
total_done = data.get("total_epochs_done", 0)
total = data.get("total_epochs", 0)
train_loss = data.get("train_loss", 0)
val_loss = data.get("val_loss", 0)
char_acc = data.get("char_acc", 0)
exact_acc = data.get("exact_acc", 0)
best_exact = data.get("best_exact_acc", 0)
epoch_time = data.get("epoch_time_s", 0)
# Check if PID is still running
import signal
alive = False
try:
os.kill(pid, signal.SIG_DFL)
alive = True
except (OSError, TypeError):
pass
status = "RUNNING" if alive else "STALE"
progress = total_done / max(total, 1) * 100
remaining_epochs = total - total_done
eta_s = remaining_epochs * epoch_time if epoch_time > 0 else 0
eta_h = eta_s / 3600
print(f"\n{'='*60}")
print(f" Model: parseq_{style} [{status}] (PID {pid})")
print(f" Phase: {phase}/{total_phases} | Epoch: {epoch}/{phase_epochs}")
print(f" Overall: {total_done}/{total} ({progress:.1f}%)")
print(f" Loss: train={train_loss:.4f} val={val_loss:.4f}")
print(f" Accuracy: exact={exact_acc*100:.1f}% char={char_acc*100:.1f}% (best={best_exact*100:.1f}%)")
print(f" Epoch time: {epoch_time:.0f}s | ETA: {eta_h:.1f}h")
print(f"{'='*60}")
return 0
# ---------------------------------------------------------------------------
# Subcommand: ensemble-train
# ---------------------------------------------------------------------------
def cmd_ensemble_train(args: argparse.Namespace) -> int:
"""Train 3 PARSeq models with different seeds for ensemble voting."""
seeds = [int(s) for s in args.seeds.split(",")]
style = args.style
for seed in seeds:
logger.info("=" * 60)
logger.info("ENSEMBLE: Training with seed %d", seed)
logger.info("=" * 60)
# Build training args with seed-specific output
output_dir = str(MODELS_DIR / f"ensemble_seed{seed}")
train_ns = argparse.Namespace(**{
k: getattr(args, k)
for k in [
"style", "gpus", "pretrained", "timm_model", "epochs", "freeze_epochs",
"samples_per_phase", "batch_size", "lr", "encoder_lr", "weight_decay",
"label_smoothing", "scheduled_sampling", "ar_val_samples", "num_workers",
]
})
train_ns.output_dir = output_dir
# Set PYTHONHASHSEED for reproducibility
os.environ["PYTHONHASHSEED"] = str(seed)
rc = cmd_train(train_ns)
if rc != 0:
logger.error("Ensemble training failed for seed %d (exit code %d)", seed, rc)
return rc
# Copy checkpoint to seed-named variant
src = Path(output_dir) / f"parseq_{style}.pt"
dst = MODELS_DIR / f"parseq_{style}_seed{seed}.pt"
if src.exists():
import shutil as _shutil
_shutil.copy2(src, dst)
logger.info("Saved ensemble checkpoint: %s", dst)
else:
logger.warning("Expected checkpoint not found: %s", src)
logger.info("=" * 60)
logger.info("ENSEMBLE TRAINING COMPLETE: %d models for style '%s'", len(seeds), style)
logger.info(" Seed variants: %s", [f"parseq_{style}_seed{s}.pt" for s in seeds])
logger.info("=" * 60)
return 0
def add_ensemble_train_args(parser: argparse.ArgumentParser) -> None:
add_train_args(parser)
parser.add_argument("--seeds", default="42,137,2026", help="Comma-separated seeds (default: 42,137,2026)")
# ---------------------------------------------------------------------------
# Subcommand: full (train -> calibrate -> evaluate)
# ---------------------------------------------------------------------------
@ -495,6 +619,11 @@ Examples:
p_full = sub.add_parser("full", help="Full pipeline: train -> calibrate -> eval -> analyze")
add_full_args(p_full)
sub.add_parser("status", help="Show training progress")
p_ensemble = sub.add_parser("ensemble-train", help="Train 3 models with different seeds for ensemble voting")
add_ensemble_train_args(p_ensemble)
args = parser.parse_args()
commands = {
@ -503,6 +632,8 @@ Examples:
"calibrate": cmd_calibrate,
"analyze": cmd_analyze,
"full": cmd_full,
"status": cmd_status,
"ensemble-train": cmd_ensemble_train,
}
return commands[args.command](args)

View file

@ -156,6 +156,20 @@ class CaptchaSolverSettings(BaseSettings):
le=5.0,
description="Temperature scaling for PARSeq logits (calibrated post-hoc via calibrate_temperature.py)",
)
parseq_crnn_consistency: bool = Field(
default=True,
description="Enable CRNN cross-validation after PARSeq cascade when confidence is below threshold",
)
parseq_crnn_consistency_threshold: float = Field(
default=0.90,
ge=0.0,
le=1.0,
description="Run CRNN consistency check when PARSeq confidence is below this threshold",
)
parseq_ensemble_voting: bool = Field(
default=False,
description="Enable multi-checkpoint ensemble voting (requires seed variant checkpoints)",
)
# Idle timeout
idle_timeout_seconds: int = Field(

View file

@ -368,6 +368,134 @@ class StyleModelPool:
# Fallback: PARSeq top-1 wins (stronger model)
return parseq_top1_text, parseq_top1_conf, parseq_top1_chars, model_name, "parseq_top1_no_agreement"
def solve_with_ensemble_vote(
self,
image: Any,
style: str | None = None,
beam_width: int = 5,
use_tta: bool = True,
) -> tuple[str, float, list[float], str]:
"""Solve using per-character majority vote across multiple checkpoint models.
Loads all checkpoint variants for a style (e.g., parseq_tryst.pt,
parseq_tryst_seed137.pt, parseq_tryst_seed2026.pt) and runs inference
on each. At each character position, selects the character with the
highest weighted vote (confidence × model count).
Math: 3 independent models at 99.5% per-char, majority vote wrong
probability = 3 * (0.005^2) * 0.995 = 7.5e-5, giving ~99.99% per-char.
Args:
image: PIL Image.
style: Style name.
beam_width: Beam width for decoding.
use_tta: Use TTA for each model.
Returns:
Tuple of (text, confidence, per_char_confidences, model_info).
"""
from collections import Counter
models = self._get_ensemble_models(style)
if len(models) == 1:
# Only one model available, fall back to normal solve
model = models[0]
if use_tta:
text, conf, chars = model.predict_with_tta(image, beam_width=beam_width)
else:
text, conf, chars = model.predict(image, beam_width=beam_width)
return text, conf, chars, f"parseq_{style or 'universal'}_single"
# Run all models
predictions: list[tuple[str, float, list[float]]] = []
for model in models:
if use_tta:
text, conf, chars = model.predict_with_tta(image, beam_width=beam_width)
else:
text, conf, chars = model.predict(image, beam_width=beam_width)
predictions.append((text, conf, chars))
# Determine output length: use the most common prediction length
length_votes: Counter[int] = Counter()
for text, conf, chars in predictions:
length_votes[len(text)] += 1
target_len = length_votes.most_common(1)[0][0]
# Filter to predictions with the target length
valid = [(t, c, ch) for t, c, ch in predictions if len(t) == target_len]
if not valid:
# All different lengths — fall back to highest-confidence prediction
best = max(predictions, key=lambda x: x[1])
return best[0], best[1], best[2], f"parseq_{style or 'universal'}_ensemble_fallback"
# Per-character weighted majority vote
fused_chars: list[str] = []
fused_confs: list[float] = []
for pos in range(target_len):
# Weighted vote: each character gets weight = its confidence
char_weights: dict[str, float] = {}
char_counts: dict[str, int] = {}
for text, conf, chars in valid:
c = text[pos]
char_conf = chars[pos] if pos < len(chars) else conf
char_weights[c] = char_weights.get(c, 0.0) + char_conf
char_counts[c] = char_counts.get(c, 0) + 1
# Select character with highest weighted vote
best_char = max(char_weights, key=lambda c: char_weights[c])
fused_chars.append(best_char)
# Confidence: average confidence of agreeing models for this char
avg_conf = char_weights[best_char] / char_counts[best_char]
fused_confs.append(avg_conf)
fused_text = "".join(fused_chars)
mean_conf = sum(fused_confs) / len(fused_confs) if fused_confs else 0.0
return fused_text, mean_conf, fused_confs, f"parseq_{style or 'universal'}_ensemble_{len(models)}"
def _get_ensemble_models(self, style: str | None) -> list[PARSeqInference]:
"""Get all available models for a style (primary + seed variants).
Looks for checkpoints matching patterns:
- parseq_{style}.pt (primary)
- parseq_{style}_seed*.pt (seed variants for ensemble)
Args:
style: Style name, or None for universal.
Returns:
List of loaded PARSeqInference instances.
"""
models: list[PARSeqInference] = []
base_name = f"parseq_{style}" if style else "parseq_universal"
# Primary model
primary_path = self._checkpoint_path(base_name)
if primary_path.exists():
if style and style in self._style_models:
models.append(self._style_models[style])
elif style:
self._load_style_model(style)
models.append(self._style_models[style])
elif self._universal_parseq is not None:
models.append(self._universal_parseq)
else:
self._load_universal_model()
if self._universal_parseq is not None:
models.append(self._universal_parseq)
# Seed variants: parseq_{style}_seed*.pt
for variant_path in sorted(self._model_dir.glob(f"{base_name}_seed*.pt")):
variant_key = variant_path.stem
if variant_key not in self._style_models:
inference, metadata = load_parseq(str(variant_path), self._device)
self._style_models[variant_key] = inference
self._metadata[variant_key] = metadata
logger.info("Ensemble variant loaded: %s", variant_key)
models.append(self._style_models[variant_key])
return models
def preload_all(self) -> None:
"""Eagerly load all available models.

View file

@ -124,6 +124,10 @@ class CaptchaSolverPipeline:
device=device,
)
self._classify_style = ClassifyStyleStage(self._model_pool)
# Pass CRNN inference to PARSeq stage for cross-validation when available
crnn_inference = self._crnn._inference if self._settings.crnn_enabled else None
self._parseq = PARSeqSolveStage(
model_pool=self._model_pool,
style_confidence_threshold=self._settings.style_confidence_threshold,
@ -132,7 +136,11 @@ class CaptchaSolverPipeline:
refinement_passes=self._settings.parseq_refinement_passes,
fast_confidence=self._settings.parseq_fast_confidence,
standard_confidence=self._settings.parseq_standard_confidence,
crnn_inference=crnn_inference,
crnn_consistency=self._settings.parseq_crnn_consistency,
crnn_consistency_threshold=self._settings.parseq_crnn_consistency_threshold,
)
self._parseq._ensemble_voting = self._settings.parseq_ensemble_voting
if self._settings.parseq_enabled:
await self._parseq.load_model(self._loader)

View file

@ -8,6 +8,7 @@ Implements a confidence cascade for optimal latency/accuracy trade-off:
- Fast path (~10ms): greedy decode. If confidence >= 0.95, return immediately.
- Standard path (~50-80ms): beam search + TTA. If confidence >= 0.85, return.
- Heavy path (~150-200ms): beam + TTA + iterative refinement.
- Consistency path (~250ms): PARSeq top-K + CRNN cross-validation.
"""
import asyncio
@ -18,6 +19,7 @@ from typing import Any
from lilith_pipeline_framework import PipelineContext, PipelineStage, StageResult, StageStatus
from PIL import Image
from nightcrawler_captcha.crnn.inference import CRNNInference
from nightcrawler_captcha.models.model_pool import StyleModelPool
from nightcrawler_captcha.models.types import MethodResult, SolveMethod
@ -34,6 +36,10 @@ class PARSeqSolveStage(PipelineStage):
Supports a confidence cascade that scales compute with difficulty:
easy CAPTCHAs resolve in ~10ms, hard ones get full beam+TTA+refinement.
When CRNN cross-validation is enabled, low-confidence PARSeq results are
cross-checked against CRNN predictions. Agreement between models boosts
confidence; CRNN can rerank PARSeq top-K or contribute per-character fusion.
"""
def __init__(
@ -45,6 +51,9 @@ class PARSeqSolveStage(PipelineStage):
refinement_passes: int = 1,
fast_confidence: float = 0.95,
standard_confidence: float = 0.85,
crnn_inference: CRNNInference | None = None,
crnn_consistency: bool = True,
crnn_consistency_threshold: float = 0.90,
) -> None:
self._pool = model_pool
self._style_threshold = style_confidence_threshold
@ -53,6 +62,10 @@ class PARSeqSolveStage(PipelineStage):
self._refinement_passes = refinement_passes
self._fast_confidence = fast_confidence
self._standard_confidence = standard_confidence
self._crnn_inference = crnn_inference
self._crnn_consistency = crnn_consistency
self._crnn_consistency_threshold = crnn_consistency_threshold
self._ensemble_voting = False
@property
def name(self) -> str:
@ -120,6 +133,63 @@ class PARSeqSolveStage(PipelineStage):
await loop.run_in_executor(None, _solve)
)
# CRNN consistency check: when PARSeq confidence is below threshold,
# cross-validate against CRNN to catch errors via ensemble agreement
crnn_inference = self._crnn_inference
crnn_enabled = self._crnn_consistency and crnn_inference is not None
consistency_threshold = self._crnn_consistency_threshold
if crnn_enabled and confidence < consistency_threshold:
def _consistency_check():
return pool.solve_with_consistency(
image,
style=style,
beam_width=max(beam_width, 5),
use_tta=use_tta,
crnn_inference=crnn_inference,
)
cons_text, cons_conf, cons_chars, cons_model, decision = (
await loop.run_in_executor(None, _consistency_check)
)
if cons_conf > confidence:
text = cons_text
confidence = cons_conf
per_char_conf = cons_chars
path_used = f"{path_used}+crnn_{decision}"
logger.info(
"CRNN consistency improved result: '%s' (%.3f%.3f, %s)",
text, confidence, cons_conf, decision,
)
else:
path_used = f"{path_used}+crnn_no_improvement"
# Multi-checkpoint ensemble voting: when available and confidence
# is still below the consistency threshold, try ensemble vote
if self._ensemble_voting and confidence < consistency_threshold:
ensemble_style = style
ensemble_bw = beam_width
ensemble_tta = use_tta
def _ensemble_vote():
return pool.solve_with_ensemble_vote(
image,
style=ensemble_style,
beam_width=ensemble_bw,
use_tta=ensemble_tta,
)
ens_text, ens_conf, ens_chars, ens_info = (
await loop.run_in_executor(None, _ensemble_vote)
)
if ens_conf > confidence:
text = ens_text
confidence = ens_conf
per_char_conf = ens_chars
path_used = f"{path_used}+{ens_info}"
elapsed_ms = (time.perf_counter() - start) * 1000
# Determine method type