Skip to content

Commit

Permalink
Move GCP networking + pricing logic out of cloud provider (#614)
Browse files Browse the repository at this point in the history
  • Loading branch information
parasj authored Nov 4, 2022
1 parent 9e6c937 commit 2d128e2
Show file tree
Hide file tree
Showing 12 changed files with 346 additions and 258 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ include = ["skyplane/data/*"]
python = ">=3.7.1,<3.12"
boto3 = ">=1.16.0"
cachetools = ">=4.1.0"
cryptography = ">=1.4.0"
pandas = ">=1.0.0"
paramiko = ">=2.7.2"
questionary = ">=1.8.0"
Expand Down
22 changes: 13 additions & 9 deletions skyplane/api/impl/provisioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def __init__(
self.pending_provisioner_tasks: List[ProvisionerTask] = []
self.provisioned_vms: Dict[str, compute.Server] = {}

# store GCP firewall rules to be deleted upon exit
self.gcp_firewall_rules: Set[str] = set()

def _make_cloud_providers(self):
self.aws = compute.AWSCloudProvider(
key_prefix=f"skyplane{'-'+self.host_uuid.replace('-', '') if self.host_uuid else ''}", auth=self.aws_auth
Expand All @@ -59,9 +62,7 @@ def init_global(self, aws: bool = True, azure: bool = True, gcp: bool = True):
jobs.append(self.azure.create_ssh_key)
jobs.append(self.azure.set_up_resource_group)
if gcp:
jobs.append(self.gcp.create_ssh_key)
jobs.append(self.gcp.configure_skyplane_network)
jobs.append(self.gcp.configure_skyplane_firewall)
jobs.append(self.gcp.setup_global)
do_parallel(lambda fn: fn(), jobs, spinner=False)

def add_task(
Expand Down Expand Up @@ -151,7 +152,12 @@ def provision(self, authorize_firewall: bool = True, max_jobs: int = 16, spinner
if aws_provisioned:
authorize_ip_jobs.extend([partial(self.aws.add_ips_to_security_group, r, public_ips) for r in set(aws_regions)])
if gcp_provisioned:
authorize_ip_jobs.append(partial(self.gcp.add_ips_to_firewall, public_ips + private_ips))

def authorize_gcp_gateways():
self.gcp_firewall_rules.add(self.gcp.authorize_gateways(public_ips + private_ips))

authorize_ip_jobs.append(authorize_gcp_gateways)

do_parallel(
lambda fn: fn(),
authorize_ip_jobs,
Expand Down Expand Up @@ -208,16 +214,14 @@ def deprovision_gateway_instance(server: compute.Server):
if deauthorize_firewall:
# todo remove firewall rules for Azure
public_ips = [s.public_ip() for s in servers]
# deauthorize access to private IPs for GCP VMs due to global VPC
private_ips = [s.private_ip() for s in servers if s.provider == "gcp"]
jobs = []
if aws_deprovisioned:
aws_regions = set([s.region() for s in servers if s.provider == "aws"])
jobs.extend([partial(self.aws.remove_ips_from_security_group, r, public_ips) for r in set(aws_regions)])
logger.fs.info(f"[Provisioner.deprovision] Deauthorizing AWS gateways with firewalls: {public_ips=}")
if gcp_deprovisioned:
jobs.append(partial(self.gcp.remove_ips_from_firewall, public_ips + private_ips))
logger.fs.info(f"[Provisioner.deprovision] Deauthorizing AWS gateways with firewalls: {public_ips=}")
logger.fs.info(f"[Provisioner.deprovision] Deauthorizing GCP gateways with firewalls: {public_ips=}, {private_ips=}")
jobs.extend([partial(self.gcp.remove_gateway_rule, rule) for rule in self.gcp_firewall_rules])
logger.fs.info(f"[Provisioner.deprovision] Deauthorizing GCP gateways with firewalls: {self.gcp_firewall_rules=}")
do_parallel(
lambda fn: fn(), jobs, n=max_jobs, spinner=spinner, spinner_persist=False, desc="Deauthorizing gateways from firewalls"
)
7 changes: 1 addition & 6 deletions skyplane/cli/experiments/provision.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ def provision(
jobs.append(azure.create_ssh_key)
jobs.append(azure.set_up_resource_group)
if gcp_regions_to_provision:
jobs.append(gcp.create_ssh_key)
jobs.append(gcp.configure_skyplane_network)
jobs.append(gcp.configure_skyplane_firewall)
jobs.append(gcp.setup_global)
with Timer("Cloud SSH key initialization"):
do_parallel(lambda fn: fn(), jobs)

Expand Down Expand Up @@ -100,9 +98,6 @@ def azure_provisioner(r):
"state": [compute.ServerState.PENDING, compute.ServerState.RUNNING],
"network_tier": "PREMIUM" if gcp_use_premium_network else "STANDARD",
}
gcp.create_ssh_key()
gcp.configure_skyplane_network()
gcp.configure_skyplane_firewall()
gcp_instances = refresh_instance_list(gcp, gcp_regions_to_provision, gcp_instance_filter, n=4)
missing_gcp_regions = set(gcp_regions_to_provision) - set(gcp_instances.keys())

Expand Down
10 changes: 4 additions & 6 deletions skyplane/compute/azure/azure_cloud_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
from cryptography.utils import CryptographyDeprecationWarning
from typing import List, Optional

from skyplane.compute.key_utils import generate_keypair

with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=CryptographyDeprecationWarning)
import paramiko

from skyplane import exceptions
from skyplane.config_paths import cloud_config
Expand Down Expand Up @@ -192,15 +193,12 @@ def get_instance_list(self, region: str) -> List[AzureServer]:
)
return server_list

# Copied from gcp_cloud_provider.py --- consolidate later?
def create_ssh_key(self):
public_key_path = Path(self.public_key_path)
private_key_path = Path(self.private_key_path)
if not private_key_path.exists():
private_key_path.parent.mkdir(parents=True, exist_ok=True)
key = paramiko.RSAKey.generate(4096)
key.write_private_key_file(self.private_key_path, password="skyplane")
with open(self.public_key_path, "w") as f:
f.write(f"{key.get_name()} {key.get_base64()}\n")
generate_keypair(public_key_path, private_key_path)

def set_up_resource_group(self, clean_up_orphans=True):
resource_client = self.auth.get_resource_client()
Expand Down
20 changes: 20 additions & 0 deletions skyplane/compute/gcp/gcp_auth.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import base64
import os
import time
from pathlib import Path

from typing import Optional
Expand Down Expand Up @@ -94,6 +95,25 @@ def get_adc_credential(google_auth, project_id=None):
inferred_project = project_id
return inferred_cred, inferred_project

def get_operation_state(self, zone, operation_name):
compute = self.get_gcp_client()
if zone == "global":
return compute.globalOperations().get(project=self.project_id, operation=operation_name).execute()
else:
return compute.zoneOperations().get(project=self.project_id, zone=zone, operation=operation_name).execute()

def wait_for_operation_to_complete(self, zone, operation_name, timeout=120):
time_intervals = [0.1] * 10 + [0.2] * 10 + [1.0] * int(timeout) # backoff
start = time.time()
while time.time() - start < timeout:
operation_state = self.get_operation_state(zone, operation_name)
if operation_state["status"] == "DONE":
if "error" in operation_state:
raise Exception(operation_state["error"])
else:
return operation_state
time.sleep(time_intervals.pop(0))

@property
def service_account_name(self):
return self.config.get_flag("gcp_service_account_name")
Expand Down
Loading

0 comments on commit 2d128e2

Please sign in to comment.