Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use flash attn fuse cross entropy loss to reduce metric memory usage #2987

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

cli99
Copy link
Contributor

@cli99 cli99 commented Feb 9, 2024

This PR uses fused cross entropy loss from flash attention in the metric LanguageCrossEntropy (also LanguagePerplexity).
The current torch.nn.CrossEntropyLoss call needs 6 * seq_len * vocab_size GPU memory, and can be the bottleneck memory usage when sequence length is long (where act ckpt is probably used). Using cross entropy loss from flash attn resolves this problem.

Example test model with long sequence and full act ckpt:
with torch loss fn:
image

with flash_attn loss fn
image

@dakinggg
Copy link
Contributor

dakinggg commented Feb 9, 2024

@cli99 consider doing it as in https://github.com/mosaicml/llm-foundry/pull/575/files, to avoid introducing flash attn as a composer dependency.

Comment on lines +100 to +104
try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss as FusedCrossEntropyLoss
self.loss_fn = FusedCrossEntropyLoss(ignore_index=ignore_index, reduction='sum')
except ImportError:
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='sum')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this live in llm-foundry? CC: @dakinggg

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see above

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i have a very old pr that i never merged

@mvpatel2000
Copy link
Contributor

@cli99 should we close or is this still WIP

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants