Skip to content

Commit

Permalink
Limit batch size tuning to ≤20% of dataset size (#3003)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
geoffreyangus and pre-commit-ci[bot] authored Jan 30, 2023
1 parent 4b9d5fc commit ca9ec5d
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 15 deletions.
9 changes: 6 additions & 3 deletions ludwig/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,12 @@
BATCH_SIZE = "batch_size"
EVAL_BATCH_SIZE = "eval_batch_size"
DEFAULT_BATCH_SIZE = "auto"
MAX_POSSIBLE_BATCH_SIZE = (
1099511627776 # 2^40. Used for `max_batch_size` config param. Not a hard constraint for `batch_size` config param.
)
# 2^40. Used for `max_batch_size` config param. Not a hard constraint for `batch_size` config param.
MAX_POSSIBLE_BATCH_SIZE = 1099511627776
# min batch size. Used as a floor for batch size tuning. Not a hard constraint for `batch_size` config params.
MIN_POSSIBLE_BATCH_SIZE = 2
# max batch size for dataset is 20% of dataset size
MAX_BATCH_SIZE_DATASET_FRACTION = 0.2
LEARNING_RATE = "learning_rate"
INPUT_SIZE = "input_size"
USE_BIAS = "use_bias"
Expand Down
19 changes: 14 additions & 5 deletions ludwig/utils/batch_size_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch

from ludwig.api_annotations import DeveloperAPI
from ludwig.constants import MAX_BATCH_SIZE_DATASET_FRACTION, MIN_POSSIBLE_BATCH_SIZE

logger = logging.getLogger(__name__)

Expand All @@ -26,19 +27,21 @@ def select_best_batch_size(
max_batch_size = max_batch_size or dataset_len

def _is_valid_batch_size(batch_size):
# make sure that batch size is valid (e.g. less than size of ds)
is_smaller_than_training_set = batch_size < dataset_len
# make sure that batch size is valid (e.g. less than 20% of ds size and max_batch_size)
is_smaller_than_training_set = batch_size <= MAX_BATCH_SIZE_DATASET_FRACTION * dataset_len
is_under_max_batch_size = batch_size <= max_batch_size
is_valid = is_smaller_than_training_set and is_under_max_batch_size
if not is_valid:
logger.info(
f"Batch size {batch_size} is invalid, must be smaller than training set size "
f"{dataset_len} and less than or equal to max batch size {max_batch_size}"
f"Batch size {batch_size} is invalid, must be less than or equal to "
f"{MAX_BATCH_SIZE_DATASET_FRACTION * 100}% dataset size "
f"({int(MAX_BATCH_SIZE_DATASET_FRACTION * dataset_len)} samples "
f"of {dataset_len}) and less than or equal to max batch size {max_batch_size}"
)
return is_valid

# Set 2 as the minimum batch size to account for batch norm.
batch_size = 2
batch_size = MIN_POSSIBLE_BATCH_SIZE

best_samples_per_sec = 0
best_batch_size = None
Expand Down Expand Up @@ -71,6 +74,12 @@ def _is_valid_batch_size(batch_size):
raise
break

# Ensure that some batch size is found.
# `best_batch_size` can be None if the first batch size is invalid.
if best_batch_size is None:
logger.info("Could not tune batch size, using minimum batch size of 2")
best_batch_size = MIN_POSSIBLE_BATCH_SIZE

logger.info(f"Selected batch_size={best_batch_size}")
return best_batch_size

Expand Down
18 changes: 15 additions & 3 deletions tests/integration_tests/test_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
DATE,
H3,
IMAGE,
MAX_BATCH_SIZE_DATASET_FRACTION,
NAME,
NUMBER,
PREPROCESSING,
Expand All @@ -49,6 +50,7 @@
VECTOR,
)
from ludwig.data.preprocessing import balance_data
from ludwig.data.split import DEFAULT_PROBABILITIES
from ludwig.utils.data_utils import read_parquet
from ludwig.utils.misc_utils import merge_dict
from tests.integration_tests.utils import (
Expand Down Expand Up @@ -795,7 +797,7 @@ def _run_train_gpu_load_cpu(config, data_parquet):
@pytest.mark.distributed
@pytest.mark.parametrize(
("max_batch_size", "expected_final_batch_size", "expected_final_learning_rate"),
[(256, 128, 0.001), (32, 32, 0.001)],
[(256, None, 0.001), (8, 8, 0.001)],
)
def test_tune_batch_size_lr_cpu(
tmpdir, ray_cluster_2cpu, max_batch_size, expected_final_batch_size, expected_final_learning_rate
Expand All @@ -818,11 +820,21 @@ def test_tune_batch_size_lr_cpu(

backend_config = {**RAY_BACKEND_CONFIG}

num_samples = 200
csv_filename = os.path.join(tmpdir, "dataset.csv")
dataset_csv = generate_data(config["input_features"], config["output_features"], csv_filename, num_examples=200)
dataset_csv = generate_data(
config["input_features"], config["output_features"], csv_filename, num_examples=num_samples
)
dataset_parquet = create_data_set_to_use("parquet", dataset_csv)
model = run_api_experiment(config, dataset=dataset_parquet, backend_config=backend_config)
assert model.config[TRAINER]["batch_size"] == expected_final_batch_size

if expected_final_batch_size is not None:
assert model.config[TRAINER]["batch_size"] == expected_final_batch_size
else:
# If we don't specify a batch size, we should validate the batch size against the training dataset size
num_train_samples = num_samples * DEFAULT_PROBABILITIES[0]
assert 2 < model.config[TRAINER]["batch_size"] <= MAX_BATCH_SIZE_DATASET_FRACTION * num_train_samples

assert model.config[TRAINER]["learning_rate"] == expected_final_learning_rate


Expand Down
9 changes: 5 additions & 4 deletions tests/integration_tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from ludwig.api import LudwigModel
from ludwig.callbacks import Callback
from ludwig.constants import BATCH_SIZE, TRAINER
from ludwig.constants import BATCH_SIZE, MAX_BATCH_SIZE_DATASET_FRACTION, TRAINER
from tests.integration_tests.utils import (
binary_feature,
category_feature,
Expand Down Expand Up @@ -94,8 +94,9 @@ def test_tune_batch_size_and_lr(tmpdir, eval_batch_size, is_cpu):
vector_feature(),
]

num_samples = 30
csv_filename = os.path.join(tmpdir, "training.csv")
data_csv = generate_data(input_features, output_features, csv_filename)
data_csv = generate_data(input_features, output_features, csv_filename, num_examples=num_samples)
val_csv = shutil.copyfile(data_csv, os.path.join(tmpdir, "validation.csv"))
test_csv = shutil.copyfile(data_csv, os.path.join(tmpdir, "test.csv"))

Expand Down Expand Up @@ -133,8 +134,8 @@ def check_postconditions(model):
assert model.config_obj.trainer.batch_size != "auto"
assert model.config_obj.trainer.batch_size > 1

# 16 is the largest possible batch size for this dataset
assert model.config_obj.trainer.batch_size == 16
# 4 is the largest possible batch size for this dataset (20% of dataset size)
assert model.config_obj.trainer.batch_size <= MAX_BATCH_SIZE_DATASET_FRACTION * num_samples

assert model.config_obj.trainer.eval_batch_size != "auto"
assert model.config_obj.trainer.eval_batch_size > 1
Expand Down

0 comments on commit ca9ec5d

Please sign in to comment.