Skip to content

Commit

Permalink
This PR shards the Dataloader across both depth tensor and data paral…
Browse files Browse the repository at this point in the history
…lel ranks (#74)
  • Loading branch information
siddharth9820 authored May 7, 2024
1 parent 13c310c commit 7ea48f9
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions axonn/axonn.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,8 @@ def is_zeroth_rank():

def create_dataloader(
dataset: torch.utils.data.Dataset,
batch_size: int,
micro_batch_size: int,
global_batch_size: int,
micro_batch_size: int = 1,
num_workers: int = 0,
*args,
**kwargs,
Expand All @@ -183,7 +183,7 @@ def create_dataloader(
Arguments:
dataset (torch.utils.data.Dataset): a PyTorch dataset object
batch_size (int): batch size for dataloading
global_batch_size (int): global batch size over all GPUs
micro_batch_size (int): microbatch size for inter-layer parallelism
num_workers (int): number of worker processes in the dataloader
Expand All @@ -194,18 +194,23 @@ def create_dataloader(
"""
assert is_initialized
config.micro_batch_size = micro_batch_size
config.batch_size = batch_size
config.batch_size_per_network = batch_size // config.G_data
config.global_batch_size = global_batch_size
config.batch_size_per_network_instance = global_batch_size // (
config.G_data * config.G_intra_d
)
assert (
batch_size % (config.G_data * micro_batch_size) == 0
global_batch_size % (config.G_data * micro_batch_size) == 0
), "Batch Size should be divisible by the G_data*micro_batch_size"

sampler = torch.utils.data.distributed.DistributedSampler(
dataset, num_replicas=config.G_data, rank=config.data_parallel_rank
dataset,
num_replicas=config.G_data * config.G_intra_d,
rank=config.G_intra_d * config.data_parallel_rank
+ config.intra_layer_depth_parallel_rank,
)
data_loader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=config.batch_size_per_network,
batch_size=config.batch_size_per_network_instance,
shuffle=False,
num_workers=num_workers,
sampler=sampler,
Expand Down

0 comments on commit 7ea48f9

Please sign in to comment.