diff --git a/python/ray/_private/ray_constants.py b/python/ray/_private/ray_constants.py index 81e32029b6a1..28643f8d540c 100644 --- a/python/ray/_private/ray_constants.py +++ b/python/ray/_private/ray_constants.py @@ -424,10 +424,13 @@ def env_set_by_user(key): CUDA_VISIBLE_DEVICES_ENV_VAR = "CUDA_VISIBLE_DEVICES" NEURON_RT_VISIBLE_CORES_ENV_VAR = "NEURON_RT_VISIBLE_CORES" TPU_VISIBLE_CHIPS_ENV_VAR = "TPU_VISIBLE_CHIPS" +NPU_RT_VISIBLE_DEVICES_ENV_VAR = "ASCEND_RT_VISIBLE_DEVICES" NEURON_CORES = "neuron_cores" GPU = "GPU" TPU = "TPU" +NPU = "NPU" +HPU = "HPU" RAY_WORKER_NICENESS = "RAY_worker_niceness" diff --git a/python/ray/air/_internal/device_manager/__init__.py b/python/ray/air/_internal/device_manager/__init__.py new file mode 100644 index 000000000000..824833df03ff --- /dev/null +++ b/python/ray/air/_internal/device_manager/__init__.py @@ -0,0 +1,92 @@ +import logging +import threading +from typing import Optional + +import ray +import ray._private.ray_constants as ray_constants +from ray.air._internal.device_manager.cpu import CPUTorchDeviceManager +from ray.air._internal.device_manager.hpu import HPUTorchDeviceManager +from ray.air._internal.device_manager.npu import NPUTorchDeviceManager +from ray.air._internal.device_manager.nvidia_gpu import CUDATorchDeviceManager +from ray.air._internal.device_manager.torch_device_manager import TorchDeviceManager + +logger = logging.getLogger(__name__) + + +DEFAULT_TORCH_DEVICE_MANAGER_CLS = CPUTorchDeviceManager + + +SUPPORTED_ACCELERATOR_TORCH_DEVICE_MANAGER = { + ray_constants.GPU: CUDATorchDeviceManager, + ray_constants.HPU: HPUTorchDeviceManager, + ray_constants.NPU: NPUTorchDeviceManager, +} + + +def register_custom_torch_dist_backend(backend: Optional[str] = None) -> None: + if backend == "hccl": + # The name for the communication backend of Habana and torch-npu is the same. + HPUTorchDeviceManager.register_custom_torch_dist_backend() + + NPUTorchDeviceManager.register_custom_torch_dist_backend() + + +_torch_device_manager = None +_torch_device_manager_lock = threading.Lock() + + +def get_torch_device_manager_by_context() -> TorchDeviceManager: + global _torch_device_manager + + with _torch_device_manager_lock: + if not _torch_device_manager: + existing_device_manager_cls = None + resources = ray.get_runtime_context().get_accelerator_ids() + + # select correct accelerator type from resources + for resource_type, resource_value in resources.items(): + device_manager_cls = SUPPORTED_ACCELERATOR_TORCH_DEVICE_MANAGER.get( + resource_type, None + ) + if resource_value and device_manager_cls: + # An error will raise when multiple accelerators are specified. + if existing_device_manager_cls: + raise RuntimeError( + "Unable to determine the appropriate DeviceManager " + f"for the specified resources {resources}." + ) + else: + existing_device_manager_cls = device_manager_cls + + device_manager_cls = ( + existing_device_manager_cls or DEFAULT_TORCH_DEVICE_MANAGER_CLS + ) + + _torch_device_manager = device_manager_cls() + + return _torch_device_manager + + +def get_torch_device_manager_by_device_type(device_type: str): + if device_type.lower() == ray_constants.GPU.lower() or device_type == "cuda": + return CUDATorchDeviceManager() + elif device_type.lower() == ray_constants.NPU.lower(): + return NPUTorchDeviceManager() + elif device_type.lower() == ray_constants.HPU.lower(): + return HPUTorchDeviceManager() + elif device_type.lower() == "cpu": + return CPUTorchDeviceManager() + + raise RuntimeError(f"Device type {device_type} cannot be recognized.") + + +__all__ = [ + TorchDeviceManager, + CPUTorchDeviceManager, + CUDATorchDeviceManager, + HPUTorchDeviceManager, + NPUTorchDeviceManager, + register_custom_torch_dist_backend, + get_torch_device_manager_by_context, + get_torch_device_manager_by_device_type, +] diff --git a/python/ray/air/_internal/device_manager/cpu.py b/python/ray/air/_internal/device_manager/cpu.py new file mode 100644 index 000000000000..76fa73765287 --- /dev/null +++ b/python/ray/air/_internal/device_manager/cpu.py @@ -0,0 +1,30 @@ +from contextlib import contextmanager +from typing import List + +import torch + +from ray.air._internal.device_manager.torch_device_manager import TorchDeviceManager + + +class CPUTorchDeviceManager(TorchDeviceManager): + """CPU device manager""" + + def is_available(self) -> bool(): + return True + + def get_devices(self) -> List[torch.device]: + """Gets the correct torch device list configured for this process.""" + return [torch.device("cpu")] + + def supports_stream(self) -> bool: + """Validate if the device type support create a stream""" + return False + + def get_stream_context(self, stream): + """Return empty context mananger for CPU.""" + + @contextmanager + def default_context_manager(): + yield + + return default_context_manager() diff --git a/python/ray/air/_internal/device_manager/hpu.py b/python/ray/air/_internal/device_manager/hpu.py new file mode 100644 index 000000000000..bb402ea65b0d --- /dev/null +++ b/python/ray/air/_internal/device_manager/hpu.py @@ -0,0 +1,50 @@ +from contextlib import contextmanager +from typing import List, Union + +import torch + +from ray._private.accelerators.hpu import HPU_PACKAGE_AVAILABLE +from ray.air._internal.device_manager.torch_device_manager import TorchDeviceManager + +if HPU_PACKAGE_AVAILABLE: + import habana_frameworks.torch.hpu as torch_hpu + + +class HPUTorchDeviceManager(TorchDeviceManager): + """HPU device manager""" + + @staticmethod + def register_custom_torch_dist_backend(): + if HPU_PACKAGE_AVAILABLE: + import habana_frameworks.torch.core # noqa: F401 + import habana_frameworks.torch.distributed.hccl # noqa: F401 + + def is_available(self) -> bool(): + if not HPU_PACKAGE_AVAILABLE: + return False + + return torch_hpu.is_available() + + def get_devices(self) -> List[torch.device]: + if not self.is_available(): + raise RuntimeError( + "Using HPUTorchDeviceManager but torch hpu is not available." + ) + + return [torch.device("hpu")] + + def set_device(self, device: Union[torch.device, int, str, None]): + torch_hpu.set_device(device) + + def supports_stream(self) -> bool: + """Validate if the device type support create a stream""" + return False + + def get_stream_context(self, stream): + """Get HPU stream context manager, empty so far.""" + + @contextmanager + def default_context_manager(): + yield + + return default_context_manager() diff --git a/python/ray/air/_internal/device_manager/npu.py b/python/ray/air/_internal/device_manager/npu.py new file mode 100644 index 000000000000..aa6d7bad2408 --- /dev/null +++ b/python/ray/air/_internal/device_manager/npu.py @@ -0,0 +1,105 @@ +import os +from importlib.util import find_spec +from typing import List, Union + +import torch + +import ray +import ray._private.ray_constants as ray_constants +from ray.air._internal.device_manager.torch_device_manager import TorchDeviceManager + + +def is_package_present(package_name: str) -> bool: + try: + return find_spec(package_name) is not None + except ModuleNotFoundError: + return False + + +NPU_TORCH_PACKAGE_AVAILABLE = is_package_present("torch_npu") + + +if NPU_TORCH_PACKAGE_AVAILABLE: + import torch_npu # noqa: F401 + + +class NPUTorchDeviceManager(TorchDeviceManager): + """Ascend NPU device manager""" + + @staticmethod + def register_custom_torch_dist_backend(): + if NPU_TORCH_PACKAGE_AVAILABLE: + import torch_npu # noqa: F401, F811 + + def is_available(self) -> bool: + if not NPU_TORCH_PACKAGE_AVAILABLE: + return False + + return torch.npu.is_available() + + def get_devices(self) -> List[torch.device]: + """Gets the correct torch device list configured for this process. + + Returns a list of torch NPU devices allocated for the current worker. + If no NPUs are assigned, then it returns a list with a single CPU device. + """ + if NPU_TORCH_PACKAGE_AVAILABLE and torch.npu.is_available(): + npu_ids = [ + str(id) + for id in ray.get_runtime_context().get_accelerator_ids()[ + ray_constants.NPU + ] + ] + + device_ids = [] + + if len(npu_ids) > 0: + npu_visible_str = os.environ.get( + ray_constants.NPU_RT_VISIBLE_DEVICES_ENV_VAR, "" + ) + if npu_visible_str and npu_visible_str != "NoDevFiles": + npu_visible_list = npu_visible_str.split(",") + else: + npu_visible_list = [] + + for npu_id in npu_ids: + try: + device_ids.append(npu_visible_list.index(npu_id)) + except IndexError: + raise RuntimeError( + "ASCEND_RT_VISIBLE_DEVICES set incorrectly. " + f"Got {npu_visible_str}, expected to include {npu_id}. " + "Did you override the `ASCEND_RT_VISIBLE_DEVICES` " + "environment variable?" + ) + else: + # If called on the driver or outside of Ray Train, return the + # 0th device. + device_ids.append(0) + + devices = [torch.device(f"npu:{device_id}") for device_id in device_ids] + else: + raise RuntimeError( + "Using NPUTorchDeviceManager but torch npu is not available." + ) + + return devices + + def set_device(self, device: Union[torch.device, int]): + torch.npu.set_device(device) + + def supports_stream(self) -> bool: + """Validate if the device type support to create a stream""" + return True + + def create_stream(self, device): + """Create a stream on NPU device""" + return torch.npu.Stream(device) + + def get_stream_context(self, stream): + """Get a torch.stream context on NPU device""" + return torch.npu.stream(stream) + + def get_current_stream(self): + """Get current stream for NPU device""" + return torch.npu.current_stream() diff --git a/python/ray/air/_internal/device_manager/nvidia_gpu.py b/python/ray/air/_internal/device_manager/nvidia_gpu.py new file mode 100644 index 000000000000..f4bb1b54097e --- /dev/null +++ b/python/ray/air/_internal/device_manager/nvidia_gpu.py @@ -0,0 +1,79 @@ +import os +from typing import List, Union + +import torch + +import ray +from ray.air._internal.device_manager.torch_device_manager import TorchDeviceManager + + +class CUDATorchDeviceManager(TorchDeviceManager): + """CUDA device manager""" + + def is_available(self) -> bool(): + return torch.cuda.is_available() + + def get_devices(self) -> List[torch.device]: + """Gets the correct torch device list configured for this process. + + Returns a list of torch CUDA devices allocated for the current worker. + If no GPUs are assigned, then it returns a list with a single CPU device. + + Assumes that `CUDA_VISIBLE_DEVICES` is set and is a + superset of the `ray.get_gpu_ids()`. + """ + + # GPU IDs are assigned by Ray after you specify "use_gpu" + # GPU `ray.get_gpu_ids()` may return ints or may return strings. + # We should always convert to strings. + gpu_ids = [str(id) for id in ray.get_gpu_ids()] + + device_ids = [] + + if len(gpu_ids) > 0: + cuda_visible_str = os.environ.get("CUDA_VISIBLE_DEVICES", "") + if cuda_visible_str and cuda_visible_str != "NoDevFiles": + cuda_visible_list = cuda_visible_str.split(",") + else: + cuda_visible_list = [] + + # By default, there should only be one GPU ID if `use_gpu=True`. + # If there are multiple GPUs, return a list of devices. + # If using fractional GPUs, these IDs are not guaranteed + # to be unique across different processes. + for gpu_id in gpu_ids: + try: + device_ids.append(cuda_visible_list.index(gpu_id)) + except IndexError: + raise RuntimeError( + "CUDA_VISIBLE_DEVICES set incorrectly. " + f"Got {cuda_visible_str}, expected to include {gpu_id}. " + "Did you override the `CUDA_VISIBLE_DEVICES` environment" + " variable? If not, please help file an issue on Github." + ) + + else: + # If called on the driver or outside of Ray Train, return the + # 0th device. + device_ids.append(0) + + return [torch.device(f"cuda:{device_id}") for device_id in device_ids] + + def set_device(self, device: Union[torch.device, int, str, None]): + torch.cuda.set_device(device) + + def supports_stream(self) -> bool: + """Validate if the device type support create a stream""" + return True + + def create_stream(self, device: torch.device) -> torch.cuda.Stream: + """Create a stream on cuda device""" + return torch.cuda.Stream(device) + + def get_stream_context(self, stream): + """Get a stream context for cuda device""" + return torch.cuda.stream(stream) + + def get_current_stream(self) -> torch.cuda.Stream: + """Get current stream for cuda device""" + return torch.cuda.current_stream() diff --git a/python/ray/air/_internal/device_manager/torch_device_manager.py b/python/ray/air/_internal/device_manager/torch_device_manager.py new file mode 100644 index 000000000000..d522a477ef58 --- /dev/null +++ b/python/ray/air/_internal/device_manager/torch_device_manager.py @@ -0,0 +1,40 @@ +from abc import ABC +from typing import List, Union + +import torch + + +class TorchDeviceManager(ABC): + """This class contains the function needed for supporting + an acclerator family in Ray AI Library. + """ + + def is_available(self) -> bool: + """Validate if device is available.""" + ... + + def get_devices(self) -> List[torch.device]: + """Gets the correct torch device configured for this process""" + ... + + def set_device(self, device: Union[torch.device, int, str, None]): + """Set the correct device for this process""" + ... + + def supports_stream(self) -> bool: + """Validate if the device type support create a stream""" + ... + + def create_stream(self, device: torch.device): + """Create a device stream""" + ... + + def get_stream_context(self, stream): + """Get a stream context of device. If device didn't support stream, + this should return a empty context manager instead of None. + """ + ... + + def get_current_stream(self): + """Get current stream on accelerators like torch.cuda.current_stream""" + ... diff --git a/python/ray/air/_internal/torch_utils.py b/python/ray/air/_internal/torch_utils.py index c4c4d57c5e10..caeb27a20a30 100644 --- a/python/ray/air/_internal/torch_utils.py +++ b/python/ray/air/_internal/torch_utils.py @@ -1,4 +1,3 @@ -import os import warnings from typing import Any, Dict, List, Optional, Union @@ -6,65 +5,18 @@ import pandas as pd import torch -import ray -from ray._private.accelerators.hpu import HPU_PACKAGE_AVAILABLE +from ray.air._internal.device_manager import get_torch_device_manager_by_context from ray.air.util.data_batch_conversion import _unwrap_ndarray_object_type_if_needed -if HPU_PACKAGE_AVAILABLE: - import habana_frameworks.torch.hpu as torch_hpu - def get_devices() -> List[torch.device]: """Gets the correct torch device list configured for this process. - Returns a list of torch CUDA devices allocated for the current worker. - If no GPUs are assigned, then it returns a list with a single CPU device. - - Assumes that `CUDA_VISIBLE_DEVICES` is set and is a - superset of the `ray.get_gpu_ids()`. + Returns a list of torch accelerator (GPU, HPU, NPU...) devices allocated for + the current worker. + If no accelerators are assigned, then it returns a list with a single CPU device. """ - if torch.cuda.is_available(): - # GPU IDs are assigned by Ray after you specify "use_gpu" - # GPU `ray.get_gpu_ids()` may return ints or may return strings. - # We should always convert to strings. - gpu_ids = [str(id) for id in ray.get_gpu_ids()] - - device_ids = [] - - if len(gpu_ids) > 0: - cuda_visible_str = os.environ.get("CUDA_VISIBLE_DEVICES", "") - if cuda_visible_str and cuda_visible_str != "NoDevFiles": - cuda_visible_list = cuda_visible_str.split(",") - else: - cuda_visible_list = [] - - # By default, there should only be one GPU ID if `use_gpu=True`. - # If there are multiple GPUs, return a list of devices. - # If using fractional GPUs, these IDs are not guaranteed - # to be unique across different processes. - for gpu_id in gpu_ids: - try: - device_ids.append(cuda_visible_list.index(gpu_id)) - except IndexError: - raise RuntimeError( - "CUDA_VISIBLE_DEVICES set incorrectly. " - f"Got {cuda_visible_str}, expected to include {gpu_id}. " - "Did you override the `CUDA_VISIBLE_DEVICES` environment" - " variable? If not, please help file an issue on Github." - ) - - else: - # If called on the driver or outside of Ray Train, return the - # 0th device. - device_ids.append(0) - - devices = [torch.device(f"cuda:{device_id}") for device_id in device_ids] - elif HPU_PACKAGE_AVAILABLE and torch_hpu.is_available(): - devices = [torch.device("hpu")] - else: - devices = [torch.device("cpu")] - - return devices + return get_torch_device_manager_by_context().get_devices() def convert_pandas_to_torch_tensor( diff --git a/python/ray/train/BUILD b/python/ray/train/BUILD index b7557d9cf90e..8a566348a922 100644 --- a/python/ray/train/BUILD +++ b/python/ray/train/BUILD @@ -559,6 +559,14 @@ py_test( deps = [":train_lib", ":conftest"] ) +py_test( + name = "test_torch_device_manager", + size = "small", + srcs = ["tests/test_torch_device_manager.py"], + tags = ["team:ml", "exclusive", "ray_air", "gpu_only"], + deps = [":train_lib", ":conftest"] +) + py_test( name = "test_torch_trainer", size = "large", diff --git a/python/ray/train/_internal/backend_executor.py b/python/ray/train/_internal/backend_executor.py index ee6041a24248..d3552c1cf379 100644 --- a/python/ray/train/_internal/backend_executor.py +++ b/python/ray/train/_internal/backend_executor.py @@ -26,6 +26,7 @@ ENABLE_DETAILED_AUTOFILLED_METRICS_ENV, ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV, ENABLE_SHARE_NEURON_CORES_ACCELERATOR_ENV, + ENABLE_SHARE_NPU_RT_VISIBLE_DEVICES_ENV, RAY_TRAIN_ENABLE_STATE_TRACKING, TRAIN_ENABLE_WORKER_SPREAD_ENV, TRAIN_PLACEMENT_GROUP_TIMEOUT_S_ENV, @@ -117,7 +118,12 @@ def __init__( ray_constants.NEURON_CORES, ENABLE_SHARE_NEURON_CORES_ACCELERATOR_ENV, ray_constants.NEURON_RT_VISIBLE_CORES_ENV_VAR, - ) + ), + ResourceConfig( + ray_constants.NPU, + ENABLE_SHARE_NPU_RT_VISIBLE_DEVICES_ENV, + ray_constants.NPU_RT_VISIBLE_DEVICES_ENV_VAR, + ), ] # Record the initialization time of BackendExecutor, which is diff --git a/python/ray/train/constants.py b/python/ray/train/constants.py index 530de9d3a2d5..62611a6060e5 100644 --- a/python/ray/train/constants.py +++ b/python/ray/train/constants.py @@ -62,6 +62,10 @@ def _get_ray_train_session_dir() -> str: "TRAIN_ENABLE_SHARE_NEURON_CORES_ACCELERATOR" ) +# Integer value which if set will not share npu visible devices +# across workers. 1 for True (default), 0 for False. +ENABLE_SHARE_NPU_RT_VISIBLE_DEVICES_ENV = "TRAIN_ENABLE_SHARE_ASCEND_RT_VISIBLE_DEVICES" + # Integer value which indicates the number of seconds to wait when creating # the worker placement group before timing out. TRAIN_PLACEMENT_GROUP_TIMEOUT_S_ENV = "TRAIN_PLACEMENT_GROUP_TIMEOUT_S" diff --git a/python/ray/train/tests/test_torch_device_manager.py b/python/ray/train/tests/test_torch_device_manager.py new file mode 100644 index 000000000000..f4fd224bcdfc --- /dev/null +++ b/python/ray/train/tests/test_torch_device_manager.py @@ -0,0 +1,95 @@ +import pytest +import torch + +import ray +from ray.air._internal.device_manager import ( + CUDATorchDeviceManager, + NPUTorchDeviceManager, + get_torch_device_manager_by_context, +) +from ray.air._internal.device_manager.npu import NPU_TORCH_PACKAGE_AVAILABLE +from ray.cluster_utils import Cluster +from ray.train import ScalingConfig +from ray.train.torch import TorchTrainer + +if NPU_TORCH_PACKAGE_AVAILABLE: + import torch_npu # noqa: F401 + + +@pytest.fixture +def ray_2_node_2_npus(): + cluster = Cluster() + for _ in range(2): + cluster.add_node(num_cpus=4, resources={"NPU": 2}) + + ray.init(address=cluster.address) + + yield + + ray.shutdown() + cluster.shutdown() + + +@pytest.fixture +def ray_1_node_1_gpu_1_npu(): + cluster = Cluster() + cluster.add_node(num_cpus=4, num_gpus=1, resources={"NPU": 1}) + ray.init(address=cluster.address) + + yield + + ray.shutdown() + cluster.shutdown() + + +def test_cuda_device_manager(ray_2_node_2_gpu): + def train_fn(): + assert isinstance(get_torch_device_manager_by_context(), CUDATorchDeviceManager) + + trainer = TorchTrainer( + train_loop_per_worker=train_fn, + scaling_config=ScalingConfig( + num_workers=1, use_gpu=True, resources_per_worker={"GPU": 1} + ), + ) + + trainer.fit() + + +def test_npu_device_manager(ray_2_node_2_npus): + def train_fn(): + assert isinstance(get_torch_device_manager_by_context(), NPUTorchDeviceManager) + + trainer = TorchTrainer( + train_loop_per_worker=train_fn, + scaling_config=ScalingConfig(num_workers=1, resources_per_worker={"NPU": 1}), + ) + + if NPU_TORCH_PACKAGE_AVAILABLE and torch.npu.is_available(): + # Except test run successfully when torch npu is available. + trainer.fit() + else: + # A RuntimeError will be triggered when NPU resources are declared + # but the torch npu is actually not available + with pytest.raises(RuntimeError): + trainer.fit() + + +def test_device_manager_conflict(ray_1_node_1_gpu_1_npu): + trainer = TorchTrainer( + train_loop_per_worker=lambda: None, + scaling_config=ScalingConfig( + num_workers=1, use_gpu=True, resources_per_worker={"GPU": 1, "NPU": 1} + ), + ) + # TODO: Do validation at the `ScalingConfig.__post_init__` level instead. + with pytest.raises(RuntimeError): + trainer.fit() + + +if __name__ == "__main__": + import sys + + import pytest + + sys.exit(pytest.main(["-v", "-x", __file__])) diff --git a/python/ray/train/torch/config.py b/python/ray/train/torch/config.py index a334829fb9b2..a0ecc61e3b87 100644 --- a/python/ray/train/torch/config.py +++ b/python/ray/train/torch/config.py @@ -9,7 +9,7 @@ from packaging.version import Version import ray -from ray._private.accelerators.hpu import HPU_PACKAGE_AVAILABLE +from ray.air._internal.device_manager import register_custom_torch_dist_backend from ray.train._internal.utils import get_address_and_port from ray.train._internal.worker_group import WorkerGroup from ray.train.backend import Backend, BackendConfig @@ -109,9 +109,8 @@ def _setup_torch_process_group( f"To override this behavior, you can set {TORCH_NCCL_ASYNC_ERROR_HANDLING_ENV_VAR}=0." # noqa: E501 ) os.environ[TORCH_NCCL_ASYNC_ERROR_HANDLING_ENV_VAR] = "1" - elif backend == "hccl" and HPU_PACKAGE_AVAILABLE: - import habana_frameworks.torch.core as htcore # noqa: F401 - import habana_frameworks.torch.distributed.hccl as hpu_dist # noqa: F401 + elif backend == "hccl": + register_custom_torch_dist_backend(backend) dist.init_process_group( backend=backend, diff --git a/python/ray/train/torch/train_loop_utils.py b/python/ray/train/torch/train_loop_utils.py index 0c421ba2aa3d..465eed45a4a8 100644 --- a/python/ray/train/torch/train_loop_utils.py +++ b/python/ray/train/torch/train_loop_utils.py @@ -20,6 +20,10 @@ ) from ray._private.usage.usage_lib import TagKey, record_extra_usage_tag +from ray.air._internal.device_manager import ( + get_torch_device_manager_by_context, + get_torch_device_manager_by_device_type, +) from ray.train._internal import session from ray.train._internal.accelerator import Accelerator from ray.train._internal.session import get_accelerator, set_accelerator @@ -365,6 +369,7 @@ def __init__(self, amp: bool = False): self.amp_is_enabled = amp self.scaler = GradScaler() if amp else None self._seed = None + self.device_manager = get_torch_device_manager_by_context() def prepare_model( self, @@ -402,8 +407,8 @@ def prepare_model( if isinstance(device, list): device = device[0] - if torch.cuda.is_available(): - torch.cuda.set_device(device) + if self.device_manager.is_available(): + self.device_manager.set_device(device) if move_to_device: if rank == 0: @@ -451,7 +456,7 @@ def model_get_state(self): if parallel_strategy and world_size > 1: if parallel_strategy == "ddp": DataParallel = DistributedDataParallel - if torch.cuda.is_available(): + if self.device_manager.is_available() and device.type != "cpu": parallel_strategy_kwargs = { "device_ids": [device], "output_device": device, @@ -632,12 +637,18 @@ def __init__( self._dataloader = base_dataloader self.dataloader_iter = None self.device = device + + self.device_manager = get_torch_device_manager_by_device_type(device.type) + # disable auto transfer (host->device) if cpu is used - self._auto_transfer = auto_transfer if device.type == "cuda" else False - # create a new CUDA stream to move data from host to device concurrently + if device.type != "cpu" and self.device_manager.supports_stream(): + self._auto_transfer = auto_transfer + else: + self._auto_transfer = False + # create a new device stream to move data from host to device concurrently self._memcpy_stream = ( - torch.cuda.Stream(device) - if device.type == "cuda" and self._auto_transfer + self.device_manager.create_stream(device) + if device.type != "cpu" and self._auto_transfer else None ) self.next_batch = None @@ -653,7 +664,7 @@ def try_move_device(i): logger.debug(f"Item {i} cannot be moved to device " f"{self.device}.") return i - with torch.cuda.stream(self._memcpy_stream): + with self.device_manager.get_stream_context(self._memcpy_stream): if isinstance(item, collections.abc.Mapping): item_on_device = {k: self._move_to_device(v) for k, v in item.items()} elif isinstance(item, tuple): @@ -677,7 +688,7 @@ def _wait_for_batch(self, item): # https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # The training stream (current) needs to wait until # the memory copy stream finishes. - curr_stream = torch.cuda.current_stream() + curr_stream = self.device_manager.get_current_stream() curr_stream.wait_stream(self._memcpy_stream) # When a tensor is used by CUDA streams different from # its original allocator, we need to call ``record_stream``