Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Figure out the future of dynamic_slice vs real_dynamic_slice #2176

Open
ghpvnist opened this issue Apr 8, 2024 · 0 comments
Open

Figure out the future of dynamic_slice vs real_dynamic_slice #2176

ghpvnist opened this issue Apr 8, 2024 · 0 comments

Comments

@ghpvnist
Copy link
Member

ghpvnist commented Apr 8, 2024

Currently we have the specced dynamic_slice and the unspecced real_dynamic_slice that is inherited from mhlo. dynamic_slice is a misnomer and not fully dynamic and real_dynamic_slice is the fully dynamic version of slice. dynamic_slice only offers dynamic start_indices, but the size of the slice still needs to be determined statically. In addition, there is no way to provide slices with strides other than one. real_dynamic_slice on the other hand provides full capabilities of slice with dynamism support for everything: start_indices, limit_indices, and strides.

This is the current state for now, but we should figure out how to move forward with these two ops. Some options are:

  1. Add real_dynamimc_slice to the opset and keep dynamic_slice.
  2. Add real_dynamimc_slice to the opset and deprecate dynamic_slice.
  3. Keep the current status quo.

Further investigation is needed to decide on which path to move forward.

@GleasonK GleasonK self-assigned this Apr 8, 2024
GleasonK added a commit that referenced this issue May 13, 2024
A proposal to remove redundant operations from StableHLO before
long-term compatibility guarantees go into place.

High level summary:
- Remove `CreateTokenOp`, `TraceOp`, `BroadcastOp`, `DotOp`,
`UnaryEinsumOp`, `RealDynamicSliceOp`.
- Enhance `DynamicSliceOp`.
- Move `CrossReplicaSumOp` to CHLO.
- Hopefully remove/move to CHLO (need feedback) `MapOp`, `RngOp`,
`EinsumOp`, `TorchIndexSelectOp`, `GetTupleElementOp`, `tuple` and `tuple` type.

OpenXLA Discuss post:
https://groups.google.com/a/openxla.org/g/openxla-discuss/c/sBAkvnd2bcA

Related tickets: #2176, #3
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: In progress
Development

No branches or pull requests

2 participants