From ca560e52c2d5013c45504273e0ec1c2c80e1cfee Mon Sep 17 00:00:00 2001 From: Matthew Hayes Date: Wed, 6 Nov 2024 02:41:39 -0800 Subject: [PATCH 1/9] Initial non-interactive implementation for GATConv --- torch_geometric/nn/conv/gat_conv.py | 31 +++++++++++++++++------------ 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/torch_geometric/nn/conv/gat_conv.py b/torch_geometric/nn/conv/gat_conv.py index 720dfb09811c..cd2d811b5818 100644 --- a/torch_geometric/nn/conv/gat_conv.py +++ b/torch_geometric/nn/conv/gat_conv.py @@ -73,14 +73,17 @@ class GATConv(MessagePassing): + \mathbf{a}^{\top}_{e} \mathbf{\Theta}_{e} \mathbf{e}_{i,k} \right)\right)}. - If the graph is not bipartite, :math:`\mathbf{\Theta}_{s} = + If an integer is passed for :obj:`in_channels`, :math:`\mathbf{\Theta}_{s} = \mathbf{\Theta}_{t}`. Args: in_channels (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. - A tuple corresponds to the sizes of source and target - dimensionalities in case of a bipartite graph. + A tuple dictates the dimensionality of distinct + :math:`\mathbf{\Theta}` to be used for source and target features, + for example, in case of a bipartite graph. A value of :obj:`None|0` + for the source dimensionality fixes + :math:`\mathbf{\Theta}_s=\mathbf{0}`. out_channels (int): Size of each output sample. heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`) @@ -156,10 +159,10 @@ def __init__( self.fill_value = fill_value self.residual = residual - # In case we are operating in bipartite graphs, we apply separate - # transformations 'lin_src' and 'lin_dst' to source and target nodes: + # In case of e.g. bipartite graphs, we apply separate transformations + # 'lin_src' and 'lin_dst' to source and target nodes: self.lin = self.lin_src = self.lin_dst = None - if isinstance(in_channels, int): + if isinstance(in_channels, int) or not in_channels[0]: self.lin = Linear(in_channels, heads * out_channels, bias=False, weight_initializer='glorot') else: @@ -169,8 +172,10 @@ def __init__( weight_initializer='glorot') # The learnable parameters to compute attention coefficients: - self.att_src = Parameter(torch.empty(1, heads, out_channels)) self.att_dst = Parameter(torch.empty(1, heads, out_channels)) + self.att_src = None + if isinstance(in_channels, int) or in_channels[0]: + self.att_src = Parameter(torch.empty(1, heads, out_channels)) if edge_dim is not None: self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False, @@ -285,7 +290,6 @@ def forward( # noqa: F811 # `torch.jit._overload` decorator, as we can only change the output # arguments conditioned on type (`None` or `bool`), not based on its # actual value. - H, C = self.heads, self.out_channels res: Optional[Tensor] = None @@ -301,8 +305,9 @@ def forward( # noqa: F811 if self.lin is not None: x_src = x_dst = self.lin(x).view(-1, H, C) else: - # If the module is initialized as bipartite, transform source - # and destination node features separately: + # If the module is initialized with a tuple of positive + # in_channels, transform source and destination node features + # separately: assert self.lin_src is not None and self.lin_dst is not None x_src = self.lin_src(x).view(-1, H, C) x_dst = self.lin_dst(x).view(-1, H, C) @@ -315,9 +320,9 @@ def forward( # noqa: F811 res = self.res(x_dst) if self.lin is not None: - # If the module is initialized as non-bipartite, we expect that - # source and destination node features have the same shape and - # that they their transformations are shared: + # If the module is initialized with integer in_channels, we + # expect that source and destination node features have the same + # shape and that they their transformations are shared: x_src = self.lin(x_src).view(-1, H, C) if x_dst is not None: x_dst = self.lin(x_dst).view(-1, H, C) From cf3d48888f7d33dc6c73ef20a29e984d67da68fe Mon Sep 17 00:00:00 2001 From: Matthew Hayes Date: Thu, 7 Nov 2024 22:37:41 -0800 Subject: [PATCH 2/9] fix last commit aws wrong version --- torch_geometric/nn/conv/gat_conv.py | 58 +++++++++++++++++------------ 1 file changed, 34 insertions(+), 24 deletions(-) diff --git a/torch_geometric/nn/conv/gat_conv.py b/torch_geometric/nn/conv/gat_conv.py index cd2d811b5818..ff83a54af06b 100644 --- a/torch_geometric/nn/conv/gat_conv.py +++ b/torch_geometric/nn/conv/gat_conv.py @@ -73,17 +73,17 @@ class GATConv(MessagePassing): + \mathbf{a}^{\top}_{e} \mathbf{\Theta}_{e} \mathbf{e}_{i,k} \right)\right)}. - If an integer is passed for :obj:`in_channels`, :math:`\mathbf{\Theta}_{s} = - \mathbf{\Theta}_{t}`. + If an integer is passed for :obj:`in_channels`, :math:`\mathbf{\Theta}_{s} + = \mathbf{\Theta}_{t}`. Args: in_channels (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. - A tuple dictates the dimensionality of distinct + A tuple dictates the dimensionality of distinct :math:`\mathbf{\Theta}` to be used for source and target features, - for example, in case of a bipartite graph. A value of :obj:`None|0` - for the source dimensionality fixes - :math:`\mathbf{\Theta}_s=\mathbf{0}`. + for example, in case of a bipartite graph. A value of + :obj:`None|0` for the target dimensionality fixes + :math:`\mathbf{\Theta}_t=\mathbf{0}`. out_channels (int): Size of each output sample. heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`) @@ -119,16 +119,17 @@ class GATConv(MessagePassing): - **input:** node features :math:`(|\mathcal{V}|, F_{in})` or :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` - if bipartite, + if :obj:`in_channels` is a tuple of positive integers, edge indices :math:`(2, |\mathcal{E}|)`, edge features :math:`(|\mathcal{E}|, D)` *(optional)* - **output:** node features :math:`(|\mathcal{V}|, H * F_{out})` or - :math:`((|\mathcal{V}_t|, H * F_{out})` if bipartite. + :math:`(|\mathcal{V}_t|, H * F_{out})` if bipartite. If :obj:`return_attention_weights=True`, then :math:`((|\mathcal{V}|, H * F_{out}), ((2, |\mathcal{E}|), (|\mathcal{E}|, H)))` or :math:`((|\mathcal{V_t}|, H * F_{out}), ((2, |\mathcal{E}|), - (|\mathcal{E}|, H)))` if bipartite + (|\mathcal{E}|, H)))` if :obj:`in_channels` is a tuple of positive + integers. """ def __init__( self, @@ -159,23 +160,25 @@ def __init__( self.fill_value = fill_value self.residual = residual - # In case of e.g. bipartite graphs, we apply separate transformations + # In case of e.g. bipartite graphs, we apply separate transformations # 'lin_src' and 'lin_dst' to source and target nodes: self.lin = self.lin_src = self.lin_dst = None - if isinstance(in_channels, int) or not in_channels[0]: + if isinstance(in_channels, int): self.lin = Linear(in_channels, heads * out_channels, bias=False, weight_initializer='glorot') else: self.lin_src = Linear(in_channels[0], heads * out_channels, False, weight_initializer='glorot') - self.lin_dst = Linear(in_channels[1], heads * out_channels, False, - weight_initializer='glorot') + if in_channels[1]: + self.lin_dst = Linear(in_channels[1], heads * out_channels, + False, weight_initializer='glorot') # The learnable parameters to compute attention coefficients: - self.att_dst = Parameter(torch.empty(1, heads, out_channels)) - self.att_src = None - if isinstance(in_channels, int) or in_channels[0]: - self.att_src = Parameter(torch.empty(1, heads, out_channels)) + self.att_src = Parameter(torch.empty(1, heads, out_channels)) + if isinstance(in_channels, int) or in_channels[1]: + self.att_dst = Parameter(torch.empty(1, heads, out_channels)) + else: + self.att_dst = None if edge_dim is not None: self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False, @@ -305,12 +308,15 @@ def forward( # noqa: F811 if self.lin is not None: x_src = x_dst = self.lin(x).view(-1, H, C) else: - # If the module is initialized with a tuple of positive + # If the module is initialized with a tuple of positive # in_channels, transform source and destination node features # separately: - assert self.lin_src is not None and self.lin_dst is not None + assert self.lin_src is not None x_src = self.lin_src(x).view(-1, H, C) - x_dst = self.lin_dst(x).view(-1, H, C) + if self.lin_dst is not None: + x_dst = self.lin_dst(x).view(-1, H, C) + else: + x_dst = None else: # Tuple of source and target node features: x_src, x_dst = x @@ -321,17 +327,21 @@ def forward( # noqa: F811 if self.lin is not None: # If the module is initialized with integer in_channels, we - # expect that source and destination node features have the same - # shape and that they their transformations are shared: + # expect that source and destination node features have the + # same shape and that they their transformations are shared: x_src = self.lin(x_src).view(-1, H, C) if x_dst is not None: x_dst = self.lin(x_dst).view(-1, H, C) else: - assert self.lin_src is not None and self.lin_dst is not None + assert self.lin_src is not None x_src = self.lin_src(x_src).view(-1, H, C) - if x_dst is not None: + if self.lin_dst is not None and x_dst is not None: x_dst = self.lin_dst(x_dst).view(-1, H, C) + else: + # TODO maybe warn user they passed dest features that won't + # be used + x_dst = None x = (x_src, x_dst) From d3facf58b9c67a79d38f2e53de841a7ebc692efc Mon Sep 17 00:00:00 2001 From: Matthew Hayes Date: Thu, 7 Nov 2024 22:48:25 -0800 Subject: [PATCH 3/9] swap indices in docs so s maps to j and t to i --- torch_geometric/nn/conv/gat_conv.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/torch_geometric/nn/conv/gat_conv.py b/torch_geometric/nn/conv/gat_conv.py index ff83a54af06b..53c14a0872a3 100644 --- a/torch_geometric/nn/conv/gat_conv.py +++ b/torch_geometric/nn/conv/gat_conv.py @@ -46,13 +46,13 @@ class GATConv(MessagePassing): \alpha_{i,j} = \frac{ \exp\left(\mathrm{LeakyReLU}\left( - \mathbf{a}^{\top}_{s} \mathbf{\Theta}_{s}\mathbf{x}_i - + \mathbf{a}^{\top}_{t} \mathbf{\Theta}_{t}\mathbf{x}_j + \mathbf{a}^{\top}_{t} \mathbf{\Theta}_{t}\mathbf{x}_i + + \mathbf{a}^{\top}_{s} \mathbf{\Theta}_{s}\mathbf{x}_j \right)\right)} {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} \exp\left(\mathrm{LeakyReLU}\left( - \mathbf{a}^{\top}_{s} \mathbf{\Theta}_{s}\mathbf{x}_i - + \mathbf{a}^{\top}_{t}\mathbf{\Theta}_{t}\mathbf{x}_k + \mathbf{a}^{\top}_{t} \mathbf{\Theta}_{t}\mathbf{x}_i + + \mathbf{a}^{\top}_{s}\mathbf{\Theta}_{s}\mathbf{x}_k \right)\right)}. If the graph has multi-dimensional edge features :math:`\mathbf{e}_{i,j}`, @@ -62,14 +62,14 @@ class GATConv(MessagePassing): \alpha_{i,j} = \frac{ \exp\left(\mathrm{LeakyReLU}\left( - \mathbf{a}^{\top}_{s} \mathbf{\Theta}_{s}\mathbf{x}_i - + \mathbf{a}^{\top}_{t} \mathbf{\Theta}_{t}\mathbf{x}_j + \mathbf{a}^{\top}_{t} \mathbf{\Theta}_{t}\mathbf{x}_i + + \mathbf{a}^{\top}_{s} \mathbf{\Theta}_{s}\mathbf{x}_j + \mathbf{a}^{\top}_{e} \mathbf{\Theta}_{e} \mathbf{e}_{i,j} \right)\right)} {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} \exp\left(\mathrm{LeakyReLU}\left( - \mathbf{a}^{\top}_{s} \mathbf{\Theta}_{s}\mathbf{x}_i - + \mathbf{a}^{\top}_{t} \mathbf{\Theta}_{t}\mathbf{x}_k + \mathbf{a}^{\top}_{t} \mathbf{\Theta}_{t}\mathbf{x}_i + + \mathbf{a}^{\top}_{s} \mathbf{\Theta}_{s}\mathbf{x}_k + \mathbf{a}^{\top}_{e} \mathbf{\Theta}_{e} \mathbf{e}_{i,k} \right)\right)}. From c34e3d92bf2685e9ee27ec507fcf7c431bad2c4d Mon Sep 17 00:00:00 2001 From: Matthew Hayes Date: Sat, 30 Nov 2024 20:25:04 -0800 Subject: [PATCH 4/9] gatv2 and tests --- benchmark/citation/gat.py | 12 ++++-- test/nn/conv/test_gat_conv.py | 34 ++++++++++++++++ test/nn/conv/test_gatv2_conv.py | 34 ++++++++++++++++ torch_geometric/nn/conv/gat_conv.py | 33 ++++++++-------- torch_geometric/nn/conv/gatv2_conv.py | 56 +++++++++++++++++---------- 5 files changed, 129 insertions(+), 40 deletions(-) diff --git a/benchmark/citation/gat.py b/benchmark/citation/gat.py index f9ed5d6071af..46c2ba23fea6 100644 --- a/benchmark/citation/gat.py +++ b/benchmark/citation/gat.py @@ -24,15 +24,21 @@ parser.add_argument('--profile', action='store_true') parser.add_argument('--bf16', action='store_true') parser.add_argument('--compile', action='store_true') +parser.add_argument('--non_interactive', action='store_true') args = parser.parse_args() class Net(torch.nn.Module): def __init__(self, dataset): super().__init__() - self.conv1 = GATConv(dataset.num_features, args.hidden, - heads=args.heads, dropout=args.dropout) - self.conv2 = GATConv(args.hidden * args.heads, dataset.num_classes, + in_channels_1 = dataset.num_features + in_channels_2 = args.hidden * args.heads + if args.non_interactive: + in_channels_1 = (in_channels_1, None) + in_channels_2 = (in_channels_2, None) + self.conv1 = GATConv(in_channels_1, args.hidden, heads=args.heads, + dropout=args.dropout) + self.conv2 = GATConv(in_channels_2, dataset.num_classes, heads=args.output_heads, concat=False, dropout=args.dropout) diff --git a/test/nn/conv/test_gat_conv.py b/test/nn/conv/test_gat_conv.py index 6549911ac0d4..415d863879d5 100644 --- a/test/nn/conv/test_gat_conv.py +++ b/test/nn/conv/test_gat_conv.py @@ -111,6 +111,40 @@ def forward( assert torch.allclose(result[0], out, atol=1e-6) assert result[1].sizes() == [4, 4, 2] and result[1].nnz() == 7 + # Test no target features in attention + conv = GATConv((8, None), 32, heads=2, residual=residual) + assert str(conv) == 'GATConv((8, None), 32, heads=2)' + out = conv(x1, edge_index) + assert out.size() == (4, 64) + assert torch.allclose(conv(x1, edge_index, size=(4, 4)), out) + assert torch.allclose(conv(x1, adj1.t()), out, atol=1e-6) + + if torch_geometric.typing.WITH_TORCH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) + assert torch.allclose(conv(x1, adj2.t()), out, atol=1e-6) + + if is_full_test(): + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = conv + + def forward( + self, + x: Tensor, + edge_index: Adj, + size: Size = None, + ) -> Tensor: + return self.conv(x, edge_index, size=size) + + jit = torch.jit.script(MyModule()) + assert torch.allclose(jit(x1, edge_index), out) + assert torch.allclose(jit(x1, edge_index, size=(4, 4)), out) + + if torch_geometric.typing.WITH_TORCH_SPARSE: + assert torch.allclose(jit(x1, adj2.t()), out, atol=1e-6) + # Test bipartite message passing: adj1 = to_torch_csc_tensor(edge_index, size=(4, 2)) diff --git a/test/nn/conv/test_gatv2_conv.py b/test/nn/conv/test_gatv2_conv.py index 3bca6530eee9..3ee4070167cb 100644 --- a/test/nn/conv/test_gatv2_conv.py +++ b/test/nn/conv/test_gatv2_conv.py @@ -109,9 +109,43 @@ def forward( assert torch.allclose(result[0], out, atol=1e-6) assert result[1].sizes() == [4, 4, 2] and result[1].nnz() == 7 + # Test no target features in attention + conv = GATv2Conv((8, None), 32, heads=2, residual=residual) + assert str(conv) == 'GATv2Conv((8, None), 32, heads=2)' + out = conv(x1, edge_index) + assert out.size() == (4, 64) + assert torch.allclose(conv(x1, edge_index), out) + assert torch.allclose(conv(x1, adj1.t()), out, atol=1e-6) + + if torch_geometric.typing.WITH_TORCH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) + assert torch.allclose(conv(x1, adj2.t()), out, atol=1e-6) + + if is_full_test(): + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = conv + + def forward( + self, + x: Tensor, + edge_index: Adj, + ) -> Tensor: + return self.conv(x, edge_index) + + jit = torch.jit.script(MyModule()) + assert torch.allclose(jit(x1, edge_index), out) + + if torch_geometric.typing.WITH_TORCH_SPARSE: + assert torch.allclose(jit(x1, adj2.t()), out, atol=1e-6) + # Test bipartite message passing: adj1 = to_torch_csc_tensor(edge_index, size=(4, 2)) + conv = GATv2Conv(8, 32, heads=2, residual=residual) + out = conv((x1, x2), edge_index) assert out.size() == (2, 64) assert torch.allclose(conv((x1, x2), edge_index), out) diff --git a/torch_geometric/nn/conv/gat_conv.py b/torch_geometric/nn/conv/gat_conv.py index 53c14a0872a3..ac05e475b063 100644 --- a/torch_geometric/nn/conv/gat_conv.py +++ b/torch_geometric/nn/conv/gat_conv.py @@ -79,11 +79,10 @@ class GATConv(MessagePassing): Args: in_channels (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. - A tuple dictates the dimensionality of distinct - :math:`\mathbf{\Theta}` to be used for source and target features, - for example, in case of a bipartite graph. A value of - :obj:`None|0` for the target dimensionality fixes - :math:`\mathbf{\Theta}_t=\mathbf{0}`. + A tuple corresponds to the sizes of source and target + dimensionalities and distinct :math:`\mathbf{\Theta}`, for example, + in the case of a bipartite graph. A value of :obj:`None|0` for the + target dimensionality fixes :math:`\mathbf{\Theta}_t=\mathbf{0}`. out_channels (int): Size of each output sample. heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`) @@ -119,17 +118,16 @@ class GATConv(MessagePassing): - **input:** node features :math:`(|\mathcal{V}|, F_{in})` or :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` - if :obj:`in_channels` is a tuple of positive integers, + for distinct source and target features (e.g. bipartite), edge indices :math:`(2, |\mathcal{E}|)`, edge features :math:`(|\mathcal{E}|, D)` *(optional)* - **output:** node features :math:`(|\mathcal{V}|, H * F_{out})` or - :math:`(|\mathcal{V}_t|, H * F_{out})` if bipartite. + :math:`(|\mathcal{V}_t|, H * F_{out})` if passed a tuple. If :obj:`return_attention_weights=True`, then :math:`((|\mathcal{V}|, H * F_{out}), ((2, |\mathcal{E}|), (|\mathcal{E}|, H)))` or :math:`((|\mathcal{V_t}|, H * F_{out}), ((2, |\mathcal{E}|), - (|\mathcal{E}|, H)))` if :obj:`in_channels` is a tuple of positive - integers. + (|\mathcal{E}|, H)))` if passed a tuple """ def __init__( self, @@ -160,8 +158,8 @@ def __init__( self.fill_value = fill_value self.residual = residual - # In case of e.g. bipartite graphs, we apply separate transformations - # 'lin_src' and 'lin_dst' to source and target nodes: + # In case tuple in_channels, e.g. bipartite graphs, we apply separate + # transformations 'lin_src' and 'lin_dst' to source and target nodes: self.lin = self.lin_src = self.lin_dst = None if isinstance(in_channels, int): self.lin = Linear(in_channels, heads * out_channels, bias=False, @@ -192,9 +190,11 @@ def __init__( total_out_channels = out_channels * (heads if concat else 1) if residual: + res_in_channels = in_channels + if not isinstance(in_channels, int): + res_in_channels = in_channels[1] if in_channels[1] else -1 self.res = Linear( - in_channels - if isinstance(in_channels, int) else in_channels[1], + res_in_channels, total_out_channels, bias=False, weight_initializer='glorot', @@ -293,6 +293,7 @@ def forward( # noqa: F811 # `torch.jit._overload` decorator, as we can only change the output # arguments conditioned on type (`None` or `bool`), not based on its # actual value. + H, C = self.heads, self.out_channels res: Optional[Tensor] = None @@ -336,11 +337,11 @@ def forward( # noqa: F811 assert self.lin_src is not None x_src = self.lin_src(x_src).view(-1, H, C) - if self.lin_dst is not None and x_dst is not None: + if x_dst is not None and self.lin_dst is not None: x_dst = self.lin_dst(x_dst).view(-1, H, C) else: - # TODO maybe warn user they passed dest features that won't - # be used + # TODO maybe warn user if they passed dest features that + # won't be used x_dst = None x = (x_src, x_dst) diff --git a/torch_geometric/nn/conv/gatv2_conv.py b/torch_geometric/nn/conv/gatv2_conv.py index f3b2f4937e52..bcc12c35536c 100644 --- a/torch_geometric/nn/conv/gatv2_conv.py +++ b/torch_geometric/nn/conv/gatv2_conv.py @@ -50,11 +50,11 @@ class GATv2Conv(MessagePassing): \alpha_{i,j} = \frac{ \exp\left(\mathbf{a}^{\top}\mathrm{LeakyReLU}\left( - \mathbf{\Theta}_{s} \mathbf{x}_i + \mathbf{\Theta}_{t} \mathbf{x}_j + \mathbf{\Theta}_{t} \mathbf{x}_i + \mathbf{\Theta}_{s} \mathbf{x}_j \right)\right)} {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} \exp\left(\mathbf{a}^{\top}\mathrm{LeakyReLU}\left( - \mathbf{\Theta}_{s} \mathbf{x}_i + \mathbf{\Theta}_{t} \mathbf{x}_k + \mathbf{\Theta}_{t} \mathbf{x}_i + \mathbf{\Theta}_{s} \mathbf{x}_k \right)\right)}. If the graph has multi-dimensional edge features :math:`\mathbf{e}_{i,j}`, @@ -64,14 +64,14 @@ class GATv2Conv(MessagePassing): \alpha_{i,j} = \frac{ \exp\left(\mathbf{a}^{\top}\mathrm{LeakyReLU}\left( - \mathbf{\Theta}_{s} \mathbf{x}_i - + \mathbf{\Theta}_{t} \mathbf{x}_j + \mathbf{\Theta}_{t} \mathbf{x}_i + + \mathbf{\Theta}_{s} \mathbf{x}_j + \mathbf{\Theta}_{e} \mathbf{e}_{i,j} \right)\right)} {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} \exp\left(\mathbf{a}^{\top}\mathrm{LeakyReLU}\left( - \mathbf{\Theta}_{s} \mathbf{x}_i - + \mathbf{\Theta}_{t} \mathbf{x}_k + \mathbf{\Theta}_{t} \mathbf{x}_i + + \mathbf{\Theta}_{s} \mathbf{x}_k + \mathbf{\Theta}_{e} \mathbf{e}_{i,k}] \right)\right)}. @@ -79,7 +79,9 @@ class GATv2Conv(MessagePassing): in_channels (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target - dimensionalities in case of a bipartite graph. + dimensionalities and distinct :math:`\mathbf{\Theta}`, for example, + in the case of a bipartite graph. A value of :obj:`None|0` for the + target dimensionality fixes :math:`\mathbf{\Theta}_t=\mathbf{0}`. out_channels (int): Size of each output sample. heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`) @@ -119,16 +121,16 @@ class GATv2Conv(MessagePassing): - **input:** node features :math:`(|\mathcal{V}|, F_{in})` or :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` - if bipartite, + for distinct source and target features (e.g. bipartite), edge indices :math:`(2, |\mathcal{E}|)`, edge features :math:`(|\mathcal{E}|, D)` *(optional)* - **output:** node features :math:`(|\mathcal{V}|, H * F_{out})` or - :math:`((|\mathcal{V}_t|, H * F_{out})` if bipartite. + :math:`(|\mathcal{V}_t|, H * F_{out})` if passed a tuple. If :obj:`return_attention_weights=True`, then :math:`((|\mathcal{V}|, H * F_{out}), ((2, |\mathcal{E}|), (|\mathcal{E}|, H)))` or :math:`((|\mathcal{V_t}|, H * F_{out}), ((2, |\mathcal{E}|), - (|\mathcal{E}|, H)))` if bipartite + (|\mathcal{E}|, H)))` if passed a tuple """ def __init__( self, @@ -171,11 +173,14 @@ def __init__( else: self.lin_l = Linear(in_channels[0], heads * out_channels, bias=bias, weight_initializer='glorot') - if share_weights: - self.lin_r = self.lin_l + if in_channels[1]: + if share_weights: + self.lin_r = self.lin_l + else: + self.lin_r = Linear(in_channels[1], heads * out_channels, + bias=bias, weight_initializer='glorot') else: - self.lin_r = Linear(in_channels[1], heads * out_channels, - bias=bias, weight_initializer='glorot') + self.lin_r = None self.att = Parameter(torch.empty(1, heads, out_channels)) @@ -189,9 +194,11 @@ def __init__( total_out_channels = out_channels * (heads if concat else 1) if residual: + res_in_channels = in_channels + if not isinstance(in_channels, int): + res_in_channels = in_channels[1] if in_channels[1] else -1 self.res = Linear( - in_channels - if isinstance(in_channels, int) else in_channels[1], + res_in_channels, total_out_channels, bias=False, weight_initializer='glorot', @@ -209,7 +216,8 @@ def __init__( def reset_parameters(self): super().reset_parameters() self.lin_l.reset_parameters() - self.lin_r.reset_parameters() + if self.lin_r is not None: + self.lin_r.reset_parameters() if self.lin_edge is not None: self.lin_edge.reset_parameters() if self.res is not None: @@ -282,11 +290,12 @@ def forward( # noqa: F811 if self.res is not None: res = self.res(x) + print(f"x shape {x.shape}, res shape {res.shape}") x_l = self.lin_l(x).view(-1, H, C) if self.share_weights: x_r = x_l - else: + elif self.lin_r is not None: x_r = self.lin_r(x).view(-1, H, C) else: x_l, x_r = x[0], x[1] @@ -296,11 +305,14 @@ def forward( # noqa: F811 res = self.res(x_r) x_l = self.lin_l(x_l).view(-1, H, C) - if x_r is not None: + if x_r is not None and self.lin_r is not None: x_r = self.lin_r(x_r).view(-1, H, C) + else: + # TODO maybe warn user if they passed dest features that won't + # be used + x_r = None assert x_l is not None - assert x_r is not None if self.add_self_loops: if isinstance(edge_index, Tensor): @@ -334,6 +346,7 @@ def forward( # noqa: F811 out = out.mean(dim=1) if res is not None: + print(f"out shape {out.shape} res shape {res.shape}") out = out + res if self.bias is not None: @@ -355,7 +368,8 @@ def forward( # noqa: F811 def edge_update(self, x_j: Tensor, x_i: Tensor, edge_attr: OptTensor, index: Tensor, ptr: OptTensor, dim_size: Optional[int]) -> Tensor: - x = x_i + x_j + print("share weights", self.share_weights) + x = x_j if x_i is None else x_i + x_j if edge_attr is not None: if edge_attr.dim() == 1: From 2a3accbe6d9faa389cc139ab97db3a2aa35b27f2 Mon Sep 17 00:00:00 2001 From: Matthew Hayes Date: Thu, 5 Dec 2024 23:10:15 -0800 Subject: [PATCH 5/9] change to separate param --- test/nn/conv/test_gat_conv.py | 43 +++------------------ test/nn/conv/test_gatv2_conv.py | 40 ++------------------ torch_geometric/nn/conv/gat_conv.py | 54 +++++++++++++-------------- torch_geometric/nn/conv/gatv2_conv.py | 44 +++++++++------------- 4 files changed, 54 insertions(+), 127 deletions(-) diff --git a/test/nn/conv/test_gat_conv.py b/test/nn/conv/test_gat_conv.py index 415d863879d5..a6e1207eee9e 100644 --- a/test/nn/conv/test_gat_conv.py +++ b/test/nn/conv/test_gat_conv.py @@ -12,13 +12,15 @@ @pytest.mark.parametrize('residual', [False, True]) -def test_gat_conv(residual): +@pytest.mark.parametrize('interactive_attn', [False, True]) +def test_gat_conv(residual, interactive_attn): x1 = torch.randn(4, 8) x2 = torch.randn(2, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) - conv = GATConv(8, 32, heads=2, residual=residual) + conv = GATConv(8, 32, heads=2, residual=residual, + interactive_attn=interactive_attn) assert str(conv) == 'GATConv(8, 32, heads=2)' out = conv(x1, edge_index) assert out.size() == (4, 64) @@ -111,44 +113,11 @@ def forward( assert torch.allclose(result[0], out, atol=1e-6) assert result[1].sizes() == [4, 4, 2] and result[1].nnz() == 7 - # Test no target features in attention - conv = GATConv((8, None), 32, heads=2, residual=residual) - assert str(conv) == 'GATConv((8, None), 32, heads=2)' - out = conv(x1, edge_index) - assert out.size() == (4, 64) - assert torch.allclose(conv(x1, edge_index, size=(4, 4)), out) - assert torch.allclose(conv(x1, adj1.t()), out, atol=1e-6) - - if torch_geometric.typing.WITH_TORCH_SPARSE: - adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) - assert torch.allclose(conv(x1, adj2.t()), out, atol=1e-6) - - if is_full_test(): - - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = conv - - def forward( - self, - x: Tensor, - edge_index: Adj, - size: Size = None, - ) -> Tensor: - return self.conv(x, edge_index, size=size) - - jit = torch.jit.script(MyModule()) - assert torch.allclose(jit(x1, edge_index), out) - assert torch.allclose(jit(x1, edge_index, size=(4, 4)), out) - - if torch_geometric.typing.WITH_TORCH_SPARSE: - assert torch.allclose(jit(x1, adj2.t()), out, atol=1e-6) - # Test bipartite message passing: adj1 = to_torch_csc_tensor(edge_index, size=(4, 2)) - conv = GATConv((8, 16), 32, heads=2, residual=residual) + conv = GATConv((8, 16), 32, heads=2, residual=residual, + interactive_attn=interactive_attn) assert str(conv) == 'GATConv((8, 16), 32, heads=2)' out1 = conv((x1, x2), edge_index) diff --git a/test/nn/conv/test_gatv2_conv.py b/test/nn/conv/test_gatv2_conv.py index 3ee4070167cb..a8faaebded81 100644 --- a/test/nn/conv/test_gatv2_conv.py +++ b/test/nn/conv/test_gatv2_conv.py @@ -12,13 +12,15 @@ @pytest.mark.parametrize('residual', [False, True]) -def test_gatv2_conv(residual): +@pytest.mark.parametrize('interactive_attn', [False, True]) +def test_gatv2_conv(residual, interactive_attn): x1 = torch.randn(4, 8) x2 = torch.randn(2, 8) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) - conv = GATv2Conv(8, 32, heads=2, residual=residual) + conv = GATv2Conv(8, 32, heads=2, residual=residual, + interactive_attn=interactive_attn) assert str(conv) == 'GATv2Conv(8, 32, heads=2)' out = conv(x1, edge_index) assert out.size() == (4, 64) @@ -109,43 +111,9 @@ def forward( assert torch.allclose(result[0], out, atol=1e-6) assert result[1].sizes() == [4, 4, 2] and result[1].nnz() == 7 - # Test no target features in attention - conv = GATv2Conv((8, None), 32, heads=2, residual=residual) - assert str(conv) == 'GATv2Conv((8, None), 32, heads=2)' - out = conv(x1, edge_index) - assert out.size() == (4, 64) - assert torch.allclose(conv(x1, edge_index), out) - assert torch.allclose(conv(x1, adj1.t()), out, atol=1e-6) - - if torch_geometric.typing.WITH_TORCH_SPARSE: - adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) - assert torch.allclose(conv(x1, adj2.t()), out, atol=1e-6) - - if is_full_test(): - - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = conv - - def forward( - self, - x: Tensor, - edge_index: Adj, - ) -> Tensor: - return self.conv(x, edge_index) - - jit = torch.jit.script(MyModule()) - assert torch.allclose(jit(x1, edge_index), out) - - if torch_geometric.typing.WITH_TORCH_SPARSE: - assert torch.allclose(jit(x1, adj2.t()), out, atol=1e-6) - # Test bipartite message passing: adj1 = to_torch_csc_tensor(edge_index, size=(4, 2)) - conv = GATv2Conv(8, 32, heads=2, residual=residual) - out = conv((x1, x2), edge_index) assert out.size() == (2, 64) assert torch.allclose(conv((x1, x2), edge_index), out) diff --git a/torch_geometric/nn/conv/gat_conv.py b/torch_geometric/nn/conv/gat_conv.py index ac05e475b063..4fa3d05816d9 100644 --- a/torch_geometric/nn/conv/gat_conv.py +++ b/torch_geometric/nn/conv/gat_conv.py @@ -81,8 +81,7 @@ class GATConv(MessagePassing): derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities and distinct :math:`\mathbf{\Theta}`, for example, - in the case of a bipartite graph. A value of :obj:`None|0` for the - target dimensionality fixes :math:`\mathbf{\Theta}_t=\mathbf{0}`. + in the case of a bipartite graph. out_channels (int): Size of each output sample. heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`) @@ -142,6 +141,7 @@ def __init__( fill_value: Union[float, Tensor, str] = 'mean', bias: bool = True, residual: bool = False, + interactive_attn: bool = True, **kwargs, ): kwargs.setdefault('aggr', 'add') @@ -157,9 +157,10 @@ def __init__( self.edge_dim = edge_dim self.fill_value = fill_value self.residual = residual + self.interactive_attn = interactive_attn - # In case tuple in_channels, e.g. bipartite graphs, we apply separate - # transformations 'lin_src' and 'lin_dst' to source and target nodes: + # In case tuple in_channels, we apply separate transformations + # 'lin_src' and 'lin_dst' to source and target nodes: self.lin = self.lin_src = self.lin_dst = None if isinstance(in_channels, int): self.lin = Linear(in_channels, heads * out_channels, bias=False, @@ -167,16 +168,14 @@ def __init__( else: self.lin_src = Linear(in_channels[0], heads * out_channels, False, weight_initializer='glorot') - if in_channels[1]: + if interactive_attn: self.lin_dst = Linear(in_channels[1], heads * out_channels, False, weight_initializer='glorot') # The learnable parameters to compute attention coefficients: self.att_src = Parameter(torch.empty(1, heads, out_channels)) - if isinstance(in_channels, int) or in_channels[1]: + if interactive_attn: self.att_dst = Parameter(torch.empty(1, heads, out_channels)) - else: - self.att_dst = None if edge_dim is not None: self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False, @@ -190,11 +189,9 @@ def __init__( total_out_channels = out_channels * (heads if concat else 1) if residual: - res_in_channels = in_channels - if not isinstance(in_channels, int): - res_in_channels = in_channels[1] if in_channels[1] else -1 self.res = Linear( - res_in_channels, + in_channels + if isinstance(in_channels, int) else in_channels[1], total_out_channels, bias=False, weight_initializer='glorot', @@ -222,7 +219,8 @@ def reset_parameters(self): if self.res is not None: self.res.reset_parameters() glorot(self.att_src) - glorot(self.att_dst) + if self.interactive_attn: + glorot(self.att_dst) glorot(self.att_edge) zeros(self.bias) @@ -307,17 +305,15 @@ def forward( # noqa: F811 res = self.res(x) if self.lin is not None: - x_src = x_dst = self.lin(x).view(-1, H, C) + x_src = self.lin(x).view(-1, H, C) + x_dst = x_src if self.interactive_attn else None else: - # If the module is initialized with a tuple of positive - # in_channels, transform source and destination node features - # separately: + # If the module is initialized with tuple in_channels, + # transform source and destination node features separately: assert self.lin_src is not None x_src = self.lin_src(x).view(-1, H, C) - if self.lin_dst is not None: - x_dst = self.lin_dst(x).view(-1, H, C) - else: - x_dst = None + x_dst = (self.lin_dst(x).view(-1, H, C) + if self.interactive_attn else None) else: # Tuple of source and target node features: x_src, x_dst = x @@ -326,30 +322,32 @@ def forward( # noqa: F811 if x_dst is not None and self.res is not None: res = self.res(x_dst) + # In the case of non-interactive attention, we do not update x_dst + # below. Except in the case of a residual above, its value won't + # be used, but its size may be used when computing self loops. if self.lin is not None: # If the module is initialized with integer in_channels, we # expect that source and destination node features have the # same shape and that they their transformations are shared: x_src = self.lin(x_src).view(-1, H, C) - if x_dst is not None: + if x_dst is not None and self.interactive_attn: + assert self.lin_dst is not None x_dst = self.lin(x_dst).view(-1, H, C) else: assert self.lin_src is not None x_src = self.lin_src(x_src).view(-1, H, C) - if x_dst is not None and self.lin_dst is not None: + if x_dst is not None and self.interactive_attn: + assert self.lin_dst is not None x_dst = self.lin_dst(x_dst).view(-1, H, C) - else: - # TODO maybe warn user if they passed dest features that - # won't be used - x_dst = None x = (x_src, x_dst) # Next, we compute node-level attention coefficients, both for source # and target nodes (if present): alpha_src = (x_src * self.att_src).sum(dim=-1) - alpha_dst = None if x_dst is None else (x_dst * self.att_dst).sum(-1) + alpha_dst = ((x_dst * self.att_dst).sum(-1) + if x_dst is not None and self.interactive_attn else None) alpha = (alpha_src, alpha_dst) if self.add_self_loops: diff --git a/torch_geometric/nn/conv/gatv2_conv.py b/torch_geometric/nn/conv/gatv2_conv.py index bcc12c35536c..7a3481cc2bff 100644 --- a/torch_geometric/nn/conv/gatv2_conv.py +++ b/torch_geometric/nn/conv/gatv2_conv.py @@ -80,8 +80,7 @@ class GATv2Conv(MessagePassing): derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities and distinct :math:`\mathbf{\Theta}`, for example, - in the case of a bipartite graph. A value of :obj:`None|0` for the - target dimensionality fixes :math:`\mathbf{\Theta}_t=\mathbf{0}`. + in the case of a bipartite graph. out_channels (int): Size of each output sample. heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`) @@ -146,6 +145,7 @@ def __init__( bias: bool = True, share_weights: bool = False, residual: bool = False, + interactive_attn: bool = True, **kwargs, ): super().__init__(node_dim=0, **kwargs) @@ -161,26 +161,26 @@ def __init__( self.fill_value = fill_value self.residual = residual self.share_weights = share_weights + self.iteractive_attn = interactive_attn if isinstance(in_channels, int): self.lin_l = Linear(in_channels, heads * out_channels, bias=bias, weight_initializer='glorot') - if share_weights: - self.lin_r = self.lin_l - else: - self.lin_r = Linear(in_channels, heads * out_channels, - bias=bias, weight_initializer='glorot') + if interactive_attn: + if share_weights: + self.lin_r = self.lin_l + else: + self.lin_r = Linear(in_channels, heads * out_channels, + bias=bias, weight_initializer='glorot') else: self.lin_l = Linear(in_channels[0], heads * out_channels, bias=bias, weight_initializer='glorot') - if in_channels[1]: + if interactive_attn: if share_weights: self.lin_r = self.lin_l else: self.lin_r = Linear(in_channels[1], heads * out_channels, bias=bias, weight_initializer='glorot') - else: - self.lin_r = None self.att = Parameter(torch.empty(1, heads, out_channels)) @@ -194,11 +194,9 @@ def __init__( total_out_channels = out_channels * (heads if concat else 1) if residual: - res_in_channels = in_channels - if not isinstance(in_channels, int): - res_in_channels = in_channels[1] if in_channels[1] else -1 self.res = Linear( - res_in_channels, + in_channels + if isinstance(in_channels, int) else in_channels[1], total_out_channels, bias=False, weight_initializer='glorot', @@ -216,7 +214,7 @@ def __init__( def reset_parameters(self): super().reset_parameters() self.lin_l.reset_parameters() - if self.lin_r is not None: + if self.iteractive_attn: self.lin_r.reset_parameters() if self.lin_edge is not None: self.lin_edge.reset_parameters() @@ -290,12 +288,11 @@ def forward( # noqa: F811 if self.res is not None: res = self.res(x) - print(f"x shape {x.shape}, res shape {res.shape}") x_l = self.lin_l(x).view(-1, H, C) - if self.share_weights: + if self.share_weights or not self.iteractive_attn: x_r = x_l - elif self.lin_r is not None: + else: x_r = self.lin_r(x).view(-1, H, C) else: x_l, x_r = x[0], x[1] @@ -305,14 +302,11 @@ def forward( # noqa: F811 res = self.res(x_r) x_l = self.lin_l(x_l).view(-1, H, C) - if x_r is not None and self.lin_r is not None: + if x_r is not None and self.iteractive_attn: x_r = self.lin_r(x_r).view(-1, H, C) - else: - # TODO maybe warn user if they passed dest features that won't - # be used - x_r = None assert x_l is not None + assert x_r is not None if self.add_self_loops: if isinstance(edge_index, Tensor): @@ -346,7 +340,6 @@ def forward( # noqa: F811 out = out.mean(dim=1) if res is not None: - print(f"out shape {out.shape} res shape {res.shape}") out = out + res if self.bias is not None: @@ -368,8 +361,7 @@ def forward( # noqa: F811 def edge_update(self, x_j: Tensor, x_i: Tensor, edge_attr: OptTensor, index: Tensor, ptr: OptTensor, dim_size: Optional[int]) -> Tensor: - print("share weights", self.share_weights) - x = x_j if x_i is None else x_i + x_j + x = x_i + x_j if self.iteractive_attn else x_j if edge_attr is not None: if edge_attr.dim() == 1: From 89716c36a2bd584a697ae7818b305ebe03a1f46c Mon Sep 17 00:00:00 2001 From: Matthew Hayes Date: Sat, 7 Dec 2024 17:12:43 -0800 Subject: [PATCH 6/9] minor cleanups --- benchmark/citation/gat.py | 12 +++--------- torch_geometric/nn/conv/gat_conv.py | 2 +- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/benchmark/citation/gat.py b/benchmark/citation/gat.py index 46c2ba23fea6..f9ed5d6071af 100644 --- a/benchmark/citation/gat.py +++ b/benchmark/citation/gat.py @@ -24,21 +24,15 @@ parser.add_argument('--profile', action='store_true') parser.add_argument('--bf16', action='store_true') parser.add_argument('--compile', action='store_true') -parser.add_argument('--non_interactive', action='store_true') args = parser.parse_args() class Net(torch.nn.Module): def __init__(self, dataset): super().__init__() - in_channels_1 = dataset.num_features - in_channels_2 = args.hidden * args.heads - if args.non_interactive: - in_channels_1 = (in_channels_1, None) - in_channels_2 = (in_channels_2, None) - self.conv1 = GATConv(in_channels_1, args.hidden, heads=args.heads, - dropout=args.dropout) - self.conv2 = GATConv(in_channels_2, dataset.num_classes, + self.conv1 = GATConv(dataset.num_features, args.hidden, + heads=args.heads, dropout=args.dropout) + self.conv2 = GATConv(args.hidden * args.heads, dataset.num_classes, heads=args.output_heads, concat=False, dropout=args.dropout) diff --git a/torch_geometric/nn/conv/gat_conv.py b/torch_geometric/nn/conv/gat_conv.py index 4fa3d05816d9..e5807e7232ee 100644 --- a/torch_geometric/nn/conv/gat_conv.py +++ b/torch_geometric/nn/conv/gat_conv.py @@ -159,7 +159,7 @@ def __init__( self.residual = residual self.interactive_attn = interactive_attn - # In case tuple in_channels, we apply separate transformations + # In case of tuple in_channels, we apply separate transformations # 'lin_src' and 'lin_dst' to source and target nodes: self.lin = self.lin_src = self.lin_dst = None if isinstance(in_channels, int): From 9308f3d5b0c11a65e5936266d225d8ac523f56cd Mon Sep 17 00:00:00 2001 From: Matthew Hayes Date: Sun, 8 Dec 2024 21:42:55 -0800 Subject: [PATCH 7/9] documentation --- torch_geometric/nn/conv/gat_conv.py | 9 ++++++--- torch_geometric/nn/conv/gatv2_conv.py | 9 ++++++--- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/torch_geometric/nn/conv/gat_conv.py b/torch_geometric/nn/conv/gat_conv.py index e5807e7232ee..d578c66675ab 100644 --- a/torch_geometric/nn/conv/gat_conv.py +++ b/torch_geometric/nn/conv/gat_conv.py @@ -55,7 +55,7 @@ class GATConv(MessagePassing): + \mathbf{a}^{\top}_{s}\mathbf{\Theta}_{s}\mathbf{x}_k \right)\right)}. - If the graph has multi-dimensional edge features :math:`\mathbf{e}_{i,j}`, + If the graph has multi-dimensional edge features :math:`\mathbf{e}_{j,i}`, the attention coefficients :math:`\alpha_{i,j}` are computed as .. math:: @@ -64,13 +64,13 @@ class GATConv(MessagePassing): \exp\left(\mathrm{LeakyReLU}\left( \mathbf{a}^{\top}_{t} \mathbf{\Theta}_{t}\mathbf{x}_i + \mathbf{a}^{\top}_{s} \mathbf{\Theta}_{s}\mathbf{x}_j - + \mathbf{a}^{\top}_{e} \mathbf{\Theta}_{e} \mathbf{e}_{i,j} + + \mathbf{a}^{\top}_{e} \mathbf{\Theta}_{e} \mathbf{e}_{j,i} \right)\right)} {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} \exp\left(\mathrm{LeakyReLU}\left( \mathbf{a}^{\top}_{t} \mathbf{\Theta}_{t}\mathbf{x}_i + \mathbf{a}^{\top}_{s} \mathbf{\Theta}_{s}\mathbf{x}_k - + \mathbf{a}^{\top}_{e} \mathbf{\Theta}_{e} \mathbf{e}_{i,k} + + \mathbf{a}^{\top}_{e} \mathbf{\Theta}_{e} \mathbf{e}_{k,i} \right)\right)}. If an integer is passed for :obj:`in_channels`, :math:`\mathbf{\Theta}_{s} @@ -110,6 +110,9 @@ class GATConv(MessagePassing): an additive bias. (default: :obj:`True`) residual (bool, optional): If set to :obj:`True`, the layer will add a learnable skip-connection. (default: :obj:`False`) + interactive_attn (bool, optional): If set to :obj:`False`, fixes + :math:`\mathbf{a}^{\top}_{t} \mathbf{\Theta}_{t}\mathbf{x}_i = 0`. + (default :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. diff --git a/torch_geometric/nn/conv/gatv2_conv.py b/torch_geometric/nn/conv/gatv2_conv.py index 7a3481cc2bff..e05322e97e75 100644 --- a/torch_geometric/nn/conv/gatv2_conv.py +++ b/torch_geometric/nn/conv/gatv2_conv.py @@ -57,7 +57,7 @@ class GATv2Conv(MessagePassing): \mathbf{\Theta}_{t} \mathbf{x}_i + \mathbf{\Theta}_{s} \mathbf{x}_k \right)\right)}. - If the graph has multi-dimensional edge features :math:`\mathbf{e}_{i,j}`, + If the graph has multi-dimensional edge features :math:`\mathbf{e}_{j,i}`, the attention coefficients :math:`\alpha_{i,j}` are computed as .. math:: @@ -66,13 +66,13 @@ class GATv2Conv(MessagePassing): \exp\left(\mathbf{a}^{\top}\mathrm{LeakyReLU}\left( \mathbf{\Theta}_{t} \mathbf{x}_i + \mathbf{\Theta}_{s} \mathbf{x}_j - + \mathbf{\Theta}_{e} \mathbf{e}_{i,j} + + \mathbf{\Theta}_{e} \mathbf{e}_{j,i} \right)\right)} {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} \exp\left(\mathbf{a}^{\top}\mathrm{LeakyReLU}\left( \mathbf{\Theta}_{t} \mathbf{x}_i + \mathbf{\Theta}_{s} \mathbf{x}_k - + \mathbf{\Theta}_{e} \mathbf{e}_{i,k}] + + \mathbf{\Theta}_{e} \mathbf{e}_{k,i}] \right)\right)}. Args: @@ -113,6 +113,9 @@ class GATv2Conv(MessagePassing): (default: :obj:`False`) residual (bool, optional): If set to :obj:`True`, the layer will add a learnable skip-connection. (default: :obj:`False`) + interactive_attn (bool, optional): If set to :obj:`False`, fixes + :math:`\mathbf{a}^{\top}_{t} \mathbf{\Theta}_{t}\mathbf{x}_i = 0`. + (default :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. From 95b1e57cf8fa694bc47ca53f1128d6ffe05f1c65 Mon Sep 17 00:00:00 2001 From: Matthew Hayes Date: Sun, 8 Dec 2024 22:08:05 -0800 Subject: [PATCH 8/9] changelog --- CHANGELOG.md | 1 + torch_geometric/nn/conv/gat_conv.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4e6789b9a86d..78dd6dae5a44 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added `interactive_attn` parameter to `GATConv` and `GATv2Conv` ([#9832](https://github.com/pyg-team/pytorch_geometric/pull/9832)) - Update Dockerfile to use latest from NVIDIA ([#9794](https://github.com/pyg-team/pytorch_geometric/pull/9794)) - Added various GRetriever Architecture Benchmarking examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666)) - Added `profiler.nvtxit` with some examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666)) diff --git a/torch_geometric/nn/conv/gat_conv.py b/torch_geometric/nn/conv/gat_conv.py index d578c66675ab..1a85e571c7e9 100644 --- a/torch_geometric/nn/conv/gat_conv.py +++ b/torch_geometric/nn/conv/gat_conv.py @@ -334,7 +334,6 @@ def forward( # noqa: F811 # same shape and that they their transformations are shared: x_src = self.lin(x_src).view(-1, H, C) if x_dst is not None and self.interactive_attn: - assert self.lin_dst is not None x_dst = self.lin(x_dst).view(-1, H, C) else: assert self.lin_src is not None From a863103420a48f415b449e9a8f80361a1f03c1ed Mon Sep 17 00:00:00 2001 From: Matthew Hayes Date: Sun, 8 Dec 2024 22:50:20 -0800 Subject: [PATCH 9/9] fix v2 param description --- torch_geometric/nn/conv/gatv2_conv.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torch_geometric/nn/conv/gatv2_conv.py b/torch_geometric/nn/conv/gatv2_conv.py index e05322e97e75..2eb3c4c37e8d 100644 --- a/torch_geometric/nn/conv/gatv2_conv.py +++ b/torch_geometric/nn/conv/gatv2_conv.py @@ -114,8 +114,7 @@ class GATv2Conv(MessagePassing): residual (bool, optional): If set to :obj:`True`, the layer will add a learnable skip-connection. (default: :obj:`False`) interactive_attn (bool, optional): If set to :obj:`False`, fixes - :math:`\mathbf{a}^{\top}_{t} \mathbf{\Theta}_{t}\mathbf{x}_i = 0`. - (default :obj:`True`) + :math:`\mathbf{\Theta}_{t}\mathbf{x}_i = 0`. (default :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`.