Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DTensor] Change Sharding algorithm to be in line with torch.chunk() #98722

Closed
wants to merge 16 commits into from

Conversation

wz337
Copy link
Contributor

@wz337 wz337 commented Apr 10, 2023

As functional collective being updated, using tensor_split() as the underlying sharding algorithm would require padding and unpadding on multiple ranks. Therefore, we are changing the sharding algorithm to be in line with torch.chunk() to allow padding on the last two ranks in most of the scenarios.

@pytorch-bot pytorch-bot bot added the release notes: distributed (fsdp) release notes category label Apr 10, 2023
@pytorch-bot
Copy link

pytorch-bot bot commented Apr 10, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/98722

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 1fdbbf2:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Contributor

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First pass, looks pretty good already, I have a few suggestions inlined. Mainly I think we should use new_zeros to create empty shard, which could help us consolidate the padding logic and no need to infer the device :)

)
return torch.tensor([], device=device)
else:
return tensor

def _unpad_concat_tensor(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can probably delete this _unpad_concat_tensor as it's only being used by a test, we can fix that test instead :)

device = torch.device(
tensor.get_device() if torch.cuda.is_available() else "cpu"
)
empty_tensor = torch.tensor([], device=device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm actually I think we should not "infer" device like this. given that we have tensor already, we should always use the tensor itself to create a new tensor. I think we can do the following to correctly get empty tensor:

tensor.new_zeros((0, 3)) -> tensor([], size=(0, 3)) # this would give a empty tensor, but with the correct shape!
tensor.new_zeros(shape) -> tensor([..]) # this would give a tensor with "0" filled in according to the new `shape`, and have the same shape/dtype with the original tensor

)
tensor_size = list(reference_tensor.size())
tensor_size = [dim if dim >= self.dim else 0 for dim in tensor_size] # type: ignore[attr-defined]
return torch.zeros(tensor_size, device=device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similarly here by using new_zeros, we can create a new zeros tensor directly with the expected shape! i.e.

b= torch.tensor([], size=(0, 3))
b.new_zeros((3, 3) -> works!

device = torch.device(
tensor.get_device() if torch.cuda.is_available() else "cpu"
)
return torch.tensor([], device=device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, sth like: tensor.new_zeros((0, other_dims))

self,
tensor: torch.Tensor,
pad_size: int,
reference_tensor: Optional[torch.Tensor] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we probably don't need this reference_tensor anymore if we creating the empty tensor with sth like tensor.new_zeros((0, 3)), we only need the pad_size and that would make the padding logic consistent too.

else:
pad = [0, 0] * (tensor.ndim - self.dim)
pad[-1] = pad_size
return torch.nn.functional.pad(tensor, pad)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think padding also works for tensor like torch.tensor([], size=(0, 3))

local_offset_on_dim = -1
if return_offset:
# QQ: what would be the offset of an empty shard? -1?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm that's a good point... do you know how sharded tensor offset on empty shard looks like? I think we might can return the "global tensor dim size" for empty shard (representing the end of that tensor dim), but would like to see if this make sense for existing use case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me double check on this.

@@ -39,7 +39,7 @@ def _split_tensor(
*,
with_padding: bool = True,
contiguous: bool = True,
) -> Tuple[List[torch.Tensor], int]:
) -> Tuple[List[torch.Tensor], int, List[int]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so it looks like all the callsites (except test_device_mesh.py don't need the second return argument anymore (i think it's because this embeds in the last return argument). Shall we delete the second return arg and return two args Tuple[List[torch.Tensor], List[int]] instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! Actually thinking about the same thing!

for idx in range(num_chunks)
]
# Get idx start to pad
idx_start_to_pad = next(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we really need this given that we only need pad_sizes? pad_sizes can actually be computed without this iiuc.

# Compute pad size on each chunk
pad_sizes = [
full_chunk_size - chunk_size if idx >= idx_start_to_pad else 0
for idx, chunk_size in enumerate(chunk_sizes)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i.e. we don't really need to check if idx >=idx_start_to_pad? it could always be full_chunk_size-chunk_size, for ranks don't need to pad, the subtraction would become 0 automatically?

@wanchaol wanchaol added release notes: distributed (dtensor) release notes category and removed release notes: distributed (fsdp) release notes category labels Apr 12, 2023
@wz337
Copy link
Contributor Author

wz337 commented Apr 18, 2023

@pytorchmergebot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a rebase job. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased dtensor_update onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout dtensor_update && git pull --rebase)

Copy link
Contributor

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Thanks for working on it! I have a few more suggestions inlined.

# Explicitly return an empty tensor. Otherwise, even if the
# tensor is empty, the size won't be 0.
if tensor.numel() == 0:
return tensor.new_zeros([0])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm could you explain why we need to explicitly return an empty tensor? iiuc after tensor.narrow there would be some ranks have tensor with shape ([0, 8]) for example, that's still considered as empty tensor with no data, so that should work for us.

Copy link
Contributor Author

@wz337 wz337 Apr 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm could you explain why we need to explicitly return an empty tensor? iiuc after tensor.narrow there would be some ranks have tensor with shape ([0, 8]) for example, that's still considered as empty tensor with no data, so that should work for us.

The reason is that we have test to compare an unpad tensor with original tensor_to_split.

https://github.com/pytorch/pytorch/blob/ada67c9d8a3ac61fa0af5b8a186d5ecb31765af5/test/distributed/_tensor/test_device_mesh.py#L210-L212

In this case, the size of the two would fail the assert as one could be ([0, 8]) and the other one is ([0]), although both are just empty tensors([]). I guess I could update the test cases to make _unpad() more consistent.

]
# Compute pad size on each chunk
pad_sizes = [
full_chunk_size - chunk_size for idx, chunk_size in enumerate(chunk_sizes)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: no need to have idx here?

for idx in range(num_chunks)
]
pad_sizes = [full_chunk_size - chunk_size for chunk_size in chunk_sizes]
is_padded = not all(pad_size == 0 for pad_size in pad_sizes)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: why don't we use a similar condition for is_padded in reduce_scatter_tensor? i.e. size[self.dim]% num_chunks != 0

@wz337
Copy link
Contributor Author

wz337 commented Apr 20, 2023

@pytorchmergebot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a rebase job. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased dtensor_update onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout dtensor_update && git pull --rebase)

@wz337
Copy link
Contributor Author

wz337 commented Apr 20, 2023

@pytorchmergebot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 20, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request May 14, 2023
When tensor.size(self.dim) < num_chunks, we will fill empty chunk with empty tensor (#98722). Therefore, we no longer needs this assert.

For example, when sharding a tensor with 1 element on 2 ranks along dim 0, results would be as follows:
```
rank:0, dtensor:DTensor(local_tensor=tensor([0.4963], device='cuda:0'), device_mesh=DeviceMesh:([0, 1]), placements=[Shard(dim=0)])
rank:1, dtensor:DTensor(local_tensor=tensor([], device='cuda:1'), device_mesh=DeviceMesh:([0, 1]), placements=[Shard(dim=0)])
```
Pull Request resolved: #101218
Approved by: https://github.com/wanchaol
jcaip pushed a commit that referenced this pull request May 23, 2023
When tensor.size(self.dim) < num_chunks, we will fill empty chunk with empty tensor (#98722). Therefore, we no longer needs this assert.

For example, when sharding a tensor with 1 element on 2 ranks along dim 0, results would be as follows:
```
rank:0, dtensor:DTensor(local_tensor=tensor([0.4963], device='cuda:0'), device_mesh=DeviceMesh:([0, 1]), placements=[Shard(dim=0)])
rank:1, dtensor:DTensor(local_tensor=tensor([], device='cuda:1'), device_mesh=DeviceMesh:([0, 1]), placements=[Shard(dim=0)])
```
Pull Request resolved: #101218
Approved by: https://github.com/wanchaol
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged merging release notes: distributed (dtensor) release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants