Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dtensor][1/n] refactor op dispatch logic to reduce overhead #107305

Closed
wants to merge 6 commits into from

Conversation

wanchaol
Copy link
Contributor

@wanchaol wanchaol commented Aug 16, 2023

Stack from ghstack (oldest at bottom):

This PR is the first change of a series of refactors to the op dispatch logic to:

  1. remove the redundant logic in the op dispatch, simplify the error
    checking
  2. reduce the number of tree_map/tree_flatten/unflatten needed to reduce
    the overhead coming from those operations
  3. remove the CachedShardingPropagator by using lru_cache from functools
    directly, this makes it not only helps TP, but general DTensor
    operations could be faster!
  4. change the view ops behavior by inplace changing the op_schema, which
    is dangerous for sharding prop caching, model the view op as one type
    of resharding too
  5. enrich output sharding to include whether the op needs redistribute
    so that we don't need explicit op schema comparison to know it.

This should help with further reducing the CPU overhead, benchmark
results:
before (without this change), aten.addmm latency: 0.476ms
Screenshot 2023-08-16 at 10 46 26 AM

after (with this change), aten.addmm latency: 0.341ms
Screenshot 2023-08-16 at 11 05 49 AM

overall one layer of mlp time reduced from 13.535 -> 9.665ms

Apart from overhead reduction, this PR simplifies the op dispatching logic and the resharding logic (more refactor needed to make things more clean, which will be done in later PRs)

This PR is the first change of a series of refactors to the op dispatch logic to:
1. remove the redundant logic in the op dispatch, simplify the error
checking
2. reduce the number of tree_map/tree_flatten/unflatten needed to reduce
the overhead coming from those operations
3. remove the CachedShardingPropagator by using lru_cache from functools
directly
4. change the view ops behavior by inplace changing the op_schema, which
is dangerous for sharding prop caching, model the view op as one type
of resharding too
5. enrich output sharding to include whether the op needs redistribute
so that we don't need explicit op schema comparison to know it.

This should help with further reducing the CPU overhead, benchmark
results coming soon

By refactoring the dispatch logic, it could possibly enable us to have
more features later (i.e. add IrregularShard placement easier)

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Aug 16, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/107305

Note: Links to docs will display an error until the docs builds have been completed.

❌ 4 New Failures, 1 Unrelated Failure

