| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import inspect |
| | from typing import Any, Callable, Dict, List, Optional, Union |
| |
|
| | import intel_extension_for_pytorch as ipex |
| | import torch |
| | from packaging import version |
| | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer |
| |
|
| | from diffusers.configuration_utils import FrozenDict |
| | from diffusers.models import AutoencoderKL, UNet2DConditionModel |
| | from diffusers.pipeline_utils import DiffusionPipeline |
| | from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput |
| | from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker |
| | from diffusers.schedulers import KarrasDiffusionSchedulers |
| | from diffusers.utils import ( |
| | deprecate, |
| | is_accelerate_available, |
| | is_accelerate_version, |
| | logging, |
| | randn_tensor, |
| | replace_example_docstring, |
| | ) |
| |
|
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| | EXAMPLE_DOC_STRING = """ |
| | Examples: |
| | ```py |
| | >>> import torch |
| | >>> from diffusers import StableDiffusionPipeline |
| | |
| | >>> pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", custom_pipeline="stable_diffusion_ipex") |
| | |
| | >>> # For Float32 |
| | >>> pipe.prepare_for_ipex(prompt, dtype=torch.float32, height=512, width=512) #value of image height/width should be consistent with the pipeline inference |
| | >>> # For BFloat16 |
| | >>> pipe.prepare_for_ipex(prompt, dtype=torch.bfloat16, height=512, width=512) #value of image height/width should be consistent with the pipeline inference |
| | |
| | >>> prompt = "a photo of an astronaut riding a horse on mars" |
| | >>> # For Float32 |
| | >>> image = pipe(prompt, num_inference_steps=num_inference_steps, height=512, width=512).images[0] #value of image height/width should be consistent with 'prepare_for_ipex()' |
| | >>> # For BFloat16 |
| | >>> with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16): |
| | >>> image = pipe(prompt, num_inference_steps=num_inference_steps, height=512, width=512).images[0] #value of image height/width should be consistent with 'prepare_for_ipex()' |
| | ``` |
| | """ |
| |
|
| |
|
| | class StableDiffusionIPEXPipeline(DiffusionPipeline): |
| | r""" |
| | Pipeline for text-to-image generation using Stable Diffusion on IPEX. |
| | |
| | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the |
| | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) |
| | |
| | Args: |
| | vae ([`AutoencoderKL`]): |
| | Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. |
| | text_encoder ([`CLIPTextModel`]): |
| | Frozen text-encoder. Stable Diffusion uses the text portion of |
| | [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically |
| | the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. |
| | tokenizer (`CLIPTokenizer`): |
| | Tokenizer of class |
| | [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). |
| | unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. |
| | scheduler ([`SchedulerMixin`]): |
| | A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of |
| | [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. |
| | safety_checker ([`StableDiffusionSafetyChecker`]): |
| | Classification module that estimates whether generated images could be considered offensive or harmful. |
| | Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. |
| | feature_extractor ([`CLIPFeatureExtractor`]): |
| | Model that extracts features from generated images to be used as inputs for the `safety_checker`. |
| | """ |
| | _optional_components = ["safety_checker", "feature_extractor"] |
| |
|
| | def __init__( |
| | self, |
| | vae: AutoencoderKL, |
| | text_encoder: CLIPTextModel, |
| | tokenizer: CLIPTokenizer, |
| | unet: UNet2DConditionModel, |
| | scheduler: KarrasDiffusionSchedulers, |
| | safety_checker: StableDiffusionSafetyChecker, |
| | feature_extractor: CLIPFeatureExtractor, |
| | requires_safety_checker: bool = True, |
| | ): |
| | super().__init__() |
| |
|
| | if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: |
| | deprecation_message = ( |
| | f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" |
| | f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " |
| | "to update the config accordingly as leaving `steps_offset` might led to incorrect results" |
| | " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," |
| | " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" |
| | " file" |
| | ) |
| | deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) |
| | new_config = dict(scheduler.config) |
| | new_config["steps_offset"] = 1 |
| | scheduler._internal_dict = FrozenDict(new_config) |
| |
|
| | if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: |
| | deprecation_message = ( |
| | f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." |
| | " `clip_sample` should be set to False in the configuration file. Please make sure to update the" |
| | " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" |
| | " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" |
| | " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" |
| | ) |
| | deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) |
| | new_config = dict(scheduler.config) |
| | new_config["clip_sample"] = False |
| | scheduler._internal_dict = FrozenDict(new_config) |
| |
|
| | if safety_checker is None and requires_safety_checker: |
| | logger.warning( |
| | f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" |
| | " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" |
| | " results in services or applications open to the public. Both the diffusers team and Hugging Face" |
| | " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" |
| | " it only for use-cases that involve analyzing network behavior or auditing its results. For more" |
| | " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." |
| | ) |
| |
|
| | if safety_checker is not None and feature_extractor is None: |
| | raise ValueError( |
| | "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" |
| | " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." |
| | ) |
| |
|
| | is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( |
| | version.parse(unet.config._diffusers_version).base_version |
| | ) < version.parse("0.9.0.dev0") |
| | is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 |
| | if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: |
| | deprecation_message = ( |
| | "The configuration file of the unet has set the default `sample_size` to smaller than" |
| | " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" |
| | " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" |
| | " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" |
| | " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" |
| | " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" |
| | " in the config might lead to incorrect results in future versions. If you have downloaded this" |
| | " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" |
| | " the `unet/config.json` file" |
| | ) |
| | deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) |
| | new_config = dict(unet.config) |
| | new_config["sample_size"] = 64 |
| | unet._internal_dict = FrozenDict(new_config) |
| |
|
| | self.register_modules( |
| | vae=vae, |
| | text_encoder=text_encoder, |
| | tokenizer=tokenizer, |
| | unet=unet, |
| | scheduler=scheduler, |
| | safety_checker=safety_checker, |
| | feature_extractor=feature_extractor, |
| | ) |
| | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) |
| | self.register_to_config(requires_safety_checker=requires_safety_checker) |
| |
|
| | def get_input_example(self, prompt, height=None, width=None, guidance_scale=7.5, num_images_per_prompt=1): |
| | prompt_embeds = None |
| | negative_prompt_embeds = None |
| | negative_prompt = None |
| | callback_steps = 1 |
| | generator = None |
| | latents = None |
| |
|
| | |
| | height = height or self.unet.config.sample_size * self.vae_scale_factor |
| | width = width or self.unet.config.sample_size * self.vae_scale_factor |
| |
|
| | |
| | self.check_inputs( |
| | prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds |
| | ) |
| |
|
| | |
| | if prompt is not None and isinstance(prompt, str): |
| | batch_size = 1 |
| | elif prompt is not None and isinstance(prompt, list): |
| | batch_size = len(prompt) |
| |
|
| | device = "cpu" |
| | |
| | |
| | |
| | do_classifier_free_guidance = guidance_scale > 1.0 |
| |
|
| | |
| | prompt_embeds = self._encode_prompt( |
| | prompt, |
| | device, |
| | num_images_per_prompt, |
| | do_classifier_free_guidance, |
| | negative_prompt, |
| | prompt_embeds=prompt_embeds, |
| | negative_prompt_embeds=negative_prompt_embeds, |
| | ) |
| |
|
| | |
| | latents = self.prepare_latents( |
| | batch_size * num_images_per_prompt, |
| | self.unet.in_channels, |
| | height, |
| | width, |
| | prompt_embeds.dtype, |
| | device, |
| | generator, |
| | latents, |
| | ) |
| | dummy = torch.ones(1, dtype=torch.int32) |
| | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
| | latent_model_input = self.scheduler.scale_model_input(latent_model_input, dummy) |
| |
|
| | unet_input_example = (latent_model_input, dummy, prompt_embeds) |
| | vae_decoder_input_example = latents |
| |
|
| | return unet_input_example, vae_decoder_input_example |
| |
|
| | def prepare_for_ipex(self, promt, dtype=torch.float32, height=None, width=None, guidance_scale=7.5): |
| | self.unet = self.unet.to(memory_format=torch.channels_last) |
| | self.vae.decoder = self.vae.decoder.to(memory_format=torch.channels_last) |
| | self.text_encoder = self.text_encoder.to(memory_format=torch.channels_last) |
| | if self.safety_checker is not None: |
| | self.safety_checker = self.safety_checker.to(memory_format=torch.channels_last) |
| |
|
| | unet_input_example, vae_decoder_input_example = self.get_input_example(promt, height, width, guidance_scale) |
| |
|
| | |
| | if dtype == torch.bfloat16: |
| | self.unet = ipex.optimize( |
| | self.unet.eval(), dtype=torch.bfloat16, inplace=True, sample_input=unet_input_example |
| | ) |
| | self.vae.decoder = ipex.optimize(self.vae.decoder.eval(), dtype=torch.bfloat16, inplace=True) |
| | self.text_encoder = ipex.optimize(self.text_encoder.eval(), dtype=torch.bfloat16, inplace=True) |
| | if self.safety_checker is not None: |
| | self.safety_checker = ipex.optimize(self.safety_checker.eval(), dtype=torch.bfloat16, inplace=True) |
| | elif dtype == torch.float32: |
| | self.unet = ipex.optimize( |
| | self.unet.eval(), |
| | dtype=torch.float32, |
| | inplace=True, |
| | sample_input=unet_input_example, |
| | level="O1", |
| | weights_prepack=True, |
| | auto_kernel_selection=False, |
| | ) |
| | self.vae.decoder = ipex.optimize( |
| | self.vae.decoder.eval(), |
| | dtype=torch.float32, |
| | inplace=True, |
| | level="O1", |
| | weights_prepack=True, |
| | auto_kernel_selection=False, |
| | ) |
| | self.text_encoder = ipex.optimize( |
| | self.text_encoder.eval(), |
| | dtype=torch.float32, |
| | inplace=True, |
| | level="O1", |
| | weights_prepack=True, |
| | auto_kernel_selection=False, |
| | ) |
| | if self.safety_checker is not None: |
| | self.safety_checker = ipex.optimize( |
| | self.safety_checker.eval(), |
| | dtype=torch.float32, |
| | inplace=True, |
| | level="O1", |
| | weights_prepack=True, |
| | auto_kernel_selection=False, |
| | ) |
| | else: |
| | raise ValueError(" The value of 'dtype' should be 'torch.bfloat16' or 'torch.float32' !") |
| |
|
| | |
| | with torch.cpu.amp.autocast(enabled=dtype == torch.bfloat16), torch.no_grad(): |
| | unet_trace_model = torch.jit.trace(self.unet, unet_input_example, check_trace=False, strict=False) |
| | unet_trace_model = torch.jit.freeze(unet_trace_model) |
| | self.unet.forward = unet_trace_model.forward |
| |
|
| | |
| | with torch.cpu.amp.autocast(enabled=dtype == torch.bfloat16), torch.no_grad(): |
| | ave_decoder_trace_model = torch.jit.trace( |
| | self.vae.decoder, vae_decoder_input_example, check_trace=False, strict=False |
| | ) |
| | ave_decoder_trace_model = torch.jit.freeze(ave_decoder_trace_model) |
| | self.vae.decoder.forward = ave_decoder_trace_model.forward |
| |
|
| | def enable_vae_slicing(self): |
| | r""" |
| | Enable sliced VAE decoding. |
| | |
| | When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several |
| | steps. This is useful to save some memory and allow larger batch sizes. |
| | """ |
| | self.vae.enable_slicing() |
| |
|
| | def disable_vae_slicing(self): |
| | r""" |
| | Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to |
| | computing decoding in one step. |
| | """ |
| | self.vae.disable_slicing() |
| |
|
| | def enable_vae_tiling(self): |
| | r""" |
| | Enable tiled VAE decoding. |
| | |
| | When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in |
| | several steps. This is useful to save a large amount of memory and to allow the processing of larger images. |
| | """ |
| | self.vae.enable_tiling() |
| |
|
| | def disable_vae_tiling(self): |
| | r""" |
| | Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to |
| | computing decoding in one step. |
| | """ |
| | self.vae.disable_tiling() |
| |
|
| | def enable_sequential_cpu_offload(self, gpu_id=0): |
| | r""" |
| | Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, |
| | text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a |
| | `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. |
| | Note that offloading happens on a submodule basis. Memory savings are higher than with |
| | `enable_model_cpu_offload`, but performance is lower. |
| | """ |
| | if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): |
| | from accelerate import cpu_offload |
| | else: |
| | raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") |
| |
|
| | device = torch.device(f"cuda:{gpu_id}") |
| |
|
| | if self.device.type != "cpu": |
| | self.to("cpu", silence_dtype_warnings=True) |
| | torch.cuda.empty_cache() |
| |
|
| | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: |
| | cpu_offload(cpu_offloaded_model, device) |
| |
|
| | if self.safety_checker is not None: |
| | cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) |
| |
|
| | def enable_model_cpu_offload(self, gpu_id=0): |
| | r""" |
| | Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared |
| | to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` |
| | method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with |
| | `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. |
| | """ |
| | if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): |
| | from accelerate import cpu_offload_with_hook |
| | else: |
| | raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") |
| |
|
| | device = torch.device(f"cuda:{gpu_id}") |
| |
|
| | if self.device.type != "cpu": |
| | self.to("cpu", silence_dtype_warnings=True) |
| | torch.cuda.empty_cache() |
| |
|
| | hook = None |
| | for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: |
| | _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) |
| |
|
| | if self.safety_checker is not None: |
| | _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) |
| |
|
| | |
| | self.final_offload_hook = hook |
| |
|
| | @property |
| | def _execution_device(self): |
| | r""" |
| | Returns the device on which the pipeline's models will be executed. After calling |
| | `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module |
| | hooks. |
| | """ |
| | if not hasattr(self.unet, "_hf_hook"): |
| | return self.device |
| | for module in self.unet.modules(): |
| | if ( |
| | hasattr(module, "_hf_hook") |
| | and hasattr(module._hf_hook, "execution_device") |
| | and module._hf_hook.execution_device is not None |
| | ): |
| | return torch.device(module._hf_hook.execution_device) |
| | return self.device |
| |
|
| | def _encode_prompt( |
| | self, |
| | prompt, |
| | device, |
| | num_images_per_prompt, |
| | do_classifier_free_guidance, |
| | negative_prompt=None, |
| | prompt_embeds: Optional[torch.FloatTensor] = None, |
| | negative_prompt_embeds: Optional[torch.FloatTensor] = None, |
| | ): |
| | r""" |
| | Encodes the prompt into text encoder hidden states. |
| | |
| | Args: |
| | prompt (`str` or `List[str]`, *optional*): |
| | prompt to be encoded |
| | device: (`torch.device`): |
| | torch device |
| | num_images_per_prompt (`int`): |
| | number of images that should be generated per prompt |
| | do_classifier_free_guidance (`bool`): |
| | whether to use classifier free guidance or not |
| | negative_prompt (`str` or `List[str]`, *optional*): |
| | The prompt or prompts not to guide the image generation. If not defined, one has to pass |
| | `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. |
| | Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). |
| | prompt_embeds (`torch.FloatTensor`, *optional*): |
| | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not |
| | provided, text embeddings will be generated from `prompt` input argument. |
| | negative_prompt_embeds (`torch.FloatTensor`, *optional*): |
| | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt |
| | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input |
| | argument. |
| | """ |
| | if prompt is not None and isinstance(prompt, str): |
| | batch_size = 1 |
| | elif prompt is not None and isinstance(prompt, list): |
| | batch_size = len(prompt) |
| | else: |
| | batch_size = prompt_embeds.shape[0] |
| |
|
| | if prompt_embeds is None: |
| | text_inputs = self.tokenizer( |
| | prompt, |
| | padding="max_length", |
| | max_length=self.tokenizer.model_max_length, |
| | truncation=True, |
| | return_tensors="pt", |
| | ) |
| | text_input_ids = text_inputs.input_ids |
| | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids |
| |
|
| | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( |
| | text_input_ids, untruncated_ids |
| | ): |
| | removed_text = self.tokenizer.batch_decode( |
| | untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] |
| | ) |
| | logger.warning( |
| | "The following part of your input was truncated because CLIP can only handle sequences up to" |
| | f" {self.tokenizer.model_max_length} tokens: {removed_text}" |
| | ) |
| |
|
| | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: |
| | attention_mask = text_inputs.attention_mask.to(device) |
| | else: |
| | attention_mask = None |
| |
|
| | prompt_embeds = self.text_encoder( |
| | text_input_ids.to(device), |
| | attention_mask=attention_mask, |
| | ) |
| | prompt_embeds = prompt_embeds[0] |
| |
|
| | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) |
| |
|
| | bs_embed, seq_len, _ = prompt_embeds.shape |
| | |
| | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) |
| | prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) |
| |
|
| | |
| | if do_classifier_free_guidance and negative_prompt_embeds is None: |
| | uncond_tokens: List[str] |
| | if negative_prompt is None: |
| | uncond_tokens = [""] * batch_size |
| | elif type(prompt) is not type(negative_prompt): |
| | raise TypeError( |
| | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" |
| | f" {type(prompt)}." |
| | ) |
| | elif isinstance(negative_prompt, str): |
| | uncond_tokens = [negative_prompt] |
| | elif batch_size != len(negative_prompt): |
| | raise ValueError( |
| | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" |
| | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" |
| | " the batch size of `prompt`." |
| | ) |
| | else: |
| | uncond_tokens = negative_prompt |
| |
|
| | max_length = prompt_embeds.shape[1] |
| | uncond_input = self.tokenizer( |
| | uncond_tokens, |
| | padding="max_length", |
| | max_length=max_length, |
| | truncation=True, |
| | return_tensors="pt", |
| | ) |
| |
|
| | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: |
| | attention_mask = uncond_input.attention_mask.to(device) |
| | else: |
| | attention_mask = None |
| |
|
| | negative_prompt_embeds = self.text_encoder( |
| | uncond_input.input_ids.to(device), |
| | attention_mask=attention_mask, |
| | ) |
| | negative_prompt_embeds = negative_prompt_embeds[0] |
| |
|
| | if do_classifier_free_guidance: |
| | |
| | seq_len = negative_prompt_embeds.shape[1] |
| |
|
| | negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) |
| |
|
| | negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) |
| | negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) |
| |
|
| | |
| | |
| | |
| | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) |
| |
|
| | return prompt_embeds |
| |
|
| | def run_safety_checker(self, image, device, dtype): |
| | if self.safety_checker is not None: |
| | safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) |
| | image, has_nsfw_concept = self.safety_checker( |
| | images=image, clip_input=safety_checker_input.pixel_values.to(dtype) |
| | ) |
| | else: |
| | has_nsfw_concept = None |
| | return image, has_nsfw_concept |
| |
|
| | def decode_latents(self, latents): |
| | latents = 1 / self.vae.config.scaling_factor * latents |
| | image = self.vae.decode(latents).sample |
| | image = (image / 2 + 0.5).clamp(0, 1) |
| | |
| | image = image.cpu().permute(0, 2, 3, 1).float().numpy() |
| | return image |
| |
|
| | def prepare_extra_step_kwargs(self, generator, eta): |
| | |
| | |
| | |
| | |
| |
|
| | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) |
| | extra_step_kwargs = {} |
| | if accepts_eta: |
| | extra_step_kwargs["eta"] = eta |
| |
|
| | |
| | accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) |
| | if accepts_generator: |
| | extra_step_kwargs["generator"] = generator |
| | return extra_step_kwargs |
| |
|
| | def check_inputs( |
| | self, |
| | prompt, |
| | height, |
| | width, |
| | callback_steps, |
| | negative_prompt=None, |
| | prompt_embeds=None, |
| | negative_prompt_embeds=None, |
| | ): |
| | if height % 8 != 0 or width % 8 != 0: |
| | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") |
| |
|
| | if (callback_steps is None) or ( |
| | callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) |
| | ): |
| | raise ValueError( |
| | f"`callback_steps` has to be a positive integer but is {callback_steps} of type" |
| | f" {type(callback_steps)}." |
| | ) |
| |
|
| | if prompt is not None and prompt_embeds is not None: |
| | raise ValueError( |
| | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" |
| | " only forward one of the two." |
| | ) |
| | elif prompt is None and prompt_embeds is None: |
| | raise ValueError( |
| | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." |
| | ) |
| | elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): |
| | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") |
| |
|
| | if negative_prompt is not None and negative_prompt_embeds is not None: |
| | raise ValueError( |
| | f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" |
| | f" {negative_prompt_embeds}. Please make sure to only forward one of the two." |
| | ) |
| |
|
| | if prompt_embeds is not None and negative_prompt_embeds is not None: |
| | if prompt_embeds.shape != negative_prompt_embeds.shape: |
| | raise ValueError( |
| | "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" |
| | f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" |
| | f" {negative_prompt_embeds.shape}." |
| | ) |
| |
|
| | def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): |
| | shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) |
| | if isinstance(generator, list) and len(generator) != batch_size: |
| | raise ValueError( |
| | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" |
| | f" size of {batch_size}. Make sure the batch size matches the length of the generators." |
| | ) |
| |
|
| | if latents is None: |
| | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
| | else: |
| | latents = latents.to(device) |
| |
|
| | |
| | latents = latents * self.scheduler.init_noise_sigma |
| | return latents |
| |
|
| | @torch.no_grad() |
| | @replace_example_docstring(EXAMPLE_DOC_STRING) |
| | def __call__( |
| | self, |
| | prompt: Union[str, List[str]] = None, |
| | height: Optional[int] = None, |
| | width: Optional[int] = None, |
| | num_inference_steps: int = 50, |
| | guidance_scale: float = 7.5, |
| | negative_prompt: Optional[Union[str, List[str]]] = None, |
| | num_images_per_prompt: Optional[int] = 1, |
| | eta: float = 0.0, |
| | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
| | latents: Optional[torch.FloatTensor] = None, |
| | prompt_embeds: Optional[torch.FloatTensor] = None, |
| | negative_prompt_embeds: Optional[torch.FloatTensor] = None, |
| | output_type: Optional[str] = "pil", |
| | return_dict: bool = True, |
| | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, |
| | callback_steps: int = 1, |
| | cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
| | ): |
| | r""" |
| | Function invoked when calling the pipeline for generation. |
| | |
| | Args: |
| | prompt (`str` or `List[str]`, *optional*): |
| | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. |
| | instead. |
| | height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): |
| | The height in pixels of the generated image. |
| | width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): |
| | The width in pixels of the generated image. |
| | num_inference_steps (`int`, *optional*, defaults to 50): |
| | The number of denoising steps. More denoising steps usually lead to a higher quality image at the |
| | expense of slower inference. |
| | guidance_scale (`float`, *optional*, defaults to 7.5): |
| | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). |
| | `guidance_scale` is defined as `w` of equation 2. of [Imagen |
| | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > |
| | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, |
| | usually at the expense of lower image quality. |
| | negative_prompt (`str` or `List[str]`, *optional*): |
| | The prompt or prompts not to guide the image generation. If not defined, one has to pass |
| | `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. |
| | Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). |
| | num_images_per_prompt (`int`, *optional*, defaults to 1): |
| | The number of images to generate per prompt. |
| | eta (`float`, *optional*, defaults to 0.0): |
| | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to |
| | [`schedulers.DDIMScheduler`], will be ignored for others. |
| | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): |
| | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) |
| | to make generation deterministic. |
| | latents (`torch.FloatTensor`, *optional*): |
| | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image |
| | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents |
| | tensor will ge generated by sampling using the supplied random `generator`. |
| | prompt_embeds (`torch.FloatTensor`, *optional*): |
| | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not |
| | provided, text embeddings will be generated from `prompt` input argument. |
| | negative_prompt_embeds (`torch.FloatTensor`, *optional*): |
| | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt |
| | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input |
| | argument. |
| | output_type (`str`, *optional*, defaults to `"pil"`): |
| | The output format of the generate image. Choose between |
| | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. |
| | return_dict (`bool`, *optional*, defaults to `True`): |
| | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a |
| | plain tuple. |
| | callback (`Callable`, *optional*): |
| | A function that will be called every `callback_steps` steps during inference. The function will be |
| | called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. |
| | callback_steps (`int`, *optional*, defaults to 1): |
| | The frequency at which the `callback` function will be called. If not specified, the callback will be |
| | called at every step. |
| | cross_attention_kwargs (`dict`, *optional*): |
| | A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under |
| | `self.processor` in |
| | [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). |
| | |
| | Examples: |
| | |
| | Returns: |
| | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: |
| | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. |
| | When returning a tuple, the first element is a list with the generated images, and the second element is a |
| | list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" |
| | (nsfw) content, according to the `safety_checker`. |
| | """ |
| | |
| | height = height or self.unet.config.sample_size * self.vae_scale_factor |
| | width = width or self.unet.config.sample_size * self.vae_scale_factor |
| |
|
| | |
| | self.check_inputs( |
| | prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds |
| | ) |
| |
|
| | |
| | if prompt is not None and isinstance(prompt, str): |
| | batch_size = 1 |
| | elif prompt is not None and isinstance(prompt, list): |
| | batch_size = len(prompt) |
| | else: |
| | batch_size = prompt_embeds.shape[0] |
| |
|
| | device = self._execution_device |
| | |
| | |
| | |
| | do_classifier_free_guidance = guidance_scale > 1.0 |
| |
|
| | |
| | prompt_embeds = self._encode_prompt( |
| | prompt, |
| | device, |
| | num_images_per_prompt, |
| | do_classifier_free_guidance, |
| | negative_prompt, |
| | prompt_embeds=prompt_embeds, |
| | negative_prompt_embeds=negative_prompt_embeds, |
| | ) |
| |
|
| | |
| | self.scheduler.set_timesteps(num_inference_steps, device=device) |
| | timesteps = self.scheduler.timesteps |
| |
|
| | |
| | num_channels_latents = self.unet.in_channels |
| | latents = self.prepare_latents( |
| | batch_size * num_images_per_prompt, |
| | num_channels_latents, |
| | height, |
| | width, |
| | prompt_embeds.dtype, |
| | device, |
| | generator, |
| | latents, |
| | ) |
| |
|
| | |
| | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) |
| |
|
| | |
| | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order |
| | with self.progress_bar(total=num_inference_steps) as progress_bar: |
| | for i, t in enumerate(timesteps): |
| | |
| | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
| | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
| |
|
| | |
| | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds)["sample"] |
| |
|
| | |
| | if do_classifier_free_guidance: |
| | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
| | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
| |
|
| | |
| | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample |
| |
|
| | |
| | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
| | progress_bar.update() |
| | if callback is not None and i % callback_steps == 0: |
| | callback(i, t, latents) |
| |
|
| | if output_type == "latent": |
| | image = latents |
| | has_nsfw_concept = None |
| | elif output_type == "pil": |
| | |
| | image = self.decode_latents(latents) |
| |
|
| | |
| | image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) |
| |
|
| | |
| | image = self.numpy_to_pil(image) |
| | else: |
| | |
| | image = self.decode_latents(latents) |
| |
|
| | |
| | image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) |
| |
|
| | |
| | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: |
| | self.final_offload_hook.offload() |
| |
|
| | if not return_dict: |
| | return (image, has_nsfw_concept) |
| |
|
| | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) |
| |
|