216 lines
8.5 KiB
Python
216 lines
8.5 KiB
Python
"""Unit tests for model loading and LoRA utilities."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from ml_trainer_lm.model import apply_lora, load_model_for_training
|
|
|
|
|
|
class MockConfig:
|
|
"""Mock configuration object for testing."""
|
|
|
|
def __init__(
|
|
self,
|
|
base_model: str = "meta-llama/Llama-2-7b",
|
|
quantize: bool = False,
|
|
local_rank: int = -1,
|
|
target_modules: list[str] | None = None,
|
|
lora_r: int = 8,
|
|
lora_alpha: int = 16,
|
|
lora_dropout: float = 0.05,
|
|
):
|
|
self.base_model = base_model
|
|
self.quantize = quantize
|
|
self.local_rank = local_rank
|
|
self.target_modules = target_modules or ["q_proj", "v_proj"]
|
|
self.lora_r = lora_r
|
|
self.lora_alpha = lora_alpha
|
|
self.lora_dropout = lora_dropout
|
|
|
|
|
|
class TestLoadModelForTraining:
|
|
"""Tests for load_model_for_training function."""
|
|
|
|
@patch("transformers.AutoTokenizer.from_pretrained")
|
|
@patch("transformers.AutoConfig.from_pretrained")
|
|
@patch("transformers.AutoModelForCausalLM.from_pretrained")
|
|
def test_load_model_uses_torch_dtype_kwarg(self, mock_load, mock_config, mock_tokenizer):
|
|
"""Verify torch_dtype (not dtype) is used in AutoModelForCausalLM.from_pretrained."""
|
|
config = MockConfig(quantize=False)
|
|
|
|
mock_tokenizer.return_value = MagicMock()
|
|
mock_config.return_value = MagicMock(model_type="llama", quantization_config=None)
|
|
mock_load.return_value = MagicMock()
|
|
|
|
load_model_for_training(config)
|
|
|
|
# Verify torch_dtype kwarg was used (not dtype)
|
|
call_kwargs = mock_load.call_args[1]
|
|
assert "torch_dtype" in call_kwargs, "torch_dtype kwarg must be present"
|
|
assert "dtype" not in call_kwargs, "dtype kwarg should not be used"
|
|
assert call_kwargs["torch_dtype"] in (torch.float16, torch.float32)
|
|
|
|
@patch("transformers.AutoTokenizer.from_pretrained")
|
|
@patch("transformers.AutoConfig.from_pretrained")
|
|
@patch("transformers.AutoModelForCausalLM.from_pretrained")
|
|
@patch("torch.cuda.is_available", return_value=True)
|
|
@patch("torch.cuda.is_bf16_supported", return_value=False)
|
|
def test_load_model_with_quantization(self, mock_bf16, mock_cuda, mock_load, mock_config, mock_tokenizer):
|
|
"""Verify BitsAndBytesConfig applied when quantize=True."""
|
|
config = MockConfig(quantize=True)
|
|
|
|
mock_tokenizer.return_value = MagicMock()
|
|
mock_config.return_value = MagicMock(model_type="llama", quantization_config=None)
|
|
mock_load.return_value = MagicMock()
|
|
|
|
load_model_for_training(config)
|
|
|
|
call_kwargs = mock_load.call_args[1]
|
|
assert "quantization_config" in call_kwargs, "quantization_config should be in kwargs when quantize=True"
|
|
bnb_config = call_kwargs["quantization_config"]
|
|
assert bnb_config.load_in_4bit is True
|
|
assert bnb_config.bnb_4bit_quant_type == "nf4"
|
|
|
|
@patch("transformers.AutoTokenizer.from_pretrained")
|
|
@patch("transformers.AutoConfig.from_pretrained")
|
|
@patch("transformers.AutoModelForCausalLM.from_pretrained")
|
|
@patch("torch.cuda.is_available", return_value=False)
|
|
def test_load_model_without_quantization(self, mock_cuda, mock_load, mock_config, mock_tokenizer):
|
|
"""Verify fp16 loading when quantize=False."""
|
|
config = MockConfig(quantize=False)
|
|
|
|
mock_tokenizer.return_value = MagicMock()
|
|
mock_config.return_value = MagicMock(model_type="llama", quantization_config=None)
|
|
mock_load.return_value = MagicMock()
|
|
|
|
load_model_for_training(config)
|
|
|
|
call_kwargs = mock_load.call_args[1]
|
|
assert "quantization_config" not in call_kwargs or call_kwargs.get("quantization_config") is None
|
|
assert call_kwargs["torch_dtype"] == torch.float32
|
|
|
|
@patch("transformers.AutoTokenizer.from_pretrained")
|
|
@patch("transformers.AutoConfig.from_pretrained")
|
|
@patch("transformers.Mistral3ForConditionalGeneration.from_pretrained")
|
|
def test_load_multimodal_model(self, mock_mistral_load, mock_config, mock_tokenizer):
|
|
"""Verify multimodal models (mistral3) use correct loader."""
|
|
config = MockConfig(base_model="mistral-ai/Mistral-Large-Instruct-2407")
|
|
|
|
mock_tokenizer.return_value = MagicMock()
|
|
mock_config.return_value = MagicMock(model_type="mistral3", quantization_config=None)
|
|
mock_mistral_load.return_value = MagicMock()
|
|
|
|
load_model_for_training(config)
|
|
|
|
# Verify Mistral3ForConditionalGeneration was called, not AutoModelForCausalLM
|
|
assert mock_mistral_load.called
|
|
|
|
@patch("transformers.AutoTokenizer.from_pretrained")
|
|
@patch("transformers.AutoConfig.from_pretrained")
|
|
@patch("transformers.AutoModelForCausalLM.from_pretrained")
|
|
def test_load_model_with_local_rank(self, mock_load, mock_config, mock_tokenizer):
|
|
"""Verify DDP device_map when LOCAL_RANK is set."""
|
|
config = MockConfig(local_rank=0)
|
|
|
|
mock_tokenizer.return_value = MagicMock()
|
|
mock_config.return_value = MagicMock(model_type="llama", quantization_config=None)
|
|
mock_load.return_value = MagicMock()
|
|
|
|
# Set LOCAL_RANK environment variable
|
|
os.environ["LOCAL_RANK"] = "1"
|
|
try:
|
|
load_model_for_training(config)
|
|
|
|
call_kwargs = mock_load.call_args[1]
|
|
assert call_kwargs["device_map"] == {"": 1}
|
|
finally:
|
|
del os.environ["LOCAL_RANK"]
|
|
|
|
@patch("transformers.AutoTokenizer.from_pretrained")
|
|
@patch("transformers.AutoConfig.from_pretrained")
|
|
@patch("transformers.AutoModelForCausalLM.from_pretrained")
|
|
def test_load_model_tokenizer_pad_token(self, mock_load, mock_config, mock_tokenizer):
|
|
"""Verify tokenizer pad_token set to eos_token if missing."""
|
|
config = MockConfig()
|
|
|
|
mock_tok = MagicMock()
|
|
mock_tok.pad_token = None
|
|
mock_tok.eos_token = "<|endoftext|>"
|
|
mock_tokenizer.return_value = mock_tok
|
|
|
|
mock_config.return_value = MagicMock(model_type="llama", quantization_config=None)
|
|
mock_load.return_value = MagicMock()
|
|
|
|
load_model_for_training(config)
|
|
|
|
assert mock_tok.pad_token == "<|endoftext|>"
|
|
|
|
|
|
class TestApplyLora:
|
|
"""Tests for apply_lora function."""
|
|
|
|
@patch("peft.get_peft_model")
|
|
@patch("peft.LoraConfig")
|
|
def test_apply_lora_basic(self, mock_lora_config, mock_get_peft):
|
|
"""Verify LoraConfig created with correct parameters."""
|
|
config = MockConfig(lora_r=8, lora_alpha=16, lora_dropout=0.05)
|
|
model = MagicMock()
|
|
model.config = MagicMock(model_type="llama")
|
|
|
|
mock_lora_config_instance = MagicMock()
|
|
mock_lora_config.return_value = mock_lora_config_instance
|
|
mock_get_peft.return_value = MagicMock()
|
|
|
|
apply_lora(model, config)
|
|
|
|
# Verify LoraConfig was called with correct params
|
|
mock_lora_config.assert_called_once()
|
|
call_kwargs = mock_lora_config.call_args[1]
|
|
assert call_kwargs["r"] == 8
|
|
assert call_kwargs["lora_alpha"] == 16
|
|
assert call_kwargs["lora_dropout"] == 0.05
|
|
assert call_kwargs["bias"] == "none"
|
|
|
|
@patch("peft.get_peft_model")
|
|
@patch("peft.LoraConfig")
|
|
def test_apply_lora_multimodal_targets(self, mock_lora_config, mock_get_peft):
|
|
"""Verify LoRA targets scoped to language_model for multimodal models."""
|
|
config = MockConfig(target_modules=["q_proj", "v_proj"])
|
|
model = MagicMock()
|
|
model.config = MagicMock(model_type="mistral3")
|
|
|
|
mock_lora_config_instance = MagicMock()
|
|
mock_lora_config.return_value = mock_lora_config_instance
|
|
mock_get_peft.return_value = MagicMock()
|
|
|
|
apply_lora(model, config)
|
|
|
|
call_kwargs = mock_lora_config.call_args[1]
|
|
target_modules = call_kwargs["target_modules"]
|
|
|
|
# Should be a regex pattern targeting language_model
|
|
assert isinstance(target_modules, str)
|
|
assert "language_model" in target_modules
|
|
assert "q_proj" in target_modules
|
|
assert "v_proj" in target_modules
|
|
|
|
@patch("peft.get_peft_model")
|
|
@patch("peft.LoraConfig")
|
|
def test_apply_lora_returns_peft_model(self, mock_lora_config, mock_get_peft):
|
|
"""Verify return value is PeftModel from get_peft_model."""
|
|
config = MockConfig()
|
|
model = MagicMock()
|
|
model.config = MagicMock(model_type="llama")
|
|
|
|
mock_peft_model = MagicMock()
|
|
mock_get_peft.return_value = mock_peft_model
|
|
|
|
result = apply_lora(model, config)
|
|
|
|
assert result is mock_peft_model
|
|
mock_get_peft.assert_called_once()
|