Skip to content

Commit

Permalink
Update add_loss calls in DGI task to reduce them by dividing with the…
Browse files Browse the repository at this point in the history
… global_batch_size before passing to Keras.

PiperOrigin-RevId: 487897250
  • Loading branch information
Neslihans authored and tensorflower-gardener committed Nov 15, 2022
1 parent d7a9659 commit 1b152b2
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 26 deletions.
6 changes: 5 additions & 1 deletion tensorflow_gnn/runner/tasks/BUILD
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
load("@tensorflow_gnn//tensorflow_gnn:tensorflow_gnn.bzl", "pytype_strict_library")
load("@tensorflow_gnn//tensorflow_gnn:tensorflow_gnn.bzl", "py_strict_test")
load("@tensorflow_gnn//tensorflow_gnn:tensorflow_gnn.bzl", "distribute_py_test")

licenses(["notice"])

Expand Down Expand Up @@ -40,13 +41,16 @@ pytype_strict_library(
],
)

py_strict_test(
distribute_py_test(
name = "dgi_test",
srcs = ["dgi_test.py"],
srcs_version = "PY3",
xla_enable_strict_auto_jit = False,
deps = [
":dgi",
"//:expect_absl_installed",
"//:expect_tensorflow_installed",
"//:expect_tensorflow_installed:tensorflow_no_contrib",
"//tensorflow_gnn",
"//tensorflow_gnn/runner:orchestration",
],
Expand Down
42 changes: 29 additions & 13 deletions tensorflow_gnn/runner/tasks/dgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,17 @@
class AddLossDeepGraphInfomax(tf.keras.layers.Layer):
""""A bilinear layer with losses and metrics for Deep Graph Infomax."""

def __init__(self, units: int):
def __init__(self, units: int, global_batch_size: int, **kwargs):
"""Builds the bilinear layer weights.
Args:
units: Units for the bilinear layer.
global_batch_size: Global batch size to compute the average loss.
**kwargs: Extra arguments needed for serialization.
"""
super().__init__()
super().__init__(**kwargs)
self._bilinear = tf.keras.layers.Dense(units, use_bias=False)
self._global_batch_size = global_batch_size

def get_config(self) -> Mapping[Any, Any]:
"""Returns the config of the layer.
Expand Down Expand Up @@ -58,13 +61,19 @@ def call(self, inputs: Tuple[tf.Tensor, tf.Tensor]) -> tf.Tensor:
y_clean, y_corrupted = inputs
# Summary
summary = tf.math.reduce_mean(y_clean, axis=0, keepdims=True)
per_replica_batch_size = (
self._global_batch_size //
tf.distribute.get_strategy().num_replicas_in_sync)
# Clean losses and metrics
logits_clean = tf.matmul(y_clean, self._bilinear(summary), transpose_b=True)
self.add_loss(tf.keras.losses.BinaryCrossentropy(
loss_clean = tf.keras.losses.BinaryCrossentropy(
from_logits=True,
name="binary_crossentropy_clean")(
tf.ones_like(logits_clean),
logits_clean))
name="binary_crossentropy_clean",
reduction=tf.keras.losses.Reduction.NONE)
self.add_loss(
tf.nn.compute_average_loss(
loss_clean(tf.ones_like(logits_clean), logits_clean),
global_batch_size=per_replica_batch_size))
self.add_metric(
tf.keras.metrics.binary_crossentropy(
tf.ones_like(logits_clean),
Expand All @@ -81,11 +90,14 @@ def call(self, inputs: Tuple[tf.Tensor, tf.Tensor]) -> tf.Tensor:
y_corrupted,
self._bilinear(summary),
transpose_b=True)
self.add_loss(tf.keras.losses.BinaryCrossentropy(
loss_corrupted = tf.keras.losses.BinaryCrossentropy(
from_logits=True,
name="binary_crossentropy_corrupted")(
tf.zeros_like(logits_corrupted),
logits_corrupted))
name="binary_crossentropy_corrupted",
reduction=tf.keras.losses.Reduction.NONE)
self.add_loss(
tf.nn.compute_average_loss(
loss_corrupted(tf.zeros_like(logits_corrupted), logits_corrupted),
global_batch_size=per_replica_batch_size))
self.add_metric(
tf.keras.metrics.binary_crossentropy(
tf.zeros_like(logits_corrupted),
Expand Down Expand Up @@ -125,18 +137,21 @@ class DeepGraphInfomax:
def __init__(self,
node_set_name: str,
*,
global_batch_size: int,
state_name: str = tfgnn.HIDDEN_STATE,
seed: Optional[int] = None):
"""Captures arguments for the task.
Args:
node_set_name: The node set for activations.
global_batch_size: Global batch size(not per-replica) for the training.
state_name: The state name of any activations.
seed: A seed for corrupted representations.
"""
self._state_name = state_name
self._node_set_name = node_set_name
self._seed = seed
self._global_batch_size = global_batch_size

def adapt(self, model: tf.keras.Model) -> tf.keras.Model:
"""Adapt a `tf.keras.Model` for Deep Graph Infomax.
Expand Down Expand Up @@ -164,15 +179,16 @@ def adapt(self, model: tf.keras.Model) -> tf.keras.Model:
feature_name=self._state_name)(model.output)

# Corrupted representations: shuffling, model application and readout
shuffled = tfgnn.shuffle_features_globally(model.input)
shuffled = tfgnn.shuffle_features_globally(model.input, seed=self._seed)
y_corrupted = tfgnn.keras.layers.ReadoutFirstNode(
node_set_name=self._node_set_name,
feature_name=self._state_name)(model(shuffled))

return tf.keras.Model(
model.input,
AddLossDeepGraphInfomax(
y_clean.get_shape()[-1])((y_clean, y_corrupted)))
AddLossDeepGraphInfomax(y_clean.get_shape()[-1],
self._global_batch_size)(
(y_clean, y_corrupted)))

