Skip to content

Commit

Permalink
feat(shutdown): make shutdown async and parallel (shutdown in reverse…
Browse files Browse the repository at this point in the history
… invocation order)
  • Loading branch information
samber committed Feb 11, 2024
1 parent 9170ba6 commit 7eed21f
Show file tree
Hide file tree
Showing 12 changed files with 310 additions and 111 deletions.
79 changes: 51 additions & 28 deletions dag.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,59 +23,82 @@ type EdgeService struct {
// newDAG creates a new DAG (Directed Acyclic Graph) with initialized dependencies and dependents maps.
func newDAG() *DAG {
return &DAG{
dependencies: new(sync.Map),
dependents: new(sync.Map),
mu: sync.RWMutex{},
dependencies: map[EdgeService]map[EdgeService]struct{}{},
dependents: map[EdgeService]map[EdgeService]struct{}{},
}
}

// DAG represents a Directed Acyclic Graph of services, tracking dependencies and dependents.
type DAG struct {
dependencies *sync.Map
dependents *sync.Map
mu sync.RWMutex
dependencies map[EdgeService]map[EdgeService]struct{}
dependents map[EdgeService]map[EdgeService]struct{}
}

// addDependency adds a dependency relationship from one service to another in the DAG.
func (d *DAG) addDependency(fromScopeID, fromScopeName, fromServiceName, toScopeID, toScopeName, toServiceName string) {
from := newEdgeService(fromScopeID, fromScopeName, fromServiceName)
to := newEdgeService(toScopeID, toScopeName, toServiceName)

d.addToMap(d.dependencies, from, to)
d.addToMap(d.dependents, to, from)
d.mu.Lock()
defer d.mu.Unlock()

// from -> to
if _, ok := d.dependencies[from]; !ok {
d.dependencies[from] = map[EdgeService]struct{}{}
}
d.dependencies[from][to] = struct{}{}

// from <- to
if _, ok := d.dependents[to]; !ok {
d.dependents[to] = map[EdgeService]struct{}{}
}
d.dependents[to][from] = struct{}{}
}

// addToMap is a helper function to add a key-value pair to a sync.Map, creating a new sync.Map for the value if necessary.
func (d *DAG) addToMap(dependencyMap *sync.Map, key, value interface{}) {
valueMap := new(sync.Map)
valueMap.Store(value, struct{}{})
// removeService removes a dependency relationship between services in the DAG.
func (d *DAG) removeService(scopeID, scopeName, serviceName string) {
edge := newEdgeService(scopeID, scopeName, serviceName)

d.mu.Lock()
defer d.mu.Unlock()

if actual, loaded := dependencyMap.LoadOrStore(key, valueMap); loaded {
actual.(*sync.Map).Store(value, struct{}{})
dependencies, dependents := d.explainServiceImplem(edge)

for _, dependency := range dependencies {
delete(d.dependents[dependency], edge)
}

// should be empty, because we remove dependencies in the inverse invocation order
for _, dependent := range dependents {
delete(d.dependencies[dependent], edge)
}

delete(d.dependencies, edge)
delete(d.dependents, edge)
}

// explainService provides information about a service's dependencies and dependents in the DAG.
func (d *DAG) explainService(scopeID, scopeName, serviceName string) (dependencies, dependents []EdgeService) {
edge := newEdgeService(scopeID, scopeName, serviceName)

dependencies = d.getServicesFromMap(d.dependencies, edge)
dependents = d.getServicesFromMap(d.dependents, edge)
d.mu.RLock()
defer d.mu.RUnlock()

return dependencies, dependents
return d.explainServiceImplem(edge)
}

// getServicesFromMap is a helper function to retrieve services related to a specific key from a sync.Map.
func (d *DAG) getServicesFromMap(serviceMap *sync.Map, edge EdgeService) []EdgeService {
var services []EdgeService

if kv, ok := serviceMap.Load(edge); ok {
kv.(*sync.Map).Range(func(key, value interface{}) bool {
edgeService, ok := key.(EdgeService)
if ok {
services = append(services, edgeService)
}
return ok
})
func (d *DAG) explainServiceImplem(edge EdgeService) (dependencies, dependents []EdgeService) {
dependencies, dependents = []EdgeService{}, []EdgeService{}

if kv, ok := d.dependencies[edge]; ok {
dependencies = keys(kv)
}

return services
if kv, ok := d.dependents[edge]; ok {
dependents = keys(kv)
}

return dependencies, dependents
}
64 changes: 35 additions & 29 deletions dag_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package do

import (
"sync"
"testing"

"github.com/stretchr/testify/assert"
Expand All @@ -24,11 +23,11 @@ func TestNewDAG(t *testing.T) {
is := assert.New(t)

dag := newDAG()
expectedDependencies := unSyncMap(new(sync.Map))
expectedDependents := unSyncMap(new(sync.Map))
expectedDependencies := map[EdgeService]map[EdgeService]struct{}{}
expectedDependents := map[EdgeService]map[EdgeService]struct{}{}

is.Equal(expectedDependencies, unSyncMap(dag.dependencies))
is.Equal(expectedDependents, unSyncMap(dag.dependents))
is.Equal(expectedDependencies, dag.dependencies)
is.Equal(expectedDependents, dag.dependents)
}

// TestDAG_addDependency checks the addition of dependencies to the DAG.
Expand All @@ -44,19 +43,42 @@ func TestDAG_addDependency(t *testing.T) {

dag.addDependency("scope1", "scope1", "service1", "scope2", "scope2", "service2")

expectedDependencies := map[interface{}]interface{}{edge1: map[interface{}]interface{}{edge2: struct{}{}}}
expectedDependents := map[interface{}]interface{}{edge2: map[interface{}]interface{}{edge1: struct{}{}}}
expectedDependencies := map[EdgeService]map[EdgeService]struct{}{edge1: {edge2: {}}}
expectedDependents := map[EdgeService]map[EdgeService]struct{}{edge2: {edge1: {}}}

is.Equal(expectedDependencies, unSyncMap(dag.dependencies))
is.Equal(expectedDependents, unSyncMap(dag.dependents))
is.Equal(expectedDependencies, dag.dependencies)
is.Equal(expectedDependents, dag.dependents)

dag.addDependency("scope3", "scope3", "service3", "scope2", "scope2", "service2")

expectedDependencies[edge3] = map[interface{}]interface{}{edge2: struct{}{}}
expectedDependents[edge2] = map[interface{}]interface{}{edge1: struct{}{}, edge3: struct{}{}}
expectedDependencies = map[EdgeService]map[EdgeService]struct{}{edge1: {edge2: {}}, edge3: {edge2: {}}}
expectedDependents = map[EdgeService]map[EdgeService]struct{}{edge2: {edge1: {}, edge3: {}}}

is.Equal(expectedDependencies, unSyncMap(dag.dependencies))
is.Equal(expectedDependents, unSyncMap(dag.dependents))
is.Equal(expectedDependencies, dag.dependencies)
is.Equal(expectedDependents, dag.dependents)
}

// TestDAG_removeService checks the removal of dependencies to the DAG.
func TestDAG_removeService(t *testing.T) {
t.Parallel()
is := assert.New(t)

edge1 := newEdgeService("scope1", "scope1", "service1")
// edge2 := newEdgeService("scope2", "scope2", "service2")
edge3 := newEdgeService("scope3", "scope3", "service3")

dag := newDAG()

dag.addDependency("scope1", "scope1", "service1", "scope2", "scope2", "service2")
dag.addDependency("scope3", "scope3", "service3", "scope2", "scope2", "service2")

dag.removeService("scope2", "scope2", "service2")

expectedDependencies := map[EdgeService]map[EdgeService]struct{}{edge1: {}, edge3: {}}
expectedDependents := map[EdgeService]map[EdgeService]struct{}{}

is.Equal(expectedDependencies, dag.dependencies)
is.Equal(expectedDependents, dag.dependents)
}

// TestDAG_explainService checks the explanation of dependencies for a service in the DAG.
Expand Down Expand Up @@ -92,19 +114,3 @@ func TestDAG_explainService(t *testing.T) {
is.ElementsMatch([]EdgeService{}, a)
is.ElementsMatch([]EdgeService{}, b)
}

func unSyncMap(syncMap *sync.Map) map[interface{}]interface{} {
result := make(map[interface{}]interface{})

syncMap.Range(func(key, value interface{}) bool {
if vSyncMap, ok := value.(*sync.Map); ok {
result[key] = unSyncMap(vSyncMap)
} else {
result[key] = value
}

return true
})

return result
}
14 changes: 6 additions & 8 deletions docs/docs/service-lifecycle/shutdowner.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ A shutdown can be triggered on a root scope:

```go
// on demand
injector.Shutdown() map[string]error
injector.ShutdownWithContext(context.Context) map[string]error
injector.Shutdown() error
injector.ShutdownWithContext(context.Context) error

// on signal
injector.ShutdownOnSignals(...os.Signal) (os.Signal, map[string]error)
injector.ShutdownOnSignalsWithContext(context.Context, ...os.Signal) (os.Signal, map[string]error)
injector.ShutdownOnSignals(...os.Signal) (os.Signal, error)
injector.ShutdownOnSignalsWithContext(context.Context, ...os.Signal) (os.Signal, error)
```

...on a single service:
Expand Down Expand Up @@ -90,9 +90,7 @@ Invoke(i, ...)

ctx := context.WithTimeout(10 * time.Second)
errors := i.ShutdownWithContext(ctx)
for _, err := range errors {
if err != nil {
log.Println("shutdown error:", err)
}
if err != nil {
log.Println("shutdown error:", err)
}
```
60 changes: 59 additions & 1 deletion errors.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,65 @@
package do

import "errors"
import (
"errors"
"fmt"
"strings"
)

var ErrServiceNotFound = errors.New("DI: could not find service")
var ErrCircularDependency = errors.New("DI: circular dependency detected")
var ErrHealthCheckTimeout = errors.New("DI: health check timeout")

func newShutdownErrors() *ShutdownErrors {
return &ShutdownErrors{}
}

type ShutdownErrors map[EdgeService]error

func (e *ShutdownErrors) Add(scopeID string, scopeName string, serviceName string, err error) {
if err != nil {
(*e)[newEdgeService(scopeID, scopeName, serviceName)] = err
}
}

func (e ShutdownErrors) Len() int {
out := 0
for _, v := range e {
if v != nil {
out++
}
}
return out
}

func (e ShutdownErrors) Error() string {
lines := []string{}
for k, v := range e {
if v != nil {
lines = append(lines, fmt.Sprintf(" - %s > %s: %s", k.ScopeName, k.Service, v.Error()))
}
}

if len(lines) == 0 {
return "DI: no shutdown errors"
}

return "DI: shutdown errors:\n" + strings.Join(lines, "\n")
}

func mergeShutdownErrors(ins ...*ShutdownErrors) *ShutdownErrors {
out := newShutdownErrors()

for _, in := range ins {
if in != nil {
se := &ShutdownErrors{}
if ok := errors.As(in, &se); ok {
for k, v := range *se {
(*out)[k] = v
}
}
}
}

return out
}
67 changes: 67 additions & 0 deletions errors_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package do

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestShutdownErrors_Add(t *testing.T) {
is := assert.New(t)

se := newShutdownErrors()
is.Equal(0, len(*se))
is.Equal(0, se.Len())

se.Add("scope-1", "scope-a", "service-a", nil)
is.Equal(0, len(*se))
is.Equal(0, se.Len())
is.EqualValues(&ShutdownErrors{}, se)

se.Add("scope-2", "scope-b", "service-b", assert.AnError)
is.Equal(1, len(*se))
is.Equal(1, se.Len())
is.EqualValues(&ShutdownErrors{
{ScopeID: "scope-2", ScopeName: "scope-b", Service: "service-b"}: assert.AnError,
}, se)
}

func TestShutdownErrors_Error(t *testing.T) {
is := assert.New(t)

se := newShutdownErrors()
is.Equal(0, len(*se))
is.Equal(0, se.Len())
is.EqualValues("DI: no shutdown errors", se.Error())

se.Add("scope-1", "scope-a", "service-a", nil)
is.Equal(0, len(*se))
is.Equal(0, se.Len())
is.EqualValues("DI: no shutdown errors", se.Error())

se.Add("scope-2", "scope-b", "service-b", assert.AnError)
is.Equal(1, len(*se))
is.Equal(1, se.Len())
is.EqualValues("DI: shutdown errors:\n - scope-b > service-b: assert.AnError general error for testing", se.Error())
}

func TestMergeShutdownErrors(t *testing.T) {
is := assert.New(t)

se1 := newShutdownErrors()
se2 := newShutdownErrors()
se3 := newShutdownErrors()

se1.Add("scope-1", "scope-a", "service-a", assert.AnError)
se2.Add("scope-2", "scope-b", "service-b", assert.AnError)

result := mergeShutdownErrors(se1, se2, se3, nil)
is.Equal(2, result.Len())
is.EqualValues(
&ShutdownErrors{
{ScopeID: "scope-1", ScopeName: "scope-a", Service: "service-a"}: assert.AnError,
{ScopeID: "scope-2", ScopeName: "scope-b", Service: "service-b"}: assert.AnError,
},
result,
)
}
2 changes: 1 addition & 1 deletion examples/http/std/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package main
import (
"net/http"

"github.com/samber/do/http/std"
"github.com/samber/do/http/std/v2"
)

func main() {
Expand Down
Loading

0 comments on commit 7eed21f

Please sign in to comment.