Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Internal change #275

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 33 additions & 8 deletions tensorflow_gnn/models/gcn/gcn_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,14 @@ class GCNConv(tf.keras.layers.Layer):
paper doesn't use a bias, but this defaults to True to be consistent
with Keras and other implementations.
add_self_loops: Whether to compute the result as if a loop from each node
to itself had been added to the edge set.
to itself had been added to the edge set. The self-loop edges are added
with an edge weight of one.
normalize: Whether to normalize the node features by in-degree.
kernel_initializer: initializer of type tf.keras.initializers .
node_feature: Name of the node feature to transform.
edge_weight_feature_name: Can be set to the name of a feature on the edge
set that supplies a scalar weight for each edge. The GCN computation uses
it as the edge's entry in the adjacency matrix, instead of the default 1.
**kwargs: additional arguments for the Layer.

Call arguments:
Expand Down Expand Up @@ -99,6 +103,7 @@ def __init__(self,
kernel_initializer: bool = None,
node_feature: Optional[str] = tfgnn.HIDDEN_STATE,
kernel_regularizer: Optional[_RegularizerType] = None,
edge_weight_feature_name: Optional[tfgnn.FieldName] = None,
**kwargs):

super().__init__(**kwargs)
Expand All @@ -113,6 +118,7 @@ def __init__(self,
self._node_feature = node_feature
self._receiver = receiver_tag
self._sender = tfgnn.reverse_tag(receiver_tag)
self._edge_weight_feature_name = edge_weight_feature_name

def get_config(self):
filter_config = self._filter.get_config()
Expand All @@ -126,6 +132,7 @@ def get_config(self):
use_bias=filter_config['use_bias'],
kernel_initializer=filter_config['kernel_initializer'],
kernel_regularizer=filter_config['kernel_regularizer'],
edge_weight_feature_name=self._edge_weight_feature_name,
**super().get_config())

def call(
Expand All @@ -148,13 +155,29 @@ def call(

if self._normalize:
edge_set = graph.edge_sets[edge_set_name]
edge_ones = tf.ones([edge_set.total_size, 1])
in_degree = tf.squeeze(tfgnn.pool_edges_to_node(
graph,
edge_set_name,
self._receiver,
'sum',
feature_value=edge_ones), -1)
if self._edge_weight_feature_name is not None:
try:
edge_weights = graph.edge_sets[edge_set_name][
self._edge_weight_feature_name]
except KeyError as e:
raise ValueError(f'{self._edge_weight_feature_name} is not given '
f'for edge set {edge_set_name} ') from e
if edge_weights.shape.rank != 1:
# GraphTensor guarantees it is not None.
raise ValueError('Expecting vector for edge weights. Received rank '
f'{tf.rank(edge_weights)}.')
edge_weights = tf.expand_dims(
edge_weights, axis=1) # Align with state feature.
else:
edge_weights = tf.ones([edge_set.total_size, 1])

in_degree = tf.squeeze(
tfgnn.pool_edges_to_node(
graph,
edge_set_name,
self._receiver,
'sum',
feature_value=edge_weights), -1)
# Degree matrix is the sum of rows of adjacency
# Adding self-loops adds an identity matrix to the adjacency
# This adds 1 to each diagonal element of the degree matrix
Expand All @@ -176,6 +199,8 @@ def call(
self._sender,
feature_value=normalized_values,
)
if self._edge_weight_feature_name is not None:
source_bcast = source_bcast * edge_weights
pooled = tfgnn.pool_edges_to_node(
graph, edge_set_name, self._receiver, 'sum', feature_value=source_bcast)

Expand Down
177 changes: 174 additions & 3 deletions tensorflow_gnn/models/gcn/gcn_conv_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# ==============================================================================
"""Tests for gcn_conv."""
import enum
import math
import os

from absl.testing import parameterized
Expand Down Expand Up @@ -193,9 +194,179 @@ def test_gcnconv_heterogeneous(self):
lambda: conv(graph, edge_set_name=tfgnn.EDGES))

@parameterized.named_parameters(
('', ReloadModel.SKIP),
('Restored', ReloadModel.SAVED_MODEL),
('RestoredKeras', ReloadModel.KERAS))
dict(
testcase_name='noSelfLoops_noBias',
use_bias=False,
add_self_loops=False,
),)
def test_gcnconv_with_edge_weights_ones(self, use_bias, add_self_loops):
"""Tests that gcn_conv returns the correct values with edge weights."""
graph = tfgnn.GraphTensor.from_pieces(
node_sets={
tfgnn.NODES:
tfgnn.NodeSet.from_fields(
sizes=[2],
features={
tfgnn.HIDDEN_STATE: tf.constant([[1., 0.], [0., 1.]])
},
)
},
edge_sets={
tfgnn.EDGES:
tfgnn.EdgeSet.from_fields(
sizes=[2],
features={
'weights': tf.constant([1.0, 1.0], dtype=tf.float32)
},
adjacency=tfgnn.Adjacency.from_indices(
source=(tfgnn.NODES, tf.constant([0, 1],
dtype=tf.int64)),
target=(tfgnn.NODES, tf.constant([1, 0],
dtype=tf.int64)),
))
})
conv_with_edge_weights = gcn_conv.GCNConv(
units=2,
use_bias=use_bias,
add_self_loops=add_self_loops,
kernel_initializer=tf.keras.initializers.Constant(tf.eye(2)),
edge_weight_feature_name='weights')
conv_without_edge_weights = gcn_conv.GCNConv(
units=2,
use_bias=use_bias,
add_self_loops=add_self_loops,
kernel_initializer=tf.keras.initializers.Constant(tf.eye(2)))
self.assertAllClose(
conv_with_edge_weights(graph, edge_set_name=tfgnn.EDGES),
conv_without_edge_weights(graph, edge_set_name=tfgnn.EDGES),
rtol=1e-06,
atol=1e-06)

@parameterized.named_parameters(
dict(
testcase_name='noSelfLoops_noBias',
use_bias=False,
add_self_loops=False,
expected_result=tf.constant([[0., 4. / (2. * 3.)],
[9. / (2. * 3.), 0.]])),
dict(
testcase_name='selfLoops_noBias',
use_bias=False,
add_self_loops=True,
expected_result=tf.constant(
[[
1. / (math.sqrt(5.) * math.sqrt(5.)),
4. / (math.sqrt(10.) * math.sqrt(5.))
],
[
9. / (math.sqrt(10.) * math.sqrt(5.)),
1. / (math.sqrt(10.) * math.sqrt(10.))
]])),
)
def test_gcnconv_with_edge_weights(self, use_bias, add_self_loops,
expected_result):
"""Tests that gcn_conv returns the correct values with edge weights."""
graph = tfgnn.GraphTensor.from_pieces(
node_sets={
tfgnn.NODES:
tfgnn.NodeSet.from_fields(
sizes=[2],
features={
tfgnn.HIDDEN_STATE: tf.constant([[1., 0.], [0., 1.]])
},
)
},
edge_sets={
tfgnn.EDGES:
tfgnn.EdgeSet.from_fields(
sizes=[2],
features={
'weights': tf.constant([9.0, 4.0], dtype=tf.float32)
},
adjacency=tfgnn.Adjacency.from_indices(
source=(tfgnn.NODES, tf.constant([0, 1],
dtype=tf.int64)),
target=(tfgnn.NODES, tf.constant([1, 0],
dtype=tf.int64)),
))
})
conv = gcn_conv.GCNConv(
units=2,
use_bias=use_bias,
add_self_loops=add_self_loops,
kernel_initializer=tf.keras.initializers.Constant(tf.eye(2)),
edge_weight_feature_name='weights')

self.assertAllClose(
expected_result,
conv(graph, edge_set_name=tfgnn.EDGES),
rtol=1e-06,
atol=1e-06)

def test_gcnconv_with_edge_weights_missing(self):
"""Tests that missing given edge weights feature name in the graph tensor throws an error."""
graph = tfgnn.GraphTensor.from_pieces(
node_sets={
tfgnn.NODES:
tfgnn.NodeSet.from_fields(
sizes=[2],
features={
tfgnn.HIDDEN_STATE: tf.constant([[1., 0.], [0., 1.]])
},
)
},
edge_sets={
tfgnn.EDGES:
tfgnn.EdgeSet.from_fields(
sizes=[2],
adjacency=tfgnn.Adjacency.from_indices(
source=(tfgnn.NODES, tf.constant([0, 1],
dtype=tf.int64)),
target=(tfgnn.NODES, tf.constant([1, 0],
dtype=tf.int64)),
))
})

conv = gcn_conv.GCNConv(units=2, edge_weight_feature_name='weights')
self.assertRaisesRegex(ValueError,
'weights is not given for edge set edges ',
lambda: conv(graph, edge_set_name=tfgnn.EDGES))

