Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Train] Decouple device-related modules and add Huawei NPU support to Ray Train #44086

Merged
merged 38 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
ebcccca
Introduce TorchDeviceManager to ray TrainSession and support NPU in R…
liuxsh9 Jun 7, 2024
e0b8117
Add higher abstract class to decouple device manager with torch.
liuxsh9 Jun 7, 2024
de34e81
fix
liuxsh9 Jun 7, 2024
e3ebd13
Merge branch 'master' into train-support-npu
liuxsh9 Jun 29, 2024
c80ec32
fix get_current_stream()
liuxsh9 Jun 29, 2024
ddee918
refine code
liuxsh9 Jul 3, 2024
d0a4e73
fix lint
liuxsh9 Jul 4, 2024
6583bd4
Merge branch 'master' into train-support-npu
liuxsh9 Jul 4, 2024
a4839d2
Enable share npu visible devices for local process in ddp.
liuxsh9 Jul 17, 2024
798304a
Merge branch 'master' into train-support-npu
liuxsh9 Jul 17, 2024
82eecc7
fix lint
liuxsh9 Jul 17, 2024
c88e29e
change the order of init device mananger and set env to enable huggin…
liuxsh9 Jul 17, 2024
9b1ada5
Merge branch 'master' into train-support-npu
liuxsh9 Jul 18, 2024
13e4914
Refactor code based on the comment and feedback.
liuxsh9 Jul 27, 2024
4665b22
Add unit tests for torch device mananger and npu accelerator ids sharing
liuxsh9 Jul 27, 2024
d51d5d0
Merge branch 'master' into train-support-npu
liuxsh9 Jul 27, 2024
2f7b9c6
add gpu-only tags for unit tests
liuxsh9 Jul 27, 2024
81ba9a4
remove resources_per_worker field in backend
liuxsh9 Jul 27, 2024
79f7d43
Edit error message
liuxsh9 Jul 27, 2024
15e99c6
refine code
liuxsh9 Jul 29, 2024
16927f2
Trigger a runtime error when npu is allocated but torch npu is not av…
liuxsh9 Jul 29, 2024
0e81c8e
revert hpu get device logic.
liuxsh9 Jul 29, 2024
dee4745
Refine the code based on the feedback from review.
liuxsh9 Aug 7, 2024
e6a2f4a
Introduce `CPUTorchDeviceManager` and change the value of^CDEFAULT_TO…
liuxsh9 Aug 7, 2024
9c5a296
delete fall back logic in CUDATorchDevicaManager
liuxsh9 Aug 7, 2024
68bc4e1
Merge branch 'master' into train-support-npu
liuxsh9 Aug 7, 2024
16094f1
fix
liuxsh9 Aug 7, 2024
82f27b1
refine code
liuxsh9 Aug 8, 2024
7973ddf
Refine code based on the review feedback.
liuxsh9 Aug 15, 2024
b83d1ff
fix
liuxsh9 Aug 15, 2024
199e572
Merge branch 'master' into train-support-npu
liuxsh9 Aug 15, 2024
1c0d115
Fix device manager logic in local environment.
liuxsh9 Aug 16, 2024
2c4c5cc
fix
liuxsh9 Aug 16, 2024
d8a0900
fix typo and remove unnecessary operation for npu.
liuxsh9 Aug 16, 2024
93a9bcf
move implementation into subclass
liuxsh9 Aug 19, 2024
761ac5c
Update code based on the review feedback.
liuxsh9 Aug 23, 2024
657b1f9
Merge branch 'master' into train-support-npu
matthewdeng Aug 27, 2024
a5549cc
fix lint
liuxsh9 Aug 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/ray/_private/ray_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,10 +408,12 @@ def env_set_by_user(key):
CUDA_VISIBLE_DEVICES_ENV_VAR = "CUDA_VISIBLE_DEVICES"
NEURON_RT_VISIBLE_CORES_ENV_VAR = "NEURON_RT_VISIBLE_CORES"
TPU_VISIBLE_CHIPS_ENV_VAR = "TPU_VISIBLE_CHIPS"
NPU_RT_VISIBLE_DEVICES_ENV_VAR = "ASCEND_RT_VISIBLE_DEVICES"

NEURON_CORES = "neuron_cores"
GPU = "GPU"
TPU = "TPU"
NPU = "NPU"


