-
-
Notifications
You must be signed in to change notification settings - Fork 4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Core][Distributed] add fast broadcast for tensor dict #4757
Conversation
The from vllm.distributed.communication_op import TensorMetadata
import torch
d = TensorMetadata("cuda", torch.float32, torch.Size([]))
import pickle
s = pickle.dumps(d)
len(s) # 120
import pickletools
pickletools.dis(s) output:
Each single |
After a8d1d3a, the serialization size is reduced by more than a half (120 bytes to 52 bytes): from vllm import TensorMeta
import torch
d = TensorMeta("cuda", torch.float32, torch.Size([]))
import pickle
s = pickle.dumps(d)
len(s) # 52
import pickletools
pickletools.dis(s) output:
|
With all above optimization, the bytes to broadcast This benefit will become more significant when we apply the technique to prepare input related data stucture. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"improve the broadcast for prepare input in the same way." -> planning to do this in this PR? Also can you tell me the perf improvement from it?
Also can you update
def test_swap( |
metadata_list = [] | ||
tensor_list = [] | ||
for key, value in tensor_dict.items(): | ||
used_keys = keys or tensor_dict.keys() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we assert keys == len(tensor_dict)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not necessary though. This current code is more flexible without assert.
@dataclasses.dataclass | ||
class TensorMeta: | ||
""" | ||
This class is placed here to reduce the size of qualified name, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what about we just create vllm/tensor_meta.py? Is this still long?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, vllm/tensor_meta.py
will lead to vllm.tensor_meta.TensorMeta
, longer than vllm.TensorMeta
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does tensor_meta
make a big difference? Feel like if it is just a little bit difference (like 2 digits microsecond), I prefer to avoid it...
tensor_list.append(value) | ||
else: | ||
metadata_list.append((key, value)) | ||
if keys is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit; why don't we just check it in line 213?
if keys is not None:
metadata_list.append((key, value))
else:
metadata_list.append(value)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think control flow in the loop is more expensive (N control flow) than control flow outside of the loop (1 control flow).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In [12]: def control():
...: b = True
...: result = []
...: a = [i for i in range(3000)]
...: for i in a:
...: if b:
...: result.append((i, i))
...: else:
...: result.append(i)
In [16]: def copy():
...: b = True
...: result = []
...: a = [i for i in range(3000)]
...: for i in a:
...: result.append((i, i))
...: result = [value for key, value in result]
In [22]: timeit copy()
192 µs ± 686 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
In [23]: timeit control()
159 µs ± 487 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
Hmm actually I tried and it looks like control is faster. But I think the perf diff here is not very meaningful (it is premature optimization). I was asking because I thought it is easier to understand, but not strong opinion. I will leave it up to you.
|
||
This class represents a dictionary of tensors with bounded metadata. | ||
The upperbound of the buffer size is known a priori. Therefore, we can | ||
pre-allocate a buffer for the metadata, and invoke only one collective |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this correct because we are now broadcasting using cpu "tensor", we don't need to broadcast the object size (which is the implementation detail of broadcast_object_list)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The key idea is not cpu "tensor", but we know the maximum size of the serialization, so we don't need to broadcast the length. This is indeed an implementation detail of broadcast_object_list
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add this to the comment that it relies on that implementation detail?
dtypes). If `cls` is provided, we can know the length of the metadata | ||
roughly and allocate a buffer for it, then broadcasting metadata requires | ||
only one broadcast call. Otherwise, we need to broadcast the metadata | ||
length first, then broadcast the metadata. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a simple example of how to use TensorDictWithBoundedMetadata in the docstring?
It will require another PR.
For broadcasting blocks to swap/copy, the benefit is: Broadcast time (before): 0.38772106170654297ms I don't have an end-to-end benchmarking.
It requires quite a large modification to the test procedure (separate the test into distributed tests) . Meanwhile, the correctness is already checked in https://github.com/vllm-project/vllm/pull/4757/files#diff-cba46ef2b8ff23834781fa3b43794a3f19ffc6b4f1ec2353a8d13d1cdc2d0588R110 . |
@rkooo567 can you help take a look at https://buildkite.com/vllm/ci/builds/7258#018f732a-46ad-4e69-a35b-25f5200d0e19 ? The failure looks like a ray issue, the function cannot access the name |
@youkaichao it would be good to check whether there's non-negligible performance difference in end-to-end tests before introducing the additional complexity, it's not always easy to infer this from a microbenchmark. A simple before/after test generating a decent number of tokens with a TP deployment would be sufficient I think? Do you know how much of the latency benefit comes from compressing the number of bytes with the new TensorMeta class vs eliminating one of the broadcasts? The two Especially if they're combined, I'm wondering whether we can avoid the quite convoluted abstractions for what is just a single case. Implementation-wise, instead of requiring the custom classes, what do you think about this:
Then there's no need to maintain special classes. I also don't think there's any need to have special handling for the keys, we can just pass lists instead of dicts? |
@youkaichao another reason the above approach might be better - IIUC the |
First of all, this PR is the first step for later optimization. Itself is a pure benefit because it reduces the broadcast from twice to once. The followup for applying the optimization in prepare input needs to come after the refactor #4681 .
This does not reduce the broadcast. It still requires two broadcast even if we don't have any tensor data to broadcast. |
I think the nice benchmark to back up is;
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For PR, I think it LGTM.
prefer to avoid having TensorMeta inside init unless the perf diff is very big. I will approve it for now, but please resolve the discussion with @njhill before merging it!
|
||
This class represents a dictionary of tensors with bounded metadata. | ||
The upperbound of the buffer size is known a priori. Therefore, we can | ||
pre-allocate a buffer for the metadata, and invoke only one collective |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add this to the comment that it relies on that implementation detail?
tensor_list.append(value) | ||
else: | ||
metadata_list.append((key, value)) | ||
if keys is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In [12]: def control():
...: b = True
...: result = []
...: a = [i for i in range(3000)]
...: for i in a:
...: if b:
...: result.append((i, i))
...: else:
...: result.append(i)
In [16]: def copy():
...: b = True
...: result = []
...: a = [i for i in range(3000)]
...: for i in a:
...: result.append((i, i))
...: result = [value for key, value in result]
In [22]: timeit copy()
192 µs ± 686 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
In [23]: timeit control()
159 µs ± 487 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
Hmm actually I tried and it looks like control is faster. But I think the perf diff here is not very meaningful (it is premature optimization). I was asking because I thought it is easier to understand, but not strong opinion. I will leave it up to you.
@njhill has a proposal to cache the max length of metadata based on callsite, I will wait and see how it works. |
@youkaichao I've opened #4844 to show the idea, PTAL! |
close as #5399 will be a better solution. |
An ongoing effort of #4440 .
Reduce the number of broadcast from 2 to 1.
Broadcast time (before): 0.38772106170654297ms
Broadcast time (after): 0.128173828125ms
TODO: