diff --git a/llmfoundry/data/dataloader.py b/llmfoundry/data/dataloader.py index 83a9a7d8ea..e7521bc343 100644 --- a/llmfoundry/data/dataloader.py +++ b/llmfoundry/data/dataloader.py @@ -3,7 +3,7 @@ """Dataloader builder utilities.""" -from typing import Any, Dict +from typing import Any, Dict, Union from composer import DataSpec from transformers import PreTrainedTokenizerBase @@ -19,7 +19,7 @@ def build_dataloader( cfg: Dict[str, Any], tokenizer: PreTrainedTokenizerBase, - device_batch_size: int, + device_batch_size: Union[int, float], ) -> DataSpec: """Builds a dataloader from a config. diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 5ab148bbe8..5c1ec9114a 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -100,7 +100,7 @@ class TrainConfig: optimizer: Dict[str, Any] = MISSING scheduler: Dict[str, Any] = MISSING train_loader: Dict[str, Any] = MISSING - device_train_batch_size: int = MISSING + device_train_batch_size: Union[int, float] = MISSING device_eval_batch_size: int = MISSING max_duration: Union[int, str] = MISSING eval_interval: Union[int, str] = MISSING @@ -183,7 +183,6 @@ class TrainConfig: # Fields created by `update_batch_size_info` n_gpus: int = MISSING - device_train_batch_size: int = MISSING device_train_grad_accum: str = MISSING