-
Notifications
You must be signed in to change notification settings - Fork 22.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DTensor] Change Sharding algorithm to be in line with torch.chunk()
#98722
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/98722
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 1fdbbf2: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First pass, looks pretty good already, I have a few suggestions inlined. Mainly I think we should use new_zeros
to create empty shard, which could help us consolidate the padding logic and no need to infer the device :)
) | ||
return torch.tensor([], device=device) | ||
else: | ||
return tensor | ||
|
||
def _unpad_concat_tensor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can probably delete this _unpad_concat_tensor
as it's only being used by a test, we can fix that test instead :)
device = torch.device( | ||
tensor.get_device() if torch.cuda.is_available() else "cpu" | ||
) | ||
empty_tensor = torch.tensor([], device=device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm actually I think we should not "infer" device like this. given that we have tensor
already, we should always use the tensor
itself to create a new tensor. I think we can do the following to correctly get empty tensor:
tensor.new_zeros((0, 3)) -> tensor([], size=(0, 3)) # this would give a empty tensor, but with the correct shape!
tensor.new_zeros(shape) -> tensor([..]) # this would give a tensor with "0" filled in according to the new `shape`, and have the same shape/dtype with the original tensor
) | ||
tensor_size = list(reference_tensor.size()) | ||
tensor_size = [dim if dim >= self.dim else 0 for dim in tensor_size] # type: ignore[attr-defined] | ||
return torch.zeros(tensor_size, device=device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similarly here by using new_zeros
, we can create a new zeros tensor directly with the expected shape! i.e.
b= torch.tensor([], size=(0, 3))
b.new_zeros((3, 3) -> works!
device = torch.device( | ||
tensor.get_device() if torch.cuda.is_available() else "cpu" | ||
) | ||
return torch.tensor([], device=device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here, sth like: tensor.new_zeros((0, other_dims))
self, | ||
tensor: torch.Tensor, | ||
pad_size: int, | ||
reference_tensor: Optional[torch.Tensor] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we probably don't need this reference_tensor
anymore if we creating the empty tensor with sth like tensor.new_zeros((0, 3))
, we only need the pad_size and that would make the padding logic consistent too.
else: | ||
pad = [0, 0] * (tensor.ndim - self.dim) | ||
pad[-1] = pad_size | ||
return torch.nn.functional.pad(tensor, pad) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think padding also works for tensor like torch.tensor([], size=(0, 3))
local_offset_on_dim = -1 | ||
if return_offset: | ||
# QQ: what would be the offset of an empty shard? -1? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm that's a good point... do you know how sharded tensor offset on empty shard looks like? I think we might can return the "global tensor dim size" for empty shard (representing the end of that tensor dim), but would like to see if this make sense for existing use case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me double check on this.
@@ -39,7 +39,7 @@ def _split_tensor( | |||
*, | |||
with_padding: bool = True, | |||
contiguous: bool = True, | |||
) -> Tuple[List[torch.Tensor], int]: | |||
) -> Tuple[List[torch.Tensor], int, List[int]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so it looks like all the callsites (except test_device_mesh.py
don't need the second return argument anymore (i think it's because this embeds in the last return argument). Shall we delete the second return arg and return two args Tuple[List[torch.Tensor], List[int]]
instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point! Actually thinking about the same thing!
for idx in range(num_chunks) | ||
] | ||
# Get idx start to pad | ||
idx_start_to_pad = next( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we really need this given that we only need pad_sizes
? pad_sizes
can actually be computed without this iiuc.
# Compute pad size on each chunk | ||
pad_sizes = [ | ||
full_chunk_size - chunk_size if idx >= idx_start_to_pad else 0 | ||
for idx, chunk_size in enumerate(chunk_sizes) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i.e. we don't really need to check if idx >=idx_start_to_pad
? it could always be full_chunk_size-chunk_size
, for ranks don't need to pad, the subtraction would become 0 automatically?
@pytorchmergebot rebase |
@pytorchbot successfully started a rebase job. Check the current status here |
Successfully rebased |
781faa5
to
ada67c9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Thanks for working on it! I have a few more suggestions inlined.
# Explicitly return an empty tensor. Otherwise, even if the | ||
# tensor is empty, the size won't be 0. | ||
if tensor.numel() == 0: | ||
return tensor.new_zeros([0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm could you explain why we need to explicitly return an empty tensor? iiuc after tensor.narrow
there would be some ranks have tensor
with shape ([0, 8]) for example, that's still considered as empty tensor with no data, so that should work for us.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm could you explain why we need to explicitly return an empty tensor? iiuc after
tensor.narrow
there would be some ranks havetensor
with shape ([0, 8]) for example, that's still considered as empty tensor with no data, so that should work for us.
The reason is that we have test to compare an unpad tensor with original tensor_to_split.
In this case, the size of the two would fail the assert as one could be ([0, 8]) and the other one is ([0]), although both are just empty tensors([]). I guess I could update the test cases to make _unpad() more consistent.
] | ||
# Compute pad size on each chunk | ||
pad_sizes = [ | ||
full_chunk_size - chunk_size for idx, chunk_size in enumerate(chunk_sizes) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: no need to have idx
here?
for idx in range(num_chunks) | ||
] | ||
pad_sizes = [full_chunk_size - chunk_size for chunk_size in chunk_sizes] | ||
is_padded = not all(pad_size == 0 for pad_size in pad_sizes) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: why don't we use a similar condition for is_padded
in reduce_scatter_tensor
? i.e. size[self.dim]% num_chunks != 0
@pytorchmergebot rebase |
@pytorchbot successfully started a rebase job. Check the current status here |
Successfully rebased |
522ca0c
to
1fdbbf2
Compare
@pytorchmergebot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
When tensor.size(self.dim) < num_chunks, we will fill empty chunk with empty tensor (#98722). Therefore, we no longer needs this assert. For example, when sharding a tensor with 1 element on 2 ranks along dim 0, results would be as follows: ``` rank:0, dtensor:DTensor(local_tensor=tensor([0.4963], device='cuda:0'), device_mesh=DeviceMesh:([0, 1]), placements=[Shard(dim=0)]) rank:1, dtensor:DTensor(local_tensor=tensor([], device='cuda:1'), device_mesh=DeviceMesh:([0, 1]), placements=[Shard(dim=0)]) ``` Pull Request resolved: #101218 Approved by: https://github.com/wanchaol
When tensor.size(self.dim) < num_chunks, we will fill empty chunk with empty tensor (#98722). Therefore, we no longer needs this assert. For example, when sharding a tensor with 1 element on 2 ranks along dim 0, results would be as follows: ``` rank:0, dtensor:DTensor(local_tensor=tensor([0.4963], device='cuda:0'), device_mesh=DeviceMesh:([0, 1]), placements=[Shard(dim=0)]) rank:1, dtensor:DTensor(local_tensor=tensor([], device='cuda:1'), device_mesh=DeviceMesh:([0, 1]), placements=[Shard(dim=0)]) ``` Pull Request resolved: #101218 Approved by: https://github.com/wanchaol
As functional collective being updated, using tensor_split() as the underlying sharding algorithm would require padding and unpadding on multiple ranks. Therefore, we are changing the sharding algorithm to be in line with
torch.chunk()
to allow padding on the last two ranks in most of the scenarios.