From 90acb1b55fd4eebdd4a33c2e696af4ed19ebdb0e Mon Sep 17 00:00:00 2001 From: Paras Jain Date: Thu, 16 Dec 2021 22:46:52 +0000 Subject: [PATCH] Fix AWS and GCP key generation --- setup.py | 1 + skylark/__init__.py | 2 ++ skylark/benchmark/network/throughput.py | 5 +--- skylark/benchmark/utils.py | 15 ++---------- skylark/compute/aws/aws_cloud_provider.py | 2 +- skylark/compute/aws/aws_server.py | 30 ++++++++++++++++------- skylark/compute/gcp/gcp_cloud_provider.py | 12 +++++---- skylark/compute/gcp/gcp_server.py | 15 +++++++----- skylark/compute/server.py | 6 ++++- skylark/utils.py | 2 +- 10 files changed, 50 insertions(+), 40 deletions(-) diff --git a/setup.py b/setup.py index efe72e55b..73331cdf8 100644 --- a/setup.py +++ b/setup.py @@ -20,6 +20,7 @@ "pandas", "paramiko", "tqdm", + "questionary", ], extras_require={"test": ["black", "pytest", "ipython", "jupyter_console"]} ) \ No newline at end of file diff --git a/skylark/__init__.py b/skylark/__init__.py index a41f07aa2..9068e1ceb 100644 --- a/skylark/__init__.py +++ b/skylark/__init__.py @@ -1,3 +1,5 @@ from pathlib import Path skylark_root = Path(__file__).parent.parent +data_root = skylark_root / "data" +key_root = skylark_root / "data" / "keys" diff --git a/skylark/benchmark/network/throughput.py b/skylark/benchmark/network/throughput.py index bf754610c..849530b3c 100644 --- a/skylark/benchmark/network/throughput.py +++ b/skylark/benchmark/network/throughput.py @@ -48,11 +48,8 @@ def main(args): log_dir = data_dir / "logs" / "throughput" / datetime.now().strftime("%Y-%m-%d_%H-%M-%S") log_dir.mkdir(exist_ok=True, parents=True) - gcp_private_key = str(data_dir / "keys" / "gcp-cert.pem") - gcp_public_key = str(data_dir / "keys" / "gcp-cert.pub") - aws = AWSCloudProvider() - gcp = GCPCloudProvider(args.gcp_project, gcp_private_key, gcp_public_key) + gcp = GCPCloudProvider(args.gcp_project) aws_instances, gcp_instances = provision( aws=aws, gcp=gcp, diff --git a/skylark/benchmark/utils.py b/skylark/benchmark/utils.py index 3386bf39e..43895cf87 100644 --- a/skylark/benchmark/utils.py +++ b/skylark/benchmark/utils.py @@ -14,12 +14,7 @@ def refresh_instance_list(provider: CloudProvider, region_list=[None], instance_filter=None) -> Dict[str, List[Server]]: if instance_filter is None: instance_filter = {"tags": {"skylark": "true"}} - results = do_parallel( - lambda region: provider.get_matching_instances(region=region, **instance_filter), - region_list, - progress_bar=True, - # desc=f"refresh {provider.name}", - ) + results = do_parallel(lambda region: provider.get_matching_instances(region=region, **instance_filter), region_list, progress_bar=False) return {r: ilist for r, ilist in results if ilist} @@ -69,9 +64,6 @@ def provision( results = do_parallel(aws_provisioner, missing_aws_regions, progress_bar=True, desc="provision aws") for region, result in results: aws_instances[region] = [result] - logger.info(f"(aws:{region}) provisioned {result}, waiting for ready") - result.wait_for_ready() - logger.info(f"(aws:{region}) ready") aws_instances = refresh_instance_list(aws, aws_regions_to_provision, aws_instance_filter) if len(gcp_regions_to_provision) > 0: @@ -86,8 +78,6 @@ def provision( gcp.configure_default_network() gcp.configure_default_firewall() gcp_instances = refresh_instance_list(gcp, gcp_regions_to_provision, gcp_instance_filter) - - # provision missing regions using do_parallel missing_gcp_regions = set(gcp_regions_to_provision) - set(gcp_instances.keys()) if missing_gcp_regions: logger.info(f"(gcp) provisioning missing regions: {missing_gcp_regions}") @@ -97,7 +87,6 @@ def provision( ) for region, result in results: gcp_instances[region] = [result] - logger.info(f"(gcp:{region}) provisioned {result}") gcp_instances = refresh_instance_list(gcp, gcp_regions_to_provision, gcp_instance_filter) # init log files @@ -108,5 +97,5 @@ def init(i: Server): i.copy_and_run_script(setup_script) all_instances = [i for ilist in aws_instances.values() for i in ilist] + [i for ilist in gcp_instances.values() for i in ilist] - do_parallel(init, all_instances, progress_bar=True, desc="init instances") + do_parallel(init, all_instances, progress_bar=True, desc="Provisioning init") return aws_instances, gcp_instances diff --git a/skylark/compute/aws/aws_cloud_provider.py b/skylark/compute/aws/aws_cloud_provider.py index 80b609a62..367e8760d 100644 --- a/skylark/compute/aws/aws_cloud_provider.py +++ b/skylark/compute/aws/aws_cloud_provider.py @@ -121,7 +121,7 @@ def provision_instance(self, region, instance_class, name=None, ami_id=None, tag InstanceType=instance_class, MinCount=1, MaxCount=1, - KeyName=region, + KeyName=f"skylark-{region}", TagSpecifications=[ { "ResourceType": "instance", diff --git a/skylark/compute/aws/aws_server.py b/skylark/compute/aws/aws_server.py index 3c28ece13..2778b3741 100644 --- a/skylark/compute/aws/aws_server.py +++ b/skylark/compute/aws/aws_server.py @@ -1,23 +1,27 @@ import os from functools import lru_cache +from pathlib import Path import boto3 from boto3 import session import paramiko from loguru import logger +import questionary from skylark.compute.server import Server, ServerState +from skylark import key_root +from tqdm import tqdm class AWSServer(Server): """AWS Server class to support basic SSH operations""" - def __init__(self, region_tag, instance_id, log_dir=None): + def __init__(self, region_tag, instance_id, log_dir=None, key_root=key_root / "aws"): super().__init__(region_tag, log_dir=log_dir) assert self.region_tag.split(":")[0] == "aws" self.aws_region = self.region_tag.split(":")[1] self.instance_id = instance_id - self.local_keyfile = self.make_keyfile() + self.local_keyfile = self.make_keyfile(key_root) def uuid(self): return f"{self.region_tag}:{self.instance_id}" @@ -49,15 +53,23 @@ def get_boto3_client(cls, service_name, aws_region): setattr(cls.ns, ns_key, client) return getattr(cls.ns, ns_key) - def make_keyfile(self): - local_key_file = os.path.expanduser(f"~/.ssh/{self.aws_region}.pem") + def make_keyfile(self, prefix): + prefix = Path(prefix) + key_name = f"skylark-{self.aws_region}" + local_key_file = prefix / f"{key_name}.pem" ec2 = AWSServer.get_boto3_resource("ec2", self.aws_region) - if not os.path.exists(local_key_file): - key_pair = ec2.create_key_pair(KeyName=self.aws_region) - with open(local_key_file, "w") as f: + ec2_client = AWSServer.get_boto3_client("ec2", self.aws_region) + if not local_key_file.exists(): + prefix.mkdir(parents=True, exist_ok=True) + # delete key pair from ec2 if it exists + keys_in_region = set(p["KeyName"] for p in ec2_client.describe_key_pairs()["KeyPairs"]) + if key_name in keys_in_region: + logger.warning(f"Deleting key {key_name} in region {self.aws_region}") + ec2_client.delete_key_pair(KeyName=key_name) + key_pair = ec2.create_key_pair(KeyName=f"skylark-{self.aws_region}") + with local_key_file.open("w") as f: f.write(key_pair.key_material) os.chmod(local_key_file, 0o600) - logger.info(f"({self.aws_region}) Created keypair and saved to {local_key_file}") return local_key_file @property @@ -115,5 +127,5 @@ def terminate_instance_impl(self): def get_ssh_client_impl(self): client = paramiko.SSHClient() client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - client.connect(self.public_ip, username="ubuntu", key_filename=self.local_keyfile) + client.connect(self.public_ip, username="ubuntu", key_filename=str(self.local_keyfile)) return client diff --git a/skylark/compute/gcp/gcp_cloud_provider.py b/skylark/compute/gcp/gcp_cloud_provider.py index f3a22b5fe..144a13c55 100644 --- a/skylark/compute/gcp/gcp_cloud_provider.py +++ b/skylark/compute/gcp/gcp_cloud_provider.py @@ -9,17 +9,19 @@ import paramiko -from skylark.compute.gcp.gcp_server import GCPServer, DEFAULT_GCP_PRIVATE_KEY_PATH, DEFAULT_GCP_PUBLIC_KEY_PATH +from skylark.compute.gcp.gcp_server import GCPServer from skylark.compute.cloud_providers import CloudProvider -from skylark.compute.server import Server, ServerState +from skylark.compute.server import Server +from skylark import key_root class GCPCloudProvider(CloudProvider): - def __init__(self, gcp_project, private_key_path=DEFAULT_GCP_PRIVATE_KEY_PATH, public_key_path=DEFAULT_GCP_PUBLIC_KEY_PATH): + def __init__(self, gcp_project, key_root=key_root / "gcp"): super().__init__() self.gcp_project = gcp_project - self.private_key_path = os.path.expanduser(private_key_path) - self.public_key_path = os.path.expanduser(public_key_path) + key_root.mkdir(parents=True, exist_ok=True) + self.private_key_path = key_root / "gcp-cert.pem" + self.public_key_path = key_root / "gcp-cert.pub" @property def name(self): diff --git a/skylark/compute/gcp/gcp_server.py b/skylark/compute/gcp/gcp_server.py index 2d76e6bc4..91886375c 100644 --- a/skylark/compute/gcp/gcp_server.py +++ b/skylark/compute/gcp/gcp_server.py @@ -7,19 +7,18 @@ from tqdm import tqdm from skylark.compute.server import Server, ServerState - -DEFAULT_GCP_PRIVATE_KEY_PATH = os.path.expanduser("~/.ssh/google_compute_engine") -DEFAULT_GCP_PUBLIC_KEY_PATH = os.path.expanduser("~/.ssh/google_compute_engine.pub") +from skylark import key_root class GCPServer(Server): - def __init__(self, region_tag, gcp_project, instance_name, ssh_private_key=DEFAULT_GCP_PRIVATE_KEY_PATH, log_dir=None): + def __init__(self, region_tag, gcp_project, instance_name, key_root=key_root / "gcp", log_dir=None): super().__init__(region_tag, log_dir=log_dir) assert self.region_tag.split(":")[0] == "gcp", f"Region name doesn't match pattern gcp: {self.region_tag}" self.gcp_region = self.region_tag.split(":")[1] self.gcp_project = gcp_project self.gcp_instance_name = instance_name - self.ssh_private_key = os.path.expanduser(ssh_private_key) + key_root.mkdir(parents=True, exist_ok=True) + self.ssh_private_key = key_root / f"gcp.pem" def uuid(self): return f"{self.region_tag}:{self.gcp_instance_name}" @@ -97,6 +96,10 @@ def get_ssh_client_impl(self, uname=os.environ.get("USER"), ssh_key_password="sk ssh_client = paramiko.SSHClient() ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) ssh_client.connect( - hostname=self.public_ip, username=uname, key_filename=self.ssh_private_key, passphrase=ssh_key_password, look_for_keys=False + hostname=self.public_ip, + username=uname, + key_filename=str(self.ssh_private_key), + passphrase=ssh_key_password, + look_for_keys=False, ) return ssh_client diff --git a/skylark/compute/server.py b/skylark/compute/server.py index f5cdc9ac5..48dacf052 100644 --- a/skylark/compute/server.py +++ b/skylark/compute/server.py @@ -122,7 +122,11 @@ def wait_for_ready(self, timeout=120) -> bool: while (time.time() - start_time) < timeout: try: if self.instance_state == ServerState.RUNNING: - return True + try: + self.run_command("true") + return True + except Exception as e: + logger.warning(f"{self.instance_name} is not ready: {e}") time.sleep(wait_intervals.pop(0)) except Exception as e: print(f"Error waiting for server to be ready: {e}") diff --git a/skylark/utils.py b/skylark/utils.py index 9fe45ebd5..cb6e60602 100644 --- a/skylark/utils.py +++ b/skylark/utils.py @@ -23,7 +23,7 @@ def elapsed(self): def do_parallel(func, args_list, n=8, progress_bar=False, leave_pbar=True, desc=None, arg_fmt=None): """Run list of jobs in parallel with tqdm progress bar""" if arg_fmt is None: - arg_fmt = lambda x: x + arg_fmt = lambda x: x.region_tag if hasattr(x, "region_tag") else x def wrapped_fn(args): return args, func(args)