chore(ml-service): 🔧 Optimize style classifier training loop with improved batching and preprocessing
Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
parent
f63990102d
commit
5d102e08a3
1 changed files with 86 additions and 2 deletions
|
|
@ -95,6 +95,8 @@ def train_style_classifier_core(
|
|||
bridge: HeartbeatBridge | None = None,
|
||||
dataset_dir: str | None = None,
|
||||
online: bool = False,
|
||||
checkpoint_every: int = 5,
|
||||
resume_from: str | None = None,
|
||||
) -> dict:
|
||||
"""Train a MobileNetV3-Small style classifier.
|
||||
|
||||
|
|
@ -114,6 +116,8 @@ def train_style_classifier_core(
|
|||
output_path: File path to save the checkpoint.
|
||||
bridge: Optional heartbeat bridge for preemption checks.
|
||||
dataset_dir: Root directory of pre-generated dataset. None = on-the-fly.
|
||||
checkpoint_every: Save resumable checkpoint every N epochs.
|
||||
resume_from: Path to a resumable checkpoint to continue training from.
|
||||
|
||||
Returns:
|
||||
Dict with keys: output_path, best_accuracy, epochs_trained.
|
||||
|
|
@ -183,8 +187,44 @@ def train_style_classifier_core(
|
|||
|
||||
best_accuracy = 0.0
|
||||
best_state = None
|
||||
start_epoch = 1
|
||||
|
||||
for epoch in range(1, epochs + 1):
|
||||
# Resume from checkpoint
|
||||
if resume_from is not None:
|
||||
resume_path = Path(resume_from)
|
||||
if not resume_path.is_absolute():
|
||||
resume_path = Path.cwd() / resume_path
|
||||
if resume_path.exists():
|
||||
ckpt = torch.load(resume_path, map_location=device, weights_only=False)
|
||||
model.load_state_dict(ckpt["state_dict"])
|
||||
meta = ckpt.get("metadata", {})
|
||||
if ckpt.get("checkpoint_type") == "resumable":
|
||||
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
|
||||
scheduler.load_state_dict(ckpt["scheduler_state_dict"])
|
||||
best_accuracy = meta.get("best_accuracy", 0.0)
|
||||
start_epoch = meta.get("epoch", 0) + 1
|
||||
if ckpt.get("best_state") is not None:
|
||||
best_state = ckpt["best_state"]
|
||||
# Restore backbone unfreeze state
|
||||
if meta.get("backbone_unfrozen", False):
|
||||
model.unfreeze_last_n_blocks(2)
|
||||
if ckpt.get("rng_states"):
|
||||
rng = ckpt["rng_states"]
|
||||
torch.set_rng_state(rng["torch_cpu"])
|
||||
if torch.cuda.is_available() and rng.get("torch_cuda") is not None:
|
||||
torch.cuda.set_rng_state(rng["torch_cuda"], device=device)
|
||||
logger.info(
|
||||
"Resumed from checkpoint: epoch %d, best_acc=%.1f%%",
|
||||
start_epoch - 1, best_accuracy * 100,
|
||||
)
|
||||
else:
|
||||
logger.info("Loaded weights from checkpoint (fine-tuning mode, curriculum restarts)")
|
||||
else:
|
||||
logger.error("Resume checkpoint not found: %s", resume_path)
|
||||
|
||||
output_path_obj = Path(output_path)
|
||||
|
||||
for epoch in range(start_epoch, epochs + 1):
|
||||
if bridge is not None:
|
||||
bridge.check_or_raise()
|
||||
|
||||
|
|
@ -280,15 +320,55 @@ def train_style_classifier_core(
|
|||
)
|
||||
logger.info(" New best! Per-style: %s", per_style_str)
|
||||
|
||||
# Periodic resumable checkpoint
|
||||
if epoch % checkpoint_every == 0 or epoch == epochs:
|
||||
backbone_unfrozen = epoch > warmup_epochs
|
||||
rng_states = {
|
||||
"torch_cpu": torch.get_rng_state(),
|
||||
}
|
||||
if torch.cuda.is_available():
|
||||
rng_states["torch_cuda"] = torch.cuda.get_rng_state(device=device)
|
||||
ckpt_data = {
|
||||
"state_dict": model.state_dict(),
|
||||
"optimizer_state_dict": optimizer.state_dict(),
|
||||
"scheduler_state_dict": scheduler.state_dict(),
|
||||
"best_state": best_state,
|
||||
"checkpoint_type": "resumable",
|
||||
"num_classes": NUM_STYLES,
|
||||
"rng_states": rng_states,
|
||||
"metadata": {
|
||||
"epoch": epoch,
|
||||
"best_accuracy": best_accuracy,
|
||||
"epochs_total": epochs,
|
||||
"samples_per_style": samples_per_style,
|
||||
"total_samples": samples_per_style * NUM_STYLES,
|
||||
"style_names": STYLE_NAMES,
|
||||
"backbone_unfrozen": backbone_unfrozen,
|
||||
},
|
||||
}
|
||||
ckpt_path = output_path_obj.with_suffix(f".epoch{epoch}.pt")
|
||||
ckpt_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(ckpt_data, ckpt_path)
|
||||
logger.info(" Resumable checkpoint saved: %s", ckpt_path.name)
|
||||
|
||||
# Keep only 3 most recent epoch checkpoints
|
||||
existing = sorted(
|
||||
output_path_obj.parent.glob(f"{output_path_obj.stem}.epoch*.pt"),
|
||||
key=lambda p: p.stat().st_mtime,
|
||||
)
|
||||
for stale in existing[:-3]:
|
||||
stale.unlink()
|
||||
logger.info(" Cleaned old checkpoint: %s", stale.name)
|
||||
|
||||
# Restore best model and save
|
||||
if best_state is not None:
|
||||
model.load_state_dict(best_state)
|
||||
|
||||
output_path_obj = Path(output_path)
|
||||
output_path_obj.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
checkpoint = {
|
||||
"state_dict": model.state_dict(),
|
||||
"checkpoint_type": "final",
|
||||
"num_classes": NUM_STYLES,
|
||||
"metadata": {
|
||||
"best_accuracy": best_accuracy,
|
||||
|
|
@ -323,6 +403,8 @@ def _build_parser() -> argparse.ArgumentParser:
|
|||
parser.add_argument("--val-split", type=float, default=0.2, help="Validation split ratio")
|
||||
parser.add_argument("--output", type=str, default="models/style_classifier.pt", help="Output path")
|
||||
parser.add_argument("--dataset-dir", type=str, default=settings.dataset_dir, help="Pre-generated dataset directory")
|
||||
parser.add_argument("--checkpoint-every", type=int, default=5, help="Save resumable checkpoint every N epochs")
|
||||
parser.add_argument("--resume-from", type=str, default=None, metavar="CHECKPOINT", help="Resume from checkpoint")
|
||||
parser.add_argument("--no-gpu-lease", action="store_true", help="Run directly on auto-detected device (no GPUBoss lease)")
|
||||
return parser
|
||||
|
||||
|
|
@ -344,6 +426,8 @@ def _run_direct(args: argparse.Namespace) -> None:
|
|||
output_path=args.output,
|
||||
dataset_dir=args.dataset_dir,
|
||||
online=args.online,
|
||||
checkpoint_every=args.checkpoint_every,
|
||||
resume_from=args.resume_from,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue