Skip to content

Commit

Permalink
[gateway] Add new API call to spawn and close server (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
parasj authored Dec 27, 2021
1 parent 58dec4a commit bd93d25
Show file tree
Hide file tree
Showing 8 changed files with 176 additions and 182 deletions.
8 changes: 8 additions & 0 deletions scripts/checkpy.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#!/bin/bash
cd $(dirname $0)/..

set -xe
pip install -e .
pip install pytype black
black --line-length 140 skylark
pytype skylark
14 changes: 9 additions & 5 deletions skylark/compute/aws/aws_cloud_provider.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Optional
import uuid

from loguru import logger
Expand Down Expand Up @@ -55,7 +55,7 @@ def get_transfer_cost(src_key, dst_key):
else:
raise NotImplementedError

def get_instance_list(self, region) -> List[AWSServer]:
def get_instance_list(self, region: str) -> List[AWSServer]:
ec2 = AWSServer.get_boto3_resource("ec2", region)
instances = ec2.instances.filter(
Filters=[
Expand All @@ -69,7 +69,9 @@ def get_instance_list(self, region) -> List[AWSServer]:
instances = [AWSServer(f"aws:{region}", i) for i in instance_ids]
return instances

def add_ip_to_security_group(self, aws_region, security_group_id: str = None, ip="0.0.0.0/0", from_port=0, to_port=65535):
def add_ip_to_security_group(
self, aws_region: str, security_group_id: Optional[str] = None, ip="0.0.0.0/0", from_port=0, to_port=65535
):
"""Add IP to security group. If security group ID is None, use default."""
ec2 = AWSServer.get_boto3_resource("ec2", aws_region)
if security_group_id is None:
Expand All @@ -84,7 +86,7 @@ def add_ip_to_security_group(self, aws_region, security_group_id: str = None, ip
logger.info(f"({aws_region}) Added IP {ip} to security group {security_group_id}")

@staticmethod
def get_ubuntu_ami_id(region):
def get_ubuntu_ami_id(region: str) -> str:
client = AWSServer.get_boto3_client("ec2", region)
response = client.describe_images(
Filters=[
Expand All @@ -109,7 +111,9 @@ def get_ubuntu_ami_id(region):
image_list = sorted(response["Images"], key=lambda x: x["CreationDate"], reverse=True)
return image_list[0]["ImageId"]

def provision_instance(self, region, instance_class, name=None, ami_id=None, tags={"skylark": "true"}) -> AWSServer:
def provision_instance(
self, region: str, instance_class: str, name: Optional[str] = None, ami_id: Optional[str] = None, tags={"skylark": "true"}
) -> AWSServer:
assert not region.startswith("aws:"), "Region should be AWS region"
if name is None:
name = f"skylark-aws-{str(uuid.uuid4()).replace('-', '')}"
Expand Down
3 changes: 0 additions & 3 deletions skylark/compute/aws/aws_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,11 @@
from pathlib import Path

import boto3
from boto3 import session
import paramiko
from loguru import logger
import questionary

from skylark.compute.server import Server, ServerState
from skylark import key_root
from tqdm import tqdm


class AWSServer(Server):
Expand Down
10 changes: 7 additions & 3 deletions skylark/compute/cloud_providers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import threading
from typing import List, Union
from typing import List, Optional, Union

from loguru import logger

Expand All @@ -14,7 +14,7 @@ def name():
raise NotImplementedError

@staticmethod
def region_list(self):
def region_list():
raise NotImplementedError

@staticmethod
Expand All @@ -37,7 +37,11 @@ def get_instance_list(self, region) -> List[Server]:
raise NotImplementedError

def get_matching_instances(
self, region=None, instance_type=None, state: Union[ServerState, List[ServerState]] = None, tags={"skylark": "true"}
self,
region: Optional[str] = None,
instance_type: Optional[str] = None,
state: Optional[Union[ServerState, List[ServerState]]] = None,
tags={"skylark": "true"},
) -> List[Server]:
if isinstance(region, str):
region = [region]
Expand Down
2 changes: 0 additions & 2 deletions skylark/compute/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,6 @@ def wait_for_ready(self, timeout=120, verbose=False) -> bool:
except Exception as e:
continue
logger.warning(f"({self.region_tag}) Timeout waiting for server to be ready")
if e is not None:
logger.exception(e)
return False

def close_server(self):
Expand Down
66 changes: 66 additions & 0 deletions skylark/gateway/chunk_header.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from dataclasses import dataclass
import socket


@dataclass
class ChunkHeader:
# sent over wire in order:
# magic
# chunk_id
# chunk_size_bytes
# chunk_offset_bytes
# end_of_stream
# chunk_hash_sha256
chunk_id: int # unsigned long
chunk_size_bytes: int # unsigned long
chunk_offset_bytes: int # unsigned long
chunk_hash_sha256: str # 64-byte checksum
end_of_stream: bool = False # false by default, but true if this is the last chunk

@staticmethod
def magic_hex():
return 0x534B595F4C41524B # "SKY_LARK"

@staticmethod
def length_bytes():
# magic (8) + chunk_id (8) + chunk_size_bytes (8) + chunk_offset_bytes (8) + end_of_stream (1) + chunk_hash_sha256 (64)
return 8 + 8 + 8 + 8 + 1 + 64

@staticmethod
def from_bytes(data: bytes):
assert len(data) == ChunkHeader.length_bytes()
magic = int.from_bytes(data[:8], byteorder="big")
if magic != ChunkHeader.magic_hex():
raise ValueError("Invalid magic number")
chunk_id = int.from_bytes(data[8:16], byteorder="big")
chunk_size_bytes = int.from_bytes(data[16:24], byteorder="big")
chunk_offset_bytes = int.from_bytes(data[24:32], byteorder="big")
chunk_end_of_stream = bool(data[32])
chunk_hash_sha256 = data[33:].decode("utf-8")
return ChunkHeader(
chunk_id=chunk_id,
chunk_size_bytes=chunk_size_bytes,
chunk_offset_bytes=chunk_offset_bytes,
chunk_hash_sha256=chunk_hash_sha256,
end_of_stream=chunk_end_of_stream,
)

def to_bytes(self):
out_bytes = b""
out_bytes += self.magic_hex().to_bytes(8, byteorder="big")
out_bytes += self.chunk_id.to_bytes(8, byteorder="big")
out_bytes += self.chunk_size_bytes.to_bytes(8, byteorder="big")
out_bytes += self.chunk_offset_bytes.to_bytes(8, byteorder="big")
out_bytes += bytes([int(self.end_of_stream)])
assert len(self.chunk_hash_sha256) == 64
out_bytes += self.chunk_hash_sha256.encode("utf-8")
assert len(out_bytes) == ChunkHeader.length_bytes()
return out_bytes

@staticmethod
def from_socket(sock: socket.socket):
header_bytes = sock.recv(ChunkHeader.length_bytes())
return ChunkHeader.from_bytes(header_bytes)

def to_socket(self, sock: socket.socket):
assert sock.sendall(self.to_bytes()) == None
Loading

0 comments on commit bd93d25

Please sign in to comment.