Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transfer Server VPC Endpoint Type and User Home Directory Type / Mapping #12599

Merged
merged 13 commits into from
Sep 24, 2020
281 changes: 258 additions & 23 deletions aws/resource_aws_transfer_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ package aws
import (
"fmt"
"log"
"regexp"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/transfer"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
Expand Down Expand Up @@ -41,6 +41,7 @@ func resourceAwsTransferServer() *schema.Resource {
Default: transfer.EndpointTypePublic,
ValidateFunc: validation.StringInSlice([]string{
transfer.EndpointTypePublic,
transfer.EndpointTypeVpc,
transfer.EndpointTypeVpcEndpoint,
}, false),
},
Expand All @@ -52,23 +53,48 @@ func resourceAwsTransferServer() *schema.Resource {
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"vpc_endpoint_id": {
Type: schema.TypeString,
Required: true,
ValidateFunc: func(v interface{}, k string) (ws []string, errors []error) {
value := v.(string)
validNamePattern := "^vpce-[0-9a-f]{17}$"
validName, nameMatchErr := regexp.MatchString(validNamePattern, value)
if !validName || nameMatchErr != nil {
errors = append(errors, fmt.Errorf(
"%q must match regex '%v'", k, validNamePattern))
Type: schema.TypeString,
Optional: true,
ConflictsWith: []string{"endpoint_details.0.address_allocation_ids", "endpoint_details.0.subnet_ids", "endpoint_details.0.vpc_id"},
DiffSuppressFunc: func(k, o, n string, d *schema.ResourceData) bool {
if n == "" && d.Get("endpoint_type").(string) == transfer.EndpointTypeVpc {
return true
}
return
return false
},
},
"address_allocation_ids": {
Type: schema.TypeSet,
Optional: true,
Elem: &schema.Schema{Type: schema.TypeString},
Set: schema.HashString,
ConflictsWith: []string{"endpoint_details.0.vpc_endpoint_id"},
},
"subnet_ids": {
Type: schema.TypeSet,
Optional: true,
Elem: &schema.Schema{Type: schema.TypeString},
Set: schema.HashString,
ConflictsWith: []string{"endpoint_details.0.vpc_endpoint_id"},
},
"vpc_id": {
Type: schema.TypeString,
Optional: true,
ValidateFunc: validation.NoZeroValues,
ConflictsWith: []string{"endpoint_details.0.vpc_endpoint_id"},
},
},
},
},

"vpce_security_group_ids": {
Type: schema.TypeSet,
Optional: true,
Computed: true,
Elem: &schema.Schema{Type: schema.TypeString},
Set: schema.HashString,
},

"host_key": {
Type: schema.TypeString,
Optional: true,
Expand Down Expand Up @@ -121,6 +147,7 @@ func resourceAwsTransferServer() *schema.Resource {
}

func resourceAwsTransferServerCreate(d *schema.ResourceData, meta interface{}) error {
updateAfterCreate := false
conn := meta.(*AWSClient).transferconn
tags := keyvaluetags.New(d.Get("tags").(map[string]interface{})).IgnoreAws().TransferTags()
createOpts := &transfer.CreateServerInput{}
Expand Down Expand Up @@ -156,6 +183,11 @@ func resourceAwsTransferServerCreate(d *schema.ResourceData, meta interface{}) e

if attr, ok := d.GetOk("endpoint_details"); ok {
createOpts.EndpointDetails = expandTransferServerEndpointDetails(attr.([]interface{}))

if createOpts.EndpointDetails.AddressAllocationIds != nil {
createOpts.EndpointDetails.AddressAllocationIds = nil
updateAfterCreate = true
}
}

if attr, ok := d.GetOk("host_key"); ok {
Expand All @@ -171,6 +203,17 @@ func resourceAwsTransferServerCreate(d *schema.ResourceData, meta interface{}) e

d.SetId(*resp.ServerId)

if updateAfterCreate {
updateOpts := &transfer.UpdateServerInput{
ServerId: aws.String(d.Id()),
EndpointDetails: expandTransferServerEndpointDetails(d.Get("endpoint_details").([]interface{})),
}

if err := doTransferServerUpdate(d, updateOpts, conn, meta.(*AWSClient).ec2conn, true); err != nil {
return err
}
}

return resourceAwsTransferServerRead(d, meta)
}

Expand Down Expand Up @@ -210,6 +253,15 @@ func resourceAwsTransferServerRead(d *schema.ResourceData, meta interface{}) err
d.Set("logging_role", resp.Server.LoggingRole)
d.Set("host_key_fingerprint", resp.Server.HostKeyFingerprint)

if resp.Server.EndpointDetails.VpcEndpointId != nil {
out, err := describeTransferServerVPCEndpoint(meta.(*AWSClient).ec2conn, resp.Server.EndpointDetails.VpcEndpointId)
if err != nil {
return err
}

d.Set("vpce_security_group_ids", flattenVpcEndpointSecurityGroupIds(out.Groups))
}

if err := d.Set("tags", keyvaluetags.TransferKeyValueTags(resp.Server.Tags).IgnoreAws().IgnoreConfig(ignoreTagsConfig).Map()); err != nil {
return fmt.Errorf("Error setting tags: %s", err)
}
Expand All @@ -219,6 +271,7 @@ func resourceAwsTransferServerRead(d *schema.ResourceData, meta interface{}) err
func resourceAwsTransferServerUpdate(d *schema.ResourceData, meta interface{}) error {
conn := meta.(*AWSClient).transferconn
updateFlag := false
stopFlag := false
updateOpts := &transfer.UpdateServerInput{
ServerId: aws.String(d.Id()),
}
Expand Down Expand Up @@ -253,6 +306,10 @@ func resourceAwsTransferServerUpdate(d *schema.ResourceData, meta interface{}) e
if attr, ok := d.GetOk("endpoint_details"); ok {
updateOpts.EndpointDetails = expandTransferServerEndpointDetails(attr.([]interface{}))
}

if d.HasChange("endpoint_details.0.address_allocation_ids") {
stopFlag = true
}
}

if d.HasChange("host_key") {
Expand All @@ -263,14 +320,8 @@ func resourceAwsTransferServerUpdate(d *schema.ResourceData, meta interface{}) e
}

if updateFlag {
_, err := conn.UpdateServer(updateOpts)
if err != nil {
if isAWSErr(err, transfer.ErrCodeResourceNotFoundException, "") {
log.Printf("[WARN] Transfer Server (%s) not found, removing from state", d.Id())
d.SetId("")
return nil
}
return fmt.Errorf("error updating Transfer Server (%s): %s", d.Id(), err)
if err := doTransferServerUpdate(d, updateOpts, conn, meta.(*AWSClient).ec2conn, stopFlag); err != nil {
return err
}
}

Expand Down Expand Up @@ -392,19 +443,203 @@ func expandTransferServerEndpointDetails(l []interface{}) *transfer.EndpointDeta
}
e := l[0].(map[string]interface{})

return &transfer.EndpointDetails{
VpcEndpointId: aws.String(e["vpc_endpoint_id"].(string)),
out := &transfer.EndpointDetails{}

if v, ok := e["vpc_endpoint_id"]; ok && v != "" {
out.VpcEndpointId = aws.String(v.(string))
}

if v, ok := e["address_allocation_ids"]; ok {
out.AddressAllocationIds = expandStringSet(v.(*schema.Set))
}

if v, ok := e["subnet_ids"]; ok {
out.SubnetIds = expandStringSet(v.(*schema.Set))
}

if v, ok := e["vpc_id"]; ok {
out.VpcId = aws.String(v.(string))
}

return out
}

func flattenTransferServerEndpointDetails(endpointDetails *transfer.EndpointDetails) []interface{} {
if endpointDetails == nil {
return []interface{}{}
}

e := map[string]interface{}{
"vpc_endpoint_id": aws.StringValue(endpointDetails.VpcEndpointId),
e := make(map[string]interface{})
if endpointDetails.VpcEndpointId != nil {
e["vpc_endpoint_id"] = aws.StringValue(endpointDetails.VpcEndpointId)
}
if endpointDetails.AddressAllocationIds != nil {
e["address_allocation_ids"] = flattenStringSet(endpointDetails.AddressAllocationIds)
}
if endpointDetails.SubnetIds != nil {
e["subnet_ids"] = flattenStringSet(endpointDetails.SubnetIds)
}
if endpointDetails.VpcId != nil {
e["vpc_id"] = aws.StringValue(endpointDetails.VpcId)
}

return []interface{}{e}
}

func doTransferServerUpdate(d *schema.ResourceData, updateOpts *transfer.UpdateServerInput, transferConn *transfer.Transfer, ec2Conn *ec2.EC2, stopFlag bool) error {
if stopFlag {
if err := waitForTransferServerVPCEndpointState(transferConn, ec2Conn, d.Id(), d.Timeout(schema.TimeoutCreate)); err != nil {
return fmt.Errorf("error waiting for Transfer Server VPC Endpoint (%s) to start: %s", d.Id(), err)
}

if err := stopAndWaitForTransferServer(d.Id(), transferConn, d.Timeout(schema.TimeoutCreate)); err != nil {
return err
}
}

_, err := transferConn.UpdateServer(updateOpts)
if err != nil {
if isAWSErr(err, transfer.ErrCodeResourceNotFoundException, "") {
log.Printf("[WARN] Transfer Server (%s) not found, removing from state", d.Id())
d.SetId("")
return nil
}
return fmt.Errorf("error updating Transfer Server (%s): %s", d.Id(), err)
}

if stopFlag {
if err := startAndWaitForTransferServer(d.Id(), transferConn, d.Timeout(schema.TimeoutCreate)); err != nil {
return err
}
}

if err := updateTransferServerVPCEndpointSecurityGroup(d, transferConn, ec2Conn); err != nil {
return err
}

return nil
}

func stopAndWaitForTransferServer(serverId string, conn *transfer.Transfer, timeout time.Duration) error {
stopReq := &transfer.StopServerInput{
ServerId: aws.String(serverId),
}
if _, err := conn.StopServer(stopReq); err != nil {
return fmt.Errorf("error stopping Transfer Server (%s): %s", serverId, err)
}

stateChangeConf := &resource.StateChangeConf{
Pending: []string{transfer.StateStarting, transfer.StateOnline, transfer.StateStopping},
Target: []string{transfer.StateOffline},
Refresh: refreshTransferServerStatus(conn, serverId),
Timeout: timeout,
Delay: 10 * time.Second,
}

if _, err := stateChangeConf.WaitForState(); err != nil {
return fmt.Errorf("error waiting for Transfer Server (%s) to stop: %s", serverId, err)
}

return nil
}

func startAndWaitForTransferServer(serverId string, conn *transfer.Transfer, timeout time.Duration) error {
stopReq := &transfer.StartServerInput{
ServerId: aws.String(serverId),
}

if _, err := conn.StartServer(stopReq); err != nil {
return fmt.Errorf("error starting Transfer Server (%s): %s", serverId, err)
}

stateChangeConf := &resource.StateChangeConf{
Pending: []string{transfer.StateStarting, transfer.StateOffline, transfer.StateStopping},
Target: []string{transfer.StateOnline},
Refresh: refreshTransferServerStatus(conn, serverId),
Timeout: timeout,
Delay: 10 * time.Second,
}

if _, err := stateChangeConf.WaitForState(); err != nil {
return fmt.Errorf("error waiting for Transfer Server (%s) to start: %s", serverId, err)
}

return nil
}

func refreshTransferServerStatus(conn *transfer.Transfer, serverId string) resource.StateRefreshFunc {
return func() (interface{}, string, error) {
server, err := describeTransferServer(conn, serverId)

if server == nil {
return 42, "destroyed", nil
}

return server, aws.StringValue(server.State), err
}
}

func describeTransferServer(conn *transfer.Transfer, serverId string) (*transfer.DescribedServer, error) {
params := &transfer.DescribeServerInput{
ServerId: aws.String(serverId),
}

resp, err := conn.DescribeServer(params)

return resp.Server, err
}

func waitForTransferServerVPCEndpointState(transferConn *transfer.Transfer, ec2Conn *ec2.EC2, serverId string, timeout time.Duration) error {
server, err := describeTransferServer(transferConn, serverId)

if err != nil {
return err
}

if err := vpcEndpointWaitUntilAvailable(ec2Conn, *server.EndpointDetails.VpcEndpointId, timeout); err != nil {
return err
}

return nil
}

func describeTransferServerVPCEndpoint(conn *ec2.EC2, vpceId *string) (*ec2.VpcEndpoint, error) {
params := &ec2.DescribeVpcEndpointsInput{
VpcEndpointIds: []*string{vpceId},
}

resp, err := conn.DescribeVpcEndpoints(params)

return resp.VpcEndpoints[0], err
}

func updateTransferServerVPCEndpointSecurityGroup(d *schema.ResourceData, transferConn *transfer.Transfer, ec2conn *ec2.EC2) error {
server, err := describeTransferServer(transferConn, d.Id())

if err != nil {
return fmt.Errorf("error describing Transfer Server VPC Endpoint: %s", err)
}

req := &ec2.ModifyVpcEndpointInput{
VpcEndpointId: aws.String(*server.EndpointDetails.VpcEndpointId),
}

if d.IsNewResource() {
out, err := describeTransferServerVPCEndpoint(ec2conn, server.EndpointDetails.VpcEndpointId)
if err != nil {
return err
}

req.RemoveSecurityGroupIds = append(req.RemoveSecurityGroupIds, out.Groups[0].GroupId)

setVpcEndpointCreateList(d, "vpce_security_group_ids", &req.AddSecurityGroupIds)
} else {
setVpcEndpointUpdateLists(d, "vpce_security_group_ids", &req.AddSecurityGroupIds, &req.RemoveSecurityGroupIds)
}

if _, err := ec2conn.ModifyVpcEndpoint(req); err != nil {
return fmt.Errorf("error updating VPC Endpoint: %s", err)
}

return nil
}
Loading