Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add XENetConv convolution layer #9869

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added XENetConv - a convolution layer based on the XENet paper ([#8257](https://github.com/pyg-team/pytorch_geometric/issues/8257))
- Update Dockerfile to use latest from NVIDIA ([#9794](https://github.com/pyg-team/pytorch_geometric/pull/9794))
- Added various GRetriever Architecture Benchmarking examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666))
- Added `profiler.nvtxit` with some examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666))
Expand Down
169 changes: 169 additions & 0 deletions test/nn/conv/test_xenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import unittest

import torch

from torch_geometric.data import Data
from torch_geometric.nn import XENetConv


class TestXENetConv(unittest.TestCase):
def setUp(self):
# Set random seed for reproducibility
torch.manual_seed(42)

# Define test dimensions
self.num_nodes = 4
self.in_node_channels = 3
self.in_edge_channels = 2
self.node_channels = 5
self.edge_channels = 4
self.stack_channels = [8, 16]

# Create a simple graph for testing
self.x = torch.randn(self.num_nodes, self.in_node_channels)
self.edge_index = torch.tensor(
[[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]], dtype=torch.long)
self.edge_attr = torch.randn(self.edge_index.size(1),
self.in_edge_channels)

# Create different variants of the layer for testing
self.conv_attention = XENetConv(in_node_channels=self.in_node_channels,
in_edge_channels=self.in_edge_channels,
stack_channels=self.stack_channels,
node_channels=self.node_channels,
edge_channels=self.edge_channels,
attention=True)

self.conv_no_attention = XENetConv(
in_node_channels=self.in_node_channels,
in_edge_channels=self.in_edge_channels,
stack_channels=self.stack_channels,
node_channels=self.node_channels, edge_channels=self.edge_channels,
attention=False)

def test_basic_forward(self):
"""Test basic forward pass with attention."""
out_x, out_edge_attr = self.conv_attention(self.x, self.edge_index,
self.edge_attr)

# Check output shapes
self.assertEqual(out_x.shape, (self.num_nodes, self.node_channels))
self.assertEqual(out_edge_attr.shape,
(self.edge_index.size(1), self.edge_channels))

# Check that outputs contain no NaN values
self.assertFalse(torch.isnan(out_x).any())
self.assertFalse(torch.isnan(out_edge_attr).any())

def test_no_attention_forward(self):
"""Test forward pass without attention."""
out_x, out_edge_attr = self.conv_no_attention(self.x, self.edge_index,
self.edge_attr)

# Check output shapes
self.assertEqual(out_x.shape, (self.num_nodes, self.node_channels))
self.assertEqual(out_edge_attr.shape,
(self.edge_index.size(1), self.edge_channels))

# Check that outputs contain no NaN values
self.assertFalse(torch.isnan(out_x).any())
self.assertFalse(torch.isnan(out_edge_attr).any())

def test_custom_activation(self):
"""Test with custom activation functions."""
conv = XENetConv(in_node_channels=self.in_node_channels,
in_edge_channels=self.in_edge_channels,
stack_channels=self.stack_channels,
node_channels=self.node_channels,
edge_channels=self.edge_channels, attention=True,
node_activation=torch.tanh,
edge_activation=torch.relu)

out_x, out_edge_attr = conv(self.x, self.edge_index, self.edge_attr)

# Check output ranges for activations
self.assertTrue(torch.all(out_x >= -1)
and torch.all(out_x <= 1)) # tanh range
self.assertTrue(torch.all(out_edge_attr >= 0)) # ReLU range

def test_single_stack_channel(self):
"""Test with a single stack channel instead of a list."""
conv = XENetConv(
in_node_channels=self.in_node_channels,
in_edge_channels=self.in_edge_channels,
stack_channels=32, # single integer
node_channels=self.node_channels,
edge_channels=self.edge_channels)

out_x, out_edge_attr = conv(self.x, self.edge_index, self.edge_attr)

# Check output shapes
self.assertEqual(out_x.shape, (self.num_nodes, self.node_channels))
self.assertEqual(out_edge_attr.shape,
(self.edge_index.size(1), self.edge_channels))

def test_batch_processing(self):
"""Test processing of batched graphs."""
# Create two graphs with different sizes
x1 = torch.randn(3, self.in_node_channels)
edge_index1 = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]],
dtype=torch.long)
edge_attr1 = torch.randn(edge_index1.size(1), self.in_edge_channels)

