Skip to content

Commit

Permalink
provider/aws: Security Rules drift and sorting changes
Browse files Browse the repository at this point in the history
This commit adds failing tests to demonstrate the problem presented with AWS
aggregating the security group rules
  • Loading branch information
catsby authored and bigkraig committed Mar 1, 2016
1 parent d91432f commit 0cfff67
Show file tree
Hide file tree
Showing 3 changed files with 1,135 additions and 3 deletions.
235 changes: 232 additions & 3 deletions builtin/providers/aws/resource_aws_security_group.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ func resourceAwsSecurityGroupCreate(d *schema.ResourceData, meta interface{}) er
}

// AWS defaults all Security Groups to have an ALLOW ALL egress rule. Here we
// revoke that rule, so users don't unknowningly have/use it.
// revoke that rule, so users don't unknowingly have/use it.
group := resp.(*ec2.SecurityGroup)
if group.VpcId != nil && *group.VpcId != "" {
log.Printf("[DEBUG] Revoking default egress rule for Security Group for %s", d.Id())
Expand Down Expand Up @@ -273,13 +273,26 @@ func resourceAwsSecurityGroupRead(d *schema.ResourceData, meta interface{}) erro

sg := sgRaw.(*ec2.SecurityGroup)

ingressRules := resourceAwsSecurityGroupIPPermGather(d.Id(), sg.IpPermissions)
egressRules := resourceAwsSecurityGroupIPPermGather(d.Id(), sg.IpPermissionsEgress)
remoteIngressRules := resourceAwsSecurityGroupIPPermGather(d.Id(), sg.IpPermissions)
remoteEgressRules := resourceAwsSecurityGroupIPPermGather(d.Id(), sg.IpPermissionsEgress)

//
// TODO enforce the seperation of ips and security_groups in a rule block
//

localIngressRules := d.Get("ingress").(*schema.Set).List()
localEgressRules := d.Get("egress").(*schema.Set).List()

// Loop through the local state of rules, doing a match against the remote
// ruleSet we built above.
ingressRules := matchRules("ingress", localIngressRules, remoteIngressRules)
egressRules := matchRules("egress", localEgressRules, remoteEgressRules)

d.Set("description", sg.Description)
d.Set("name", sg.GroupName)
d.Set("vpc_id", sg.VpcId)
d.Set("owner_id", sg.OwnerId)

if err := d.Set("ingress", ingressRules); err != nil {
log.Printf("[WARN] Error setting Ingress rule set for (%s): %s", d.Id(), err)
}
Expand Down Expand Up @@ -593,3 +606,219 @@ func SGStateRefreshFunc(conn *ec2.EC2, id string) resource.StateRefreshFunc {
return group, "exists", nil
}
}

