-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Train] Decouple device-related modules and add Huawei NPU support to Ray Train #44086
Changes from 35 commits
ebcccca
e0b8117
de34e81
e3ebd13
c80ec32
ddee918
d0a4e73
6583bd4
a4839d2
798304a
82eecc7
c88e29e
9b1ada5
13e4914
4665b22
d51d5d0
2f7b9c6
81ba9a4
79f7d43
15e99c6
16927f2
0e81c8e
dee4745
e6a2f4a
9c5a296
68bc4e1
16094f1
82f27b1
7973ddf
b83d1ff
199e572
1c0d115
2c4c5cc
d8a0900
93a9bcf
761ac5c
657b1f9
a5549cc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import logging | ||
import threading | ||
from typing import Optional, Type | ||
|
||
import ray | ||
import ray._private.ray_constants as ray_constants | ||
from ray.air._internal.device_manager.cpu import CPUTorchDeviceManager | ||
from ray.air._internal.device_manager.hpu import HPUTorchDeviceManager | ||
from ray.air._internal.device_manager.npu import NPUTorchDeviceManager | ||
from ray.air._internal.device_manager.nvidia_gpu import CUDATorchDeviceManager | ||
from ray.air._internal.device_manager.torch_device_manager import TorchDeviceManager | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
DEFAULT_TORCH_DEVICE_MANAGER_CLS = CPUTorchDeviceManager | ||
|
||
|
||
SUPPORTED_ACCELERATOR_TORCH_DEVICE_MANAGER = { | ||
ray_constants.GPU: CUDATorchDeviceManager, | ||
ray_constants.HPU: HPUTorchDeviceManager, | ||
ray_constants.NPU: NPUTorchDeviceManager, | ||
} | ||
|
||
|
||
def register_custom_torch_dist_backend(backend: Optional[str] = None) -> None: | ||
if backend == "hccl": | ||
# The name for the communication backend of Habana and torch-npu is the same. | ||
HPUTorchDeviceManager.register_custom_torch_dist_backend() | ||
|
||
NPUTorchDeviceManager.register_custom_torch_dist_backend() | ||
|
||
|
||
def get_torch_device_manager_cls_by_resources( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any reason not to define these functions in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently the dependency is that these methods depend on various If we put these methods into the torch_device_manager.py, it may lead to circular import issues. We refer the design of My understanding may be limited, could you please provide more information? |
||
resources: Optional[dict], | ||
) -> Type[TorchDeviceManager]: | ||
existing_device_manager = None | ||
|
||
# input resources may be None | ||
if not resources: | ||
return DEFAULT_TORCH_DEVICE_MANAGER_CLS | ||
|
||
# select correct accelerator type from resources | ||
for resource_type, resource_value in resources.items(): | ||
device_manager = SUPPORTED_ACCELERATOR_TORCH_DEVICE_MANAGER.get( | ||
resource_type, None | ||
) | ||
if resource_value and device_manager: | ||
# An error will raise when multiple accelerators are specified. | ||
if existing_device_manager: | ||
raise RuntimeError( | ||
"Unable to determine the appropriate DeviceManager " | ||
f"for the specified resources {resources}." | ||
) | ||
else: | ||
existing_device_manager = device_manager | ||
|
||
return existing_device_manager or DEFAULT_TORCH_DEVICE_MANAGER_CLS | ||
|
||
|
||
def get_torch_device_manager_cls_by_device_type(device_type: str): | ||
if device_type.lower() == ray_constants.GPU.lower() or device_type == "cuda": | ||
return CUDATorchDeviceManager | ||
elif device_type.lower() == ray_constants.NPU.lower(): | ||
return NPUTorchDeviceManager | ||
elif device_type.lower() == ray_constants.HPU.lower(): | ||
return HPUTorchDeviceManager | ||
elif device_type.lower() == "cpu": | ||
return CPUTorchDeviceManager | ||
|
||
raise RuntimeError(f"Device type {device_type} cannot be recognized.") | ||
|
||
|
||
_torch_device_manager = None | ||
_torch_device_manager_lock = threading.Lock() | ||
Comment on lines
+34
to
+35
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need a global variable to track this? Can this not just be tracked as an instance variable by the caller? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One issue with global variables is that we want these functions to be called by the individual workers, so they aren't pointing to these references. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This addresses previous reviewer's feedback. Additionally, we will check for |
||
|
||
|
||
def get_torch_device_manager(device_type: Optional[str] = None) -> TorchDeviceManager: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IMO we should consider removing this function. It's convenient to have a single function like this, but for all the usages it seems we want one explicit path (either with or without a device), so just calling that explicit logic directly is easier to follow. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good suggestion, this has been modified to have two clear function entry points |
||
if device_type: | ||
# Specify the device type to retrieve the device manager directly, | ||
# rather than relying on the remote environment to determine it. | ||
return get_torch_device_manager_cls_by_device_type(device_type)() | ||
|
||
with _torch_device_manager_lock: | ||
if not _torch_device_manager: | ||
init_torch_device_manager() | ||
|
||
return _torch_device_manager | ||
|
||
|
||
def init_torch_device_manager() -> None: | ||
global _torch_device_manager | ||
|
||
resources = ray.get_runtime_context().get_accelerator_ids() | ||
|
||
_torch_device_manager = get_torch_device_manager_cls_by_resources(resources)() | ||
|
||
|
||
__all__ = [ | ||
TorchDeviceManager, | ||
CPUTorchDeviceManager, | ||
CUDATorchDeviceManager, | ||
HPUTorchDeviceManager, | ||
NPUTorchDeviceManager, | ||
register_custom_torch_dist_backend, | ||
get_torch_device_manager, | ||
init_torch_device_manager, | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from contextlib import contextmanager | ||
from typing import List | ||
|
||
import torch | ||
|
||
from ray.air._internal.device_manager.torch_device_manager import TorchDeviceManager | ||
|
||
|
||
class CPUTorchDeviceManager(TorchDeviceManager): | ||
"""CPU device manager""" | ||
|
||
def is_available(self) -> bool(): | ||
return True | ||
|
||
def get_devices(self) -> List[torch.device]: | ||
"""Gets the correct torch device list configured for this process.""" | ||
return [torch.device("cpu")] | ||
|
||
def supports_stream(self) -> bool: | ||
"""Validate if the device type support create a stream""" | ||
return False | ||
|
||
def get_stream_context(self, stream): | ||
"""Return empty context mananger for CPU.""" | ||
|
||
@contextmanager | ||
def default_context_manager(): | ||
yield | ||
|
||
return default_context_manager() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from contextlib import contextmanager | ||
from typing import List, Union | ||
|
||
import torch | ||
|
||
from ray._private.accelerators.hpu import HPU_PACKAGE_AVAILABLE | ||
from ray.air._internal.device_manager.torch_device_manager import TorchDeviceManager | ||
|
||
if HPU_PACKAGE_AVAILABLE: | ||
import habana_frameworks.torch.hpu as torch_hpu | ||
|
||
|
||
class HPUTorchDeviceManager(TorchDeviceManager): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
"""HPU device manager""" | ||
|
||
@staticmethod | ||
def register_custom_torch_dist_backend(): | ||
if HPU_PACKAGE_AVAILABLE: | ||
import habana_frameworks.torch.core # noqa: F401 | ||
import habana_frameworks.torch.distributed.hccl # noqa: F401 | ||
|
||
def is_available(self) -> bool(): | ||
if not HPU_PACKAGE_AVAILABLE: | ||
return False | ||
|
||
return torch_hpu.is_available() | ||
|
||
def get_devices(self) -> List[torch.device]: | ||
if not self.is_available(): | ||
raise RuntimeError( | ||
"Using HPUTorchDeviceManager but torch hpu is not available." | ||
) | ||
|
||
return [torch.device("hpu")] | ||
|
||
def set_device(self, device: Union[torch.device, int, str, None]): | ||
torch_hpu.set_device(device) | ||
|
||
def supports_stream(self) -> bool: | ||
"""Validate if the device type support create a stream""" | ||
return False | ||
|
||
def get_stream_context(self, stream): | ||
"""Get HPU stream context manager, empty so far.""" | ||
|
||
@contextmanager | ||
def default_context_manager(): | ||
yield | ||
|
||
return default_context_manager() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import os | ||
from importlib.util import find_spec | ||
from typing import List, Union | ||
|
||
import torch | ||
|
||
import ray | ||
import ray._private.ray_constants as ray_constants | ||
from ray.air._internal.device_manager.torch_device_manager import TorchDeviceManager | ||
|
||
|
||
def is_package_present(package_name: str) -> bool: | ||
try: | ||
return find_spec(package_name) is not None | ||
except ModuleNotFoundError: | ||
return False | ||
|
||
|
||
NPU_TORCH_PACKAGE_AVAILABLE = is_package_present("torch_npu") | ||
|
||
|
||
if NPU_TORCH_PACKAGE_AVAILABLE: | ||
import torch_npu # noqa: F401 | ||
|
||
|
||
class NPUTorchDeviceManager(TorchDeviceManager): | ||
"""Ascend NPU device manager""" | ||
|
||
@staticmethod | ||
def register_custom_torch_dist_backend(): | ||
if NPU_TORCH_PACKAGE_AVAILABLE: | ||
import torch_npu # noqa: F401, F811 | ||
|
||
def is_available(self) -> bool: | ||
if not NPU_TORCH_PACKAGE_AVAILABLE: | ||
return False | ||
|
||
return torch.npu.is_available() | ||
|
||
def get_devices(self) -> List[torch.device]: | ||
"""Gets the correct torch device list configured for this process. | ||
|
||
Returns a list of torch NPU devices allocated for the current worker. | ||
If no NPUs are assigned, then it returns a list with a single CPU device. | ||
""" | ||
if NPU_TORCH_PACKAGE_AVAILABLE and torch.npu.is_available(): | ||
npu_ids = [ | ||
str(id) | ||
for id in ray.get_runtime_context().get_accelerator_ids()[ | ||
ray_constants.NPU | ||
] | ||
] | ||
|
||
device_ids = [] | ||
|
||
if len(npu_ids) > 0: | ||
npu_visible_str = os.environ.get( | ||
ray_constants.NPU_RT_VISIBLE_DEVICES_ENV_VAR, "" | ||
) | ||
if npu_visible_str and npu_visible_str != "NoDevFiles": | ||
npu_visible_list = npu_visible_str.split(",") | ||
else: | ||
npu_visible_list = [] | ||
|
||
for npu_id in npu_ids: | ||
try: | ||
device_ids.append(npu_visible_list.index(npu_id)) | ||
except IndexError: | ||
raise RuntimeError( | ||
"ASCEND_RT_VISIBLE_DEVICES set incorrectly. " | ||
f"Got {npu_visible_str}, expected to include {npu_id}. " | ||
"Did you override the `ASCEND_RT_VISIBLE_DEVICES` " | ||
"environment variable?" | ||
) | ||
else: | ||
# If called on the driver or outside of Ray Train, return the | ||
# 0th device. | ||
device_ids.append(0) | ||
|
||
devices = [torch.device(f"npu:{device_id}") for device_id in device_ids] | ||
else: | ||
raise RuntimeError( | ||
"Using NPUTorchDeviceManager but torch npu is not available." | ||
) | ||
|
||
return devices | ||
|
||
def set_device(self, device: Union[torch.device, int]): | ||
torch.npu.set_device(device) | ||
|
||
def supports_stream(self) -> bool: | ||
"""Validate if the device type support to create a stream""" | ||
return True | ||
|
||
def create_stream(self, device): | ||
"""Create a stream on NPU device""" | ||
return torch.npu.Stream(device) | ||
|
||
def get_stream_context(self, stream): | ||
"""Get a torch.stream context on NPU device""" | ||
return torch.npu.stream(stream) | ||
|
||
def get_current_stream(self): | ||
"""Get current stream for NPU device""" | ||
return torch.npu.current_stream() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import os | ||
from typing import List, Union | ||
|
||
import torch | ||
|
||
import ray | ||
from ray.air._internal.device_manager.torch_device_manager import TorchDeviceManager | ||
|
||
|
||
class CUDATorchDeviceManager(TorchDeviceManager): | ||
"""CUDA device manager""" | ||
|
||
def is_available(self) -> bool(): | ||
return torch.cuda.is_available() | ||
|
||
def get_devices(self) -> List[torch.device]: | ||
"""Gets the correct torch device list configured for this process. | ||
|
||
Returns a list of torch CUDA devices allocated for the current worker. | ||
If no GPUs are assigned, then it returns a list with a single CPU device. | ||
|
||
Assumes that `CUDA_VISIBLE_DEVICES` is set and is a | ||
superset of the `ray.get_gpu_ids()`. | ||
""" | ||
|
||
# GPU IDs are assigned by Ray after you specify "use_gpu" | ||
# GPU `ray.get_gpu_ids()` may return ints or may return strings. | ||
# We should always convert to strings. | ||
gpu_ids = [str(id) for id in ray.get_gpu_ids()] | ||
|
||
device_ids = [] | ||
|
||
if len(gpu_ids) > 0: | ||
cuda_visible_str = os.environ.get("CUDA_VISIBLE_DEVICES", "") | ||
if cuda_visible_str and cuda_visible_str != "NoDevFiles": | ||
cuda_visible_list = cuda_visible_str.split(",") | ||
else: | ||
cuda_visible_list = [] | ||
|
||
# By default, there should only be one GPU ID if `use_gpu=True`. | ||
# If there are multiple GPUs, return a list of devices. | ||
# If using fractional GPUs, these IDs are not guaranteed | ||
# to be unique across different processes. | ||
for gpu_id in gpu_ids: | ||
try: | ||
device_ids.append(cuda_visible_list.index(gpu_id)) | ||
except IndexError: | ||
raise RuntimeError( | ||
"CUDA_VISIBLE_DEVICES set incorrectly. " | ||
f"Got {cuda_visible_str}, expected to include {gpu_id}. " | ||
"Did you override the `CUDA_VISIBLE_DEVICES` environment" | ||
" variable? If not, please help file an issue on Github." | ||
) | ||
|
||
else: | ||
# If called on the driver or outside of Ray Train, return the | ||
# 0th device. | ||
device_ids.append(0) | ||
|
||
return [torch.device(f"cuda:{device_id}") for device_id in device_ids] | ||
|
||
def set_device(self, device: Union[torch.device, int, str, None]): | ||
torch.cuda.set_device(device) | ||
|
||
def supports_stream(self) -> bool: | ||
"""Validate if the device type support create a stream""" | ||
return True | ||
|
||
def create_stream(self, device: torch.device) -> torch.cuda.Stream: | ||
"""Create a stream on cuda device""" | ||
return torch.cuda.Stream(device) | ||
|
||
def get_stream_context(self, stream): | ||
"""Get a stream context for cuda device""" | ||
return torch.cuda.stream(stream) | ||
|
||
def get_current_stream(self) -> torch.cuda.Stream: | ||
"""Get current stream for cuda device""" | ||
return torch.cuda.current_stream() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we remove this high level
DeviceManager
abstraction and just keep theTorchDeviceManager
. No need for the extra abstraction for now.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, move this entire folder to
ray/train/torch/_internal
instead.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, just introduce the
TorchDeviceManager
for now.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are trying to extend the
DeviceManager
inray.air
to support more third-party devices, and the plan is to not only use it for Ray Train, but also include RLlib and others. So it seems more reasonable to maintain it withinair
. What do you think?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's just keep it in Train for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, Train rely on the
get_devices
in AIR, so it's natural for us to implement theDeviceManager
in AIR to return the correct device to Train. If we move theDeviceManager
to Train, it would create a weird dependency where Train calls AIR'sget_devices
, which in turn calls back to Train'sDeviceManager
. Would you mind elaborating on your thoughts about this part?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@matthewdeng WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's just put it in
ray.air
for now. Can restructure the package in the future if needed.