diff --git a/GNNGraphs/src/GNNGraphs.jl b/GNNGraphs/src/GNNGraphs.jl index b82ea39e..d969afd1 100644 --- a/GNNGraphs/src/GNNGraphs.jl +++ b/GNNGraphs/src/GNNGraphs.jl @@ -96,7 +96,7 @@ export rand_graph, rand_temporal_hyperbolic_graph include("sampling.jl") -export sample_neighbors +export sample_neighbors, induced_subgraph include("operators.jl") # Base.intersect diff --git a/GNNGraphs/src/sampling.jl b/GNNGraphs/src/sampling.jl index 01a601f5..d88790ce 100644 --- a/GNNGraphs/src/sampling.jl +++ b/GNNGraphs/src/sampling.jl @@ -116,3 +116,48 @@ function sample_neighbors(g::GNNGraph{<:COO_T}, nodes, K = -1; end return gnew end + +""" + Graphs.induced_subgraph(graph::GNNGraph, nodes::Vector{Int}) -> GNNGraph + +Generates a subgraph from the original graph using the provided `nodes`. +The function includes the nodes' neighbors and creates edges between nodes that are connected in the original graph. +If a node has no neighbors, an isolated node will be added to the subgraph. + +# Arguments: +- `graph::GNNGraph`: The original graph containing nodes, edges, and node features. +- `nodes::Vector{Int}`: A vector of node indices to include in the subgraph. + +# Returns: +A new `GNNGraph` containing the subgraph with the specified nodes and their features. +""" +function Graphs.induced_subgraph(graph::GNNGraph, nodes::Vector{Int}) + if isempty(nodes) + return GNNGraph() # Return empty graph if no nodes are provided + end + + node_map = Dict(node => i for (i, node) in enumerate(nodes)) + + # Collect edges to add + source = Int[] + target = Int[] + eindices = Int[] + for node in nodes + neighbors = Graphs.neighbors(graph, node, dir = :in) + for neighbor in neighbors + if neighbor in keys(node_map) + push!(target, node_map[node]) + push!(source, node_map[neighbor]) + + eindex = findfirst(x -> x == [neighbor, node], edge_index(graph)) + push!(eindices, eindex) + end + end + end + + # Extract features for the new nodes + new_ndata = getobs(graph.ndata, nodes) + new_edata = getobs(graph.edata, eindices) + + return GNNGraph(source, target, num_nodes = length(node_map), ndata = new_ndata, edata = new_edata) +end diff --git a/GNNGraphs/test/sampling.jl b/GNNGraphs/test/sampling.jl index 658cee8d..c9c12634 100644 --- a/GNNGraphs/test/sampling.jl +++ b/GNNGraphs/test/sampling.jl @@ -45,4 +45,28 @@ if GRAPH_T == :coo @test sg.ndata.x1 == g.ndata.x1[sg.ndata.NID] @test length(union(sg.ndata.NID)) == length(sg.ndata.NID) end + + @testset "induced_subgraph" begin + s = [1, 2] + t = [2, 3] + + graph = GNNGraph((s, t), ndata = (; x=rand(Float32, 32, 3), y=rand(Float32, 3)), edata = rand(Float32, 2)) + + nodes = [1, 2, 3] + subgraph = Graphs.induced_subgraph(graph, nodes) + + @test subgraph.num_nodes == 3 + @test subgraph.num_edges == 2 + @test subgraph.ndata.x == graph.ndata.x + @test subgraph.ndata.y == graph.ndata.y + @test subgraph.edata == graph.edata + + graph = GNNGraph(2) + graph = add_edges(graph, ([2], [1])) + nodes = [1] + subgraph = Graphs.induced_subgraph(graph, nodes) + + @test subgraph.num_nodes == 1 + @test subgraph.num_edges == 0 + end end \ No newline at end of file diff --git a/GraphNeuralNetworks/src/GraphNeuralNetworks.jl b/GraphNeuralNetworks/src/GraphNeuralNetworks.jl index cebf7b7d..6a8cc0dc 100644 --- a/GraphNeuralNetworks/src/GraphNeuralNetworks.jl +++ b/GraphNeuralNetworks/src/GraphNeuralNetworks.jl @@ -11,7 +11,8 @@ using ChainRulesCore using Reexport using MLUtils: zeros_like -using GNNGraphs: COO_T, ADJMAT_T, SPARSE_T, +using GNNGraphs: induced_subgraph, + COO_T, ADJMAT_T, SPARSE_T, check_num_nodes, check_num_edges, EType, NType # for heteroconvs