Skip to content

Commit

Permalink
Fix BFSWithDepth visit callback invoked with incorrect depth
Browse files Browse the repository at this point in the history
  • Loading branch information
s111 committed Sep 20, 2023
1 parent 854cb3c commit b3f630d
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 17 deletions.
26 changes: 14 additions & 12 deletions traversal.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,23 +132,25 @@ func BFSWithDepth[K comparable, T any](g Graph[K, T], start K, visit func(K, int
depth := 0

for len(queue) > 0 {
currentHash := queue[0]

queue = queue[1:]
depth++

// Stop traversing the graph if the visit function returns true.
if stop := visit(currentHash, depth); stop {
break
}
for verticesAtDepth := len(queue); verticesAtDepth > 0; verticesAtDepth-- {
currentHash := queue[0]

queue = queue[1:]

for adjacency := range adjacencyMap[currentHash] {
if _, ok := visited[adjacency]; !ok {
visited[adjacency] = true
queue = append(queue, adjacency)
// Stop traversing the graph if the visit function returns true.
if stop := visit(currentHash, depth); stop {
break
}
}

for adjacency := range adjacencyMap[currentHash] {
if _, ok := visited[adjacency]; !ok {
visited[adjacency] = true
queue = append(queue, adjacency)
}
}
}
}

return nil
Expand Down
91 changes: 86 additions & 5 deletions traversal_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package graph

import (
"log"
"testing"
)

Expand Down Expand Up @@ -319,10 +318,81 @@ func TestDirectedBFS(t *testing.T) {
t.Errorf("%s: expected vertex %v to be visited, but it isn't", name, expectedVisit)
}
}
}
}

visitWithDepth := func(value int, depth int) bool {
visited[value] = struct{}{}
log.Printf("cur depth: %d", depth)
func TestDirectedBFSWithDepth(t *testing.T) {
tests := map[string]struct {
vertices []int
edges []Edge[int]
startHash int
expectedVisits map[int]int
stopAtVertex int
}{
"traverse entire graph with 3 vertices": {
vertices: []int{1, 2, 3},
edges: []Edge[int]{
{Source: 1, Target: 2},
{Source: 1, Target: 3},
},
startHash: 1,
expectedVisits: map[int]int{
1: 1,
2: 2,
3: 2,
},
stopAtVertex: -1,
},
"traverse graph with 6 vertices until vertex 4": {
vertices: []int{1, 2, 3, 4, 5, 6},
edges: []Edge[int]{
{Source: 1, Target: 2},
{Source: 1, Target: 3},
{Source: 2, Target: 4},
{Source: 2, Target: 5},
{Source: 3, Target: 6},
},
startHash: 1,
expectedVisits: map[int]int{
1: 1,
2: 2,
3: 2,
4: 3,
},
stopAtVertex: 4,
},
"traverse a disconnected graph": {
vertices: []int{1, 2, 3, 4},
edges: []Edge[int]{
{Source: 1, Target: 2},
{Source: 3, Target: 4},
},
startHash: 1,
expectedVisits: map[int]int{
1: 1,
2: 2,
},
stopAtVertex: -1,
},
}

for name, test := range tests {
graph := New(IntHash, Directed())

for _, vertex := range test.vertices {
_ = graph.AddVertex(vertex)
}

for _, edge := range test.edges {
if err := graph.AddEdge(edge.Source, edge.Target); err != nil {
t.Fatalf("%s: failed to add edge: %s", name, err.Error())
}
}

visited := make(map[int]int)

visit := func(value, depth int) bool {
visited[value] = depth

if test.stopAtVertex != -1 {
if value == test.stopAtVertex {
Expand All @@ -331,7 +401,18 @@ func TestDirectedBFS(t *testing.T) {
}
return false
}
_ = BFSWithDepth(graph, test.startHash, visitWithDepth)

_ = BFSWithDepth(graph, test.startHash, visit)

for expectedVisit, expectedDepth := range test.expectedVisits {
actualDepth, ok := visited[expectedVisit]
if !ok {
t.Errorf("%s: expected vertex %v to be visited, but it isn't", name, expectedVisit)
}
if expectedDepth != actualDepth {
t.Errorf("%s: vertex depth don't match: expected %v, got %v", name, expectedDepth, actualDepth)
}
}
}
}

Expand Down

0 comments on commit b3f630d

Please sign in to comment.