Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CS224W - Bag of Tricks for Node Classification with GNN - Non interactive GAT #9832

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `interactive_attn` parameter to `GATConv` and `GATv2Conv` ([#9832](https://github.com/pyg-team/pytorch_geometric/pull/9832))
- Update Dockerfile to use latest from NVIDIA ([#9794](https://github.com/pyg-team/pytorch_geometric/pull/9794))
- Added various GRetriever Architecture Benchmarking examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666))
- Added `profiler.nvtxit` with some examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666))
Expand Down
9 changes: 6 additions & 3 deletions test/nn/conv/test_gat_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@


@pytest.mark.parametrize('residual', [False, True])
def test_gat_conv(residual):
@pytest.mark.parametrize('interactive_attn', [False, True])
def test_gat_conv(residual, interactive_attn):
x1 = torch.randn(4, 8)
x2 = torch.randn(2, 16)
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))

conv = GATConv(8, 32, heads=2, residual=residual)
conv = GATConv(8, 32, heads=2, residual=residual,
interactive_attn=interactive_attn)
assert str(conv) == 'GATConv(8, 32, heads=2)'
out = conv(x1, edge_index)
assert out.size() == (4, 64)
Expand Down Expand Up @@ -114,7 +116,8 @@ def forward(
# Test bipartite message passing:
adj1 = to_torch_csc_tensor(edge_index, size=(4, 2))

conv = GATConv((8, 16), 32, heads=2, residual=residual)
conv = GATConv((8, 16), 32, heads=2, residual=residual,
interactive_attn=interactive_attn)
assert str(conv) == 'GATConv((8, 16), 32, heads=2)'

out1 = conv((x1, x2), edge_index)
Expand Down
6 changes: 4 additions & 2 deletions test/nn/conv/test_gatv2_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@


@pytest.mark.parametrize('residual', [False, True])
def test_gatv2_conv(residual):
@pytest.mark.parametrize('interactive_attn', [False, True])
def test_gatv2_conv(residual, interactive_attn):
x1 = torch.randn(4, 8)
x2 = torch.randn(2, 8)
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))

conv = GATv2Conv(8, 32, heads=2, residual=residual)
conv = GATv2Conv(8, 32, heads=2, residual=residual,
interactive_attn=interactive_attn)
assert str(conv) == 'GATv2Conv(8, 32, heads=2)'
out = conv(x1, edge_index)
assert out.size() == (4, 64)
Expand Down
86 changes: 51 additions & 35 deletions torch_geometric/nn/conv/gat_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,41 +46,42 @@ class GATConv(MessagePassing):
\alpha_{i,j} =
\frac{
\exp\left(\mathrm{LeakyReLU}\left(
\mathbf{a}^{\top}_{s} \mathbf{\Theta}_{s}\mathbf{x}_i
+ \mathbf{a}^{\top}_{t} \mathbf{\Theta}_{t}\mathbf{x}_j
\mathbf{a}^{\top}_{t} \mathbf{\Theta}_{t}\mathbf{x}_i
+ \mathbf{a}^{\top}_{s} \mathbf{\Theta}_{s}\mathbf{x}_j
\right)\right)}
{\sum_{k \in \mathcal{N}(i) \cup \{ i \}}
\exp\left(\mathrm{LeakyReLU}\left(
\mathbf{a}^{\top}_{s} \mathbf{\Theta}_{s}\mathbf{x}_i
+ \mathbf{a}^{\top}_{t}\mathbf{\Theta}_{t}\mathbf{x}_k
\mathbf{a}^{\top}_{t} \mathbf{\Theta}_{t}\mathbf{x}_i
+ \mathbf{a}^{\top}_{s}\mathbf{\Theta}_{s}\mathbf{x}_k
\right)\right)}.

If the graph has multi-dimensional edge features :math:`\mathbf{e}_{i,j}`,
If the graph has multi-dimensional edge features :math:`\mathbf{e}_{j,i}`,
the attention coefficients :math:`\alpha_{i,j}` are computed as

