Skip to content

Commit

Permalink
Properly handle skipped plugins in plugin get
Browse files Browse the repository at this point in the history
Signed-off-by: Derek McGowan <[email protected]>
  • Loading branch information
dmcgowan committed Oct 31, 2023
1 parent b2f449e commit c7431cc
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 17 deletions.
45 changes: 28 additions & 17 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,18 +135,28 @@ func (ps *Set) GetAll() []*Plugin {
// GetByType should be used. If only one is expected, then to switch plugins,
// disable or remove the unused plugins of the same type.
func (i *InitContext) GetSingle(t Type) (interface{}, error) {
pt, ok := i.plugins.byTypeAndID[t]
if !ok || len(pt) == 0 {
return nil, fmt.Errorf("no plugins registered for %s: %w", t, ErrPluginNotFound)
}
if len(pt) > 1 {
return nil, fmt.Errorf("multiple plugins registered for %s: %w", t, ErrPluginMultipleInstances)
var (
found bool
instance interface{}
)
for _, v := range i.plugins.byTypeAndID[t] {
i, err := v.Instance()
if err != nil {
if IsSkipPlugin(err) {
continue
}
return instance, err
}
if found {
return nil, fmt.Errorf("multiple plugins registered for %s: %w", t, ErrPluginMultipleInstances)
}
instance = i
found = true
}
var p *Plugin
for _, v := range pt {
p = v
if !found {
return nil, fmt.Errorf("no plugins registered for %s: %w", t, ErrPluginNotFound)
}
return p.Instance()
return instance, nil
}

// Plugins returns plugin set
Expand All @@ -170,19 +180,20 @@ func (i *InitContext) GetByID(t Type, id string) (interface{}, error) {

// GetByType returns all plugins with the specific type.
func (i *InitContext) GetByType(t Type) (map[string]interface{}, error) {
pt, ok := i.plugins.byTypeAndID[t]
if !ok {
return nil, fmt.Errorf("no plugins registered for %s: %w", t, ErrPluginNotFound)
}

pi := make(map[string]interface{}, len(pt))
for id, p := range pt {
pi := map[string]interface{}{}
for id, p := range i.plugins.byTypeAndID[t] {
i, err := p.Instance()
if err != nil {
if IsSkipPlugin(err) {
continue
}
return nil, err
}
pi[id] = i
}
if len(pi) == 0 {
return nil, fmt.Errorf("no plugins registered for %s: %w", t, ErrPluginNotFound)
}

return pi, nil
}
135 changes: 135 additions & 0 deletions plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package plugin

import (
"errors"
"fmt"
"testing"
)

Expand Down Expand Up @@ -377,3 +379,136 @@ func TestPluginGraph(t *testing.T) {
cmpOrdered(t, ordered, testcase.expectedURI)
}
}

func TestGetPlugins(t *testing.T) {
otherError := fmt.Errorf("other error")
plugins := NewPluginSet()
for _, p := range []*Plugin{
testPlugin("type1", "id1", "id1", nil),
testPlugin("type1", "id2", "id2", ErrSkipPlugin),
testPlugin("type2", "id3", "id3", ErrSkipPlugin),
testPlugin("type3", "id4", "id4", nil),
testPlugin("type4", "id5", "id5", nil),
testPlugin("type4", "id6", "id6", nil),
testPlugin("type5", "id7", "id7", otherError),
} {
plugins.Add(p)
}

ic := InitContext{
plugins: plugins,
}

for _, tc := range []struct {
pluginType string
err error
}{
{"type1", nil},
{"type2", ErrPluginNotFound},
{"type3", nil},
{"type4", ErrPluginMultipleInstances},
{"type5", otherError},
} {
t.Run("GetSingle", func(t *testing.T) {
instance, err := ic.GetSingle(Type(tc.pluginType))
if err != nil {
if tc.err == nil {
t.Fatalf("unexpected error %v", err)
} else if !errors.Is(err, tc.err) {
t.Fatalf("unexpected error %v, expected %v", err, tc.err)
}
return
} else if tc.err != nil {
t.Fatalf("expected error %v, got no error", tc.err)
}
_, ok := instance.(string)
if !ok {
t.Fatalf("unexpected instance value %v", instance)
}
})
}

for _, tc := range []struct {
pluginType string
expected []string
err error
}{
{"type1", []string{"id1"}, nil},
{"type2", nil, ErrPluginNotFound},
{"type3", []string{"id4"}, nil},
{"type4", []string{"id5", "id6"}, nil},
{"type5", nil, otherError},
} {
t.Run("GetByType", func(t *testing.T) {
m, err := ic.GetByType(Type(tc.pluginType))
if err != nil {
if tc.err == nil {
t.Fatalf("unexpected error %v", err)
} else if !errors.Is(err, tc.err) {
t.Fatalf("unexpected error %v, expected %v", err, tc.err)
}
return
} else if tc.err != nil {
t.Fatalf("expected error %v, got no error", tc.err)
}

if len(m) != len(tc.expected) {
t.Fatalf("unexpected result %v, expected %v", m, tc.expected)
}
for _, v := range tc.expected {
instance, ok := m[v]
if !ok {
t.Errorf("missing value for %q", v)
continue
}
if instance.(string) != v {
t.Errorf("unexpected value %v, expected %v", instance, v)
}
}
})
}

for _, tc := range []struct {
pluginType string
id string
err error
}{
{"type1", "id1", nil},
{"type1", "id2", ErrSkipPlugin},
{"type2", "id3", ErrSkipPlugin},
{"type3", "id4", nil},
{"type4", "id5", nil},
{"type4", "id6", nil},
{"type5", "id7", otherError},
} {
t.Run("GetByID", func(t *testing.T) {
instance, err := ic.GetByID(Type(tc.pluginType), tc.id)
if err != nil {
if tc.err == nil {
t.Fatalf("unexpected error %v", err)
} else if !errors.Is(err, tc.err) {
t.Fatalf("unexpected error %v, expected %v", err, tc.err)
}
return
} else if tc.err != nil {
t.Fatalf("expected error %v, got no error", tc.err)
}

if instance.(string) != tc.id {
t.Errorf("unexpected value %v, expected %v", instance, tc.id)
}
})
}

}

func testPlugin(t Type, id string, i interface{}, err error) *Plugin {
return &Plugin{
Registration: Registration{
Type: t,
ID: id,
},
instance: i,
err: err,
}
}

0 comments on commit c7431cc

Please sign in to comment.