Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add util to create a torch ddp process group for a list of workers. #34202

Merged
merged 12 commits into from
Apr 19, 2023

Conversation

gjoliver
Copy link
Member

Why are these changes needed?

For running DeepSpeed jobs.

Related issue number

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • [*] Unit tests
    • Release tests
    • This PR is not tested :(

@amogkam
Copy link
Contributor

amogkam commented Apr 10, 2023

Is it possible to consolidate on a singular code path?

@gjoliver gjoliver requested a review from amogkam April 10, 2023 18:15
@gjoliver
Copy link
Member Author

Is it possible to consolidate on a singular code path?

this goal is to not share code with Train, and make it a generic AIR util for now.
this will be used for llm training serving, where each Trial / Replica is a ddp group.
so there may be things that are different from Train.
what do you think?

@amogkam amogkam self-assigned this Apr 11, 2023
Copy link
Contributor

@jovany-wang jovany-wang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 high level questions:

  1. What's the difference between torch ddp in ray train?
  2. How can we distribute workers onto different nodes?
  3. what's the whole pipeline on integrating deepspeed?

python/ray/air/util/torch_dist.py Outdated Show resolved Hide resolved
@gjoliver
Copy link
Member Author

3 high level questions:

  1. What's the difference between torch ddp in ray train?
  2. How can we distribute workers onto different nodes?
  3. what's the whole pipeline on integrating deepspeed?

Pretty much no difference.
I will cc you on the example notebook, and you will see how we integrate DeepSpeed in this case.

@fecet
Copy link

fecet commented Apr 16, 2023

Hello, I'm really impressed with this feature! Would it be possible for me to obtain a copy of the example notebook?

# Wait for all workers to join the process group.
ray.get(setup_futures)

return local_ranks
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we also return a dict mapping IPs to world ranks?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't find the need because I want to use the index of the worker as their global rank.
do you see any flaw in this approach?

Copy link
Member

@Yard1 Yard1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, one question

Copy link
Contributor

@amogkam amogkam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would really like to unify on a common code path. All the differences between Ray train logic and this logic can be abstracted away via function arguments.

  1. A common utility function for sharing cuda visible devices that can handle multiple gpu per worker case as well.
  2. A common utility function for getting the local rank, local world size, and node rank
  3. TorchBackend.on_start should call the init_torch_dist_process_group function

local_world_size: int,
master_addr: str,
master_port: str,
gpu_ids: List[int],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this may not be List[int] if using multiple GPUs per worker.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added logic to flatten the gpu ids for multiple workers.
changed the unit test to include this case too.

# All the workers on a specific node.
node_to_workers = {}
# All the gpu ids visible to all the workers on a specific node.
node_to_gpu_ids = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can do defaultdict(list) instead of needing to do setdefault every time.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

return func(*args, **kwargs)
except Exception as e:
skipped = skip_exceptions(e)
raise skipped from exception_cause(skipped)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we want to remove all this skip_exceptions stuff? It only works when used with Train/Tune

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, good to know. thanks for the comment.

return node_id, gpu_ids


def init_torch_dist_process_group(
Copy link
Contributor

@amogkam amogkam Apr 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also need to add corresponding shutdown logic?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point. done.

return node_id, gpu_ids


def init_torch_dist_process_group(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this intended to be public API? If not, then let's move it into _internal package?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will be included in our public facing examples.
if we move those predictors into ray/air/, we can mark this private, or get rid of this altogether and refactor Train to share this logic.

raise RuntimeError("Distributed torch is not available.")

# Build a map from node_id to workers on that node.
node_and_gpu_ids = ray.get(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be possible to make sure that we sort the workers by gpu id to avoid the issue fixed in #33159?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

am using a set to collect per-node visible GPUs now.
list(set) will always be sorted.
added a comment about this.

Jun Gong added 8 commits April 18, 2023 09:59
Signed-off-by: Jun Gong <[email protected]>
Signed-off-by: Jun Gong <[email protected]>
Signed-off-by: Jun Gong <[email protected]>
Signed-off-by: Jun Gong <[email protected]>
Signed-off-by: Jun Gong <[email protected]>
Jun Gong added 3 commits April 18, 2023 11:26
Signed-off-by: Jun Gong <[email protected]>
fix
Signed-off-by: Jun Gong <[email protected]>
Signed-off-by: Jun Gong <[email protected]>
Copy link
Member Author

@gjoliver gjoliver left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for all the comments.
ptal.

return func(*args, **kwargs)
except Exception as e:
skipped = skip_exceptions(e)
raise skipped from exception_cause(skipped)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, good to know. thanks for the comment.

return node_id, gpu_ids


def init_torch_dist_process_group(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will be included in our public facing examples.
if we move those predictors into ray/air/, we can mark this private, or get rid of this altogether and refactor Train to share this logic.

# Wait for all workers to join the process group.
ray.get(setup_futures)

return local_ranks
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't find the need because I want to use the index of the worker as their global rank.
do you see any flaw in this approach?

local_world_size: int,
master_addr: str,
master_port: str,
gpu_ids: List[int],
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added logic to flatten the gpu ids for multiple workers.
changed the unit test to include this case too.

raise RuntimeError("Distributed torch is not available.")

# Build a map from node_id to workers on that node.
node_and_gpu_ids = ray.get(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

am using a set to collect per-node visible GPUs now.
list(set) will always be sorted.
added a comment about this.

# All the workers on a specific node.
node_to_workers = {}
# All the gpu ids visible to all the workers on a specific node.
node_to_gpu_ids = {}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

return node_id, gpu_ids


def init_torch_dist_process_group(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point. done.

@gjoliver
Copy link
Member Author

unit tests all pass now. lint error is not related.
@amogkam can I get your blessing? thanks.

python/ray/air/util/torch_dist.py Outdated Show resolved Hide resolved
Co-authored-by: Amog Kamsetty <[email protected]>
Signed-off-by: Jun Gong <[email protected]>
@gjoliver gjoliver merged commit a0255e5 into ray-project:master Apr 19, 2023
elliottower pushed a commit to elliottower/ray that referenced this pull request Apr 22, 2023
ProjectsByJackHe pushed a commit to ProjectsByJackHe/ray that referenced this pull request May 4, 2023
architkulkarni pushed a commit to architkulkarni/ray that referenced this pull request May 16, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants