Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add EGNNConv support for HeteroGraphConv #386

Draft
wants to merge 13 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ Manifest.toml
LocalPreferences.toml
.DS_Store
/test.jl
try.jl
11 changes: 7 additions & 4 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1656,18 +1656,21 @@ function EGNNConv(ch::Pair{NTuple{2, Int}, Int}; hidden_size::Int = 2 * ch[1][1]
return EGNNConv(ϕe, ϕx, ϕh, num_features, residual)
end

function (l::EGNNConv)(g::GNNGraph, h::AbstractMatrix, x::AbstractMatrix, e = nothing)
function (l::EGNNConv)(g::AbstractGNNGraph, h, x, e = nothing)
if l.num_features.edge > 0
@assert e!==nothing "Edge features must be provided."
end

@assert size(h, 1)==l.num_features.in "Input features must match layer input size."

x_diff = apply_edges(xi_sub_xj, g, x, x)
xj, xi = expand_srcdst(g, x)
hj, hi = expand_srcdst(g, h) #not needed since its invariant node features

x_diff = apply_edges(xi_sub_xj, g, xi, xj)
sqnorm_xdiff = sum(x_diff .^ 2, dims = 1)
x_diff = x_diff ./ (sqrt.(sqnorm_xdiff) .+ 1.0f-6)

msg = apply_edges(message, g, l,
xi = (; h), xj = (; h), e = (; e, x_diff, sqnorm_xdiff))
xi = (; hi), xj = (; hj), e = (; e, x_diff, sqnorm_xdiff))
h_aggr = aggregate_neighbors(g, +, msg.h)
x_aggr = aggregate_neighbors(g, mean, msg.x)

Expand Down
21 changes: 21 additions & 0 deletions src/layers/heteroconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,27 @@ function (hgc::HeteroGraphConv)(g::GNNHeteroGraph, x::Union{NamedTuple,Dict})
return _reduceby_node_t(hgc.aggr, outs, dst_ntypes)
end


function (hgc::HeteroGraphConv)(g::GNNHeteroGraph, x::NamedTuple, h::AbstractMatrix)
function forw(l, et)
sg = edge_type_subgraph(g, et)
node1_t, _, node2_t = et

print(x,"\n\n", h,"before\n\n\n")

x_features = (x[node1_t], x[node2_t])
h_features = h # temporary

return l(sg, h_features, x_features)

end
outs = [forw(l, et) for (l, et) in zip(hgc.layers, hgc.etypes)]
dst_ntypes = [et[3] for et in hgc.etypes]
return _reduceby_node_t(hgc.aggr, outs, dst_ntypes)
end



function _reduceby_node_t(aggr, outs, ntypes)
function _reduce(node_t)
idxs = findall(x -> x == node_t, ntypes)
Expand Down
14 changes: 14 additions & 0 deletions test/layers/heteroconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,20 @@
@test size(y.A) == (2, 2) && size(y.B) == (2, 3)
end

@testset "EGNNConv with Heterogeneous Graphs" begin
hin = 5
hout = 5
hidden = 5
hg = rand_bipartite_heterograph((2,3), 6)
hg.num_nodes
x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 3))
h = (A = rand(Float32, 5, 2), B = rand(Float32, 5, 3))
layers = HeteroGraphConv((:A, :to, :B) => EGNNConv(4 => 2),
(:B, :to, :A) => EGNNConv(4 => 2));
y = layers(hg, x, h)
@test size(y.A) == (2, 2) && size(y.B) == (2, 3)
end

@testset "GINConv" begin
x = (A = rand(4, 2), B = rand(4, 3))
layers = HeteroGraphConv((:A, :to, :B) => GINConv(Dense(4, 2), 0.4),
Expand Down