Skip to content

Commit

Permalink
Pass batch size to dataset class
Browse files Browse the repository at this point in the history
  • Loading branch information
corystephenson-db committed Oct 7, 2024
1 parent b5aa661 commit 829465c
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
1 change: 1 addition & 0 deletions diffusion/datasets/image_caption_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ def build_streaming_image_caption_latents_dataloader(
text_latent_shapes=text_latent_shapes,
attention_mask_keys=attention_mask_keys,
latent_dtype=dtype,
batch_size=batch_size,
**streaming_kwargs,
)

Expand Down
16 changes: 8 additions & 8 deletions diffusion/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,18 +103,19 @@ def train(config: DictConfig) -> None:
else:
optimizer = hydra.utils.instantiate(config.optimizer, params=model.parameters())

# Load train dataset. Need to ensure that the per-device batch size is added as a streaming kwarg
per_device_train_batch_size = config.dataset.train_batch_size // dist.get_world_size()
# Load train dataset. Currently this expects to load according to the datasetHparam method.
# This means adding external datasets is currently not super easy. Will refactor or check for
# upstream composer changes that could make this easier.
if tokenizer:
train_dataloader: Union[Iterable, DataSpec, Dict[str, Any]] = hydra.utils.instantiate(
config.dataset.train_dataset,
tokenizer=tokenizer,
batch_size=per_device_train_batch_size,
batch_size=config.dataset.train_batch_size // dist.get_world_size(),
)
else:
train_dataloader: Union[Iterable, DataSpec, Dict[str, Any]] = hydra.utils.instantiate(
config.dataset.train_dataset,
batch_size=per_device_train_batch_size,
batch_size=config.dataset.train_batch_size // dist.get_world_size(),
)
# Need to sleep for a bit to avoid dataloader crash
time.sleep(10)
Expand Down Expand Up @@ -147,14 +148,13 @@ def train(config: DictConfig) -> None:
eval_set = evaluators

else:
# Need to ensure that the per-device batch size is added as a streaming kwarg
per_device_eval_batch_size = config.dataset.eval_batch_size // dist.get_world_size()
if tokenizer:
eval_set = hydra.utils.instantiate(config.dataset.eval_dataset,
tokenizer=model.tokenizer,
batch_size=per_device_eval_batch_size)
batch_size=config.dataset.eval_batch_size // dist.get_world_size())
else:
eval_set = hydra.utils.instantiate(config.dataset.eval_dataset, batch_size=per_device_eval_batch_size)
eval_set = hydra.utils.instantiate(config.dataset.eval_dataset,
batch_size=config.dataset.eval_batch_size // dist.get_world_size())

# Need to sleep for a bit to avoid dataloader crash
time.sleep(10)
Expand Down

0 comments on commit 829465c

Please sign in to comment.