ml-model-loader/src_python/build/lib/tqftw_model_loader/diffusers_loader.py

422 lines
14 KiB
Python
Raw Permalink Normal View History

"""
Diffusers model loader for Stable Diffusion and SDXL.
Loads image generation pipelines from HuggingFace or local paths.
"""
from pathlib import Path
from typing import Optional, Any, Union, Literal, List
import time
import logging
from .base import BaseModelLoader, ModelInfo, ModelLoadError
from .device import DeviceManager, get_best_device
from .registry import register_loader
logger = logging.getLogger(__name__)
# Type alias for diffusion pipelines
DiffusionPipeline = Any
PipelineType = Literal[
"sdxl",
"sd15",
"sd21",
"flux",
"kandinsky",
"if",
"controlnet",
"img2img",
"inpaint",
]
@register_loader("diffusers", aliases=["stable-diffusion", "sdxl", "sd", "flux"])
class DiffusersLoader(BaseModelLoader[DiffusionPipeline]):
"""
Diffusers model loader for image generation.
Supports SDXL, SD 1.5, SD 2.1, Flux, and other diffusion models.
Example:
>>> loader = DiffusersLoader()
>>> # Load SDXL
>>> pipeline = await loader.load(
... "stabilityai/stable-diffusion-xl-base-1.0",
... dtype="float16",
... device="cuda:0"
... )
>>> # Generate image
>>> image = pipeline("a photo of a cat").images[0]
>>> # Load from single file (safetensors/ckpt)
>>> pipeline = await loader.load(
... "/path/to/model.safetensors",
... pipeline_type="sdxl"
... )
"""
def __init__(self) -> None:
super().__init__()
self._scheduler: Any = None
self._vae: Any = None
@property
def scheduler(self) -> Any:
"""Get the current scheduler."""
return self._scheduler
@property
def vae(self) -> Any:
"""Get the current VAE."""
return self._vae
async def load(
self,
model_id: str,
*,
pipeline_type: Optional[PipelineType] = None,
device: Optional[str] = None,
dtype: Optional[str] = None,
enable_attention_slicing: bool = True,
enable_vae_slicing: bool = True,
enable_vae_tiling: bool = False,
enable_model_cpu_offload: bool = False,
enable_sequential_cpu_offload: bool = False,
use_safetensors: bool = True,
variant: Optional[str] = None,
custom_pipeline: Optional[str] = None,
scheduler: Optional[str] = None,
**kwargs: Any,
) -> DiffusionPipeline:
"""
Load a diffusion model pipeline.
Args:
model_id: HuggingFace model ID, local path, or URL to single file
pipeline_type: Type of pipeline (auto-detected if not specified)
device: Device to load on (default: auto-detect)
dtype: Data type ("float16", "bfloat16", "float32")
enable_attention_slicing: Reduce VRAM usage for attention
enable_vae_slicing: Reduce VRAM usage for VAE
enable_vae_tiling: Enable tiled VAE for very high resolution
enable_model_cpu_offload: Offload model to CPU when not in use
enable_sequential_cpu_offload: Aggressively offload to CPU
use_safetensors: Prefer safetensors format
variant: Model variant (e.g., "fp16")
custom_pipeline: Custom pipeline class name
scheduler: Scheduler to use (e.g., "DPMSolverMultistep")
**kwargs: Additional arguments for pipeline
Returns:
Loaded diffusion pipeline
"""
if self._loading:
raise ModelLoadError(model_id, "Another load operation is in progress")
self._loading = True
start_time = time.time()
try:
import torch
from diffusers import (
StableDiffusionPipeline,
StableDiffusionXLPipeline,
StableDiffusionXLImg2ImgPipeline,
DiffusionPipeline as BaseDiffusionPipeline,
AutoPipelineForText2Image,
AutoPipelineForImage2Image,
AutoPipelineForInpainting,
)
# Determine device
if device is None:
device = get_best_device()
# Determine dtype
torch_dtype = torch.float32
if dtype:
dtype_map = {
"float16": torch.float16,
"fp16": torch.float16,
"bfloat16": torch.bfloat16,
"bf16": torch.bfloat16,
"float32": torch.float32,
"fp32": torch.float32,
}
torch_dtype = dtype_map.get(dtype, torch.float32)
elif device != "cpu":
# Default to float16 on GPU
torch_dtype = torch.float16
# Unload existing model
if self._model is not None:
await self.unload()
# Check if it's a single file
model_path = Path(model_id)
is_single_file = model_path.exists() and model_path.suffix in (
".safetensors",
".ckpt",
".pt",
)
if is_single_file:
# Load from single file
pipeline_class = self._get_pipeline_class_for_type(pipeline_type)
self._model = pipeline_class.from_single_file(
str(model_path),
torch_dtype=torch_dtype,
use_safetensors=use_safetensors,
**kwargs,
)
else:
# Load from HuggingFace or directory
load_kwargs = {
"torch_dtype": torch_dtype,
"use_safetensors": use_safetensors,
**kwargs,
}
if variant:
load_kwargs["variant"] = variant
if custom_pipeline:
load_kwargs["custom_pipeline"] = custom_pipeline
# Try auto pipeline first
try:
if pipeline_type == "img2img":
self._model = AutoPipelineForImage2Image.from_pretrained(
model_id, **load_kwargs
)
elif pipeline_type == "inpaint":
self._model = AutoPipelineForInpainting.from_pretrained(
model_id, **load_kwargs
)
else:
self._model = AutoPipelineForText2Image.from_pretrained(
model_id, **load_kwargs
)
except Exception:
# Fall back to specific pipeline class
pipeline_class = self._get_pipeline_class_for_type(pipeline_type)
self._model = pipeline_class.from_pretrained(
model_id, **load_kwargs
)
# Move to device
if not enable_model_cpu_offload and not enable_sequential_cpu_offload:
self._model = self._model.to(device)
# Apply memory optimizations
if enable_attention_slicing:
self._model.enable_attention_slicing()
if enable_vae_slicing:
self._model.enable_vae_slicing()
if enable_vae_tiling:
self._model.enable_vae_tiling()
if enable_model_cpu_offload:
self._model.enable_model_cpu_offload()
if enable_sequential_cpu_offload:
self._model.enable_sequential_cpu_offload()
# Change scheduler if requested
if scheduler:
self._set_scheduler(scheduler)
# Store references
if hasattr(self._model, "scheduler"):
self._scheduler = self._model.scheduler
if hasattr(self._model, "vae"):
self._vae = self._model.vae
# Calculate memory usage
memory_used = 0.0
if device.startswith("cuda") and torch.cuda.is_available():
memory_used = torch.cuda.memory_allocated() / 1024 / 1024
# Store model info
self._model_info = ModelInfo(
model_id=model_id,
path=model_path if is_single_file else None,
device=device,
dtype=dtype,
memory_used_mb=memory_used,
load_time_seconds=time.time() - start_time,
metadata={
"pipeline_type": pipeline_type,
"scheduler": scheduler,
"is_single_file": is_single_file,
},
)
logger.info(
f"Loaded {model_id} on {device} in {self._model_info.load_time_seconds:.2f}s"
)
return self._model
except ImportError as e:
raise ModelLoadError(
model_id,
"diffusers library not installed. Install with: pip install diffusers",
cause=e,
)
except Exception as e:
raise ModelLoadError(model_id, str(e), cause=e)
finally:
self._loading = False
def _get_pipeline_class_for_type(
self, pipeline_type: Optional[PipelineType]
) -> type:
"""Get the appropriate pipeline class for the type."""
from diffusers import (
StableDiffusionPipeline,
StableDiffusionXLPipeline,
DiffusionPipeline,
)
type_to_class = {
"sdxl": StableDiffusionXLPipeline,
"sd15": StableDiffusionPipeline,
"sd21": StableDiffusionPipeline,
}
return type_to_class.get(pipeline_type, DiffusionPipeline)
def _set_scheduler(self, scheduler_name: str) -> None:
"""Set a different scheduler."""
from diffusers import (
DDIMScheduler,
DDPMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DPMSolverSinglestepScheduler,
HeunDiscreteScheduler,
KDPM2DiscreteScheduler,
KDPM2AncestralDiscreteScheduler,
UniPCMultistepScheduler,
)
schedulers = {
"ddim": DDIMScheduler,
"ddpm": DDPMScheduler,
"pndm": PNDMScheduler,
"lms": LMSDiscreteScheduler,
"euler": EulerDiscreteScheduler,
"euler_a": EulerAncestralDiscreteScheduler,
"euler_ancestral": EulerAncestralDiscreteScheduler,
"dpm": DPMSolverMultistepScheduler,
"dpm_solver": DPMSolverMultistepScheduler,
"dpm_solver_multistep": DPMSolverMultistepScheduler,
"dpm_solver_singlestep": DPMSolverSinglestepScheduler,
"heun": HeunDiscreteScheduler,
"kdpm2": KDPM2DiscreteScheduler,
"kdpm2_a": KDPM2AncestralDiscreteScheduler,
"kdpm2_ancestral": KDPM2AncestralDiscreteScheduler,
"unipc": UniPCMultistepScheduler,
}
scheduler_key = scheduler_name.lower().replace("-", "_")
if scheduler_key in schedulers:
scheduler_class = schedulers[scheduler_key]
self._model.scheduler = scheduler_class.from_config(
self._model.scheduler.config
)
else:
logger.warning(f"Unknown scheduler: {scheduler_name}")
async def unload(self) -> None:
"""Unload the pipeline and free GPU memory."""
if self._model is None:
return
try:
import torch
# Delete pipeline and components
del self._model
self._model = None
self._scheduler = None
self._vae = None
# Clear GPU cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
self._model_info = None
logger.debug("Pipeline unloaded and GPU cache cleared")
except ImportError:
self._model = None
self._scheduler = None
self._vae = None
self._model_info = None
def generate(
self,
prompt: str,
negative_prompt: Optional[str] = None,
num_inference_steps: int = 30,
guidance_scale: float = 7.5,
width: int = 1024,
height: int = 1024,
num_images: int = 1,
seed: Optional[int] = None,
**kwargs: Any,
) -> List[Any]:
"""
Generate images from a prompt.
Args:
prompt: Text prompt for generation
negative_prompt: Things to avoid in the image
num_inference_steps: Number of denoising steps
guidance_scale: How closely to follow the prompt
width: Image width
height: Image height
num_images: Number of images to generate
seed: Random seed for reproducibility
**kwargs: Additional pipeline arguments
Returns:
List of generated PIL images
"""
if not self._model:
raise ValueError("No pipeline loaded")
import torch
generator = None
if seed is not None:
device = self._model_info.device if self._model_info else "cpu"
generator = torch.Generator(device=device).manual_seed(seed)
result = self._model(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
width=width,
height=height,
num_images_per_prompt=num_images,
generator=generator,
**kwargs,
)
return result.images