Skip to content

Commit

Permalink
feat(GROW-2908): lwgenerate enable custom root terraform blocks and p…
Browse files Browse the repository at this point in the history
…rovider arguments (#1626)

* feat(GROW-2908): enable adding generic hcl block to root terraform block

* feat(GROW-2908): enable adding custom provider arguments

* feat(GROW-2908): enable setting arbitrary blocks on root hcl doc
  • Loading branch information
Matt Cadorette committed May 16, 2024
1 parent 8c76d48 commit 8d50d31
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 13 deletions.
54 changes: 47 additions & 7 deletions lwgenerate/aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,15 @@ type GenerateAwsTfConfigurationArgs struct {

// Default AWS Provider Tags
ProviderDefaultTags map[string]interface{}

// Add custom blocks to the root `terraform{}` block. Can be used for advanced configuration. Things like backend, etc
ExtraBlocksRootTerraform []*hclwrite.Block

// ExtraProviderArguments allows adding more arguments to the provider block as needed (custom use cases)
ExtraProviderArguments map[string]interface{}

// ExtraBlocks allows adding more hclwrite.Block to the root terraform document (advanced use cases)
ExtraBlocks []*hclwrite.Block
}

func (args *GenerateAwsTfConfigurationArgs) IsEmpty() bool {
Expand Down Expand Up @@ -732,6 +741,29 @@ func WithS3BucketNotification(s3BucketNotifiaction bool) AwsTerraformModifier {
}
}

// WithExtraRootBlocks allows adding generic hcl blocks to the root `terraform{}` block
// this enables custom use cases
func WithExtraRootBlocks(blocks []*hclwrite.Block) AwsTerraformModifier {
return func(c *GenerateAwsTfConfigurationArgs) {
c.ExtraBlocksRootTerraform = blocks
}
}

// WithExtraProviderArguments enables adding additional arguments into the `aws` provider block
// this enables custom use cases
func WithExtraProviderArguments(arguments map[string]interface{}) AwsTerraformModifier {
return func(c *GenerateAwsTfConfigurationArgs) {
c.ExtraProviderArguments = arguments
}
}

// WithExtraBlocks enables adding additional arbitrary blocks to the root hcl document
func WithExtraBlocks(blocks []*hclwrite.Block) AwsTerraformModifier {
return func(c *GenerateAwsTfConfigurationArgs) {
c.ExtraBlocks = blocks
}
}

// Generate new Terraform code based on the supplied args.
func (args *GenerateAwsTfConfigurationArgs) Generate() (string, error) {
// Validate inputs
Expand All @@ -740,7 +772,7 @@ func (args *GenerateAwsTfConfigurationArgs) Generate() (string, error) {
}

// Create blocks
requiredProviders, err := createRequiredProviders()
requiredProviders, err := createRequiredProviders(args.ExtraBlocksRootTerraform)
if err != nil {
return "", errors.Wrap(err, "failed to generate required providers")
}
Expand Down Expand Up @@ -788,13 +820,16 @@ func (args *GenerateAwsTfConfigurationArgs) Generate() (string, error) {
configModule,
cloudTrailModule,
agentlessModule,
outputBlocks),
outputBlocks,
args.ExtraBlocks,
),
)
return hclBlocks, nil
}

