diff --git a/database/mock/store.go b/database/mock/store.go index b9ffb00c73..e8ce38b3f8 100644 --- a/database/mock/store.go +++ b/database/mock/store.go @@ -903,6 +903,21 @@ func (mr *MockStoreMockRecorder) GetProfileByProjectAndID(arg0, arg1 interface{} return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProfileByProjectAndID", reflect.TypeOf((*MockStore)(nil).GetProfileByProjectAndID), arg0, arg1) } +// GetProfileForEntity mocks base method. +func (m *MockStore) GetProfileForEntity(arg0 context.Context, arg1 db.GetProfileForEntityParams) (db.EntityProfile, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetProfileForEntity", arg0, arg1) + ret0, _ := ret[0].(db.EntityProfile) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetProfileForEntity indicates an expected call of GetProfileForEntity. +func (mr *MockStoreMockRecorder) GetProfileForEntity(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProfileForEntity", reflect.TypeOf((*MockStore)(nil).GetProfileForEntity), arg0, arg1) +} + // GetProfileStatusByIdAndProject mocks base method. func (m *MockStore) GetProfileStatusByIdAndProject(arg0 context.Context, arg1 db.GetProfileStatusByIdAndProjectParams) (db.GetProfileStatusByIdAndProjectRow, error) { m.ctrl.T.Helper() diff --git a/database/query/profiles.sql b/database/query/profiles.sql index 62d5bca1d9..22ce8eff42 100644 --- a/database/query/profiles.sql +++ b/database/query/profiles.sql @@ -26,6 +26,9 @@ WHERE profile_id = $1 AND entity = $2 RETURNING *; -- name: DeleteProfileForEntity :exec DELETE FROM entity_profiles WHERE profile_id = $1 AND entity = $2; +-- name: GetProfileForEntity :one +SELECT * FROM entity_profiles WHERE profile_id = $1 AND entity = $2; + -- name: GetProfileByProjectAndID :many SELECT * FROM profiles JOIN entity_profiles ON profiles.id = entity_profiles.profile_id WHERE profiles.project_id = $1 AND profiles.id = $2; diff --git a/internal/controlplane/handlers_profile.go b/internal/controlplane/handlers_profile.go index b9bd4491ea..ba64cae576 100644 --- a/internal/controlplane/handlers_profile.go +++ b/internal/controlplane/handlers_profile.go @@ -37,6 +37,13 @@ import ( minderv1 "github.com/stacklok/minder/pkg/api/protobuf/go/minder/v1" ) +// this is a tuple that allows us track rule instantiations +// and the entity they're associated with +type entityAndRuleTuple struct { + Entity minderv1.Entity + RuleID uuid.UUID +} + // authAndContextValidation is a helper function to initialize entity context info and validate input // It also sets up the needed information in the `in` entity context that's needed for the rest of the flow // Note that this also does an authorization check. @@ -133,7 +140,7 @@ func (s *Server) CreateProfile(ctx context.Context, return nil, status.Errorf(codes.InvalidArgument, "invalid profile: %v", err) } - ruleIDs, err := s.getAndValidateRulesFromProfile(ctx, in, entityCtx) + rulesInProf, err := s.getAndValidateRulesFromProfile(ctx, in, entityCtx) if err != nil { var violation *engine.RuleValidationError if errors.As(err, &violation) { @@ -181,7 +188,7 @@ func (s *Server) CreateProfile(ctx context.Context, minderv1.Entity_ENTITY_BUILD_ENVIRONMENTS: in.GetBuildEnvironment(), minderv1.Entity_ENTITY_PULL_REQUESTS: in.GetPullRequest(), } { - if err := createProfileRulesForEntity(ctx, ent, &profile, qtx, entRules, ruleIDs); err != nil { + if err := createProfileRulesForEntity(ctx, ent, &profile, qtx, entRules, rulesInProf); err != nil { return nil, err } } @@ -218,7 +225,7 @@ func createProfileRulesForEntity( profile *db.Profile, qtx db.Querier, rules []*minderv1.Profile_Rule, - ruleIDs map[string]uuid.UUID, + rulesInProf map[string]entityAndRuleTuple, ) error { if rules == nil { return nil @@ -239,8 +246,8 @@ func createProfileRulesForEntity( return status.Errorf(codes.Internal, "error creating profile") } - for idx := range ruleIDs { - ruleID := ruleIDs[idx] + for idx := range rulesInProf { + ruleID := rulesInProf[idx].RuleID _, err := qtx.UpsertRuleInstantiation(ctx, db.UpsertRuleInstantiationParams{ EntityProfileID: entProf.ID, @@ -596,7 +603,7 @@ func (s *Server) UpdateProfile(ctx context.Context, return nil, util.UserVisibleError(codes.InvalidArgument, "invalid profile update: %v", err) } - ruleIDs, err := s.getAndValidateRulesFromProfile(ctx, in, entityCtx) + rules, err := s.getAndValidateRulesFromProfile(ctx, in, entityCtx) if err != nil { var violation *engine.RuleValidationError if errors.As(err, &violation) { @@ -618,7 +625,7 @@ func (s *Server) UpdateProfile(ctx context.Context, return nil, status.Errorf(codes.Internal, "failed to get profile: %s", err) } - oldRuleIDs, err := s.getRulesFromProfile(ctx, oldProfile, entityCtx) + oldRules, err := s.getRulesFromProfile(ctx, oldProfile, entityCtx) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, util.UserVisibleError(codes.NotFound, "profile not found") @@ -643,15 +650,27 @@ func (s *Server) UpdateProfile(ctx context.Context, minderv1.Entity_ENTITY_BUILD_ENVIRONMENTS: in.GetBuildEnvironment(), minderv1.Entity_ENTITY_PULL_REQUESTS: in.GetPullRequest(), } { - if err := updateProfileRulesForEntity(ctx, ent, &profile, qtx, entRules, ruleIDs, oldRuleIDs); err != nil { + if err := updateProfileRulesForEntity(ctx, ent, &profile, qtx, entRules, rules, oldRules); err != nil { return nil, err } } - for _, ruleID := range oldRuleIDs { + for _, rule := range oldRules { + // get entity profile + log.Printf("getting profile for entity %s", rule.Entity) + entProf, err := qtx.GetProfileForEntity(ctx, db.GetProfileForEntityParams{ + ProfileID: profile.ID, + Entity: entities.EntityTypeToDB(rule.Entity), + }) + if err != nil { + log.Printf("error getting profile for entity %s: %v", rule.Entity, err) + return nil, status.Errorf(codes.Internal, "error creating profile") + } + + log.Printf("deleting rule instantiation for rule %s for entity profile %s", rule.RuleID, entProf.ID) if err := qtx.DeleteRuleInstantiation(ctx, db.DeleteRuleInstantiationParams{ - EntityProfileID: profile.ID, - RuleTypeID: ruleID, + EntityProfileID: entProf.ID, + RuleTypeID: rule.RuleID, }); err != nil { log.Printf("error deleting rule instantiation: %v", err) return nil, status.Errorf(codes.Internal, "error creating profile") @@ -689,10 +708,10 @@ func (s *Server) getAndValidateRulesFromProfile( ctx context.Context, prof *minderv1.Profile, entityCtx *engine.EntityContext, -) (map[string]uuid.UUID, error) { +) (map[string]entityAndRuleTuple, error) { // We capture the rule instantiations here so we can // track them in the db later. - ruleIDs := map[string]uuid.UUID{} + rulesInProf := map[string]entityAndRuleTuple{} err := engine.TraverseAllRulesForPipeline(prof, func(r *minderv1.Profile_Rule) error { // TODO: This will need to be updated to support @@ -731,7 +750,10 @@ func (s *Server) getAndValidateRulesFromProfile( return fmt.Errorf("error validating rule params: %w", err) } - ruleIDs[r.GetType()] = rtdb.ID + rulesInProf[r.GetType()] = entityAndRuleTuple{ + Entity: minderv1.EntityFromString(rtyppb.Def.InEntity), + RuleID: rtdb.ID, + } return nil }) @@ -740,17 +762,17 @@ func (s *Server) getAndValidateRulesFromProfile( return nil, err } - return ruleIDs, nil + return rulesInProf, nil } func (s *Server) getRulesFromProfile( ctx context.Context, prof *minderv1.Profile, entityCtx *engine.EntityContext, -) (map[string]uuid.UUID, error) { +) (map[string]entityAndRuleTuple, error) { // We capture the rule instantiations here so we can // track them in the db later. - ruleIDs := map[string]uuid.UUID{} + rulesInProf := map[string]entityAndRuleTuple{} err := engine.TraverseAllRulesForPipeline(prof, func(r *minderv1.Profile_Rule) error { // TODO: This will need to be updated to support @@ -764,7 +786,15 @@ func (s *Server) getRulesFromProfile( return fmt.Errorf("error getting rule type %s: %w", r.GetType(), err) } - ruleIDs[r.GetType()] = rtdb.ID + rtyppb, err := engine.RuleTypePBFromDB(&rtdb, entityCtx) + if err != nil { + return fmt.Errorf("cannot convert rule type %s to pb: %w", rtdb.Name, err) + } + + rulesInProf[r.GetType()] = entityAndRuleTuple{ + Entity: minderv1.EntityFromString(rtyppb.Def.InEntity), + RuleID: rtdb.ID, + } return nil }) @@ -773,7 +803,7 @@ func (s *Server) getRulesFromProfile( return nil, err } - return ruleIDs, nil + return rulesInProf, nil } func updateProfileRulesForEntity( @@ -782,8 +812,8 @@ func updateProfileRulesForEntity( profile *db.Profile, qtx db.Querier, rules []*minderv1.Profile_Rule, - ruleIDs map[string]uuid.UUID, - oldruleIDs map[string]uuid.UUID, + rulesInProf map[string]entityAndRuleTuple, + oldRulesInProf map[string]entityAndRuleTuple, ) error { if len(rules) == 0 { return qtx.DeleteProfileForEntity(ctx, db.DeleteProfileForEntityParams{ @@ -807,12 +837,12 @@ func updateProfileRulesForEntity( return err } - for idx := range ruleIDs { - ruleID := ruleIDs[idx] + for idx := range rulesInProf { + ruleID := rulesInProf[idx] _, err := qtx.UpsertRuleInstantiation(ctx, db.UpsertRuleInstantiationParams{ EntityProfileID: entProf.ID, - RuleTypeID: ruleID, + RuleTypeID: ruleID.RuleID, }) if errors.Is(err, sql.ErrNoRows) { log.Printf("the rule instantiation for rule already existed.") @@ -823,7 +853,7 @@ func updateProfileRulesForEntity( // Remove the rule from the old rule IDs so we // can delete the ones that are no longer needed - delete(oldruleIDs, idx) + delete(oldRulesInProf, idx) } return err diff --git a/internal/db/profiles.sql.go b/internal/db/profiles.sql.go index 17039b4021..aabe2eba8c 100644 --- a/internal/db/profiles.sql.go +++ b/internal/db/profiles.sql.go @@ -347,6 +347,29 @@ func (q *Queries) GetProfileByProjectAndID(ctx context.Context, arg GetProfileBy return items, nil } +const getProfileForEntity = `-- name: GetProfileForEntity :one +SELECT id, entity, profile_id, contextual_rules, created_at, updated_at FROM entity_profiles WHERE profile_id = $1 AND entity = $2 +` + +type GetProfileForEntityParams struct { + ProfileID uuid.UUID `json:"profile_id"` + Entity Entities `json:"entity"` +} + +func (q *Queries) GetProfileForEntity(ctx context.Context, arg GetProfileForEntityParams) (EntityProfile, error) { + row := q.db.QueryRowContext(ctx, getProfileForEntity, arg.ProfileID, arg.Entity) + var i EntityProfile + err := row.Scan( + &i.ID, + &i.Entity, + &i.ProfileID, + &i.ContextualRules, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + const listProfilesByProjectID = `-- name: ListProfilesByProjectID :many SELECT profiles.id, name, provider, project_id, remediate, alert, profiles.created_at, profiles.updated_at, entity_profiles.id, entity, profile_id, contextual_rules, entity_profiles.created_at, entity_profiles.updated_at FROM profiles JOIN entity_profiles ON profiles.id = entity_profiles.profile_id WHERE profiles.project_id = $1 diff --git a/internal/db/querier.go b/internal/db/querier.go index 76680b12fe..c044a0257b 100644 --- a/internal/db/querier.go +++ b/internal/db/querier.go @@ -70,6 +70,7 @@ type Querier interface { GetProfileByIDAndLock(ctx context.Context, id uuid.UUID) (Profile, error) GetProfileByNameAndLock(ctx context.Context, arg GetProfileByNameAndLockParams) (Profile, error) GetProfileByProjectAndID(ctx context.Context, arg GetProfileByProjectAndIDParams) ([]GetProfileByProjectAndIDRow, error) + GetProfileForEntity(ctx context.Context, arg GetProfileForEntityParams) (EntityProfile, error) GetProfileStatusByIdAndProject(ctx context.Context, arg GetProfileStatusByIdAndProjectParams) (GetProfileStatusByIdAndProjectRow, error) GetProfileStatusByNameAndProject(ctx context.Context, arg GetProfileStatusByNameAndProjectParams) (GetProfileStatusByNameAndProjectRow, error) GetProfileStatusByProject(ctx context.Context, projectID uuid.UUID) ([]GetProfileStatusByProjectRow, error)