#!/usr/bin/env bash
# Train — CAPTCHA model training lifecycle management
#
# Launches DDP training across all available GPUs using torchrun.
# Handles resume from checkpoints, status monitoring, and graceful stop.

SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
source "$SCRIPT_DIR/core"

TALENT_SCOUT_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)"
ML_SERVICE_ROOT="$TALENT_SCOUT_ROOT/packages/captcha-solver/ml-service"
VENV_PYTHON="$ML_SERVICE_ROOT/.venv/bin/python"
TRAINING_HISTORY="$ML_SERVICE_ROOT/models/.training-history"
MODELS_DIR="$ML_SERVICE_ROOT/models"
PIDFILE="$ML_SERVICE_ROOT/.training.pid"

_require_venv() {
  if [ ! -f "$VENV_PYTHON" ]; then
    err "ML service venv not found at $ML_SERVICE_ROOT/.venv"
    err "Run: cd $ML_SERVICE_ROOT && python -m venv .venv && .venv/bin/pip install -e ."
    return 1
  fi
}

_gpu_count() {
  "$VENV_PYTHON" -c "import torch; print(torch.cuda.device_count())" 2>/dev/null || echo "0"
}

_find_latest_checkpoint() {
  local style="$1"
  # Find the highest-numbered epoch checkpoint for this style
  local latest
  latest=$(ls -1 "$MODELS_DIR"/svtrv2_"${style}".epoch*.pt 2>/dev/null | sort -t. -k2 -n | tail -1)
  echo "$latest"
}

train_start() {
  _require_venv || return 1

  local style="${1:-}"
  if [ -z "$style" ]; then
    err "Usage: ./run train start <style> [options]"
    err ""
    err "Styles: line-strike, classic, color-mesh, gradient-wave, etc."
    err "  Use 'all' to train all styles."
    err ""
    err "Model options:"
    err "  --model-size SIZE       tiny|small|base (default: tiny)"
    err "  --epochs N              Total epochs (default: 100)"
    err "  --batch-size N          Per-GPU batch size (default: 64)"
    err "  --samples-per-phase N   Samples per curriculum phase (default: 200000)"
    err "  --lr RATE               Base learning rate (default: 5e-4)"
    err "  --num-workers N         DataLoader workers (default: 4)"
    err ""
    err "Training recipe:"
    err "  --amp                   Enable automatic mixed precision"
    err "  --scheduler TYPE        cosine|onecycle (default: cosine)"
    err "  --grad-accum-steps N    Gradient accumulation steps (default: 1)"
    err "  --difficulty-range L H  Continuous difficulty range (e.g. 1.0 5.0)"
    err "  --start-phase N        Skip to phase N (1-indexed, for resuming curriculum)"
    err ""
    err "Presets (shorthand for common recipes):"
    err "  --preset paper          SVTRv2 paper recipe: small, AMP, OneCycleLR, batch 1024"
    return 1
  fi
  shift

  # Check if training is already running
  if [ -f "$PIDFILE" ] && kill -0 "$(cat "$PIDFILE")" 2>/dev/null; then
    err "Training already running (PID $(cat "$PIDFILE"))"
    err "Use './run train stop' to stop it first, or './run train status' to check progress."
    return 1
  fi

  local num_gpus
  num_gpus=$(_gpu_count)
  if [ "$num_gpus" -eq 0 ]; then
    err "No CUDA GPUs detected"
    return 1
  fi

  # Check for --preset and expand it before processing other args
  local extra_args=("$@")
  local expanded_args=()
  local has_preset=false
  local i=0
  while [ $i -lt ${#extra_args[@]} ]; do
    if [ "${extra_args[$i]}" = "--preset" ]; then
      has_preset=true
      i=$((i + 1))
      local preset_name="${extra_args[$i]:-}"
      case "$preset_name" in
        paper)
          # SVTRv2 paper recipe: Small model, AMP, OneCycleLR, effective batch 1024
          log "Applying preset: paper (SVTRv2 paper recipe)"
          expanded_args+=(--model-size small --amp --scheduler onecycle)
          # Target effective batch 1024: batch_size * num_gpus * grad_accum
          # With 2 GPUs: 128 * 2 * 4 = 1024
          local accum_steps=$(( 1024 / (128 * num_gpus) ))
          [ "$accum_steps" -lt 1 ] && accum_steps=1
          expanded_args+=(--batch-size 128 --grad-accum-steps "$accum_steps")
          expanded_args+=(--lr 3.25e-4 --num-workers 32 --samples-per-phase 2000000)
          ;;
        *)
          err "Unknown preset: $preset_name"
          err "Available presets: paper"
          return 1
          ;;
      esac
    else
      expanded_args+=("${extra_args[$i]}")
    fi
    i=$((i + 1))
  done

  # Build args
  local args=()
  if [ "$style" = "all" ]; then
    args+=(--online)
  else
    args+=(--styles "$style" --skip-universal --online)
  fi

  # Auto-resume from latest checkpoint (skip if --no-resume is in args)
  local no_resume=false
  for arg in "${expanded_args[@]}"; do
    [ "$arg" = "--no-resume" ] && no_resume=true
  done
  # Filter out --no-resume from expanded_args (not a real train_svtrv2 flag)
  local filtered_args=()
  for arg in "${expanded_args[@]}"; do
    [ "$arg" != "--no-resume" ] && filtered_args+=("$arg")
  done

  local checkpoint
  if [ "$style" != "all" ] && [ "$no_resume" = false ]; then
    checkpoint=$(_find_latest_checkpoint "$style")
    if [ -n "$checkpoint" ]; then
      log "Auto-resuming from checkpoint: $(basename "$checkpoint")"
      args+=(--resume-from "$checkpoint")
    fi
  fi

  args+=(--dataset-dir ~/.cache/captcha-gen --output-dir "$MODELS_DIR/")
  args+=("${filtered_args[@]}")

  log "Starting DDP training on $num_gpus GPU(s) (model-boss lease coordinator)..."
  log "Style: $style"
  if [ "$has_preset" = true ]; then
    log "Effective batch size: $((128 * num_gpus * accum_steps))"
  fi
  log "Command: python train_svtrv2_by_style.py ${args[*]}"

  cd "$ML_SERVICE_ROOT"

  # Log file for this training run
  local logdir="$ML_SERVICE_ROOT/logs"
  mkdir -p "$logdir"
  local logfile="$logdir/train_${style}_$(date +%Y%m%d_%H%M%S).log"

  # Launch the coordinator (outer process) directly — it acquires GPU leases from
  # model-boss, then spawns torchrun internally with CUDA_VISIBLE_DEVICES pinned to
  # the leased GPUs.  Pass --no-gpu-lease explicitly to skip model-boss for direct
  # use (e.g. debugging) or when running inside an already-leased torchrun context.
  PYTHONPATH=src "$VENV_PYTHON" \
    train_svtrv2_by_style.py "${args[@]}" \
    > "$logfile" 2>&1 &

  local pid=$!
  echo "$pid" > "$PIDFILE"
  ok "Training launched (PID $pid)"
  echo ""
  echo "  tail -f $logfile"
  echo "  ./run train status $style"
  echo "  ./run train check $style"
  echo "  ./run train stop"
  echo ""
}