func createRequiredProviders() (*hclwrite.Block, error) {
return lwgenerate.CreateRequiredProviders(
func createRequiredProviders(extraBlocks []*hclwrite.Block) (*hclwrite.Block, error) {
return lwgenerate.CreateRequiredProvidersWithCustomBlocks(
extraBlocks,
lwgenerate.NewRequiredProvider("lacework",
lwgenerate.HclRequiredProviderWithSource(lwgenerate.LaceworkProviderSource),
lwgenerate.HclRequiredProviderWithVersion(lwgenerate.LaceworkProviderVersion)))
Expand All @@ -803,11 +838,16 @@ func createRequiredProviders() (*hclwrite.Block, error) {
func createAwsProvider(args *GenerateAwsTfConfigurationArgs) ([]*hclwrite.Block, error) {
blocks := []*hclwrite.Block{}

attributes := map[string]interface{}{
"alias": "main",
"region": args.AwsRegion,
attributes := map[string]interface{}{}

// set custom args before the required ones below to ensure expected behavior (i.e., no overrides)
for k, v := range args.ExtraProviderArguments {
attributes[k] = v
}

// required defaults
attributes["alias"] = "main"
attributes["region"] = args.AwsRegion
if args.AwsProfile != "" {
attributes["profile"] = args.AwsProfile
}
Expand Down
52 changes: 52 additions & 0 deletions lwgenerate/aws/aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"testing"

"github.com/hashicorp/hcl/v2/hclwrite"
"github.com/lacework/go-sdk/lwgenerate"
"github.com/stretchr/testify/assert"
)
Expand Down Expand Up @@ -91,6 +92,35 @@ func TestGenerationConfig(t *testing.T) {
assert.Equal(t, reqProviderAndRegion(moduleImportConfig), hcl)
}

func TestGenerationConfigWithExtraBlocks(t *testing.T) {
extraBlock, err := lwgenerate.HclCreateGenericBlock("variable", []string{"var_name"}, nil)
assert.NoError(t, err)

hcl, err := NewTerraform(false, false, true, false,
WithAwsRegion("us-east-2"), WithExtraBlocks([]*hclwrite.Block{extraBlock})).Generate()
assert.Nil(t, err)
assert.NotNil(t, hcl)
assert.Equal(t, requiredProviders+"\n"+awsProvider+"\n"+moduleImportConfig+"\n"+testVariable, hcl)
}

func TestGenerationConfigWithCustomBackendBlock(t *testing.T) {
customBlock, err := lwgenerate.HclCreateGenericBlock("backend", []string{"s3"}, nil)
assert.NoError(t, err)
hcl, err := NewTerraform(false, false, true, false, WithAwsRegion("us-east-2"),
WithExtraRootBlocks([]*hclwrite.Block{customBlock})).Generate()
assert.Nil(t, err)
assert.NotNil(t, hcl)
assert.Equal(t, requiredProvidersWithCustomBlock+"\n"+awsProvider+"\n"+moduleImportConfig, hcl)
}

func TestGenerationConfigWithCustomProviderAttributes(t *testing.T) {
hcl, err := NewTerraform(false, false, true, false, WithAwsRegion("us-east-2"),
WithExtraProviderArguments(map[string]interface{}{"foo": "bar"})).Generate()
assert.Nil(t, err)
assert.NotNil(t, hcl)
assert.Equal(t, requiredProviders+"\n"+awsProviderExtraArguments+"\n"+moduleImportConfig, hcl)
}

func TestGenerationConfigWithOutputs(t *testing.T) {
hcl, err := NewTerraform(
false, false, true, false, WithAwsRegion("us-east-2"),
Expand Down Expand Up @@ -390,6 +420,18 @@ func TestGenerationCloudTrailS3BucketNotification(t *testing.T) {
)
}

var requiredProvidersWithCustomBlock = `terraform {
required_providers {
lacework = {
source = "lacework/lacework"
version = "~> 1.0"
}
}
backend "s3" {
}
}
`

var requiredProviders = `terraform {
required_providers {
lacework = {
Expand All @@ -399,6 +441,12 @@ var requiredProviders = `terraform {
}
}
`
var awsProviderExtraArguments = `provider "aws" {
alias = "main"
foo = "bar"
region = "us-east-2"
}
`

var awsProvider = `provider "aws" {
alias = "main"
Expand Down Expand Up @@ -826,3 +874,7 @@ var moduleImportCtWithS3BucketNotification = `module "main_cloudtrail" {
}
}
`

var testVariable = `variable "var_name" {
}
`
50 changes: 44 additions & 6 deletions lwgenerate/hcl.go
Original file line number Diff line number Diff line change
Expand Up @@ -556,13 +556,13 @@ func CreateHclStringOutput(blocks []*hclwrite.Block) string {
return string(file.Bytes())
}

// CreateRequiredProviders Create required providers block
func CreateRequiredProviders(providers ...*HclRequiredProvider) (*hclwrite.Block, error) {
block, err := HclCreateGenericBlock("terraform", nil, nil)
if err != nil {
return nil, err
}
// rootTerraformBlock is a helper that creates the literal `terraform{}` hcl block
func rootTerraformBlock() (*hclwrite.Block, error) {
return HclCreateGenericBlock("terraform", nil, nil)
}

// createRequiredProviders is a helper that creates the `required_providers` hcl block
func createRequiredProviders(providers ...*HclRequiredProvider) (*hclwrite.Block, error) {
providerDetails := map[string]interface{}{}
for _, provider := range providers {
details := map[string]interface{}{}
Expand All @@ -579,7 +579,45 @@ func CreateRequiredProviders(providers ...*HclRequiredProvider) (*hclwrite.Block
if err != nil {
return nil, err
}

return requiredProviders, nil
}

// CreateRequiredProviders Create required providers block
func CreateRequiredProviders(providers ...*HclRequiredProvider) (*hclwrite.Block, error) {
block, err := rootTerraformBlock()
if err != nil {
return nil, err
}

requiredProviders, err := createRequiredProviders(providers...)
if err != nil {
return nil, err
}

block.Body().AppendBlock(requiredProviders)
return block, nil
}

// CreateRequiredProviders Create required providers block
func CreateRequiredProvidersWithCustomBlocks(
blocks []*hclwrite.Block,
providers ...*HclRequiredProvider,
) (*hclwrite.Block, error) {
block, err := rootTerraformBlock()
if err != nil {
return nil, err
}

requiredProviders, err := createRequiredProviders(providers...)
if err != nil {
return nil, err
}

block.Body().AppendBlock(requiredProviders)
for _, customBlock := range blocks {
block.Body().AppendBlock(customBlock)
}

return block, nil
}
Expand Down
34 changes: 34 additions & 0 deletions lwgenerate/hcl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,22 @@ func TestRequiredProvidersBlock(t *testing.T) {
assert.Equal(t, testRequiredProvider, lwgenerate.CreateHclStringOutput([]*hclwrite.Block{data}))
}

func TestRequiredProvidersBlockWithCustomBlocks(t *testing.T) {
provider1 := lwgenerate.NewRequiredProvider("foo",
lwgenerate.HclRequiredProviderWithSource("test/test"))
provider2 := lwgenerate.NewRequiredProvider("bar",
lwgenerate.HclRequiredProviderWithVersion("~> 0.1"))
provider3 := lwgenerate.NewRequiredProvider("lacework",
lwgenerate.HclRequiredProviderWithSource("lacework/lacework"),
lwgenerate.HclRequiredProviderWithVersion("~> 0.1"))

customBlock, err := lwgenerate.HclCreateGenericBlock("backend", []string{"s3"}, nil)
assert.NoError(t, err)
data, err := lwgenerate.CreateRequiredProvidersWithCustomBlocks([]*hclwrite.Block{customBlock}, provider1, provider2, provider3)
assert.Nil(t, err)
assert.Equal(t, testRequiredProviderWithCustomBlocks, lwgenerate.CreateHclStringOutput([]*hclwrite.Block{data}))
}

func TestModuleBlockWithComplexAttributes(t *testing.T) {
data, err := lwgenerate.NewModule("foo",
"mycorp/mycloud",
Expand Down Expand Up @@ -192,6 +208,24 @@ func TestOutputBlockCreation(t *testing.T) {
})
}

var testRequiredProviderWithCustomBlocks = `terraform {
required_providers {
bar = {
version = "~> 0.1"
}
foo = {
source = "test/test"
}
lacework = {
source = "lacework/lacework"
version = "~> 0.1"
}
}
backend "s3" {
}
}
`

var testRequiredProvider = `terraform {
required_providers {
bar = {
Expand Down

0 comments on commit 8d50d31

Please sign in to comment.