diff --git a/internal/flag/internal_flag.go b/internal/flag/internal_flag.go index 4c89953611a..b95c5ae61c4 100644 --- a/internal/flag/internal_flag.go +++ b/internal/flag/internal_flag.go @@ -295,13 +295,13 @@ func (f *InternalFlag) IsValid() error { const isDefaultRule = true // Validate rules - if err := f.GetDefaultRule().IsValid(isDefaultRule); err != nil { + if err := f.GetDefaultRule().IsValid(isDefaultRule, f.GetVariations()); err != nil { return err } ruleNames := map[string]interface{}{} for _, rule := range f.GetRules() { - if err := rule.IsValid(!isDefaultRule); err != nil { + if err := rule.IsValid(!isDefaultRule, f.GetVariations()); err != nil { return err } diff --git a/internal/flag/internal_flag_test.go b/internal/flag/internal_flag_test.go index 3133c7f489c..812a63062b0 100644 --- a/internal/flag/internal_flag_test.go +++ b/internal/flag/internal_flag_test.go @@ -1993,7 +1993,7 @@ func TestInternalFlag_GetVariationValue(t *testing.T) { want interface{} }{ { - name: "Should return nil if variation does not exists", + name: "Should return nil if variation does not exist", flag: flag.InternalFlag{ Variations: &map[string]*interface{}{ "varA": testconvert.Interface("valueA"), @@ -2210,6 +2210,7 @@ func TestInternalFlag_IsValid(t *testing.T) { fields: fields{ Variations: &map[string]*interface{}{ "A": testconvert.Interface("A"), + "B": testconvert.Interface("B"), }, DefaultRule: &flag.Rule{ Percentages: &map[string]float64{ @@ -2264,6 +2265,7 @@ func TestInternalFlag_IsValid(t *testing.T) { fields: fields{ Variations: &map[string]*interface{}{ "A": testconvert.Interface("A"), + "B": testconvert.Interface("B"), }, DefaultRule: &flag.Rule{ Percentages: &map[string]float64{ @@ -2366,6 +2368,194 @@ func TestInternalFlag_IsValid(t *testing.T) { errorMsg: "", wantErr: assert.NoError, }, + { + name: "should error if default rule referencing a variation that does not exist", + fields: fields{ + Variations: &map[string]*interface{}{ + "A": testconvert.Interface("A"), + "B": testconvert.Interface("B"), + }, + DefaultRule: &flag.Rule{ + VariationResult: testconvert.String("C"), + }, + }, + errorMsg: "invalid variation: C does not exist", + wantErr: assert.Error, + }, + { + name: "should error if default percentage rule referencing a variation that does not exist", + fields: fields{ + Variations: &map[string]*interface{}{ + "A": testconvert.Interface("A"), + "B": testconvert.Interface("B"), + }, + DefaultRule: &flag.Rule{ + Percentages: &map[string]float64{ + "A": 90, + "B": 5, + "C": 5, + }, + }, + }, + errorMsg: "invalid percentage: variation C does not exist", + wantErr: assert.Error, + }, + { + name: "should error if default progressive rule end rollout step referencing a variation that does not exist", + fields: fields{ + Variations: &map[string]*interface{}{ + "A": testconvert.Interface("A"), + "B": testconvert.Interface("B"), + }, + DefaultRule: &flag.Rule{ + ProgressiveRollout: &flag.ProgressiveRollout{ + Initial: &flag.ProgressiveRolloutStep{ + Variation: testconvert.String("A"), + Percentage: testconvert.Float64(0), + Date: testconvert.Time(time.Now().Add(1 * time.Second)), + }, + End: &flag.ProgressiveRolloutStep{ + Variation: testconvert.String("C"), + Percentage: testconvert.Float64(100), + Date: testconvert.Time(time.Now().Add(2 * time.Second)), + }, + }, + }, + }, + errorMsg: "invalid progressive rollout, end variation C does not exist", + wantErr: assert.Error, + }, + { + name: "should error if default progressive rule initial rollout step referencing a variation that does not exist", + fields: fields{ + Variations: &map[string]*interface{}{ + "A": testconvert.Interface("A"), + "B": testconvert.Interface("B"), + }, + DefaultRule: &flag.Rule{ + ProgressiveRollout: &flag.ProgressiveRollout{ + Initial: &flag.ProgressiveRolloutStep{ + Variation: testconvert.String("C"), + Percentage: testconvert.Float64(0), + Date: testconvert.Time(time.Now().Add(1 * time.Second)), + }, + End: &flag.ProgressiveRolloutStep{ + Variation: testconvert.String("A"), + Percentage: testconvert.Float64(100), + Date: testconvert.Time(time.Now().Add(2 * time.Second)), + }, + }, + }, + }, + errorMsg: "invalid progressive rollout, initial variation C does not exist", + wantErr: assert.Error, + }, + { + name: "should error if targeting rule referencing a variation that does not exist", + fields: fields{ + Variations: &map[string]*interface{}{ + "A": testconvert.Interface("A"), + "B": testconvert.Interface("B"), + }, + Rules: &[]flag.Rule{ + { + Query: testconvert.String("targetingKey eq 1"), + VariationResult: testconvert.String("C"), + }, + }, + DefaultRule: &flag.Rule{ + VariationResult: testconvert.String("A"), + }, + }, + errorMsg: "invalid variation: C does not exist", + wantErr: assert.Error, + }, + { + name: "should error if percentage in targeting rule referencing a variation that does not exist", + fields: fields{ + Variations: &map[string]*interface{}{ + "A": testconvert.Interface("A"), + "B": testconvert.Interface("B"), + }, + Rules: &[]flag.Rule{ + { + Query: testconvert.String("targetingKey eq 1"), + Percentages: &map[string]float64{ + "A": 90, + "B": 5, + "C": 5, + }, + }, + }, + DefaultRule: &flag.Rule{ + VariationResult: testconvert.String("A"), + }, + }, + errorMsg: "invalid percentage: variation C does not exist", + wantErr: assert.Error, + }, + { + name: "should error if progressive rollout in targeting rule referencing an initial variation that does not exist", + fields: fields{ + Variations: &map[string]*interface{}{ + "A": testconvert.Interface("A"), + "B": testconvert.Interface("B"), + }, + Rules: &[]flag.Rule{ + { + Query: testconvert.String("targetingKey eq 1"), + ProgressiveRollout: &flag.ProgressiveRollout{ + Initial: &flag.ProgressiveRolloutStep{ + Variation: testconvert.String("C"), + Percentage: testconvert.Float64(0), + Date: testconvert.Time(time.Now().Add(1 * time.Second)), + }, + End: &flag.ProgressiveRolloutStep{ + Variation: testconvert.String("A"), + Percentage: testconvert.Float64(100), + Date: testconvert.Time(time.Now().Add(2 * time.Second)), + }, + }, + }, + }, + DefaultRule: &flag.Rule{ + VariationResult: testconvert.String("A"), + }, + }, + errorMsg: "invalid progressive rollout, initial variation C does not exist", + wantErr: assert.Error, + }, + { + name: "should error if progressive rollout in targeting rule referencing an end variation that does not exist", + fields: fields{ + Variations: &map[string]*interface{}{ + "A": testconvert.Interface("A"), + "B": testconvert.Interface("B"), + }, + Rules: &[]flag.Rule{ + { + Query: testconvert.String("targetingKey eq 1"), + ProgressiveRollout: &flag.ProgressiveRollout{ + Initial: &flag.ProgressiveRolloutStep{ + Variation: testconvert.String("A"), + Percentage: testconvert.Float64(0), + Date: testconvert.Time(time.Now().Add(1 * time.Second)), + }, + End: &flag.ProgressiveRolloutStep{ + Variation: testconvert.String("C"), + Percentage: testconvert.Float64(100), + Date: testconvert.Time(time.Now().Add(2 * time.Second)), + }, + }, + }, + }, + DefaultRule: &flag.Rule{ + VariationResult: testconvert.String("A"), + }, + }, + errorMsg: "invalid progressive rollout, end variation C does not exist", + wantErr: assert.Error, + }, } for _, tt := range tests { diff --git a/internal/flag/rule.go b/internal/flag/rule.go index 6fba42418b6..265e606f2a2 100644 --- a/internal/flag/rule.go +++ b/internal/flag/rule.go @@ -223,7 +223,7 @@ func (r *Rule) MergeRules(updatedRule Rule) { } // IsValid is checking if the rule is valid -func (r *Rule) IsValid(defaultRule bool) error { +func (r *Rule) IsValid(defaultRule bool, variations map[string]*interface{}) error { if !defaultRule && r.IsDisable() { return nil } @@ -240,8 +240,11 @@ func (r *Rule) IsValid(defaultRule bool) error { // Validate the percentage of the rule if r.Percentages != nil { count := float64(0) - for _, p := range r.GetPercentages() { + for k, p := range r.GetPercentages() { count += p + if _, ok := variations[k]; !ok { + return fmt.Errorf("invalid percentage: variation %s does not exist", k) + } } if len(r.GetPercentages()) == 0 { @@ -254,11 +257,29 @@ func (r *Rule) IsValid(defaultRule bool) error { } // Progressive rollout: check that initial is lower than end - if r.ProgressiveRollout != nil && - (r.GetProgressiveRollout().End.getPercentage() < r.GetProgressiveRollout().Initial.getPercentage()) { - return fmt.Errorf("invalid progressive rollout, initial percentage should be lower "+ - "than end percentage: %v/%v", - r.GetProgressiveRollout().Initial.getPercentage(), r.GetProgressiveRollout().End.getPercentage()) + if r.ProgressiveRollout != nil { + if r.GetProgressiveRollout().End.getPercentage() < r.GetProgressiveRollout().Initial.getPercentage() { + return fmt.Errorf("invalid progressive rollout, initial percentage should be lower "+ + "than end percentage: %v/%v", + r.GetProgressiveRollout().Initial.getPercentage(), r.GetProgressiveRollout().End.getPercentage()) + } + + endVar := r.GetProgressiveRollout().End.getVariation() + if _, ok := variations[endVar]; !ok { + return fmt.Errorf("invalid progressive rollout, end variation %s does not exist", endVar) + } + + initialVar := r.GetProgressiveRollout().Initial.getVariation() + if _, ok := variations[initialVar]; !ok { + return fmt.Errorf("invalid progressive rollout, initial variation %s does not exist", initialVar) + } + } + + // Check that the variation exists + if r.Percentages == nil && r.ProgressiveRollout == nil && r.VariationResult != nil { + if _, ok := variations[r.GetVariationResult()]; !ok { + return fmt.Errorf("invalid variation: %s does not exist", r.GetVariationResult()) + } } return nil } diff --git a/variation_all_flags_test.go b/variation_all_flags_test.go index 68d4ae10e05..3042bdb706b 100644 --- a/variation_all_flags_test.go +++ b/variation_all_flags_test.go @@ -31,17 +31,6 @@ func TestAllFlagsState(t *testing.T) { jsonOutput: "./testdata/ffclient/all_flags/marshal_json/valid_multiple_types.json", initModule: true, }, - { - name: "Error in flag-0", - config: Config{ - Retriever: &fileretriever.Retriever{ - Path: "./testdata/ffclient/all_flags/config_flag/flag-config-with-error.yaml", - }, - }, - valid: false, - jsonOutput: "./testdata/ffclient/all_flags/marshal_json/error_in_flag_0.json", - initModule: true, - }, { name: "module not init", config: Config{