RAY_WORKER_NICENESS = "RAY_worker_niceness"
Expand Down
5 changes: 5 additions & 0 deletions python/ray/air/_internal/device_manager/__init__.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove this high level DeviceManager abstraction and just keep the TorchDeviceManager. No need for the extra abstraction for now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, move this entire folder to ray/train/torch/_internal instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove this high level DeviceManager abstraction and just keep the TorchDeviceManager. No need for the extra abstraction for now.

OK, just introduce the TorchDeviceManager for now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, move this entire folder to ray/train/torch/_internal instead.

We are trying to extend the DeviceManager in ray.air to support more third-party devices, and the plan is to not only use it for Ray Train, but also include RLlib and others. So it seems more reasonable to maintain it within air. What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just keep it in Train for now.

Copy link
Contributor Author

@liuxsh9 liuxsh9 Aug 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, Train rely on the get_devices in AIR, so it's natural for us to implement the DeviceManager in AIR to return the correct device to Train. If we move the DeviceManager to Train, it would create a weird dependency where Train calls AIR's get_devices, which in turn calls back to Train's DeviceManager. Would you mind elaborating on your thoughts about this part?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@matthewdeng WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just put it in ray.air for now. Can restructure the package in the future if needed.

Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ray.air._internal.device_manager.device_manager import DeviceManager

__all__ = [
"DeviceManager",
]
21 changes: 21 additions & 0 deletions python/ray/air/_internal/device_manager/device_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from abc import ABC, abstractmethod


class DeviceManager(ABC):
"""This class contains the function needed for supporting
an acclerator family in Ray AI Library.
"""

@staticmethod
@abstractmethod
def get_accelerator_name() -> str:
justinvyu marked this conversation as resolved.
Show resolved Hide resolved
"""Gets the corresponding accelerator type, e.g. GPU, NPU."""
...

@staticmethod
@abstractmethod
def get_device_type() -> str:
"""Gets the device type in deeplearning framwork,
e.g. cuda, hpu, npu in torch.
"""
...
42 changes: 42 additions & 0 deletions python/ray/air/_internal/device_manager/hpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import List, Union

import torch

from ray._private.accelerators.hpu import HPU_PACKAGE_AVAILABLE
from ray.air._internal.device_manager.torch_device_manager import TorchDeviceManager

if HPU_PACKAGE_AVAILABLE:
import habana_frameworks.torch.hpu as torch_hpu


class HPUTorchDeviceManager(TorchDeviceManager):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also invite @harborn @kira-lin to review the HPU device manager.

"""HPU device manager"""

@staticmethod
def get_accelerator_name() -> str:
return "HPU"
justinvyu marked this conversation as resolved.
Show resolved Hide resolved

@staticmethod
def get_device_type() -> str:
return "hpu"

def is_device_available(self) -> bool():
if not HPU_PACKAGE_AVAILABLE:
return False

return torch_hpu.is_available()

def get_devices(self) -> List[torch.device]:
if HPU_PACKAGE_AVAILABLE and torch_hpu.is_available():
devices = [torch.device("hpu")]
else:
devices = [torch.device("cpu")]

return devices
liuxsh9 marked this conversation as resolved.
Show resolved Hide resolved

def set_device(self, device: Union[torch.device, int, str, None]):
torch_hpu.set_device(device)

def is_support_stream(self) -> bool:
"""Validate if the device type support create a stream"""
return False
103 changes: 103 additions & 0 deletions python/ray/air/_internal/device_manager/npu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import os
from functools import lru_cache
from importlib.util import find_spec
from typing import List, Union

import torch

import ray
from ray._private.accelerators.npu import ASCEND_RT_VISIBLE_DEVICES_ENV_VAR
from ray.air._internal.device_manager.torch_device_manager import TorchDeviceManager


@lru_cache()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how does this play with multi-node settings? does this lru cache var get shipped over to other nodes?

ex: a CPU head node would possibly have this as false, but we don't want to keep that around to the worker nodes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your reminder, the cached var won't shipped over to other nodes. But even so, it is removed to avoid unnecessary code complexity.

def is_package_present(package_name: str) -> bool:
try:
return find_spec(package_name) is not None
except ModuleNotFoundError:
return False


NPU_TORCH_PACKAGE_AVAILABLE = is_package_present("torch_npu")


if NPU_TORCH_PACKAGE_AVAILABLE:
import torch_npu # noqa: F401


class NPUTorchDeviceManager(TorchDeviceManager):
"""Ascend NPU device manager"""

@staticmethod
def get_accelerator_name() -> str:
return "NPU"

@staticmethod
def get_device_type() -> str:
return "npu"

