Skip to content

Commit

Permalink
[Train] Fix prepare_data_loader with enable_reproducibility (ray-…
Browse files Browse the repository at this point in the history
…project#30266)

Calling train.torch.enable_reproducibility before train.torch.prepare_data_loader causes an exception to be raised of the num_workers in DataLoader is bigger than 0 and the worker_init_fn in DataLoader is not set. The exception is caused by the worker_init_fn, which has a value of None, being used as a callable in seeded_worker_init_fn. This was untested.

This PR fixes this oversight and ensures that this is tested in CI (and also removes a duplicate test in the process).

Signed-off-by: Antoni Baum <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
  • Loading branch information
Yard1 authored and WeichenXu123 committed Dec 19, 2022
1 parent 90755b3 commit 87e4663
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
10 changes: 7 additions & 3 deletions python/ray/train/tests/test_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,8 @@ def train_fn():
trainer.fit()


@pytest.mark.parametrize("use_gpu", (False, True))
def test_enable_reproducibility(ray_start_4_cpus_2_gpus, use_gpu):
@pytest.mark.parametrize("data_loader_num_workers", (0, 2))
def test_enable_reproducibility(ray_start_4_cpus_2_gpus, data_loader_num_workers):
# NOTE: Reproducible results aren't guaranteed between seeded executions, even with
# identical hardware and software dependencies. This test should be okay given that
# it only runs for two epochs on a small dataset.
Expand All @@ -251,7 +251,11 @@ def train_func():
torch.randn(dataset_length, 3, 32, 32),
torch.randint(low=0, high=1000, size=(dataset_length,)),
)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64)

# num_workers > 0 tests for https://github.com/ray-project/ray/issues/30247
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=64, num_workers=data_loader_num_workers
)
dataloader = train.torch.prepare_data_loader(dataloader)

optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
Expand Down
17 changes: 10 additions & 7 deletions python/ray/train/torch/train_loop_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import collections
from distutils.version import LooseVersion

from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Callable

import ray
from ray.air import session
Expand Down Expand Up @@ -415,19 +415,22 @@ def with_sampler(loader):
# shuffling is enabled by checking the default sampler type.
shuffle = not isinstance(loader.sampler, SequentialSampler)

def seeded_worker_init_fn(worker_init_fn):
def wrapper(worker_id):
def seeded_worker_init_fn(
worker_init_fn: Optional[Callable[[int], None]]
):
def wrapper(worker_id: int):
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
worker_init_fn(worker_id)
if worker_init_fn:
worker_init_fn(worker_id)

return wrapper

worker_init_fn = loader.worker_init_fn
generator = loader.generator
worker_init_fn: Optional[Callable[[int], None]] = loader.worker_init_fn
generator: Optional[torch.Generator] = loader.generator
if self._seed is not None:
worker_init_fn = seeded_worker_init_fn(loader.worker_init_fn)
worker_init_fn = seeded_worker_init_fn(worker_init_fn)
generator = torch.Generator()
generator.manual_seed(self._seed)

Expand Down

0 comments on commit 87e4663

Please sign in to comment.