Skip to content

Commit

Permalink
Merge pull request hashicorp#34885 from OpenGLShaders/f-aws_db_cluste…
Browse files Browse the repository at this point in the history
…r_snapshot_sharing

Allow cross-account sharing of db_cluster_snapshot
  • Loading branch information
ewbankkit authored Jul 23, 2024
2 parents a83c617 + cbf21dd commit 0d6de4e
Show file tree
Hide file tree
Showing 11 changed files with 220 additions and 97 deletions.
3 changes: 3 additions & 0 deletions .changelog/34885.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:enhancement
resource/aws_db_cluster_snapshot: Add `shared_accounts` argument
```
186 changes: 138 additions & 48 deletions internal/service/rds/cluster_snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,28 @@ import (
"time"

"github.com/YakDriver/regexache"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/rds"
"github.com/hashicorp/aws-sdk-go-base/v2/awsv1shim/v2/tfawserr"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/rds"
"github.com/aws/aws-sdk-go-v2/service/rds/types"
"github.com/hashicorp/terraform-plugin-sdk/v2/diag"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/retry"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation"
"github.com/hashicorp/terraform-provider-aws/internal/conns"
"github.com/hashicorp/terraform-provider-aws/internal/errs"
"github.com/hashicorp/terraform-provider-aws/internal/errs/sdkdiag"
"github.com/hashicorp/terraform-provider-aws/internal/flex"
tfslices "github.com/hashicorp/terraform-provider-aws/internal/slices"
tftags "github.com/hashicorp/terraform-provider-aws/internal/tags"
"github.com/hashicorp/terraform-provider-aws/internal/tfresource"
"github.com/hashicorp/terraform-provider-aws/internal/verify"
"github.com/hashicorp/terraform-provider-aws/names"
)

const clusterSnapshotCreateTimeout = 2 * time.Minute

// @SDKResource("aws_db_cluster_snapshot", name="DB Cluster Snapshot")
// @Tags(identifierAttribute="db_cluster_snapshot_arn")
// @Testing(tagsTest=false)
func ResourceClusterSnapshot() *schema.Resource {
func resourceClusterSnapshot() *schema.Resource {
return &schema.Resource{
CreateWithoutTimeout: resourceClusterSnapshotCreate,
ReadWithoutTimeout: resourceClusterSnapshotRead,
Expand Down Expand Up @@ -96,6 +96,11 @@ func ResourceClusterSnapshot() *schema.Resource {
Type: schema.TypeInt,
Computed: true,
},
"shared_accounts": {
Type: schema.TypeSet,
Optional: true,
Elem: &schema.Schema{Type: schema.TypeString},
},
"source_db_cluster_snapshot_arn": {
Type: schema.TypeString,
Computed: true,
Expand Down Expand Up @@ -126,18 +131,21 @@ func ResourceClusterSnapshot() *schema.Resource {

func resourceClusterSnapshotCreate(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics {
var diags diag.Diagnostics
conn := meta.(*conns.AWSClient).RDSConn(ctx)
conn := meta.(*conns.AWSClient).RDSClient(ctx)

id := d.Get("db_cluster_snapshot_identifier").(string)
input := &rds.CreateDBClusterSnapshotInput{
DBClusterIdentifier: aws.String(d.Get("db_cluster_identifier").(string)),
DBClusterSnapshotIdentifier: aws.String(id),
Tags: getTagsIn(ctx),
Tags: getTagsInV2(ctx),
}

_, err := tfresource.RetryWhenAWSErrCodeEquals(ctx, clusterSnapshotCreateTimeout, func() (interface{}, error) {
return conn.CreateDBClusterSnapshotWithContext(ctx, input)
}, rds.ErrCodeInvalidDBClusterStateFault)
const (
timeout = 2 * time.Minute
)
_, err := tfresource.RetryWhenIsA[*types.InvalidDBClusterStateFault](ctx, timeout, func() (interface{}, error) {
return conn.CreateDBClusterSnapshot(ctx, input)
})

if err != nil {
return sdkdiag.AppendErrorf(diags, "creating RDS DB Cluster Snapshot (%s): %s", id, err)
Expand All @@ -149,14 +157,28 @@ func resourceClusterSnapshotCreate(ctx context.Context, d *schema.ResourceData,
return sdkdiag.AppendErrorf(diags, "waiting for RDS DB Cluster Snapshot (%s) create: %s", d.Id(), err)
}

if v, ok := d.GetOk("shared_accounts"); ok && v.(*schema.Set).Len() > 0 {
input := &rds.ModifyDBClusterSnapshotAttributeInput{
AttributeName: aws.String(clusterSnapshotAttributeNameRestore),
DBClusterSnapshotIdentifier: aws.String(d.Id()),
ValuesToAdd: flex.ExpandStringValueSet(v.(*schema.Set)),
}

_, err := conn.ModifyDBClusterSnapshotAttribute(ctx, input)

if err != nil {
return sdkdiag.AppendErrorf(diags, "modifying RDS DB Cluster Snapshot (%s) attribute: %s", d.Id(), err)
}
}

return append(diags, resourceClusterSnapshotRead(ctx, d, meta)...)
}

func resourceClusterSnapshotRead(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics {
var diags diag.Diagnostics
conn := meta.(*conns.AWSClient).RDSConn(ctx)
conn := meta.(*conns.AWSClient).RDSClient(ctx)

snapshot, err := FindDBClusterSnapshotByID(ctx, conn, d.Id())
snapshot, err := findDBClusterSnapshotByID(ctx, conn, d.Id())

if !d.IsNewResource() && tfresource.NotFound(err) {
log.Printf("[WARN] RDS DB Cluster Snapshot (%s) not found, removing from state", d.Id())
Expand All @@ -169,7 +191,7 @@ func resourceClusterSnapshotRead(ctx context.Context, d *schema.ResourceData, me
}

d.Set(names.AttrAllocatedStorage, snapshot.AllocatedStorage)
d.Set(names.AttrAvailabilityZones, aws.StringValueSlice(snapshot.AvailabilityZones))
d.Set(names.AttrAvailabilityZones, snapshot.AvailabilityZones)
d.Set("db_cluster_identifier", snapshot.DBClusterIdentifier)
d.Set("db_cluster_snapshot_arn", snapshot.DBClusterSnapshotArn)
d.Set("db_cluster_snapshot_identifier", snapshot.DBClusterSnapshotIdentifier)
Expand All @@ -184,29 +206,55 @@ func resourceClusterSnapshotRead(ctx context.Context, d *schema.ResourceData, me
d.Set(names.AttrStorageEncrypted, snapshot.StorageEncrypted)
d.Set(names.AttrVPCID, snapshot.VpcId)

setTagsOut(ctx, snapshot.TagList)
attribute, err := findDBClusterSnapshotAttributeByTwoPartKey(ctx, conn, d.Id(), clusterSnapshotAttributeNameRestore)
switch {
case err == nil:
d.Set("shared_accounts", attribute.AttributeValues)
case tfresource.NotFound(err):
default:
return sdkdiag.AppendErrorf(diags, "reading RDS DB Cluster Snapshot (%s) attribute: %s", d.Id(), err)
}

setTagsOutV2(ctx, snapshot.TagList)

return diags
}

func resourceClusterSnapshotUpdate(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics {
var diags diag.Diagnostics
conn := meta.(*conns.AWSClient).RDSClient(ctx)

if d.HasChange("shared_accounts") {
o, n := d.GetChange("shared_accounts")
os, ns := o.(*schema.Set), n.(*schema.Set)
add, del := ns.Difference(os), os.Difference(ns)
input := &rds.ModifyDBClusterSnapshotAttributeInput{
AttributeName: aws.String(clusterSnapshotAttributeNameRestore),
DBClusterSnapshotIdentifier: aws.String(d.Id()),
ValuesToAdd: flex.ExpandStringValueSet(add),
ValuesToRemove: flex.ExpandStringValueSet(del),
}

// Tags only.
_, err := conn.ModifyDBClusterSnapshotAttribute(ctx, input)

if err != nil {
return sdkdiag.AppendErrorf(diags, "modifying RDS DB Cluster Snapshot (%s) attribute: %s", d.Id(), err)
}
}

return append(diags, resourceClusterSnapshotRead(ctx, d, meta)...)
}

func resourceClusterSnapshotDelete(ctx context.Context, d *schema.ResourceData, meta interface{}) diag.Diagnostics {
var diags diag.Diagnostics
conn := meta.(*conns.AWSClient).RDSConn(ctx)
conn := meta.(*conns.AWSClient).RDSClient(ctx)

log.Printf("[DEBUG] Deleting RDS DB Cluster Snapshot: %s", d.Id())
_, err := conn.DeleteDBClusterSnapshotWithContext(ctx, &rds.DeleteDBClusterSnapshotInput{
_, err := conn.DeleteDBClusterSnapshot(ctx, &rds.DeleteDBClusterSnapshotInput{
DBClusterSnapshotIdentifier: aws.String(d.Id()),
})

if tfawserr.ErrCodeEquals(err, rds.ErrCodeDBClusterSnapshotNotFoundFault) {
if errs.IsA[*types.DBClusterSnapshotNotFoundFault](err) {
return diags
}

Expand All @@ -217,18 +265,18 @@ func resourceClusterSnapshotDelete(ctx context.Context, d *schema.ResourceData,
return diags
}

func FindDBClusterSnapshotByID(ctx context.Context, conn *rds.RDS, id string) (*rds.DBClusterSnapshot, error) {
func findDBClusterSnapshotByID(ctx context.Context, conn *rds.Client, id string) (*types.DBClusterSnapshot, error) {
input := &rds.DescribeDBClusterSnapshotsInput{
DBClusterSnapshotIdentifier: aws.String(id),
}
output, err := findDBClusterSnapshot(ctx, conn, input, tfslices.PredicateTrue[*rds.DBClusterSnapshot]())
output, err := findDBClusterSnapshot(ctx, conn, input, tfslices.PredicateTrue[*types.DBClusterSnapshot]())

if err != nil {
return nil, err
}

// Eventual consistency check.
if aws.StringValue(output.DBClusterSnapshotIdentifier) != id {
if aws.ToString(output.DBClusterSnapshotIdentifier) != id {
return nil, &retry.NotFoundError{
LastRequest: input,
}
Expand All @@ -237,50 +285,47 @@ func FindDBClusterSnapshotByID(ctx context.Context, conn *rds.RDS, id string) (*
return output, nil
}

func findDBClusterSnapshot(ctx context.Context, conn *rds.RDS, input *rds.DescribeDBClusterSnapshotsInput, filter tfslices.Predicate[*rds.DBClusterSnapshot]) (*rds.DBClusterSnapshot, error) {
func findDBClusterSnapshot(ctx context.Context, conn *rds.Client, input *rds.DescribeDBClusterSnapshotsInput, filter tfslices.Predicate[*types.DBClusterSnapshot]) (*types.DBClusterSnapshot, error) {
output, err := findDBClusterSnapshots(ctx, conn, input, filter)

if err != nil {
return nil, err
}

return tfresource.AssertSinglePtrResult(output)
return tfresource.AssertSingleValueResult(output)
}

func findDBClusterSnapshots(ctx context.Context, conn *rds.RDS, input *rds.DescribeDBClusterSnapshotsInput, filter tfslices.Predicate[*rds.DBClusterSnapshot]) ([]*rds.DBClusterSnapshot, error) {
var output []*rds.DBClusterSnapshot
func findDBClusterSnapshots(ctx context.Context, conn *rds.Client, input *rds.DescribeDBClusterSnapshotsInput, filter tfslices.Predicate[*types.DBClusterSnapshot]) ([]types.DBClusterSnapshot, error) {
var output []types.DBClusterSnapshot

err := conn.DescribeDBClusterSnapshotsPagesWithContext(ctx, input, func(page *rds.DescribeDBClusterSnapshotsOutput, lastPage bool) bool {
if page == nil {
return !lastPage
}
pages := rds.NewDescribeDBClusterSnapshotsPaginator(conn, input)
for pages.HasMorePages() {
page, err := pages.NextPage(ctx)

for _, v := range page.DBClusterSnapshots {
if v != nil && filter(v) {
output = append(output, v)
if errs.IsA[*types.DBClusterSnapshotNotFoundFault](err) {
return nil, &retry.NotFoundError{
LastError: err,
LastRequest: input,
}
}

return !lastPage
})

if tfawserr.ErrCodeEquals(err, rds.ErrCodeDBClusterSnapshotNotFoundFault) {
return nil, &retry.NotFoundError{
LastError: err,
LastRequest: input,
if err != nil {
return nil, err
}
}

if err != nil {
return nil, err
for _, v := range page.DBClusterSnapshots {
if filter(&v) {
output = append(output, v)
}
}
}

return output, nil
}

func statusDBClusterSnapshot(ctx context.Context, conn *rds.RDS, id string) retry.StateRefreshFunc {
func statusDBClusterSnapshot(ctx context.Context, conn *rds.Client, id string) retry.StateRefreshFunc {
return func() (interface{}, string, error) {
output, err := FindDBClusterSnapshotByID(ctx, conn, id)
output, err := findDBClusterSnapshotByID(ctx, conn, id)

if tfresource.NotFound(err) {
return nil, "", nil
Expand All @@ -290,11 +335,11 @@ func statusDBClusterSnapshot(ctx context.Context, conn *rds.RDS, id string) retr
return nil, "", err
}

return output, aws.StringValue(output.Status), nil
return output, aws.ToString(output.Status), nil
}
}

func waitDBClusterSnapshotCreated(ctx context.Context, conn *rds.RDS, id string, timeout time.Duration) (*rds.DBClusterSnapshot, error) {
func waitDBClusterSnapshotCreated(ctx context.Context, conn *rds.Client, id string, timeout time.Duration) (*types.DBClusterSnapshot, error) {
stateConf := &retry.StateChangeConf{
Pending: []string{clusterSnapshotStatusCreating},
Target: []string{clusterSnapshotStatusAvailable},
Expand All @@ -306,9 +351,54 @@ func waitDBClusterSnapshotCreated(ctx context.Context, conn *rds.RDS, id string,

outputRaw, err := stateConf.WaitForStateContext(ctx)

if output, ok := outputRaw.(*rds.DBClusterSnapshot); ok {
if output, ok := outputRaw.(*types.DBClusterSnapshot); ok {
return output, err
}

return nil, err
}

func findDBClusterSnapshotAttributeByTwoPartKey(ctx context.Context, conn *rds.Client, id, attributeName string) (*types.DBClusterSnapshotAttribute, error) {
input := &rds.DescribeDBClusterSnapshotAttributesInput{
DBClusterSnapshotIdentifier: aws.String(id),
}

return findDBClusterSnapshotAttribute(ctx, conn, input, func(v *types.DBClusterSnapshotAttribute) bool {
return aws.ToString(v.AttributeName) == attributeName
})
}

func findDBClusterSnapshotAttribute(ctx context.Context, conn *rds.Client, input *rds.DescribeDBClusterSnapshotAttributesInput, filter tfslices.Predicate[*types.DBClusterSnapshotAttribute]) (*types.DBClusterSnapshotAttribute, error) {
output, err := findDBClusterSnapshotAttributes(ctx, conn, input, filter)

if err != nil {
return nil, err
}

return tfresource.AssertSingleValueResult(output)
}

func findDBClusterSnapshotAttributes(ctx context.Context, conn *rds.Client, input *rds.DescribeDBClusterSnapshotAttributesInput, filter tfslices.Predicate[*types.DBClusterSnapshotAttribute]) ([]types.DBClusterSnapshotAttribute, error) {
output, err := conn.DescribeDBClusterSnapshotAttributes(ctx, input)

if errs.IsA[*types.DBClusterSnapshotNotFoundFault](err) {
return nil, &retry.NotFoundError{
LastError: err,
LastRequest: input,
}
}

if err != nil {
return nil, err
}

if output == nil || output.DBClusterSnapshotAttributesResult == nil {
return nil, tfresource.NewEmptyResultError(input)
}

f := func(v types.DBClusterSnapshotAttribute) bool {
return filter(&v)
}

return tfslices.Filter(output.DBClusterSnapshotAttributesResult.DBClusterSnapshotAttributes, f), nil
}
Loading

0 comments on commit 0d6de4e

Please sign in to comment.