Skip to content

Commit

Permalink
Fixed typo and reverted removal of skip_layers in SD3Transformer2DModel
Browse files Browse the repository at this point in the history
  • Loading branch information
guiyrt committed Dec 6, 2024
1 parent 50d09d9 commit 30e0dda
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions src/diffusers/models/transformers/transformer_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ def forward(
block_controlnet_hidden_states: List = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
skip_layers: Optional[List[int]] = None,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
The [`SD3Transformer2DModel`] forward method.
Expand All @@ -363,6 +364,8 @@ def forward(
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
skip_layers (`list` of `int`, *optional*):
A list of layer indices to skip during the forward pass.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
Expand Down Expand Up @@ -390,7 +393,10 @@ def forward(
encoder_hidden_states = self.context_embedder(encoder_hidden_states)

for index_block, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
# Skip specified layers
is_skip = True if skip_layers is not None and index_block in skip_layers else False

if torch.is_grad_enabled() and self.gradient_checkpointing and not is_skip:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand All @@ -410,8 +416,7 @@ def custom_forward(*inputs):
joint_attention_kwargs,
**ckpt_kwargs,
)

else:
elif not is_skip:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb,
joint_attention_kwargs=joint_attention_kwargs,
Expand Down

0 comments on commit 30e0dda

Please sign in to comment.