Skip to content

Commit

Permalink
Fix rule instantiation deletions
Browse files Browse the repository at this point in the history
  • Loading branch information
JAORMX committed Nov 8, 2023
1 parent 2ddb937 commit f81c523
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 25 deletions.
15 changes: 15 additions & 0 deletions database/mock/store.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions database/query/profiles.sql
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ WHERE profile_id = $1 AND entity = $2 RETURNING *;
-- name: DeleteProfileForEntity :exec
DELETE FROM entity_profiles WHERE profile_id = $1 AND entity = $2;

-- name: GetProfileForEntity :one
SELECT * FROM entity_profiles WHERE profile_id = $1 AND entity = $2;

-- name: GetProfileByProjectAndID :many
SELECT * FROM profiles JOIN entity_profiles ON profiles.id = entity_profiles.profile_id
WHERE profiles.project_id = $1 AND profiles.id = $2;
Expand Down
80 changes: 55 additions & 25 deletions internal/controlplane/handlers_profile.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@ import (
minderv1 "github.com/stacklok/minder/pkg/api/protobuf/go/minder/v1"
)

// this is a tuple that allows us track rule instantiations
// and the entity they're associated with
type entityAndRuleTuple struct {
Entity minderv1.Entity
RuleID uuid.UUID
}

// authAndContextValidation is a helper function to initialize entity context info and validate input
// It also sets up the needed information in the `in` entity context that's needed for the rest of the flow
// Note that this also does an authorization check.
Expand Down Expand Up @@ -133,7 +140,7 @@ func (s *Server) CreateProfile(ctx context.Context,
return nil, status.Errorf(codes.InvalidArgument, "invalid profile: %v", err)
}

ruleIDs, err := s.getAndValidateRulesFromProfile(ctx, in, entityCtx)
rulesInProf, err := s.getAndValidateRulesFromProfile(ctx, in, entityCtx)
if err != nil {
var violation *engine.RuleValidationError
if errors.As(err, &violation) {
Expand Down Expand Up @@ -181,7 +188,7 @@ func (s *Server) CreateProfile(ctx context.Context,
minderv1.Entity_ENTITY_BUILD_ENVIRONMENTS: in.GetBuildEnvironment(),
minderv1.Entity_ENTITY_PULL_REQUESTS: in.GetPullRequest(),
} {
if err := createProfileRulesForEntity(ctx, ent, &profile, qtx, entRules, ruleIDs); err != nil {
if err := createProfileRulesForEntity(ctx, ent, &profile, qtx, entRules, rulesInProf); err != nil {
return nil, err
}
}
Expand Down Expand Up @@ -218,7 +225,7 @@ func createProfileRulesForEntity(
profile *db.Profile,
qtx db.Querier,
rules []*minderv1.Profile_Rule,
ruleIDs map[string]uuid.UUID,
rulesInProf map[string]entityAndRuleTuple,
) error {
if rules == nil {
return nil
Expand All @@ -239,8 +246,8 @@ func createProfileRulesForEntity(
return status.Errorf(codes.Internal, "error creating profile")
}

for idx := range ruleIDs {
ruleID := ruleIDs[idx]
for idx := range rulesInProf {
ruleID := rulesInProf[idx].RuleID

_, err := qtx.UpsertRuleInstantiation(ctx, db.UpsertRuleInstantiationParams{
EntityProfileID: entProf.ID,
Expand Down Expand Up @@ -596,7 +603,7 @@ func (s *Server) UpdateProfile(ctx context.Context,
return nil, util.UserVisibleError(codes.InvalidArgument, "invalid profile update: %v", err)
}

ruleIDs, err := s.getAndValidateRulesFromProfile(ctx, in, entityCtx)
rules, err := s.getAndValidateRulesFromProfile(ctx, in, entityCtx)
if err != nil {
var violation *engine.RuleValidationError
if errors.As(err, &violation) {
Expand All @@ -618,7 +625,7 @@ func (s *Server) UpdateProfile(ctx context.Context,
return nil, status.Errorf(codes.Internal, "failed to get profile: %s", err)
}

oldRuleIDs, err := s.getRulesFromProfile(ctx, oldProfile, entityCtx)
oldRules, err := s.getRulesFromProfile(ctx, oldProfile, entityCtx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, util.UserVisibleError(codes.NotFound, "profile not found")
Expand All @@ -643,15 +650,27 @@ func (s *Server) UpdateProfile(ctx context.Context,
minderv1.Entity_ENTITY_BUILD_ENVIRONMENTS: in.GetBuildEnvironment(),
minderv1.Entity_ENTITY_PULL_REQUESTS: in.GetPullRequest(),
} {
if err := updateProfileRulesForEntity(ctx, ent, &profile, qtx, entRules, ruleIDs, oldRuleIDs); err != nil {
if err := updateProfileRulesForEntity(ctx, ent, &profile, qtx, entRules, rules, oldRules); err != nil {
return nil, err
}
}

for _, ruleID := range oldRuleIDs {
for _, rule := range oldRules {
// get entity profile
log.Printf("getting profile for entity %s", rule.Entity)
entProf, err := qtx.GetProfileForEntity(ctx, db.GetProfileForEntityParams{
ProfileID: profile.ID,
Entity: entities.EntityTypeToDB(rule.Entity),
})
if err != nil {
log.Printf("error getting profile for entity %s: %v", rule.Entity, err)
return nil, status.Errorf(codes.Internal, "error creating profile")
}

log.Printf("deleting rule instantiation for rule %s for entity profile %s", rule.RuleID, entProf.ID)
if err := qtx.DeleteRuleInstantiation(ctx, db.DeleteRuleInstantiationParams{
EntityProfileID: profile.ID,
RuleTypeID: ruleID,
EntityProfileID: entProf.ID,
RuleTypeID: rule.RuleID,
}); err != nil {
log.Printf("error deleting rule instantiation: %v", err)
return nil, status.Errorf(codes.Internal, "error creating profile")
Expand Down Expand Up @@ -689,10 +708,10 @@ func (s *Server) getAndValidateRulesFromProfile(
ctx context.Context,
prof *minderv1.Profile,
entityCtx *engine.EntityContext,
) (map[string]uuid.UUID, error) {
) (map[string]entityAndRuleTuple, error) {
// We capture the rule instantiations here so we can
// track them in the db later.
ruleIDs := map[string]uuid.UUID{}
rulesInProf := map[string]entityAndRuleTuple{}

err := engine.TraverseAllRulesForPipeline(prof, func(r *minderv1.Profile_Rule) error {
// TODO: This will need to be updated to support
Expand Down Expand Up @@ -731,7 +750,10 @@ func (s *Server) getAndValidateRulesFromProfile(
return fmt.Errorf("error validating rule params: %w", err)
}

ruleIDs[r.GetType()] = rtdb.ID
rulesInProf[r.GetType()] = entityAndRuleTuple{
Entity: minderv1.EntityFromString(rtyppb.Def.InEntity),
RuleID: rtdb.ID,
}

return nil
})
Expand All @@ -740,17 +762,17 @@ func (s *Server) getAndValidateRulesFromProfile(
return nil, err
}

return ruleIDs, nil
return rulesInProf, nil
}

func (s *Server) getRulesFromProfile(
ctx context.Context,
prof *minderv1.Profile,
entityCtx *engine.EntityContext,
) (map[string]uuid.UUID, error) {
) (map[string]entityAndRuleTuple, error) {
// We capture the rule instantiations here so we can
// track them in the db later.
ruleIDs := map[string]uuid.UUID{}
rulesInProf := map[string]entityAndRuleTuple{}

err := engine.TraverseAllRulesForPipeline(prof, func(r *minderv1.Profile_Rule) error {
// TODO: This will need to be updated to support
Expand All @@ -764,7 +786,15 @@ func (s *Server) getRulesFromProfile(
return fmt.Errorf("error getting rule type %s: %w", r.GetType(), err)
}

ruleIDs[r.GetType()] = rtdb.ID
rtyppb, err := engine.RuleTypePBFromDB(&rtdb, entityCtx)
if err != nil {
return fmt.Errorf("cannot convert rule type %s to pb: %w", rtdb.Name, err)
}

rulesInProf[r.GetType()] = entityAndRuleTuple{
Entity: minderv1.EntityFromString(rtyppb.Def.InEntity),
RuleID: rtdb.ID,
}

return nil
})
Expand All @@ -773,7 +803,7 @@ func (s *Server) getRulesFromProfile(
return nil, err
}

return ruleIDs, nil
return rulesInProf, nil
}

func updateProfileRulesForEntity(
Expand All @@ -782,8 +812,8 @@ func updateProfileRulesForEntity(
profile *db.Profile,
qtx db.Querier,
rules []*minderv1.Profile_Rule,
ruleIDs map[string]uuid.UUID,
oldruleIDs map[string]uuid.UUID,
rulesInProf map[string]entityAndRuleTuple,
oldRulesInProf map[string]entityAndRuleTuple,
) error {
if len(rules) == 0 {
return qtx.DeleteProfileForEntity(ctx, db.DeleteProfileForEntityParams{
Expand All @@ -807,12 +837,12 @@ func updateProfileRulesForEntity(
return err
}

for idx := range ruleIDs {
ruleID := ruleIDs[idx]
for idx := range rulesInProf {
ruleID := rulesInProf[idx]

_, err := qtx.UpsertRuleInstantiation(ctx, db.UpsertRuleInstantiationParams{
EntityProfileID: entProf.ID,
RuleTypeID: ruleID,
RuleTypeID: ruleID.RuleID,
})
if errors.Is(err, sql.ErrNoRows) {
log.Printf("the rule instantiation for rule already existed.")
Expand All @@ -823,7 +853,7 @@ func updateProfileRulesForEntity(

// Remove the rule from the old rule IDs so we
// can delete the ones that are no longer needed
delete(oldruleIDs, idx)
delete(oldRulesInProf, idx)
}

return err
Expand Down
23 changes: 23 additions & 0 deletions internal/db/profiles.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions internal/db/querier.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit f81c523

Please sign in to comment.