diff --git a/poetry.lock b/poetry.lock index 8e17c6812..8c4ca7bb7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -772,7 +772,7 @@ scipy = ">=0.9" name = "exceptiongroup" version = "1.1.1" description = "Backport of PEP 654 (exception groups)" -category = "dev" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -1381,7 +1381,7 @@ testing = ["flake8 (<5)", "flufl.flake8", "importlib-resources (>=1.3)", "packag name = "iniconfig" version = "2.0.0" description = "brain-dead simple config-ini parsing" -category = "dev" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2095,7 +2095,7 @@ test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.3.1)", "pytest- name = "pluggy" version = "1.0.0" description = "plugin and hook calling mechanisms for python" -category = "dev" +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -2363,14 +2363,14 @@ diagrams = ["jinja2", "railroad-diagrams"] [[package]] name = "pytest" -version = "7.3.1" +version = "7.3.2" description = "pytest: simple powerful testing with Python" -category = "dev" +category = "main" optional = false python-versions = ">=3.7" files = [ - {file = "pytest-7.3.1-py3-none-any.whl", hash = "sha256:3799fa815351fea3a5e96ac7e503a96fa51cc9942c3753cda7651b93c1cfa362"}, - {file = "pytest-7.3.1.tar.gz", hash = "sha256:434afafd78b1d78ed0addf160ad2b77a30d35d4bdf8af234fe621919d9ed15e3"}, + {file = "pytest-7.3.2-py3-none-any.whl", hash = "sha256:cdcbd012c9312258922f8cd3f1b62a6580fdced17db6014896053d47cddf9295"}, + {file = "pytest-7.3.2.tar.gz", hash = "sha256:ee990a3cc55ba808b80795a79944756f315c67c12b56abd3ac993a7b8c17030b"}, ] [package.dependencies] @@ -2383,7 +2383,7 @@ pluggy = ">=0.12,<2.0" tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} [package.extras] -testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "xmlschema"] +testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] [[package]] name = "pytest-cov" @@ -2736,7 +2736,7 @@ test = ["tox (>=1.8.1)"] name = "tomli" version = "2.0.1" description = "A lil' TOML parser" -category = "dev" +category = "main" optional = false python-versions = ">=3.7" files = [ @@ -2899,4 +2899,4 @@ solver = ["cvxpy", "graphviz", "matplotlib", "numpy"] [metadata] lock-version = "2.0" python-versions = ">=3.7.1,<3.12" -content-hash = "b399e30ba267365760659d157388a4cf91af396621d7712b32dadae144e420d1" +content-hash = "3bd46488c6cc92dabe1262e664c7a227016cc42e7c2cbc203ce4abda7c5826e9" diff --git a/pyproject.toml b/pyproject.toml index 6443f3ff3..d50254d58 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,7 @@ pynacl = { version = "^1.5.0", optional = true } pyopenssl = { version = "^22.0.0", optional = true } werkzeug = { version = "^2.1.2", optional = true } pyarrow = "^10.0.1" +pytest = "^7.3.2" [tool.poetry.extras] aws = ["boto3"] diff --git a/scripts/gen_data/gen_many_small.py b/scripts/gen_data/gen_many_small.py index 24640d59e..78dde2436 100644 --- a/scripts/gen_data/gen_many_small.py +++ b/scripts/gen_data/gen_many_small.py @@ -26,4 +26,4 @@ def make_file(data, fname): files = [f"{outdir}/{i:08d}.bin" for i in range(args.nfiles)] data = np.arange(args.size // 4, dtype=np.int32).tobytes() - do_parallel(partial(make_file, data), files, desc="Generating files", spinner=True, spinner_persist=True) + do_parallel(partial(make_file, data), files, desc="Generating files", spinner=True, spinner_persist=True, n=16) diff --git a/skyplane/api/dataplane.py b/skyplane/api/dataplane.py index d009b5eee..eebe7e42f 100644 --- a/skyplane/api/dataplane.py +++ b/skyplane/api/dataplane.py @@ -246,7 +246,7 @@ def copy_gateway_logs(self): # copy logs from all gateways in parallel do_parallel(self.copy_gateway_log, self.bound_nodes.values(), n=-1) - def deprovision(self, max_jobs: int = 64, spinner: bool = False): + def deprovision(self, max_jobs: int = 64, spinner: bool = True): """ Deprovision the remote gateways @@ -267,6 +267,8 @@ def deprovision(self, max_jobs: int = 64, spinner: bool = False): for task in self.pending_transfers: logger.fs.warning(f"Before deprovisioning, waiting for jobs to finish: {list(task.jobs.keys())}") task.join() + for thread in threading.enumerate(): + assert "_run_multipart_chunk_thread" not in thread.name, f"thread {thread.name} is still running" except KeyboardInterrupt: logger.warning("Interrupted while waiting for transfers to finish, deprovisioning anyway.") raise diff --git a/skyplane/api/tracker.py b/skyplane/api/tracker.py index 9f9d2e4a7..120e59d8a 100644 --- a/skyplane/api/tracker.py +++ b/skyplane/api/tracker.py @@ -1,10 +1,12 @@ import functools +import signal + from pprint import pprint import json import time from abc import ABC from datetime import datetime -from threading import Thread +from threading import Thread, Event import urllib3 from typing import TYPE_CHECKING, Dict, List, Optional, Set @@ -97,6 +99,14 @@ def __init__(self, dataplane, jobs: List["TransferJob"], transfer_config: Transf self.jobs = {job.uuid: job for job in jobs} self.transfer_config = transfer_config + # exit handling + self.exit_flag = Event() + + def signal_handler(signal, frame): + self.exit_flag.set() + + signal.signal(signal.SIGINT, signal_handler) + if hooks is None: self.hooks = EmptyTransferHook() else: @@ -138,9 +148,7 @@ def run(self): session_start_timestamp_ms = int(time.time() * 1000) try: # pre-dispatch chunks to begin pre-buffering chunks - chunk_streams = { - job_uuid: job.dispatch(self.dataplane, transfer_config=self.transfer_config) for job_uuid, job in self.jobs.items() - } + chunk_streams = {job_uuid: job.dispatch(self.dataplane) for job_uuid, job in self.jobs.items()} for job_uuid, job in self.jobs.items(): logger.fs.debug(f"[TransferProgressTracker] Dispatching job {job.uuid}") self.job_chunk_requests[job_uuid] = {} @@ -148,6 +156,12 @@ def run(self): self.job_complete_chunk_ids[job_uuid] = {region: set() for region in self.dataplane.topology.dest_region_tags} for chunk in chunk_streams[job_uuid]: + if self.exit_flag.is_set(): + logger.fs.debug(f"[TransferProgressTracker] Exiting due to signal") + self.hooks.on_dispatch_end() + self.hooks.on_transfer_end() + job.stop() # stop threads in chunk stream + return chunks_dispatched = [chunk] self.job_chunk_requests[job_uuid][chunk.chunk_id] = chunk self.hooks.on_chunk_dispatched(chunks_dispatched) diff --git a/skyplane/api/transfer_job.py b/skyplane/api/transfer_job.py index 300a01b85..44d2ae76c 100644 --- a/skyplane/api/transfer_job.py +++ b/skyplane/api/transfer_job.py @@ -1,4 +1,5 @@ import json +import signal import time import time import typer @@ -69,7 +70,7 @@ def __init__( self, src_iface: StorageInterface, dst_ifaces: List[StorageInterface], - transfer_config: TransferConfig, + transfer_config: Optional[TransferConfig] = None, concurrent_multipart_chunk_threads: Optional[int] = 64, num_partitions: Optional[int] = 1, ): @@ -89,6 +90,28 @@ def __init__( self.multipart_upload_requests = [] self.concurrent_multipart_chunk_threads = concurrent_multipart_chunk_threads self.num_partitions = num_partitions + if transfer_config is None: + self.transfer_config = TransferConfig() + + # threads for multipart uploads + self.multipart_send_queue: Queue[TransferPair] = Queue() + self.multipart_chunk_queue: Queue[GatewayMessage] = Queue() + self.multipart_exit_event = threading.Event() + self.multipart_chunk_threads = [] + + # handle exit signal + def signal_handler(signal, frame): + self.multipart_exit_event.set() + for t in self.multipart_chunk_threads: + t.join() + + signal.signal(signal.SIGINT, signal_handler) + + def stop(self): + """Stops all threads""" + self.multipart_exit_event.set() + for t in self.multipart_chunk_threads: + t.join() def _run_multipart_chunk_thread( self, @@ -304,27 +327,22 @@ def chunk(self, transfer_pair_generator: Generator[TransferPair, None, None]) -> :param transfer_pair_generator: generator of pairs of objects to transfer :type transfer_pair_generator: Generator """ - multipart_send_queue: Queue[TransferPair] = Queue() - multipart_chunk_queue: Queue[GatewayMessage] = Queue() - multipart_exit_event = threading.Event() - multipart_chunk_threads = [] - # start chunking threads if self.transfer_config.multipart_enabled: for _ in range(self.concurrent_multipart_chunk_threads): t = threading.Thread( target=self._run_multipart_chunk_thread, - args=(multipart_exit_event, multipart_send_queue, multipart_chunk_queue), + args=(self.multipart_exit_event, self.multipart_send_queue, self.multipart_chunk_queue), daemon=False, ) t.start() - multipart_chunk_threads.append(t) + self.multipart_chunk_threads.append(t) # begin chunking loop for transfer_pair in transfer_pair_generator: src_obj = transfer_pair.src_obj if self.transfer_config.multipart_enabled and src_obj.size > self.transfer_config.multipart_threshold_mb * MB: - multipart_send_queue.put(transfer_pair) + self.multipart_send_queue.put(transfer_pair) else: if transfer_pair.src_obj.size == 0: logger.fs.debug(f"Skipping empty object {src_obj.key}") @@ -341,25 +359,25 @@ def chunk(self, transfer_pair_generator: Generator[TransferPair, None, None]) -> if self.transfer_config.multipart_enabled: # drain multipart chunk queue and yield with updated chunk IDs - while not multipart_chunk_queue.empty(): - yield multipart_chunk_queue.get() + while not self.multipart_chunk_queue.empty(): + yield self.multipart_chunk_queue.get() if self.transfer_config.multipart_enabled: # wait for processing multipart requests to finish logger.fs.debug("Waiting for multipart threads to finish") # while not multipart_send_queue.empty(): # TODO: may be an issue waiting for this in case of force-quit - while not multipart_send_queue.empty(): - logger.fs.debug(f"Remaining in multipart queue: sent {multipart_send_queue.qsize()}") + while not self.multipart_send_queue.empty(): + logger.fs.debug(f"Remaining in multipart queue: sent {self.multipart_send_queue.qsize()}") time.sleep(0.1) # send sentinel to all threads - multipart_exit_event.set() - for thread in multipart_chunk_threads: + self.multipart_exit_event.set() + for thread in self.multipart_chunk_threads: thread.join() # drain multipart chunk queue and yield with updated chunk IDs - while not multipart_chunk_queue.empty(): - yield multipart_chunk_queue.get() + while not self.multipart_chunk_queue.empty(): + yield self.multipart_chunk_queue.get() @staticmethod def batch_generator(gen_in: Generator[T, None, None], batch_size: int) -> Generator[List[T], None, None]: @@ -377,8 +395,8 @@ def batch_generator(gen_in: Generator[T, None, None], batch_size: int) -> Genera if len(batch) > 0: yield batch - @staticmethod - def prefetch_generator(gen_in: Generator[T, None, None], buffer_size: int) -> Generator[T, None, None]: + # @staticmethod + def prefetch_generator(self, gen_in: Generator[T, None, None], buffer_size: int) -> Generator[T, None, None]: """ Prefetches from generator while handing StopIteration to ensure items yield immediately. Start a thread to prefetch items from the generator and put them in a queue. Upon StopIteration, @@ -394,6 +412,8 @@ def prefetch_generator(gen_in: Generator[T, None, None], buffer_size: int) -> Ge def prefetch(): for item in gen_in: + if self.multipart_exit_event.is_set(): # exit with exit event + break queue.put(item) queue.put(sentinel) @@ -444,13 +464,15 @@ def __init__( dst_paths: List[str] or str, recursive: bool = False, requester_pays: bool = False, + transfer_config: Optional[TransferConfig] = None, uuid: str = field(init=False, default_factory=lambda: str(uuid.uuid4())), ): self.src_path = src_path - self.dst_paths = dst_paths + self.dst_paths = dst_paths if isinstance(dst_paths, list) else [dst_paths] self.recursive = recursive self.requester_pays = requester_pays self.uuid = uuid + self.transfer_config = transfer_config if transfer_config else TransferConfig() @property def transfer_type(self) -> str: @@ -543,11 +565,16 @@ def __init__( dst_paths: List[str] or str, recursive: bool = False, requester_pays: bool = False, + transfer_config: Optional[TransferConfig] = None, uuid: str = field(init=False, default_factory=lambda: str(uuid.uuid4())), ): - super().__init__(src_path, dst_paths, recursive, requester_pays, uuid) + super().__init__(src_path, dst_paths, recursive, requester_pays, transfer_config, uuid) self.transfer_list = [] self.multipart_transfer_list = [] + self.chunker = Chunker(self.src_iface, self.dst_ifaces, self.transfer_config) # TODO: should read in existing transfer config + + def stop(self): + self.chunker.stop() @property def http_pool(self): @@ -559,23 +586,18 @@ def http_pool(self): def gen_transfer_pairs( self, - chunker: Optional[Chunker] = None, - transfer_config: Optional[TransferConfig] = field(init=False, default_factory=lambda: TransferConfig()), ) -> Generator[TransferPair, None, None]: """Generate transfer pairs for the transfer job. :param chunker: chunker that makes the chunk requests :type chunker: Chunker """ - if chunker is None: # used for external access to transfer pair list - chunker = Chunker(self.src_iface, self.dst_ifaces, transfer_config) # TODO: should read in existing transfer config - yield from chunker.transfer_pair_generator(self.src_prefix, self.dst_prefixes, self.recursive, self._pre_filter_fn) + yield from self.chunker.transfer_pair_generator(self.src_prefix, self.dst_prefixes, self.recursive, self._pre_filter_fn) def dispatch( self, dataplane: "Dataplane", dispatch_batch_size: int = 100, # 6.4 GB worth of chunks - transfer_config: Optional[TransferConfig] = field(init=False, default_factory=lambda: TransferConfig()), ) -> Generator[Chunk, None, None]: """Dispatch transfer job to specified gateways. @@ -586,12 +608,12 @@ def dispatch( :param dispatch_batch_size: maximum size of the buffer to temporarily store the generators (default: 1000) :type dispatch_batch_size: int """ - chunker = Chunker(self.src_iface, self.dst_ifaces, transfer_config) - transfer_pair_generator = self.gen_transfer_pairs(chunker) # returns TransferPair objects - gen_transfer_list = chunker.tail_generator(transfer_pair_generator, self.transfer_list) - chunks = chunker.chunk(gen_transfer_list) - batches = chunker.batch_generator( - chunker.prefetch_generator(chunks, buffer_size=dispatch_batch_size * 32), batch_size=dispatch_batch_size + # chunker = Chunker(self.src_iface, self.dst_ifaces, transfer_config) + transfer_pair_generator = self.gen_transfer_pairs() # returns TransferPair objects + gen_transfer_list = self.chunker.tail_generator(transfer_pair_generator, self.transfer_list) + chunks = self.chunker.chunk(gen_transfer_list) + batches = self.chunker.batch_generator( + self.chunker.prefetch_generator(chunks, buffer_size=dispatch_batch_size * 32), batch_size=dispatch_batch_size ) # dispatch chunk requests @@ -657,8 +679,8 @@ def dispatch( yield from chunk_batch # copy new multipart transfers to the multipart transfer list - updated_len = len(chunker.multipart_upload_requests) - self.multipart_transfer_list.extend(chunker.multipart_upload_requests[n_multiparts:updated_len]) + updated_len = len(self.chunker.multipart_upload_requests) + self.multipart_transfer_list.extend(self.chunker.multipart_upload_requests[n_multiparts:updated_len]) n_multiparts = updated_len def finalize(self): @@ -739,11 +761,12 @@ def __init__( dst_paths: List[str] or str, recursive: bool = False, requester_pays: bool = False, + transfer_config: Optional[TransferConfig] = None, uuid: str = field(init=False, default_factory=lambda: str(uuid.uuid4())), num_chunks: int = 10, chunk_size_bytes: int = 1024, ): - super().__init__(src_path, dst_paths, recursive, requester_pays, uuid) + super().__init__(src_path, dst_paths, recursive, requester_pays, transfer_config, uuid) self.num_chunks = num_chunks self.chunk_size_bytes = chunk_size_bytes @@ -757,9 +780,10 @@ def __init__( src_path: str, dst_paths: List[str] or str, requester_pays: bool = False, + transfer_config: Optional[TransferConfig] = None, uuid: str = field(init=False, default_factory=lambda: str(uuid.uuid4())), ): - super().__init__(src_path, dst_paths, True, requester_pays, uuid) + super().__init__(src_path, dst_paths, True, requester_pays, transfer_config, uuid) self.transfer_list = [] self.multipart_transfer_list = [] @@ -770,17 +794,13 @@ def __init__( def gen_transfer_pairs( self, - chunker: Optional[Chunker] = None, - transfer_config: Optional[TransferConfig] = field(init=False, default_factory=lambda: TransferConfig()), ) -> Generator[TransferPair, None, None]: """Generate transfer pairs for the transfer job. :param chunker: chunker that makes the chunk requests :type chunker: Chunker """ - if chunker is None: # used for external access to transfer pair list - chunker = Chunker(self.src_iface, self.dst_ifaces, transfer_config) - transfer_pair_gen = chunker.transfer_pair_generator(self.src_prefix, self.dst_prefixes, self.recursive, self._pre_filter_fn) + transfer_pair_gen = self.chunker.transfer_pair_generator(self.src_prefix, self.dst_prefixes, self.recursive, self._pre_filter_fn) # only single destination supported assert len(self.dst_ifaces) == 1, "Only single destination supported for sync job" diff --git a/skyplane/cli/cli.py b/skyplane/cli/cli.py index eb1d38f5d..a44f374db 100644 --- a/skyplane/cli/cli.py +++ b/skyplane/cli/cli.py @@ -44,7 +44,7 @@ def deprovision( instances = query_instances() if filter_client_id: instances = [instance for instance in instances if instance.tags().get("skyplaneclientid") == filter_client_id] - + print(instances) if instances: typer.secho(f"Deprovisioning {len(instances)} instances", fg="yellow", bold=True) do_parallel(lambda instance: instance.terminate_instance(), instances, desc="Deprovisioning", spinner=True, spinner_persist=True) diff --git a/skyplane/cli/cli_cloud.py b/skyplane/cli/cli_cloud.py index 71f8be7e7..084e62faa 100644 --- a/skyplane/cli/cli_cloud.py +++ b/skyplane/cli/cli_cloud.py @@ -3,6 +3,7 @@ """ import json +import uuid import subprocess import time from collections import defaultdict @@ -260,6 +261,7 @@ def azure_check( role_idx = [i for i, r in enumerate(roles) if r["scope"] == f"/subscriptions/{account_subscription}"] check_assert(len(role_idx) >= 1, "Skyplane storage account role assigned to UMI") role_names = [roles[i]["roleDefinitionName"] for i in role_idx] + print(role_names) rprint(f"[bright_black]Skyplane storage account roles: {role_names}[/bright_black]") check_assert("Storage Blob Data Contributor" in role_names, "Skyplane storage account has Blob Data Contributor role assigned to UMI") check_assert("Storage Account Contributor" in role_names, "Skyplane storage account has Account Contributor role assigned to UMI") @@ -281,6 +283,30 @@ def azure_check( iface = AzureBlobInterface(account, container) print(iface.container_client.get_container_properties()) + # check if writeable + rprint(f"\n{hline}\n[bold]Checking Skyplane AzureBlobInterface write access[/bold]\n{hline}") + import tempfile + import random + import string + + def generate_random_string(length): + """Generate a random string of given length""" + letters = string.ascii_letters + return "".join(random.choice(letters) for _ in range(length)) + + def create_temp_file(size): + """Create a temporary file with random data""" + with tempfile.NamedTemporaryFile(delete=False) as temp_file: + temp_file_path = temp_file.name + random_data = generate_random_string(size) + temp_file.write(random_data.encode()) + return temp_file_path + + tmp_file = create_temp_file(1024) + tmp_object = f"skyplane-{uuid.uuid4()}" + iface.upload_object(tmp_file, tmp_object) + iface.delete_objects([tmp_object]) + @app.command() def gcp_check( diff --git a/skyplane/cli/cli_init.py b/skyplane/cli/cli_init.py index e946de451..59cf95951 100644 --- a/skyplane/cli/cli_init.py +++ b/skyplane/cli/cli_init.py @@ -512,7 +512,7 @@ def init( # load AWS config if not (reinit_azure or reinit_gcp or reinit_ibm): - typer.secho("\n(1) configuring AWS:", fg="yellow", bold=True) + typer.secho("\n(1) Configuring AWS:", fg="yellow", bold=True) if not disable_config_aws: cloud_config = load_aws_config(cloud_config, non_interactive=non_interactive) diff --git a/skyplane/cli/cli_transfer.py b/skyplane/cli/cli_transfer.py index a1b30d963..c49edef31 100644 --- a/skyplane/cli/cli_transfer.py +++ b/skyplane/cli/cli_transfer.py @@ -7,6 +7,7 @@ import typer from rich.progress import Progress, TextColumn, SpinnerColumn +from rich import print as rprint import skyplane from skyplane.api.config import TransferConfig, AWSConfig, GCPConfig, AzureConfig, IBMCloudConfig @@ -292,8 +293,9 @@ def estimate_small_transfer(self, job: TransferJob, size_threshold_bytes: float, def force_deprovision(dp: skyplane.Dataplane): + rprint(f"\n:x: [bold red]Force deprovisioning dataplane[/bold red]") s = signal.signal(signal.SIGINT, signal.SIG_IGN) - dp.deprovision() + dp.deprovision(spinner=True) signal.signal(signal.SIGINT, s) @@ -371,12 +373,12 @@ def run_transfer( elif cloud_config.get_flag("native_cmd_enabled"): # fallback option: transfer is too small if cli.args["cmd"] == "cp": - job = CopyJob(src, [dst], recursive=recursive) # TODO: rever to using pipeline + job = CopyJob(src, [dst], recursive=recursive, transfer_config=cli.transfer_config) # TODO: rever to using pipeline if cli.estimate_small_transfer(job, cloud_config.get_flag("native_cmd_threshold_gb") * GB): small_transfer_status = cli.transfer_cp_small(src, dst, recursive) return 0 if small_transfer_status else 1 else: - job = SyncJob(src, [dst]) + job = SyncJob(src, [dst], transfer_config=cli.transfer_config) if cli.estimate_small_transfer(job, cloud_config.get_flag("native_cmd_threshold_gb") * GB): small_transfer_status = cli.transfer_sync_small(src, dst) return 0 if small_transfer_status else 1 diff --git a/skyplane/compute/aws/aws_server.py b/skyplane/compute/aws/aws_server.py index a2c1d2a22..3a087e14f 100644 --- a/skyplane/compute/aws/aws_server.py +++ b/skyplane/compute/aws/aws_server.py @@ -24,8 +24,8 @@ class AWSServer(Server): def __init__(self, region_tag, instance_id, log_dir=None): super().__init__(region_tag, log_dir=log_dir) assert self.region_tag.split(":")[0] == "aws" - self.auth = AWSAuthentication() - self.key_manager = AWSKeyManager(self.auth) + assert "aws" in self.auth, f"AWS Server created but not authenticated with AWS" + self.key_manager = AWSKeyManager(self.auth["aws"]) self.aws_region = self.region_tag.split(":")[1] self.instance_id = instance_id @@ -33,7 +33,7 @@ def __init__(self, region_tag, instance_id, log_dir=None): @functools.lru_cache(maxsize=None) def login_name(self) -> str: # update the login name according to AMI - ec2 = self.auth.get_boto3_resource("ec2", self.aws_region) + ec2 = self.auth["aws"].get_boto3_resource("ec2", self.aws_region) ec2client = ec2.meta.client image_info = ec2client.describe_images(ImageIds=[ec2.Instance(self.instance_id).image_id]) if [r["Name"] for r in image_info["Images"]][0].split("/")[0] == "ubuntu": @@ -52,7 +52,7 @@ def uuid(self): return f"{self.region_tag}:{self.instance_id}" def get_boto3_instance_resource(self): - ec2 = self.auth.get_boto3_resource("ec2", self.aws_region) + ec2 = self.auth["aws"].get_boto3_resource("ec2", self.aws_region) return ec2.Instance(self.instance_id) @ignore_lru_cache() @@ -101,7 +101,7 @@ def __repr__(self): return f"AWSServer(region_tag={self.region_tag}, instance_id={self.instance_id})" def terminate_instance_impl(self): - iam = self.auth.get_boto3_resource("iam") + iam = self.auth["aws"].get_boto3_resource("iam") # get instance profile name that is associated with this instance profile = self.get_boto3_instance_resource().iam_instance_profile diff --git a/skyplane/compute/azure/azure_server.py b/skyplane/compute/azure/azure_server.py index 4f995c50d..9538a826f 100644 --- a/skyplane/compute/azure/azure_server.py +++ b/skyplane/compute/azure/azure_server.py @@ -20,7 +20,6 @@ class AzureServer(Server): resource_group_location = "westus2" def __init__(self, name: str, key_root: PathLike = key_root / "azure", log_dir=None, ssh_private_key=None, assume_exists=True): - self.auth = AzureAuthentication() self.name = name self.location = None @@ -32,6 +31,7 @@ def __init__(self, name: str, key_root: PathLike = key_root / "azure", log_dir=N region_tag = "azure:UNKNOWN" super().__init__(region_tag, log_dir=log_dir) + assert "azure" in self.auth, f"Azure Server created but not authenticated with Azure" key_root = Path(key_root) key_root.mkdir(parents=True, exist_ok=True) @@ -85,7 +85,7 @@ def nic_name(name): return AzureServer.vm_name(name) + "-nic" def get_virtual_machine(self): - compute_client = self.auth.get_compute_client() + compute_client = self.auth["azure"].get_compute_client() vm = compute_client.virtual_machines.get(AzureServer.resource_group_name, AzureServer.vm_name(self.name)) # Sanity checks @@ -106,7 +106,7 @@ def uuid(self): return f"{self.region_tag}:{self.name}" def instance_state(self) -> ServerState: - compute_client = self.auth.get_compute_client() + compute_client = self.auth["azure"].get_compute_client() vm_instance_view = compute_client.virtual_machines.instance_view(AzureServer.resource_group_name, AzureServer.vm_name(self.name)) statuses = vm_instance_view.statuses for status in statuses: @@ -116,7 +116,7 @@ def instance_state(self) -> ServerState: @ignore_lru_cache() def public_ip(self): - network_client = self.auth.get_network_client() + network_client = self.auth["azure"].get_network_client() public_ip = network_client.public_ip_addresses.get(AzureServer.resource_group_name, AzureServer.ip_name(self.name)) # Sanity checks @@ -147,10 +147,10 @@ def network_tier(self): return "PREMIUM" def terminate_instance_impl(self): - compute_client = self.auth.get_compute_client() - network_client = self.auth.get_network_client() + compute_client = self.auth["azure"].get_compute_client() + network_client = self.auth["azure"].get_network_client() - self.auth.get_authorization_client() + self.auth["azure"].get_authorization_client() self.get_virtual_machine() vm_poller = compute_client.virtual_machines.begin_delete(AzureServer.resource_group_name, self.vm_name(self.name)) diff --git a/skyplane/compute/gcp/gcp_server.py b/skyplane/compute/gcp/gcp_server.py index ed4029d74..397e7dd87 100644 --- a/skyplane/compute/gcp/gcp_server.py +++ b/skyplane/compute/gcp/gcp_server.py @@ -19,7 +19,7 @@ def __init__(self, region_tag: str, instance_name: str, key_root: PathLike = key super().__init__(region_tag, log_dir=log_dir) assert self.region_tag.split(":")[0] == "gcp", f"Region name doesn't match pattern gcp: {self.region_tag}" self.gcp_region = self.region_tag.split(":")[1] - self.auth = GCPAuthentication() + assert "gcp" in self.auth, f"GCP Server created but not authenticated with GCP" self.gcp_instance_name = instance_name key_root = Path(key_root) key_root.mkdir(parents=True, exist_ok=True) @@ -33,7 +33,7 @@ def uuid(self): @lru_cache(maxsize=1) def get_gcp_instance(self): - instances = self.auth.get_gcp_instances(self.gcp_region) + instances = self.auth["gcp"].get_gcp_instances(self.gcp_region) if "items" in instances: for i in instances["items"]: if i["name"] == self.gcp_instance_name: @@ -79,8 +79,8 @@ def __repr__(self): return f"GCPServer(region_tag={self.region_tag}, instance_name={self.gcp_instance_name})" def terminate_instance_impl(self): - self.auth.get_gcp_client().instances().delete( - project=self.auth.project_id, zone=self.gcp_region, instance=self.instance_name() + self.auth["gcp"].get_gcp_client().instances().delete( + project=self.auth["gcp"].project_id, zone=self.gcp_region, instance=self.instance_name() ).execute() def get_ssh_client_impl(self, uname="skyplane", ssh_key_password="skyplane"): diff --git a/skyplane/planner/planner.py b/skyplane/planner/planner.py index 6acedbeaf..8d9ed87c7 100644 --- a/skyplane/planner/planner.py +++ b/skyplane/planner/planner.py @@ -33,6 +33,7 @@ class Planner: def __init__(self, transfer_config: TransferConfig): self.transfer_config = transfer_config self.config = SkyplaneConfig.load_config(config_path) + self.n_instances = self.config.get_flag("max_instances") # Loading the quota information, add ibm cloud when it is supported self.quota_limits = {} @@ -116,11 +117,14 @@ def _calculate_vm_types(self, region_tag: str) -> Optional[Tuple[str, int]]: cloud_provider=cloud_provider, region=region, spot=getattr(self.transfer_config, f"{cloud_provider}_use_spot_instances") ) + config_vm_type = getattr(self.transfer_config, f"{cloud_provider}_instance_class") + # No quota limits (quota limits weren't initialized properly during skyplane init) if quota_limit is None: - return None + logger.warning(f"Quota limit file not found for {region_tag}") + # return default instance type and number of instances + return config_vm_type, self.n_instances - config_vm_type = getattr(self.transfer_config, f"{cloud_provider}_instance_class") config_vcpus = self._vm_to_vcpus(cloud_provider, config_vm_type) if config_vcpus <= quota_limit: return config_vm_type, quota_limit // config_vcpus @@ -144,9 +148,7 @@ def _calculate_vm_types(self, region_tag: str) -> Optional[Tuple[str, int]]: ) return (vm_type, n_instances) - def _get_vm_type_and_instances( - self, src_region_tag: Optional[str] = None, dst_region_tags: Optional[List[str]] = None - ) -> Tuple[Dict[str, str], int]: + def _get_vm_type_and_instances(self, src_region_tag: str, dst_region_tags: List[str]) -> Tuple[Dict[str, str], int]: """Dynamically calculates the vm type each region can use (both the source region and all destination regions) based on their quota limits and calculates the number of vms to launch in all regions by conservatively taking the minimum of all regions to stay consistent. @@ -156,10 +158,16 @@ def _get_vm_type_and_instances( :param dst_region_tags: a list of the destination region tags (defualt: None) :type dst_region_tags: Optional[List[str]] """ + # One of them has to provided - assert src_region_tag is not None or dst_region_tags is not None, "There needs to be at least one source or destination" - src_tags = [src_region_tag] if src_region_tag is not None else [] - dst_tags = dst_region_tags or [] + # assert src_region_tag is not None or dst_region_tags is not None, "There needs to be at least one source or destination" + src_tags = [src_region_tag] # if src_region_tag is not None else [] + dst_tags = dst_region_tags # or [] + + assert len(src_region_tag.split(":")) == 2, f"Source region tag {src_region_tag} must be in the form of `cloud_provider:region`" + assert ( + len(dst_region_tags[0].split(":")) == 2 + ), f"Destination region tag {dst_region_tags} must be in the form of `cloud_provider:region`" # do_parallel returns tuples of (region_tag, (vm_type, n_instances)) vm_info = do_parallel(self._calculate_vm_types, src_tags + dst_tags) @@ -184,6 +192,12 @@ def plan(self, jobs: List[TransferJob]) -> TopologyPlan: src_region_tag = jobs[0].src_iface.region_tag() dst_region_tag = jobs[0].dst_ifaces[0].region_tag() + + assert len(src_region_tag.split(":")) == 2, f"Source region tag {src_region_tag} must be in the form of `cloud_provider:region`" + assert ( + len(dst_region_tag.split(":")) == 2 + ), f"Destination region tag {dst_region_tag} must be in the form of `cloud_provider:region`" + # jobs must have same sources and destinations for job in jobs[1:]: assert job.src_iface.region_tag() == src_region_tag, "All jobs must have same source region" @@ -242,9 +256,9 @@ def plan(self, jobs: List[TransferJob]) -> TopologyPlan: class MulticastDirectPlanner(Planner): def __init__(self, n_instances: int, n_connections: int, transfer_config: TransferConfig): + super().__init__(transfer_config) self.n_instances = n_instances self.n_connections = n_connections - super().__init__(transfer_config) def plan(self, jobs: List[TransferJob]) -> TopologyPlan: src_region_tag = jobs[0].src_iface.region_tag() @@ -368,7 +382,7 @@ def plan(self, jobs: List[TransferJob]) -> TopologyPlan: plan = TopologyPlan(src_region_tag=src_region_tag, dest_region_tags=dst_region_tags) # Dynammically calculate n_instances based on quota limits - vm_types, n_instances = self._get_vm_type_and_instances(src_region_tag=src_region_tag) + vm_types, n_instances = self._get_vm_type_and_instances(src_region_tag, dst_region_tags) # TODO: support on-sided transfers but not requiring VMs to be created in source/destination regions for i in range(n_instances): @@ -429,7 +443,7 @@ def plan(self, jobs: List[TransferJob]) -> TopologyPlan: plan = TopologyPlan(src_region_tag=src_region_tag, dest_region_tags=dst_region_tags) # Dynammically calculate n_instances based on quota limits - vm_types, n_instances = self._get_vm_type_and_instances(dst_region_tags=dst_region_tags) + vm_types, n_instances = self._get_vm_type_and_instances(src_region_tag, dst_region_tags) # TODO: support on-sided transfers but not requiring VMs to be created in source/destination regions for i in range(n_instances): diff --git a/tests/gateway/test_gateway_obj_store.py b/tests/gateway/test_gateway_obj_store.py index 81f46ac8c..5d7c7ff9b 100644 --- a/tests/gateway/test_gateway_obj_store.py +++ b/tests/gateway/test_gateway_obj_store.py @@ -63,7 +63,8 @@ def check_container_running(container_name): def run(gateway_docker_image: str, restart_gateways: bool): """Run the gateway docker image locally""" - job = CopyJob("s3://feature-store-datasets/yahoo/processed_yahoo/A1/", ["gs://38046a6749df436886491a95cacdebb8/yahoo/"], recursive=True) + # job = CopyJob("s3://feature-store-datasets/yahoo/processed_yahoo/A1/", ["gs://38046a6749df436886491a95cacdebb8/yahoo/"], recursive=True) + job = CopyJob("gs://skyplane-test-bucket/files_100000_size_4_mb/", "s3://integrationus-west-2-4450f073/", recursive=True) topology = MulticastDirectPlanner(1, 64, TransferConfig()).plan([job]) print([g.region_tag for g in topology.get_gateways()]) diff --git a/tests/interface_util.py b/tests/interface_util.py index f6d1f85d9..5b9084c85 100644 --- a/tests/interface_util.py +++ b/tests/interface_util.py @@ -12,7 +12,10 @@ def interface_test_framework(region, bucket, multipart: bool, test_delete_bucket: bool = False, file_size_mb: int = 1): interface = ObjectStoreInterface.create(region, bucket) interface.create_bucket(region.split(":")[1]) - time.sleep(5) + # time.sleep(10) + while not interface.bucket_exists(): + print("waiting for bucket to exist") + time.sleep(1) # generate file and upload obj_name = f"test_{uuid.uuid4()}.txt" @@ -63,7 +66,9 @@ def interface_test_framework(region, bucket, multipart: bool, test_delete_bucket assert obj_name in objs[0].key, f"{objs[0].key} != {obj_name}" assert objs[0].size == file_size_mb * MB, f"{objs[0].size} != {file_size_mb * MB}" - # interface.delete_objects([obj_name]) - # if test_delete_bucket: - # interface.delete_bucket() + # delete bucket + if test_delete_bucket: + interface.delete_bucket() + assert not interface.bucket_exists(), "Bucket should not exist" + return True