Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update add_loss calls in DGI task to reduce them by dividing with the global_batch_size before passing to Keras. #174

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion tensorflow_gnn/runner/tasks/BUILD
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
load("@tensorflow_gnn//tensorflow_gnn:tensorflow_gnn.bzl", "pytype_strict_library")
load("@tensorflow_gnn//tensorflow_gnn:tensorflow_gnn.bzl", "py_strict_test")
load("@tensorflow_gnn//tensorflow_gnn:tensorflow_gnn.bzl", "distribute_py_test")

licenses(["notice"])

Expand Down Expand Up @@ -40,13 +41,16 @@ pytype_strict_library(
],
)

py_strict_test(
distribute_py_test(
name = "dgi_test",
srcs = ["dgi_test.py"],
srcs_version = "PY3",
xla_enable_strict_auto_jit = False,
deps = [
":dgi",
"//:expect_absl_installed",
"//:expect_tensorflow_installed",
"//:expect_tensorflow_installed:tensorflow_no_contrib",
"//tensorflow_gnn",
"//tensorflow_gnn/runner:orchestration",
],
Expand Down
42 changes: 29 additions & 13 deletions tensorflow_gnn/runner/tasks/dgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,17 @@
class AddLossDeepGraphInfomax(tf.keras.layers.Layer):
""""A bilinear layer with losses and metrics for Deep Graph Infomax."""

def __init__(self, units: int):
def __init__(self, units: int, global_batch_size: int, **kwargs):
"""Builds the bilinear layer weights.

Args:
units: Units for the bilinear layer.
global_batch_size: Global batch size to compute the average loss.
**kwargs: Extra arguments needed for serialization.
"""
super().__init__()
super().__init__(**kwargs)
self._bilinear = tf.keras.layers.Dense(units, use_bias=False)
self._global_batch_size = global_batch_size

def get_config(self) -> Mapping[Any, Any]:
"""Returns the config of the layer.
Expand Down Expand Up @@ -58,13 +61,19 @@ def call(self, inputs: Tuple[tf.Tensor, tf.Tensor]) -> tf.Tensor:
y_clean, y_corrupted = inputs
# Summary
summary = tf.math.reduce_mean(y_clean, axis=0, keepdims=True)
per_replica_batch_size = (
self._global_batch_size //
tf.distribute.get_strategy().num_replicas_in_sync)
# Clean losses and metrics
logits_clean = tf.matmul(y_clean, self._bilinear(summary), transpose_b=True)
self.add_loss(tf.keras.losses.BinaryCrossentropy(
loss_clean = tf.keras.losses.BinaryCrossentropy(
from_logits=True,
name="binary_crossentropy_clean")(
tf.ones_like(logits_clean),
logits_clean))
name="binary_crossentropy_clean",
reduction=tf.keras.losses.Reduction.NONE)
self.add_loss(
tf.nn.compute_average_loss(
loss_clean(tf.ones_like(logits_clean), logits_clean),
global_batch_size=per_replica_batch_size))
self.add_metric(
tf.keras.metrics.binary_crossentropy(
tf.ones_like(logits_clean),
Expand All @@ -81,11 +90,14 @@ def call(self, inputs: Tuple[tf.Tensor, tf.Tensor]) -> tf.Tensor:
y_corrupted,
self._bilinear(summary),
transpose_b=True)
self.add_loss(tf.keras.losses.BinaryCrossentropy(
loss_corrupted = tf.keras.losses.BinaryCrossentropy(
from_logits=True,
name="binary_crossentropy_corrupted")(
tf.zeros_like(logits_corrupted),
logits_corrupted))
name="binary_crossentropy_corrupted",
reduction=tf.keras.losses.Reduction.NONE)
self.add_loss(
tf.nn.compute_average_loss(
loss_corrupted(tf.zeros_like(logits_corrupted), logits_corrupted),
global_batch_size=per_replica_batch_size))
self.add_metric(
tf.keras.metrics.binary_crossentropy(
tf.zeros_like(logits_corrupted),
Expand Down Expand Up @@ -125,18 +137,21 @@ class DeepGraphInfomax:
def __init__(self,
node_set_name: str,
*,
global_batch_size: int,
state_name: str = tfgnn.HIDDEN_STATE,
seed: Optional[int] = None):
"""Captures arguments for the task.

Args:
node_set_name: The node set for activations.
global_batch_size: Global batch size(not per-replica) for the training.
state_name: The state name of any activations.
seed: A seed for corrupted representations.
"""
self._state_name = state_name
self._node_set_name = node_set_name
self._seed = seed
self._global_batch_size = global_batch_size

def adapt(self, model: tf.keras.Model) -> tf.keras.Model:
"""Adapt a `tf.keras.Model` for Deep Graph Infomax.
Expand Down Expand Up @@ -164,15 +179,16 @@ def adapt(self, model: tf.keras.Model) -> tf.keras.Model:
feature_name=self._state_name)(model.output)

# Corrupted representations: shuffling, model application and readout
shuffled = tfgnn.shuffle_features_globally(model.input)
shuffled = tfgnn.shuffle_features_globally(model.input, seed=self._seed)
y_corrupted = tfgnn.keras.layers.ReadoutFirstNode(
node_set_name=self._node_set_name,
feature_name=self._state_name)(model(shuffled))

return tf.keras.Model(
model.input,
AddLossDeepGraphInfomax(
y_clean.get_shape()[-1])((y_clean, y_corrupted)))
AddLossDeepGraphInfomax(y_clean.get_shape()[-1],
self._global_batch_size)(
(y_clean, y_corrupted)))

def preprocess(self, gt: tfgnn.GraphTensor) -> tfgnn.GraphTensor:
"""Returns the input GraphTensor."""
Expand Down
123 changes: 111 additions & 12 deletions tensorflow_gnn/runner/tasks/dgi_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
# limitations under the License.
# ==============================================================================
"""Tests for dgi."""
import os

from absl.testing import parameterized
import tensorflow as tf
import tensorflow.__internal__.distribute as tfdistribute
import tensorflow.__internal__.test as tftest
import tensorflow_gnn as tfgnn

from tensorflow_gnn.runner import orchestration
Expand Down Expand Up @@ -42,10 +47,59 @@
""" % tfgnn.HIDDEN_STATE


class DeepGraphInfomaxTest(tf.test.TestCase):

def _all_eager_distributed_strategy_combinations():
strategies = [
# MirroredStrategy
tfdistribute.combinations.mirrored_strategy_with_gpu_and_cpu,
tfdistribute.combinations.mirrored_strategy_with_one_cpu,
tfdistribute.combinations.mirrored_strategy_with_one_gpu,
""" # MultiWorkerMirroredStrategy
tfdistribute.combinations.multi_worker_mirrored_2x1_cpu,
tfdistribute.combinations.multi_worker_mirrored_2x1_gpu,
# TPUStrategy
tfdistribute.combinations.tpu_strategy,
tfdistribute.combinations.tpu_strategy_one_core,
tfdistribute.combinations.tpu_strategy_packed_var,
# ParameterServerStrategy
tfdistribute.combinations.parameter_server_strategy_3worker_2ps_cpu,
tfdistribute.combinations.parameter_server_strategy_3worker_2ps_1gpu,
tfdistribute.combinations.parameter_server_strategy_1worker_2ps_cpu,
tfdistribute.combinations.parameter_server_strategy_1worker_2ps_1gpu, """
]
return tftest.combinations.combine(distribution=strategies)


class DeepGraphInfomaxTest(tf.test.TestCase, parameterized.TestCase):

global_batch_size = 2
gtspec = tfgnn.create_graph_spec_from_schema_pb(tfgnn.parse_schema(SCHEMA))
task = dgi.DeepGraphInfomax("node", seed=8191)
seed = 8191
task = dgi.DeepGraphInfomax(
"node", global_batch_size=global_batch_size, seed=seed)

def get_graph_tensor(self):
gt = tfgnn.GraphTensor.from_pieces(
node_sets={
"node":
tfgnn.NodeSet.from_fields(
features={
tfgnn.HIDDEN_STATE:
tf.convert_to_tensor([[1., 2., 3., 4.],
[11., 11., 11., 11.],
[19., 19., 19., 19.]])
},
sizes=tf.convert_to_tensor([3])),
},
edge_sets={
"edge":
tfgnn.EdgeSet.from_fields(
sizes=tf.convert_to_tensor([2]),
adjacency=tfgnn.Adjacency.from_indices(
("node", tf.convert_to_tensor([0, 1], dtype=tf.int32)),
("node", tf.convert_to_tensor([2, 0], dtype=tf.int32)),
)),
})
return gt

def build_model(self):
graph = inputs = tf.keras.layers.Input(type_spec=self.gtspec)
Expand All @@ -56,7 +110,9 @@ def build_model(self):
"edge",
tfgnn.TARGET,
feature_name=tfgnn.HIDDEN_STATE)
messages = tf.keras.layers.Dense(16)(values)
messages = tf.keras.layers.Dense(
8, kernel_initializer=tf.constant_initializer(1.))(
values)

pooled = tfgnn.pool_edges_to_node(
graph,
Expand All @@ -67,7 +123,9 @@ def build_model(self):
h_old = graph.node_sets["node"].features[tfgnn.HIDDEN_STATE]

h_next = tf.keras.layers.Concatenate()((pooled, h_old))
h_next = tf.keras.layers.Dense(8)(h_next)
h_next = tf.keras.layers.Dense(
4, kernel_initializer=tf.constant_initializer(1.))(
h_next)

graph = graph.replace_features(
node_sets={"node": {
Expand All @@ -87,30 +145,71 @@ def test_adapt(self):
feature_name=tfgnn.HIDDEN_STATE)(model(gt))
actual = adapted(gt)

self.assertAllClose(actual, expected)
self.assertAllClose(actual, expected, rtol=1e-04, atol=1e-04)

def test_fit(self):
gt = tfgnn.random_graph_tensor(self.gtspec)
ds = tf.data.Dataset.from_tensors(gt).repeat(8)
ds = ds.batch(2).map(tfgnn.GraphTensor.merge_batch_to_components)
ds = tf.data.Dataset.from_tensors(self.get_graph_tensor()).repeat(8)
ds = ds.batch(self.global_batch_size).map(
tfgnn.GraphTensor.merge_batch_to_components)

tf.random.set_seed(self.seed)
model = self.task.adapt(self.build_model())
model.compile()

def get_loss():
tf.random.set_seed(self.seed)
values = model.evaluate(ds)
return dict(zip(model.metrics_names, values))["loss"]

before = get_loss()
model.fit(ds)
after = get_loss()
self.assertAllClose(before, 21754138.0, rtol=1e-04, atol=1e-04)
self.assertAllClose(after, 16268301.0, rtol=1e-04, atol=1e-04)

@tfdistribute.combinations.generate(
tftest.combinations.combine(distribution=[
tfdistribute.combinations.mirrored_strategy_with_one_gpu,
tfdistribute.combinations.multi_worker_mirrored_2x1_gpu,
]))
def test_distributed(self, distribution):
gt = self.get_graph_tensor()

def dataset_fn(input_context=None, gt=gt):
ds = tf.data.Dataset.from_tensors(gt).repeat(8)
if input_context:
batch_size = input_context.get_per_replica_batch_size(
self.global_batch_size)
else:
batch_size = self.global_batch_size
ds = ds.batch(batch_size).map(tfgnn.GraphTensor.merge_batch_to_components)
return ds

with distribution.scope():
tf.random.set_seed(self.seed)
model = self.task.adapt(self.build_model())
model.compile()

def get_loss():
tf.random.set_seed(self.seed)
values = model.evaluate(
distribution.distribute_datasets_from_function(dataset_fn), steps=4)
return dict(zip(model.metrics_names, values))["loss"]

before = get_loss()
model.fit(
distribution.distribute_datasets_from_function(dataset_fn),
steps_per_epoch=4)
after = get_loss()
self.assertAllClose(before, 21754138.0, rtol=1e-04, atol=1e-04)
self.assertAllClose(after, 16268301.0, rtol=1e-04, atol=1e-04)

self.assertAllClose(before, 250.42036, rtol=1e-04, atol=1e-04)
self.assertAllClose(after, 13.18533, rtol=1e-04, atol=1e-04)
export_dir = os.path.join(self.get_temp_dir(), "dropout-model")
model.save(export_dir)

def test_protocol(self):
self.assertIsInstance(dgi.DeepGraphInfomax, orchestration.Task)


if __name__ == "__main__":
tf.test.main()
tfdistribute.multi_process_runner.test_main()