Skip to content

Commit 7ca64fd

Browse files
authored
Merge branch 'main' into torchao-quantizer
2 parents 29ec905 + 96a9097 commit 7ca64fd

25 files changed

+708
-411
lines changed

.github/workflows/nightly_tests.yml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -238,12 +238,13 @@ jobs:
238238
239239
run_flax_tpu_tests:
240240
name: Nightly Flax TPU Tests
241-
runs-on: docker-tpu
241+
runs-on:
242+
group: gcp-ct5lp-hightpu-8t
242243
if: github.event_name == 'schedule'
243244

244245
container:
245246
image: diffusers/diffusers-flax-tpu
246-
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ --privileged
247+
options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache
247248
defaults:
248249
run:
249250
shell: bash
@@ -519,4 +520,4 @@ jobs:
519520
# if: always()
520521
# run: |
521522
# pip install slack_sdk tabulate
522-
# python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
523+
# python utils/log_reports.py >> $GITHUB_STEP_SUMMARY

.github/workflows/push_tests.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,11 +161,11 @@ jobs:
161161

162162
flax_tpu_tests:
163163
name: Flax TPU Tests
164-
runs-on: docker-tpu
164+
runs-on:
165+
group: gcp-ct5lp-hightpu-8t
165166
container:
166167
image: diffusers/diffusers-flax-tpu
167-
options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/ --privileged
168-
defaults:
168+
options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache defaults:
169169
run:
170170
shell: bash
171171
steps:

examples/community/README_community_scripts.md

Lines changed: 149 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,45 @@ from diffusers import StableDiffusionPipeline
241241
from diffusers.callbacks import PipelineCallback, MultiPipelineCallbacks
242242
from diffusers.configuration_utils import register_to_config
243243
import torch
244-
from typing import Any, Dict, Optional
244+
from typing import Any, Dict, Tuple, Union
245+
246+
247+
class SDPromptSchedulingCallback(PipelineCallback):
248+
@register_to_config
249+
def __init__(
250+
self,
251+
encoded_prompt: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
252+
cutoff_step_ratio=None,
253+
cutoff_step_index=None,
254+
):
255+
super().__init__(
256+
cutoff_step_ratio=cutoff_step_ratio, cutoff_step_index=cutoff_step_index
257+
)
258+
259+
tensor_inputs = ["prompt_embeds"]
260+
261+
def callback_fn(
262+
self, pipeline, step_index, timestep, callback_kwargs
263+
) -> Dict[str, Any]:
264+
cutoff_step_ratio = self.config.cutoff_step_ratio
265+
cutoff_step_index = self.config.cutoff_step_index
266+
if isinstance(self.config.encoded_prompt, tuple):
267+
prompt_embeds, negative_prompt_embeds = self.config.encoded_prompt
268+
else:
269+
prompt_embeds = self.config.encoded_prompt
270+
271+
# Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
272+
cutoff_step = (
273+
cutoff_step_index
274+
if cutoff_step_index is not None
275+
else int(pipeline.num_timesteps * cutoff_step_ratio)
276+
)
277+
278+
if step_index == cutoff_step:
279+
if pipeline.do_classifier_free_guidance:
280+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
281+
callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
282+
return callback_kwargs
245283

246284

