Skip to content

Commit

Permalink
feat: add induced_subgraph functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
askorupka committed Sep 29, 2024
1 parent a034753 commit 5859d3d
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 0 deletions.
1 change: 1 addition & 0 deletions GNNGraphs/src/GNNGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ export rand_graph,

include("sampling.jl")
export sample_neighbors
export induced_subgraph

include("operators.jl")
# Base.intersect
Expand Down
50 changes: 50 additions & 0 deletions GNNGraphs/src/sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,53 @@ function sample_neighbors(g::GNNGraph{<:COO_T}, nodes, K = -1;
end
return gnew
end

"""
induced_subgraph(graph::GNNGraph, nodes::Vector{Int}) -> GNNGraph
Generates a subgraph from the original graph using the provided `nodes`.
The function includes the nodes' neighbors and creates edges between nodes that are connected in the original graph.
If a node has no neighbors, an isolated node will be added to the subgraph.
# Arguments:
- `graph::GNNGraph`: The original graph containing nodes, edges, and node features.
- `nodes::Vector{Int}`: A vector of node indices to include in the subgraph.
# Returns:
A new `GNNGraph` containing the subgraph with the specified nodes and their features.
"""
function Graphs.induced_subgraph(graph::GNNGraph, nodes::Vector{Int})
if isempty(nodes)
return GNNGraph() # Return empty graph if no nodes are provided
end

node_map = Dict(node => i for (i, node) in enumerate(nodes))

# Collect edges to add
source = Int[]
target = Int[]
backup_gnn = GNNGraph()
for node in nodes
neighbors = Graphs.neighbors(graph, node, dir = :in)
if isempty(neighbors)
backup_gnn = add_nodes(backup_gnn, 1)
end
for neighbor in neighbors
if neighbor in keys(node_map)
push!(source, node_map[node])
push!(target, node_map[neighbor])
end
end
end

# Extract features for the new nodes
#new_features = graph.x[:, nodes]

if isempty(source) && isempty(target)
#backup_gnn.ndata.x = new_features ### TODO fix & add edges data (probably push themto the new vector?)
return backup_gnn # Return empty graph if no nodes are provided
end

return GNNGraph(source, target)
#, ndata = new_features) # Return the new GNNGraph with subgraph and features
end
16 changes: 16 additions & 0 deletions GNNGraphs/test/sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,20 @@ if GRAPH_T == :coo
@test sg.ndata.x1 == g.ndata.x1[sg.ndata.NID]
@test length(union(sg.ndata.NID)) == length(sg.ndata.NID)
end

@testset "induced_subgraph" begin
# Create a simple GNNGraph with two nodes and one edge
graph = GNNGraph() # Initialize graph
add_nodes!(graph, 2) # Add 2 nodes
add_edge!(graph, 1, 2) # Add an edge from node 1 to node 2
graph.x = rand(10, 2) # Assign random features to both nodes (10 features per node)

# Induce subgraph on both nodes
nodes = [1, 2]
subgraph = induced_subgraph(graph, nodes)

@test num_nodes(subgraph) == 2 # Subgraph should have 2 nodes
@test num_nodes(subgraph) == 1 # Subgraph should have 1 edge
### TODO @test subgraph.ndata.x == graph.x[:, nodes] # Features should match the original graph
end
end

0 comments on commit 5859d3d

Please sign in to comment.