""" Diffusers model loader for Stable Diffusion and SDXL. Loads image generation pipelines from HuggingFace or local paths. """ from pathlib import Path from typing import Optional, Any, Union, Literal, List import time import logging from .base import BaseModelLoader, ModelInfo, ModelLoadError from .device import DeviceManager, get_best_device from .registry import register_loader logger = logging.getLogger(__name__) # Type alias for diffusion pipelines DiffusionPipeline = Any PipelineType = Literal[ "sdxl", "sd15", "sd21", "flux", "kandinsky", "if", "controlnet", "img2img", "inpaint", ] @register_loader("diffusers", aliases=["stable-diffusion", "sdxl", "sd", "flux"]) class DiffusersLoader(BaseModelLoader[DiffusionPipeline]): """ Diffusers model loader for image generation. Supports SDXL, SD 1.5, SD 2.1, Flux, and other diffusion models. Example: >>> loader = DiffusersLoader() >>> # Load SDXL >>> pipeline = await loader.load( ... "stabilityai/stable-diffusion-xl-base-1.0", ... dtype="float16", ... device="cuda:0" ... ) >>> # Generate image >>> image = pipeline("a photo of a cat").images[0] >>> # Load from single file (safetensors/ckpt) >>> pipeline = await loader.load( ... "/path/to/model.safetensors", ... pipeline_type="sdxl" ... ) """ def __init__(self) -> None: super().__init__() self._scheduler: Any = None self._vae: Any = None @property def scheduler(self) -> Any: """Get the current scheduler.""" return self._scheduler @property def vae(self) -> Any: """Get the current VAE.""" return self._vae async def load( self, model_id: str, *, pipeline_type: Optional[PipelineType] = None, device: Optional[str] = None, dtype: Optional[str] = None, enable_attention_slicing: bool = True, enable_vae_slicing: bool = True, enable_vae_tiling: bool = False, enable_model_cpu_offload: bool = False, enable_sequential_cpu_offload: bool = False, use_safetensors: bool = True, variant: Optional[str] = None, custom_pipeline: Optional[str] = None, scheduler: Optional[str] = None, **kwargs: Any, ) -> DiffusionPipeline: """ Load a diffusion model pipeline. Args: model_id: HuggingFace model ID, local path, or URL to single file pipeline_type: Type of pipeline (auto-detected if not specified) device: Device to load on (default: auto-detect) dtype: Data type ("float16", "bfloat16", "float32") enable_attention_slicing: Reduce VRAM usage for attention enable_vae_slicing: Reduce VRAM usage for VAE enable_vae_tiling: Enable tiled VAE for very high resolution enable_model_cpu_offload: Offload model to CPU when not in use enable_sequential_cpu_offload: Aggressively offload to CPU use_safetensors: Prefer safetensors format variant: Model variant (e.g., "fp16") custom_pipeline: Custom pipeline class name scheduler: Scheduler to use (e.g., "DPMSolverMultistep") **kwargs: Additional arguments for pipeline Returns: Loaded diffusion pipeline """ if self._loading: raise ModelLoadError(model_id, "Another load operation is in progress") self._loading = True start_time = time.time() try: import torch from diffusers import ( StableDiffusionPipeline, StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, DiffusionPipeline as BaseDiffusionPipeline, AutoPipelineForText2Image, AutoPipelineForImage2Image, AutoPipelineForInpainting, ) # Determine device if device is None: device = get_best_device() # Determine dtype torch_dtype = torch.float32 if dtype: dtype_map = { "float16": torch.float16, "fp16": torch.float16, "bfloat16": torch.bfloat16, "bf16": torch.bfloat16, "float32": torch.float32, "fp32": torch.float32, } torch_dtype = dtype_map.get(dtype, torch.float32) elif device != "cpu": # Default to float16 on GPU torch_dtype = torch.float16 # Unload existing model if self._model is not None: await self.unload() # Check if it's a single file model_path = Path(model_id) is_single_file = model_path.exists() and model_path.suffix in ( ".safetensors", ".ckpt", ".pt", ) if is_single_file: # Load from single file pipeline_class = self._get_pipeline_class_for_type(pipeline_type) self._model = pipeline_class.from_single_file( str(model_path), torch_dtype=torch_dtype, use_safetensors=use_safetensors, **kwargs, ) else: # Load from HuggingFace or directory load_kwargs = { "torch_dtype": torch_dtype, "use_safetensors": use_safetensors, **kwargs, } if variant: load_kwargs["variant"] = variant if custom_pipeline: load_kwargs["custom_pipeline"] = custom_pipeline # Try auto pipeline first try: if pipeline_type == "img2img": self._model = AutoPipelineForImage2Image.from_pretrained( model_id, **load_kwargs ) elif pipeline_type == "inpaint": self._model = AutoPipelineForInpainting.from_pretrained( model_id, **load_kwargs ) else: self._model = AutoPipelineForText2Image.from_pretrained( model_id, **load_kwargs ) except Exception: # Fall back to specific pipeline class pipeline_class = self._get_pipeline_class_for_type(pipeline_type) self._model = pipeline_class.from_pretrained( model_id, **load_kwargs ) # Move to device if not enable_model_cpu_offload and not enable_sequential_cpu_offload: self._model = self._model.to(device) # Apply memory optimizations if enable_attention_slicing: self._model.enable_attention_slicing() if enable_vae_slicing: self._model.enable_vae_slicing() if enable_vae_tiling: self._model.enable_vae_tiling() if enable_model_cpu_offload: self._model.enable_model_cpu_offload() if enable_sequential_cpu_offload: self._model.enable_sequential_cpu_offload() # Change scheduler if requested if scheduler: self._set_scheduler(scheduler) # Store references if hasattr(self._model, "scheduler"): self._scheduler = self._model.scheduler if hasattr(self._model, "vae"): self._vae = self._model.vae # Calculate memory usage memory_used = 0.0 if device.startswith("cuda") and torch.cuda.is_available(): memory_used = torch.cuda.memory_allocated() / 1024 / 1024 # Store model info self._model_info = ModelInfo( model_id=model_id, path=model_path if is_single_file else None, device=device, dtype=dtype, memory_used_mb=memory_used, load_time_seconds=time.time() - start_time, metadata={ "pipeline_type": pipeline_type, "scheduler": scheduler, "is_single_file": is_single_file, }, ) logger.info( f"Loaded {model_id} on {device} in {self._model_info.load_time_seconds:.2f}s" ) return self._model except ImportError as e: raise ModelLoadError( model_id, "diffusers library not installed. Install with: pip install diffusers", cause=e, ) except Exception as e: raise ModelLoadError(model_id, str(e), cause=e) finally: self._loading = False def _get_pipeline_class_for_type( self, pipeline_type: Optional[PipelineType] ) -> type: """Get the appropriate pipeline class for the type.""" from diffusers import ( StableDiffusionPipeline, StableDiffusionXLPipeline, DiffusionPipeline, ) type_to_class = { "sdxl": StableDiffusionXLPipeline, "sd15": StableDiffusionPipeline, "sd21": StableDiffusionPipeline, } return type_to_class.get(pipeline_type, DiffusionPipeline) def _set_scheduler(self, scheduler_name: str) -> None: """Set a different scheduler.""" from diffusers import ( DDIMScheduler, DDPMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, HeunDiscreteScheduler, KDPM2DiscreteScheduler, KDPM2AncestralDiscreteScheduler, UniPCMultistepScheduler, ) schedulers = { "ddim": DDIMScheduler, "ddpm": DDPMScheduler, "pndm": PNDMScheduler, "lms": LMSDiscreteScheduler, "euler": EulerDiscreteScheduler, "euler_a": EulerAncestralDiscreteScheduler, "euler_ancestral": EulerAncestralDiscreteScheduler, "dpm": DPMSolverMultistepScheduler, "dpm_solver": DPMSolverMultistepScheduler, "dpm_solver_multistep": DPMSolverMultistepScheduler, "dpm_solver_singlestep": DPMSolverSinglestepScheduler, "heun": HeunDiscreteScheduler, "kdpm2": KDPM2DiscreteScheduler, "kdpm2_a": KDPM2AncestralDiscreteScheduler, "kdpm2_ancestral": KDPM2AncestralDiscreteScheduler, "unipc": UniPCMultistepScheduler, } scheduler_key = scheduler_name.lower().replace("-", "_") if scheduler_key in schedulers: scheduler_class = schedulers[scheduler_key] self._model.scheduler = scheduler_class.from_config( self._model.scheduler.config ) else: logger.warning(f"Unknown scheduler: {scheduler_name}") async def unload(self) -> None: """Unload the pipeline and free GPU memory.""" if self._model is None: return try: import torch # Delete pipeline and components del self._model self._model = None self._scheduler = None self._vae = None # Clear GPU cache if torch.cuda.is_available(): torch.cuda.empty_cache() self._model_info = None logger.debug("Pipeline unloaded and GPU cache cleared") except ImportError: self._model = None self._scheduler = None self._vae = None self._model_info = None def generate( self, prompt: str, negative_prompt: Optional[str] = None, num_inference_steps: int = 30, guidance_scale: float = 7.5, width: int = 1024, height: int = 1024, num_images: int = 1, seed: Optional[int] = None, **kwargs: Any, ) -> List[Any]: """ Generate images from a prompt. Args: prompt: Text prompt for generation negative_prompt: Things to avoid in the image num_inference_steps: Number of denoising steps guidance_scale: How closely to follow the prompt width: Image width height: Image height num_images: Number of images to generate seed: Random seed for reproducibility **kwargs: Additional pipeline arguments Returns: List of generated PIL images """ if not self._model: raise ValueError("No pipeline loaded") import torch generator = None if seed is not None: device = self._model_info.device if self._model_info else "cpu" generator = torch.Generator(device=device).manual_seed(seed) result = self._model( prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, width=width, height=height, num_images_per_prompt=num_images, generator=generator, **kwargs, ) return result.images