Skip to content

Commit

Permalink
Fix AWS and GCP key generation (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
parasj authored Dec 16, 2021
1 parent 2bf3d5d commit e4db252
Show file tree
Hide file tree
Showing 10 changed files with 50 additions and 40 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"pandas",
"paramiko",
"tqdm",
"questionary",
],
extras_require={"test": ["black", "pytest", "ipython", "jupyter_console"]}
)
2 changes: 2 additions & 0 deletions skylark/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pathlib import Path

skylark_root = Path(__file__).parent.parent
data_root = skylark_root / "data"
key_root = skylark_root / "data" / "keys"
5 changes: 1 addition & 4 deletions skylark/benchmark/network/throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,8 @@ def main(args):
log_dir = data_dir / "logs" / "throughput" / datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
log_dir.mkdir(exist_ok=True, parents=True)

gcp_private_key = str(data_dir / "keys" / "gcp-cert.pem")
gcp_public_key = str(data_dir / "keys" / "gcp-cert.pub")

aws = AWSCloudProvider()
gcp = GCPCloudProvider(args.gcp_project, gcp_private_key, gcp_public_key)
gcp = GCPCloudProvider(args.gcp_project)
aws_instances, gcp_instances = provision(
aws=aws,
gcp=gcp,
Expand Down
15 changes: 2 additions & 13 deletions skylark/benchmark/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,7 @@
def refresh_instance_list(provider: CloudProvider, region_list=[None], instance_filter=None) -> Dict[str, List[Server]]:
if instance_filter is None:
instance_filter = {"tags": {"skylark": "true"}}
results = do_parallel(
lambda region: provider.get_matching_instances(region=region, **instance_filter),
region_list,
progress_bar=True,
# desc=f"refresh {provider.name}",
)
results = do_parallel(lambda region: provider.get_matching_instances(region=region, **instance_filter), region_list, progress_bar=False)
return {r: ilist for r, ilist in results if ilist}


Expand Down Expand Up @@ -69,9 +64,6 @@ def provision(
results = do_parallel(aws_provisioner, missing_aws_regions, progress_bar=True, desc="provision aws")
for region, result in results:
aws_instances[region] = [result]
logger.info(f"(aws:{region}) provisioned {result}, waiting for ready")
result.wait_for_ready()
logger.info(f"(aws:{region}) ready")
aws_instances = refresh_instance_list(aws, aws_regions_to_provision, aws_instance_filter)

if len(gcp_regions_to_provision) > 0:
Expand All @@ -86,8 +78,6 @@ def provision(
gcp.configure_default_network()
gcp.configure_default_firewall()
gcp_instances = refresh_instance_list(gcp, gcp_regions_to_provision, gcp_instance_filter)

# provision missing regions using do_parallel
missing_gcp_regions = set(gcp_regions_to_provision) - set(gcp_instances.keys())
if missing_gcp_regions:
logger.info(f"(gcp) provisioning missing regions: {missing_gcp_regions}")
Expand All @@ -97,7 +87,6 @@ def provision(
)
for region, result in results:
gcp_instances[region] = [result]
logger.info(f"(gcp:{region}) provisioned {result}")
gcp_instances = refresh_instance_list(gcp, gcp_regions_to_provision, gcp_instance_filter)

# init log files
Expand All @@ -108,5 +97,5 @@ def init(i: Server):
i.copy_and_run_script(setup_script)

all_instances = [i for ilist in aws_instances.values() for i in ilist] + [i for ilist in gcp_instances.values() for i in ilist]
do_parallel(init, all_instances, progress_bar=True, desc="init instances")
do_parallel(init, all_instances, progress_bar=True, desc="Provisioning init")
return aws_instances, gcp_instances
2 changes: 1 addition & 1 deletion skylark/compute/aws/aws_cloud_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def provision_instance(self, region, instance_class, name=None, ami_id=None, tag
InstanceType=instance_class,
MinCount=1,
MaxCount=1,
KeyName=region,
KeyName=f"skylark-{region}",
TagSpecifications=[
{
"ResourceType": "instance",
Expand Down
30 changes: 21 additions & 9 deletions skylark/compute/aws/aws_server.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,27 @@
import os
from functools import lru_cache
from pathlib import Path

import boto3
from boto3 import session
import paramiko
from loguru import logger
import questionary

from skylark.compute.server import Server, ServerState
from skylark import key_root
from tqdm import tqdm


class AWSServer(Server):
"""AWS Server class to support basic SSH operations"""

def __init__(self, region_tag, instance_id, log_dir=None):
def __init__(self, region_tag, instance_id, log_dir=None, key_root=key_root / "aws"):
super().__init__(region_tag, log_dir=log_dir)
assert self.region_tag.split(":")[0] == "aws"
self.aws_region = self.region_tag.split(":")[1]
self.instance_id = instance_id
self.local_keyfile = self.make_keyfile()
self.local_keyfile = self.make_keyfile(key_root)

def uuid(self):
return f"{self.region_tag}:{self.instance_id}"
Expand Down Expand Up @@ -49,15 +53,23 @@ def get_boto3_client(cls, service_name, aws_region):
setattr(cls.ns, ns_key, client)
return getattr(cls.ns, ns_key)

def make_keyfile(self):
local_key_file = os.path.expanduser(f"~/.ssh/{self.aws_region}.pem")
def make_keyfile(self, prefix):
prefix = Path(prefix)
key_name = f"skylark-{self.aws_region}"
local_key_file = prefix / f"{key_name}.pem"
ec2 = AWSServer.get_boto3_resource("ec2", self.aws_region)
if not os.path.exists(local_key_file):
key_pair = ec2.create_key_pair(KeyName=self.aws_region)
with open(local_key_file, "w") as f:
ec2_client = AWSServer.get_boto3_client("ec2", self.aws_region)
if not local_key_file.exists():
prefix.mkdir(parents=True, exist_ok=True)
# delete key pair from ec2 if it exists
keys_in_region = set(p["KeyName"] for p in ec2_client.describe_key_pairs()["KeyPairs"])
if key_name in keys_in_region:
logger.warning(f"Deleting key {key_name} in region {self.aws_region}")
ec2_client.delete_key_pair(KeyName=key_name)
key_pair = ec2.create_key_pair(KeyName=f"skylark-{self.aws_region}")
with local_key_file.open("w") as f:
f.write(key_pair.key_material)
os.chmod(local_key_file, 0o600)
logger.info(f"({self.aws_region}) Created keypair and saved to {local_key_file}")
return local_key_file

@property
Expand Down Expand Up @@ -115,5 +127,5 @@ def terminate_instance_impl(self):
def get_ssh_client_impl(self):
client = paramiko.SSHClient()
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
client.connect(self.public_ip, username="ubuntu", key_filename=self.local_keyfile)
client.connect(self.public_ip, username="ubuntu", key_filename=str(self.local_keyfile))
return client
12 changes: 7 additions & 5 deletions skylark/compute/gcp/gcp_cloud_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,19 @@

import paramiko

from skylark.compute.gcp.gcp_server import GCPServer, DEFAULT_GCP_PRIVATE_KEY_PATH, DEFAULT_GCP_PUBLIC_KEY_PATH
from skylark.compute.gcp.gcp_server import GCPServer
from skylark.compute.cloud_providers import CloudProvider
from skylark.compute.server import Server, ServerState
from skylark.compute.server import Server
from skylark import key_root


class GCPCloudProvider(CloudProvider):
def __init__(self, gcp_project, private_key_path=DEFAULT_GCP_PRIVATE_KEY_PATH, public_key_path=DEFAULT_GCP_PUBLIC_KEY_PATH):
def __init__(self, gcp_project, key_root=key_root / "gcp"):
super().__init__()
self.gcp_project = gcp_project
self.private_key_path = os.path.expanduser(private_key_path)
self.public_key_path = os.path.expanduser(public_key_path)
key_root.mkdir(parents=True, exist_ok=True)
self.private_key_path = key_root / "gcp-cert.pem"
self.public_key_path = key_root / "gcp-cert.pub"

@property
def name(self):
Expand Down
15 changes: 9 additions & 6 deletions skylark/compute/gcp/gcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,18 @@
from tqdm import tqdm

from skylark.compute.server import Server, ServerState

DEFAULT_GCP_PRIVATE_KEY_PATH = os.path.expanduser("~/.ssh/google_compute_engine")
DEFAULT_GCP_PUBLIC_KEY_PATH = os.path.expanduser("~/.ssh/google_compute_engine.pub")
from skylark import key_root


class GCPServer(Server):
def __init__(self, region_tag, gcp_project, instance_name, ssh_private_key=DEFAULT_GCP_PRIVATE_KEY_PATH, log_dir=None):
def __init__(self, region_tag, gcp_project, instance_name, key_root=key_root / "gcp", log_dir=None):
super().__init__(region_tag, log_dir=log_dir)
assert self.region_tag.split(":")[0] == "gcp", f"Region name doesn't match pattern gcp:<region> {self.region_tag}"
self.gcp_region = self.region_tag.split(":")[1]
self.gcp_project = gcp_project
self.gcp_instance_name = instance_name
self.ssh_private_key = os.path.expanduser(ssh_private_key)
key_root.mkdir(parents=True, exist_ok=True)
self.ssh_private_key = key_root / f"gcp.pem"

def uuid(self):
return f"{self.region_tag}:{self.gcp_instance_name}"
Expand Down Expand Up @@ -97,6 +96,10 @@ def get_ssh_client_impl(self, uname=os.environ.get("USER"), ssh_key_password="sk
ssh_client = paramiko.SSHClient()
ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
ssh_client.connect(
hostname=self.public_ip, username=uname, key_filename=self.ssh_private_key, passphrase=ssh_key_password, look_for_keys=False
hostname=self.public_ip,
username=uname,
key_filename=str(self.ssh_private_key),
passphrase=ssh_key_password,
look_for_keys=False,
)
return ssh_client
6 changes: 5 additions & 1 deletion skylark/compute/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,11 @@ def wait_for_ready(self, timeout=120) -> bool:
while (time.time() - start_time) < timeout:
try:
if self.instance_state == ServerState.RUNNING:
return True
try:
self.run_command("true")
return True
except Exception as e:
logger.warning(f"{self.instance_name} is not ready: {e}")
time.sleep(wait_intervals.pop(0))
except Exception as e:
print(f"Error waiting for server to be ready: {e}")
Expand Down
2 changes: 1 addition & 1 deletion skylark/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def elapsed(self):
def do_parallel(func, args_list, n=8, progress_bar=False, leave_pbar=True, desc=None, arg_fmt=None):
"""Run list of jobs in parallel with tqdm progress bar"""
if arg_fmt is None:
arg_fmt = lambda x: x
arg_fmt = lambda x: x.region_tag if hasattr(x, "region_tag") else x

def wrapped_fn(args):
return args, func(args)
Expand Down

0 comments on commit e4db252

Please sign in to comment.