Skip to content

Commit

Permalink
Update conv.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
rbSparky authored Sep 30, 2024
1 parent 7bc5c42 commit d33b184
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion GNNLux/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -887,6 +887,10 @@ function TransformerConv(ch::Pair{NTuple{2, Int}, Int};
@assert iszero(ein) "Using edge features and setting add_self_loops=true at the same time is not yet supported."
end

if skip_connection
@assert in == (concat ? out * heads : out) "In-channels must correspond to out-channels * heads (or just out_channels if concat=false) if skip_connection is used"
end

W1 = root_weight ? Dense(in => out * (concat ? heads : 1); use_bias=bias_root, init_weight, init_bias) : nothing
W2 = Dense(in => out * heads; use_bias=bias_qkv, init_weight, init_bias)
W3 = Dense(in => out * heads; use_bias=bias_qkv, init_weight, init_bias)
Expand All @@ -904,7 +908,7 @@ function TransformerConv(ch::Pair{NTuple{2, Int}, Int};
skip_connection, Float32(out), W1, W2, W3, W4, W5, W6, FF, BN1, BN2)
end

LuxCore.outputsize(l::TransformerConv) = (l.concat ? l.out_dims * l.heads : l.out_dims,)
LuxCore.outputsize(l::TransformerConv) = (l.concat ? l.out_dims * l.heads : l.out_dims,)

function (l::TransformerConv)(g, x, ps, st)
return l(g, x, nothing, ps, st)
Expand Down

0 comments on commit d33b184

Please sign in to comment.