diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index 029c20ad..a491b7f6 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -887,6 +887,10 @@ function TransformerConv(ch::Pair{NTuple{2, Int}, Int}; @assert iszero(ein) "Using edge features and setting add_self_loops=true at the same time is not yet supported." end + if skip_connection + @assert in == (concat ? out * heads : out) "In-channels must correspond to out-channels * heads (or just out_channels if concat=false) if skip_connection is used" + end + W1 = root_weight ? Dense(in => out * (concat ? heads : 1); use_bias=bias_root, init_weight, init_bias) : nothing W2 = Dense(in => out * heads; use_bias=bias_qkv, init_weight, init_bias) W3 = Dense(in => out * heads; use_bias=bias_qkv, init_weight, init_bias) @@ -904,7 +908,7 @@ function TransformerConv(ch::Pair{NTuple{2, Int}, Int}; skip_connection, Float32(√out), W1, W2, W3, W4, W5, W6, FF, BN1, BN2) end -LuxCore.outputsize(l::TransformerConv) = (l.concat ? l.out_dims * l.heads : l.out_dims,) +LuxCore.outputsize(l::TransformerConv) = (l.concat ? l.out_dims * l.heads : l.out_dims,) function (l::TransformerConv)(g, x, ps, st) return l(g, x, nothing, ps, st)