diff --git a/lwgenerate/aws/aws.go b/lwgenerate/aws/aws.go index dfcbee7d7..4b30618fc 100644 --- a/lwgenerate/aws/aws.go +++ b/lwgenerate/aws/aws.go @@ -296,6 +296,15 @@ type GenerateAwsTfConfigurationArgs struct { // Default AWS Provider Tags ProviderDefaultTags map[string]interface{} + + // Add custom blocks to the root `terraform{}` block. Can be used for advanced configuration. Things like backend, etc + ExtraBlocksRootTerraform []*hclwrite.Block + + // ExtraProviderArguments allows adding more arguments to the provider block as needed (custom use cases) + ExtraProviderArguments map[string]interface{} + + // ExtraBlocks allows adding more hclwrite.Block to the root terraform document (advanced use cases) + ExtraBlocks []*hclwrite.Block } func (args *GenerateAwsTfConfigurationArgs) IsEmpty() bool { @@ -732,6 +741,29 @@ func WithS3BucketNotification(s3BucketNotifiaction bool) AwsTerraformModifier { } } +// WithExtraRootBlocks allows adding generic hcl blocks to the root `terraform{}` block +// this enables custom use cases +func WithExtraRootBlocks(blocks []*hclwrite.Block) AwsTerraformModifier { + return func(c *GenerateAwsTfConfigurationArgs) { + c.ExtraBlocksRootTerraform = blocks + } +} + +// WithExtraProviderArguments enables adding additional arguments into the `aws` provider block +// this enables custom use cases +func WithExtraProviderArguments(arguments map[string]interface{}) AwsTerraformModifier { + return func(c *GenerateAwsTfConfigurationArgs) { + c.ExtraProviderArguments = arguments + } +} + +// WithExtraBlocks enables adding additional arbitrary blocks to the root hcl document +func WithExtraBlocks(blocks []*hclwrite.Block) AwsTerraformModifier { + return func(c *GenerateAwsTfConfigurationArgs) { + c.ExtraBlocks = blocks + } +} + // Generate new Terraform code based on the supplied args. func (args *GenerateAwsTfConfigurationArgs) Generate() (string, error) { // Validate inputs @@ -740,7 +772,7 @@ func (args *GenerateAwsTfConfigurationArgs) Generate() (string, error) { } // Create blocks - requiredProviders, err := createRequiredProviders() + requiredProviders, err := createRequiredProviders(args.ExtraBlocksRootTerraform) if err != nil { return "", errors.Wrap(err, "failed to generate required providers") } @@ -788,13 +820,16 @@ func (args *GenerateAwsTfConfigurationArgs) Generate() (string, error) { configModule, cloudTrailModule, agentlessModule, - outputBlocks), + outputBlocks, + args.ExtraBlocks, + ), ) return hclBlocks, nil } -func createRequiredProviders() (*hclwrite.Block, error) { - return lwgenerate.CreateRequiredProviders( +func createRequiredProviders(extraBlocks []*hclwrite.Block) (*hclwrite.Block, error) { + return lwgenerate.CreateRequiredProvidersWithCustomBlocks( + extraBlocks, lwgenerate.NewRequiredProvider("lacework", lwgenerate.HclRequiredProviderWithSource(lwgenerate.LaceworkProviderSource), lwgenerate.HclRequiredProviderWithVersion(lwgenerate.LaceworkProviderVersion))) @@ -803,11 +838,16 @@ func createRequiredProviders() (*hclwrite.Block, error) { func createAwsProvider(args *GenerateAwsTfConfigurationArgs) ([]*hclwrite.Block, error) { blocks := []*hclwrite.Block{} - attributes := map[string]interface{}{ - "alias": "main", - "region": args.AwsRegion, + attributes := map[string]interface{}{} + + // set custom args before the required ones below to ensure expected behavior (i.e., no overrides) + for k, v := range args.ExtraProviderArguments { + attributes[k] = v } + // required defaults + attributes["alias"] = "main" + attributes["region"] = args.AwsRegion if args.AwsProfile != "" { attributes["profile"] = args.AwsProfile } diff --git a/lwgenerate/aws/aws_test.go b/lwgenerate/aws/aws_test.go index 55b3c2063..3c445947b 100644 --- a/lwgenerate/aws/aws_test.go +++ b/lwgenerate/aws/aws_test.go @@ -4,6 +4,7 @@ import ( "fmt" "testing" + "github.com/hashicorp/hcl/v2/hclwrite" "github.com/lacework/go-sdk/lwgenerate" "github.com/stretchr/testify/assert" ) @@ -91,6 +92,35 @@ func TestGenerationConfig(t *testing.T) { assert.Equal(t, reqProviderAndRegion(moduleImportConfig), hcl) } +func TestGenerationConfigWithExtraBlocks(t *testing.T) { + extraBlock, err := lwgenerate.HclCreateGenericBlock("variable", []string{"var_name"}, nil) + assert.NoError(t, err) + + hcl, err := NewTerraform(false, false, true, false, + WithAwsRegion("us-east-2"), WithExtraBlocks([]*hclwrite.Block{extraBlock})).Generate() + assert.Nil(t, err) + assert.NotNil(t, hcl) + assert.Equal(t, requiredProviders+"\n"+awsProvider+"\n"+moduleImportConfig+"\n"+testVariable, hcl) +} + +func TestGenerationConfigWithCustomBackendBlock(t *testing.T) { + customBlock, err := lwgenerate.HclCreateGenericBlock("backend", []string{"s3"}, nil) + assert.NoError(t, err) + hcl, err := NewTerraform(false, false, true, false, WithAwsRegion("us-east-2"), + WithExtraRootBlocks([]*hclwrite.Block{customBlock})).Generate() + assert.Nil(t, err) + assert.NotNil(t, hcl) + assert.Equal(t, requiredProvidersWithCustomBlock+"\n"+awsProvider+"\n"+moduleImportConfig, hcl) +} + +func TestGenerationConfigWithCustomProviderAttributes(t *testing.T) { + hcl, err := NewTerraform(false, false, true, false, WithAwsRegion("us-east-2"), + WithExtraProviderArguments(map[string]interface{}{"foo": "bar"})).Generate() + assert.Nil(t, err) + assert.NotNil(t, hcl) + assert.Equal(t, requiredProviders+"\n"+awsProviderExtraArguments+"\n"+moduleImportConfig, hcl) +} + func TestGenerationConfigWithOutputs(t *testing.T) { hcl, err := NewTerraform( false, false, true, false, WithAwsRegion("us-east-2"), @@ -390,6 +420,18 @@ func TestGenerationCloudTrailS3BucketNotification(t *testing.T) { ) } +var requiredProvidersWithCustomBlock = `terraform { + required_providers { + lacework = { + source = "lacework/lacework" + version = "~> 1.0" + } + } + backend "s3" { + } +} +` + var requiredProviders = `terraform { required_providers { lacework = { @@ -399,6 +441,12 @@ var requiredProviders = `terraform { } } ` +var awsProviderExtraArguments = `provider "aws" { + alias = "main" + foo = "bar" + region = "us-east-2" +} +` var awsProvider = `provider "aws" { alias = "main" @@ -826,3 +874,7 @@ var moduleImportCtWithS3BucketNotification = `module "main_cloudtrail" { } } ` + +var testVariable = `variable "var_name" { +} +` diff --git a/lwgenerate/hcl.go b/lwgenerate/hcl.go index cf3c52a48..5c796ae70 100644 --- a/lwgenerate/hcl.go +++ b/lwgenerate/hcl.go @@ -556,13 +556,13 @@ func CreateHclStringOutput(blocks []*hclwrite.Block) string { return string(file.Bytes()) } -// CreateRequiredProviders Create required providers block -func CreateRequiredProviders(providers ...*HclRequiredProvider) (*hclwrite.Block, error) { - block, err := HclCreateGenericBlock("terraform", nil, nil) - if err != nil { - return nil, err - } +// rootTerraformBlock is a helper that creates the literal `terraform{}` hcl block +func rootTerraformBlock() (*hclwrite.Block, error) { + return HclCreateGenericBlock("terraform", nil, nil) +} +// createRequiredProviders is a helper that creates the `required_providers` hcl block +func createRequiredProviders(providers ...*HclRequiredProvider) (*hclwrite.Block, error) { providerDetails := map[string]interface{}{} for _, provider := range providers { details := map[string]interface{}{} @@ -579,7 +579,45 @@ func CreateRequiredProviders(providers ...*HclRequiredProvider) (*hclwrite.Block if err != nil { return nil, err } + + return requiredProviders, nil +} + +// CreateRequiredProviders Create required providers block +func CreateRequiredProviders(providers ...*HclRequiredProvider) (*hclwrite.Block, error) { + block, err := rootTerraformBlock() + if err != nil { + return nil, err + } + + requiredProviders, err := createRequiredProviders(providers...) + if err != nil { + return nil, err + } + + block.Body().AppendBlock(requiredProviders) + return block, nil +} + +// CreateRequiredProviders Create required providers block +func CreateRequiredProvidersWithCustomBlocks( + blocks []*hclwrite.Block, + providers ...*HclRequiredProvider, +) (*hclwrite.Block, error) { + block, err := rootTerraformBlock() + if err != nil { + return nil, err + } + + requiredProviders, err := createRequiredProviders(providers...) + if err != nil { + return nil, err + } + block.Body().AppendBlock(requiredProviders) + for _, customBlock := range blocks { + block.Body().AppendBlock(customBlock) + } return block, nil } diff --git a/lwgenerate/hcl_test.go b/lwgenerate/hcl_test.go index 9c360e4b0..4263e24c7 100644 --- a/lwgenerate/hcl_test.go +++ b/lwgenerate/hcl_test.go @@ -149,6 +149,22 @@ func TestRequiredProvidersBlock(t *testing.T) { assert.Equal(t, testRequiredProvider, lwgenerate.CreateHclStringOutput([]*hclwrite.Block{data})) } +func TestRequiredProvidersBlockWithCustomBlocks(t *testing.T) { + provider1 := lwgenerate.NewRequiredProvider("foo", + lwgenerate.HclRequiredProviderWithSource("test/test")) + provider2 := lwgenerate.NewRequiredProvider("bar", + lwgenerate.HclRequiredProviderWithVersion("~> 0.1")) + provider3 := lwgenerate.NewRequiredProvider("lacework", + lwgenerate.HclRequiredProviderWithSource("lacework/lacework"), + lwgenerate.HclRequiredProviderWithVersion("~> 0.1")) + + customBlock, err := lwgenerate.HclCreateGenericBlock("backend", []string{"s3"}, nil) + assert.NoError(t, err) + data, err := lwgenerate.CreateRequiredProvidersWithCustomBlocks([]*hclwrite.Block{customBlock}, provider1, provider2, provider3) + assert.Nil(t, err) + assert.Equal(t, testRequiredProviderWithCustomBlocks, lwgenerate.CreateHclStringOutput([]*hclwrite.Block{data})) +} + func TestModuleBlockWithComplexAttributes(t *testing.T) { data, err := lwgenerate.NewModule("foo", "mycorp/mycloud", @@ -192,6 +208,24 @@ func TestOutputBlockCreation(t *testing.T) { }) } +var testRequiredProviderWithCustomBlocks = `terraform { + required_providers { + bar = { + version = "~> 0.1" + } + foo = { + source = "test/test" + } + lacework = { + source = "lacework/lacework" + version = "~> 0.1" + } + } + backend "s3" { + } +} +` + var testRequiredProvider = `terraform { required_providers { bar = {