Skip to content

Commit

Permalink
Fix gateway server bug and interrupt handling for chunk dispatching (#…
Browse files Browse the repository at this point in the history
…868)

This fixes two bugs: 
* During chunk dispatching for CLI transfers, Ctrl-C could not be used
to exit out of the transfer due to threads not being properly filled.
This is fixed by adding exit signals to the tracker and stopped the
multipart threads when the tracker is exited.
* Previous `self.auth` in the `Server` instance was modified to contain
credentials for all clouds that are initialized, however this was not
properly updated for child classes.
  • Loading branch information
sarahwooders authored Jun 15, 2023
1 parent fecd572 commit 2696494
Show file tree
Hide file tree
Showing 16 changed files with 180 additions and 95 deletions.
20 changes: 10 additions & 10 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ pynacl = { version = "^1.5.0", optional = true }
pyopenssl = { version = "^22.0.0", optional = true }
werkzeug = { version = "^2.1.2", optional = true }
pyarrow = "^10.0.1"
pytest = "^7.3.2"

[tool.poetry.extras]
aws = ["boto3"]
Expand Down
2 changes: 1 addition & 1 deletion scripts/gen_data/gen_many_small.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ def make_file(data, fname):

files = [f"{outdir}/{i:08d}.bin" for i in range(args.nfiles)]
data = np.arange(args.size // 4, dtype=np.int32).tobytes()
do_parallel(partial(make_file, data), files, desc="Generating files", spinner=True, spinner_persist=True)
do_parallel(partial(make_file, data), files, desc="Generating files", spinner=True, spinner_persist=True, n=16)
4 changes: 3 additions & 1 deletion skyplane/api/dataplane.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def copy_gateway_logs(self):
# copy logs from all gateways in parallel
do_parallel(self.copy_gateway_log, self.bound_nodes.values(), n=-1)

def deprovision(self, max_jobs: int = 64, spinner: bool = False):
def deprovision(self, max_jobs: int = 64, spinner: bool = True):
"""
Deprovision the remote gateways
Expand All @@ -267,6 +267,8 @@ def deprovision(self, max_jobs: int = 64, spinner: bool = False):
for task in self.pending_transfers:
logger.fs.warning(f"Before deprovisioning, waiting for jobs to finish: {list(task.jobs.keys())}")
task.join()
for thread in threading.enumerate():
assert "_run_multipart_chunk_thread" not in thread.name, f"thread {thread.name} is still running"
except KeyboardInterrupt:
logger.warning("Interrupted while waiting for transfers to finish, deprovisioning anyway.")
raise
Expand Down
22 changes: 18 additions & 4 deletions skyplane/api/tracker.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import functools
import signal

from pprint import pprint
import json
import time
from abc import ABC
from datetime import datetime
from threading import Thread
from threading import Thread, Event

import urllib3
from typing import TYPE_CHECKING, Dict, List, Optional, Set
Expand Down Expand Up @@ -97,6 +99,14 @@ def __init__(self, dataplane, jobs: List["TransferJob"], transfer_config: Transf
self.jobs = {job.uuid: job for job in jobs}
self.transfer_config = transfer_config

# exit handling
self.exit_flag = Event()

def signal_handler(signal, frame):
self.exit_flag.set()

signal.signal(signal.SIGINT, signal_handler)

if hooks is None:
self.hooks = EmptyTransferHook()
else:
Expand Down Expand Up @@ -138,16 +148,20 @@ def run(self):
session_start_timestamp_ms = int(time.time() * 1000)
try:
# pre-dispatch chunks to begin pre-buffering chunks
chunk_streams = {
job_uuid: job.dispatch(self.dataplane, transfer_config=self.transfer_config) for job_uuid, job in self.jobs.items()
}
chunk_streams = {job_uuid: job.dispatch(self.dataplane) for job_uuid, job in self.jobs.items()}
for job_uuid, job in self.jobs.items():
logger.fs.debug(f"[TransferProgressTracker] Dispatching job {job.uuid}")
self.job_chunk_requests[job_uuid] = {}
self.job_pending_chunk_ids[job_uuid] = {region: set() for region in self.dataplane.topology.dest_region_tags}
self.job_complete_chunk_ids[job_uuid] = {region: set() for region in self.dataplane.topology.dest_region_tags}

for chunk in chunk_streams[job_uuid]:
if self.exit_flag.is_set():
logger.fs.debug(f"[TransferProgressTracker] Exiting due to signal")
self.hooks.on_dispatch_end()
self.hooks.on_transfer_end()
job.stop() # stop threads in chunk stream
return
chunks_dispatched = [chunk]
self.job_chunk_requests[job_uuid][chunk.chunk_id] = chunk
self.hooks.on_chunk_dispatched(chunks_dispatched)
Expand Down
Loading

0 comments on commit 2696494

Please sign in to comment.