-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add util to create a torch ddp process group for a list of workers. #34202
Conversation
Is it possible to consolidate on a singular code path? |
this goal is to not share code with Train, and make it a generic AIR util for now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 high level questions:
- What's the difference between torch ddp in ray train?
- How can we distribute workers onto different nodes?
- what's the whole pipeline on integrating deepspeed?
Pretty much no difference. |
161c970
to
7474e97
Compare
Hello, I'm really impressed with this feature! Would it be possible for me to obtain a copy of the example notebook? |
# Wait for all workers to join the process group. | ||
ray.get(setup_futures) | ||
|
||
return local_ranks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we also return a dict mapping IPs to world ranks?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't find the need because I want to use the index of the worker as their global rank.
do you see any flaw in this approach?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, one question
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would really like to unify on a common code path. All the differences between Ray train logic and this logic can be abstracted away via function arguments.
- A common utility function for sharing cuda visible devices that can handle multiple gpu per worker case as well.
- A common utility function for getting the local rank, local world size, and node rank
TorchBackend.on_start
should call theinit_torch_dist_process_group
function
local_world_size: int, | ||
master_addr: str, | ||
master_port: str, | ||
gpu_ids: List[int], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this may not be List[int]
if using multiple GPUs per worker.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added logic to flatten the gpu ids for multiple workers.
changed the unit test to include this case too.
python/ray/air/util/torch_dist.py
Outdated
# All the workers on a specific node. | ||
node_to_workers = {} | ||
# All the gpu ids visible to all the workers on a specific node. | ||
node_to_gpu_ids = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can do defaultdict(list)
instead of needing to do setdefault
every time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure
python/ray/air/util/torch_dist.py
Outdated
return func(*args, **kwargs) | ||
except Exception as e: | ||
skipped = skip_exceptions(e) | ||
raise skipped from exception_cause(skipped) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we want to remove all this skip_exceptions stuff? It only works when used with Train/Tune
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, good to know. thanks for the comment.
return node_id, gpu_ids | ||
|
||
|
||
def init_torch_dist_process_group( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also need to add corresponding shutdown logic?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point. done.
return node_id, gpu_ids | ||
|
||
|
||
def init_torch_dist_process_group( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this intended to be public API? If not, then let's move it into _internal
package?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this will be included in our public facing examples.
if we move those predictors into ray/air/, we can mark this private, or get rid of this altogether and refactor Train to share this logic.
raise RuntimeError("Distributed torch is not available.") | ||
|
||
# Build a map from node_id to workers on that node. | ||
node_and_gpu_ids = ray.get( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would it be possible to make sure that we sort the workers by gpu id to avoid the issue fixed in #33159?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
am using a set to collect per-node visible GPUs now.
list(set) will always be sorted.
added a comment about this.
Signed-off-by: Jun Gong <[email protected]>
Signed-off-by: Jun Gong <[email protected]>
Signed-off-by: Jun Gong <[email protected]>
Signed-off-by: Jun Gong <[email protected]>
Signed-off-by: Jun Gong <[email protected]>
Signed-off-by: Jun Gong <[email protected]>
Signed-off-by: Jun Gong <[email protected]>
Signed-off-by: Jun Gong <[email protected]>
8436314
to
e680053
Compare
Signed-off-by: Jun Gong <[email protected]>
Signed-off-by: Jun Gong <[email protected]>
Signed-off-by: Jun Gong <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for all the comments.
ptal.
python/ray/air/util/torch_dist.py
Outdated
return func(*args, **kwargs) | ||
except Exception as e: | ||
skipped = skip_exceptions(e) | ||
raise skipped from exception_cause(skipped) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, good to know. thanks for the comment.
return node_id, gpu_ids | ||
|
||
|
||
def init_torch_dist_process_group( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this will be included in our public facing examples.
if we move those predictors into ray/air/, we can mark this private, or get rid of this altogether and refactor Train to share this logic.
# Wait for all workers to join the process group. | ||
ray.get(setup_futures) | ||
|
||
return local_ranks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't find the need because I want to use the index of the worker as their global rank.
do you see any flaw in this approach?
local_world_size: int, | ||
master_addr: str, | ||
master_port: str, | ||
gpu_ids: List[int], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added logic to flatten the gpu ids for multiple workers.
changed the unit test to include this case too.
raise RuntimeError("Distributed torch is not available.") | ||
|
||
# Build a map from node_id to workers on that node. | ||
node_and_gpu_ids = ray.get( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
am using a set to collect per-node visible GPUs now.
list(set) will always be sorted.
added a comment about this.
python/ray/air/util/torch_dist.py
Outdated
# All the workers on a specific node. | ||
node_to_workers = {} | ||
# All the gpu ids visible to all the workers on a specific node. | ||
node_to_gpu_ids = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure
return node_id, gpu_ids | ||
|
||
|
||
def init_torch_dist_process_group( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point. done.
unit tests all pass now. lint error is not related. |
Co-authored-by: Amog Kamsetty <[email protected]> Signed-off-by: Jun Gong <[email protected]>
…ers. (ray-project#34202) Signed-off-by: Jun Gong <[email protected]> Signed-off-by: elliottower <[email protected]>
…ers. (ray-project#34202) Signed-off-by: Jun Gong <[email protected]> Signed-off-by: Jack He <[email protected]>
…ers. (ray-project#34202) Signed-off-by: Jun Gong <[email protected]>
Why are these changes needed?
For running DeepSpeed jobs.
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.