Status: Approved
Initial version: 7/4/2023
Last updated: 4/15/2024
Discussion thread: openxla-discuss.
This RFC aims to leverage existing support for dynamism in the StableHLO dialect, discuss improvements to the existing design and then formalize the improved design in the StableHLO opset. To that end, this document proposes the following:
- (P1) Move TensorFlow-specific operations out of StableHLO.
- (P2) Ratify the existing convention for shape mismatches constituting undefined behavior.
- (P3) Ratify the existing convention for relaxed constraints already implemented in the StableHLO dialect.
- (P4) Enable shape-dependent dynamism for all size-related program elements but keep all axis-related program elements static.
- (P5) Drop support for unranked dynamism.
These proposals address a considerable chunk of community feedback on the existing design, but some feedback is deliberately left out of scope for this RFC to enable incremental progress while areas which require additional alignment are developed in parallel. See sections "Community feedback" and "Out of scope" for details.
This is an RFC that affects the entirety of the StableHLO opset, and there are multiple groups of operations which are affected differently. Taking into account all these operations appropriately was quite laborious, but the StableHLO Ops spreadsheet was a big help. Referring to this spreadsheet will likely help in reviewing this RFC as well.
Before beginning, let's align on terminology. In discussions around MLIR, it is
fairly conventional to use the word "dimension" ambiguously - to refer to either
an actual dimension or the size of a dimension. Most of the time, the exact
meaning of the word "dimension" is clear from the context, but sometimes it's
ambiguous. For example, it is not obvious whether
mlir::ShapedTypeComponents::getDims
returns dimension numbers or dimension
sizes (for what it's worth, it is the latter - it returns dimension sizes).
To avoid ambiguity, this document will be using the following terminology:
- Dimension numbers (or axes) to refer to actual dimensions,
e.g. "
tensor<16x?xf32>
has two axes". - Dimension sizes (or sizes) to refer to sizes of dimensions,
e.g. "
tensor<16x?xf32>
has the following dimension sizes: 16 and unknown". - Unqualified "dimension" will not be used at all.
By the virtue of being bootstrapped from MHLO, the StableHLO dialect already has support for dynamism within types:
- Unbounded dynamism: In a tensor type, some or all dimension sizes may be unknown (aka "dynamic"). In MLIR syntax, these dimension sizes are expressed via question marks.
- Bounded dynamism: The same but some of dynamic dimensions sizes may have
known upper bounds. In MLIR syntax, these dimension sizes are expressed via
question marks, and bounds are expressed via
#stablehlo.bounds
. - Unranked dynamism: In a tensor type, rank may be unknown. In MLIR syntax, this fact is expressed via asterisks.
// Static shapes:
// All dimension sizes are known.
%0 = stablehlo.add %arg0, %arg1 : tensor<16x16xf32>
// Unbounded dynamism:
// First dimension size is unknown (?).
// Second dimension size is known (16).
%1 = stablehlo.add %arg0, %arg1 : tensor<?x16xf32>
// Bounded dynamism:
// First dimension size is unknown (?), but its bound is known (16).
// Second dimension size is known (16), so its bound is N/A (?).
%2 = stablehlo.add %arg0, %arg1 : tensor<?x16xf32, #stablehlo.bounds<16, ?>>
// Unranked dynamism:
// The rank is unknown (*).
%3 = stablehlo.add %arg0, %arg1 : tensor<*xf32>
Similarly, the StableHLO dialect already has support for dynamism within operations, with some ops such as PadOp having dynamic counterparts such as DynamicPadOp. For example:
// "vanilla" PadOp:
// low, high and interior paddings are implemented via MLIR attributes,
// i.e. they are static.
%0 = stablehlo.pad %arg0, %arg1, low = [1, 1], high = [1, 1], interior = [0, 0]
: (tensor<16x16xf32>, tensor<f32>) -> tensor<20x20xf32>
// DynamicPadOp:
// low, high and interior paddings are implemented via MLIR operands,
// i.e. they are dynamic.
%1 = stablehlo.dynamic_pad %arg0, %arg1, %arg2, %arg3, %arg4 :
: (tensor<?x?xf32>, tensor<f32>, tensor<2xindex>, tensor<2xindex>, tensor<2xindex>) -> tensor<?x?xf32>
Finally, dynamism (both within types and within operations) creates new
opportunities for errors. For example, the semantics of elementwise operations
like add
are only defined when inputs and outputs have the same shape.
For the static shape case, this can be checked at compile time (i.e. before
the execution of the StableHLO program). However, when dynamism is involved,
this can only be checked at run time.
If some operation doesn't make sense at run time because some unknown ranks and/or unknown dimension sizes turned out to be incompatible with the operation and/or with each other, an error condition called "shape mismatch" occurs. In order to guard against shape mismatches, StableHLO programs may employ shape checks.
The current design of dynamism in MHLO and StableHLO has been practically useful. There are success stories of it connecting JAX, PyTorch and TensorFlow to a number of compilers, in a mix of research and production environments. However, this design can be improved. Over the years of using this design in practice, the community has provided the following feedback which is summarized below (see #8 for the full version):
This RFC aims to address a considerable part of community feedback on the existing design of dynamism in MHLO and StableHLO, but some feedback is deliberately out of scope because further discussion is needed to obtain alignment in the corresponding areas.
However, this is a very new area for the OpenXLA stack, so a considerable amount of design exploration will be needed before these ideas can be turned into concrete proposals. Meanwhile, this RFC is focused on making design improvements which have been already socialized within the community.
(O2) Shape computations are considerably improved by this RFC, but not all of the feedback from F5/F6 is addressed.
More specifically, this RFC proposes to represent shape computations as
StableHLO operations on 0-dimensional tensors, which is a practically important
simplification for StableHLO programs and their producers/consumers.
However, interoperability with other dialects, including arith
and shape
which are oftentimes used by the community for shape computations, is out of
scope, and P5 goes into details about the rationale.
(O3) Working out a common solution for shape checks would be nice. Per F7, there are multiple incompatible approaches, so a common solution would result in a simplification of the overall stack.
However, based on author's experience, alignment in this area requires non-trivial effort (different approaches have different benefits, plus they have limited interoperability), so this is left for future work.
Much of this is already implemented in the StableHLO dialect (e.g. verifiers and shape functions already support dynamism in a fairly robust manner, although some open questions still remain), but formalizing this area is left for future work, because that would require formalizing the general notion of shape inference which represents a significant amount of work.
(O5) Value inference is an algorithm implemented in the XLA compiler. It "analyzes values in XlaOp answers following questions: 1) What's the upper-bound of each value in a tensor, 2) What's the lower-bound of each value in a tensor, 3) What's the constant value of each tensor, 4) Whether or not each value in a tensor is dynamic". This algorithm is used in the TF/XLA bridge to automatically compute bounds for bounded types in HLO programs.
Similarly to shape inference, this area is out of scope for this RFC. From the implementation standpoint, value inference is currently not available for the StableHLO dialect, but this may change in the future depending on community feedback.
However, similarly to formalizing shape inference, formalizing shape specialization is a significant amount of work which is not on the critical path to StableHLO v1.0, so it is left for future work. In the meantime this refinement machinery and verification which ensures that programs are refinable will be made available to frameworks and hardware teams. Per the feedback in F9, this RFC aims to only support shape dependent dynamism. Anything beyond shape dependent dynamism will be left to future RFCs.
There are several operations in the StableHLO dialect which provide dynamic
versions of existing StableHLO operations. For example, PadOp
defines
edge_padding_low
, edge_padding_high
and interior_padding
as static
attributes, whereas DynamicPadOp
has the same contract except that those
arguments are dynamic values. Unifying these ops may be nice, but prior to that
significant investigation to determine the usability of unified ops, as well as
the impact of having unified ops on DRR users needs to be investigated. In this
RFC F2 remains unaddressed.
Baseline: The MHLO and CHLO dialects have been initially co-designed with the TensorFlow dialect, the MLIR-based TF/XLA bridge and TensorFlow's KernelGen toolchain. This has been inherited by the StableHLO repository when the StableHLO dialect was bootstrapped from MHLO and the CHLO dialect was moved to the StableHLO repository.
As a result, there are 2 StableHLO ops (compute_reshape_shape
and
cstr_reshapable
) as well as 4 CHLO ops (dynamic_reshape
,
minimum_broadcast_shapes
, rank_specialization_cluster
and
rank_specialization_cluster_yield
) which appear specific to TensorFlow, i.e.
it looks like they are only used in legalize_tf.cc
and rank_specialization.cc
passes within the TensorFlow ecosystem. This specific nature of these ops does
not meet the bar for inclusion in the StableHLO repository.
Proposal: This RFC proposes to immediately remove these 6 operations from
their current dialects and move them into a new dialect called chlo_legacy
.
StableHLO doesn't guarantee
compatibility for these operations (see the "Unspecced features" clause).
However, this RFC proposes to provide compatibility guarantees on exceptional
basis given that they don't appear to involve a lot of work. More specifically,
the proposal is to keep the chlo_legacy
dialect in the StableHLO repository
for 6 months from the date of its creation and only then remove it.
Baseline: As mentioned in F7, there is broad consensus that shape mismatches in StableHLO operations should constitute undefined behavior, i.e. that guarding against shape mismatches should be made a producer responsibility. This has been a subject of many informal conversations as well as one of the discussion items at the StableHLO dynamism breakout session during the OpenXLA Dev Summit in April 2023.
Proposal: This RFC proposes to ratify this convention as a StableHLO design principle, so that the folklore consensus becomes codified in the StableHLO specification.
Discussion: A) The rationale for this proposal is that lower-level abstractions (e.g. HLO or Linalg) typically don't want to concern themselves with shape checks - they are focused on generating high-performance code, so they want shape checks to be handled at a higher level, i.e. at StableHLO or above.
Furthermore, different producers have different requirements and different preferences for expressing shape checks (e.g. JAX's type system enables it to need fewer checks than TensorFlow's type system), so a specific way of performing shape checks don't look like something that should be standardized at the StableHLO level either.
B) From the implementation standpoint, this proposal would involve updating
the "Errors" section of the specification, and
changing the StableHLO dialect to implement the ConditionallySpeculatable
interface, prohibiting speculation for StableHLO ops that involve dynamism
(documentation).
(P3) Ratify the existing convention for relaxed constraints already implemented in the StableHLO dialect
Baseline: At the moment, the StableHLO dialect supports relaxed constraints that were inherited from the MHLO dialect. For example, the code below is valid in the StableHLO dialect:
func.func @main(%arg0: tensor<?xf32>, %arg1: tensor<1xf32>) {
%0 = stablehlo.add %arg0, %arg0 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%1 = stablehlo.add %arg0, %arg0 : (tensor<?xf32>, tensor<?xf32>) -> tensor<1xf32>
%2 = stablehlo.add %arg0, %arg1 : (tensor<?xf32>, tensor<1xf32>) -> tensor<?xf32>
%3 = stablehlo.add %arg0, %arg1 : (tensor<?xf32>, tensor<1xf32>) -> tensor<1xf32>
%4 = stablehlo.add %arg1, %arg0 : (tensor<1xf32>, tensor<?xf32>) -> tensor<?xf32>
%5 = stablehlo.add %arg1, %arg0 : (tensor<1xf32>, tensor<?xf32>) -> tensor<1xf32>
%6 = stablehlo.add %arg1, %arg1 : (tensor<1xf32>, tensor<1xf32>) -> tensor<?xf32>
%7 = stablehlo.add %arg1, %arg1 : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
func.return
}
More formally, these relaxed constraints generalize the constraints that are already documented in the StableHLO specification. If a constraint for a specific operation cannot be evaluated at compile time because it involves an unknown rank or an unknown dimension size, it gets deferred until the run time of the operation. If at that point the constraint fails, a shape mismatch occurs, and P2 discusses what should happen in that case.
// Static case - constraints I1-I5 and C1-C4 can be evaluated at compile time.
%0 = stablehlo.pad %arg0, %arg1, low = [1, 1], high = [1, 1], interior = [0, 0]
: (tensor<16x16xf32>, tensor<f32>) -> tensor<20x20xf32>
// Dynamic case - constraints C2 and C4 cannot be evaluated at compile time.
// C2 depends on rank(operand) which is unknown.
// C4 depends on shape(operand) and shape(result) which are both unknown.
// These constraints are deferred until run time.
%1 = stablehlo.pad %arg0, %arg1, low = [1, 1], high = [1, 1], interior = [0, 0]
: (tensor<*xf32>, tensor<f32>) -> tensor<?x20xf32>
// Dynamic case - constraints C3 and C4 cannot be evaluated at compile time.
// C3 depends on the value of interior_padding which is unknown.
// C4 depends on a number of shapes and values which are all unknown.
// Note that C2 can be evaluated at compile time - even though the values of
// edge_padding_low, edge_padding_high, interior_padding and operand are
// unknown, their size and rank are actually known.
%2 = stablehlo.dynamic_pad %arg0, %arg1, %arg2, %arg3, %arg4 :
: (tensor<?x?xf32>, tensor<f32>, tensor<2xindex>, tensor<2xindex>, tensor<2xindex>) -> tensor<?x?xf32>
Proposal: A) This RFC proposes to ratify this convention as a StableHLO design principle, given that: I) it allows the flexibility for producers to mix static and dynamic elements in StableHLO programs at their convenience (which has been proven to be practically useful, e.g. for JAX native serialization), II) it has a concise formulation that is demonstrably sound (it only affects constraints, and no constraints are disregarded - they just move from compile time to run time).
B) From the implementation standpoint, this proposal would involve updating the "Notation" section of the specification, and auditing the existing verifiers and shape functions in the StableHLO dialect to identify specification compliance issues, update status.md accordingly and file Specification Compliance tickets.
(P4) Enable shape-dependent dynamism for all size-related program elements but keep all axis-related program elements static
Baseline: The StableHLO dialect has inherited a considerable degree of support for dynamism from the MHLO dialect. All MLIR operands can have dynamic shapes, almost all MLIR results can have dynamic shapes and many size-related MLIR attributes have dynamic counterparts.
1) As far as MLIR results go, for 7 operations in the StableHLO dialect -
BroadcastInDimOp
, ConstantOp
, InfeedOp
, IotaOp
, RecvOp
, ReshapeOp
and RngBitGeneratorOp
- the static shape of the results is "load-bearing",
i.e. allowing dynamic shapes there would not make sense as is. Operations that
do not have "load-bearing" result shapes can infer result shapes with static
operand shapes.
// Static result type - makes sense.
%0 = stablehlo.broadcast_in_dim %arg0, dims = [0] : (tensor<1xf32>) -> tensor<1x2xf32>
// Dynamic result type - doesn't make sense as is.
// How does the operation know what result to produce? 1x1xf32? 1x2xf32? etc.
// Resolving this would need an additional argument - see below.
%1 = stablehlo.broadcast_in_dim %arg0, dims = [0] : (tensor<1xf32>) -> tensor<1x?xf32>
2) As far as MLIR attributes go, there are 9 operations in the StableHLO dialect which have size-related attributes (operations are grouped together using the categories from the StableHLO Ops spreadsheet , operations from the Dynamism category are analyzed below, operations from the Not in HLO category are not included because they are on their way out of the StableHLO dialect):
- Data Movement:
DynamicSliceOp
,GatherOp
,PadOp
,ScatterOp
,SliceOp
. - Miscellaneous:
FftOp
. - Reduction:
ConvolutionOp
,ReduceWindowOp
,SelectAndScatterOp
.
Furthermore, there are 18 operations in the StableHLO dialect which have axis-related attributes (some of these operations overlap with the operations from the previous section):
- Data Movement:
BroadcastInDimOp
,ConcatenateOp
,GatherOp
,ReverseOp
,ScatterOp
,SortOp
,TransposeOp
. - Distribution:
AllGatherOp
,AllToAllOp
,ReduceScatterOp
. - Elementwise:
MapOp
. - Miscellaneous:
BatchNormGradOp
,BatchNormInferenceOp
,BatchNormTrainingOp
,IotaOp
. - Reduction:
ConvolutionOp
,DotGeneralOp
,ReduceOp
.
3) Finally, there are 7 operations in the StableHLO
dialect which provide dynamic versions of existing StableHLO operations:
DynamicBroadcastInDimOp
, DynamicConvOp
, DynamicGatherOp
, DynamicIotaOp
,
DynamicPadOp
, DynamicReshapeOp
and RealDynamicSliceOp
. Not all StableHLO
operations have dynamic counterparts, e.g. there is no DynamicReduceWindowOp
.
Per the feedback in F9, this RFC proposes to support shape-dependent uses
of these operations, which are refinable to be used by the entire StableHLO
ecosystem. In PyTorch and JAX, only shape-dependent uses of these operations
exist. Meaning, the the dynamic operand value to specify the result shapes is a
computation of the shape of on another operation. With this principle all
StableHLO programs with dynamism are refinable to static programs by providing
concrete input types. A refinement verification pass will be offered to ensure
in frameworks that a generated program is shape-dependent / refinable.
// The StableHLO dialect resolves the conundrum mentioned in the previous
// example by providing a dynamic version of BroadcastInDimOp which takes
// the shape of the result as an operand.
%0 = stablehlo.dynamic_broadcast_in_dim %arg0, %arg1, dims = [0] :
(tensor<1xf32>, tensor<2xi64>) -> tensor<1x?xf32>
Overall, as mentioned in F3 and further elaborated above in this section, dynamism within operations is modelled inconsistently. Some ops have two versions (e.g. PadOp and DynamicPadOp), some ops have three versions (e.g. SliceOp, DynamicSliceOp and RealDynamicSliceOp), and some ops don't have dynamic versions even though there are use cases for them (e.g. ReduceWindowOp).
Proposal: This RFC proposes to address the inconsistencies in the existing support for dynamism in StableHLO summarized in "Status of dynamic versions" by establishing a concise design principle:
- If a program element is related to sizes, then it should be possible to express it both statically and dynamically.
- If a program element is related to axes, then it should only be possible to express it statically.
Discussion: A) This proposal builds on a broad consensus to treat dimension sizes dynamically and axes statically, which has come up both in informal conversations and at the OpenXLA Dev Summit.
For 1), there's already a significant precedent to provide both static and dynamic representations - which started in HLO and continued in MHLO (plus, we have several feature requests from JAX: for ConvolutionOp, for FftOp, for ReduceWindowOp and for RngBitGeneratorOp). This RFC proposes to extend this predecent to the entire opset, with the rationale that the burden of adding a few extensions that were not previously supported or requested (for ConstantOp, for InfeedOp and for RecvOp) is smaller than the burden of having to define and maintain carveouts in the specification.
For 2), the proposal is to keep them static because: A) there doesn't seem to be precedent or feature requests to make them dynamic, B) when this topic came up with several StableHLO consumers, there was considerable pushback since that would considerably complicate code generation.
B) From the implementation standpoint, this proposal would involve updating
the "Types" section of the specification to extend
the definition of TensorType
to express dynamism within types and
accommodate both unbounded and bounded dynamism (but not unranked dynamism -
see P5 for details).
However, this section doesn't provide an opinion on what representations should
be used to express dynamism within operations. As mentioned in F2,
the existing design where static ops like PadOp
are accompanied with dynamic
ops like DynamicPadOp
has some drawbacks, which are denoted in O7 as
out of scope for this RFC.
As mentioned in F1, the only user of unranked dynamism in StableHLO/MHLO
appears to be TensorFlow, and TensorFlow can encapsulate handling unranked
dynamism in a different layer, so overall this proposal looks like an
improvement - it does require a non-trivial refactoring in one user, but in
return it simplifies the StableHLO specification and implementation for
everyone. The ops of concern in TF are StridedSlice
, Reshape
, and Squeeze
,
and it appears that interop between these ops and StableHLO is not required.
The StableHLO dialect represents shapes as 1-dimensional tensors,
with static shapes expressed as DenseI64ArrayAttr
and dynamic shapes
expressed as 1DTensorOf<[HLO_DimensionValue]>
(HLO_DimensionTensor
).
class PadOp ... {
// This is how static sizes are represented in TableGen:
// DenseI64ArrayAttr:$edge_padding_low
// DenseI64ArrayAttr:$edge_padding_high
// DenseI64ArrayAttr:$interior_padding
DenseI64ArrayAttr getEdgePaddingLow();
DenseI64ArrayAttr getEdgePaddingHigh();
DenseI64ArrayAttr getInteriorPadding();
}
class DynamicPadOp ... {
// This is how dynamic sizes are represented in TableGen:
// HLO_DimensionTensor:$edge_padding_low
// HLO_DimensionTensor:$edge_padding_high
// HLO_DimensionTensor:$interior_padding
TypedValue<RankedTensorType> getEdgePaddingLow();
TypedValue<RankedTensorType> getEdgePaddingHigh();
TypedValue<RankedTensorType> getInteriorPadding();
}
As mentioned in F5, shape computations in StableHLO programs produce
these 1-dimensional tensors using one of several approaches, including ops from
Arith, Shape, StableHLO as well as other dialects. E.g. addition of dimension
sizes can be represented via stablehlo::AddOp
, arith::AddIOp
, shape::AddOp
and in a few other ways.
As mentioned in F4, using 1-dimensional tensors introduces complexity
when shape computations operate in terms of scalars, which requires the
0-dimensional tensors to be reshaped to 1-dimensional tensors and then
concatenated together (example).
However, in most cases shapes are created directly using ops like
shape.shape_of
, and some shape computations operate on 1-dimensional tensors,
e.g. shape.broadcast
.
An earlier version of the RFC proposed standardizing on representing shapes
in StableHLO as series of 0-dimensional tensors, changing
from using a single operand of tensor<Nxi64>
to using N
operands of
tensor<i64>
, i.e. to replace 1DTensorOf<[HLO_DimensionValue]>
with
Variadic<0DTensorOf<[HLO_DimensionValue]>>
. However, while attempting to
integrate this change we observed that the returns are not as high as expected
and the costs are high. Indeed:
- Both 0-dimensional tensors (which would benefit from this change) and 1-dimensional tensors (which would suffer from this change) are used to represent shapes in the wild.
- 1-dimensional tensors can be created and manipulated using the
shape
dialect. - 1-dimensional tensors make pattern matching easier, especially in DRR.
- Changing from a single operand to a variadic operand breaks C++, MLIR and Python code, which would cause a lot of churn in the ecosystem.
Note that only shape computations that use StableHLO operations are supported
in StableHLO portable artifacts (this can be changed in future RFCs, but is out
of scope for this one). StableHLO provides a pass to legalize many shape
operations to StableHLO operations (including e.g. shape_of
and broadcast
).
As discussed in O2, the community sometimes uses other dialects,
including arith
and shape
, for shape computations. These dialects can be
used with StableHLO programs, but there are interoperability issues which could
be improved if: I) StableHLO used scalars instead of 0-dimensional tensors,
II) StableHLO used index
instead of allowing integer and index types.
However both of these changes need a significant amount of work, so they are not included in this RFC. I) requires extending the StableHLO opset with scalar operations, which involves defining semantics for these operations within StableHLO programs and getting buy-in from producers to support them. This is a fairly novel notion for the StableHLO ecosystem, so some further design exploration is needed here. II) is conceptually straightforward but will need updating a lot of code.