Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Calc README and FLOPs Script #1

Open
wants to merge 14 commits into
base: main
Choose a base branch
from

Conversation

Quentin-Anthony
Copy link
Contributor

  • Improves the calc README. Both overall writing and some inconsistencies in how we present flops/params
  • Completely rework the mamba flop calc script. Some errors were introduced in commits after the initial push

@Quentin-Anthony Quentin-Anthony changed the title Update calc readme and script Improve Calc README and FLOPs Script Dec 16, 2024
# State updates
mamba2_block_flops += 2 * args.batch_size * args.sequence_length * d_inner * args.state_size
# Output projections
mamba2_block_flops += 2 * args.batch_size * args.sequence_length * d_inner * args.state_size * args.hidden_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
mamba2_block_flops += 2 * args.batch_size * args.sequence_length * d_inner * args.state_size * args.hidden_size
mamba2_block_flops += 2 * args.batch_size * args.sequence_length * d_inner * args.hidden_size

# Output projections
mamba2_block_flops += 2 * args.batch_size * args.sequence_length * d_inner * args.state_size * args.hidden_size
# Final gating
mamba2_block_flops += args.batch_size * args.sequence_length * args.hidden_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
mamba2_block_flops += args.batch_size * args.sequence_length * args.hidden_size

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deleted the last two lines because gating is included now before the output projections

total_attention_flops += shared_attention_flops
total_flops += shared_ffn_flops
total_ffn_flops = shared_ffn_flops
total_ffn_flops += shared_ffn_flops
args.hidden_size = original_hidden_size
# final downprojector matrix
total_flops += 4 * args.batch_size * args.sequence_length * args.hidden_size * args.hidden_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
total_flops += 4 * args.batch_size * args.sequence_length * args.hidden_size * args.hidden_size
total_flops += 2 * args.batch_size * args.sequence_length * args.hidden_size * args.hidden_size * iter_factor

this is a square matrix, the downproj happens inside attention as commented above


mamba_flops = compute_mamba_flops(args)
# Calculate component FLOPs
mamba1_flops = compute_mamba1_flops(args, iter_factor)
mamba2_flops = compute_mamba2_flops(args)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
mamba2_flops = compute_mamba2_flops(args)
mamba2_flops = iter_factor * compute_mamba2_flops(args)

# final downprojector matrix
total_flops += 4 * args.batch_size * args.sequence_length * args.hidden_size * args.hidden_size
total_flops += 4 * args.hidden_size * args.hidden_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
total_flops += 4 * args.hidden_size * args.hidden_size
total_flops += 2 * args.hidden_size * args.hidden_size

this layer maps hidden_size -> hidden_size

# final downprojector matrix
total_flops += 4 * args.batch_size * args.sequence_length * args.hidden_size * args.hidden_size
total_flops += 4 * args.hidden_size * args.hidden_size
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
else:
total_flops += mamba2_flops
total_mamba2_flops += mamba2_flops
else:

layer type g contains a mamba2 layer after transformer

Comment on lines +142 to +143
mamba2_block_flops += 2 * d_inner * args.state_size * args.hidden_size
mamba2_block_flops += args.hidden_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
mamba2_block_flops += 2 * d_inner * args.state_size * args.hidden_size
mamba2_block_flops += args.hidden_size
mamba2_block_flops += 2 * d_inner * args.hidden_size

there is no bias in the out projector

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants