From 5d102e08a3e71b64aa072eb5ab4c19261231250b Mon Sep 17 00:00:00 2001 From: Lilith Date: Tue, 24 Feb 2026 20:08:28 -0800 Subject: [PATCH] =?UTF-8?q?chore(ml-service):=20=F0=9F=94=A7=20Optimize=20?= =?UTF-8?q?style=20classifier=20training=20loop=20with=20improved=20batchi?= =?UTF-8?q?ng=20and=20preprocessing?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Lilith Autocommit --- .../ml-service/train_style_classifier.py | 88 ++++++++++++++++++- 1 file changed, 86 insertions(+), 2 deletions(-) diff --git a/tools/talent-scout/packages/captcha-solver/ml-service/train_style_classifier.py b/tools/talent-scout/packages/captcha-solver/ml-service/train_style_classifier.py index a3aa4a6e6..d1e5bdeac 100644 --- a/tools/talent-scout/packages/captcha-solver/ml-service/train_style_classifier.py +++ b/tools/talent-scout/packages/captcha-solver/ml-service/train_style_classifier.py @@ -95,6 +95,8 @@ def train_style_classifier_core( bridge: HeartbeatBridge | None = None, dataset_dir: str | None = None, online: bool = False, + checkpoint_every: int = 5, + resume_from: str | None = None, ) -> dict: """Train a MobileNetV3-Small style classifier. @@ -114,6 +116,8 @@ def train_style_classifier_core( output_path: File path to save the checkpoint. bridge: Optional heartbeat bridge for preemption checks. dataset_dir: Root directory of pre-generated dataset. None = on-the-fly. + checkpoint_every: Save resumable checkpoint every N epochs. + resume_from: Path to a resumable checkpoint to continue training from. Returns: Dict with keys: output_path, best_accuracy, epochs_trained. @@ -183,8 +187,44 @@ def train_style_classifier_core( best_accuracy = 0.0 best_state = None + start_epoch = 1 - for epoch in range(1, epochs + 1): + # Resume from checkpoint + if resume_from is not None: + resume_path = Path(resume_from) + if not resume_path.is_absolute(): + resume_path = Path.cwd() / resume_path + if resume_path.exists(): + ckpt = torch.load(resume_path, map_location=device, weights_only=False) + model.load_state_dict(ckpt["state_dict"]) + meta = ckpt.get("metadata", {}) + if ckpt.get("checkpoint_type") == "resumable": + optimizer.load_state_dict(ckpt["optimizer_state_dict"]) + scheduler.load_state_dict(ckpt["scheduler_state_dict"]) + best_accuracy = meta.get("best_accuracy", 0.0) + start_epoch = meta.get("epoch", 0) + 1 + if ckpt.get("best_state") is not None: + best_state = ckpt["best_state"] + # Restore backbone unfreeze state + if meta.get("backbone_unfrozen", False): + model.unfreeze_last_n_blocks(2) + if ckpt.get("rng_states"): + rng = ckpt["rng_states"] + torch.set_rng_state(rng["torch_cpu"]) + if torch.cuda.is_available() and rng.get("torch_cuda") is not None: + torch.cuda.set_rng_state(rng["torch_cuda"], device=device) + logger.info( + "Resumed from checkpoint: epoch %d, best_acc=%.1f%%", + start_epoch - 1, best_accuracy * 100, + ) + else: + logger.info("Loaded weights from checkpoint (fine-tuning mode, curriculum restarts)") + else: + logger.error("Resume checkpoint not found: %s", resume_path) + + output_path_obj = Path(output_path) + + for epoch in range(start_epoch, epochs + 1): if bridge is not None: bridge.check_or_raise() @@ -280,15 +320,55 @@ def train_style_classifier_core( ) logger.info(" New best! Per-style: %s", per_style_str) + # Periodic resumable checkpoint + if epoch % checkpoint_every == 0 or epoch == epochs: + backbone_unfrozen = epoch > warmup_epochs + rng_states = { + "torch_cpu": torch.get_rng_state(), + } + if torch.cuda.is_available(): + rng_states["torch_cuda"] = torch.cuda.get_rng_state(device=device) + ckpt_data = { + "state_dict": model.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "scheduler_state_dict": scheduler.state_dict(), + "best_state": best_state, + "checkpoint_type": "resumable", + "num_classes": NUM_STYLES, + "rng_states": rng_states, + "metadata": { + "epoch": epoch, + "best_accuracy": best_accuracy, + "epochs_total": epochs, + "samples_per_style": samples_per_style, + "total_samples": samples_per_style * NUM_STYLES, + "style_names": STYLE_NAMES, + "backbone_unfrozen": backbone_unfrozen, + }, + } + ckpt_path = output_path_obj.with_suffix(f".epoch{epoch}.pt") + ckpt_path.parent.mkdir(parents=True, exist_ok=True) + torch.save(ckpt_data, ckpt_path) + logger.info(" Resumable checkpoint saved: %s", ckpt_path.name) + + # Keep only 3 most recent epoch checkpoints + existing = sorted( + output_path_obj.parent.glob(f"{output_path_obj.stem}.epoch*.pt"), + key=lambda p: p.stat().st_mtime, + ) + for stale in existing[:-3]: + stale.unlink() + logger.info(" Cleaned old checkpoint: %s", stale.name) + # Restore best model and save if best_state is not None: model.load_state_dict(best_state) - output_path_obj = Path(output_path) output_path_obj.parent.mkdir(parents=True, exist_ok=True) checkpoint = { "state_dict": model.state_dict(), + "checkpoint_type": "final", "num_classes": NUM_STYLES, "metadata": { "best_accuracy": best_accuracy, @@ -323,6 +403,8 @@ def _build_parser() -> argparse.ArgumentParser: parser.add_argument("--val-split", type=float, default=0.2, help="Validation split ratio") parser.add_argument("--output", type=str, default="models/style_classifier.pt", help="Output path") parser.add_argument("--dataset-dir", type=str, default=settings.dataset_dir, help="Pre-generated dataset directory") + parser.add_argument("--checkpoint-every", type=int, default=5, help="Save resumable checkpoint every N epochs") + parser.add_argument("--resume-from", type=str, default=None, metavar="CHECKPOINT", help="Resume from checkpoint") parser.add_argument("--no-gpu-lease", action="store_true", help="Run directly on auto-detected device (no GPUBoss lease)") return parser @@ -344,6 +426,8 @@ def _run_direct(args: argparse.Namespace) -> None: output_path=args.output, dataset_dir=args.dataset_dir, online=args.online, + checkpoint_every=args.checkpoint_every, + resume_from=args.resume_from, )