Skip to content

Commit

Permalink
Add option to provision spot instances to CLI for skyplane cp and sky…
Browse files Browse the repository at this point in the history
…plane sync
  • Loading branch information
parasj committed Sep 4, 2022
1 parent ceaae05 commit 8514534
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 6 deletions.
6 changes: 6 additions & 0 deletions skyplane/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,12 @@ def cp(
use_e2ee=cloud_config.get_flag("encrypt_e2e") if src_region != dst_region else False,
use_socket_tls=cloud_config.get_flag("encrypt_socket_tls") if src_region != dst_region else False,
aws_instance_class=cloud_config.get_flag("aws_instance_class"),
aws_use_spot_instances=cloud_config.get_flag("aws_use_spot_instances"),
azure_instance_class=cloud_config.get_flag("azure_instance_class"),
azure_use_spot_instances=cloud_config.get_flag("azure_use_spot_instances"),
gcp_instance_class=cloud_config.get_flag("gcp_instance_class"),
gcp_use_premium_network=cloud_config.get_flag("gcp_use_premium_network"),
gcp_use_spot_instances=cloud_config.get_flag("gcp_use_spot_instances"),
multipart_enabled=multipart,
multipart_min_threshold_mb=cloud_config.get_flag("multipart_min_threshold_mb"),
multipart_min_size_mb=cloud_config.get_flag("multipart_min_size_mb"),
Expand Down Expand Up @@ -367,9 +370,12 @@ def sync(
use_e2ee=cloud_config.get_flag("encrypt_e2e") if src_region != dst_region else False,
use_socket_tls=cloud_config.get_flag("encrypt_socket_tls") if src_region != dst_region else False,
aws_instance_class=cloud_config.get_flag("aws_instance_class"),
aws_use_spot_instances=cloud_config.get_flag("aws_use_spot_instances"),
azure_instance_class=cloud_config.get_flag("azure_instance_class"),
azure_use_spot_instances=cloud_config.get_flag("azure_use_spot_instances"),
gcp_instance_class=cloud_config.get_flag("gcp_instance_class"),
gcp_use_premium_network=cloud_config.get_flag("gcp_use_premium_network"),
gcp_use_spot_instances=cloud_config.get_flag("gcp_use_spot_instances"),
multipart_enabled=multipart,
multipart_min_threshold_mb=cloud_config.get_flag("multipart_min_threshold_mb"),
multipart_min_size_mb=cloud_config.get_flag("multipart_min_size_mb"),
Expand Down
12 changes: 11 additions & 1 deletion skyplane/cli/cli_impl/cp_replicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,11 @@ def launch_replication_job(
multipart_min_size_mb: int = 8,
multipart_max_chunks: int = 9990,
# cloud provider specific options
aws_use_spot_instances: bool = False,
aws_instance_class: str = "m5.8xlarge",
azure_use_spot_instances: bool = False,
azure_instance_class: str = "Standard_D32_v4",
gcp_use_spot_instances: bool = False,
gcp_instance_class: str = "n2-standard-32",
gcp_use_premium_network: bool = True,
# logging options
Expand Down Expand Up @@ -289,7 +292,14 @@ def launch_replication_job(
stats = TransferStats.empty()
try:
rc.provision_gateways(
reuse_gateways, use_bbr=use_bbr, use_compression=use_compression, use_e2ee=use_e2ee, use_socket_tls=use_socket_tls
reuse_gateways,
use_bbr=use_bbr,
use_compression=use_compression,
use_e2ee=use_e2ee,
use_socket_tls=use_socket_tls,
aws_use_spot_instances=aws_use_spot_instances,
azure_use_spot_instances=azure_use_spot_instances,
gcp_use_spot_instances=gcp_use_spot_instances,
)
for node, gw in rc.bound_nodes.items():
logger.fs.info(f"Log URLs for {gw.uuid()} ({node.region}:{node.instance})")
Expand Down
2 changes: 2 additions & 0 deletions skyplane/compute/aws/aws_cloud_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ def provision_instance(
tags={"skyplane": "true"},
ebs_volume_size: int = 128,
iam_name: str = "skyplane_gateway",
use_spot_instances: bool = False,
) -> AWSServer:

assert not region.startswith("aws:"), "Region should be AWS region"
Expand Down Expand Up @@ -399,6 +400,7 @@ def start_instance(subnet_id: str):
],
IamInstanceProfile={"Name": iam_instance_profile_name},
InstanceInitiatedShutdownBehavior="terminate",
InstanceMarketOptions={"MarketType": "spot" if use_spot_instances else "on-demand"},
)

backoff = 1
Expand Down
6 changes: 5 additions & 1 deletion skyplane/compute/azure/azure_cloud_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,9 @@ def set_up_resource_group(self, clean_up_orphans=True):

# This code, along with some code in azure_server.py, is based on
# https://github.com/ucbrise/mage-scripts/blob/main/azure_cloud.py.
def provision_instance(self, location: str, vm_size: str, name: Optional[str] = None, uname: str = "skyplane") -> AzureServer:
def provision_instance(
self, location: str, vm_size: str, name: Optional[str] = None, uname: str = "skyplane", use_spot_instances: bool = False
) -> AzureServer:
assert ":" not in location, "invalid colon in Azure location"

if name is None:
Expand Down Expand Up @@ -375,6 +377,8 @@ def provision_instance(self, location: str, vm_size: str, name: Optional[str] =
}
],
},
# use spot instances if use_spot_instances is set
"priority": "Spot" if use_spot_instances else "Regular",
},
)
vm_result = poller.result()
Expand Down
12 changes: 11 additions & 1 deletion skyplane/compute/gcp/gcp_cloud_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,14 @@ def wait_for_operation_to_complete(self, zone, operation_name, timeout=120):
time.sleep(time_intervals.pop(0))

def provision_instance(
self, region, instance_class, name=None, premium_network=False, uname="skyplane", tags={"skyplane": "true"}
self,
region,
instance_class,
name=None,
premium_network=False,
uname="skyplane",
tags={"skyplane": "true"},
use_spot_instances: bool = False,
) -> GCPServer:
assert not region.startswith("gcp:"), "Region should be GCP region"
if name is None:
Expand Down Expand Up @@ -331,6 +338,9 @@ def provision_instance(
"scheduling": {"onHostMaintenance": "TERMINATE", "automaticRestart": False},
"deletionProtection": False,
}
# use preemtible instances if use_spot_instances is True
if use_spot_instances:
req_body["scheduling"]["preemptible"] = True
try:
result = compute.instances().insert(project=self.auth.project_id, zone=region, body=req_body).execute()
self.wait_for_operation_to_complete(region, result["name"])
Expand Down
6 changes: 6 additions & 0 deletions skyplane/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
"num_connections": int,
"max_instances": int,
"autoshutdown_minutes": int,
"aws_use_spot_instances": bool,
"azure_use_spot_instances": bool,
"gcp_use_spot_instances": bool,
"aws_instance_class": str,
"azure_instance_class": str,
"gcp_instance_class": str,
Expand All @@ -40,6 +43,9 @@
"num_connections": 32,
"max_instances": 1,
"autoshutdown_minutes": 15,
"aws_use_spot_instances": False,
"azure_use_spot_instances": False,
"gcp_use_spot_instances": False,
"aws_instance_class": "m5.8xlarge",
"azure_instance_class": "Standard_D32_v5",
"gcp_instance_class": "n2-standard-32",
Expand Down
14 changes: 11 additions & 3 deletions skyplane/replicate/replicator_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ def provision_gateways(
use_compression=True,
use_e2ee=True,
use_socket_tls=False,
aws_use_spot_instances: bool = False,
azure_use_spot_instances: bool = False,
gcp_use_spot_instances: bool = False,
):
regions_to_provision = [node.region for node in self.topology.gateway_nodes]
aws_regions_to_provision = [r for r in regions_to_provision if r.startswith("aws:")]
Expand Down Expand Up @@ -190,14 +193,19 @@ def provision_gateway_instance(region: str) -> Server:
provider, subregion = region.split(":")
if provider == "aws":
assert self.aws.auth.enabled()
server = self.aws.provision_instance(subregion, self.aws_instance_class)
server = self.aws.provision_instance(subregion, self.aws_instance_class, use_spot_instances=aws_use_spot_instances)
elif provider == "azure":
assert self.azure.auth.enabled()
server = self.azure.provision_instance(subregion, self.azure_instance_class)
server = self.azure.provision_instance(subregion, self.azure_instance_class, use_spot_instances=azure_use_spot_instances)
elif provider == "gcp":
assert self.gcp.auth.enabled()
# todo specify network tier in ReplicationTopology
server = self.gcp.provision_instance(subregion, self.gcp_instance_class, premium_network=self.gcp_use_premium_network)
server = self.gcp.provision_instance(
subregion,
self.gcp_instance_class,
premium_network=self.gcp_use_premium_network,
use_spot_instances=gcp_use_spot_instances,
)
else:
raise NotImplementedError(f"Unknown provider {provider}")
server.enable_auto_shutdown()
Expand Down

0 comments on commit 8514534

Please sign in to comment.