def test_gcnconv_with_bad_shaped_edge_weights(self):
"""Tests that given edge weights feature with bad shape throws an error."""
graph = tfgnn.GraphTensor.from_pieces(
node_sets={
tfgnn.NODES:
tfgnn.NodeSet.from_fields(
sizes=[2],
features={
tfgnn.HIDDEN_STATE: tf.constant([[1., 0.], [0., 1.]])
},
)
},
edge_sets={
tfgnn.EDGES:
tfgnn.EdgeSet.from_fields(
sizes=[2],
features={
'weights': tf.constant([[9.0], [4.0]], dtype=tf.float32)
},
adjacency=tfgnn.Adjacency.from_indices(
source=(tfgnn.NODES, tf.constant([0, 1],
dtype=tf.int64)),
target=(tfgnn.NODES, tf.constant([1, 0],
dtype=tf.int64)),
))
})

conv = gcn_conv.GCNConv(units=2, edge_weight_feature_name='weights')
self.assertRaisesRegex(
ValueError, 'Expecting vector for edge weights. Received rank 2.',
lambda: conv(graph, edge_set_name=tfgnn.EDGES))

@parameterized.named_parameters(('', ReloadModel.SKIP),
('Restored', ReloadModel.SAVED_MODEL),
('RestoredKeras', ReloadModel.KERAS))
def test_full_model(self, reload_model):
"""Tests GCNGraphUpdate in a full Model (incl. saving) with edge input."""
gt_input = tfgnn.GraphTensor.from_pieces(
Expand Down
1 change: 1 addition & 0 deletions tensorflow_gnn/runner/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ pytype_strict_library(
"//tensorflow_gnn",
"//tensorflow_gnn/runner/utils:model",
"//tensorflow_gnn/runner/utils:model_export",
"//tensorflow_gnn/runner/utils:parsing",
],
)

Expand Down
32 changes: 3 additions & 29 deletions tensorflow_gnn/runner/orchestration.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
import functools
import itertools
import os
from typing import Any, Callable, Optional, Sequence, Tuple, Union
from typing import Callable, Optional, Sequence, Tuple, Union

import tensorflow as tf
import tensorflow_gnn as tfgnn
from tensorflow_gnn.runner import interfaces
from tensorflow_gnn.runner.utils import model as model_utils
from tensorflow_gnn.runner.utils import model_export
from tensorflow_gnn.runner.utils import parsing as parsing_utils

GraphTensor = tfgnn.GraphTensor
GraphTensorAndField = Tuple[GraphTensor, tfgnn.Field]
Expand Down Expand Up @@ -172,33 +173,6 @@ def _make_preprocessing_model(
num_parallel_calls=tf.data.experimental.AUTOTUNE)


def _maybe_parse(gtspec: GraphTensorSpec) -> Callable[[Any], GraphTensor]:
"""Returns a callable to parse (or assert the spec of) dataset elements."""
parse_example = tfgnn.keras.layers.ParseExample(gtspec)
# Relax the spec for potential comparisons.
relaxed = gtspec.relax(num_components=True, num_nodes=True, num_edges=True)
def fn(element):
# Use `getattr` to account for types without a `dtype` (e.g. `GraphTensor`).
if getattr(element, "dtype", None) == tf.string:
gt = parse_example(element)
elif not tfgnn.is_graph_tensor(element):
raise ValueError(f"Expected `GraphTensor` (got {element})")
else:
# Access protected member `_unbatch` to avoid any potential
# `merge_batch_to_components` work.
actual = element.spec._unbatch().relax( # pylint: disable=protected-access
num_components=True,
num_nodes=True,
num_edges=True)
if actual != relaxed:
raise ValueError(
f"Expected a `GraphTensor` of spec {relaxed} (got {actual})")
else:
gt = element
return gt
return fn


def _per_replica_batch_size(global_batch_size: int, num_replicas: int) -> int:
if global_batch_size % num_replicas != 0:
raise ValueError(f"The `global_batch_size` {global_batch_size} is not "
Expand Down Expand Up @@ -287,7 +261,7 @@ def apply_fn(ds,
*,
filter_fn: Optional[Callable[..., bool]] = None,
size_constraints: Optional[SizeConstraints] = None):
ds = _map_over_dataset(ds, _maybe_parse(gtspec))
ds = parsing_utils.maybe_parse_graph_tensor_dataset(ds, gtspec)
if filter_fn is not None:
ds = ds.filter(filter_fn)
if size_constraints is not None:
Expand Down
Loading