// matchRules receives the group id, type of rules, and the local / remote maps
// of rules. We iterate through the local set of rules trying to find a matching
// remote rule, which may be structured differently because of how AWS
// aggregates the rules under the to, from, and type.
//
//
// Matching rules are written to state, with their elements removed from the
// remote set
//
// If no match is found, we'll write the remote rule to state and let the graph
// sort things out
func matchRules(rType string, local []interface{}, remote []map[string]interface{}) []map[string]interface{} {
// For each local ip or security_group, we need to match against the remote
// ruleSet until all ips or security_groups are found

// saves represents the rules that have been identified to be saved to state,
// in the appropriate d.Set("{ingress,egress}") call.
var saves []map[string]interface{}
for _, raw := range local {
l := raw.(map[string]interface{})

var selfVal bool
if v, ok := l["self"]; ok {
selfVal = v.(bool)
}

// matching against self is required to detect rules that only include self
// as the rule. resourceAwsSecurityGroupIPPermGather parses the group out
// and replaces it with self if it's ID is found
localHash := idHash(rType, l["protocol"].(string), int64(l["to_port"].(int)), int64(l["from_port"].(int)), selfVal)

// loop remote rules, looking for a matching hash
for _, r := range remote {
var remoteSelfVal bool
if v, ok := r["self"]; ok {
remoteSelfVal = v.(bool)
}

// hash this remote rule and compare it for a match consideration with the
// local rule we're examining
rHash := idHash(rType, r["protocol"].(string), r["to_port"].(int64), r["from_port"].(int64), remoteSelfVal)
if rHash == localHash {
var numExpectedCidrs, numExpectedSGs, numRemoteCidrs, numRemoteSGs int
var matchingCidrs []string
var matchingSGs []string

// grab the local/remote cidr and sg groups, capturing the expected and
// actual counts
lcRaw, ok := l["cidr_blocks"]
if ok {
numExpectedCidrs = len(l["cidr_blocks"].([]interface{}))
}
lsRaw, ok := l["security_groups"]
if ok {
numExpectedSGs = len(l["security_groups"].(*schema.Set).List())
}

rcRaw, ok := r["cidr_blocks"]
if ok {
numRemoteCidrs = len(r["cidr_blocks"].([]string))
}

rsRaw, ok := r["security_groups"]
if ok {
numRemoteSGs = len(r["security_groups"].(*schema.Set).List())
}

// check some early failures
if numExpectedCidrs > numRemoteCidrs {
log.Printf("[DEBUG] Local rule has more CIDR blocks, continuing (%d/%d)", numExpectedCidrs, numRemoteCidrs)
continue
}
if numExpectedSGs > numRemoteSGs {
log.Printf("[DEBUG] Local rule has more Security Groups, continuing (%d/%d)", numExpectedSGs, numRemoteSGs)
continue
}

// match CIDRs by converting both to sets, and using Set methods
var localCidrs []interface{}
if lcRaw != nil {
localCidrs = lcRaw.([]interface{})
}
localCidrSet := schema.NewSet(schema.HashString, localCidrs)

// remote cidrs are presented as a slice of strings, so we need to
// reformat them into a slice of interfaces to be used in creating the
// remote cidr set
var remoteCidrs []string
if rcRaw != nil {
remoteCidrs = rcRaw.([]string)
}
// convert remote cidrs to a set, for easy comparisions
var list []interface{}
for _, s := range remoteCidrs {
list = append(list, s)
}
remoteCidrSet := schema.NewSet(schema.HashString, list)

// Build up a list of local cidrs that are found in the remote set
for _, s := range localCidrSet.List() {
if remoteCidrSet.Contains(s) {
matchingCidrs = append(matchingCidrs, s.(string))
}
}

// match SGs. Both local and remote are already sets
var localSGSet *schema.Set
if lsRaw == nil {
localSGSet = schema.NewSet(schema.HashString, nil)
} else {
localSGSet = lsRaw.(*schema.Set)
}

var remoteSGSet *schema.Set
if rsRaw == nil {
remoteSGSet = schema.NewSet(schema.HashString, nil)
} else {
remoteSGSet = rsRaw.(*schema.Set)
}

// Build up a list of local security groups that are found in the remote set
for _, s := range localSGSet.List() {
if remoteSGSet.Contains(s) {
matchingSGs = append(matchingSGs, s.(string))
}
}

// compare equalities for matches.
// If we found the number of cidrs and number of sgs, we declare a
// match, and then remove those elements from the remote rule, so that
// this remote rule can still be considered by other local rules
if numExpectedCidrs == len(matchingCidrs) {
if numExpectedSGs == len(matchingSGs) {
// confirm that self references match
var lSelf bool
var rSelf bool
if _, ok := l["self"]; ok {
lSelf = l["self"].(bool)
}
if _, ok := r["self"]; ok {
rSelf = r["self"].(bool)
}
if rSelf == lSelf {
delete(r, "self")
// pop local cidrs from remote
diffCidr := remoteCidrSet.Difference(localCidrSet)
var newCidr []string
for _, cRaw := range diffCidr.List() {
newCidr = append(newCidr, cRaw.(string))
}

// reassigning
if len(newCidr) > 0 {
r["cidr_blocks"] = newCidr
} else {
delete(r, "cidr_blocks")
}

// pop local sgs from remote
diffSGs := remoteSGSet.Difference(localSGSet)
if len(diffSGs.List()) > 0 {
r["security_groups"] = diffSGs
} else {
delete(r, "security_groups")
}

saves = append(saves, l)
}
}
}
}
}
}

// Here we catch any remote rules that have not been stripped of all self,
// cidrs, and security groups. We'll add remote rules here that have not been
// matched locally, and let the graph sort things out. This will happen when
// rules are added externally to Terraform
for _, r := range remote {
var lenCidr, lenSGs int
if rCidrs, ok := r["cidr_blocks"]; ok {
lenCidr = len(rCidrs.([]string))
}

if rawSGs, ok := r["security_groups"]; ok {
lenSGs = len(rawSGs.(*schema.Set).List())
}

if _, ok := r["self"]; ok {
if r["self"].(bool) == true {
lenSGs++
}
}

if lenSGs+lenCidr > 0 {
log.Printf("[DEBUG] Found a remote Rule that wasn't empty: (%#v)", r)
saves = append(saves, r)
}
}

return saves
}

// Creates a unique hash for the type, ports, and protocol, used as a key in
// maps
func idHash(rType, protocol string, toPort, fromPort int64, self bool) string {
var buf bytes.Buffer
buf.WriteString(fmt.Sprintf("%s-", rType))
buf.WriteString(fmt.Sprintf("%d-", toPort))
buf.WriteString(fmt.Sprintf("%d-", fromPort))
buf.WriteString(fmt.Sprintf("%s-", protocol))
buf.WriteString(fmt.Sprintf("%t-", self))

return fmt.Sprintf("rule-%d", hashcode.String(buf.String()))
}
Loading

0 comments on commit 0cfff67

Please sign in to comment.