Skip to content

Commit

Permalink
Support gateways behind ALB with ACM certificate (#1264)
Browse files Browse the repository at this point in the history
* Support gateways behind ALB with ACM certificate

* Minor enhancements
  • Loading branch information
r4victor authored May 23, 2024
1 parent 8c1e195 commit 576d1af
Show file tree
Hide file tree
Showing 15 changed files with 438 additions and 83 deletions.
16 changes: 16 additions & 0 deletions docs/docs/reference/dstack.yml/gateway.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,19 @@ domain: '*.example.com'
show_root_heading: false
type:
required: true
## `certificate[type=lets-encrypt]`

#SCHEMA# dstack._internal.core.models.gateways.LetsEncryptGatewayCertificate
overrides:
show_root_heading: false
type:
required: true

## `certificate[type=acm]`

#SCHEMA# dstack._internal.core.models.gateways.ACMGatewayCertificate
overrides:
show_root_heading: false
type:
required: true
6 changes: 3 additions & 3 deletions src/dstack/_internal/cli/utils/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ def print_gateways_table(gateways: List[Gateway], verbose: bool = False):
table = Table(box=None)
table.add_column("BACKEND")
table.add_column("REGION")
table.add_column("NAME")
table.add_column("ADDRESS")
table.add_column("NAME", no_wrap=True)
table.add_column("HOSTNAME", no_wrap=True)
table.add_column("DOMAIN")
table.add_column("DEFAULT")
table.add_column("STATUS")
Expand All @@ -29,7 +29,7 @@ def print_gateways_table(gateways: List[Gateway], verbose: bool = False):
backend.value if i == 0 else "",
gateway.region,
gateway.name,
gateway.ip_address,
gateway.hostname,
gateway.wildcard_domain,
"✓" if gateway.default else "",
gateway.status,
Expand Down
188 changes: 157 additions & 31 deletions src/dstack/_internal/core/backends/aws/compute.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

import boto3
import botocore.client
import botocore.exceptions
from pydantic import ValidationError

import dstack._internal.core.backends.aws.resources as aws_resources
from dstack._internal import settings
Expand All @@ -18,8 +19,10 @@
from dstack._internal.core.errors import ComputeError, NoCapacityError
from dstack._internal.core.models.backends.aws import AWSAccessKeyCreds
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.common import is_core_model_instance
from dstack._internal.core.models.gateways import GatewayComputeConfiguration
from dstack._internal.core.models.common import CoreModel, is_core_model_instance
from dstack._internal.core.models.gateways import (
GatewayComputeConfiguration,
)
from dstack._internal.core.models.instances import (
InstanceAvailability,
InstanceConfiguration,
Expand All @@ -34,6 +37,12 @@
logger = get_logger(__name__)


class GatewayAWSBackendData(CoreModel):
lb_arn: str
tg_arn: str
listener_arn: str


class AWSCompute(Compute):
def __init__(self, config: AWSConfig):
self.config = config
Expand Down Expand Up @@ -114,12 +123,13 @@ def create_instance(
{"Key": "dstack_user", "Value": instance_config.user},
]
try:
vpc_id, subnet_id = get_vpc_id_subnet_id_or_error(
vpc_id, subnets_ids = get_vpc_id_subnet_id_or_error(
ec2_client=ec2_client,
config=self.config,
region=instance_offer.region,
allocate_public_ip=allocate_public_ip,
)
subnet_id = subnets_ids[0]
disk_size = round(instance_offer.instance.resources.disk.size_mib / 1024)
response = ec2.create_instances(
**aws_resources.create_instances_struct(
Expand Down Expand Up @@ -149,10 +159,7 @@ def create_instance(
ec2_client.cancel_spot_instance_requests(
SpotInstanceRequestIds=[instance.spot_instance_request_id]
)
if allocate_public_ip:
hostname = instance.public_ip_address
else:
hostname = instance.private_ip_address
hostname = _get_instance_ip(instance, allocate_public_ip)
return JobProvisioningData(
backend=instance_offer.backend,
instance_type=instance_offer.instance,
Expand Down Expand Up @@ -197,19 +204,27 @@ def create_gateway(
) -> LaunchedGatewayInfo:
ec2 = self.session.resource("ec2", region_name=configuration.region)
ec2_client = self.session.client("ec2", region_name=configuration.region)

tags = [
{"Key": "Name", "Value": configuration.instance_name},
{"Key": "owner", "Value": "dstack"},
{"Key": "dstack_project", "Value": configuration.project_name},
]
if settings.DSTACK_VERSION is not None:
tags.append({"Key": "dstack_version", "Value": settings.DSTACK_VERSION})
vpc_id, subnet_id = get_vpc_id_subnet_id_or_error(

vpc_id, subnets_ids = get_vpc_id_subnet_id_or_error(
ec2_client=ec2_client,
config=self.config,
region=configuration.region,
allocate_public_ip=configuration.public_ip,
)
subnet_id = subnets_ids[0]
security_group_id = aws_resources.create_gateway_security_group(
ec2_client=ec2_client,
project_id=configuration.project_name,
vpc_id=vpc_id,
)
response = ec2.create_instances(
**aws_resources.create_instances_struct(
disk_size=10,
Expand All @@ -218,11 +233,7 @@ def create_gateway(
iam_instance_profile_arn=None,
user_data=get_gateway_user_data(configuration.ssh_key_pub),
tags=tags,
security_group_id=aws_resources.create_gateway_security_group(
ec2_client=ec2_client,
project_id=configuration.project_name,
vpc_id=vpc_id,
),
security_group_id=security_group_id,
spot=False,
subnet_id=subnet_id,
allocate_public_ip=configuration.public_ip,
Expand All @@ -231,15 +242,124 @@ def create_gateway(
instance = response[0]
instance.wait_until_running()
instance.reload() # populate instance.public_ip_address
if configuration.public_ip:
ip_address = instance.public_ip_address
else:
ip_address = instance.private_ip_address
if configuration.certificate is None or configuration.certificate.type != "acm":
ip_address = _get_instance_ip(instance, configuration.public_ip)
return LaunchedGatewayInfo(
instance_id=instance.instance_id,
region=configuration.region,
ip_address=ip_address,
)

elb_client = self.session.client("elbv2", region_name=configuration.region)

if len(subnets_ids) < 2:
raise ComputeError(
"Deploying gateway with ACM certificate requires at least two subnets in different AZs"
)

logger.debug("Creating ALB for gateway %s...", configuration.instance_name)
response = elb_client.create_load_balancer(
Name=f"{configuration.instance_name}-lb",
Subnets=subnets_ids,
SecurityGroups=[security_group_id],
Scheme="internet-facing" if configuration.public_ip else "internal",
Tags=tags,
Type="application",
IpAddressType="ipv4",
)
lb = response["LoadBalancers"][0]
lb_arn = lb["LoadBalancerArn"]
lb_dns_name = lb["DNSName"]
logger.debug("Created ALB for gateway %s.", configuration.instance_name)

logger.debug("Creating Target Group for gateway %s...", configuration.instance_name)
response = elb_client.create_target_group(
Name=f"{configuration.instance_name}-tg",
Protocol="HTTP",
Port=80,
VpcId=vpc_id,
TargetType="instance",
)
tg_arn = response["TargetGroups"][0]["TargetGroupArn"]
logger.debug("Created Target Group for gateway %s", configuration.instance_name)

logger.debug("Registering ALB target for gateway %s...", configuration.instance_name)
elb_client.register_targets(
TargetGroupArn=tg_arn,
Targets=[
{"Id": instance.instance_id, "Port": 80},
],
)
logger.debug("Registered ALB target for gateway %s", configuration.instance_name)

logger.debug("Creating ALB Listener for gateway %s...", configuration.instance_name)
response = elb_client.create_listener(
LoadBalancerArn=lb_arn,
Protocol="HTTPS",
Port=443,
SslPolicy="ELBSecurityPolicy-2016-08",
Certificates=[
{"CertificateArn": configuration.certificate.arn},
],
DefaultActions=[
{
"Type": "forward",
"TargetGroupArn": tg_arn,
}
],
)
listener_arn = response["Listeners"][0]["ListenerArn"]
logger.debug("Created ALB Listener for gateway %s", configuration.instance_name)

ip_address = _get_instance_ip(instance, configuration.public_ip)
return LaunchedGatewayInfo(
instance_id=instance.instance_id,
region=configuration.region,
ip_address=ip_address,
hostname=lb_dns_name,
backend_data=GatewayAWSBackendData(
lb_arn=lb_arn,
tg_arn=tg_arn,
listener_arn=listener_arn,
).json(),
)

def terminate_gateway(
self,
instance_id,
configuration: GatewayComputeConfiguration,
backend_data: Optional[str] = None,
):
self.terminate_instance(
instance_id=instance_id,
region=configuration.region,
backend_data=None,
)
if configuration.certificate is None or configuration.certificate.type != "acm":
return

if backend_data is None:
logger.error(
"Failed to terminate all gateway %s resources. backend_data is None.",
configuration.instance_name,
)
return

try:
backend_data_parsed = GatewayAWSBackendData.parse_raw(backend_data)
except ValidationError:
logger.exception(
"Failed to terminate all gateway %s resources. backend_data parsing error.",
configuration.instance_name,
)

elb_client = self.session.client("elbv2", region_name=configuration.region)

logger.debug("Deleting ALB resources for gateway %s...", configuration.instance_name)
elb_client.delete_listener(ListenerArn=backend_data_parsed.listener_arn)
elb_client.delete_target_group(TargetGroupArn=backend_data_parsed.tg_arn)
elb_client.delete_load_balancer(LoadBalancerArn=backend_data_parsed.lb_arn)
logger.debug("Deleted ALB resources for gateway %s", configuration.instance_name)


def _has_quota(quotas: Dict[str, int], instance_name: str) -> bool:
Expand All @@ -262,7 +382,7 @@ def get_vpc_id_subnet_id_or_error(
config: AWSConfig,
region: str,
allocate_public_ip: bool,
) -> Tuple[str, str]:
) -> Tuple[str, List[str]]:
if config.vpc_ids is not None:
vpc_id = config.vpc_ids.get(region)
if vpc_id is None:
Expand All @@ -271,14 +391,14 @@ def get_vpc_id_subnet_id_or_error(
if vpc is None:
raise ComputeError(f"Failed to find VPC {vpc_id} in region {region}")

subnet_id = aws_resources.get_subnet_id_for_vpc(
subnets_ids = aws_resources.get_subnets_ids_for_vpc(
ec2_client=ec2_client,
vpc_id=vpc_id,
allocate_public_ip=allocate_public_ip,
)
if subnet_id is not None:
return vpc_id, subnet_id
raise ComputeError(f"Failed to find public subnet for VPC {vpc_id}")
if len(subnets_ids) > 0:
return vpc_id, subnets_ids
raise ComputeError(f"Failed to find public subnets for VPC {vpc_id}")

return _get_vpc_id_subnet_id_by_vpc_name_or_error(
ec2_client=ec2_client,
Expand All @@ -293,7 +413,7 @@ def _get_vpc_id_subnet_id_by_vpc_name_or_error(
vpc_name: Optional[str],
region: str,
allocate_public_ip: bool,
) -> Tuple[str, str]:
) -> Tuple[str, List[str]]:
if vpc_name is not None:
vpc_id = aws_resources.get_vpc_id_by_name(
ec2_client=ec2_client,
Expand All @@ -305,23 +425,29 @@ def _get_vpc_id_subnet_id_by_vpc_name_or_error(
vpc_id = aws_resources.get_default_vpc_id(ec2_client=ec2_client)
if vpc_id is None:
raise ComputeError(f"No default VPC in region {region}")
subnet_id = aws_resources.get_subnet_id_for_vpc(
subnets_ids = aws_resources.get_subnets_ids_for_vpc(
ec2_client=ec2_client,
vpc_id=vpc_id,
allocate_public_ip=allocate_public_ip,
)
if subnet_id is not None:
return vpc_id, subnet_id
if len(subnets_ids) > 0:
return vpc_id, subnets_ids
if vpc_name is not None:
if allocate_public_ip:
raise ComputeError(
f"Failed to find public subnet for VPC {vpc_name} in region {region}"
f"Failed to find public subnets for VPC {vpc_name} in region {region}"
)
raise ComputeError(
f"Failed to find private subnet with NAT for VPC {vpc_name} in region {region}"
f"Failed to find private subnets with NAT for VPC {vpc_name} in region {region}"
)
if allocate_public_ip:
raise ComputeError(f"Failed to find public subnet for default VPC in region {region}")
raise ComputeError(f"Failed to find public subnets for default VPC in region {region}")
raise ComputeError(
f"Failed to find private subnet with NAT for default VPC in region {region}"
f"Failed to find private subnets with NAT for default VPC in region {region}"
)


def _get_instance_ip(instance: Any, public_ip: bool) -> str:
if public_ip:
return instance.public_ip_address
return instance.private_ip_address
17 changes: 9 additions & 8 deletions src/dstack/_internal/core/backends/aws/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,35 +256,36 @@ def get_vpc_by_vpc_id(ec2_client: botocore.client.BaseClient, vpc_id: str) -> Op
return None


def get_subnet_id_for_vpc(
def get_subnets_ids_for_vpc(
ec2_client: botocore.client.BaseClient,
vpc_id: str,
allocate_public_ip: bool,
) -> Optional[str]:
) -> List[str]:
"""
If `allocate_public_ip` is True, returns a first public subnet found in the VPC.
If `allocate_public_ip` is False, returns a first subnet with NAT found in the VPC.
If `allocate_public_ip` is True, returns public subnets found in the VPC.
If `allocate_public_ip` is False, returns subnets with NAT found in the VPC.
"""
subnets = _get_subnets_by_vpc_id(ec2_client=ec2_client, vpc_id=vpc_id)
if len(subnets) == 0:
return None
return []
subnets_ids = []
for subnet in subnets:
subnet_id = subnet["SubnetId"]
if allocate_public_ip:
is_public_subnet = _is_public_subnet(
ec2_client=ec2_client, vpc_id=vpc_id, subnet_id=subnet_id
)
if is_public_subnet:
return subnet_id
subnets_ids.append(subnet_id)
else:
subnet_behind_nat = _is_subnet_behind_nat(
ec2_client=ec2_client,
vpc_id=vpc_id,
subnet_id=subnet_id,
)
if subnet_behind_nat:
return subnet_id
return None
subnets_ids.append(subnet_id)
return subnets_ids


def _add_ingress_security_group_rule_if_missing(
Expand Down
Loading

0 comments on commit 576d1af

Please sign in to comment.