421 lines
14 KiB
Python
421 lines
14 KiB
Python
"""
|
|
Diffusers model loader for Stable Diffusion and SDXL.
|
|
|
|
Loads image generation pipelines from HuggingFace or local paths.
|
|
"""
|
|
|
|
from pathlib import Path
|
|
from typing import Optional, Any, Union, Literal, List
|
|
import time
|
|
import logging
|
|
|
|
from .base import BaseModelLoader, ModelInfo, ModelLoadError
|
|
from .device import DeviceManager, get_best_device
|
|
from .registry import register_loader
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Type alias for diffusion pipelines
|
|
DiffusionPipeline = Any
|
|
|
|
PipelineType = Literal[
|
|
"sdxl",
|
|
"sd15",
|
|
"sd21",
|
|
"flux",
|
|
"kandinsky",
|
|
"if",
|
|
"controlnet",
|
|
"img2img",
|
|
"inpaint",
|
|
]
|
|
|
|
|
|
@register_loader("diffusers", aliases=["stable-diffusion", "sdxl", "sd", "flux"])
|
|
class DiffusersLoader(BaseModelLoader[DiffusionPipeline]):
|
|
"""
|
|
Diffusers model loader for image generation.
|
|
|
|
Supports SDXL, SD 1.5, SD 2.1, Flux, and other diffusion models.
|
|
|
|
Example:
|
|
>>> loader = DiffusersLoader()
|
|
|
|
>>> # Load SDXL
|
|
>>> pipeline = await loader.load(
|
|
... "stabilityai/stable-diffusion-xl-base-1.0",
|
|
... dtype="float16",
|
|
... device="cuda:0"
|
|
... )
|
|
|
|
>>> # Generate image
|
|
>>> image = pipeline("a photo of a cat").images[0]
|
|
|
|
>>> # Load from single file (safetensors/ckpt)
|
|
>>> pipeline = await loader.load(
|
|
... "/path/to/model.safetensors",
|
|
... pipeline_type="sdxl"
|
|
... )
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self._scheduler: Any = None
|
|
self._vae: Any = None
|
|
|
|
@property
|
|
def scheduler(self) -> Any:
|
|
"""Get the current scheduler."""
|
|
return self._scheduler
|
|
|
|
@property
|
|
def vae(self) -> Any:
|
|
"""Get the current VAE."""
|
|
return self._vae
|
|
|
|
async def load(
|
|
self,
|
|
model_id: str,
|
|
*,
|
|
pipeline_type: Optional[PipelineType] = None,
|
|
device: Optional[str] = None,
|
|
dtype: Optional[str] = None,
|
|
enable_attention_slicing: bool = True,
|
|
enable_vae_slicing: bool = True,
|
|
enable_vae_tiling: bool = False,
|
|
enable_model_cpu_offload: bool = False,
|
|
enable_sequential_cpu_offload: bool = False,
|
|
use_safetensors: bool = True,
|
|
variant: Optional[str] = None,
|
|
custom_pipeline: Optional[str] = None,
|
|
scheduler: Optional[str] = None,
|
|
**kwargs: Any,
|
|
) -> DiffusionPipeline:
|
|
"""
|
|
Load a diffusion model pipeline.
|
|
|
|
Args:
|
|
model_id: HuggingFace model ID, local path, or URL to single file
|
|
pipeline_type: Type of pipeline (auto-detected if not specified)
|
|
device: Device to load on (default: auto-detect)
|
|
dtype: Data type ("float16", "bfloat16", "float32")
|
|
enable_attention_slicing: Reduce VRAM usage for attention
|
|
enable_vae_slicing: Reduce VRAM usage for VAE
|
|
enable_vae_tiling: Enable tiled VAE for very high resolution
|
|
enable_model_cpu_offload: Offload model to CPU when not in use
|
|
enable_sequential_cpu_offload: Aggressively offload to CPU
|
|
use_safetensors: Prefer safetensors format
|
|
variant: Model variant (e.g., "fp16")
|
|
custom_pipeline: Custom pipeline class name
|
|
scheduler: Scheduler to use (e.g., "DPMSolverMultistep")
|
|
**kwargs: Additional arguments for pipeline
|
|
|
|
Returns:
|
|
Loaded diffusion pipeline
|
|
"""
|
|
if self._loading:
|
|
raise ModelLoadError(model_id, "Another load operation is in progress")
|
|
|
|
self._loading = True
|
|
start_time = time.time()
|
|
|
|
try:
|
|
import torch
|
|
from diffusers import (
|
|
StableDiffusionPipeline,
|
|
StableDiffusionXLPipeline,
|
|
StableDiffusionXLImg2ImgPipeline,
|
|
DiffusionPipeline as BaseDiffusionPipeline,
|
|
AutoPipelineForText2Image,
|
|
AutoPipelineForImage2Image,
|
|
AutoPipelineForInpainting,
|
|
)
|
|
|
|
# Determine device
|
|
if device is None:
|
|
device = get_best_device()
|
|
|
|
# Determine dtype
|
|
torch_dtype = torch.float32
|
|
if dtype:
|
|
dtype_map = {
|
|
"float16": torch.float16,
|
|
"fp16": torch.float16,
|
|
"bfloat16": torch.bfloat16,
|
|
"bf16": torch.bfloat16,
|
|
"float32": torch.float32,
|
|
"fp32": torch.float32,
|
|
}
|
|
torch_dtype = dtype_map.get(dtype, torch.float32)
|
|
elif device != "cpu":
|
|
# Default to float16 on GPU
|
|
torch_dtype = torch.float16
|
|
|
|
# Unload existing model
|
|
if self._model is not None:
|
|
await self.unload()
|
|
|
|
# Check if it's a single file
|
|
model_path = Path(model_id)
|
|
is_single_file = model_path.exists() and model_path.suffix in (
|
|
".safetensors",
|
|
".ckpt",
|
|
".pt",
|
|
)
|
|
|
|
if is_single_file:
|
|
# Load from single file
|
|
pipeline_class = self._get_pipeline_class_for_type(pipeline_type)
|
|
self._model = pipeline_class.from_single_file(
|
|
str(model_path),
|
|
torch_dtype=torch_dtype,
|
|
use_safetensors=use_safetensors,
|
|
**kwargs,
|
|
)
|
|
else:
|
|
# Load from HuggingFace or directory
|
|
load_kwargs = {
|
|
"torch_dtype": torch_dtype,
|
|
"use_safetensors": use_safetensors,
|
|
**kwargs,
|
|
}
|
|
|
|
if variant:
|
|
load_kwargs["variant"] = variant
|
|
|
|
if custom_pipeline:
|
|
load_kwargs["custom_pipeline"] = custom_pipeline
|
|
|
|
# Try auto pipeline first
|
|
try:
|
|
if pipeline_type == "img2img":
|
|
self._model = AutoPipelineForImage2Image.from_pretrained(
|
|
model_id, **load_kwargs
|
|
)
|
|
elif pipeline_type == "inpaint":
|
|
self._model = AutoPipelineForInpainting.from_pretrained(
|
|
model_id, **load_kwargs
|
|
)
|
|
else:
|
|
self._model = AutoPipelineForText2Image.from_pretrained(
|
|
model_id, **load_kwargs
|
|
)
|
|
except Exception:
|
|
# Fall back to specific pipeline class
|
|
pipeline_class = self._get_pipeline_class_for_type(pipeline_type)
|
|
self._model = pipeline_class.from_pretrained(
|
|
model_id, **load_kwargs
|
|
)
|
|
|
|
# Move to device
|
|
if not enable_model_cpu_offload and not enable_sequential_cpu_offload:
|
|
self._model = self._model.to(device)
|
|
|
|
# Apply memory optimizations
|
|
if enable_attention_slicing:
|
|
self._model.enable_attention_slicing()
|
|
|
|
if enable_vae_slicing:
|
|
self._model.enable_vae_slicing()
|
|
|
|
if enable_vae_tiling:
|
|
self._model.enable_vae_tiling()
|
|
|
|
if enable_model_cpu_offload:
|
|
self._model.enable_model_cpu_offload()
|
|
|
|
if enable_sequential_cpu_offload:
|
|
self._model.enable_sequential_cpu_offload()
|
|
|
|
# Change scheduler if requested
|
|
if scheduler:
|
|
self._set_scheduler(scheduler)
|
|
|
|
# Store references
|
|
if hasattr(self._model, "scheduler"):
|
|
self._scheduler = self._model.scheduler
|
|
|
|
if hasattr(self._model, "vae"):
|
|
self._vae = self._model.vae
|
|
|
|
# Calculate memory usage
|
|
memory_used = 0.0
|
|
if device.startswith("cuda") and torch.cuda.is_available():
|
|
memory_used = torch.cuda.memory_allocated() / 1024 / 1024
|
|
|
|
# Store model info
|
|
self._model_info = ModelInfo(
|
|
model_id=model_id,
|
|
path=model_path if is_single_file else None,
|
|
device=device,
|
|
dtype=dtype,
|
|
memory_used_mb=memory_used,
|
|
load_time_seconds=time.time() - start_time,
|
|
metadata={
|
|
"pipeline_type": pipeline_type,
|
|
"scheduler": scheduler,
|
|
"is_single_file": is_single_file,
|
|
},
|
|
)
|
|
|
|
logger.info(
|
|
f"Loaded {model_id} on {device} in {self._model_info.load_time_seconds:.2f}s"
|
|
)
|
|
|
|
return self._model
|
|
|
|
except ImportError as e:
|
|
raise ModelLoadError(
|
|
model_id,
|
|
"diffusers library not installed. Install with: pip install diffusers",
|
|
cause=e,
|
|
)
|
|
except Exception as e:
|
|
raise ModelLoadError(model_id, str(e), cause=e)
|
|
finally:
|
|
self._loading = False
|
|
|
|
def _get_pipeline_class_for_type(
|
|
self, pipeline_type: Optional[PipelineType]
|
|
) -> type:
|
|
"""Get the appropriate pipeline class for the type."""
|
|
from diffusers import (
|
|
StableDiffusionPipeline,
|
|
StableDiffusionXLPipeline,
|
|
DiffusionPipeline,
|
|
)
|
|
|
|
type_to_class = {
|
|
"sdxl": StableDiffusionXLPipeline,
|
|
"sd15": StableDiffusionPipeline,
|
|
"sd21": StableDiffusionPipeline,
|
|
}
|
|
|
|
return type_to_class.get(pipeline_type, DiffusionPipeline)
|
|
|
|
def _set_scheduler(self, scheduler_name: str) -> None:
|
|
"""Set a different scheduler."""
|
|
from diffusers import (
|
|
DDIMScheduler,
|
|
DDPMScheduler,
|
|
PNDMScheduler,
|
|
LMSDiscreteScheduler,
|
|
EulerDiscreteScheduler,
|
|
EulerAncestralDiscreteScheduler,
|
|
DPMSolverMultistepScheduler,
|
|
DPMSolverSinglestepScheduler,
|
|
HeunDiscreteScheduler,
|
|
KDPM2DiscreteScheduler,
|
|
KDPM2AncestralDiscreteScheduler,
|
|
UniPCMultistepScheduler,
|
|
)
|
|
|
|
schedulers = {
|
|
"ddim": DDIMScheduler,
|
|
"ddpm": DDPMScheduler,
|
|
"pndm": PNDMScheduler,
|
|
"lms": LMSDiscreteScheduler,
|
|
"euler": EulerDiscreteScheduler,
|
|
"euler_a": EulerAncestralDiscreteScheduler,
|
|
"euler_ancestral": EulerAncestralDiscreteScheduler,
|
|
"dpm": DPMSolverMultistepScheduler,
|
|
"dpm_solver": DPMSolverMultistepScheduler,
|
|
"dpm_solver_multistep": DPMSolverMultistepScheduler,
|
|
"dpm_solver_singlestep": DPMSolverSinglestepScheduler,
|
|
"heun": HeunDiscreteScheduler,
|
|
"kdpm2": KDPM2DiscreteScheduler,
|
|
"kdpm2_a": KDPM2AncestralDiscreteScheduler,
|
|
"kdpm2_ancestral": KDPM2AncestralDiscreteScheduler,
|
|
"unipc": UniPCMultistepScheduler,
|
|
}
|
|
|
|
scheduler_key = scheduler_name.lower().replace("-", "_")
|
|
if scheduler_key in schedulers:
|
|
scheduler_class = schedulers[scheduler_key]
|
|
self._model.scheduler = scheduler_class.from_config(
|
|
self._model.scheduler.config
|
|
)
|
|
else:
|
|
logger.warning(f"Unknown scheduler: {scheduler_name}")
|
|
|
|
async def unload(self) -> None:
|
|
"""Unload the pipeline and free GPU memory."""
|
|
if self._model is None:
|
|
return
|
|
|
|
try:
|
|
import torch
|
|
|
|
# Delete pipeline and components
|
|
del self._model
|
|
self._model = None
|
|
|
|
self._scheduler = None
|
|
self._vae = None
|
|
|
|
# Clear GPU cache
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
|
|
self._model_info = None
|
|
|
|
logger.debug("Pipeline unloaded and GPU cache cleared")
|
|
|
|
except ImportError:
|
|
self._model = None
|
|
self._scheduler = None
|
|
self._vae = None
|
|
self._model_info = None
|
|
|
|
def generate(
|
|
self,
|
|
prompt: str,
|
|
negative_prompt: Optional[str] = None,
|
|
num_inference_steps: int = 30,
|
|
guidance_scale: float = 7.5,
|
|
width: int = 1024,
|
|
height: int = 1024,
|
|
num_images: int = 1,
|
|
seed: Optional[int] = None,
|
|
**kwargs: Any,
|
|
) -> List[Any]:
|
|
"""
|
|
Generate images from a prompt.
|
|
|
|
Args:
|
|
prompt: Text prompt for generation
|
|
negative_prompt: Things to avoid in the image
|
|
num_inference_steps: Number of denoising steps
|
|
guidance_scale: How closely to follow the prompt
|
|
width: Image width
|
|
height: Image height
|
|
num_images: Number of images to generate
|
|
seed: Random seed for reproducibility
|
|
**kwargs: Additional pipeline arguments
|
|
|
|
Returns:
|
|
List of generated PIL images
|
|
"""
|
|
if not self._model:
|
|
raise ValueError("No pipeline loaded")
|
|
|
|
import torch
|
|
|
|
generator = None
|
|
if seed is not None:
|
|
device = self._model_info.device if self._model_info else "cpu"
|
|
generator = torch.Generator(device=device).manual_seed(seed)
|
|
|
|
result = self._model(
|
|
prompt=prompt,
|
|
negative_prompt=negative_prompt,
|
|
num_inference_steps=num_inference_steps,
|
|
guidance_scale=guidance_scale,
|
|
width=width,
|
|
height=height,
|
|
num_images_per_prompt=num_images,
|
|
generator=generator,
|
|
**kwargs,
|
|
)
|
|
|
|
return result.images
|