x2 = torch.randn(4, self.in_node_channels)
edge_index2 = torch.tensor([[0, 1, 2, 2, 3], [1, 2, 1, 3, 2]],
dtype=torch.long)
edge_attr2 = torch.randn(edge_index2.size(1), self.in_edge_channels)

# Create PyG Data objects
data1 = Data(x=x1, edge_index=edge_index1, edge_attr=edge_attr1)
data2 = Data(x=x2, edge_index=edge_index2, edge_attr=edge_attr2)

# Process each graph separately
out_x1, out_edge_attr1 = self.conv_attention(data1.x, data1.edge_index,
data1.edge_attr)
out_x2, out_edge_attr2 = self.conv_attention(data2.x, data2.edge_index,
data2.edge_attr)

# Check output shapes
self.assertEqual(out_x1.shape, (3, self.node_channels))
self.assertEqual(out_edge_attr1.shape, (4, self.edge_channels))
self.assertEqual(out_x2.shape, (4, self.node_channels))
self.assertEqual(out_edge_attr2.shape, (5, self.edge_channels))

def test_isolated_nodes(self):
"""Test handling of isolated nodes."""
# Create a graph with an isolated node
x = torch.randn(4, self.in_node_channels)
edge_index = torch.tensor([[0, 1], [1, 2]],
dtype=torch.long) # Node 3 is isolated
edge_attr = torch.randn(edge_index.size(1), self.in_edge_channels)

out_x, out_edge_attr = self.conv_attention(x, edge_index, edge_attr)

# Check that isolated node features are updated
self.assertFalse(torch.isnan(out_x[3]).any())
self.assertEqual(out_x.shape, (4, self.node_channels))
self.assertEqual(out_edge_attr.shape, (2, self.edge_channels))

def test_gradients(self):
"""Test gradient computation."""
self.x.requires_grad_()
self.edge_attr.requires_grad_()

out_x, out_edge_attr = self.conv_attention(self.x, self.edge_index,
self.edge_attr)

# Compute gradients
loss = out_x.sum() + out_edge_attr.sum()
loss.backward()

# Check that gradients are computed
self.assertIsNotNone(self.x.grad)
self.assertIsNotNone(self.edge_attr.grad)
self.assertFalse(torch.isnan(self.x.grad).any())
self.assertFalse(torch.isnan(self.edge_attr.grad).any())


if __name__ == '__main__':
unittest.main()
2 changes: 2 additions & 0 deletions torch_geometric/nn/conv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from .antisymmetric_conv import AntiSymmetricConv
from .dir_gnn_conv import DirGNNConv
from .mixhop_conv import MixHopConv
from .xenet_conv import XENetConv

import torch_geometric.nn.conv.utils # noqa

Expand Down Expand Up @@ -131,6 +132,7 @@
'AntiSymmetricConv',
'DirGNNConv',
'MixHopConv',
'XENetConv',
]

classes = __all__
Expand Down
155 changes: 155 additions & 0 deletions torch_geometric/nn/conv/xenet_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from typing import List, Optional, Union

import torch
from torch import Tensor, nn

from torch_geometric.nn.conv import MessagePassing


class XENetConv(MessagePassing):
r"""Implementation of XENet convolution layer from the paper.

"XENet: Using a new graph convolution to accelerate the timeline for
protein design on quantum computers.

Based on original implementation here:
https://github.com/danielegrattarola/spektral/blob/master/spektral/ \
layers/convolutional/xenet_conv.py"

Args:
in_node_channels (int): Size of input node features
in_edge_channels (int): Size of input edge features
stack_channels (Union[int, List[int]]): Number of channels for the
hidden stack layers
node_channels (int): Number of output node features
edge_channels (int): Number of output edge features
attention (bool, optional): Whether to use attention when aggregating
messages. (default: True)
node_activation(Optional[callable], optional): Activation function for
nodes. (default: None)
edge_activation (Optional[callable], optional): Activation function for
edges. (default: None)
"""
def __init__(self, in_node_channels: int, in_edge_channels: int,
stack_channels: Union[int, List[int]], node_channels: int,
edge_channels: int, attention: bool = True,
node_activation: Optional[callable] = None,
edge_activation: Optional[callable] = None, **kwargs):
super().__init__(aggr='add', node_dim=0, **kwargs)

self.in_node_channels = in_node_channels
self.in_edge_channels = in_edge_channels
self.stack_channels = stack_channels if isinstance(
stack_channels, list) else [stack_channels]
self.node_channels = node_channels
self.edge_channels = edge_channels
self.attention = attention

# Node and edge activation functions
self.node_activation = node_activation if node_activation is not None \
else lambda x: x
self.edge_activation = edge_activation if edge_activation is not None \
else lambda x: x

# Stack MLPs
stack_input_size = 2 * in_node_channels + 2 * in_edge_channels
self.stack_layers = nn.ModuleList()
current_channels = stack_input_size

for i, channels in enumerate(self.stack_channels):
self.stack_layers.append(nn.Linear(current_channels, channels))
if i != len(self.stack_channels) - 1:
self.stack_layers.append(nn.ReLU())
else:
self.stack_layers.append(nn.PReLU())
current_channels = channels

# Final node and edge MLPs
node_input_size = in_node_channels + 2 * self.stack_channels[-1]
self.node_mlp = nn.Linear(node_input_size, node_channels)
self.edge_mlp = nn.Linear(self.stack_channels[-1], edge_channels)

# Attention layers
if self.attention:
self.att_in = nn.Sequential(nn.Linear(self.stack_channels[-1], 1),
nn.Sigmoid())
self.att_out = nn.Sequential(nn.Linear(self.stack_channels[-1], 1),
nn.Sigmoid())

def forward(self, x: Tensor, edge_index: Tensor,
edge_attr: Tensor) -> tuple[Tensor, Tensor]:
"""Args
x (Tensor): Node feature matrix of shape [num_nodes,
in_node_channels] edge_index (Tensor): Graph connectivity matrix of
shape [2, num_edges] edge_attr (Tensor): Edge feature matrix of
shape [num_edges, in_edge_channels]

Returns:
tuple[Tensor, Tensor]: Updated node features [num_nodes,
node_channels] and edge features [num_edges, edge_channels]
"""
# Propagate messages
out_dict = self.propagate(edge_index, x=x, edge_attr=edge_attr,
size=(x.size(0), x.size(0)))

# Update node features
x_new = self.node_mlp(
torch.cat([x, out_dict['incoming'], out_dict['outgoing']], dim=-1))
x_new = self.node_activation(x_new)

# Update edge features
edge_features = out_dict['edge_features']
edge_attr_new = self.edge_mlp(edge_features)
edge_attr_new = self.edge_activation(edge_attr_new)

return x_new, edge_attr_new

def message(self, x_i: Tensor, x_j: Tensor, edge_attr: Tensor) -> dict:
"""Constructs messages for each edge."""
# Get reversed edge features by flipping edge_index
edge_attr_rev = edge_attr[torch.arange(edge_attr.size(0) - 1, -1, -1)]

# Concatenate all features
stack = torch.cat([x_i, x_j, edge_attr, edge_attr_rev], dim=-1)

# Apply stack MLPs
for layer in self.stack_layers:
stack = layer(stack)

# Apply attention if needed
if self.attention:
att_in = self.att_in(stack)
att_out = self.att_out(stack)
stack_in = stack * att_in
stack_out = stack * att_out
else:
stack_in = stack_out = stack

return {
'incoming': stack_in,
'outgoing': stack_out,
'edge_features': stack
}

def aggregate(self, inputs: dict, index: Tensor,
dim_size: Optional[int] = None) -> dict:
"""Aggregates messages from neighbors."""
incoming = self.aggr_module(inputs['incoming'], index,
dim_size=dim_size)
outgoing = self.aggr_module(inputs['outgoing'], index,
dim_size=dim_size)

return {
'incoming': incoming,
'outgoing': outgoing,
'edge_features': inputs['edge_features']
}

def __repr__(self) -> str:
return (f'{self.__class__.__name__}('
f'in_node_channels={self.in_node_channels}, '
f'in_edge_channels={self.in_edge_channels}, '
f'stack_channels={self.stack_channels}, '
f'node_channels={self.node_channels}, '
f'edge_channels={self.edge_channels}, '
f'attention={self.attention})')