-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Train] Decouple device-related modules and add Huawei NPU support to Ray Train #44086
Conversation
c72e9cc
to
e9a7311
Compare
This is our plan to enhance the support for third-party devices in Ray Train, which will also contribute to expanding device compatibility in Rllib. Looking forward to receiving your feedback @woshiyyya . Kindly invite developers working with HPU, NPU, and AMD GPUs to stay informed and engaged with these updates. @kira-lin @matthewdeng @vickytsang @nemo9cby @Bye-legumes |
Thanks for the contribution! We'll review it soon:) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @liuxsh9 , thanks for the contribution! Left some comments.
I personally like the idea of abstracting out the accelerator concept, which will be easier to generalize to other new accelerator types.
1e9fc35
to
dce7e54
Compare
Hi @woshiyyya, may I ask if the modifications made above have met your expectations? If you have any other concerns, we would be happy to provide further information. |
@liuxsh9 Sure! will take another look these days! |
@woshiyyya can you follow up? im marking it as p2 for now @liuxsh9 as not seeing any users blocked on this; do tell me if wrong though and we can juggle priority here. |
Hi @anyscalesam @liuxsh9 , This PR involves major changes to Ray Train equipment management, which needs to be fully tested to ensure stability. We will discuss with Train team for the next steps and give an update soon. |
@woshiyyya This looks like a nice accelerator abstraction , and we plan to refactor the Intel GPU backend as well . |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @liuxsh9 , left some comments. Here are some suggestions:
- Instead of defining methods in
DeviceManager
all as global staticmethods, let's make some of them instance methods (e.g.get_devices
,is_device_available
), and initialize aDeviceManager
object for each worker. - Specify the allocated device object while initializing the
DeviceManager
object. - Explicitly choose the
DeviceManager
type, instead of determining by the available resources in the current worker.
Any thoughts?
…ay Train. Signed-off-by: liuxsh9 <[email protected]>
c70fae0
to
ebcccca
Compare
accident -- some changes are still required
1. Replace hard-coding string with constant from ray_constants 2. Adjust the timing of registering the torch accelerator module. 3. Adjust the test cases. Signed-off-by: liuxsh9 <[email protected]>
…RCH_DEVICE_MANAGER_CLS` to `CPUTorchDeviceManager` Signed-off-by: liuxsh9 <[email protected]>
Signed-off-by: liuxsh9 <[email protected]>
Signed-off-by: liuxsh9 <[email protected]>
Signed-off-by: liuxsh9 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Just a few nits here -- I will run some GPU release tests and then we can merge.
python/ray/air/_internal/device_manager/torch_device_manager.py
Outdated
Show resolved
Hide resolved
python/ray/air/_internal/device_manager/torch_device_manager.py
Outdated
Show resolved
Hide resolved
from ray.air._internal.device_manager.torch_device_manager import TorchDeviceManager | ||
|
||
|
||
@lru_cache() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how does this play with multi-node settings? does this lru cache var get shipped over to other nodes?
ex: a CPU head node would possibly have this as false, but we don't want to keep that around to the worker nodes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your reminder, the cached var won't shipped over to other nodes. But even so, it is removed to avoid unnecessary code complexity.
@@ -16,6 +17,7 @@ | |||
if TYPE_CHECKING: | |||
from ray.data.preprocessor import Preprocessor | |||
|
|||
try_register_torch_accelerator_module() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we remove this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, has been removed.
1. Fix nits. 2. Raise clear runtime error when accelerator is allocated but unavailable. 3. Removed redundant module registrations. Signed-off-by: liuxsh9 <[email protected]> 0#
Signed-off-by: liuxsh9 <[email protected]>
# reset device is needed for npu in a new thread so far. | ||
if device.type == "npu": | ||
self.device_manager.set_device(device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we running in a new thread here? Can we remove this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic probably has something to do with the test failure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we running in a new thread here? Can we remove this?
Has been removed.
This logic probably has something to do with the test failure.
The test called get_torch_device_manager
locally (not in @remote
), which without RuntimeContext
. Fixed it by getting device directly use device.type
in _WrappedDataLoader
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@matthewdeng WDYT?
Signed-off-by: liuxsh9 <[email protected]>
Signed-off-by: liuxsh9 <[email protected]>
Signed-off-by: liuxsh9 <[email protected]>
Signed-off-by: liuxsh9 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great job, LGTM! Thanks for the patience in this effort! 🚢 🚀
Starting a release test sanity check here: https://buildkite.com/ray-project/release/builds/21289 |
def register_custom_torch_dist_backend(backend: Optional[str] = None) -> None: | ||
if backend == "hccl": | ||
# The name for the communication backend of Habana and torch-npu is the same. | ||
HPUTorchDeviceManager.register_custom_torch_dist_backend() | ||
|
||
NPUTorchDeviceManager.register_custom_torch_dist_backend() | ||
|
||
|
||
def get_torch_device_manager_cls_by_resources( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason not to define these functions in torch_device_manager.py
instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently the dependency is that these methods depend on various XPUDeviceManager
implementations, and the XPUDeviceManager
depend on TorchDeviceManager
in torch_device_manager.py.
If we put these methods into the torch_device_manager.py, it may lead to circular import issues.
We refer the design of AcceleratorManager
, it put the higher-level API functions in the __init__.py
.
My understanding may be limited, could you please provide more information?
_torch_device_manager = None | ||
_torch_device_manager_lock = threading.Lock() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need a global variable to track this? Can this not just be tracked as an instance variable by the caller?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One issue with global variables is that we want these functions to be called by the individual workers, so they aren't pointing to these references.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This addresses previous reviewer's feedback. Additionally, we will check for _torch_device_manager
within the worker and initialize it if it doesn't exist. Can this meet the users' requirements?
_torch_device_manager_lock = threading.Lock() | ||
|
||
|
||
def get_torch_device_manager(device_type: Optional[str] = None) -> TorchDeviceManager: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO we should consider removing this function. It's convenient to have a single function like this, but for all the usages it seems we want one explicit path (either with or without a device), so just calling that explicit logic directly is easier to follow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good suggestion, this has been modified to have two clear function entry points
1. Clearly categorize the methods for obtaining deviceManager into two classes. 2. Lazily instantiate the device manager whenever the first call to get_torch_device_manager happens. Signed-off-by: liuxsh9 <[email protected]>
Signed-off-by: matthewdeng <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks!
Signed-off-by: liuxsh9 <[email protected]>
Hi @woshiyyya , I think I've addressed all the concerns you raised in previous review. However, the PR is blocked on your approval. Could you please take another look and let me know if everything looks good now? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approved. @liuxsh9 Thanks for the contribution!
… Ray Train (ray-project#44086) We are looking to expand the hardware support range of Ray Train by incorporating Huawei Ascend NPU support. However, as the number of hardware types increases, scattered and device-specific modifications have been made to the code, which can impact future compatibility and maintainability. To address this, we have extracted the device-related modules from Ray Train and consolidated them into the `accelerator_utils`. This allows for greater independence among the device-specific code, resulting in improved maintainability. Signed-off-by: liuxsh9 <[email protected]> Signed-off-by: matthewdeng <[email protected]> Co-authored-by: matthewdeng <[email protected]> Signed-off-by: ujjawal-khare <[email protected]>
Why are these changes needed?
We are looking to expand the hardware support range of Ray Train by incorporating Huawei Ascend NPU support.
However, as the number of hardware types increases, scattered and device-specific modifications have been made to the code, which can impact future compatibility and maintainability.
To address this, we have extracted the device-related modules from Ray Train and consolidated them into the
accelerator_utils
. This allows for greater independence among the device-specific code, resulting in improved maintainability.Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.