247285
pipeline: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained(
@@ -253,28 +291,73 @@ pipeline: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained(
253291
pipeline.safety_checker = None
254292
pipeline.requires_safety_checker = False
255293

294+
callback = MultiPipelineCallbacks(
295+
[
296+
SDPromptSchedulingCallback(
297+
encoded_prompt=pipeline.encode_prompt(
298+
prompt=f"prompt {index}",
299+
negative_prompt=f"negative prompt {index}",
300+
device=pipeline._execution_device,
301+
num_images_per_prompt=1,
302+
# pipeline.do_classifier_free_guidance can't be accessed until after pipeline is ran
303+
do_classifier_free_guidance=True,
304+
),
305+
cutoff_step_index=index,
306+
) for index in range(1, 20)
307+
]
308+
)
309+
310+
image = pipeline(
311+
prompt="prompt"
312+
negative_prompt="negative prompt",
313+
callback_on_step_end=callback,
314+
callback_on_step_end_tensor_inputs=["prompt_embeds"],
315+
).images[0]
316+
torch.cuda.empty_cache()
317+
image.save('image.png')
318+
```
256319

257-
class SDPromptScheduleCallback(PipelineCallback):
320+
```python
321+
from diffusers import StableDiffusionXLPipeline
322+
from diffusers.callbacks import PipelineCallback, MultiPipelineCallbacks
323+
from diffusers.configuration_utils import register_to_config
324+
import torch
325+
from typing import Any, Dict, Tuple, Union
326+
327+
328+
class SDXLPromptSchedulingCallback(PipelineCallback):
258329
@register_to_config
259330
def __init__(
260331
self,
261-
prompt: str,
262-
negative_prompt: Optional[str] = None,
263-
num_images_per_prompt: int = 1,
264-
cutoff_step_ratio=1.0,
332+
encoded_prompt: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
333+
add_text_embeds: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
334+
add_time_ids: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
335+
cutoff_step_ratio=None,
265336
cutoff_step_index=None,
266337
):
267338
super().__init__(
268339
cutoff_step_ratio=cutoff_step_ratio, cutoff_step_index=cutoff_step_index
269340
)
270341

271-
tensor_inputs = ["prompt_embeds"]
342+
tensor_inputs = ["prompt_embeds", "add_text_embeds", "add_time_ids"]
272343

273344
def callback_fn(
274345
self, pipeline, step_index, timestep, callback_kwargs
275346
) -> Dict[str, Any]:
276347
cutoff_step_ratio = self.config.cutoff_step_ratio
277348
cutoff_step_index = self.config.cutoff_step_index
349+
if isinstance(self.config.encoded_prompt, tuple):
350+
prompt_embeds, negative_prompt_embeds = self.config.encoded_prompt
351+
else:
352+
prompt_embeds = self.config.encoded_prompt
353+
if isinstance(self.config.add_text_embeds, tuple):
354+
add_text_embeds, negative_add_text_embeds = self.config.add_text_embeds
355+
else:
356+
add_text_embeds = self.config.add_text_embeds
357+
if isinstance(self.config.add_time_ids, tuple):
358+
add_time_ids, negative_add_time_ids = self.config.add_time_ids
359+
else:
360+
add_time_ids = self.config.add_time_ids
278361

279362
# Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
280363
cutoff_step = (
@@ -284,34 +367,73 @@ class SDPromptScheduleCallback(PipelineCallback):
284367
)
285368

286369
if step_index == cutoff_step:
287-
prompt_embeds, negative_prompt_embeds = pipeline.encode_prompt(
288-
prompt=self.config.prompt,
289-
negative_prompt=self.config.negative_prompt,
290-
device=pipeline._execution_device,
291-
num_images_per_prompt=self.config.num_images_per_prompt,
292-
do_classifier_free_guidance=pipeline.do_classifier_free_guidance,
293-
)
294370
if pipeline.do_classifier_free_guidance:
295371
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
372+
add_text_embeds = torch.cat([negative_add_text_embeds, add_text_embeds])
373+
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids])
296374
callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
375+
callback_kwargs[self.tensor_inputs[1]] = add_text_embeds
376+
callback_kwargs[self.tensor_inputs[2]] = add_time_ids
297377
return callback_kwargs
298378