train_stop() {
  if [ ! -f "$PIDFILE" ]; then
    warn "No training PID file found"
    # Try to find running training processes
    local pids
    pids=$(pgrep -f "train_svtrv2_by_style" 2>/dev/null | head -5)
    if [ -n "$pids" ]; then
      warn "Found training processes: $pids"
      warn "Kill manually with: kill $pids"
    fi
    return 0
  fi

  local pid
  pid=$(cat "$PIDFILE")
  if kill -0 "$pid" 2>/dev/null; then
    log "Sending SIGTERM to training (PID $pid)..."
    kill "$pid"
    # Wait up to 30s for graceful shutdown (saves emergency checkpoint)
    local waited=0
    while kill -0 "$pid" 2>/dev/null && [ "$waited" -lt 30 ]; do
      sleep 1
      waited=$((waited + 1))
    done
    if kill -0 "$pid" 2>/dev/null; then
      warn "Training did not exit after 30s, sending SIGKILL..."
      kill -9 "$pid" 2>/dev/null
    fi
    ok "Training stopped"
  else
    warn "Training process $pid is not running"
  fi
  rm -f "$PIDFILE"

  # Kill any surviving torchrun workers / DDP processes. The coordinator forwards
  # SIGTERM to its torchrun subprocess, but if the coordinator was killed abruptly
  # (SIGKILL or crash) the workers become orphaned and keep holding VRAM.
  local orphans
  orphans=$(pgrep -f "train_svtrv2_by_style" 2>/dev/null)
  if [ -n "$orphans" ]; then
    log "Killing orphaned training workers: $orphans"
    # shellcheck disable=SC2086
    kill -TERM $orphans 2>/dev/null || true
  fi
}

