diff --git a/tensorflow_gnn/runner/tasks/BUILD b/tensorflow_gnn/runner/tasks/BUILD index a2b9c4f6..5db61aba 100644 --- a/tensorflow_gnn/runner/tasks/BUILD +++ b/tensorflow_gnn/runner/tasks/BUILD @@ -1,5 +1,6 @@ load("@tensorflow_gnn//tensorflow_gnn:tensorflow_gnn.bzl", "pytype_strict_library") load("@tensorflow_gnn//tensorflow_gnn:tensorflow_gnn.bzl", "py_strict_test") +load("@tensorflow_gnn//tensorflow_gnn:tensorflow_gnn.bzl", "distribute_py_test") licenses(["notice"]) @@ -40,13 +41,16 @@ pytype_strict_library( ], ) -py_strict_test( +distribute_py_test( name = "dgi_test", srcs = ["dgi_test.py"], srcs_version = "PY3", + xla_enable_strict_auto_jit = False, deps = [ ":dgi", + "//:expect_absl_installed", "//:expect_tensorflow_installed", + "//:expect_tensorflow_installed:tensorflow_no_contrib", "//tensorflow_gnn", "//tensorflow_gnn/runner:orchestration", ], diff --git a/tensorflow_gnn/runner/tasks/dgi.py b/tensorflow_gnn/runner/tasks/dgi.py index 9286755a..b6e47faf 100644 --- a/tensorflow_gnn/runner/tasks/dgi.py +++ b/tensorflow_gnn/runner/tasks/dgi.py @@ -23,14 +23,17 @@ class AddLossDeepGraphInfomax(tf.keras.layers.Layer): """"A bilinear layer with losses and metrics for Deep Graph Infomax.""" - def __init__(self, units: int): + def __init__(self, units: int, global_batch_size: int, **kwargs): """Builds the bilinear layer weights. Args: units: Units for the bilinear layer. + global_batch_size: Global batch size to compute the average loss. + **kwargs: Extra arguments needed for serialization. """ - super().__init__() + super().__init__(**kwargs) self._bilinear = tf.keras.layers.Dense(units, use_bias=False) + self._global_batch_size = global_batch_size def get_config(self) -> Mapping[Any, Any]: """Returns the config of the layer. @@ -58,13 +61,19 @@ def call(self, inputs: Tuple[tf.Tensor, tf.Tensor]) -> tf.Tensor: y_clean, y_corrupted = inputs # Summary summary = tf.math.reduce_mean(y_clean, axis=0, keepdims=True) + per_replica_batch_size = ( + self._global_batch_size // + tf.distribute.get_strategy().num_replicas_in_sync) # Clean losses and metrics logits_clean = tf.matmul(y_clean, self._bilinear(summary), transpose_b=True) - self.add_loss(tf.keras.losses.BinaryCrossentropy( + loss_clean = tf.keras.losses.BinaryCrossentropy( from_logits=True, - name="binary_crossentropy_clean")( - tf.ones_like(logits_clean), - logits_clean)) + name="binary_crossentropy_clean", + reduction=tf.keras.losses.Reduction.NONE) + self.add_loss( + tf.nn.compute_average_loss( + loss_clean(tf.ones_like(logits_clean), logits_clean), + global_batch_size=per_replica_batch_size)) self.add_metric( tf.keras.metrics.binary_crossentropy( tf.ones_like(logits_clean), @@ -81,11 +90,14 @@ def call(self, inputs: Tuple[tf.Tensor, tf.Tensor]) -> tf.Tensor: y_corrupted, self._bilinear(summary), transpose_b=True) - self.add_loss(tf.keras.losses.BinaryCrossentropy( + loss_corrupted = tf.keras.losses.BinaryCrossentropy( from_logits=True, - name="binary_crossentropy_corrupted")( - tf.zeros_like(logits_corrupted), - logits_corrupted)) + name="binary_crossentropy_corrupted", + reduction=tf.keras.losses.Reduction.NONE) + self.add_loss( + tf.nn.compute_average_loss( + loss_corrupted(tf.zeros_like(logits_corrupted), logits_corrupted), + global_batch_size=per_replica_batch_size)) self.add_metric( tf.keras.metrics.binary_crossentropy( tf.zeros_like(logits_corrupted), @@ -125,18 +137,21 @@ class DeepGraphInfomax: def __init__(self, node_set_name: str, *, + global_batch_size: int, state_name: str = tfgnn.HIDDEN_STATE, seed: Optional[int] = None): """Captures arguments for the task. Args: node_set_name: The node set for activations. + global_batch_size: Global batch size(not per-replica) for the training. state_name: The state name of any activations. seed: A seed for corrupted representations. """ self._state_name = state_name self._node_set_name = node_set_name self._seed = seed + self._global_batch_size = global_batch_size def adapt(self, model: tf.keras.Model) -> tf.keras.Model: """Adapt a `tf.keras.Model` for Deep Graph Infomax. @@ -164,15 +179,16 @@ def adapt(self, model: tf.keras.Model) -> tf.keras.Model: feature_name=self._state_name)(model.output) # Corrupted representations: shuffling, model application and readout - shuffled = tfgnn.shuffle_features_globally(model.input) + shuffled = tfgnn.shuffle_features_globally(model.input, seed=self._seed) y_corrupted = tfgnn.keras.layers.ReadoutFirstNode( node_set_name=self._node_set_name, feature_name=self._state_name)(model(shuffled)) return tf.keras.Model( model.input, - AddLossDeepGraphInfomax( - y_clean.get_shape()[-1])((y_clean, y_corrupted))) + AddLossDeepGraphInfomax(y_clean.get_shape()[-1], + self._global_batch_size)( + (y_clean, y_corrupted))) def preprocess(self, gt: tfgnn.GraphTensor) -> tfgnn.GraphTensor: """Returns the input GraphTensor.""" diff --git a/tensorflow_gnn/runner/tasks/dgi_test.py b/tensorflow_gnn/runner/tasks/dgi_test.py index 5011b023..e42ec99d 100644 --- a/tensorflow_gnn/runner/tasks/dgi_test.py +++ b/tensorflow_gnn/runner/tasks/dgi_test.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== """Tests for dgi.""" +import os + +from absl.testing import parameterized import tensorflow as tf +import tensorflow.__internal__.distribute as tfdistribute +import tensorflow.__internal__.test as tftest import tensorflow_gnn as tfgnn from tensorflow_gnn.runner import orchestration @@ -42,10 +47,59 @@ """ % tfgnn.HIDDEN_STATE -class DeepGraphInfomaxTest(tf.test.TestCase): - +def _all_eager_distributed_strategy_combinations(): + strategies = [ + # MirroredStrategy + tfdistribute.combinations.mirrored_strategy_with_gpu_and_cpu, + tfdistribute.combinations.mirrored_strategy_with_one_cpu, + tfdistribute.combinations.mirrored_strategy_with_one_gpu, + """ # MultiWorkerMirroredStrategy + tfdistribute.combinations.multi_worker_mirrored_2x1_cpu, + tfdistribute.combinations.multi_worker_mirrored_2x1_gpu, + # TPUStrategy + tfdistribute.combinations.tpu_strategy, + tfdistribute.combinations.tpu_strategy_one_core, + tfdistribute.combinations.tpu_strategy_packed_var, + # ParameterServerStrategy + tfdistribute.combinations.parameter_server_strategy_3worker_2ps_cpu, + tfdistribute.combinations.parameter_server_strategy_3worker_2ps_1gpu, + tfdistribute.combinations.parameter_server_strategy_1worker_2ps_cpu, + tfdistribute.combinations.parameter_server_strategy_1worker_2ps_1gpu, """ + ] + return tftest.combinations.combine(distribution=strategies) + + +class DeepGraphInfomaxTest(tf.test.TestCase, parameterized.TestCase): + + global_batch_size = 2 gtspec = tfgnn.create_graph_spec_from_schema_pb(tfgnn.parse_schema(SCHEMA)) - task = dgi.DeepGraphInfomax("node", seed=8191) + seed = 8191 + task = dgi.DeepGraphInfomax( + "node", global_batch_size=global_batch_size, seed=seed) + + def get_graph_tensor(self): + gt = tfgnn.GraphTensor.from_pieces( + node_sets={ + "node": + tfgnn.NodeSet.from_fields( + features={ + tfgnn.HIDDEN_STATE: + tf.convert_to_tensor([[1., 2., 3., 4.], + [11., 11., 11., 11.], + [19., 19., 19., 19.]]) + }, + sizes=tf.convert_to_tensor([3])), + }, + edge_sets={ + "edge": + tfgnn.EdgeSet.from_fields( + sizes=tf.convert_to_tensor([2]), + adjacency=tfgnn.Adjacency.from_indices( + ("node", tf.convert_to_tensor([0, 1], dtype=tf.int32)), + ("node", tf.convert_to_tensor([2, 0], dtype=tf.int32)), + )), + }) + return gt def build_model(self): graph = inputs = tf.keras.layers.Input(type_spec=self.gtspec) @@ -56,7 +110,9 @@ def build_model(self): "edge", tfgnn.TARGET, feature_name=tfgnn.HIDDEN_STATE) - messages = tf.keras.layers.Dense(16)(values) + messages = tf.keras.layers.Dense( + 8, kernel_initializer=tf.constant_initializer(1.))( + values) pooled = tfgnn.pool_edges_to_node( graph, @@ -67,7 +123,9 @@ def build_model(self): h_old = graph.node_sets["node"].features[tfgnn.HIDDEN_STATE] h_next = tf.keras.layers.Concatenate()((pooled, h_old)) - h_next = tf.keras.layers.Dense(8)(h_next) + h_next = tf.keras.layers.Dense( + 4, kernel_initializer=tf.constant_initializer(1.))( + h_next) graph = graph.replace_features( node_sets={"node": { @@ -87,30 +145,71 @@ def test_adapt(self): feature_name=tfgnn.HIDDEN_STATE)(model(gt)) actual = adapted(gt) - self.assertAllClose(actual, expected) + self.assertAllClose(actual, expected, rtol=1e-04, atol=1e-04) def test_fit(self): - gt = tfgnn.random_graph_tensor(self.gtspec) - ds = tf.data.Dataset.from_tensors(gt).repeat(8) - ds = ds.batch(2).map(tfgnn.GraphTensor.merge_batch_to_components) + ds = tf.data.Dataset.from_tensors(self.get_graph_tensor()).repeat(8) + ds = ds.batch(self.global_batch_size).map( + tfgnn.GraphTensor.merge_batch_to_components) + tf.random.set_seed(self.seed) model = self.task.adapt(self.build_model()) model.compile() def get_loss(): + tf.random.set_seed(self.seed) values = model.evaluate(ds) return dict(zip(model.metrics_names, values))["loss"] before = get_loss() model.fit(ds) after = get_loss() + self.assertAllClose(before, 21754138.0, rtol=1e-04, atol=1e-04) + self.assertAllClose(after, 16268301.0, rtol=1e-04, atol=1e-04) + + @tfdistribute.combinations.generate( + tftest.combinations.combine(distribution=[ + tfdistribute.combinations.mirrored_strategy_with_one_gpu, + tfdistribute.combinations.multi_worker_mirrored_2x1_gpu, + ])) + def test_distributed(self, distribution): + gt = self.get_graph_tensor() + + def dataset_fn(input_context=None, gt=gt): + ds = tf.data.Dataset.from_tensors(gt).repeat(8) + if input_context: + batch_size = input_context.get_per_replica_batch_size( + self.global_batch_size) + else: + batch_size = self.global_batch_size + ds = ds.batch(batch_size).map(tfgnn.GraphTensor.merge_batch_to_components) + return ds + + with distribution.scope(): + tf.random.set_seed(self.seed) + model = self.task.adapt(self.build_model()) + model.compile() + + def get_loss(): + tf.random.set_seed(self.seed) + values = model.evaluate( + distribution.distribute_datasets_from_function(dataset_fn), steps=4) + return dict(zip(model.metrics_names, values))["loss"] + + before = get_loss() + model.fit( + distribution.distribute_datasets_from_function(dataset_fn), + steps_per_epoch=4) + after = get_loss() + self.assertAllClose(before, 21754138.0, rtol=1e-04, atol=1e-04) + self.assertAllClose(after, 16268301.0, rtol=1e-04, atol=1e-04) - self.assertAllClose(before, 250.42036, rtol=1e-04, atol=1e-04) - self.assertAllClose(after, 13.18533, rtol=1e-04, atol=1e-04) + export_dir = os.path.join(self.get_temp_dir(), "dropout-model") + model.save(export_dir) def test_protocol(self): self.assertIsInstance(dgi.DeepGraphInfomax, orchestration.Task) if __name__ == "__main__": - tf.test.main() + tfdistribute.multi_process_runner.test_main()