299-
callback = MultiPipelineCallbacks(
300-
[
301-
SDPromptScheduleCallback(
302-
prompt="Official portrait of a smiling world war ii general, female, cheerful, happy, detailed face, 20th century, highly detailed, cinematic lighting, digital art painting by Greg Rutkowski",
303-
negative_prompt="Deformed, ugly, bad anatomy",
304-
cutoff_step_ratio=0.25,
379+
380+
pipeline: StableDiffusionXLPipeline = StableDiffusionXLPipeline.from_pretrained(
381+
"stabilityai/stable-diffusion-xl-base-1.0",
382+
torch_dtype=torch.float16,
383+
variant="fp16",
384+
use_safetensors=True,
385+
).to("cuda")
386+
387+
callbacks = []
388+
for index in range(1, 20):
389+
(
390+
prompt_embeds,
391+
negative_prompt_embeds,
392+
pooled_prompt_embeds,
393+
negative_pooled_prompt_embeds,
394+
) = pipeline.encode_prompt(
395+
prompt=f"prompt {index}",
396+
negative_prompt=f"prompt {index}",
397+
device=pipeline._execution_device,
398+
num_images_per_prompt=1,
399+
# pipeline.do_classifier_free_guidance can't be accessed until after pipeline is ran
400+
do_classifier_free_guidance=True,
401+
)
402+
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
403+
add_time_ids = pipeline._get_add_time_ids(
404+
(1024, 1024),
405+
(0, 0),
406+
(1024, 1024),
407+
dtype=prompt_embeds.dtype,
408+
text_encoder_projection_dim=text_encoder_projection_dim,
409+
)
410+
negative_add_time_ids = pipeline._get_add_time_ids(
411+
(1024, 1024),
412+
(0, 0),
413+
(1024, 1024),
414+
dtype=prompt_embeds.dtype,
415+
text_encoder_projection_dim=text_encoder_projection_dim,
416+
)
417+
callbacks.append(
418+
SDXLPromptSchedulingCallback(
419+
encoded_prompt=(prompt_embeds, negative_prompt_embeds),
420+
add_text_embeds=(pooled_prompt_embeds, negative_pooled_prompt_embeds),
421+
add_time_ids=(add_time_ids, negative_add_time_ids),
422+
cutoff_step_index=index,
305423
)
306-
]
307-
)
424+
)
425+
426+
427+
callback = MultiPipelineCallbacks(callbacks)
308428

309429
image = pipeline(
310-
prompt="Official portrait of a smiling world war ii general, male, cheerful, happy, detailed face, 20th century, highly detailed, cinematic lighting, digital art painting by Greg Rutkowski",
311-
negative_prompt="Deformed, ugly, bad anatomy",
430+
prompt="prompt",
431+
negative_prompt="negative prompt",
312432
callback_on_step_end=callback,
313-
callback_on_step_end_tensor_inputs=["prompt_embeds"],
433+
callback_on_step_end_tensor_inputs=[
434+
"prompt_embeds",
435+
"add_text_embeds",
436+
"add_time_ids",
437+
],
314438
).images[0]
315-
torch.cuda.empty_cache()
316-
image.save('image.png')
317439
```

examples/community/pipeline_flux_rf_inversion.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,8 @@ def __call__(
648648
height: Optional[int] = None,
649649
width: Optional[int] = None,
650650
eta: float = 1.0,
651+
decay_eta: Optional[bool] = False,
652+
eta_decay_power: Optional[float] = 1.0,
651653
strength: float = 1.0,
652654
start_timestep: float = 0,
653655
stop_timestep: float = 0.25,
@@ -880,12 +882,9 @@ def __call__(
880882
v_t = -noise_pred
881883
v_t_cond = (y_0 - latents) / (1 - t_i)
882884
eta_t = eta if start_timestep <= i < stop_timestep else 0.0
883-
if start_timestep <= i < stop_timestep:
884-
# controlled vector field
885-
v_hat_t = v_t + eta * (v_t_cond - v_t)
886-
887-
else:
888-
v_hat_t = v_t
885+
if decay_eta:
886+
eta_t = eta_t * (1 - i / num_inference_steps) ** eta_decay_power # Decay eta over the loop
887+
v_hat_t = v_t + eta_t * (v_t_cond - v_t)
889888

890889
# SDE Eq: 17 from https://arxiv.org/pdf/2410.10792
891890
latents = latents + v_hat_t * (sigmas[i] - sigmas[i + 1])

examples/flux-control/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ accelerate launch train_control_lora_flux.py \
3636
--max_train_steps=5000 \
3737
--validation_image="openpose.png" \
3838
--validation_prompt="A couple, 4k photo, highly detailed" \
39+
--offload \
3940
--seed="0" \
4041
--push_to_hub
4142
```
@@ -154,6 +155,7 @@ accelerate launch --config_file=accelerate_ds2.yaml train_control_flux.py \
154155
--validation_steps=200 \
155156
--validation_image "2_pose_1024.jpg" "3_pose_1024.jpg" \
156157
--validation_prompt "two friends sitting by each other enjoying a day at the park, full hd, cinematic" "person enjoying a day at the park, full hd, cinematic" \
158+
--offload \
157159
--seed="0" \
158160
--push_to_hub
159161
```

examples/flux-control/train_control_flux.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,11 @@ def parse_args(input_args=None):
541541
default=1.29,
542542
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
543543
)
544+
parser.add_argument(
545+
"--offload",
546+
action="store_true",
547+
help="Whether to offload the VAE and the text encoders to CPU when they are not used.",
548+
)
544549

545550
if input_args is not None:
546551
args = parser.parse_args(input_args)
@@ -999,8 +1004,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
9991004
control_latents = encode_images(
10001005
batch["conditioning_pixel_values"], vae.to(accelerator.device), weight_dtype
10011006
)
1002-
# offload vae to CPU.
1003-
vae.cpu()
1007+
if args.offload:
1008+
# offload vae to CPU.
1009+
vae.cpu()
10041010

10051011
# Sample a random timestep for each image
10061012
# for weighting schemes where we sample timesteps non-uniformly
@@ -1064,7 +1070,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
10641070
if args.proportion_empty_prompts and random.random() < args.proportion_empty_prompts:
10651071
prompt_embeds.zero_()
10661072
pooled_prompt_embeds.zero_()
1067-
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
1073+
if args.offload:
1074+
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
10681075

10691076
# Predict.
10701077
model_pred = flux_transformer(

examples/flux-control/train_control_lora_flux.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,11 @@ def parse_args(input_args=None):
573573
default=1.29,
574574
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
575575
)
576+
parser.add_argument(
577+
"--offload",
578+
action="store_true",
579+
help="Whether to offload the VAE and the text encoders to CPU when they are not used.",
580+
)
576581

577582
if input_args is not None:
578583
args = parser.parse_args(input_args)
@@ -1140,8 +1145,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
11401145
control_latents = encode_images(
11411146
batch["conditioning_pixel_values"], vae.to(accelerator.device), weight_dtype
11421147
)
1143-
# offload vae to CPU.
1144-
vae.cpu()
1148+
1149+
if args.offload:
1150+
# offload vae to CPU.
1151+
vae.cpu()
11451152

11461153
# Sample a random timestep for each image
11471154
# for weighting schemes where we sample timesteps non-uniformly
@@ -1205,7 +1212,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
12051212
if args.proportion_empty_prompts and random.random() < args.proportion_empty_prompts:
12061213
prompt_embeds.zero_()
12071214
pooled_prompt_embeds.zero_()
1208-
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
1215+
if args.offload:
1216+
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
12091217

12101218
# Predict.
12111219
model_pred = flux_transformer(

0 commit comments

Comments
 (0)