diff --git a/python/ray/train/tests/test_gpu.py b/python/ray/train/tests/test_gpu.py index feaf383bdad8..1b6ec2c699d3 100644 --- a/python/ray/train/tests/test_gpu.py +++ b/python/ray/train/tests/test_gpu.py @@ -232,8 +232,8 @@ def train_fn(): trainer.fit() -@pytest.mark.parametrize("use_gpu", (False, True)) -def test_enable_reproducibility(ray_start_4_cpus_2_gpus, use_gpu): +@pytest.mark.parametrize("data_loader_num_workers", (0, 2)) +def test_enable_reproducibility(ray_start_4_cpus_2_gpus, data_loader_num_workers): # NOTE: Reproducible results aren't guaranteed between seeded executions, even with # identical hardware and software dependencies. This test should be okay given that # it only runs for two epochs on a small dataset. @@ -251,7 +251,11 @@ def train_func(): torch.randn(dataset_length, 3, 32, 32), torch.randint(low=0, high=1000, size=(dataset_length,)), ) - dataloader = torch.utils.data.DataLoader(dataset, batch_size=64) + + # num_workers > 0 tests for https://github.com/ray-project/ray/issues/30247 + dataloader = torch.utils.data.DataLoader( + dataset, batch_size=64, num_workers=data_loader_num_workers + ) dataloader = train.torch.prepare_data_loader(dataloader) optimizer = torch.optim.SGD(model.parameters(), lr=0.001) diff --git a/python/ray/train/torch/train_loop_utils.py b/python/ray/train/torch/train_loop_utils.py index 702a81889ba1..d8cacfa4d6fe 100644 --- a/python/ray/train/torch/train_loop_utils.py +++ b/python/ray/train/torch/train_loop_utils.py @@ -6,7 +6,7 @@ import collections from distutils.version import LooseVersion -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Callable import ray from ray.air import session @@ -415,19 +415,22 @@ def with_sampler(loader): # shuffling is enabled by checking the default sampler type. shuffle = not isinstance(loader.sampler, SequentialSampler) - def seeded_worker_init_fn(worker_init_fn): - def wrapper(worker_id): + def seeded_worker_init_fn( + worker_init_fn: Optional[Callable[[int], None]] + ): + def wrapper(worker_id: int): worker_seed = torch.initial_seed() % 2**32 np.random.seed(worker_seed) random.seed(worker_seed) - worker_init_fn(worker_id) + if worker_init_fn: + worker_init_fn(worker_id) return wrapper - worker_init_fn = loader.worker_init_fn - generator = loader.generator + worker_init_fn: Optional[Callable[[int], None]] = loader.worker_init_fn + generator: Optional[torch.Generator] = loader.generator if self._seed is not None: - worker_init_fn = seeded_worker_init_fn(loader.worker_init_fn) + worker_init_fn = seeded_worker_init_fn(worker_init_fn) generator = torch.Generator() generator.manual_seed(self._seed)