def is_device_available(self) -> bool:
if not NPU_TORCH_PACKAGE_AVAILABLE:
return False

return torch.npu.is_available()

def get_devices(self) -> List[torch.device]:
"""Gets the correct torch device list configured for this process.

Returns a list of torch NPU devices allocated for the current worker.
If no NPUs are assigned, then it returns a list with a single CPU device.
"""
if NPU_TORCH_PACKAGE_AVAILABLE and torch.npu.is_available():
npu_ids = [
str(id) for id in ray.get_runtime_context().get_accelerator_ids()["NPU"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replace all occurences of this with the NPU constant from ray_constants

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

]

device_ids = []

if len(npu_ids) > 0:
npu_visible_str = os.environ.get(ASCEND_RT_VISIBLE_DEVICES_ENV_VAR, "")
if npu_visible_str and npu_visible_str != "NoDevFiles":
npu_visible_list = npu_visible_str.split(",")
else:
npu_visible_list = []

for npu_id in npu_ids:
try:
device_ids.append(npu_visible_list.index(npu_id))
except IndexError:
raise RuntimeError(
"ASCEND_RT_VISIBLE_DEVICES set incorrectly. "
f"Got {npu_visible_str}, expected to include {npu_id}. "
"Did you override the `ASCEND_RT_VISIBLE_DEVICES` "
"environment variable?"
)
else:
# If called on the driver or outside of Ray Train, return the
# 0th device.
device_ids.append(0)

devices = [torch.device(f"npu:{device_id}") for device_id in device_ids]
else:
devices = [torch.device("cpu")]

return devices

def set_device(self, device: Union[torch.device, int]):
torch.npu.set_device(device)

def is_support_stream(self) -> bool:
"""Validate if the device type support create a stream"""
return True

def create_stream(self, deivce):
"""Create a NPU Stream"""
return torch.npu.Stream(deivce)

def get_stream_context(self, stream):
"""Get a torch.npu.stream context"""
return torch.npu.stream(stream)

def get_current_stream(self):
"""Get current stream for npu"""
return torch.npu.current_stream()
91 changes: 91 additions & 0 deletions python/ray/air/_internal/device_manager/nvidia_gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import os
from typing import List, Union

import torch

import ray
from ray.air._internal.device_manager.torch_device_manager import TorchDeviceManager


class CUDATorchDeviceManager(TorchDeviceManager):
"""CUDA device manager"""

@staticmethod
def get_accelerator_name() -> str:
return "GPU"

@staticmethod
def get_device_type() -> str:
return "cuda"

def is_device_available(self) -> bool():
return torch.cuda.is_available()

def get_devices(self) -> List[torch.device]:
"""Gets the correct torch device list configured for this process.

Returns a list of torch CUDA devices allocated for the current worker.
If no GPUs are assigned, then it returns a list with a single CPU device.

Assumes that `CUDA_VISIBLE_DEVICES` is set and is a
superset of the `ray.get_gpu_ids()`.
"""
if torch.cuda.is_available():
# GPU IDs are assigned by Ray after you specify "use_gpu"
# GPU `ray.get_gpu_ids()` may return ints or may return strings.
# We should always convert to strings.
gpu_ids = [str(id) for id in ray.get_gpu_ids()]

device_ids = []

if len(gpu_ids) > 0:
cuda_visible_str = os.environ.get("CUDA_VISIBLE_DEVICES", "")
if cuda_visible_str and cuda_visible_str != "NoDevFiles":
cuda_visible_list = cuda_visible_str.split(",")
else:
cuda_visible_list = []

# By default, there should only be one GPU ID if `use_gpu=True`.
# If there are multiple GPUs, return a list of devices.
# If using fractional GPUs, these IDs are not guaranteed
# to be unique across different processes.
for gpu_id in gpu_ids:
try:
device_ids.append(cuda_visible_list.index(gpu_id))
except IndexError:
raise RuntimeError(
"CUDA_VISIBLE_DEVICES set incorrectly. "
f"Got {cuda_visible_str}, expected to include {gpu_id}. "
"Did you override the `CUDA_VISIBLE_DEVICES` environment"
" variable? If not, please help file an issue on Github."
)

else:
# If called on the driver or outside of Ray Train, return the
# 0th device.
device_ids.append(0)

devices = [torch.device(f"cuda:{device_id}") for device_id in device_ids]
else:
devices = [torch.device("cpu")]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this may have been discussed already, but why don't we just have a CPU device manager instead of falling back to CPU in every class?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, the CPU scheduling and GPU scheduling logic are coupled, and we hope to start a separate PR later to decouple them. For now, we just decouple NPU/HPU from CPU, which means they will not fall back to CPU.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate on this? Can we just fall back to a CPU device manager if no GPUs are assigned to the workers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have implemented a CPU device manger, serving as an alternative when no accelerator is available.


return devices

def set_device(self, device: Union[torch.device, int, str, None]):
torch.cuda.set_device(device)

def is_support_stream(self) -> bool:
"""Validate if the device type support create a stream"""
return True

def create_stream(self, deivce: torch.device) -> torch.cuda.Stream:
"""Create a CUDA Stream"""
return torch.cuda.Stream(deivce)

def get_stream_context(self, stream):
"""Get a torch.cuda.stream context"""
return torch.cuda.stream(stream)

def get_current_stream(self) -> torch.cuda.Stream:
"""Get a current stream for cuda"""
return torch.cuda.current_stream()
39 changes: 39 additions & 0 deletions python/ray/air/_internal/device_manager/torch_device_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import List, Union

import torch

from ray.air._internal.device_manager.device_manager import DeviceManager


class TorchDeviceManager(DeviceManager):
justinvyu marked this conversation as resolved.
Show resolved Hide resolved
"""This class contains the function needed for supporting
an acclerator family in Ray AI Library.
"""

def is_device_available(self) -> bool:
"""Validate if device is available."""
liuxsh9 marked this conversation as resolved.
Show resolved Hide resolved
...

def get_devices(self) -> List[torch.device]:
"""Gets the correct torch device configured for this process"""
...

def set_device(self, device: Union[torch.device, int, str, None]):
"""Set the correct device for this process"""
...

def is_support_stream(self) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def is_support_stream(self) -> bool:
def supports_stream(self) -> bool:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we do this on demand rather than at the file level in torch/config.py? The auto global import seems a bit weird to me.

Sure!

"""Validate if the device type support create a stream"""
...

def create_stream(self, device: torch.device):
"""Create a device stream"""
...

def get_stream_context(self, stream):
"""Get a stream context like torch.cuda.stream"""
...

def get_current_stream(self):
"""Get a torch stream like torch.cuda.current_stream"""
...
51 changes: 51 additions & 0 deletions python/ray/air/_internal/device_manager/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import logging
from typing import Optional

from ray._private.accelerators.hpu import HPU_PACKAGE_AVAILABLE
from ray.air._internal.device_manager.hpu import HPUTorchDeviceManager
from ray.air._internal.device_manager.npu import (
NPU_TORCH_PACKAGE_AVAILABLE,
NPUTorchDeviceManager,
)
from ray.air._internal.device_manager.nvidia_gpu import CUDATorchDeviceManager
from ray.air._internal.device_manager.torch_device_manager import TorchDeviceManager

logger = logging.getLogger(__name__)


SUPPORTED_ACCELERATOR_TORCH_DEVICE_MANAGER = {
"GPU": CUDATorchDeviceManager,
"HPU": HPUTorchDeviceManager,
"NPU": NPUTorchDeviceManager,
}


def try_register_torch_accelerator_module() -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the purpose of this? I do not want to raise if someone doesn't torch_npu installed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is used to import/register relevant PyTorch extensions required by accelerators. Currently, for users who do not use NPU or HPU, no errors will occur. The error message has been optimized to avoid misunderstandings.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we do this on demand rather than at the file level in torch/config.py? The auto global import seems a bit weird to me.

try:
if NPU_TORCH_PACKAGE_AVAILABLE:
import torch_npu # noqa: F401

if HPU_PACKAGE_AVAILABLE:
import habana_frameworks.torch.hpu as torch_hpu # noqa: F401

except ImportError:
raise ImportError("Could not import PyTorch")


def get_torch_device_manager_cls_by_resources(
resources: Optional[dict],
) -> TorchDeviceManager:
device_manager = None

# input resources may be None
if not resources:
return CUDATorchDeviceManager

# select correct accelerator type from resources
for resource_type, resource_value in resources.items():
if resource_value and resource_type != "CPU":
device_manager = SUPPORTED_ACCELERATOR_TORCH_DEVICE_MANAGER.get(
resource_type, None
)

return device_manager or CUDATorchDeviceManager
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we raise if attempting to use multiple accelerators instead of just taking the last seen accelerator type?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, now it will raise a clear RuntimeError when encountering multiple types of accelerators.

Loading