Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(trainers): Add support for DistributedDatasetsFromFunction in data adapters #20829

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras/src/trainers/data_adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def is_tf_dataset(x):
if parent.__name__ in (
"DatasetV2",
"DistributedDataset",
"DistributedDatasetsFromFunction",
) and "tensorflow.python." in str(parent.__module__):
return True
return False
Expand Down
65 changes: 65 additions & 0 deletions keras/src/trainers/data_adapters/tf_dataset_adapter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import tensorflow as tf
import torch

from keras.src import Sequential
from keras.src import backend
from keras.src import layers
from keras.src import testing
from keras.src.trainers.data_adapters import tf_dataset_adapter

Expand Down Expand Up @@ -286,3 +288,66 @@ def test_tf_sparse_tensors(self):
self.assertIsInstance(by, expected_class)
self.assertEqual(bx.shape, (2, 4))
self.assertEqual(by.shape, (2, 2))

def test_distributed_datasets_from_function_adapter_properties(self):
strategy = tf.distribute.MirroredStrategy()

def dataset_fn(input_context):
batch_size = input_context.get_per_replica_batch_size(
global_batch_size=2
)
x = tf.random.uniform((32, 4))
y = tf.random.uniform((32, 2))
return tf.data.Dataset.from_tensor_slices((x, y)).batch(batch_size)

dist_dataset = strategy.distribute_datasets_from_function(dataset_fn)
adapter = tf_dataset_adapter.TFDatasetAdapter(dist_dataset)
self.assertEqual(adapter.num_batches, 16)
self.assertIsNone(adapter.batch_size)
self.assertIsNone(adapter.has_partial_batch)
self.assertIsNone(adapter.partial_batch_size)

if backend.backend() == "numpy":
it = adapter.get_numpy_iterator()
expected_class = np.ndarray
elif backend.backend() == "tensorflow":
it = adapter.get_tf_dataset()
expected_class = tf.Tensor
elif backend.backend() == "jax":
it = adapter.get_jax_iterator()
expected_class = np.ndarray
elif backend.backend() == "torch":
it = adapter.get_torch_dataloader()
expected_class = torch.Tensor

batch_count = 0
for batch in it:
batch_count += 1
self.assertEqual(len(batch), 2)
data, labels = batch
self.assertIsInstance(data, expected_class)
self.assertIsInstance(labels, expected_class)
self.assertEqual(data.shape, (2, 4))
self.assertEqual(labels.shape, (2, 2))

self.assertEqual(batch_count, 16)

@pytest.mark.requires_trainable_backend
def test_distributed_datasets_from_function_model_integration(self):
strategy = tf.distribute.MirroredStrategy()

def dataset_fn(input_context):
batch_size = input_context.get_per_replica_batch_size(
global_batch_size=2
)
x = tf.random.uniform((4, 1))
y = tf.random.uniform((4, 2))
return tf.data.Dataset.from_tensor_slices((x, y)).batch(batch_size)

dist_dataset = strategy.distribute_datasets_from_function(dataset_fn)

model = Sequential([layers.Dense(2, input_shape=(1,))])
model.compile(optimizer="adam", loss="mse")
model.fit(dist_dataset, epochs=1)
history = model.fit(dist_dataset, epochs=1)
self.assertIn("loss", history.history)
Loading