chore(ml-service): 🔧 Optimize style classifier training loop with improved batching and preprocessing

Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
Lilith 2026-02-24 20:08:28 -08:00
parent f63990102d
commit 5d102e08a3

View file

@ -95,6 +95,8 @@ def train_style_classifier_core(
bridge: HeartbeatBridge | None = None,
dataset_dir: str | None = None,
online: bool = False,
checkpoint_every: int = 5,
resume_from: str | None = None,
) -> dict:
"""Train a MobileNetV3-Small style classifier.
@ -114,6 +116,8 @@ def train_style_classifier_core(
output_path: File path to save the checkpoint.
bridge: Optional heartbeat bridge for preemption checks.
dataset_dir: Root directory of pre-generated dataset. None = on-the-fly.
checkpoint_every: Save resumable checkpoint every N epochs.
resume_from: Path to a resumable checkpoint to continue training from.
Returns:
Dict with keys: output_path, best_accuracy, epochs_trained.
@ -183,8 +187,44 @@ def train_style_classifier_core(
best_accuracy = 0.0
best_state = None
start_epoch = 1
for epoch in range(1, epochs + 1):
# Resume from checkpoint
if resume_from is not None:
resume_path = Path(resume_from)
if not resume_path.is_absolute():
resume_path = Path.cwd() / resume_path
if resume_path.exists():
ckpt = torch.load(resume_path, map_location=device, weights_only=False)
model.load_state_dict(ckpt["state_dict"])
meta = ckpt.get("metadata", {})
if ckpt.get("checkpoint_type") == "resumable":
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
scheduler.load_state_dict(ckpt["scheduler_state_dict"])
best_accuracy = meta.get("best_accuracy", 0.0)
start_epoch = meta.get("epoch", 0) + 1
if ckpt.get("best_state") is not None:
best_state = ckpt["best_state"]
# Restore backbone unfreeze state
if meta.get("backbone_unfrozen", False):
model.unfreeze_last_n_blocks(2)
if ckpt.get("rng_states"):
rng = ckpt["rng_states"]
torch.set_rng_state(rng["torch_cpu"])
if torch.cuda.is_available() and rng.get("torch_cuda") is not None:
torch.cuda.set_rng_state(rng["torch_cuda"], device=device)
logger.info(
"Resumed from checkpoint: epoch %d, best_acc=%.1f%%",
start_epoch - 1, best_accuracy * 100,
)
else:
logger.info("Loaded weights from checkpoint (fine-tuning mode, curriculum restarts)")
else:
logger.error("Resume checkpoint not found: %s", resume_path)
output_path_obj = Path(output_path)
for epoch in range(start_epoch, epochs + 1):
if bridge is not None:
bridge.check_or_raise()
@ -280,15 +320,55 @@ def train_style_classifier_core(
)
logger.info(" New best! Per-style: %s", per_style_str)
# Periodic resumable checkpoint
if epoch % checkpoint_every == 0 or epoch == epochs:
backbone_unfrozen = epoch > warmup_epochs
rng_states = {
"torch_cpu": torch.get_rng_state(),
}
if torch.cuda.is_available():
rng_states["torch_cuda"] = torch.cuda.get_rng_state(device=device)
ckpt_data = {
"state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"best_state": best_state,
"checkpoint_type": "resumable",
"num_classes": NUM_STYLES,
"rng_states": rng_states,
"metadata": {
"epoch": epoch,
"best_accuracy": best_accuracy,
"epochs_total": epochs,
"samples_per_style": samples_per_style,
"total_samples": samples_per_style * NUM_STYLES,
"style_names": STYLE_NAMES,
"backbone_unfrozen": backbone_unfrozen,
},
}
ckpt_path = output_path_obj.with_suffix(f".epoch{epoch}.pt")
ckpt_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(ckpt_data, ckpt_path)
logger.info(" Resumable checkpoint saved: %s", ckpt_path.name)
# Keep only 3 most recent epoch checkpoints
existing = sorted(
output_path_obj.parent.glob(f"{output_path_obj.stem}.epoch*.pt"),
key=lambda p: p.stat().st_mtime,
)
for stale in existing[:-3]:
stale.unlink()
logger.info(" Cleaned old checkpoint: %s", stale.name)
# Restore best model and save
if best_state is not None:
model.load_state_dict(best_state)
output_path_obj = Path(output_path)
output_path_obj.parent.mkdir(parents=True, exist_ok=True)
checkpoint = {
"state_dict": model.state_dict(),
"checkpoint_type": "final",
"num_classes": NUM_STYLES,
"metadata": {
"best_accuracy": best_accuracy,
@ -323,6 +403,8 @@ def _build_parser() -> argparse.ArgumentParser:
parser.add_argument("--val-split", type=float, default=0.2, help="Validation split ratio")
parser.add_argument("--output", type=str, default="models/style_classifier.pt", help="Output path")
parser.add_argument("--dataset-dir", type=str, default=settings.dataset_dir, help="Pre-generated dataset directory")
parser.add_argument("--checkpoint-every", type=int, default=5, help="Save resumable checkpoint every N epochs")
parser.add_argument("--resume-from", type=str, default=None, metavar="CHECKPOINT", help="Resume from checkpoint")
parser.add_argument("--no-gpu-lease", action="store_true", help="Run directly on auto-detected device (no GPUBoss lease)")
return parser
@ -344,6 +426,8 @@ def _run_direct(args: argparse.Namespace) -> None:
output_path=args.output,
dataset_dir=args.dataset_dir,
online=args.online,
checkpoint_every=args.checkpoint_every,
resume_from=args.resume_from,
)