diff --git a/python/ray/train/tests/test_torch_trainer.py b/python/ray/train/tests/test_torch_trainer.py index cecbbd2c8baf..a2e2dc306f1a 100644 --- a/python/ray/train/tests/test_torch_trainer.py +++ b/python/ray/train/tests/test_torch_trainer.py @@ -1,4 +1,6 @@ import contextlib +import uuid + import pytest import time import torch @@ -11,7 +13,7 @@ from ray.train.batch_predictor import BatchPredictor from ray.train.constants import DISABLE_LAZY_CHECKPOINTING_ENV from ray.train.torch import TorchPredictor, TorchTrainer -from ray.air.config import ScalingConfig +from ray.air.config import RunConfig, ScalingConfig from ray.train.torch import TorchConfig from ray.train.trainer import TrainingFailedError import ray.train as train @@ -258,7 +260,6 @@ def test_tune_torch_get_device_gpu(num_gpus_per_worker): (for example when used with Tune). """ from ray.air.config import ScalingConfig - import time num_samples = 2 num_workers = 2 @@ -269,6 +270,7 @@ def test_tune_torch_get_device_gpu(num_gpus_per_worker): # Divide by two because of a 2 node cluster. gpus_per_node = total_gpus_required // 2 + exception = None # Use the same number of cpus per node as gpus per node. with ray_start_2_node_cluster( num_cpus_per_node=gpus_per_node, num_gpus_per_node=gpus_per_node @@ -290,12 +292,14 @@ def train_fn(): @ray.remote(num_cpus=0) class TrialActor: def __init__(self, warmup_steps): - # adding warmup_steps to the config - # to avoid the error of checkpoint name conflict - time.sleep(2 * warmup_steps) self.trainer = TorchTrainer( train_fn, torch_config=TorchConfig(backend="gloo"), + run_config=RunConfig( + # Use a unique name to avoid using the same + # experiment directory + name=f"test_tune_torch_get_device_gpu_{uuid.uuid4()}" + ), scaling_config=ScalingConfig( num_workers=num_workers, use_gpu=True, @@ -313,8 +317,15 @@ def __init__(self, warmup_steps): def run(self): return self.trainer.fit() - actors = [TrialActor.remote(1) for _ in range(num_samples)] - ray.get([actor.run.remote() for actor in actors]) + try: + actors = [TrialActor.remote(1) for _ in range(num_samples)] + ray.get([actor.run.remote() for actor in actors]) + except Exception as exc: + exception = exc + + # Raise exception after Ray cluster has been shutdown to avoid corrupted state + if exception: + raise exception def test_torch_auto_unwrap(ray_start_4_cpus): diff --git a/python/ray/tune/execution/trial_runner.py b/python/ray/tune/execution/trial_runner.py index c3cc20abe0e5..f912fdf6f0c6 100644 --- a/python/ray/tune/execution/trial_runner.py +++ b/python/ray/tune/execution/trial_runner.py @@ -1,3 +1,4 @@ +import uuid from typing import Any, Dict, List, Optional, Union, Tuple, Set from datetime import datetime @@ -364,7 +365,9 @@ def save_to_dir(self, experiment_dir: Optional[str] = None): }, } - tmp_file_name = os.path.join(experiment_dir, ".tmp_experiment_state") + tmp_file_name = os.path.join( + experiment_dir, f".tmp_experiment_state_{uuid.uuid4()}" + ) with open(tmp_file_name, "w") as f: json.dump(runner_state, f, indent=2, cls=TuneFunctionEncoder)