.. math::
\alpha_{i,j} =
\frac{
\exp\left(\mathrm{LeakyReLU}\left(
\mathbf{a}^{\top}_{s} \mathbf{\Theta}_{s}\mathbf{x}_i
+ \mathbf{a}^{\top}_{t} \mathbf{\Theta}_{t}\mathbf{x}_j
+ \mathbf{a}^{\top}_{e} \mathbf{\Theta}_{e} \mathbf{e}_{i,j}
\mathbf{a}^{\top}_{t} \mathbf{\Theta}_{t}\mathbf{x}_i
+ \mathbf{a}^{\top}_{s} \mathbf{\Theta}_{s}\mathbf{x}_j
+ \mathbf{a}^{\top}_{e} \mathbf{\Theta}_{e} \mathbf{e}_{j,i}
\right)\right)}
{\sum_{k \in \mathcal{N}(i) \cup \{ i \}}
\exp\left(\mathrm{LeakyReLU}\left(
\mathbf{a}^{\top}_{s} \mathbf{\Theta}_{s}\mathbf{x}_i
+ \mathbf{a}^{\top}_{t} \mathbf{\Theta}_{t}\mathbf{x}_k
+ \mathbf{a}^{\top}_{e} \mathbf{\Theta}_{e} \mathbf{e}_{i,k}
\mathbf{a}^{\top}_{t} \mathbf{\Theta}_{t}\mathbf{x}_i
+ \mathbf{a}^{\top}_{s} \mathbf{\Theta}_{s}\mathbf{x}_k
+ \mathbf{a}^{\top}_{e} \mathbf{\Theta}_{e} \mathbf{e}_{k,i}
\right)\right)}.

If the graph is not bipartite, :math:`\mathbf{\Theta}_{s} =
\mathbf{\Theta}_{t}`.
If an integer is passed for :obj:`in_channels`, :math:`\mathbf{\Theta}_{s}
= \mathbf{\Theta}_{t}`.

