Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add induced_subgraph functionality #499

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions GNNGraphs/src/GNNGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ export rand_graph,

include("sampling.jl")
export sample_neighbors
export induced_subgraph
askorupka marked this conversation as resolved.
Show resolved Hide resolved

include("operators.jl")
# Base.intersect
Expand Down
50 changes: 50 additions & 0 deletions GNNGraphs/src/sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,53 @@ function sample_neighbors(g::GNNGraph{<:COO_T}, nodes, K = -1;
end
return gnew
end

"""
induced_subgraph(graph::GNNGraph, nodes::Vector{Int}) -> GNNGraph

Generates a subgraph from the original graph using the provided `nodes`.
The function includes the nodes' neighbors and creates edges between nodes that are connected in the original graph.
If a node has no neighbors, an isolated node will be added to the subgraph.
askorupka marked this conversation as resolved.
Show resolved Hide resolved

# Arguments:
- `graph::GNNGraph`: The original graph containing nodes, edges, and node features.
- `nodes::Vector{Int}`: A vector of node indices to include in the subgraph.

# Returns:
A new `GNNGraph` containing the subgraph with the specified nodes and their features.
"""
function Graphs.induced_subgraph(graph::GNNGraph, nodes::Vector{Int})
if isempty(nodes)
return GNNGraph() # Return empty graph if no nodes are provided
end

node_map = Dict(node => i for (i, node) in enumerate(nodes))

# Collect edges to add
source = Int[]
target = Int[]
backup_gnn = GNNGraph()
for node in nodes
neighbors = Graphs.neighbors(graph, node, dir = :in)
if isempty(neighbors)
backup_gnn = add_nodes(backup_gnn, 1)
end
for neighbor in neighbors
if neighbor in keys(node_map)
push!(source, node_map[node])
push!(target, node_map[neighbor])
askorupka marked this conversation as resolved.
Show resolved Hide resolved
end
end
end

# Extract features for the new nodes
#new_features = graph.x[:, nodes]

if isempty(source) && isempty(target)
#backup_gnn.ndata.x = new_features ### TODO fix & add edges data (probably push themto the new vector?)
return backup_gnn # Return empty graph if no nodes are provided
end
askorupka marked this conversation as resolved.
Show resolved Hide resolved

return GNNGraph(source, target)
askorupka marked this conversation as resolved.
Show resolved Hide resolved
#, ndata = new_features) # Return the new GNNGraph with subgraph and features
end
16 changes: 16 additions & 0 deletions GNNGraphs/test/sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,20 @@ if GRAPH_T == :coo
@test sg.ndata.x1 == g.ndata.x1[sg.ndata.NID]
@test length(union(sg.ndata.NID)) == length(sg.ndata.NID)
end

@testset "induced_subgraph" begin
# Create a simple GNNGraph with two nodes and one edge
s = [1]
t = [2]
### TODO add data
graph = GNNGraph((s, t))

# Induce subgraph on both nodes
nodes = [1, 2]
subgraph = induced_subgraph(graph, nodes)

@test subgraph.num_nodes == 2 # Subgraph should have 2 nodes
@test subgraph.num_edges == 1 # Subgraph should have 1 edge
### TODO @test subgraph.ndata.x == graph.x[:, nodes] # Features should match the original graph
end
end
Loading