As of commit dc70183 with merge base 5b9b816 (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

This PR is the first change of a series of refactors to the op dispatch logic to:
1. remove the redundant logic in the op dispatch, simplify the error
checking
2. reduce the number of tree_map/tree_flatten/unflatten needed to reduce
the overhead coming from those operations
3. remove the CachedShardingPropagator by using lru_cache from functools
directly, this makes it not only helps TP, but general DTensor
operations could be faster!
4. change the view ops behavior by inplace changing the op_schema, which
is dangerous for sharding prop caching, model the view op as one type
of resharding too
5. enrich output sharding to include whether the op needs redistribute
so that we don't need explicit op schema comparison to know it.

This should help with further reducing the CPU overhead, benchmark
results coming soon

By refactoring the dispatch logic, it could possibly enable us to have
more features later (i.e. add IrregularShard placement easier)

[ghstack-poisoned]
This PR is the first change of a series of refactors to the op dispatch logic to:
1. remove the redundant logic in the op dispatch, simplify the error
checking
2. reduce the number of tree_map/tree_flatten/unflatten needed to reduce
the overhead coming from those operations
3. remove the CachedShardingPropagator by using lru_cache from functools
directly, this makes it not only helps TP, but general DTensor
operations could be faster!
4. change the view ops behavior by inplace changing the op_schema, which
is dangerous for sharding prop caching, model the view op as one type
of resharding too
5. enrich output sharding to include whether the op needs redistribute
so that we don't need explicit op schema comparison to know it.

This should help with further reducing the CPU overhead, benchmark
results:
before (without this change), aten.addmm latency: 0.476ms
![Screenshot 2023-08-16 at 10 46 26 AM](https://github.com/pytorch/pytorch/assets/9443650/7692e6c1-1936-4c7f-bf9c-6c8c9b8f6c76)

after (with this change), aten.addmm latency: 0.341ms
![Screenshot 2023-08-16 at 11 05 49 AM](https://github.com/pytorch/pytorch/assets/9443650/15a53f0b-7a95-444e-ab2f-3ee0ad2fa47f)


overall one layer of mlp time reduced from 13.535 -> 9.665ms


[ghstack-poisoned]
This PR is the first change of a series of refactors to the op dispatch logic to:
1. remove the redundant logic in the op dispatch, simplify the error
checking
2. reduce the number of tree_map/tree_flatten/unflatten needed to reduce
the overhead coming from those operations
3. remove the CachedShardingPropagator by using lru_cache from functools
directly, this makes it not only helps TP, but general DTensor
operations could be faster!
4. change the view ops behavior by inplace changing the op_schema, which
is dangerous for sharding prop caching, model the view op as one type
of resharding too
5. enrich output sharding to include whether the op needs redistribute
so that we don't need explicit op schema comparison to know it.

This should help with further reducing the CPU overhead, benchmark
results:
before (without this change), aten.addmm latency: 0.476ms
![Screenshot 2023-08-16 at 10 46 26 AM](https://github.com/pytorch/pytorch/assets/9443650/7692e6c1-1936-4c7f-bf9c-6c8c9b8f6c76)

after (with this change), aten.addmm latency: 0.341ms
![Screenshot 2023-08-16 at 11 05 49 AM](https://github.com/pytorch/pytorch/assets/9443650/15a53f0b-7a95-444e-ab2f-3ee0ad2fa47f)


overall one layer of mlp time reduced from 13.535 -> 9.665ms


[ghstack-poisoned]
Copy link
Contributor

@fduwjj fduwjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work!! I have some questions around naming and cache miss stats, also is it possible to change the how much we cache in the lru cache?

torch/distributed/_tensor/sharding_prop.py Show resolved Hide resolved
torch/distributed/_tensor/op_schema.py Show resolved Hide resolved
torch/distributed/_tensor/dispatch.py Show resolved Hide resolved
torch/distributed/_tensor/dispatch.py Show resolved Hide resolved
Comment on lines +167 to +168
tree_unflatten(flat_args_schema, args_spec),
tree_unflatten(flat_kwargs_schema, kwargs_spec),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there anyway that we can avoid this unflatten? To me this seems to be unnecessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I realized that so added a TODO here https://github.com/pytorch/pytorch/pull/107305/files#diff-8aca68cfab443c93335bbe5e6a1c3c3cb34df117fc08a2330b7966752a049b47R81

We can possibly keep the op schema be flattened, but we will need to change all of our ops first to behave like this, I'll do more refactor later to help us gradually move the op registration to use flattened schema

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Make sense, if treemap is unavoidable, ideally we only want to do it once.

torch/distributed/_tensor/dispatch.py Outdated Show resolved Hide resolved
# compute locally with redistribute first if needed
assert output_sharding.schema_suggestions is not None
suggested_input_schema = output_sharding.schema_suggestions[0]
redistribute_local_args(op_info, suggested_input_schema)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This name is a little bit confusing to me because the concept local to me is more like tensor rather DTensor while redistribute is a unique thing for DTensor. Would it be possible that we can just call it redistribute_args or redistribute_n_update_local_args? I understand the logic here is to first redistribute dtensor and update local_args so that we pass the correct args to the final Aten ops. But the naming seems to suggest we are doing redistribute directly on top of a Tensor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you look at the refactor I did to this function https://github.com/pytorch/pytorch/pull/107305/files#diff-8aca68cfab443c93335bbe5e6a1c3c3cb34df117fc08a2330b7966752a049b47R75

It is actually indeed directly redistributing on the local tensors, so when redistributing we just have a local tensor + the dtensor spec we want to redistribute to, we don't need to make the redistribute work on the dtensor wrapper, hence that's why I renamed this function to redistributed_local_args. Let me know if this make sense or not

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After reading your comment below and a second thought... I think this name is fine.

@@ -72,19 +73,25 @@ def _decompose_reshard(val: List[_PlacementItem]) -> List[_PlacementItem]:


# Intentionally expose this API to trace ops on local tensors
def _redistribute_with_local_tensor(
def redistribute_local_tensor(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto, like I mentioned in the dispatch.py, maybe the old name redistribute_with_local_tensor or redistribute_with_local_tensor_updated` is better? because the name here seems to suggest that we directly redistribute a Tensor, which is somewhat confusing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm this function/api is indeed directly redistribute a torch.Tensor (i.e. the first arg is the local_tensor) with src/dst dtensor spec, maybe we should add more comment on the API to clarify?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, essentially there is no DTensor involved indeed. We just call collectives directly (and with no-autograd). I just think redistributing a tensor is a little confusing. Well since we redistribute a local tensor of a DTensor, this sounds make sense to me. If we can add some comments here, that would be much appreciated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep will add in the follow up PR

This PR is the first change of a series of refactors to the op dispatch logic to:
1. remove the redundant logic in the op dispatch, simplify the error
checking
2. reduce the number of tree_map/tree_flatten/unflatten needed to reduce
the overhead coming from those operations
3. remove the CachedShardingPropagator by using lru_cache from functools
directly, this makes it not only helps TP, but general DTensor
operations could be faster!
4. change the view ops behavior by inplace changing the op_schema, which
is dangerous for sharding prop caching, model the view op as one type
of resharding too
5. enrich output sharding to include whether the op needs redistribute
so that we don't need explicit op schema comparison to know it.

This should help with further reducing the CPU overhead, benchmark
results:
before (without this change), aten.addmm latency: 0.476ms
![Screenshot 2023-08-16 at 10 46 26 AM](https://github.com/pytorch/pytorch/assets/9443650/7692e6c1-1936-4c7f-bf9c-6c8c9b8f6c76)

after (with this change), aten.addmm latency: 0.341ms
![Screenshot 2023-08-16 at 11 05 49 AM](https://github.com/pytorch/pytorch/assets/9443650/15a53f0b-7a95-444e-ab2f-3ee0ad2fa47f)


overall one layer of mlp time reduced from 13.535 -> 9.665ms

Apart from overhead reduction, this PR simplifies the op dispatching logic and the resharding logic (more refactor needed to make things more clean, which will be done in later PRs)


[ghstack-poisoned]
@wanchaol wanchaol requested a review from fduwjj August 17, 2023 21:30
This PR is the first change of a series of refactors to the op dispatch logic to:
1. remove the redundant logic in the op dispatch, simplify the error
checking
2. reduce the number of tree_map/tree_flatten/unflatten needed to reduce
the overhead coming from those operations
3. remove the CachedShardingPropagator by using lru_cache from functools
directly, this makes it not only helps TP, but general DTensor
operations could be faster!
4. change the view ops behavior by inplace changing the op_schema, which
is dangerous for sharding prop caching, model the view op as one type
of resharding too
5. enrich output sharding to include whether the op needs redistribute
so that we don't need explicit op schema comparison to know it.

This should help with further reducing the CPU overhead, benchmark
results:
before (without this change), aten.addmm latency: 0.476ms
![Screenshot 2023-08-16 at 10 46 26 AM](https://github.com/pytorch/pytorch/assets/9443650/7692e6c1-1936-4c7f-bf9c-6c8c9b8f6c76)

after (with this change), aten.addmm latency: 0.341ms
![Screenshot 2023-08-16 at 11 05 49 AM](https://github.com/pytorch/pytorch/assets/9443650/15a53f0b-7a95-444e-ab2f-3ee0ad2fa47f)


overall one layer of mlp time reduced from 13.535 -> 9.665ms

Apart from overhead reduction, this PR simplifies the op dispatching logic and the resharding logic (more refactor needed to make things more clean, which will be done in later PRs)


[ghstack-poisoned]
@wanchaol wanchaol added ciflow/trunk Trigger trunk jobs on your pull request ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR release notes: distributed (dtensor) release notes category labels Aug 18, 2023
from torch.distributed._tensor.api import DTensor


def get_sharding_prop_cache_info():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this. It would be nice to add a comment or a reference on Python LRU so that user know what information we can get from this API. (This can be done in a follow-up PR as a BE work)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah sure I'll add comments in a follow up PR :)

Copy link
Contributor

@fduwjj fduwjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM and thanks for the refactoring.

@wanchaol
Copy link
Contributor Author

@pytorchbot merge -i

wanchaol added a commit that referenced this pull request Aug 21, 2023
This update some comments from the follow up of #107305

[ghstack-poisoned]
wanchaol added a commit that referenced this pull request Aug 21, 2023
This update some comments from the follow up of #107305

ghstack-source-id: eb1c9aed2a49243d415448a2f94b4c57b16403b4
Pull Request resolved: #107608
@huydhn
Copy link
Contributor

huydhn commented Aug 21, 2023

@pytorchbot drci

1 similar comment
@huydhn
Copy link
Contributor

huydhn commented Aug 21, 2023

@pytorchbot drci

@facebook-github-bot facebook-github-bot deleted the gh/wanchaol/341/head branch August 22, 2023 14:16
wanchaol added a commit that referenced this pull request Aug 22, 2023
This update some comments from the follow up of #107305

[ghstack-poisoned]
wanchaol added a commit that referenced this pull request Aug 22, 2023
This update some comments from the follow up of #107305

ghstack-source-id: aa4ff3add0ff431ee25ead4cfd419ad6925a1302
Pull Request resolved: #107608
pytorchmergebot pushed a commit that referenced this pull request Aug 22, 2023
This update some comments from the follow up of #107305
Pull Request resolved: #107608
Approved by: https://github.com/fduwjj
ghstack dependencies: #107606
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: distributed (dtensor) release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants