-
Notifications
You must be signed in to change notification settings - Fork 112
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] StableHLO v1.0 Opset Deprecations & Cleanups #2283
Merged
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
mlevesquedion
approved these changes
May 3, 2024
GleasonK
commented
May 6, 2024
This was referenced May 7, 2024
mlevesquedion
pushed a commit
that referenced
this pull request
May 9, 2024
…2296) This pass is intended to be used to test the blast radius of the proposed deprecation changes in #2283 and help with migration of passes and other tooling. Still todo - patterns for the following: - Einsum pattern - TorchIndexSelect pattern - RNG pattern (if possible) - RealDynamicSliceOp (requires dynamic_slice update first)
abhigunj
pushed a commit
to abhigunj/stablehlo
that referenced
this pull request
May 9, 2024
…penxla#2296) This pass is intended to be used to test the blast radius of the proposed deprecation changes in openxla#2283 and help with migration of passes and other tooling. Still todo - patterns for the following: - Einsum pattern - TorchIndexSelect pattern - RNG pattern (if possible) - RealDynamicSliceOp (requires dynamic_slice update first)
mlevesquedion
pushed a commit
that referenced
this pull request
May 9, 2024
Part of #2283. This op doesn't seem to have any uses in frameworks or compilers. It shouldn't cause any large issues to remove this op all together, also since this op was never specced, it is exempt from compatibility guarantees. In the case that it is used somewhere, I would recommend migration to a custom_call.
GleasonK
force-pushed
the
stablehlo-op-deprecation
branch
from
May 13, 2024 17:24
95fce0b
to
3a8947a
Compare
sdasgup3
approved these changes
May 13, 2024
mlevesquedion
approved these changes
May 13, 2024
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
A proposal to remove redundant operations from StableHLO before long-term compatibility guarantees go into place.
High level summary:
CreateTokenOp
,TraceOp
,BroadcastOp
,DotOp
,UnaryEinsumOp
,RealDynamicSliceOp
.DynamicSliceOp
.CrossReplicaSumOp
to CHLO.MapOp
,RngOp
,EinsumOp
,TorchIndexSelectOp
,GetTupleElementOp
, andtuple
type.OpenXLA Discuss post: https://groups.google.com/a/openxla.org/g/openxla-discuss/c/sBAkvnd2bcA
Related tickets: #2176, #3