Skip to content

Commit

Permalink
Update conv.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
rbSparky authored Sep 30, 2024
1 parent 1bc673a commit 7bc5c42
Showing 1 changed file with 18 additions and 19 deletions.
37 changes: 18 additions & 19 deletions GNNLux/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,7 @@ function Base.show(io::IO, l::ResGatedGraphConv)
l.use_bias || print(io, ", use_bias=false")
print(io, ")")
end

@concrete struct TransformerConv <: GNNContainerLayer{(:W1, :W2, :W3, :W4, :W5, :W6, :FF, :BN1, :BN2)}
in_dims::NTuple{2, Int}
out_dims::Int
Expand All @@ -864,7 +865,7 @@ end
end

function TransformerConv(ch::Pair{Int, Int}, args...; kws...)
TransformerConv((ch[1], 0) => ch[2], args...; kws...)
return TransformerConv((ch[1], 0) => ch[2], args...; kws...)
end

function TransformerConv(ch::Pair{NTuple{2, Int}, Int};
Expand All @@ -880,21 +881,19 @@ function TransformerConv(ch::Pair{NTuple{2, Int}, Int};
skip_connection::Bool = false,
batch_norm::Bool = false,
ff_channels::Int = 0)

(in, ein), out = ch

if add_self_loops
@assert iszero(ein) "Using edge features and setting add_self_loops=true at the same time is not yet supported."
end

W1 = root_weight ?
Dense(in => out * (concat ? heads : 1); use_bias = bias_root, init_weight, init_bias) : nothing
W2 = Dense(in => out * heads; use_bias = bias_qkv, init_weight, init_bias)
W3 = Dense(in => out * heads; use_bias = bias_qkv, init_weight, init_bias)
W4 = Dense(in => out * heads; use_bias = bias_qkv, init_weight, init_bias)
W1 = root_weight ? Dense(in => out * (concat ? heads : 1); use_bias=bias_root, init_weight, init_bias) : nothing
W2 = Dense(in => out * heads; use_bias=bias_qkv, init_weight, init_bias)
W3 = Dense(in => out * heads; use_bias=bias_qkv, init_weight, init_bias)
W4 = Dense(in => out * heads; use_bias=bias_qkv, init_weight, init_bias)
out_mha = out * (concat ? heads : 1)
W5 = gating ? Dense(3 * out_mha => 1, sigmoid; use_bias = false, init_weight, init_bias) : nothing
W6 = ein > 0 ? Dense(ein => out * heads; use_bias = bias_qkv, init_weight, init_bias) : nothing
W5 = gating ? Dense(3 * out_mha => 1, sigmoid; use_bias=false, init_weight, init_bias) : nothing
W6 = ein > 0 ? Dense(ein => out * heads; use_bias=bias_qkv, init_weight, init_bias) : nothing
FF = ff_channels > 0 ?
Chain(Dense(out_mha => ff_channels, relu; init_weight, init_bias),
Dense(ff_channels => out_mha; init_weight, init_bias)) : nothing
Expand All @@ -905,10 +904,10 @@ function TransformerConv(ch::Pair{NTuple{2, Int}, Int};
skip_connection, Float32(out), W1, W2, W3, W4, W5, W6, FF, BN1, BN2)
end

LuxCore.outputsize(l::TransformerConv) = (l.out_dims,)
LuxCore.outputsize(l::TransformerConv) = (l.concat ? l.out_dims * l.heads : l.out_dims,)

function (l::TransformerConv)(g, x, ps, st)
l(g, x, nothing, ps, st)
return l(g, x, nothing, ps, st)
end

function (l::TransformerConv)(g, x, e, ps, st)
Expand All @@ -933,21 +932,21 @@ function (l::TransformerConv)(g, x, e, ps, st)
end

function LuxCore.parameterlength(l::TransformerConv)
n = parameterlength(l.W1) + parameterlength(l.W2) +
parameterlength(l.W3) + parameterlength(l.W4) +
parameterlength(l.W5) + parameterlength(l.W6)

n = parameterlength(l.W2) + parameterlength(l.W3) + parameterlength(l.W4)
n += l.W1 === nothing ? 0 : parameterlength(l.W1)
n += l.W5 === nothing ? 0 : parameterlength(l.W5)
n += l.W6 === nothing ? 0 : parameterlength(l.W6)
n += l.FF === nothing ? 0 : parameterlength(l.FF)
n += l.BN1 === nothing ? 0 : parameterlength(l.BN1)
n += l.BN2 === nothing ? 0 : parameterlength(l.BN2)
return n
end

function LuxCore.statelength(l::TransformerConv)
n = statelength(l.W1) + statelength(l.W2) +
statelength(l.W3) + statelength(l.W4) +
statelength(l.W5) + statelength(l.W6)

n = statelength(l.W2) + statelength(l.W3) + statelength(l.W4)
n += l.W1 === nothing ? 0 : statelength(l.W1)
n += l.W5 === nothing ? 0 : statelength(l.W5)
n += l.W6 === nothing ? 0 : statelength(l.W6)
n += l.FF === nothing ? 0 : statelength(l.FF)
n += l.BN1 === nothing ? 0 : statelength(l.BN1)
n += l.BN2 === nothing ? 0 : statelength(l.BN2)
Expand Down

0 comments on commit 7bc5c42

Please sign in to comment.