diff --git a/.gitignore b/.gitignore index 3d1804049..13cacaa12 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ Manifest.toml LocalPreferences.toml .DS_Store /test.jl +try.jl diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 19c200e7a..76fb3ec0f 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1656,18 +1656,21 @@ function EGNNConv(ch::Pair{NTuple{2, Int}, Int}; hidden_size::Int = 2 * ch[1][1] return EGNNConv(ϕe, ϕx, ϕh, num_features, residual) end -function (l::EGNNConv)(g::GNNGraph, h::AbstractMatrix, x::AbstractMatrix, e = nothing) +function (l::EGNNConv)(g::AbstractGNNGraph, h, x, e = nothing) if l.num_features.edge > 0 @assert e!==nothing "Edge features must be provided." end + @assert size(h, 1)==l.num_features.in "Input features must match layer input size." - - x_diff = apply_edges(xi_sub_xj, g, x, x) + xj, xi = expand_srcdst(g, x) + hj, hi = expand_srcdst(g, h) #not needed since its invariant node features + + x_diff = apply_edges(xi_sub_xj, g, xi, xj) sqnorm_xdiff = sum(x_diff .^ 2, dims = 1) x_diff = x_diff ./ (sqrt.(sqnorm_xdiff) .+ 1.0f-6) msg = apply_edges(message, g, l, - xi = (; h), xj = (; h), e = (; e, x_diff, sqnorm_xdiff)) + xi = (; hi), xj = (; hj), e = (; e, x_diff, sqnorm_xdiff)) h_aggr = aggregate_neighbors(g, +, msg.h) x_aggr = aggregate_neighbors(g, mean, msg.x) diff --git a/src/layers/heteroconv.jl b/src/layers/heteroconv.jl index ec75c8922..fe8205199 100644 --- a/src/layers/heteroconv.jl +++ b/src/layers/heteroconv.jl @@ -65,6 +65,27 @@ function (hgc::HeteroGraphConv)(g::GNNHeteroGraph, x::Union{NamedTuple,Dict}) return _reduceby_node_t(hgc.aggr, outs, dst_ntypes) end + +function (hgc::HeteroGraphConv)(g::GNNHeteroGraph, x::NamedTuple, h::AbstractMatrix) + function forw(l, et) + sg = edge_type_subgraph(g, et) + node1_t, _, node2_t = et + + print(x,"\n\n", h,"before\n\n\n") + + x_features = (x[node1_t], x[node2_t]) + h_features = h # temporary + + return l(sg, h_features, x_features) + + end + outs = [forw(l, et) for (l, et) in zip(hgc.layers, hgc.etypes)] + dst_ntypes = [et[3] for et in hgc.etypes] + return _reduceby_node_t(hgc.aggr, outs, dst_ntypes) +end + + + function _reduceby_node_t(aggr, outs, ntypes) function _reduce(node_t) idxs = findall(x -> x == node_t, ntypes) diff --git a/test/layers/heteroconv.jl b/test/layers/heteroconv.jl index 3d5f2c09c..c987eb58d 100644 --- a/test/layers/heteroconv.jl +++ b/test/layers/heteroconv.jl @@ -125,6 +125,20 @@ @test size(y.A) == (2, 2) && size(y.B) == (2, 3) end + @testset "EGNNConv with Heterogeneous Graphs" begin + hin = 5 + hout = 5 + hidden = 5 + hg = rand_bipartite_heterograph((2,3), 6) + hg.num_nodes + x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 3)) + h = (A = rand(Float32, 5, 2), B = rand(Float32, 5, 3)) + layers = HeteroGraphConv((:A, :to, :B) => EGNNConv(4 => 2), + (:B, :to, :A) => EGNNConv(4 => 2)); + y = layers(hg, x, h) + @test size(y.A) == (2, 2) && size(y.B) == (2, 3) + end + @testset "GINConv" begin x = (A = rand(4, 2), B = rand(4, 3)) layers = HeteroGraphConv((:A, :to, :B) => GINConv(Dense(4, 2), 0.4),