Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New DataPartitionType DATA #567

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

apoorvtintin
Copy link

@apoorvtintin apoorvtintin commented Jul 1, 2024

Increases memory efficiency during large scale training, input batches and labels are sharded along the 'data' axis.
Added new input data sharding option DataPartitionType.DATA.

# Data are fully replicated across all devices.
REPLICATED = "replicated"
# Data are partially partitioned across data axis
DATA = "data"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A high level question, what is the purpose of this change?
I see that we have FULL partition support already, which partitions on axis=0 which is the data axis, how is DATA different from FULL?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DATA replicates over the sequence dimension. so the spec is ("data", None) versus ("data", "model") for FULL

Copy link
Contributor

@ruomingp ruomingp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Increases memory efficiency

Do you have measurements on how DATA improves memory efficiency? Thanks.

@ptoulme-aws
Copy link

Increases memory efficiency

Do you have measurements on how DATA improves memory efficiency? Thanks.

Increases memory efficiency

Do you have measurements on how DATA improves memory efficiency? Thanks.

By replicating the sequence length over TP workers we limit collectives and dynamic-slices introduced by the SPMD partitioner. This lowers overall step time and also allows us to run sequence parallelism over TP workers.

@ruomingp
Copy link
Contributor

ruomingp commented Jul 8, 2024

Increases memory efficiency

Do you have measurements on how DATA improves memory efficiency? Thanks.

Increases memory efficiency

Do you have measurements on how DATA improves memory efficiency? Thanks.

By replicating the sequence length over TP workers we limit collectives and dynamic-slices introduced by the SPMD partitioner. This lowers overall step time and also allows us to run sequence parallelism over TP workers.

Thanks. Do you have quantitative measurements?

@ptoulme-aws
Copy link

Increases memory efficiency

Do you have measurements on how DATA improves memory efficiency? Thanks.

Increases memory efficiency

Do you have measurements on how DATA improves memory efficiency? Thanks.

By replicating the sequence length over TP workers we limit collectives and dynamic-slices introduced by the SPMD partitioner. This lowers overall step time and also allows us to run sequence parallelism over TP workers.

Thanks. Do you have quantitative measurements?

No, we do not. It is more when we inspect the HLO after SPMD partition pass we see much more optimal sharding. Less all-to-alls and less dynamic-slices on right hand side.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants