-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Train] Decouple device-related modules and add Huawei NPU support to Ray Train #44086
Changes from 13 commits
ebcccca
e0b8117
de34e81
e3ebd13
c80ec32
ddee918
d0a4e73
6583bd4
a4839d2
798304a
82eecc7
c88e29e
9b1ada5
13e4914
4665b22
d51d5d0
2f7b9c6
81ba9a4
79f7d43
15e99c6
16927f2
0e81c8e
dee4745
e6a2f4a
9c5a296
68bc4e1
16094f1
82f27b1
7973ddf
b83d1ff
199e572
1c0d115
2c4c5cc
d8a0900
93a9bcf
761ac5c
657b1f9
a5549cc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from ray.air._internal.device_manager.device_manager import DeviceManager | ||
|
||
__all__ = [ | ||
"DeviceManager", | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
|
||
class DeviceManager(ABC): | ||
"""This class contains the function needed for supporting | ||
an acclerator family in Ray AI Library. | ||
""" | ||
|
||
@staticmethod | ||
@abstractmethod | ||
def get_accelerator_name() -> str: | ||
justinvyu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Gets the corresponding accelerator type, e.g. GPU, NPU.""" | ||
... | ||
|
||
@staticmethod | ||
@abstractmethod | ||
def get_device_type() -> str: | ||
"""Gets the device type in deeplearning framwork, | ||
e.g. cuda, hpu, npu in torch. | ||
""" | ||
... |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from typing import List, Union | ||
|
||
import torch | ||
|
||
from ray._private.accelerators.hpu import HPU_PACKAGE_AVAILABLE | ||
from ray.air._internal.device_manager.torch_device_manager import TorchDeviceManager | ||
|
||
if HPU_PACKAGE_AVAILABLE: | ||
import habana_frameworks.torch.hpu as torch_hpu | ||
|
||
|
||
class HPUTorchDeviceManager(TorchDeviceManager): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
"""HPU device manager""" | ||
|
||
@staticmethod | ||
def get_accelerator_name() -> str: | ||
return "HPU" | ||
justinvyu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
@staticmethod | ||
def get_device_type() -> str: | ||
return "hpu" | ||
|
||
def is_device_available(self) -> bool(): | ||
if not HPU_PACKAGE_AVAILABLE: | ||
return False | ||
|
||
return torch_hpu.is_available() | ||
|
||
def get_devices(self) -> List[torch.device]: | ||
if HPU_PACKAGE_AVAILABLE and torch_hpu.is_available(): | ||
devices = [torch.device("hpu")] | ||
else: | ||
devices = [torch.device("cpu")] | ||
|
||
return devices | ||
liuxsh9 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def set_device(self, device: Union[torch.device, int, str, None]): | ||
torch_hpu.set_device(device) | ||
|
||
def is_support_stream(self) -> bool: | ||
"""Validate if the device type support create a stream""" | ||
return False |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import os | ||
from functools import lru_cache | ||
from importlib.util import find_spec | ||
from typing import List, Union | ||
|
||
import torch | ||
|
||
import ray | ||
from ray._private.accelerators.npu import ASCEND_RT_VISIBLE_DEVICES_ENV_VAR | ||
from ray.air._internal.device_manager.torch_device_manager import TorchDeviceManager | ||
|
||
|
||
@lru_cache() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how does this play with multi-node settings? does this lru cache var get shipped over to other nodes? ex: a CPU head node would possibly have this as false, but we don't want to keep that around to the worker nodes. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you for your reminder, the cached var won't shipped over to other nodes. But even so, it is removed to avoid unnecessary code complexity. |
||
def is_package_present(package_name: str) -> bool: | ||
try: | ||
return find_spec(package_name) is not None | ||
except ModuleNotFoundError: | ||
return False | ||
|
||
|
||
NPU_TORCH_PACKAGE_AVAILABLE = is_package_present("torch_npu") | ||
|
||
|
||
if NPU_TORCH_PACKAGE_AVAILABLE: | ||
import torch_npu # noqa: F401 | ||
|
||
|
||
class NPUTorchDeviceManager(TorchDeviceManager): | ||
"""Ascend NPU device manager""" | ||
|
||
@staticmethod | ||
def get_accelerator_name() -> str: | ||
return "NPU" | ||
|
||
@staticmethod | ||
def get_device_type() -> str: | ||
return "npu" | ||
|
||
def is_device_available(self) -> bool: | ||
if not NPU_TORCH_PACKAGE_AVAILABLE: | ||
return False | ||
|
||
return torch.npu.is_available() | ||
|
||
def get_devices(self) -> List[torch.device]: | ||
"""Gets the correct torch device list configured for this process. | ||
|
||
Returns a list of torch NPU devices allocated for the current worker. | ||
If no NPUs are assigned, then it returns a list with a single CPU device. | ||
""" | ||
if NPU_TORCH_PACKAGE_AVAILABLE and torch.npu.is_available(): | ||
npu_ids = [ | ||
str(id) for id in ray.get_runtime_context().get_accelerator_ids()["NPU"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. replace all occurences of this with the NPU constant from ray_constants There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
] | ||
|
||
device_ids = [] | ||
|
||
if len(npu_ids) > 0: | ||
npu_visible_str = os.environ.get(ASCEND_RT_VISIBLE_DEVICES_ENV_VAR, "") | ||
if npu_visible_str and npu_visible_str != "NoDevFiles": | ||
npu_visible_list = npu_visible_str.split(",") | ||
else: | ||
npu_visible_list = [] | ||
|
||
for npu_id in npu_ids: | ||
try: | ||
device_ids.append(npu_visible_list.index(npu_id)) | ||
except IndexError: | ||
raise RuntimeError( | ||
"ASCEND_RT_VISIBLE_DEVICES set incorrectly. " | ||
f"Got {npu_visible_str}, expected to include {npu_id}. " | ||
"Did you override the `ASCEND_RT_VISIBLE_DEVICES` " | ||
"environment variable?" | ||
) | ||
else: | ||
# If called on the driver or outside of Ray Train, return the | ||
# 0th device. | ||
device_ids.append(0) | ||
|
||
devices = [torch.device(f"npu:{device_id}") for device_id in device_ids] | ||
else: | ||
devices = [torch.device("cpu")] | ||
|
||
return devices | ||
|
||
def set_device(self, device: Union[torch.device, int]): | ||
torch.npu.set_device(device) | ||
|
||
def is_support_stream(self) -> bool: | ||
"""Validate if the device type support create a stream""" | ||
return True | ||
|
||
def create_stream(self, deivce): | ||
"""Create a NPU Stream""" | ||
return torch.npu.Stream(deivce) | ||
|
||
def get_stream_context(self, stream): | ||
"""Get a torch.npu.stream context""" | ||
return torch.npu.stream(stream) | ||
|
||
def get_current_stream(self): | ||
"""Get current stream for npu""" | ||
return torch.npu.current_stream() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import os | ||
from typing import List, Union | ||
|
||
import torch | ||
|
||
import ray | ||
from ray.air._internal.device_manager.torch_device_manager import TorchDeviceManager | ||
|
||
|
||
class CUDATorchDeviceManager(TorchDeviceManager): | ||
"""CUDA device manager""" | ||
|
||
@staticmethod | ||
def get_accelerator_name() -> str: | ||
return "GPU" | ||
|
||
@staticmethod | ||
def get_device_type() -> str: | ||
return "cuda" | ||
|
||
def is_device_available(self) -> bool(): | ||
return torch.cuda.is_available() | ||
|
||
def get_devices(self) -> List[torch.device]: | ||
"""Gets the correct torch device list configured for this process. | ||
|
||
Returns a list of torch CUDA devices allocated for the current worker. | ||
If no GPUs are assigned, then it returns a list with a single CPU device. | ||
|
||
Assumes that `CUDA_VISIBLE_DEVICES` is set and is a | ||
superset of the `ray.get_gpu_ids()`. | ||
""" | ||
if torch.cuda.is_available(): | ||
# GPU IDs are assigned by Ray after you specify "use_gpu" | ||
# GPU `ray.get_gpu_ids()` may return ints or may return strings. | ||
# We should always convert to strings. | ||
gpu_ids = [str(id) for id in ray.get_gpu_ids()] | ||
|
||
device_ids = [] | ||
|
||
if len(gpu_ids) > 0: | ||
cuda_visible_str = os.environ.get("CUDA_VISIBLE_DEVICES", "") | ||
if cuda_visible_str and cuda_visible_str != "NoDevFiles": | ||
cuda_visible_list = cuda_visible_str.split(",") | ||
else: | ||
cuda_visible_list = [] | ||
|
||
# By default, there should only be one GPU ID if `use_gpu=True`. | ||
# If there are multiple GPUs, return a list of devices. | ||
# If using fractional GPUs, these IDs are not guaranteed | ||
# to be unique across different processes. | ||
for gpu_id in gpu_ids: | ||
try: | ||
device_ids.append(cuda_visible_list.index(gpu_id)) | ||
except IndexError: | ||
raise RuntimeError( | ||
"CUDA_VISIBLE_DEVICES set incorrectly. " | ||
f"Got {cuda_visible_str}, expected to include {gpu_id}. " | ||
"Did you override the `CUDA_VISIBLE_DEVICES` environment" | ||
" variable? If not, please help file an issue on Github." | ||
) | ||
|
||
else: | ||
# If called on the driver or outside of Ray Train, return the | ||
# 0th device. | ||
device_ids.append(0) | ||
|
||
devices = [torch.device(f"cuda:{device_id}") for device_id in device_ids] | ||
else: | ||
devices = [torch.device("cpu")] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this may have been discussed already, but why don't we just have a CPU device manager instead of falling back to CPU in every class? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently, the CPU scheduling and GPU scheduling logic are coupled, and we hope to start a separate PR later to decouple them. For now, we just decouple NPU/HPU from CPU, which means they will not fall back to CPU. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you elaborate on this? Can we just fall back to a CPU device manager if no GPUs are assigned to the workers? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Have implemented a CPU device manger, serving as an alternative when no accelerator is available. |
||
|
||
return devices | ||
|
||
def set_device(self, device: Union[torch.device, int, str, None]): | ||
torch.cuda.set_device(device) | ||
|
||
def is_support_stream(self) -> bool: | ||
"""Validate if the device type support create a stream""" | ||
return True | ||
|
||
def create_stream(self, deivce: torch.device) -> torch.cuda.Stream: | ||
"""Create a CUDA Stream""" | ||
return torch.cuda.Stream(deivce) | ||
|
||
def get_stream_context(self, stream): | ||
"""Get a torch.cuda.stream context""" | ||
return torch.cuda.stream(stream) | ||
|
||
def get_current_stream(self) -> torch.cuda.Stream: | ||
"""Get a current stream for cuda""" | ||
return torch.cuda.current_stream() |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,39 @@ | ||||||
from typing import List, Union | ||||||
|
||||||
import torch | ||||||
|
||||||
from ray.air._internal.device_manager.device_manager import DeviceManager | ||||||
|
||||||
|
||||||
class TorchDeviceManager(DeviceManager): | ||||||
justinvyu marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
"""This class contains the function needed for supporting | ||||||
an acclerator family in Ray AI Library. | ||||||
""" | ||||||
|
||||||
def is_device_available(self) -> bool: | ||||||
"""Validate if device is available.""" | ||||||
liuxsh9 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
... | ||||||
|
||||||
def get_devices(self) -> List[torch.device]: | ||||||
"""Gets the correct torch device configured for this process""" | ||||||
... | ||||||
|
||||||
def set_device(self, device: Union[torch.device, int, str, None]): | ||||||
"""Set the correct device for this process""" | ||||||
... | ||||||
|
||||||
def is_support_stream(self) -> bool: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Sure! |
||||||
"""Validate if the device type support create a stream""" | ||||||
... | ||||||
|
||||||
def create_stream(self, device: torch.device): | ||||||
"""Create a device stream""" | ||||||
... | ||||||
|
||||||
def get_stream_context(self, stream): | ||||||
"""Get a stream context like torch.cuda.stream""" | ||||||
... | ||||||
|
||||||
def get_current_stream(self): | ||||||
"""Get a torch stream like torch.cuda.current_stream""" | ||||||
... |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import logging | ||
from typing import Optional | ||
|
||
from ray._private.accelerators.hpu import HPU_PACKAGE_AVAILABLE | ||
from ray.air._internal.device_manager.hpu import HPUTorchDeviceManager | ||
from ray.air._internal.device_manager.npu import ( | ||
NPU_TORCH_PACKAGE_AVAILABLE, | ||
NPUTorchDeviceManager, | ||
) | ||
from ray.air._internal.device_manager.nvidia_gpu import CUDATorchDeviceManager | ||
from ray.air._internal.device_manager.torch_device_manager import TorchDeviceManager | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
SUPPORTED_ACCELERATOR_TORCH_DEVICE_MANAGER = { | ||
"GPU": CUDATorchDeviceManager, | ||
"HPU": HPUTorchDeviceManager, | ||
"NPU": NPUTorchDeviceManager, | ||
} | ||
|
||
|
||
def try_register_torch_accelerator_module() -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's the purpose of this? I do not want to raise if someone doesn't There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is used to import/register relevant PyTorch extensions required by accelerators. Currently, for users who do not use NPU or HPU, no errors will occur. The error message has been optimized to avoid misunderstandings. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we do this on demand rather than at the file level in |
||
try: | ||
if NPU_TORCH_PACKAGE_AVAILABLE: | ||
import torch_npu # noqa: F401 | ||
|
||
if HPU_PACKAGE_AVAILABLE: | ||
import habana_frameworks.torch.hpu as torch_hpu # noqa: F401 | ||
|
||
except ImportError: | ||
raise ImportError("Could not import PyTorch") | ||
|
||
|
||
def get_torch_device_manager_cls_by_resources( | ||
resources: Optional[dict], | ||
) -> TorchDeviceManager: | ||
device_manager = None | ||
|
||
# input resources may be None | ||
if not resources: | ||
return CUDATorchDeviceManager | ||
|
||
# select correct accelerator type from resources | ||
for resource_type, resource_value in resources.items(): | ||
if resource_value and resource_type != "CPU": | ||
device_manager = SUPPORTED_ACCELERATOR_TORCH_DEVICE_MANAGER.get( | ||
resource_type, None | ||
) | ||
|
||
return device_manager or CUDATorchDeviceManager | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we raise if attempting to use multiple accelerators instead of just taking the last seen accelerator type? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, now it will raise a clear |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we remove this high level
DeviceManager
abstraction and just keep theTorchDeviceManager
. No need for the extra abstraction for now.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, move this entire folder to
ray/train/torch/_internal
instead.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, just introduce the
TorchDeviceManager
for now.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are trying to extend the
DeviceManager
inray.air
to support more third-party devices, and the plan is to not only use it for Ray Train, but also include RLlib and others. So it seems more reasonable to maintain it withinair
. What do you think?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's just keep it in Train for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, Train rely on the
get_devices
in AIR, so it's natural for us to implement theDeviceManager
in AIR to return the correct device to Train. If we move theDeviceManager
to Train, it would create a weird dependency where Train calls AIR'sget_devices
, which in turn calls back to Train'sDeviceManager
. Would you mind elaborating on your thoughts about this part?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@matthewdeng WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's just put it in
ray.air
for now. Can restructure the package in the future if needed.