Skip to content

Commit

Permalink
[air] Fix: Gracefully fail in file stat lookup race conditions (ray-p…
Browse files Browse the repository at this point in the history
…roject#27169)

See ray-project#27168, this fixes a race condition where a file is os.lstat'ed while being deleted.

Signed-off-by: Kai Fricke <[email protected]>
  • Loading branch information
krfricke authored and Kai Fricke committed Jul 28, 2022
1 parent 0fa2e1b commit edb1527
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 3 deletions.
31 changes: 31 additions & 0 deletions python/ray/air/tests/test_remote_storage.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import os
import threading
from unittest.mock import patch

import pytest
import shutil
import tempfile
Expand All @@ -7,6 +10,7 @@
upload_to_uri,
download_from_uri,
)
from ray.tune.utils.file_transfer import _get_recursive_files_and_stats


@pytest.fixture
Expand Down Expand Up @@ -126,6 +130,33 @@ def test_upload_exclude_multimatch(temp_data_dirs):
assert_file(False, tmp_target, "subdir_exclude/something/somewhere.txt")


def test_get_recursive_files_race_con(temp_data_dirs):
tmp_source, _ = temp_data_dirs

def run(event):
lst = os.lstat

def waiting_lstat(*args, **kwargs):
event.wait()
return lst(*args, **kwargs)

with patch("os.lstat", wraps=waiting_lstat):
_get_recursive_files_and_stats(tmp_source)

event = threading.Event()

get_thread = threading.Thread(target=run, args=(event,))
get_thread.start()

os.remove(os.path.join(tmp_source, "level0.txt"))
event.set()

get_thread.join()

assert_file(False, tmp_source, "level0.txt")
assert_file(True, tmp_source, "level0_exclude.txt")


if __name__ == "__main__":
import sys

Expand Down
11 changes: 8 additions & 3 deletions python/ray/tune/utils/file_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,14 @@ def _get_recursive_files_and_stats(path: str) -> Dict[str, Tuple[float, int]]:
for root, dirs, files in os.walk(path, topdown=False):
rel_root = os.path.relpath(root, path)
for file in files:
key = os.path.join(rel_root, file)
stat = os.lstat(os.path.join(path, key))
files_stats[key] = stat.st_mtime, stat.st_size
try:
key = os.path.join(rel_root, file)
stat = os.lstat(os.path.join(path, key))
files_stats[key] = stat.st_mtime, stat.st_size
except FileNotFoundError:
# Race condition: If a file is deleted while executing this
# method, just continue and don't include the file in the stats
pass

return files_stats

Expand Down

0 comments on commit edb1527

Please sign in to comment.