diff --git a/python/ray/air/tests/test_new_dataset_config.py b/python/ray/air/tests/test_new_dataset_config.py index b4c7ea091790..f475e2699e91 100644 --- a/python/ray/air/tests/test_new_dataset_config.py +++ b/python/ray/air/tests/test_new_dataset_config.py @@ -2,6 +2,7 @@ import random import pytest +from unittest.mock import MagicMock import ray from ray import train @@ -168,6 +169,32 @@ def test_configure_execution_options_carryover_context(ray_start_4_cpus): assert ingest_options.verbose_progress is True +@pytest.mark.parametrize("enable_locality", [True, False]) +def test_configure_locality(enable_locality): + options = DataConfig.default_ingest_options() + options.locality_with_output = enable_locality + data_config = DataConfig(execution_options=options) + + mock_ds = MagicMock() + mock_ds.streaming_split = MagicMock() + mock_ds.copy = MagicMock(return_value=mock_ds) + world_size = 2 + worker_handles = [MagicMock() for _ in range(world_size)] + worker_node_ids = ["node" + str(i) for i in range(world_size)] + data_config.configure( + datasets={"train": mock_ds}, + world_size=world_size, + worker_handles=worker_handles, + worker_node_ids=worker_node_ids, + ) + mock_ds.streaming_split.assert_called_once() + mock_ds.streaming_split.assert_called_with( + world_size, + equal=True, + locality_hints=worker_node_ids if enable_locality else None, + ) + + class CustomConfig(DataConfig): def __init__(self): pass diff --git a/python/ray/train/_internal/data_config.py b/python/ray/train/_internal/data_config.py index fb85dce793ea..2c26195c8a8d 100644 --- a/python/ray/train/_internal/data_config.py +++ b/python/ray/train/_internal/data_config.py @@ -91,6 +91,9 @@ def configure( else: datasets_to_split = set(self._datasets_to_split) + locality_hints = ( + worker_node_ids if self._execution_options.locality_with_output else None + ) for name, ds in datasets.items(): ds = ds.copy(ds) ds.context.execution_options = copy.deepcopy(self._execution_options) @@ -107,7 +110,7 @@ def configure( if name in datasets_to_split: for i, split in enumerate( ds.streaming_split( - world_size, equal=True, locality_hints=worker_node_ids + world_size, equal=True, locality_hints=locality_hints ) ): output[i][name] = split