ml-model-loader/src_python/tqftw_model_loader/hf_loader.py

362 lines
12 KiB
Python
Raw Permalink Normal View History

2025-12-28 04:32:35 -08:00
"""
HuggingFace Transformers model loader.
Loads models from HuggingFace Hub or local paths using the transformers library.
"""
from pathlib import Path
from typing import Optional, Any, Union, Literal
import time
import logging
from .base import BaseModelLoader, ModelInfo, ModelLoadError, ModelNotFoundError
from .device import DeviceManager, get_best_device
from .registry import register_loader
logger = logging.getLogger(__name__)
# Type alias for the various model types transformers can return
TransformersModel = Any # Could be PreTrainedModel, Pipeline, etc.
TaskType = Literal[
"text-generation",
"text-classification",
"token-classification",
"question-answering",
"summarization",
"translation",
"fill-mask",
"image-classification",
"object-detection",
"image-segmentation",
"automatic-speech-recognition",
"audio-classification",
"zero-shot-classification",
"feature-extraction",
]
@register_loader("hf", aliases=["huggingface", "transformers", "hf-transformers"])
class HFModelLoader(BaseModelLoader[TransformersModel]):
"""
HuggingFace Transformers model loader.
Supports loading models as pipelines or raw model/tokenizer pairs.
Example:
>>> loader = HFModelLoader()
>>> # Load as pipeline (recommended for inference)
>>> classifier = await loader.load(
... "Marqo/nsfw-image-detection-384",
... task="image-classification"
... )
>>> result = classifier(image)
>>> # Load raw model and tokenizer
>>> model = await loader.load(
... "Qwen/Qwen2.5-7B-Instruct",
... as_pipeline=False
... )
>>> model, tokenizer = loader.get_model_and_tokenizer()
"""
def __init__(self) -> None:
super().__init__()
self._tokenizer: Any = None
self._processor: Any = None
@property
def tokenizer(self) -> Any:
"""Get the tokenizer for the loaded model."""
return self._tokenizer
@property
def processor(self) -> Any:
"""Get the processor for the loaded model (for multimodal models)."""
return self._processor
def get_model_and_tokenizer(self) -> tuple[Any, Any]:
"""Get both model and tokenizer."""
return self._model, self._tokenizer
async def load(
self,
model_id: str,
*,
task: Optional[TaskType] = None,
device: Optional[str] = None,
dtype: Optional[str] = None,
as_pipeline: bool = True,
trust_remote_code: bool = False,
use_fast_tokenizer: bool = True,
low_cpu_mem_usage: bool = True,
torch_compile: bool = False,
**kwargs: Any,
) -> TransformersModel:
"""
Load a HuggingFace model.
Args:
model_id: HuggingFace model ID or local path
task: Pipeline task type (required if as_pipeline=True)
device: Device to load on (default: auto-detect)
dtype: Data type ("float16", "bfloat16", "float32", "auto")
as_pipeline: Load as a pipeline (True) or raw model (False)
trust_remote_code: Allow running remote code
use_fast_tokenizer: Use fast tokenizer if available
low_cpu_mem_usage: Reduce CPU memory during loading
torch_compile: Apply torch.compile() for faster inference
**kwargs: Additional arguments for AutoModel/pipeline
Returns:
Loaded model (Pipeline or PreTrainedModel)
"""
if self._loading:
raise ModelLoadError(model_id, "Another load operation is in progress")
self._loading = True
start_time = time.time()
try:
import torch
from transformers import (
AutoModel,
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
AutoModelForQuestionAnswering,
AutoTokenizer,
AutoProcessor,
AutoImageProcessor,
pipeline,
)
# Determine device
if device is None:
device = get_best_device()
# Determine dtype
torch_dtype = None
if dtype:
dtype_map = {
"float16": torch.float16,
"fp16": torch.float16,
"bfloat16": torch.bfloat16,
"bf16": torch.bfloat16,
"float32": torch.float32,
"fp32": torch.float32,
"auto": "auto",
}
torch_dtype = dtype_map.get(dtype, dtype)
# Unload existing model if any
if self._model is not None:
await self.unload()
if as_pipeline:
if task is None:
raise ModelLoadError(
model_id,
"task is required when loading as pipeline"
)
# Load as pipeline
pipe_kwargs = {
"model": model_id,
"task": task,
"device": device if device != "cpu" else -1,
"trust_remote_code": trust_remote_code,
**kwargs,
}
if torch_dtype and torch_dtype != "auto":
pipe_kwargs["torch_dtype"] = torch_dtype
self._model = pipeline(**pipe_kwargs)
else:
# Load raw model and tokenizer
model_kwargs = {
"trust_remote_code": trust_remote_code,
"low_cpu_mem_usage": low_cpu_mem_usage,
**kwargs,
}
if torch_dtype:
model_kwargs["torch_dtype"] = torch_dtype
# Determine model class based on task
model_class = AutoModel
if task:
task_to_class = {
"text-generation": AutoModelForCausalLM,
"text-classification": AutoModelForSequenceClassification,
"token-classification": AutoModelForTokenClassification,
"question-answering": AutoModelForQuestionAnswering,
"summarization": AutoModelForSeq2SeqLM,
"translation": AutoModelForSeq2SeqLM,
}
model_class = task_to_class.get(task, AutoModel)
# Load model
self._model = model_class.from_pretrained(model_id, **model_kwargs)
# Move to device
if device != "cpu":
self._model = self._model.to(device)
# Apply torch.compile if requested
if torch_compile and hasattr(torch, "compile"):
self._model = torch.compile(self._model)
# Load tokenizer or processor
try:
self._tokenizer = AutoTokenizer.from_pretrained(
model_id,
trust_remote_code=trust_remote_code,
use_fast=use_fast_tokenizer,
)
except Exception:
pass # Some models don't have tokenizers
try:
self._processor = AutoProcessor.from_pretrained(
model_id,
trust_remote_code=trust_remote_code,
)
except Exception:
pass # Some models don't have processors
# Calculate memory usage
memory_used = 0.0
if device.startswith("cuda"):
memory_used = torch.cuda.memory_allocated() / 1024 / 1024
# Store model info
self._model_info = ModelInfo(
model_id=model_id,
device=device,
dtype=dtype,
memory_used_mb=memory_used,
load_time_seconds=time.time() - start_time,
metadata={
"task": task,
"as_pipeline": as_pipeline,
"torch_compile": torch_compile,
},
)
logger.info(
f"Loaded {model_id} on {device} in {self._model_info.load_time_seconds:.2f}s"
)
return self._model
except ImportError as e:
raise ModelLoadError(
model_id,
"transformers library not installed. Install with: pip install transformers",
cause=e,
)
except Exception as e:
raise ModelLoadError(model_id, str(e), cause=e)
finally:
self._loading = False
async def unload(self) -> None:
"""Unload the model and free GPU memory."""
if self._model is None:
return
try:
import torch
# Delete model
del self._model
self._model = None
# Delete tokenizer/processor
if self._tokenizer:
del self._tokenizer
self._tokenizer = None
if self._processor:
del self._processor
self._processor = None
# Clear GPU cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
self._model_info = None
logger.debug("Model unloaded and GPU cache cleared")
except ImportError:
self._model = None
self._tokenizer = None
self._processor = None
self._model_info = None
def generate(
self,
prompt: str,
max_new_tokens: int = 256,
temperature: float = 0.7,
top_p: float = 0.9,
**kwargs: Any,
) -> str:
"""
Generate text using the loaded model.
Only works if model was loaded with task="text-generation" or as a causal LM.
Args:
prompt: Input prompt
max_new_tokens: Maximum tokens to generate
temperature: Sampling temperature
top_p: Nucleus sampling parameter
**kwargs: Additional generation arguments
Returns:
Generated text
"""
if not self._model:
raise ValueError("No model loaded")
# If it's a pipeline, use it directly
if hasattr(self._model, "__call__"):
result = self._model(
prompt,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
**kwargs,
)
if isinstance(result, list) and len(result) > 0:
return result[0].get("generated_text", "")
return str(result)
# Raw model generation
if not self._tokenizer:
raise ValueError("No tokenizer available for generation")
import torch
inputs = self._tokenizer(prompt, return_tensors="pt")
inputs = {k: v.to(self._model.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self._model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=temperature > 0,
**kwargs,
)
return self._tokenizer.decode(outputs[0], skip_special_tokens=True)