Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core][Distributed] Refactor ipc buffer init in CustomAllreduce #10030

Merged
merged 5 commits into from
Nov 7, 2024

Conversation

hanzhi713
Copy link
Contributor

@hanzhi713 hanzhi713 commented Nov 5, 2024

As discussed with @youkaichao, we use cuda API to share tensors instead of replying _share_cuda_, which won't break with expandable segment or future pytorch upgrade.

Additional changes:

  1. Improved some comments, especially about cuda graph ipc registration.
  2. Consolidated all_reduce methods.
  3. Use vector<int64_t> instead of pytorch tensor for handle during cuda graph ipc registration process. The use of Tensor was introduced in [Kernel][Misc] Use TORCH_LIBRARY instead of PYBIND11_MODULE for custom ops #5047. We sticked to non-tensor type to remove this complexity.

PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Adding or changing kernels

Each custom kernel needs a schema and one or more implementations to be registered with PyTorch.

  • Make sure custom ops are registered following PyTorch guidelines: Custom C++ and CUDA Operators and The Custom Operators Manual
  • Custom operations that return Tensors require meta-functions. Meta-functions should be implemented and registered in python so that dynamic dims can be handled automatically. See above documents for a description of meta-functions.
  • Use torch.libary.opcheck() to test the function registration and meta-function for any registered ops. See tests/kernels for examples.
  • When changing the C++ signature of an existing op, the schema must be updated to reflect the changes.
  • If a new custom type is needed, see the following document: Custom Class Support in PT2.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

Copy link

github-actions bot commented Nov 5, 2024

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@hanzhi713 hanzhi713 force-pushed the torch-ipc-share branch 2 times, most recently from e782d66 to 99a9c6e Compare November 5, 2024 08:16
@mgoin
Copy link
Collaborator

mgoin commented Nov 5, 2024

cc @tlrmchlsmth

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hanzhi713 and @youkaichao could you share a few more details on what's going on?

Looks like this is related to #9815 -- is the idea to be more stable across Pytorch versions? Do you see any downsides to this?

@hanzhi713
Copy link
Contributor Author

hanzhi713 commented Nov 5, 2024

@hanzhi713 and @youkaichao could you share a few more details on what's going on?

Looks like this is related to #9815 -- is the idea to be more stable across Pytorch versions? Do you see any downsides to this?

Yes, we want to rely less on internal API to prevent future breaking. A downside to the current approach is that it doesn't support expandable_segments:True due to a current pytorch limitation.

An alternative design I see is to do the one-time allocation of IPC-enabled buffers ourselves through CUDA C++ (i.e. cudaMalloc + ipc handle calls).

@youkaichao
Copy link
Member

#10064 will make this pr easier. we don't need to depend on pytorch's internal apis

@youkaichao
Copy link
Member

please merge main to use the functionality from #10064

@youkaichao youkaichao changed the title [Core][Distributed] Use Pytorch IPC to share tensors for custom allreduce [Core][Distributed] refactor ipc buffer init in CustomAllreduce Nov 6, 2024
Signed-off-by: Hanzhi Zhou <[email protected]>
Signed-off-by: Hanzhi Zhou <[email protected]>
@hanzhi713
Copy link
Contributor Author

Done. @youkaichao PTAL, thanks!

Signed-off-by: Hanzhi Zhou <[email protected]>
Copy link
Member

@youkaichao youkaichao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the great contribution! do you have some updated perf numbers? I assume it should not affect the performance.

@youkaichao
Copy link
Member

nit: please fix the format.

@hanzhi713
Copy link
Contributor Author

There're no perf changes from the C++. I can run the python side benchmark to be sure.

Signed-off-by: Hanzhi Zhou <[email protected]>
@hanzhi713 hanzhi713 changed the title [Core][Distributed] refactor ipc buffer init in CustomAllreduce [Core][Distributed] Refactor ipc buffer init in CustomAllreduce Nov 7, 2024
@hanzhi713
Copy link
Contributor Author

@youkaichao I can confirm that there's no perf difference with benchmarks/benchmark_latency.py

@youkaichao
Copy link
Member

RuntimeError: Inplace update to inference tensor outside InferenceMode is not allowed.You can make a clone to get a normal tensor before doing inplace

there are some errors in the ci actually @hanzhi713

logger.info("Registering %d cuda graph addresses", len(offset))
all_data = [None] * dist.get_world_size(group=self.group)
dist.all_gather_object(all_data, (handle, offset), group=self.group)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use broadcast as before in _gather_ipc_meta ?

code for reference:

        for i, rank in enumerate(ranks):
            dist.broadcast_object_list(all_data[i],
                                       src=rank,
                                       group=self.group,
                                       device="cpu")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it looks like we still need to use broadcast here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to locally test it, run pytest -v -s tests/basic_correctness/test_basic_correctness.py::test_models_distributed[facebook/opt-125m-mp--A100]

Copy link
Contributor Author

@hanzhi713 hanzhi713 Nov 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My old benchmark script had --enforce-eager there which didn't catch this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be fixed now.

Signed-off-by: Hanzhi Zhou <[email protected]>
@youkaichao youkaichao merged commit 6192e9b into vllm-project:main Nov 7, 2024
16 of 21 checks passed
@youkaichao
Copy link
Member

@hanzhi713 thanks again for the great contribution!

Isotr0py pushed a commit to Isotr0py/vllm that referenced this pull request Nov 8, 2024
omer-dayan pushed a commit to omer-dayan/vllm that referenced this pull request Nov 10, 2024
JC1DA pushed a commit to JC1DA/vllm that referenced this pull request Nov 11, 2024
sumitd2 pushed a commit to sumitd2/vllm that referenced this pull request Nov 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants