Skip to content

Commit

Permalink
fix: phantom attributes on multi-resources (CLOUD-1843)
Browse files Browse the repository at this point in the history
We were failing to resolve the
`bucket = aws_s3_bucket.nativebucket[0].id` attribute in the following snippet:

    resource "aws_s3_bucket" "nativebucket" {
        count = 1
        bucket = "test"
    }

    resource "aws_s3_bucket_versioning" "nativebucket" {
        count = 1
        bucket = aws_s3_bucket.nativebucket[0].id
        versioning_configuration {
          status = "Enabled"
        }
    }

This is problematic since it is extremely common to use
`count = var.create ? 1 : 0` in terraform modules.

This is fixed by a number of changes:

1.  We currently reference attributes by `LocalName`, which is a list of
    strings.  This cannot represent an accessor that also has numbers in it,
    like `aws_s3_bucket.nativebucket[0].id`.

    A new `accessor` is added to take care of this case, including conversion
    functions to and from `LocalName`.

2.  We take care to keep the trailing part of the `accessor`s around in any
    conversion functions; so that
    `aws_s3_bucket.nativebucket[0].id` gets split into
    `aws_s3_bucket.nativebucket` and `[0].id`.

3.  We refactor the `phantomAttrs` type so it receives a way to tell whether
    or not a resource is a multi-resource (rather than passing this in through
    an argument in `phantomAttrs.add()`).

4.  Finally, now that we have all the machinery from 1-3 available, we can
    adjust `phantomAttrs` to remove one element from the trailing accessors
    when applicable.

I added this snippet as a golden test, and included one for the `for_each`
version as well.
  • Loading branch information
jaspervdj-snyk committed Nov 22, 2023
1 parent fca577d commit 1f77093
Show file tree
Hide file tree
Showing 10 changed files with 358 additions and 46 deletions.
3 changes: 3 additions & 0 deletions changes/unreleased/Fixed-20231122-134818.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
kind: Fixed
body: phantom attributes on multi-resources
time: 2023-11-22T13:48:18.832055057+01:00
116 changes: 116 additions & 0 deletions pkg/hcl_interpreter/accessor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package hcl_interpreter

import (
"fmt"
"strconv"
"strings"

"github.com/hashicorp/hcl/v2"
"github.com/zclconf/go-cty/cty"
)

// accessor represents paths in HCL that can contain both string or int parts,
// e.g. "foo.bar[3].qux".
type accessor []interface{}

func (a accessor) toString() string {
buf := &strings.Builder{}
for i, p := range a {
switch p := p.(type) {
case int:
fmt.Fprintf(buf, "[%d]", p)
case string:
if i == 0 {
fmt.Fprintf(buf, "%s", p)
} else {
fmt.Fprintf(buf, ".%s", p)
}
}
}
return buf.String()
}

func stringToAccessor(input string) (accessor, error) {
parts := []interface{}{}
for len(input) > 0 {
if input[0] == '[' {
end := strings.IndexByte(input, ']')
if end < 0 {
return nil, fmt.Errorf("unmatched [")
}
num, err := strconv.Atoi(input[1:end])
if err != nil {
return nil, err
}
parts = append(parts, num)
input = input[end+1:]
if len(input) > 0 && input[0] == '.' {
input = input[1:] // Consume extra '.' after ']'
}
} else {
end := strings.IndexAny(input, ".[")
if end < 0 {
parts = append(parts, input)
input = ""
} else {
parts = append(parts, input[:end])
if input[end] == '.' {
input = input[end+1:]
} else {
input = input[end:]
}
}
}
}
return parts, nil
}

func traversalToAccessor(traversal hcl.Traversal) (accessor, error) {
parts := make(accessor, 0)
for _, traverser := range traversal {
switch t := traverser.(type) {
case hcl.TraverseRoot:
parts = append(parts, t.Name)
case hcl.TraverseAttr:
parts = append(parts, t.Name)
case hcl.TraverseIndex:
val := t.Key
if val.IsKnown() {
if val.Type() == cty.Number {
n := val.AsBigFloat()
if n.IsInt() {
i, _ := n.Int64()
parts = append(parts, int(i))
} else {
return nil, fmt.Errorf("Non-int number type in TraverseIndex")
}
} else if val.Type() == cty.String {
parts = append(parts, val.AsString())
} else {
return nil, fmt.Errorf("Unsupported type in TraverseIndex: %s", val.Type().GoString())
}
} else {
return nil, fmt.Errorf("Unknown value in TraverseIndex")
}
}
}
return parts, nil
}

// toLocalName tries to convert the accessor to a local name starting from the
// front. As soon as a non-string part is encountered, we stop and return the
// trailing accessor as well.
func (a accessor) toLocalName() (LocalName, accessor) {
name := make(LocalName, 0)
trailing := make(accessor, len(a))
copy(trailing, a)
for len(trailing) > 0 {
if str, ok := trailing[0].(string); ok {
name = append(name, str)
trailing = trailing[1:]
} else {
break
}
}
return name, trailing
}
93 changes: 93 additions & 0 deletions pkg/hcl_interpreter/accessor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package hcl_interpreter

import (
"fmt"
"testing"

"github.com/stretchr/testify/assert"
)

func TestAccessorToString(t *testing.T) {
tests := []struct {
input accessor
expected string
}{
{
input: accessor{"foo", "bar", 3, "qux"},
expected: "foo.bar[3].qux",
},
}
for i, test := range tests {
t.Run(fmt.Sprintf("case%02d", i), func(t *testing.T) {
actual := test.input.toString()
assert.Equal(t, test.expected, actual)
})
}
}

func TestStringToAccessor(t *testing.T) {
tests := []struct {
input string
expected accessor
err bool
}{
{
input: "foo.bar[3].qux",
expected: accessor{"foo", "bar", 3, "qux"},
},
{
input: "[1][2][3]",
expected: accessor{1, 2, 3},
},
{
input: "foo[3.qux",
err: true,
},
{
input: "foo.bar[three].qux",
err: true,
},
}
for i, test := range tests {
t.Run(fmt.Sprintf("case%02d", i), func(t *testing.T) {
actual, err := stringToAccessor(test.input)
if test.err {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, test.expected, actual)
}
})
}
}

