Skip to content

Commit

Permalink
Add a PatchProvider server handler, currently just updating the config
Browse files Browse the repository at this point in the history
  • Loading branch information
jhrozek committed Jun 4, 2024
1 parent 52ea56f commit 6da5849
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 0 deletions.
39 changes: 39 additions & 0 deletions internal/controlplane/handlers_providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,45 @@ func (s *Server) DeleteProviderByID(
}, nil
}

// PatchProvider patches a provider by name from a specific project.
func (s *Server) PatchProvider(
ctx context.Context,
req *minderv1.PatchProviderRequest,
) (*minderv1.PatchProviderResponse, error) {
entityCtx := engine.EntityFromContext(ctx)
projectID := entityCtx.Project.ID
providerName := entityCtx.Provider.Name

if providerName == "" {
return nil, status.Errorf(codes.InvalidArgument, "provider name is required")
}

err := s.providerManager.PatchProviderConfig(ctx, providerName, projectID, req.GetPatch().GetConfig().AsMap())
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, util.UserVisibleError(codes.NotFound, "provider not found")
}
return nil, status.Errorf(codes.Internal, "error patching provider: %v", err)
}

dbProv, err := s.providerStore.GetByNameInSpecificProject(ctx, projectID, providerName)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, util.UserVisibleError(codes.NotFound, "provider not found")
}
return nil, status.Errorf(codes.Internal, "error getting provider: %v", err)
}

prov, err := protobufProviderFromDB(ctx, s.store, s.cryptoEngine, &s.cfg.Provider, dbProv)
if err != nil {
return nil, status.Errorf(codes.Internal, "error creating provider: %v", err)
}

return &minderv1.PatchProviderResponse{
Provider: prov,
}, nil
}

func protobufProviderFromDB(
ctx context.Context, store db.Store,
cryptoEngine crypto.Engine, pc *config.ProviderConfig, p *db.Provider,
Expand Down
31 changes: 31 additions & 0 deletions internal/providers/manager/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"errors"
"fmt"

"dario.cat/mergo"
"github.com/google/uuid"
"github.com/rs/zerolog"

Expand Down Expand Up @@ -62,6 +63,9 @@ type ProviderManager interface {
// Deletion will only occur if the provider is in the specified project -
// it will not attempt to find a provider elsewhere in the hierarchy.
DeleteByName(ctx context.Context, name string, projectID uuid.UUID) error
// PatchProviderConfig updates the configuration of the specified provider with the specified patch.
// All keys in the configMap will overwrite the fields in the provider config.
PatchProviderConfig(ctx context.Context, providerName string, projectID uuid.UUID, configPatch map[string]any) error
}

// ProviderClassManager describes an interface for creating instances of a
Expand Down Expand Up @@ -224,6 +228,33 @@ func (p *providerManager) DeleteByName(ctx context.Context, name string, project
return p.deleteByRecord(ctx, config)
}

func (p *providerManager) PatchProviderConfig(
ctx context.Context, providerName string, projectID uuid.UUID, configPatch map[string]any,
) error {
dbProvider, err := p.store.GetByNameInSpecificProject(ctx, projectID, providerName)
if err != nil {
return fmt.Errorf("error retrieving db record: %w", err)
}

var originalConfig map[string]any

if err := json.Unmarshal(dbProvider.Definition, &originalConfig); err != nil {
return fmt.Errorf("error unmarshalling provider config: %w", err)
}

err = mergo.Map(&originalConfig, configPatch, mergo.WithOverride)
if err != nil {
return fmt.Errorf("error merging provider config: %w", err)
}

mergedJSON, err := json.Marshal(originalConfig)
if err != nil {
return fmt.Errorf("error marshalling provider config: %w", err)
}

return p.store.Update(ctx, dbProvider.ID, dbProvider.ProjectID, mergedJSON)
}

func (p *providerManager) deleteByRecord(ctx context.Context, config *db.Provider) error {
manager, err := p.getClassManager(config.Class)
if err != nil {
Expand Down
100 changes: 100 additions & 0 deletions internal/providers/manager/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"fmt"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
Expand All @@ -32,6 +33,105 @@ import (
"github.com/stacklok/minder/internal/providers/mock/fixtures"
)

type configMatcher struct {
expected json.RawMessage
}

func (m *configMatcher) Matches(x interface{}) bool {
actual, ok := x.(json.RawMessage)
if !ok {
return false
}

var exp, got interface{}

if err := json.Unmarshal(m.expected, &exp); err != nil {
return false
}
if err := json.Unmarshal(actual, &got); err != nil {
return false
}
if !cmp.Equal(exp, got) {
fmt.Printf("config mismatch for %s\n", cmp.Diff(actual, m.expected))
return false
}
return true
}

func (m *configMatcher) String() string {
return fmt.Sprintf("is equal to %+v", m.expected)
}

func TestProviderManager_PatchProviderConfig(t *testing.T) {
t.Parallel()

scenarios := []struct {
Name string
FieldMask []string
Provider *db.Provider
CurrentConfig json.RawMessage
Patch map[string]any
MergedConfig json.RawMessage
ExpectedError string
}{
{
Name: "Enabling the auto_enroll field",
Provider: githubAppProvider,
CurrentConfig: json.RawMessage(`{ "github-app": {} }`),
Patch: map[string]any{
"auto_registration": map[string]any{
"enabled": []string{"repository"},
},
},
MergedConfig: json.RawMessage(`{ "auto_registration": { "enabled": ["repository"] }, "github-app": {}}`),
},
{
Name: "Disabling the auto_enroll field",
Provider: githubAppProvider,
CurrentConfig: json.RawMessage(`{ "auto_registration": { "enabled": ["repository"] }, "github-app": {}}`),
Patch: map[string]any{
"auto_registration": map[string]any{
"enabled": []string{},
},
},
MergedConfig: json.RawMessage(`{ "auto_registration": { "enabled": [] }, "github-app": {}}`),
},
}

for _, scenario := range scenarios {
t.Run(scenario.Name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
defer ctrl.Finish()

ctx := context.Background()
store := fixtures.NewProviderStoreMock()(ctrl)
classManager := mockmanager.NewMockProviderClassManager(ctrl)

classManager.EXPECT().GetSupportedClasses().
Return([]db.ProviderClass{db.ProviderClassGithubApp}).
Times(1)
provManager, err := manager.NewProviderManager(store, classManager)
require.NoError(t, err)

dbProvider := providerWithClass(scenario.Provider.Class, providerWithConfig(scenario.CurrentConfig))
store.EXPECT().GetByNameInSpecificProject(ctx, scenario.Provider.ProjectID, scenario.Provider.Name).
Return(dbProvider, nil).
Times(1)

store.EXPECT().Update(ctx, dbProvider.ID, dbProvider.ProjectID, &configMatcher{expected: scenario.MergedConfig}).
Return(nil).
Times(1)

err = provManager.PatchProviderConfig(ctx, scenario.Provider.Name, scenario.Provider.ProjectID, scenario.Patch)
if scenario.ExpectedError != "" {
require.ErrorContains(t, err, scenario.ExpectedError)
} else {
require.NoError(t, err)
}
})
}
}
func TestProviderManager_CreateFromConfig(t *testing.T) {
t.Parallel()

Expand Down
14 changes: 14 additions & 0 deletions internal/providers/manager/mock/manager.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 6da5849

Please sign in to comment.