Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add induced_subgraph functionality #499

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
2 changes: 1 addition & 1 deletion GNNGraphs/src/GNNGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ export rand_graph,
rand_temporal_hyperbolic_graph

include("sampling.jl")
export sample_neighbors
export sample_neighbors, induced_subgraph

include("operators.jl")
# Base.intersect
Expand Down
45 changes: 45 additions & 0 deletions GNNGraphs/src/sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,48 @@ function sample_neighbors(g::GNNGraph{<:COO_T}, nodes, K = -1;
end
return gnew
end

"""
Graphs.induced_subgraph(graph::GNNGraph, nodes::Vector{Int}) -> GNNGraph

Generates a subgraph from the original graph using the provided `nodes`.
The function includes the nodes' neighbors and creates edges between nodes that are connected in the original graph.
If a node has no neighbors, an isolated node will be added to the subgraph.

# Arguments:
- `graph::GNNGraph`: The original graph containing nodes, edges, and node features.
- `nodes::Vector{Int}`: A vector of node indices to include in the subgraph.

# Returns:
A new `GNNGraph` containing the subgraph with the specified nodes and their features.
"""
function Graphs.induced_subgraph(graph::GNNGraph, nodes::Vector{Int})
if isempty(nodes)
return GNNGraph() # Return empty graph if no nodes are provided
end

node_map = Dict(node => i for (i, node) in enumerate(nodes))

# Collect edges to add
source = Int[]
target = Int[]
eindices = Int[]
for node in nodes
neighbors = Graphs.neighbors(graph, node, dir = :in)
for neighbor in neighbors
if neighbor in keys(node_map)
push!(target, node_map[node])
push!(source, node_map[neighbor])

eindex = findfirst(x -> x == [neighbor, node], edge_index(graph))
push!(eindices, eindex)
end
end
end

# Extract features for the new nodes
new_ndata = getobs(graph.ndata, nodes)
new_edata = getobs(graph.edata, eindices)

return GNNGraph(source, target, num_nodes = length(node_map), ndata = new_ndata, edata = new_edata)
end
24 changes: 24 additions & 0 deletions GNNGraphs/test/sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,28 @@ if GRAPH_T == :coo
@test sg.ndata.x1 == g.ndata.x1[sg.ndata.NID]
@test length(union(sg.ndata.NID)) == length(sg.ndata.NID)
end

@testset "induced_subgraph" begin
s = [1, 2]
t = [2, 3]

graph = GNNGraph((s, t), ndata = (; x=rand(Float32, 32, 3), y=rand(Float32, 3)), edata = rand(Float32, 2))

nodes = [1, 2, 3]
subgraph = Graphs.induced_subgraph(graph, nodes)

@test subgraph.num_nodes == 3
@test subgraph.num_edges == 2
@test subgraph.ndata.x == graph.ndata.x
@test subgraph.ndata.y == graph.ndata.y
@test subgraph.edata == graph.edata

graph = GNNGraph(2)
graph = add_edges(graph, ([2], [1]))
nodes = [1]
subgraph = Graphs.induced_subgraph(graph, nodes)

@test subgraph.num_nodes == 1
@test subgraph.num_edges == 0
end
end
3 changes: 2 additions & 1 deletion GraphNeuralNetworks/src/GraphNeuralNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ using ChainRulesCore
using Reexport
using MLUtils: zeros_like

using GNNGraphs: COO_T, ADJMAT_T, SPARSE_T,
using GNNGraphs: induced_subgraph,
COO_T, ADJMAT_T, SPARSE_T,
check_num_nodes, check_num_edges,
EType, NType # for heteroconvs

Expand Down
Loading