Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move BroadcastOp from StableHLO to CHLO. #51

Closed
wants to merge 1 commit into from

Conversation

subhankarshah
Copy link
Member

@subhankarshah subhankarshah commented Aug 28, 2022

Part 1 of Many

  • Create a copy of stablehlo.broadcast in CHLO
  • Create tests for chlo.broadcast

@subhankarshah subhankarshah changed the title [mhlo] Move BroadcastOp from HLO to CHLO. [mhlo] Move BroadcastOp from StableHLO to CHLO. Aug 28, 2022
@burmako burmako added the Spec label Aug 30, 2022
Copy link
Contributor

@burmako burmako left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM except for a few minor changes. Please also adjust the PR title to: 1) drop "[mhlo]" because there's no MHLO in this repository, 2) change "move" to "copy" because we aren't removing the corresponding MHLO op.

stablehlo/dialect/CMakeLists.txt Outdated Show resolved Hide resolved
stablehlo/dialect/ChloOps.cpp Outdated Show resolved Hide resolved

return success();
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy/pasting implementations between dialects is fine for now. In the near future, we'll come up with ways to share code - both between CHLO and StableHLO (for ops introduced during #3) and between StableHLO and MHLO.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

spurious line 529

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The is only shared code for the time being, because eventually we will remove the op from StableHLO

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"for the time being". Given that we're already providing compatibility guarantees for StableHLO (we haven't documented them in docs/ yet, but we've been talking about at least 6 months of backward compatibility to various stakeholders), the StableHLO op is here to stay for a long time. We'll need a solution for code sharing, but that doesn't have to block this PR.

@burmako
Copy link
Contributor

burmako commented Aug 30, 2022

Let's also expand the PR description to explain why it's doing what it's doing, perhaps with a reference to the corresponding issue.

@burmako
Copy link
Contributor

burmako commented Aug 30, 2022

Having thought about this further (thank you, @subhankarshah, for your prototype of the Linalg lowering!), I believe that we should have more discussion before merging this PR.

The motivation behind this PR (and, more generally, #3) is the desire to trim the StableHLO opset by moving decomposable ops into CHLO. But in what cases are these decompositions worth it? Even for something as simple as mhlo.broadcast, evaluating this is non-obvious.

E.g. if we were to decompose stablehlo.broadcast to other MHLO ops, then the most general case of decomposition would involve stablehlo.broadcast_in_dim and three shape ops (shape.const with BroadcastOp::broadcast_sizes, shape.shape_of with the BroadcastOp::operand and shape.concat). This isn't extremely easy to eyeball or analyze, so I think that a larger-scale discussion is needed.

As recently mentioned on Discord, we currently don't have a formal process for discussing new developments. While in some cases, e.g. for evaluating the proposed design of the interpreter, it feels like we can make forward progress less formally, it sounds like in this case cutting corners could be inadvisable. I'll follow up soon on the next steps.

Part 1-of-many
 - Create a copy of stablehlo.broadcast in CHLO, including changes in .td and .cc files.
 - Create tests for chlo.broadcast (refer tests for stablehlo.broadcast).
@burmako burmako changed the title [mhlo] Move BroadcastOp from StableHLO to CHLO. Move BroadcastOp from StableHLO to CHLO. Sep 2, 2022
@burmako
Copy link
Contributor

burmako commented Sep 6, 2022

As @subhankarshah pointed out, unranked dynamism presents a problem here.

More specifically, stablehlo.broadcast with an unranked operand and an unranked result cannot be lowered to other StableHLO ops. stablehlo.broadcast_in_dim could be a potential target, but its broadcast_dimensions don't work for this use case. Since stablehlo.broadcast prepends the new dimensions, broadcast_dimensions would need to represent a suffix of the result shape, which it cannot do for an unknown rank. If it supported negative axes, then this could work, but it doesn't at the moment.

Given that, we'll be closing this PR (and postponing the work on some other potentially unranked ops in #3) until we reach a conclusion on whether or not StableHLO should support unranked dynamism, which is part of #8.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants