Skip to content

Commit

Permalink
Merge branch 'main' into torchao-quantizer
Browse files Browse the repository at this point in the history
  • Loading branch information
a-r-r-o-w authored Dec 15, 2024
2 parents 29ec905 + 96a9097 commit 7ca64fd
Show file tree
Hide file tree
Showing 25 changed files with 708 additions and 411 deletions.
7 changes: 4 additions & 3 deletions .github/workflows/nightly_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -238,12 +238,13 @@ jobs:
run_flax_tpu_tests:
name: Nightly Flax TPU Tests
runs-on: docker-tpu
runs-on:
group: gcp-ct5lp-hightpu-8t
if: github.event_name == 'schedule'

container:
image: diffusers/diffusers-flax-tpu
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ --privileged
options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache
defaults:
run:
shell: bash
Expand Down Expand Up @@ -519,4 +520,4 @@ jobs:
# if: always()
# run: |
# pip install slack_sdk tabulate
# python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
# python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
6 changes: 3 additions & 3 deletions .github/workflows/push_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,11 @@ jobs:

flax_tpu_tests:
name: Flax TPU Tests
runs-on: docker-tpu
runs-on:
group: gcp-ct5lp-hightpu-8t
container:
image: diffusers/diffusers-flax-tpu
options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/ --privileged
defaults:
options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache defaults:
run:
shell: bash
steps:
Expand Down
176 changes: 149 additions & 27 deletions examples/community/README_community_scripts.md
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,45 @@ from diffusers import StableDiffusionPipeline
from diffusers.callbacks import PipelineCallback, MultiPipelineCallbacks
from diffusers.configuration_utils import register_to_config
import torch
from typing import Any, Dict, Optional
from typing import Any, Dict, Tuple, Union


class SDPromptSchedulingCallback(PipelineCallback):
@register_to_config
def __init__(
self,
encoded_prompt: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
cutoff_step_ratio=None,
cutoff_step_index=None,
):
super().__init__(
cutoff_step_ratio=cutoff_step_ratio, cutoff_step_index=cutoff_step_index
)

tensor_inputs = ["prompt_embeds"]

def callback_fn(
self, pipeline, step_index, timestep, callback_kwargs
) -> Dict[str, Any]:
cutoff_step_ratio = self.config.cutoff_step_ratio
cutoff_step_index = self.config.cutoff_step_index
if isinstance(self.config.encoded_prompt, tuple):
prompt_embeds, negative_prompt_embeds = self.config.encoded_prompt
else:
prompt_embeds = self.config.encoded_prompt

# Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
cutoff_step = (
cutoff_step_index
if cutoff_step_index is not None
else int(pipeline.num_timesteps * cutoff_step_ratio)
)

if step_index == cutoff_step:
if pipeline.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
return callback_kwargs


pipeline: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained(
Expand All @@ -253,28 +291,73 @@ pipeline: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained(
pipeline.safety_checker = None
pipeline.requires_safety_checker = False

callback = MultiPipelineCallbacks(
[
SDPromptSchedulingCallback(
encoded_prompt=pipeline.encode_prompt(
prompt=f"prompt {index}",
negative_prompt=f"negative prompt {index}",
device=pipeline._execution_device,
num_images_per_prompt=1,
# pipeline.do_classifier_free_guidance can't be accessed until after pipeline is ran
do_classifier_free_guidance=True,
),
cutoff_step_index=index,
) for index in range(1, 20)
]
)

image = pipeline(
prompt="prompt"
negative_prompt="negative prompt",
callback_on_step_end=callback,
callback_on_step_end_tensor_inputs=["prompt_embeds"],
).images[0]
torch.cuda.empty_cache()
image.save('image.png')
```

class SDPromptScheduleCallback(PipelineCallback):
```python
from diffusers import StableDiffusionXLPipeline
from diffusers.callbacks import PipelineCallback, MultiPipelineCallbacks
from diffusers.configuration_utils import register_to_config
import torch
from typing import Any, Dict, Tuple, Union


class SDXLPromptSchedulingCallback(PipelineCallback):
@register_to_config
def __init__(
self,
prompt: str,
negative_prompt: Optional[str] = None,
num_images_per_prompt: int = 1,
cutoff_step_ratio=1.0,
encoded_prompt: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
add_text_embeds: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
add_time_ids: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
cutoff_step_ratio=None,
cutoff_step_index=None,
):
super().__init__(
cutoff_step_ratio=cutoff_step_ratio, cutoff_step_index=cutoff_step_index
)

tensor_inputs = ["prompt_embeds"]
tensor_inputs = ["prompt_embeds", "add_text_embeds", "add_time_ids"]

def callback_fn(
self, pipeline, step_index, timestep, callback_kwargs
) -> Dict[str, Any]:
cutoff_step_ratio = self.config.cutoff_step_ratio
cutoff_step_index = self.config.cutoff_step_index
if isinstance(self.config.encoded_prompt, tuple):
prompt_embeds, negative_prompt_embeds = self.config.encoded_prompt
else:
prompt_embeds = self.config.encoded_prompt
if isinstance(self.config.add_text_embeds, tuple):
add_text_embeds, negative_add_text_embeds = self.config.add_text_embeds
else:
add_text_embeds = self.config.add_text_embeds
if isinstance(self.config.add_time_ids, tuple):
add_time_ids, negative_add_time_ids = self.config.add_time_ids
else:
add_time_ids = self.config.add_time_ids

# Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
cutoff_step = (
Expand All @@ -284,34 +367,73 @@ class SDPromptScheduleCallback(PipelineCallback):
)

if step_index == cutoff_step:
prompt_embeds, negative_prompt_embeds = pipeline.encode_prompt(
prompt=self.config.prompt,
negative_prompt=self.config.negative_prompt,
device=pipeline._execution_device,
num_images_per_prompt=self.config.num_images_per_prompt,
do_classifier_free_guidance=pipeline.do_classifier_free_guidance,
)
if pipeline.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
add_text_embeds = torch.cat([negative_add_text_embeds, add_text_embeds])
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids])
callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
callback_kwargs[self.tensor_inputs[1]] = add_text_embeds
callback_kwargs[self.tensor_inputs[2]] = add_time_ids
return callback_kwargs

callback = MultiPipelineCallbacks(
[
SDPromptScheduleCallback(
prompt="Official portrait of a smiling world war ii general, female, cheerful, happy, detailed face, 20th century, highly detailed, cinematic lighting, digital art painting by Greg Rutkowski",
negative_prompt="Deformed, ugly, bad anatomy",
cutoff_step_ratio=0.25,

pipeline: StableDiffusionXLPipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True,
).to("cuda")

callbacks = []
for index in range(1, 20):
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = pipeline.encode_prompt(
prompt=f"prompt {index}",
negative_prompt=f"prompt {index}",
device=pipeline._execution_device,
num_images_per_prompt=1,
# pipeline.do_classifier_free_guidance can't be accessed until after pipeline is ran
do_classifier_free_guidance=True,
)
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
add_time_ids = pipeline._get_add_time_ids(
(1024, 1024),
(0, 0),
(1024, 1024),
dtype=prompt_embeds.dtype,
text_encoder_projection_dim=text_encoder_projection_dim,
)
negative_add_time_ids = pipeline._get_add_time_ids(
(1024, 1024),
(0, 0),
(1024, 1024),
dtype=prompt_embeds.dtype,
text_encoder_projection_dim=text_encoder_projection_dim,
)
callbacks.append(
SDXLPromptSchedulingCallback(
encoded_prompt=(prompt_embeds, negative_prompt_embeds),
add_text_embeds=(pooled_prompt_embeds, negative_pooled_prompt_embeds),
add_time_ids=(add_time_ids, negative_add_time_ids),
cutoff_step_index=index,
)
]
)
)


callback = MultiPipelineCallbacks(callbacks)

image = pipeline(
prompt="Official portrait of a smiling world war ii general, male, cheerful, happy, detailed face, 20th century, highly detailed, cinematic lighting, digital art painting by Greg Rutkowski",
negative_prompt="Deformed, ugly, bad anatomy",
prompt="prompt",
negative_prompt="negative prompt",
callback_on_step_end=callback,
callback_on_step_end_tensor_inputs=["prompt_embeds"],
callback_on_step_end_tensor_inputs=[
"prompt_embeds",
"add_text_embeds",
"add_time_ids",
],
).images[0]
torch.cuda.empty_cache()
image.save('image.png')
```
11 changes: 5 additions & 6 deletions examples/community/pipeline_flux_rf_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,8 @@ def __call__(
height: Optional[int] = None,
width: Optional[int] = None,
eta: float = 1.0,
decay_eta: Optional[bool] = False,
eta_decay_power: Optional[float] = 1.0,
strength: float = 1.0,
start_timestep: float = 0,
stop_timestep: float = 0.25,
Expand Down Expand Up @@ -880,12 +882,9 @@ def __call__(
v_t = -noise_pred
v_t_cond = (y_0 - latents) / (1 - t_i)
eta_t = eta if start_timestep <= i < stop_timestep else 0.0
if start_timestep <= i < stop_timestep:
# controlled vector field
v_hat_t = v_t + eta * (v_t_cond - v_t)

else:
v_hat_t = v_t
if decay_eta:
eta_t = eta_t * (1 - i / num_inference_steps) ** eta_decay_power # Decay eta over the loop
v_hat_t = v_t + eta_t * (v_t_cond - v_t)

# SDE Eq: 17 from https://arxiv.org/pdf/2410.10792
latents = latents + v_hat_t * (sigmas[i] - sigmas[i + 1])
Expand Down
2 changes: 2 additions & 0 deletions examples/flux-control/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ accelerate launch train_control_lora_flux.py \
--max_train_steps=5000 \
--validation_image="openpose.png" \
--validation_prompt="A couple, 4k photo, highly detailed" \
--offload \
--seed="0" \
--push_to_hub
```
Expand Down Expand Up @@ -154,6 +155,7 @@ accelerate launch --config_file=accelerate_ds2.yaml train_control_flux.py \
--validation_steps=200 \
--validation_image "2_pose_1024.jpg" "3_pose_1024.jpg" \
--validation_prompt "two friends sitting by each other enjoying a day at the park, full hd, cinematic" "person enjoying a day at the park, full hd, cinematic" \
--offload \
--seed="0" \
--push_to_hub
```
Expand Down
13 changes: 10 additions & 3 deletions examples/flux-control/train_control_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,11 @@ def parse_args(input_args=None):
default=1.29,
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
)
parser.add_argument(
"--offload",
action="store_true",
help="Whether to offload the VAE and the text encoders to CPU when they are not used.",
)

if input_args is not None:
args = parser.parse_args(input_args)
Expand Down Expand Up @@ -999,8 +1004,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
control_latents = encode_images(
batch["conditioning_pixel_values"], vae.to(accelerator.device), weight_dtype
)
# offload vae to CPU.
vae.cpu()
if args.offload:
# offload vae to CPU.
vae.cpu()

# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
Expand Down Expand Up @@ -1064,7 +1070,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
if args.proportion_empty_prompts and random.random() < args.proportion_empty_prompts:
prompt_embeds.zero_()
pooled_prompt_embeds.zero_()
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
if args.offload:
text_encoding_pipeline = text_encoding_pipeline.to("cpu")

# Predict.
model_pred = flux_transformer(
Expand Down
14 changes: 11 additions & 3 deletions examples/flux-control/train_control_lora_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,11 @@ def parse_args(input_args=None):
default=1.29,
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
)
parser.add_argument(
"--offload",
action="store_true",
help="Whether to offload the VAE and the text encoders to CPU when they are not used.",
)

if input_args is not None:
args = parser.parse_args(input_args)
Expand Down Expand Up @@ -1140,8 +1145,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
control_latents = encode_images(
batch["conditioning_pixel_values"], vae.to(accelerator.device), weight_dtype
)
# offload vae to CPU.
vae.cpu()

if args.offload:
# offload vae to CPU.
vae.cpu()

# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
Expand Down Expand Up @@ -1205,7 +1212,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
if args.proportion_empty_prompts and random.random() < args.proportion_empty_prompts:
prompt_embeds.zero_()
pooled_prompt_embeds.zero_()
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
if args.offload:
text_encoding_pipeline = text_encoding_pipeline.to("cpu")

# Predict.
model_pred = flux_transformer(
Expand Down
Loading

0 comments on commit 7ca64fd

Please sign in to comment.