Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[air] Horovod: Use Torch.encode_data if torch is imported #28440

Merged
merged 6 commits into from
Sep 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions python/ray/air/_internal/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,19 @@ def load_torch_model(
f"to be of type `torch.nn.Module`, or a model "
f"state dict of type dict."
)


def contains_tensor(obj):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this? I think if torch is installed, it should be safe to always use the TorchBackend for encoding/decoding (even if the data dict does not contain a tensor). I'm worried in the worst case contains_tensor can lead to a lot of recursion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I originally didn't have it in, but thought the overhead wouldn't be as bad. But I agree, since it only concerns an internal communication channel and the intermediate objects are not exposed to the user, we can just do this always when torch is loaded. Updated the PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Turns out we do need it as torch.save seems to silently fail if a full model is passed (and not a state dict).

I think it should be fine - a similar lookup has to be in pickling after all, and in most cases it should finish early.

if isinstance(obj, torch.Tensor):
return True
elif isinstance(obj, dict):
for k, v in obj.items():
if contains_tensor(k):
return True
if contains_tensor(v):
return True
elif isinstance(obj, (list, tuple)):
for v in obj:
if contains_tensor(v):
return True
return False
37 changes: 35 additions & 2 deletions python/ray/train/horovod/config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Optional, Set
import sys
from typing import Optional, Set, Dict

import os
from dataclasses import dataclass

import ray
from ray.train.backend import BackendConfig, Backend
from ray.air._internal.torch_utils import contains_tensor
from ray.train.backend import BackendConfig, Backend, EncodedData
from ray.train._internal.utils import update_env_vars
from ray.train._internal.worker_group import WorkerGroup, Worker

Expand Down Expand Up @@ -129,6 +131,37 @@ def on_start(self, worker_group: WorkerGroup, backend_config: HorovodConfig):

worker_group.execute(update_env_vars, coordinator_envs)

@staticmethod
def encode_data(data_dict: Dict) -> EncodedData:
"""Logic to encode a data dict before sending to the driver.

This function will be called on the workers for any data that is
sent to the driver via ``session.report()``.
"""
# If torch is imported, we can use it to serialize the data dict
# into bytes. This will prevent e.g. GPU deserialization errors.
if "torch" in sys.modules and contains_tensor(data_dict):
from ray.train.torch.config import _TorchBackend

return _TorchBackend.encode_data(data_dict)

return data_dict

@staticmethod
def decode_data(encoded_data: EncodedData) -> Dict:
"""Logic to decode an encoded data dict.

This function will be called on the driver after receiving the
encoded data dict from the worker.
"""
# See encode_data
if "torch" in sys.modules and isinstance(encoded_data, bytes):
from ray.train.torch.config import _TorchBackend

return _TorchBackend.decode_data(encoded_data)

return encoded_data


def _init_env_vars(world_rank: int, world_size: int, node_id: str):
"""Initialize Horovod environment variables."""
Expand Down
18 changes: 18 additions & 0 deletions python/ray/train/tests/test_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torch.utils.data import DataLoader, DistributedSampler

import ray
from ray.air import Checkpoint, session
from ray.cluster_utils import Cluster

import ray.train as train
Expand Down Expand Up @@ -442,6 +443,23 @@ def test_horovod_torch_mnist_gpu(ray_start_4_cpus_2_gpus):
assert result[TRAINING_ITERATION] == num_workers


def test_horovod_torch_mnist_gpu_checkpoint(ray_start_4_cpus_2_gpus):
def checkpointing_func(config):
net = torch.nn.Linear(in_features=8, out_features=16)
net.to("cuda")

checkpoint = Checkpoint.from_dict({"model": net.state_dict()})
session.report({"metric": 1}, checkpoint=checkpoint)

num_workers = 2
trainer = HorovodTrainer(
checkpointing_func,
scaling_config=ScalingConfig(num_workers=num_workers, use_gpu=True),
)
result = trainer.fit()
assert not result.error


def test_tune_fashion_mnist_gpu(ray_start_4_cpus_2_gpus):
torch_fashion_mnist(num_workers=2, use_gpu=True, num_samples=1)

Expand Down
11 changes: 11 additions & 0 deletions python/ray/train/tests/test_torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ray.air._internal.torch_utils import (
convert_pandas_to_torch_tensor,
load_torch_model,
contains_tensor,
)

data_batch = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
Expand Down Expand Up @@ -95,6 +96,16 @@ def test_load_state_dict_fail(self):
load_torch_model(torch_module.state_dict())


def test_contains_tensor():
t = torch.tensor([0])
assert contains_tensor(t)
assert contains_tensor([1, 2, 3, t, 5, 6])
assert contains_tensor([1, 2, 3, {"dict": t}, 5, 6])
assert contains_tensor({"outer": [1, 2, 3, {"dict": t}, 5, 6]})
assert contains_tensor({t: [1, 2, 3, {"dict": 2}, 5, 6]})
assert not contains_tensor([4, 5, 6])


if __name__ == "__main__":
import sys

Expand Down