Skip to content

Commit

Permalink
fix: implemented correct addition and removal of network rules from n…
Browse files Browse the repository at this point in the history
…etwork policies
  • Loading branch information
AS-auxmoney committed Apr 25, 2024
1 parent 36fffe7 commit 0998ff4
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 122 deletions.
152 changes: 95 additions & 57 deletions pkg/resources/network_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,11 @@ func CreateContextNetworkPolicy(ctx context.Context, d *schema.ResourceData, met
return ReadContextNetworkPolicy(ctx, d, meta)
}

// NetworkRulesSnowflakeDTO is needed to unpack the applied network rules from the JSON response from Snowflake
type NetworkRulesSnowflakeDTO struct {
FullyQualifiedRuleName string
}

func ReadContextNetworkPolicy(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics {
diags := diag.Diagnostics{}
policyName := d.Id()
Expand Down Expand Up @@ -152,7 +157,7 @@ func ReadContextNetworkPolicy(ctx context.Context, d *schema.ResourceData, meta
return diag.FromErr(err)
}
case "ALLOWED_NETWORK_RULE_LIST":
var networkRules []NetworkRules
var networkRules []NetworkRulesSnowflakeDTO
err := json.Unmarshal([]byte(desc.Value), &networkRules)
if err != nil {
return diag.FromErr(err)
Expand All @@ -166,7 +171,7 @@ func ReadContextNetworkPolicy(ctx context.Context, d *schema.ResourceData, meta
return diag.FromErr(err)
}
case "BLOCKED_NETWORK_RULE_LIST":
var networkRules []NetworkRules
var networkRules []NetworkRulesSnowflakeDTO
err := json.Unmarshal([]byte(desc.Value), &networkRules)
if err != nil {
return diag.FromErr(err)
Expand All @@ -186,84 +191,70 @@ func ReadContextNetworkPolicy(ctx context.Context, d *schema.ResourceData, meta
return diags
}

type NetworkRules struct {
FullyQualifiedRuleName string
}

func UpdateContextNetworkPolicy(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics {
name := d.Id()
client := meta.(*provider.Context).Client
baseReq := sdk.NewAlterNetworkPolicyRequest(sdk.NewAccountObjectIdentifier(name))

if d.HasChange("comment") {
comment := d.Get("comment")
comment := d.Get("comment").(string)
baseReq := sdk.NewAlterNetworkPolicyRequest(sdk.NewAccountObjectIdentifier(name))

if c := comment.(string); c == "" {
if comment == "" {
err := client.NetworkPolicies.Alter(ctx, baseReq.WithUnsetComment(sdk.Bool(true)))
if err != nil {
return diag.Diagnostics{
diag.Diagnostic{
Severity: diag.Error,
Summary: "Error updating network policy",
Detail: fmt.Sprintf("error unsetting comment for network policy %v err = %v", name, err),
},
}
return getUpdateContextDiag("unsetting comment", name, err)
}
} else {
setReq := sdk.NewNetworkPolicySetRequest().WithComment(sdk.String(comment.(string)))
setReq := sdk.NewNetworkPolicySetRequest().WithComment(sdk.String(comment))
err := client.NetworkPolicies.Alter(ctx, baseReq.WithSet(setReq))
if err != nil {
return diag.Diagnostics{
diag.Diagnostic{
Severity: diag.Error,
Summary: "Error updating network policy",
Detail: fmt.Sprintf("error updating comment for network policy %v err = %v", name, err),
},
}
return getUpdateContextDiag("updating comment", name, err)
}
}
}

// TODO: empty network rules (that is unsetting) does not work, as WithUnset is missing.
// Removing the validation in network_policies_validations_gen.go does not solve the problem, as the SDK cannot
// handle empty lists
if d.HasChange("allowed_network_rule_list") {
networkRuleIdentifiers := parseNetworkRulesList(d.Get("allowed_network_rule_list"))
oldList, newList := d.GetChange("allowed_network_rule_list")
addedNetworkRuleIdentifiers, removedNetworkRuleIdentifiers := getAddedAndRemovedIdentifiers(oldList, newList)

var err error
if len(networkRuleIdentifiers) == 0 {
removeReq := sdk.NewRemoveNetworkRuleRequest().WithAllowedNetworkRuleList(networkRuleIdentifiers)
err = client.NetworkPolicies.Alter(ctx, baseReq.WithRemove(removeReq))
} else {
addReq := sdk.NewAddNetworkRuleRequest().WithAllowedNetworkRuleList(networkRuleIdentifiers)
err = client.NetworkPolicies.Alter(ctx, baseReq.WithAdd(addReq))
}
if len(addedNetworkRuleIdentifiers) > 0 {
baseReq := sdk.NewAlterNetworkPolicyRequest(sdk.NewAccountObjectIdentifier(name))
addReq := sdk.NewAddNetworkRuleRequest().WithAllowedNetworkRuleList(addedNetworkRuleIdentifiers)
err := client.NetworkPolicies.Alter(ctx, baseReq.WithAdd(addReq))

if err != nil {
return diag.Diagnostics{
diag.Diagnostic{
Severity: diag.Error,
Summary: "Error updating network policy",
Detail: fmt.Sprintf("error updating ALLOWED_NETWORK_RULE_LIST for network policy %v err = %v", name, err),
},
if err != nil {
return getUpdateContextDiag("adding to ALLOWED_NETWORK_RULE_LIST", name, err)
}
}
if len(removedNetworkRuleIdentifiers) > 0 {
baseReq := sdk.NewAlterNetworkPolicyRequest(sdk.NewAccountObjectIdentifier(name))
removeReq := sdk.NewRemoveNetworkRuleRequest().WithAllowedNetworkRuleList(removedNetworkRuleIdentifiers)
err := client.NetworkPolicies.Alter(ctx, baseReq.WithRemove(removeReq))
if err != nil {
return getUpdateContextDiag("removing from ALLOWED_NETWORK_RULE_LIST", name, err)
}
}
}

// TODO: empty network rules (that is unsetting) does not work, as WithUnset is missing.
// Removing the validation in network_policies_validations_gen.go does not solve the problem, as the SDK cannot
// handle empty lists
if d.HasChange("blocked_network_rule_list") {
networkRuleIdentifiers := parseNetworkRulesList(d.Get("blocked_network_rule_list"))
setReq := sdk.NewNetworkPolicySetRequest().WithBlockedNetworkRuleList(networkRuleIdentifiers)
err := client.NetworkPolicies.Alter(ctx, baseReq.WithSet(setReq))
if err != nil {
return diag.Diagnostics{
diag.Diagnostic{
Severity: diag.Error,
Summary: "Error updating network policy",
Detail: fmt.Sprintf("error updating BLOCKED_NETWORK_RULE_LIST for network policy %v err = %v", name, err),
},
oldList, newList := d.GetChange("blocked_network_rule_list")
addedNetworkRuleIdentifiers, removedNetworkRuleIdentifiers := getAddedAndRemovedIdentifiers(oldList, newList)

if len(addedNetworkRuleIdentifiers) > 0 {
baseReq := sdk.NewAlterNetworkPolicyRequest(sdk.NewAccountObjectIdentifier(name))
addReq := sdk.NewAddNetworkRuleRequest().WithBlockedNetworkRuleList(addedNetworkRuleIdentifiers)
err := client.NetworkPolicies.Alter(ctx, baseReq.WithAdd(addReq))

if err != nil {
return getUpdateContextDiag("adding to BLOCKED_NETWORK_RULE_LIST", name, err)
}
}
if len(removedNetworkRuleIdentifiers) > 0 {
baseReq := sdk.NewAlterNetworkPolicyRequest(sdk.NewAccountObjectIdentifier(name))
removeReq := sdk.NewRemoveNetworkRuleRequest().WithBlockedNetworkRuleList(removedNetworkRuleIdentifiers)
err := client.NetworkPolicies.Alter(ctx, baseReq.WithRemove(removeReq))
if err != nil {
return getUpdateContextDiag("removing from BLOCKED_NETWORK_RULE_LIST", name, err)
}
}
}
Expand All @@ -272,10 +263,12 @@ func UpdateContextNetworkPolicy(ctx context.Context, d *schema.ResourceData, met
// Removing the validation in network_policies_validations_gen.go does not solve the problem, as the SDK cannot
// handle empty lists
if d.HasChange("allowed_ip_list") {
baseReq := sdk.NewAlterNetworkPolicyRequest(sdk.NewAccountObjectIdentifier(name))
ipRequests := parseIPList(d.Get("allowed_ip_list"))
log.Printf("ipRequests: %v", ipRequests)

setReq := sdk.NewNetworkPolicySetRequest().WithAllowedIpList(ipRequests)
err := client.NetworkPolicies.Alter(ctx, baseReq.WithSet(setReq))

if err != nil {
return diag.Diagnostics{
diag.Diagnostic{
Expand All @@ -291,9 +284,12 @@ func UpdateContextNetworkPolicy(ctx context.Context, d *schema.ResourceData, met
// Removing the validation in network_policies_validations_gen.go does not solve the problem, as the SDK cannot
// handle empty lists
if d.HasChange("blocked_ip_list") {
baseReq := sdk.NewAlterNetworkPolicyRequest(sdk.NewAccountObjectIdentifier(name))
ipRequests := parseIPList(d.Get("blocked_ip_list"))

setReq := sdk.NewNetworkPolicySetRequest().WithBlockedIpList(ipRequests)
err := client.NetworkPolicies.Alter(ctx, baseReq.WithSet(setReq))

if err != nil {
return diag.Diagnostics{
diag.Diagnostic{
Expand All @@ -308,6 +304,16 @@ func UpdateContextNetworkPolicy(ctx context.Context, d *schema.ResourceData, met
return ReadContextNetworkPolicy(ctx, d, meta)
}

func getUpdateContextDiag(action string, name string, err error) diag.Diagnostics {
return diag.Diagnostics{
diag.Diagnostic{
Severity: diag.Error,
Summary: "Error updating network policy",
Detail: fmt.Sprintf("error %v for network policy %v err = %v", action, name, err),
},
}
}

func DeleteContextNetworkPolicy(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics {
name := d.Id()
client := meta.(*provider.Context).Client
Expand Down Expand Up @@ -347,3 +353,35 @@ func parseNetworkRulesList(v interface{}) []sdk.SchemaObjectIdentifier {
}
return networkRuleIdentifiers
}

// getAddedAndRemovedIdentifiers returns the identifiers that were added and removed from the old and new network rule lists.
func getAddedAndRemovedIdentifiers(oldList interface{}, newList interface{}) ([]sdk.SchemaObjectIdentifier, []sdk.SchemaObjectIdentifier) {
oldNetworkRuleIdentifiers := parseNetworkRulesList(oldList)
newNetworkRuleIdentifiers := parseNetworkRulesList(newList)

var addedNetworkRuleIdentifiers []sdk.SchemaObjectIdentifier
var removedNetworkRuleIdentifiers []sdk.SchemaObjectIdentifier

for _, identifier := range oldNetworkRuleIdentifiers {
if !contains(newNetworkRuleIdentifiers, identifier) {
removedNetworkRuleIdentifiers = append(removedNetworkRuleIdentifiers, identifier)
}
}
log.Printf("removedNetworkRuleIdentifiers: %v", removedNetworkRuleIdentifiers)
for _, identifier := range newNetworkRuleIdentifiers {
if !contains(oldNetworkRuleIdentifiers, identifier) {
addedNetworkRuleIdentifiers = append(addedNetworkRuleIdentifiers, identifier)
}
}
return addedNetworkRuleIdentifiers, removedNetworkRuleIdentifiers
}

// contains checks if a given identifier is in a list of identifiers.
func contains(identifierList []sdk.SchemaObjectIdentifier, identifier sdk.SchemaObjectIdentifier) bool {
for _, objectIdentifier := range identifierList {
if objectIdentifier.FullyQualifiedName() == identifier.FullyQualifiedName() {
return true
}
}
return false
}
94 changes: 29 additions & 65 deletions pkg/resources/network_rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,9 @@ package resources

import (
"context"
"errors"
"fmt"
"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers"
"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider"
"github.com/hashicorp/terraform-plugin-sdk/v2/diag"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/customdiff"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation"
"log"

Expand Down Expand Up @@ -79,56 +76,6 @@ func NetworkRule() *schema.Resource {
Importer: &schema.ResourceImporter{
StateContext: schema.ImportStatePassthroughContext,
},
CustomizeDiff: customdiff.Sequence(
func(ctx context.Context, diff *schema.ResourceDiff, v interface{}) error {
// Plan time validation for az_mode
// InvalidParameterCombination: Must specify at least two cache nodes in order to specify AZ Mode of 'cross-az'.

ruleTypeRaw, ok := diff.GetOk("type")
if !ok {
return nil
}
ruleType := sdk.NetworkRuleType(ruleTypeRaw.(string))
ruleModeRaw, ok := diff.GetOk("mode")
if !ok {
return nil
}
ruleMode := sdk.NetworkRuleMode(ruleModeRaw.(string))

// TODO: add valueList validators for different rule types
//valueListRaw, ok := diff.GetOk("value_list")
//if !ok {
// return nil
//}
//
//valueList := expandStringList(valueListRaw.(*schema.Set).List())

switch ruleType {
case sdk.NetworkRuleTypeIpv4:
if ruleMode != sdk.NetworkRuleModeIngress {
s := fmt.Sprintf("the network rule mode %s is not supported by the network rule type IPv4. The network rule mode must be one of [INGRESS].", ruleMode)
return errors.New(s)
}
case sdk.NetworkRuleTypeAwsVpcEndpointId:
if ruleMode != sdk.NetworkRuleModeIngress && ruleMode != sdk.NetworkRuleModeInternalStage {
s := fmt.Sprintf("the network rule mode %s is not supported by the network rule type AWSVPCEID. The network rule mode must be one of [INGRESS, INTERNAL_STAGE].", ruleMode)
return errors.New(s)
}
case sdk.NetworkRuleTypeAzureLinkId:
if ruleMode != sdk.NetworkRuleModeIngress {
s := fmt.Sprintf("the network rule mode %s is not supported by the network rule type AZURELINKID. The network rule mode must be one of [INGRESS, INTERNAL_STAGE].", ruleMode)
return errors.New(s)
}
case sdk.NetworkRuleTypeHostPort:
if ruleMode != sdk.NetworkRuleModeEgress {
s := fmt.Sprintf("the network rule mode %s is not supported by the network rule type HOST_PORT. The network rule mode must be one of [EGRESS].", ruleMode)
return errors.New(s)
}
}

return nil
},
),
}
}

Expand Down Expand Up @@ -217,23 +164,40 @@ func ReadContextNetworkRule(ctx context.Context, d *schema.ResourceData, meta in
func UpdateContextNetworkRule(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics {
client := meta.(*provider.Context).Client
id := helpers.DecodeSnowflakeID(d.Id()).(sdk.SchemaObjectIdentifier)
baseReq := sdk.NewAlterNetworkRuleRequest(id)

if d.HasChange("comment") || d.HasChange("value_list") {
valueList := expandStringList(d.Get("value_list").(*schema.Set).List())
networkRuleValues := make([]sdk.NetworkRuleValue, len(valueList))
for i, v := range valueList {
networkRuleValues[i] = sdk.NetworkRuleValue{Value: v}
valueList := expandStringList(d.Get("value_list").(*schema.Set).List())
networkRuleValues := make([]sdk.NetworkRuleValue, len(valueList))
for i, v := range valueList {
networkRuleValues[i] = sdk.NetworkRuleValue{Value: v}
}
comment := d.Get("comment").(string)

if d.HasChange("value_list") {
baseReq := sdk.NewAlterNetworkRuleRequest(id)
if len(valueList) == 0 {
unsetReq := sdk.NewNetworkRuleUnsetRequest().WithValueList(sdk.Bool(true))
baseReq.WithUnset(unsetReq)
} else {
setReq := sdk.NewNetworkRuleSetRequest(networkRuleValues)
baseReq.WithSet(setReq)
}

// TODO: use sdk.NewNetworkRuleUnsetRequest() if valueList is empty
setReq := sdk.NewNetworkRuleSetRequest(networkRuleValues)
if err := client.NetworkRules.Alter(ctx, baseReq); err != nil {
return diag.FromErr(err)
}
}

if d.HasChange("comment") {
comment := d.Get("comment").(string)
setReq.WithComment(sdk.String(comment))
if d.HasChange("comment") {
baseReq := sdk.NewAlterNetworkRuleRequest(id)
if len(comment) == 0 {
unsetReq := sdk.NewNetworkRuleUnsetRequest().WithComment(sdk.Bool(true))
baseReq.WithUnset(unsetReq)
} else {
setReq := sdk.NewNetworkRuleSetRequest(networkRuleValues).WithComment(sdk.String(comment))
baseReq.WithSet(setReq)
}
if err := client.NetworkRules.Alter(ctx, baseReq.WithSet(setReq)); err != nil {

if err := client.NetworkRules.Alter(ctx, baseReq); err != nil {
return diag.FromErr(err)
}
}
Expand Down

0 comments on commit 0998ff4

Please sign in to comment.