-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve Calc README and FLOPs Script #1
base: main
Are you sure you want to change the base?
Conversation
Quentin-Anthony
commented
Dec 16, 2024
- Improves the calc README. Both overall writing and some inconsistencies in how we present flops/params
- Completely rework the mamba flop calc script. Some errors were introduced in commits after the initial push
…llows the conv Co-authored-by: pglorio <[email protected]>
Co-authored-by: pglorio <[email protected]>
Co-authored-by: pglorio <[email protected]>
Co-authored-by: pglorio <[email protected]>
calc/calc_mamba_flops.py
Outdated
# State updates | ||
mamba2_block_flops += 2 * args.batch_size * args.sequence_length * d_inner * args.state_size | ||
# Output projections | ||
mamba2_block_flops += 2 * args.batch_size * args.sequence_length * d_inner * args.state_size * args.hidden_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mamba2_block_flops += 2 * args.batch_size * args.sequence_length * d_inner * args.state_size * args.hidden_size | |
mamba2_block_flops += 2 * args.batch_size * args.sequence_length * d_inner * args.hidden_size |
calc/calc_mamba_flops.py
Outdated
# Output projections | ||
mamba2_block_flops += 2 * args.batch_size * args.sequence_length * d_inner * args.state_size * args.hidden_size | ||
# Final gating | ||
mamba2_block_flops += args.batch_size * args.sequence_length * args.hidden_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mamba2_block_flops += args.batch_size * args.sequence_length * args.hidden_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
deleted the last two lines because gating is included now before the output projections
calc/calc_mamba_flops.py
Outdated
total_attention_flops += shared_attention_flops | ||
total_flops += shared_ffn_flops | ||
total_ffn_flops = shared_ffn_flops | ||
total_ffn_flops += shared_ffn_flops | ||
args.hidden_size = original_hidden_size | ||
# final downprojector matrix | ||
total_flops += 4 * args.batch_size * args.sequence_length * args.hidden_size * args.hidden_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
total_flops += 4 * args.batch_size * args.sequence_length * args.hidden_size * args.hidden_size | |
total_flops += 2 * args.batch_size * args.sequence_length * args.hidden_size * args.hidden_size * iter_factor |
this is a square matrix, the downproj happens inside attention as commented above
calc/calc_mamba_flops.py
Outdated
|
||
mamba_flops = compute_mamba_flops(args) | ||
# Calculate component FLOPs | ||
mamba1_flops = compute_mamba1_flops(args, iter_factor) | ||
mamba2_flops = compute_mamba2_flops(args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mamba2_flops = compute_mamba2_flops(args) | |
mamba2_flops = iter_factor * compute_mamba2_flops(args) |
Co-authored-by: pglorio <[email protected]>
Co-authored-by: pglorio <[email protected]>
Co-authored-by: pglorio <[email protected]>
Co-authored-by: pglorio <[email protected]>
Co-authored-by: pglorio <[email protected]>
Co-authored-by: pglorio <[email protected]>
# final downprojector matrix | ||
total_flops += 4 * args.batch_size * args.sequence_length * args.hidden_size * args.hidden_size | ||
total_flops += 4 * args.hidden_size * args.hidden_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
total_flops += 4 * args.hidden_size * args.hidden_size | |
total_flops += 2 * args.hidden_size * args.hidden_size |
this layer maps hidden_size -> hidden_size
# final downprojector matrix | ||
total_flops += 4 * args.batch_size * args.sequence_length * args.hidden_size * args.hidden_size | ||
total_flops += 4 * args.hidden_size * args.hidden_size | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
else: | |
total_flops += mamba2_flops | |
total_mamba2_flops += mamba2_flops | |
else: |
layer type g
contains a mamba2 layer after transformer
mamba2_block_flops += 2 * d_inner * args.state_size * args.hidden_size | ||
mamba2_block_flops += args.hidden_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mamba2_block_flops += 2 * d_inner * args.state_size * args.hidden_size | |
mamba2_block_flops += args.hidden_size | |
mamba2_block_flops += 2 * d_inner * args.hidden_size |
there is no bias in the out projector