Args:
in_channels (int or tuple): Size of each input sample, or :obj:`-1` to
derive the size from the first input(s) to the forward method.
A tuple corresponds to the sizes of source and target
dimensionalities in case of a bipartite graph.
dimensionalities and distinct :math:`\mathbf{\Theta}`, for example,
in the case of a bipartite graph.
out_channels (int): Size of each output sample.
heads (int, optional): Number of multi-head-attentions.
(default: :obj:`1`)
Expand Down Expand Up @@ -109,23 +110,26 @@ class GATConv(MessagePassing):
an additive bias. (default: :obj:`True`)
residual (bool, optional): If set to :obj:`True`, the layer will add
a learnable skip-connection. (default: :obj:`False`)
interactive_attn (bool, optional): If set to :obj:`False`, fixes
:math:`\mathbf{a}^{\top}_{t} \mathbf{\Theta}_{t}\mathbf{x}_i = 0`.
(default :obj:`True`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.

Shapes:
- **input:**
node features :math:`(|\mathcal{V}|, F_{in})` or
:math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))`
if bipartite,
for distinct source and target features (e.g. bipartite),
edge indices :math:`(2, |\mathcal{E}|)`,
edge features :math:`(|\mathcal{E}|, D)` *(optional)*
- **output:** node features :math:`(|\mathcal{V}|, H * F_{out})` or
:math:`((|\mathcal{V}_t|, H * F_{out})` if bipartite.
:math:`(|\mathcal{V}_t|, H * F_{out})` if passed a tuple.
If :obj:`return_attention_weights=True`, then
:math:`((|\mathcal{V}|, H * F_{out}),
((2, |\mathcal{E}|), (|\mathcal{E}|, H)))`
or :math:`((|\mathcal{V_t}|, H * F_{out}), ((2, |\mathcal{E}|),
(|\mathcal{E}|, H)))` if bipartite
(|\mathcal{E}|, H)))` if passed a tuple
"""
def __init__(
self,
Expand All @@ -140,6 +144,7 @@ def __init__(
fill_value: Union[float, Tensor, str] = 'mean',
bias: bool = True,
residual: bool = False,
interactive_attn: bool = True,
**kwargs,
):
kwargs.setdefault('aggr', 'add')
Expand All @@ -155,22 +160,25 @@ def __init__(
self.edge_dim = edge_dim
self.fill_value = fill_value
self.residual = residual
self.interactive_attn = interactive_attn

# In case we are operating in bipartite graphs, we apply separate
# transformations 'lin_src' and 'lin_dst' to source and target nodes:
# In case of tuple in_channels, we apply separate transformations
# 'lin_src' and 'lin_dst' to source and target nodes:
self.lin = self.lin_src = self.lin_dst = None
if isinstance(in_channels, int):
self.lin = Linear(in_channels, heads * out_channels, bias=False,
weight_initializer='glorot')
else:
self.lin_src = Linear(in_channels[0], heads * out_channels, False,
weight_initializer='glorot')
self.lin_dst = Linear(in_channels[1], heads * out_channels, False,
weight_initializer='glorot')
if interactive_attn:
self.lin_dst = Linear(in_channels[1], heads * out_channels,
False, weight_initializer='glorot')

# The learnable parameters to compute attention coefficients:
self.att_src = Parameter(torch.empty(1, heads, out_channels))
self.att_dst = Parameter(torch.empty(1, heads, out_channels))
if interactive_attn:
self.att_dst = Parameter(torch.empty(1, heads, out_channels))

if edge_dim is not None:
self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False,
Expand Down Expand Up @@ -214,7 +222,8 @@ def reset_parameters(self):
if self.res is not None:
self.res.reset_parameters()
glorot(self.att_src)
glorot(self.att_dst)
if self.interactive_attn:
glorot(self.att_dst)
glorot(self.att_edge)
zeros(self.bias)

Expand Down Expand Up @@ -299,13 +308,15 @@ def forward( # noqa: F811
res = self.res(x)

if self.lin is not None:
x_src = x_dst = self.lin(x).view(-1, H, C)
x_src = self.lin(x).view(-1, H, C)
x_dst = x_src if self.interactive_attn else None
else:
# If the module is initialized as bipartite, transform source
# and destination node features separately:
assert self.lin_src is not None and self.lin_dst is not None
# If the module is initialized with tuple in_channels,
# transform source and destination node features separately:
assert self.lin_src is not None
x_src = self.lin_src(x).view(-1, H, C)
x_dst = self.lin_dst(x).view(-1, H, C)
x_dst = (self.lin_dst(x).view(-1, H, C)
if self.interactive_attn else None)

else: # Tuple of source and target node features:
x_src, x_dst = x
Expand All @@ -314,26 +325,31 @@ def forward( # noqa: F811
if x_dst is not None and self.res is not None:
res = self.res(x_dst)

# In the case of non-interactive attention, we do not update x_dst
# below. Except in the case of a residual above, its value won't
# be used, but its size may be used when computing self loops.
if self.lin is not None:
# If the module is initialized as non-bipartite, we expect that
# source and destination node features have the same shape and
# that they their transformations are shared:
# If the module is initialized with integer in_channels, we
# expect that source and destination node features have the
# same shape and that they their transformations are shared:
x_src = self.lin(x_src).view(-1, H, C)
if x_dst is not None:
if x_dst is not None and self.interactive_attn:
x_dst = self.lin(x_dst).view(-1, H, C)
else:
assert self.lin_src is not None and self.lin_dst is not None
assert self.lin_src is not None

x_src = self.lin_src(x_src).view(-1, H, C)
if x_dst is not None:
if x_dst is not None and self.interactive_attn:
assert self.lin_dst is not None
x_dst = self.lin_dst(x_dst).view(-1, H, C)

x = (x_src, x_dst)

# Next, we compute node-level attention coefficients, both for source
# and target nodes (if present):
alpha_src = (x_src * self.att_src).sum(dim=-1)
alpha_dst = None if x_dst is None else (x_dst * self.att_dst).sum(-1)
alpha_dst = ((x_dst * self.att_dst).sum(-1)
if x_dst is not None and self.interactive_attn else None)
alpha = (alpha_src, alpha_dst)

if self.add_self_loops:
Expand Down
Loading
Loading