train_status() {
  local style="${1:-}"

  # Show GPU status
  log "GPU Status:"
  nvidia-smi --query-gpu=index,name,utilization.gpu,memory.used,memory.total --format=csv,noheader 2>/dev/null | while read -r line; do
    echo "  $line"
  done

  # Show running training
  if [ -f "$PIDFILE" ] && kill -0 "$(cat "$PIDFILE")" 2>/dev/null; then
    ok "Training running (PID $(cat "$PIDFILE"))"
  else
    warn "No training running"
    [ -f "$PIDFILE" ] && rm -f "$PIDFILE"
  fi

  # Show training history for requested style (or all)
  echo ""
  if [ -n "$style" ]; then
    local csv="$TRAINING_HISTORY/svtrv2_${style}.csv"
    if [ -f "$csv" ]; then
      local last_line
      last_line=$(tail -1 "$csv")
      # CSV columns: timestamp,model_id,epoch,total_epochs,phase,train_loss,val_loss,exact_acc,char_acc,lr,...
      local epoch total_epochs phase exact_acc
      epoch=$(echo "$last_line" | cut -d, -f3)
      total_epochs=$(echo "$last_line" | cut -d, -f4)
      phase=$(echo "$last_line" | cut -d, -f5)
      exact_acc=$(echo "$last_line" | cut -d, -f8)
      log "Style: $style — Epoch $epoch/$total_epochs, Phase $phase/5, Exact acc: $(awk "BEGIN{printf \"%.1f\", $exact_acc * 100}")%"
    else
      warn "No training history for style: $style"
    fi
  else
    # Show all styles
    for csv in "$TRAINING_HISTORY"/svtrv2_*.csv; do
      [ -f "$csv" ] || continue
      local name
      name=$(basename "$csv" .csv | sed 's/svtrv2_//')
      local last_line
      last_line=$(tail -1 "$csv")
      local epoch total_epochs exact_acc
      epoch=$(echo "$last_line" | cut -d, -f3)
      total_epochs=$(echo "$last_line" | cut -d, -f4)
      exact_acc=$(echo "$last_line" | cut -d, -f8)
      log "  $name — Epoch $epoch/$total_epochs, Exact acc: $(awk "BEGIN{printf \"%.1f\", $exact_acc * 100}")%"
    done
  fi
}

train_check() {
  _require_venv || return 1
  local style="${1:-}"
  local extra_args=()
  [ -n "$style" ] && extra_args+=(--style "$style")
  cd "$ML_SERVICE_ROOT"
  "$VENV_PYTHON" ckpt_health.py --models-dir "$MODELS_DIR" check "${extra_args[@]}"
}

train_clean() {
  _require_venv || return 1
  local dry_run=false
  for arg in "$@"; do
    [ "$arg" = "--dry-run" ] && dry_run=true
  done
  local extra_args=()
  "$dry_run" && extra_args+=(--dry-run)
  cd "$ML_SERVICE_ROOT"
  "$VENV_PYTHON" ckpt_health.py --models-dir "$MODELS_DIR" clean "${extra_args[@]}"
}

train_orphans() {
  local pids
  pids=$(pgrep -f "train_svtrv2_by_style" 2>/dev/null)
  if [ -z "$pids" ]; then
    ok "No orphaned training processes found"
    return 0
  fi
  warn "Orphaned training processes: $pids"
  if [ "${1:-}" = "--kill" ]; then
    # shellcheck disable=SC2086
    kill -TERM $pids 2>/dev/null || true
    sleep 2
    local survivors
    survivors=$(pgrep -f "train_svtrv2_by_style" 2>/dev/null)
    if [ -n "$survivors" ]; then
      warn "Survivors after SIGTERM, sending SIGKILL: $survivors"
      # shellcheck disable=SC2086
      kill -9 $survivors 2>/dev/null || true
    fi
    ok "Orphaned processes killed"
  else
    warn "Re-run with --kill to terminate them"
  fi
}

train_usage() {
  echo -e "${BOLD}CAPTCHA Model Training${RESET}"
  echo ""
  echo "Usage: ./run train <command> [args]"
  echo ""
  echo "Commands:"
  echo "  start <style> [opts]   Start DDP training (auto-resumes from latest checkpoint)"
  echo "  stop                   Gracefully stop training (saves emergency checkpoint)"
  echo "  status [style]         Show GPU status and training progress"
  echo "  check [style]          Scan model files and checkpoints for NaN/Inf corruption"
  echo "  clean [--dry-run]      Delete corrupt checkpoints from models/checkpoints/"
  echo "  orphans [--kill]       List (or kill) orphaned training worker processes"
  echo ""
  echo "Examples:"
  echo "  ./run train start color-mesh --preset paper --difficulty-range 1.0 5.0 --epochs 100"
  echo "  ./run train start color-mesh --model-size small --amp --scheduler onecycle"
  echo "  ./run train start all --epochs 50"
  echo "  ./run train status color-mesh"
  echo "  ./run train stop"
  echo "  ./run train check color-mesh"
  echo "  ./run train clean --dry-run"
  echo "  ./run train orphans --kill"
  echo ""
  echo "Presets:"
  echo "  --preset paper    SVTRv2 paper recipe: Small model, AMP, OneCycleLR, batch 1024"
  echo ""
}

# Dispatch
train_command="${1:-}"
shift || true

case "$train_command" in
  start)   train_start "$@" ;;
  stop)    train_stop "$@" ;;
  status)  train_status "$@" ;;
  check)   train_check "$@" ;;
  clean)   train_clean "$@" ;;
  orphans) train_orphans "$@" ;;
  -h|--help|help|"") train_usage ;;
  *)
    err "Unknown train command: $train_command"
    train_usage
    exit 1
    ;;
esac