def preprocess(self, gt: tfgnn.GraphTensor) -> tfgnn.GraphTensor:
"""Returns the input GraphTensor."""
Expand Down
123 changes: 111 additions & 12 deletions tensorflow_gnn/runner/tasks/dgi_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
# limitations under the License.
# ==============================================================================
"""Tests for dgi."""
import os

from absl.testing import parameterized
import tensorflow as tf
import tensorflow.__internal__.distribute as tfdistribute
import tensorflow.__internal__.test as tftest
import tensorflow_gnn as tfgnn

from tensorflow_gnn.runner import orchestration
Expand Down Expand Up @@ -42,10 +47,59 @@
""" % tfgnn.HIDDEN_STATE


class DeepGraphInfomaxTest(tf.test.TestCase):

def _all_eager_distributed_strategy_combinations():
strategies = [
# MirroredStrategy
tfdistribute.combinations.mirrored_strategy_with_gpu_and_cpu,
tfdistribute.combinations.mirrored_strategy_with_one_cpu,
tfdistribute.combinations.mirrored_strategy_with_one_gpu,
""" # MultiWorkerMirroredStrategy
tfdistribute.combinations.multi_worker_mirrored_2x1_cpu,
tfdistribute.combinations.multi_worker_mirrored_2x1_gpu,
# TPUStrategy
tfdistribute.combinations.tpu_strategy,
tfdistribute.combinations.tpu_strategy_one_core,
tfdistribute.combinations.tpu_strategy_packed_var,
# ParameterServerStrategy
tfdistribute.combinations.parameter_server_strategy_3worker_2ps_cpu,
tfdistribute.combinations.parameter_server_strategy_3worker_2ps_1gpu,
tfdistribute.combinations.parameter_server_strategy_1worker_2ps_cpu,
tfdistribute.combinations.parameter_server_strategy_1worker_2ps_1gpu, """
]
return tftest.combinations.combine(distribution=strategies)


class DeepGraphInfomaxTest(tf.test.TestCase, parameterized.TestCase):

global_batch_size = 2
gtspec = tfgnn.create_graph_spec_from_schema_pb(tfgnn.parse_schema(SCHEMA))
task = dgi.DeepGraphInfomax("node", seed=8191)
seed = 8191
task = dgi.DeepGraphInfomax(
"node", global_batch_size=global_batch_size, seed=seed)

def get_graph_tensor(self):
gt = tfgnn.GraphTensor.from_pieces(
node_sets={
"node":
tfgnn.NodeSet.from_fields(
features={
tfgnn.HIDDEN_STATE:
tf.convert_to_tensor([[1., 2., 3., 4.],
[11., 11., 11., 11.],
[19., 19., 19., 19.]])
},
sizes=tf.convert_to_tensor([3])),
},
edge_sets={
"edge":
tfgnn.EdgeSet.from_fields(
sizes=tf.convert_to_tensor([2]),
adjacency=tfgnn.Adjacency.from_indices(
("node", tf.convert_to_tensor([0, 1], dtype=tf.int32)),
("node", tf.convert_to_tensor([2, 0], dtype=tf.int32)),
)),
})
return gt

def build_model(self):
graph = inputs = tf.keras.layers.Input(type_spec=self.gtspec)
Expand All @@ -56,7 +110,9 @@ def build_model(self):
"edge",
tfgnn.TARGET,
feature_name=tfgnn.HIDDEN_STATE)
messages = tf.keras.layers.Dense(16)(values)
messages = tf.keras.layers.Dense(
8, kernel_initializer=tf.constant_initializer(1.))(
values)

pooled = tfgnn.pool_edges_to_node(
graph,
Expand All @@ -67,7 +123,9 @@ def build_model(self):
h_old = graph.node_sets["node"].features[tfgnn.HIDDEN_STATE]

h_next = tf.keras.layers.Concatenate()((pooled, h_old))
h_next = tf.keras.layers.Dense(8)(h_next)
h_next = tf.keras.layers.Dense(
4, kernel_initializer=tf.constant_initializer(1.))(
h_next)

graph = graph.replace_features(
node_sets={"node": {
Expand All @@ -87,30 +145,71 @@ def test_adapt(self):
feature_name=tfgnn.HIDDEN_STATE)(model(gt))
actual = adapted(gt)

self.assertAllClose(actual, expected)
self.assertAllClose(actual, expected, rtol=1e-04, atol=1e-04)

def test_fit(self):
gt = tfgnn.random_graph_tensor(self.gtspec)
ds = tf.data.Dataset.from_tensors(gt).repeat(8)
ds = ds.batch(2).map(tfgnn.GraphTensor.merge_batch_to_components)
ds = tf.data.Dataset.from_tensors(self.get_graph_tensor()).repeat(8)
ds = ds.batch(self.global_batch_size).map(
tfgnn.GraphTensor.merge_batch_to_components)

tf.random.set_seed(self.seed)
model = self.task.adapt(self.build_model())
model.compile()

def get_loss():
tf.random.set_seed(self.seed)
values = model.evaluate(ds)
return dict(zip(model.metrics_names, values))["loss"]

before = get_loss()
model.fit(ds)
after = get_loss()
self.assertAllClose(before, 21754138.0, rtol=1e-04, atol=1e-04)
self.assertAllClose(after, 16268301.0, rtol=1e-04, atol=1e-04)

@tfdistribute.combinations.generate(
tftest.combinations.combine(distribution=[
tfdistribute.combinations.mirrored_strategy_with_one_gpu,
tfdistribute.combinations.multi_worker_mirrored_2x1_gpu,
]))
def test_distributed(self, distribution):
gt = self.get_graph_tensor()

def dataset_fn(input_context=None, gt=gt):
ds = tf.data.Dataset.from_tensors(gt).repeat(8)
if input_context:
batch_size = input_context.get_per_replica_batch_size(
self.global_batch_size)
else:
batch_size = self.global_batch_size
ds = ds.batch(batch_size).map(tfgnn.GraphTensor.merge_batch_to_components)
return ds

with distribution.scope():
tf.random.set_seed(self.seed)
model = self.task.adapt(self.build_model())
model.compile()

def get_loss():
tf.random.set_seed(self.seed)
values = model.evaluate(
distribution.distribute_datasets_from_function(dataset_fn), steps=4)
return dict(zip(model.metrics_names, values))["loss"]

before = get_loss()
model.fit(
distribution.distribute_datasets_from_function(dataset_fn),
steps_per_epoch=4)
after = get_loss()
self.assertAllClose(before, 21754138.0, rtol=1e-04, atol=1e-04)
self.assertAllClose(after, 16268301.0, rtol=1e-04, atol=1e-04)

self.assertAllClose(before, 250.42036, rtol=1e-04, atol=1e-04)
self.assertAllClose(after, 13.18533, rtol=1e-04, atol=1e-04)
export_dir = os.path.join(self.get_temp_dir(), "dropout-model")
model.save(export_dir)

def test_protocol(self):
self.assertIsInstance(dgi.DeepGraphInfomax, orchestration.Task)


if __name__ == "__main__":
tf.test.main()
tfdistribute.multi_process_runner.test_main()

0 comments on commit 1b152b2

Please sign in to comment.