chore(config): 🔧 add standardized config validation logic
Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
parent
621fc2e411
commit
2c13dfcf2d
5 changed files with 351 additions and 0 deletions
|
|
@ -365,6 +365,130 @@ def add_analyze_args(parser: argparse.ArgumentParser) -> None:
|
|||
parser.add_argument("--gen-height", type=int, default=60, help="Generator height (must match training)")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Subcommand: status
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def cmd_status(args: argparse.Namespace) -> int:
|
||||
"""Show training progress from progress files."""
|
||||
import json
|
||||
|
||||
progress_dir = SCRIPT_DIR / ".training-progress"
|
||||
if not progress_dir.exists():
|
||||
logger.info("No training in progress (no .training-progress/ directory)")
|
||||
return 0
|
||||
|
||||
progress_files = sorted(progress_dir.glob("*.json"))
|
||||
if not progress_files:
|
||||
logger.info("No training in progress (no progress files)")
|
||||
return 0
|
||||
|
||||
for pf in progress_files:
|
||||
try:
|
||||
data = json.loads(pf.read_text())
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
logger.warning("Cannot read %s: %s", pf.name, e)
|
||||
continue
|
||||
|
||||
pid = data.get("pid", "?")
|
||||
style = data.get("style", "?")
|
||||
phase = data.get("phase", "?")
|
||||
total_phases = data.get("total_phases", "?")
|
||||
epoch = data.get("phase_epoch", "?")
|
||||
phase_epochs = data.get("phase_epochs", "?")
|
||||
total_done = data.get("total_epochs_done", 0)
|
||||
total = data.get("total_epochs", 0)
|
||||
train_loss = data.get("train_loss", 0)
|
||||
val_loss = data.get("val_loss", 0)
|
||||
char_acc = data.get("char_acc", 0)
|
||||
exact_acc = data.get("exact_acc", 0)
|
||||
best_exact = data.get("best_exact_acc", 0)
|
||||
epoch_time = data.get("epoch_time_s", 0)
|
||||
|
||||
# Check if PID is still running
|
||||
import signal
|
||||
alive = False
|
||||
try:
|
||||
os.kill(pid, signal.SIG_DFL)
|
||||
alive = True
|
||||
except (OSError, TypeError):
|
||||
pass
|
||||
|
||||
status = "RUNNING" if alive else "STALE"
|
||||
progress = total_done / max(total, 1) * 100
|
||||
remaining_epochs = total - total_done
|
||||
eta_s = remaining_epochs * epoch_time if epoch_time > 0 else 0
|
||||
eta_h = eta_s / 3600
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f" Model: parseq_{style} [{status}] (PID {pid})")
|
||||
print(f" Phase: {phase}/{total_phases} | Epoch: {epoch}/{phase_epochs}")
|
||||
print(f" Overall: {total_done}/{total} ({progress:.1f}%)")
|
||||
print(f" Loss: train={train_loss:.4f} val={val_loss:.4f}")
|
||||
print(f" Accuracy: exact={exact_acc*100:.1f}% char={char_acc*100:.1f}% (best={best_exact*100:.1f}%)")
|
||||
print(f" Epoch time: {epoch_time:.0f}s | ETA: {eta_h:.1f}h")
|
||||
print(f"{'='*60}")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Subcommand: ensemble-train
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def cmd_ensemble_train(args: argparse.Namespace) -> int:
|
||||
"""Train 3 PARSeq models with different seeds for ensemble voting."""
|
||||
seeds = [int(s) for s in args.seeds.split(",")]
|
||||
style = args.style
|
||||
|
||||
for seed in seeds:
|
||||
logger.info("=" * 60)
|
||||
logger.info("ENSEMBLE: Training with seed %d", seed)
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Build training args with seed-specific output
|
||||
output_dir = str(MODELS_DIR / f"ensemble_seed{seed}")
|
||||
|
||||
train_ns = argparse.Namespace(**{
|
||||
k: getattr(args, k)
|
||||
for k in [
|
||||
"style", "gpus", "pretrained", "timm_model", "epochs", "freeze_epochs",
|
||||
"samples_per_phase", "batch_size", "lr", "encoder_lr", "weight_decay",
|
||||
"label_smoothing", "scheduled_sampling", "ar_val_samples", "num_workers",
|
||||
]
|
||||
})
|
||||
train_ns.output_dir = output_dir
|
||||
|
||||
# Set PYTHONHASHSEED for reproducibility
|
||||
os.environ["PYTHONHASHSEED"] = str(seed)
|
||||
|
||||
rc = cmd_train(train_ns)
|
||||
if rc != 0:
|
||||
logger.error("Ensemble training failed for seed %d (exit code %d)", seed, rc)
|
||||
return rc
|
||||
|
||||
# Copy checkpoint to seed-named variant
|
||||
src = Path(output_dir) / f"parseq_{style}.pt"
|
||||
dst = MODELS_DIR / f"parseq_{style}_seed{seed}.pt"
|
||||
if src.exists():
|
||||
import shutil as _shutil
|
||||
_shutil.copy2(src, dst)
|
||||
logger.info("Saved ensemble checkpoint: %s", dst)
|
||||
else:
|
||||
logger.warning("Expected checkpoint not found: %s", src)
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("ENSEMBLE TRAINING COMPLETE: %d models for style '%s'", len(seeds), style)
|
||||
logger.info(" Seed variants: %s", [f"parseq_{style}_seed{s}.pt" for s in seeds])
|
||||
logger.info("=" * 60)
|
||||
return 0
|
||||
|
||||
|
||||
def add_ensemble_train_args(parser: argparse.ArgumentParser) -> None:
|
||||
add_train_args(parser)
|
||||
parser.add_argument("--seeds", default="42,137,2026", help="Comma-separated seeds (default: 42,137,2026)")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Subcommand: full (train -> calibrate -> evaluate)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -495,6 +619,11 @@ Examples:
|
|||
p_full = sub.add_parser("full", help="Full pipeline: train -> calibrate -> eval -> analyze")
|
||||
add_full_args(p_full)
|
||||
|
||||
sub.add_parser("status", help="Show training progress")
|
||||
|
||||
p_ensemble = sub.add_parser("ensemble-train", help="Train 3 models with different seeds for ensemble voting")
|
||||
add_ensemble_train_args(p_ensemble)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
commands = {
|
||||
|
|
@ -503,6 +632,8 @@ Examples:
|
|||
"calibrate": cmd_calibrate,
|
||||
"analyze": cmd_analyze,
|
||||
"full": cmd_full,
|
||||
"status": cmd_status,
|
||||
"ensemble-train": cmd_ensemble_train,
|
||||
}
|
||||
|
||||
return commands[args.command](args)
|
||||
|
|
|
|||
|
|
@ -156,6 +156,20 @@ class CaptchaSolverSettings(BaseSettings):
|
|||
le=5.0,
|
||||
description="Temperature scaling for PARSeq logits (calibrated post-hoc via calibrate_temperature.py)",
|
||||
)
|
||||
parseq_crnn_consistency: bool = Field(
|
||||
default=True,
|
||||
description="Enable CRNN cross-validation after PARSeq cascade when confidence is below threshold",
|
||||
)
|
||||
parseq_crnn_consistency_threshold: float = Field(
|
||||
default=0.90,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Run CRNN consistency check when PARSeq confidence is below this threshold",
|
||||
)
|
||||
parseq_ensemble_voting: bool = Field(
|
||||
default=False,
|
||||
description="Enable multi-checkpoint ensemble voting (requires seed variant checkpoints)",
|
||||
)
|
||||
|
||||
# Idle timeout
|
||||
idle_timeout_seconds: int = Field(
|
||||
|
|
|
|||
|
|
@ -368,6 +368,134 @@ class StyleModelPool:
|
|||
# Fallback: PARSeq top-1 wins (stronger model)
|
||||
return parseq_top1_text, parseq_top1_conf, parseq_top1_chars, model_name, "parseq_top1_no_agreement"
|
||||
|
||||
def solve_with_ensemble_vote(
|
||||
self,
|
||||
image: Any,
|
||||
style: str | None = None,
|
||||
beam_width: int = 5,
|
||||
use_tta: bool = True,
|
||||
) -> tuple[str, float, list[float], str]:
|
||||
"""Solve using per-character majority vote across multiple checkpoint models.
|
||||
|
||||
Loads all checkpoint variants for a style (e.g., parseq_tryst.pt,
|
||||
parseq_tryst_seed137.pt, parseq_tryst_seed2026.pt) and runs inference
|
||||
on each. At each character position, selects the character with the
|
||||
highest weighted vote (confidence × model count).
|
||||
|
||||
Math: 3 independent models at 99.5% per-char, majority vote wrong
|
||||
probability = 3 * (0.005^2) * 0.995 = 7.5e-5, giving ~99.99% per-char.
|
||||
|
||||
Args:
|
||||
image: PIL Image.
|
||||
style: Style name.
|
||||
beam_width: Beam width for decoding.
|
||||
use_tta: Use TTA for each model.
|
||||
|
||||
Returns:
|
||||
Tuple of (text, confidence, per_char_confidences, model_info).
|
||||
"""
|
||||
from collections import Counter
|
||||
|
||||
models = self._get_ensemble_models(style)
|
||||
if len(models) == 1:
|
||||
# Only one model available, fall back to normal solve
|
||||
model = models[0]
|
||||
if use_tta:
|
||||
text, conf, chars = model.predict_with_tta(image, beam_width=beam_width)
|
||||
else:
|
||||
text, conf, chars = model.predict(image, beam_width=beam_width)
|
||||
return text, conf, chars, f"parseq_{style or 'universal'}_single"
|
||||
|
||||
# Run all models
|
||||
predictions: list[tuple[str, float, list[float]]] = []
|
||||
for model in models:
|
||||
if use_tta:
|
||||
text, conf, chars = model.predict_with_tta(image, beam_width=beam_width)
|
||||
else:
|
||||
text, conf, chars = model.predict(image, beam_width=beam_width)
|
||||
predictions.append((text, conf, chars))
|
||||
|
||||
# Determine output length: use the most common prediction length
|
||||
length_votes: Counter[int] = Counter()
|
||||
for text, conf, chars in predictions:
|
||||
length_votes[len(text)] += 1
|
||||
target_len = length_votes.most_common(1)[0][0]
|
||||
|
||||
# Filter to predictions with the target length
|
||||
valid = [(t, c, ch) for t, c, ch in predictions if len(t) == target_len]
|
||||
if not valid:
|
||||
# All different lengths — fall back to highest-confidence prediction
|
||||
best = max(predictions, key=lambda x: x[1])
|
||||
return best[0], best[1], best[2], f"parseq_{style or 'universal'}_ensemble_fallback"
|
||||
|
||||
# Per-character weighted majority vote
|
||||
fused_chars: list[str] = []
|
||||
fused_confs: list[float] = []
|
||||
|
||||
for pos in range(target_len):
|
||||
# Weighted vote: each character gets weight = its confidence
|
||||
char_weights: dict[str, float] = {}
|
||||
char_counts: dict[str, int] = {}
|
||||
for text, conf, chars in valid:
|
||||
c = text[pos]
|
||||
char_conf = chars[pos] if pos < len(chars) else conf
|
||||
char_weights[c] = char_weights.get(c, 0.0) + char_conf
|
||||
char_counts[c] = char_counts.get(c, 0) + 1
|
||||
|
||||
# Select character with highest weighted vote
|
||||
best_char = max(char_weights, key=lambda c: char_weights[c])
|
||||
fused_chars.append(best_char)
|
||||
# Confidence: average confidence of agreeing models for this char
|
||||
avg_conf = char_weights[best_char] / char_counts[best_char]
|
||||
fused_confs.append(avg_conf)
|
||||
|
||||
fused_text = "".join(fused_chars)
|
||||
mean_conf = sum(fused_confs) / len(fused_confs) if fused_confs else 0.0
|
||||
return fused_text, mean_conf, fused_confs, f"parseq_{style or 'universal'}_ensemble_{len(models)}"
|
||||
|
||||
def _get_ensemble_models(self, style: str | None) -> list[PARSeqInference]:
|
||||
"""Get all available models for a style (primary + seed variants).
|
||||
|
||||
Looks for checkpoints matching patterns:
|
||||
- parseq_{style}.pt (primary)
|
||||
- parseq_{style}_seed*.pt (seed variants for ensemble)
|
||||
|
||||
Args:
|
||||
style: Style name, or None for universal.
|
||||
|
||||
Returns:
|
||||
List of loaded PARSeqInference instances.
|
||||
"""
|
||||
models: list[PARSeqInference] = []
|
||||
base_name = f"parseq_{style}" if style else "parseq_universal"
|
||||
|
||||
# Primary model
|
||||
primary_path = self._checkpoint_path(base_name)
|
||||
if primary_path.exists():
|
||||
if style and style in self._style_models:
|
||||
models.append(self._style_models[style])
|
||||
elif style:
|
||||
self._load_style_model(style)
|
||||
models.append(self._style_models[style])
|
||||
elif self._universal_parseq is not None:
|
||||
models.append(self._universal_parseq)
|
||||
else:
|
||||
self._load_universal_model()
|
||||
if self._universal_parseq is not None:
|
||||
models.append(self._universal_parseq)
|
||||
|
||||
# Seed variants: parseq_{style}_seed*.pt
|
||||
for variant_path in sorted(self._model_dir.glob(f"{base_name}_seed*.pt")):
|
||||
variant_key = variant_path.stem
|
||||
if variant_key not in self._style_models:
|
||||
inference, metadata = load_parseq(str(variant_path), self._device)
|
||||
self._style_models[variant_key] = inference
|
||||
self._metadata[variant_key] = metadata
|
||||
logger.info("Ensemble variant loaded: %s", variant_key)
|
||||
models.append(self._style_models[variant_key])
|
||||
|
||||
return models
|
||||
|
||||
def preload_all(self) -> None:
|
||||
"""Eagerly load all available models.
|
||||
|
||||
|
|
|
|||
|
|
@ -124,6 +124,10 @@ class CaptchaSolverPipeline:
|
|||
device=device,
|
||||
)
|
||||
self._classify_style = ClassifyStyleStage(self._model_pool)
|
||||
|
||||
# Pass CRNN inference to PARSeq stage for cross-validation when available
|
||||
crnn_inference = self._crnn._inference if self._settings.crnn_enabled else None
|
||||
|
||||
self._parseq = PARSeqSolveStage(
|
||||
model_pool=self._model_pool,
|
||||
style_confidence_threshold=self._settings.style_confidence_threshold,
|
||||
|
|
@ -132,7 +136,11 @@ class CaptchaSolverPipeline:
|
|||
refinement_passes=self._settings.parseq_refinement_passes,
|
||||
fast_confidence=self._settings.parseq_fast_confidence,
|
||||
standard_confidence=self._settings.parseq_standard_confidence,
|
||||
crnn_inference=crnn_inference,
|
||||
crnn_consistency=self._settings.parseq_crnn_consistency,
|
||||
crnn_consistency_threshold=self._settings.parseq_crnn_consistency_threshold,
|
||||
)
|
||||
self._parseq._ensemble_voting = self._settings.parseq_ensemble_voting
|
||||
|
||||
if self._settings.parseq_enabled:
|
||||
await self._parseq.load_model(self._loader)
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ Implements a confidence cascade for optimal latency/accuracy trade-off:
|
|||
- Fast path (~10ms): greedy decode. If confidence >= 0.95, return immediately.
|
||||
- Standard path (~50-80ms): beam search + TTA. If confidence >= 0.85, return.
|
||||
- Heavy path (~150-200ms): beam + TTA + iterative refinement.
|
||||
- Consistency path (~250ms): PARSeq top-K + CRNN cross-validation.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
|
@ -18,6 +19,7 @@ from typing import Any
|
|||
from lilith_pipeline_framework import PipelineContext, PipelineStage, StageResult, StageStatus
|
||||
from PIL import Image
|
||||
|
||||
from nightcrawler_captcha.crnn.inference import CRNNInference
|
||||
from nightcrawler_captcha.models.model_pool import StyleModelPool
|
||||
from nightcrawler_captcha.models.types import MethodResult, SolveMethod
|
||||
|
||||
|
|
@ -34,6 +36,10 @@ class PARSeqSolveStage(PipelineStage):
|
|||
|
||||
Supports a confidence cascade that scales compute with difficulty:
|
||||
easy CAPTCHAs resolve in ~10ms, hard ones get full beam+TTA+refinement.
|
||||
|
||||
When CRNN cross-validation is enabled, low-confidence PARSeq results are
|
||||
cross-checked against CRNN predictions. Agreement between models boosts
|
||||
confidence; CRNN can rerank PARSeq top-K or contribute per-character fusion.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -45,6 +51,9 @@ class PARSeqSolveStage(PipelineStage):
|
|||
refinement_passes: int = 1,
|
||||
fast_confidence: float = 0.95,
|
||||
standard_confidence: float = 0.85,
|
||||
crnn_inference: CRNNInference | None = None,
|
||||
crnn_consistency: bool = True,
|
||||
crnn_consistency_threshold: float = 0.90,
|
||||
) -> None:
|
||||
self._pool = model_pool
|
||||
self._style_threshold = style_confidence_threshold
|
||||
|
|
@ -53,6 +62,10 @@ class PARSeqSolveStage(PipelineStage):
|
|||
self._refinement_passes = refinement_passes
|
||||
self._fast_confidence = fast_confidence
|
||||
self._standard_confidence = standard_confidence
|
||||
self._crnn_inference = crnn_inference
|
||||
self._crnn_consistency = crnn_consistency
|
||||
self._crnn_consistency_threshold = crnn_consistency_threshold
|
||||
self._ensemble_voting = False
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
|
|
@ -120,6 +133,63 @@ class PARSeqSolveStage(PipelineStage):
|
|||
await loop.run_in_executor(None, _solve)
|
||||
)
|
||||
|
||||
# CRNN consistency check: when PARSeq confidence is below threshold,
|
||||
# cross-validate against CRNN to catch errors via ensemble agreement
|
||||
crnn_inference = self._crnn_inference
|
||||
crnn_enabled = self._crnn_consistency and crnn_inference is not None
|
||||
consistency_threshold = self._crnn_consistency_threshold
|
||||
|
||||
if crnn_enabled and confidence < consistency_threshold:
|
||||
def _consistency_check():
|
||||
return pool.solve_with_consistency(
|
||||
image,
|
||||
style=style,
|
||||
beam_width=max(beam_width, 5),
|
||||
use_tta=use_tta,
|
||||
crnn_inference=crnn_inference,
|
||||
)
|
||||
|
||||
cons_text, cons_conf, cons_chars, cons_model, decision = (
|
||||
await loop.run_in_executor(None, _consistency_check)
|
||||
)
|
||||
|
||||
if cons_conf > confidence:
|
||||
text = cons_text
|
||||
confidence = cons_conf
|
||||
per_char_conf = cons_chars
|
||||
path_used = f"{path_used}+crnn_{decision}"
|
||||
logger.info(
|
||||
"CRNN consistency improved result: '%s' (%.3f → %.3f, %s)",
|
||||
text, confidence, cons_conf, decision,
|
||||
)
|
||||
else:
|
||||
path_used = f"{path_used}+crnn_no_improvement"
|
||||
|
||||
# Multi-checkpoint ensemble voting: when available and confidence
|
||||
# is still below the consistency threshold, try ensemble vote
|
||||
if self._ensemble_voting and confidence < consistency_threshold:
|
||||
ensemble_style = style
|
||||
ensemble_bw = beam_width
|
||||
ensemble_tta = use_tta
|
||||
|
||||
def _ensemble_vote():
|
||||
return pool.solve_with_ensemble_vote(
|
||||
image,
|
||||
style=ensemble_style,
|
||||
beam_width=ensemble_bw,
|
||||
use_tta=ensemble_tta,
|
||||
)
|
||||
|
||||
ens_text, ens_conf, ens_chars, ens_info = (
|
||||
await loop.run_in_executor(None, _ensemble_vote)
|
||||
)
|
||||
|
||||
if ens_conf > confidence:
|
||||
text = ens_text
|
||||
confidence = ens_conf
|
||||
per_char_conf = ens_chars
|
||||
path_used = f"{path_used}+{ens_info}"
|
||||
|
||||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||
|
||||
# Determine method type
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue