Skip to content

Commit

Permalink
[Train] Fix lightning trainer devices setting (ray-project#34419)
Browse files Browse the repository at this point in the history
Signed-off-by: woshiyyya <[email protected]>
Signed-off-by: elliottower <[email protected]>
  • Loading branch information
woshiyyya authored and elliottower committed Apr 22, 2023
1 parent ac4ead4 commit ec21094
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
11 changes: 10 additions & 1 deletion python/ray/train/lightning/_lightning_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,21 @@
LIGHTNING_REPORT_STAGE_KEY = "_report_on"


def get_worker_root_device():
"""Get the first torch device of the current worker if there are multiple."""
devices = ray.train.torch.get_device()
if isinstance(devices, list):
return devices[0]
else:
return devices


class RayDDPStrategy(DDPStrategy):
"""Subclass of DDPStrategy to ensure compatibility with Ray orchestration."""

@property
def root_device(self) -> torch.device:
return ray.train.torch.get_device()
return get_worker_root_device()

@property
def distributed_sampler_kwargs(self) -> Dict[str, Any]:
Expand Down
4 changes: 2 additions & 2 deletions python/ray/train/lightning/lightning_trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import ray
from inspect import isclass
from typing import Any, Dict, Optional, Type
import pytorch_lightning as pl
Expand All @@ -26,6 +25,7 @@
RayEnvironment,
RayDataModule,
RayModelCheckpoint,
get_worker_root_device,
)


Expand Down Expand Up @@ -503,7 +503,7 @@ def _lightning_train_loop_per_worker(config):

# Setup trainer's parallel devices
if trainer_config.get("accelerator", None) == "gpu":
current_device = ray.train.torch.get_device()
current_device = get_worker_root_device()
trainer_config["devices"] = [current_device.index]

# Setup ray cluster environment info
Expand Down

0 comments on commit ec21094

Please sign in to comment.