Skip to content

Commit

Permalink
Revert "Revert "[Train] Add support for handling multiple batch data …
Browse files Browse the repository at this point in the history
…types for prepare_data_loader"" (#26491)

Signed-off-by: Amog Kamsetty <[email protected]>

* Revert "Revert "[Train] Add support for handling multiple batch data types for prepare_data_loader (#26386)" (#26483)"

This reverts commit e6c0403.
  • Loading branch information
amogkam authored Jul 26, 2022
1 parent 5bcaf4f commit 68670e3
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 13 deletions.
55 changes: 43 additions & 12 deletions python/ray/train/tests/test_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@
import pytest
import torch
import torchvision
from test_tune import (
torch_fashion_mnist,
tune_tensorflow_mnist,
)
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler

Expand All @@ -31,6 +27,10 @@
)
from ray.train.examples.torch_linear_example import LinearDataset
from ray.train.horovod.horovod_trainer import HorovodTrainer
from ray.train.tests.test_tune import (
torch_fashion_mnist,
tune_tensorflow_mnist,
)
from ray.train.tensorflow.tensorflow_trainer import TensorflowTrainer
from ray.train.torch import TorchConfig
from ray.train.torch.torch_trainer import TorchTrainer
Expand Down Expand Up @@ -65,6 +65,20 @@ def ray_2_node_4_gpu():
cluster.shutdown()


class LinearDatasetDict(LinearDataset):
"""Modifies the LinearDataset to return a Dict instead of a Tuple."""

def __getitem__(self, index):
return {"x": self.x[index, None], "y": self.y[index, None]}


class NonTensorDataset(LinearDataset):
"""Modifies the LinearDataset to also return non-tensor objects."""

def __getitem__(self, index):
return {"x": self.x[index, None], "y": 2}


# TODO: Refactor as a backend test.
@pytest.mark.parametrize("num_gpus_per_worker", [0.5, 1])
def test_torch_get_device(ray_start_4_cpus_2_gpus, num_gpus_per_worker):
Expand Down Expand Up @@ -149,8 +163,11 @@ def train_fn():


# TODO: Refactor as a backend test.
def test_torch_prepare_dataloader(ray_start_4_cpus_2_gpus):
data_loader = DataLoader(LinearDataset(a=1, b=2, size=10))
@pytest.mark.parametrize(
"dataset", (LinearDataset, LinearDatasetDict, NonTensorDataset)
)
def test_torch_prepare_dataloader(ray_start_4_cpus_2_gpus, dataset):
data_loader = DataLoader(dataset(a=1, b=2, size=10))

def train_fn():
wrapped_data_loader = train.torch.prepare_data_loader(data_loader)
Expand All @@ -159,12 +176,26 @@ def train_fn():
assert isinstance(wrapped_data_loader.sampler, DistributedSampler)

# Make sure you can properly iterate through the DataLoader.
for batch in wrapped_data_loader:
X = batch[0]
y = batch[1]

# Make sure the data is on the correct device.
assert X.is_cuda and y.is_cuda
# Case where the dataset returns a tuple or list from __getitem__.
if isinstance(dataset, LinearDataset):
for batch in wrapped_data_loader:
x = batch[0]
y = batch[1]

# Make sure the data is on the correct device.
assert x.is_cuda and y.is_cuda
# Case where the dataset returns a dict from __getitem__.
elif isinstance(dataset, LinearDatasetDict):
for batch in wrapped_data_loader:
for x, y in zip(batch["x"], batch["y"]):
# Make sure the data is on the correct device.
assert x.is_cuda and y.is_cuda

elif isinstance(dataset, NonTensorDataset):
for batch in wrapped_data_loader:
for x, y in zip(batch["x"], batch["y"]):
# Make sure the data is on the correct device.
assert x.is_cuda and y == 2

trainer = Trainer("torch", num_workers=2, use_gpu=True)
trainer.start()
Expand Down
17 changes: 16 additions & 1 deletion python/ray/train/torch/train_loop_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import random
import types
import warnings
import collections

from pathlib import Path
from typing import Any, Dict, Optional
Expand Down Expand Up @@ -585,7 +586,21 @@ def try_move_device(i):
return i

with torch.cuda.stream(self._memcpy_stream):
return tuple(try_move_device(i) for i in item)
if isinstance(item, collections.abc.Mapping):
item_on_device = {k: self._move_to_device(v) for k, v in item.items()}
elif isinstance(item, tuple):
item_on_device = tuple(self._move_to_device(i) for i in item)
elif isinstance(item, list):
item_on_device = [self._move_to_device(i) for i in item]
elif isinstance(item, torch.Tensor):
item_on_device = try_move_device(item)
else:
logger.info(
f"Data type {type(item)} doesn't support being moved to device."
)
item_on_device = item

return item_on_device

def _wait_for_batch(self, item):
if self._memcpy_stream is None:
Expand Down

0 comments on commit 68670e3

Please sign in to comment.