func TestAccessorToLocalName(t *testing.T) {
tests := []struct {
input accessor
expected LocalName
trailing accessor
}{
{
input: accessor{"aws_s3_bucket", "my_bucket", 0, "id"},
expected: LocalName{"aws_s3_bucket", "my_bucket"},
trailing: accessor{0, "id"},
},
{
input: accessor{},
expected: LocalName{},
trailing: accessor{},
},
{
input: accessor{"aws_s3_bucket", "my_bucket", "id"},
expected: LocalName{"aws_s3_bucket", "my_bucket", "id"},
trailing: accessor{},
},
}
for i, test := range tests {
t.Run(fmt.Sprintf("case%02d", i), func(t *testing.T) {
actual, trailing := test.input.toLocalName()
assert.Equal(t, test.expected, actual)
assert.Equal(t, test.trailing, trailing)
})
}
}
14 changes: 7 additions & 7 deletions pkg/hcl_interpreter/hcl_interpreter.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func (v *Analysis) dependencies(name FullName, term Term) []dependency {
deps := []dependency{}
for _, termDependency := range term.Dependencies() {
traversal := termDependency.Traversal
local, err := TraversalToLocalName(traversal)
local, _, err := TraversalToLocalName(traversal)
if err != nil {
v.badKeys[TraversalToString(traversal)] = struct{}{}
continue
Expand Down Expand Up @@ -224,11 +224,15 @@ type Evaluation struct {
}

func EvaluateAnalysis(analysis *Analysis) (*Evaluation, error) {
isMultiResource := func(name FullName) bool {
meta, ok := analysis.Resources[name.ToString()]
return ok && meta.Multiple
}
eval := &Evaluation{
Analysis: analysis,
Modules: map[string]cty.Value{},
resourceAttributes: map[string]cty.Value{},
phantomAttrs: newPhantomAttrs(),
phantomAttrs: newPhantomAttrs(isMultiResource),
}

for moduleKey := range analysis.Modules {
Expand Down Expand Up @@ -306,11 +310,7 @@ func (v *Evaluation) evaluate() error {
}

v.resourceAttributes[name.ToString()] = val
patchMultiple := false
if resourceMeta, ok := v.Analysis.Resources[name.ToString()]; ok {
patchMultiple = resourceMeta.Multiple
}
val = v.phantomAttrs.add(name, patchMultiple, val)
val = v.phantomAttrs.add(name, val)
singleton := NestVal(name.Local, val)
v.Modules[moduleKey] = MergeVal(v.Modules[moduleKey], singleton)
}
Expand Down
46 changes: 18 additions & 28 deletions pkg/hcl_interpreter/names.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func ChildModuleName(moduleName ModuleName, childName string) ModuleName {
type LocalName []string

var (
// Supported fixed paths can be checked using Equals.
// Supported fixed paths can be checked using Equals.
PathModuleName = LocalName{"path", "module"}
PathRootName = LocalName{"path", "root"}
PathCwdName = LocalName{"path", "cwd"}
Expand All @@ -101,6 +101,14 @@ func (name LocalName) Equals(other LocalName) bool {
return true
}

func (name LocalName) ToAccessor() accessor {
accessor := make(accessor, len(name))
for i := range name {
accessor[i] = name[i]
}
return accessor
}

type FullName struct {
Module ModuleName
Local LocalName
Expand Down Expand Up @@ -217,34 +225,16 @@ func (name FullName) AsResourceName() (*FullName, LocalName) {
return nil, nil
}

// TODO: Refactor to TraversalToName?
func TraversalToLocalName(traversal hcl.Traversal) (LocalName, error) {
parts := make([]string, 0)

for _, traverser := range traversal {
switch t := traverser.(type) {
case hcl.TraverseRoot:
parts = append(parts, t.Name)
case hcl.TraverseAttr:
parts = append(parts, t.Name)
case hcl.TraverseIndex:
val := t.Key
if val.IsKnown() {
if val.Type() == cty.Number {
// The other parts must be trailing accessors.
return parts, nil
} else if val.Type() == cty.String {
parts = append(parts, val.AsString())
} else {
return nil, fmt.Errorf("Unsupported type in TraverseIndex: %s", val.Type().GoString())
}
} else {
return nil, fmt.Errorf("Unknown value in TraverseIndex")
}
}
// TraversalToLocalName returns the leading LocalName (strictly-string) part of
// a traversal as well as the trailing part (string-or-int).
func TraversalToLocalName(traversal hcl.Traversal) (LocalName, accessor, error) {
// Just delegate the conversion to the accessor.
accessor, err := traversalToAccessor(traversal)
if err != nil {
return nil, nil, err
}

return parts, nil
name, trailing := accessor.toLocalName()
return name, trailing, nil
}

func TraversalToString(traversal hcl.Traversal) string {
Expand Down
35 changes: 28 additions & 7 deletions pkg/hcl_interpreter/phantom_attrs.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,18 @@ import (
)

type phantomAttrs struct {
// Whether or not this is a multiresource (count or for_each).
isMultiResource func(FullName) bool
// A set of phantom attributes per FullName.
attrs map[string]map[string]struct{}
}

func newPhantomAttrs() *phantomAttrs {
func newPhantomAttrs(
isMultiResource func(FullName) bool,
) *phantomAttrs {
return &phantomAttrs{
attrs: map[string]map[string]struct{}{},
isMultiResource: isMultiResource,
attrs: map[string]map[string]struct{}{},
}
}

Expand All @@ -52,17 +57,32 @@ func (pa *phantomAttrs) analyze(name FullName, term Term) {
term.VisitExpressions(func(expr hcl.Expression) {
exprAttrs := exprAttributes(expr)
for _, traversal := range expr.Variables() {
local, err := TraversalToLocalName(traversal)
local, localTrailing, err := TraversalToLocalName(traversal)
if err != nil {
continue
}

full := FullName{Module: name.Module, Local: local}
if asResourceName, trailing := full.AsResourceName(); asResourceName != nil {
if asResourceName, resourceTrailing := full.AsResourceName(); asResourceName != nil {
attrs := map[string]struct{}{}
if len(trailing) > 0 {

// Construct the trailing part from the resourceTrailing
// and localTrailing parts defined above.
trailingAccessor := resourceTrailing.ToAccessor()
trailingAccessor = append(trailingAccessor, localTrailing...)

// We need to strip one element from the index if this is a
// multi-resource to drop the index or key.
if pa.isMultiResource(*asResourceName) && len(trailingAccessor) > 0 {
trailingAccessor = trailingAccessor[1:]
}

// Store the trailing part that was accessed in attrs.
if trailing, _ := trailingAccessor.toLocalName(); len(trailing) > 0 {
attrs[LocalNameToString(trailing)] = struct{}{}
}

// Store the other attrs.
for _, attr := range exprAttrs {
attrs[LocalNameToString(attr)] = struct{}{}
}
Expand All @@ -81,8 +101,9 @@ func (pa *phantomAttrs) analyze(name FullName, term Term) {
})
}

func (pa *phantomAttrs) add(name FullName, multiple bool, val cty.Value) cty.Value {
func (pa *phantomAttrs) add(name FullName, val cty.Value) cty.Value {
rk := name.ToString()
multiple := pa.isMultiResource(name)

var patch func(LocalName, string, cty.Value) cty.Value
patch = func(local LocalName, ref string, val cty.Value) cty.Value {
Expand Down Expand Up @@ -136,7 +157,7 @@ func exprAttributes(expr hcl.Expression) []LocalName {
hclsyntax.VisitAll(syn, func(node hclsyntax.Node) hcl.Diagnostics {
switch e := node.(type) {
case *hclsyntax.RelativeTraversalExpr:
if name, err := TraversalToLocalName(e.Traversal); err == nil {
if name, _, err := TraversalToLocalName(e.Traversal); err == nil {
names = append(names, name)
}
case *hclsyntax.IndexExpr:
Expand Down
Loading

0 comments on commit 1f77093

Please sign in to comment.