ml-trainer-lm/tests/test_model.py

216 lines
8.5 KiB
Python

"""Unit tests for model loading and LoRA utilities."""
from __future__ import annotations
import os
from unittest.mock import MagicMock, patch
import pytest
import torch
from ml_trainer_lm.model import apply_lora, load_model_for_training
class MockConfig:
"""Mock configuration object for testing."""
def __init__(
self,
base_model: str = "meta-llama/Llama-2-7b",
quantize: bool = False,
local_rank: int = -1,
target_modules: list[str] | None = None,
lora_r: int = 8,
lora_alpha: int = 16,
lora_dropout: float = 0.05,
):
self.base_model = base_model
self.quantize = quantize
self.local_rank = local_rank
self.target_modules = target_modules or ["q_proj", "v_proj"]
self.lora_r = lora_r
self.lora_alpha = lora_alpha
self.lora_dropout = lora_dropout
class TestLoadModelForTraining:
"""Tests for load_model_for_training function."""
@patch("transformers.AutoTokenizer.from_pretrained")
@patch("transformers.AutoConfig.from_pretrained")
@patch("transformers.AutoModelForCausalLM.from_pretrained")
def test_load_model_uses_torch_dtype_kwarg(self, mock_load, mock_config, mock_tokenizer):
"""Verify torch_dtype (not dtype) is used in AutoModelForCausalLM.from_pretrained."""
config = MockConfig(quantize=False)
mock_tokenizer.return_value = MagicMock()
mock_config.return_value = MagicMock(model_type="llama", quantization_config=None)
mock_load.return_value = MagicMock()
load_model_for_training(config)
# Verify torch_dtype kwarg was used (not dtype)
call_kwargs = mock_load.call_args[1]
assert "torch_dtype" in call_kwargs, "torch_dtype kwarg must be present"
assert "dtype" not in call_kwargs, "dtype kwarg should not be used"
assert call_kwargs["torch_dtype"] in (torch.float16, torch.float32)
@patch("transformers.AutoTokenizer.from_pretrained")
@patch("transformers.AutoConfig.from_pretrained")
@patch("transformers.AutoModelForCausalLM.from_pretrained")
@patch("torch.cuda.is_available", return_value=True)
@patch("torch.cuda.is_bf16_supported", return_value=False)
def test_load_model_with_quantization(self, mock_bf16, mock_cuda, mock_load, mock_config, mock_tokenizer):
"""Verify BitsAndBytesConfig applied when quantize=True."""
config = MockConfig(quantize=True)
mock_tokenizer.return_value = MagicMock()
mock_config.return_value = MagicMock(model_type="llama", quantization_config=None)
mock_load.return_value = MagicMock()
load_model_for_training(config)
call_kwargs = mock_load.call_args[1]
assert "quantization_config" in call_kwargs, "quantization_config should be in kwargs when quantize=True"
bnb_config = call_kwargs["quantization_config"]
assert bnb_config.load_in_4bit is True
assert bnb_config.bnb_4bit_quant_type == "nf4"
@patch("transformers.AutoTokenizer.from_pretrained")
@patch("transformers.AutoConfig.from_pretrained")
@patch("transformers.AutoModelForCausalLM.from_pretrained")
@patch("torch.cuda.is_available", return_value=False)
def test_load_model_without_quantization(self, mock_cuda, mock_load, mock_config, mock_tokenizer):
"""Verify fp16 loading when quantize=False."""
config = MockConfig(quantize=False)
mock_tokenizer.return_value = MagicMock()
mock_config.return_value = MagicMock(model_type="llama", quantization_config=None)
mock_load.return_value = MagicMock()
load_model_for_training(config)
call_kwargs = mock_load.call_args[1]
assert "quantization_config" not in call_kwargs or call_kwargs.get("quantization_config") is None
assert call_kwargs["torch_dtype"] == torch.float32
@patch("transformers.AutoTokenizer.from_pretrained")
@patch("transformers.AutoConfig.from_pretrained")
@patch("transformers.Mistral3ForConditionalGeneration.from_pretrained")
def test_load_multimodal_model(self, mock_mistral_load, mock_config, mock_tokenizer):
"""Verify multimodal models (mistral3) use correct loader."""
config = MockConfig(base_model="mistral-ai/Mistral-Large-Instruct-2407")
mock_tokenizer.return_value = MagicMock()
mock_config.return_value = MagicMock(model_type="mistral3", quantization_config=None)
mock_mistral_load.return_value = MagicMock()
load_model_for_training(config)
# Verify Mistral3ForConditionalGeneration was called, not AutoModelForCausalLM
assert mock_mistral_load.called
@patch("transformers.AutoTokenizer.from_pretrained")
@patch("transformers.AutoConfig.from_pretrained")
@patch("transformers.AutoModelForCausalLM.from_pretrained")
def test_load_model_with_local_rank(self, mock_load, mock_config, mock_tokenizer):
"""Verify DDP device_map when LOCAL_RANK is set."""
config = MockConfig(local_rank=0)
mock_tokenizer.return_value = MagicMock()
mock_config.return_value = MagicMock(model_type="llama", quantization_config=None)
mock_load.return_value = MagicMock()
# Set LOCAL_RANK environment variable
os.environ["LOCAL_RANK"] = "1"
try:
load_model_for_training(config)
call_kwargs = mock_load.call_args[1]
assert call_kwargs["device_map"] == {"": 1}
finally:
del os.environ["LOCAL_RANK"]
@patch("transformers.AutoTokenizer.from_pretrained")
@patch("transformers.AutoConfig.from_pretrained")
@patch("transformers.AutoModelForCausalLM.from_pretrained")
def test_load_model_tokenizer_pad_token(self, mock_load, mock_config, mock_tokenizer):
"""Verify tokenizer pad_token set to eos_token if missing."""
config = MockConfig()
mock_tok = MagicMock()
mock_tok.pad_token = None
mock_tok.eos_token = "<|endoftext|>"
mock_tokenizer.return_value = mock_tok
mock_config.return_value = MagicMock(model_type="llama", quantization_config=None)
mock_load.return_value = MagicMock()
load_model_for_training(config)
assert mock_tok.pad_token == "<|endoftext|>"
class TestApplyLora:
"""Tests for apply_lora function."""
@patch("peft.get_peft_model")
@patch("peft.LoraConfig")
def test_apply_lora_basic(self, mock_lora_config, mock_get_peft):
"""Verify LoraConfig created with correct parameters."""
config = MockConfig(lora_r=8, lora_alpha=16, lora_dropout=0.05)
model = MagicMock()
model.config = MagicMock(model_type="llama")
mock_lora_config_instance = MagicMock()
mock_lora_config.return_value = mock_lora_config_instance
mock_get_peft.return_value = MagicMock()
apply_lora(model, config)
# Verify LoraConfig was called with correct params
mock_lora_config.assert_called_once()
call_kwargs = mock_lora_config.call_args[1]
assert call_kwargs["r"] == 8
assert call_kwargs["lora_alpha"] == 16
assert call_kwargs["lora_dropout"] == 0.05
assert call_kwargs["bias"] == "none"
@patch("peft.get_peft_model")
@patch("peft.LoraConfig")
def test_apply_lora_multimodal_targets(self, mock_lora_config, mock_get_peft):
"""Verify LoRA targets scoped to language_model for multimodal models."""
config = MockConfig(target_modules=["q_proj", "v_proj"])
model = MagicMock()
model.config = MagicMock(model_type="mistral3")
mock_lora_config_instance = MagicMock()
mock_lora_config.return_value = mock_lora_config_instance
mock_get_peft.return_value = MagicMock()
apply_lora(model, config)
call_kwargs = mock_lora_config.call_args[1]
target_modules = call_kwargs["target_modules"]
# Should be a regex pattern targeting language_model
assert isinstance(target_modules, str)
assert "language_model" in target_modules
assert "q_proj" in target_modules
assert "v_proj" in target_modules
@patch("peft.get_peft_model")
@patch("peft.LoraConfig")
def test_apply_lora_returns_peft_model(self, mock_lora_config, mock_get_peft):
"""Verify return value is PeftModel from get_peft_model."""
config = MockConfig()
model = MagicMock()
model.config = MagicMock(model_type="llama")
mock_peft_model = MagicMock()
mock_get_peft.return_value = mock_peft_model
result = apply_lora(model, config)
assert result is mock_peft_model
mock_get_peft.assert_called_once()