Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add flash attn 2 #99

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions muse/modeling_movq.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@
except ImportError:
is_xformers_available = False

try:
import flash_attn

is_flash_attn_available = True
except:
is_flash_attn_available = False


class SpatialNorm(nn.Module):
def __init__(
Expand Down Expand Up @@ -173,6 +180,8 @@ def __init__(self, in_channels, zq_ch=None, add_conv=False):
self.use_memory_efficient_attention_xformers = False
self.xformers_attention_op = None

self.use_flash_attn = False

def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
):
Expand All @@ -181,6 +190,18 @@ def set_use_memory_efficient_attention_xformers(
self.use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
self.xformers_attention_op = attention_op

if self.use_flash_attn and self.use_memory_efficient_attention_xformers:
raise ValueError("set one and only one of flash attnetion and xformers")

def set_use_flash_attn(self, use_flash_attn: bool):
if use_flash_attn and not is_flash_attn_available:
raise ImportError("Please install flash attention")

self.use_flash_attn = use_flash_attn

if self.use_flash_attn and self.use_memory_efficient_attention_xformers:
raise ValueError("set one and only one of flash attnetion and xformers")

def forward(self, hidden_states, zq=None):
residual = hidden_states
batch, channel, height, width = hidden_states.shape
Expand All @@ -201,6 +222,8 @@ def forward(self, hidden_states, zq=None):
hidden_states = xops.memory_efficient_attention(
query, key, value, attn_bias=None, op=self.xformers_attention_op
)
elif self.use_flash_attn:
hidden_states = flash_attn.flash_attn_func(query, key, value)
else:
attention_scores = torch.baddbmm(
torch.empty(
Expand Down
24 changes: 24 additions & 0 deletions muse/modeling_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@
except ImportError:
is_xformers_available = False

try:
import flash_attn

is_flash_attn_available = True
except:
is_flash_attn_available = False


# classifier free guidance functions

Expand Down Expand Up @@ -382,6 +389,8 @@ def __init__(self, hidden_size, num_heads, encoder_hidden_size=None, attention_d
self.use_memory_efficient_attention_xformers = False
self.xformers_attention_op = None

self.use_flash_attn = False

def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
):
Expand All @@ -390,6 +399,18 @@ def set_use_memory_efficient_attention_xformers(
self.use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
self.xformers_attention_op = attention_op

if self.use_flash_attn and self.use_memory_efficient_attention_xformers:
raise ValueError("set one and only one of flash attnetion and xformers")

def set_use_flash_attn(self, use_flash_attn: bool):
if use_flash_attn and not is_flash_attn_available:
raise ImportError("Please install flash attention")

self.use_flash_attn = use_flash_attn

if self.use_flash_attn and self.use_memory_efficient_attention_xformers:
raise ValueError("set one and only one of flash attnetion and xformers")

def forward(self, hidden_states, encoder_hidden_states=None, encoder_attention_mask=None):
if encoder_attention_mask is not None and self.use_memory_efficient_attention_xformers:
raise ValueError("Memory efficient attention does not yet support encoder attention mask")
Expand All @@ -409,6 +430,9 @@ def forward(self, hidden_states, encoder_hidden_states=None, encoder_attention_m
if self.use_memory_efficient_attention_xformers:
attn_output = xops.memory_efficient_attention(query, key, value, op=self.xformers_attention_op)
attn_output = attn_output.view(batch, q_seq_len, self.hidden_size)
elif self.use_flash_attn:
attn_output = flash_attn.flash_attn_func(query, key, value)
attn_output = attn_output.view(batch, q_seq_len, self.hidden_size)
else:
attention_mask = None
if encoder_attention_mask is not None:
Expand Down
21 changes: 21 additions & 0 deletions muse/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,21 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
if isinstance(module, torch.nn.Module):
fn_recursive_set_mem_eff(module)

def set_use_flash_attn(self, valid: bool) -> None:
# Recursively walk through all the children.
# Any children which exposes the set_use_flash_attn method
# gets the message
def fn_recursive_set_flash_attn(module: torch.nn.Module):
if hasattr(module, "set_use_flash_attn"):
module.set_use_flash_attn(valid)

for child in module.children():
fn_recursive_set_flash_attn(child)

for module in self.children():
if isinstance(module, torch.nn.Module):
fn_recursive_set_flash_attn(module)

def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
r"""
Enable memory efficient attention as implemented in xformers.
Expand Down Expand Up @@ -328,6 +343,12 @@ def disable_xformers_memory_efficient_attention(self):
"""
self.set_use_memory_efficient_attention_xformers(False)

def enable_flash_attn(self):
self.set_use_flash_attn(True)

def disable_flash_attn(self):
self.set_use_flash_attn(False)

def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
Expand Down
3 changes: 3 additions & 0 deletions training/train_muse.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,9 @@ def save_model_hook(models, weights, output_dir):
if config.model.enable_xformers_memory_efficient_attention:
model.enable_xformers_memory_efficient_attention()

if config.model.get("enable_flash_attn", False):
model.enable_flash_attn()

optimizer_config = config.optimizer.params
learning_rate = optimizer_config.learning_rate
if optimizer_config.scale_lr:
Expand Down