-
Notifications
You must be signed in to change notification settings - Fork 42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Upstream] Decouple the Triton GPU lowering utils. #992
Conversation
Very large PR. We should split it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is WIP. Just for review comments.
What do you still plan to improve? Use common conversion files as upstream?
Yes. The steps of decoupling code are:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As the file was moved, I am not sure what changes are actually new. I just reviewed the whole of it.
/// Create a predicated block, using \p cond as the condition and \p ops for the | ||
/// values supplied by the conditional branch to the exit block. The \p | ||
/// thenOpsFn function is used to inject operations in the 'then' branch: | ||
/// cf.cond_br %cond, ^br1, ^br2(%ops) | ||
/// ^br1: | ||
/// %then_ops = `thenOpsFn()` | ||
/// cf.br ^br2(%then_ops) | ||
/// ^br2(%block_ops): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use markdown for formatting
assert(llvm::all_of(llvm::enumerate(ops, thenOps), | ||
[](const auto &enumerator) { | ||
auto [index, op, thenOp] = enumerator; | ||
return op.getType() == thenOp.getType(); | ||
}) && | ||
"type mismatch found"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert(llvm::all_of(llvm::enumerate(ops, thenOps), | |
[](const auto &enumerator) { | |
auto [index, op, thenOp] = enumerator; | |
return op.getType() == thenOp.getType(); | |
}) && | |
"type mismatch found"); | |
assert(llvm::equal(ops, thenOps, ...) && "type mismatch found"); |
template <typename ThenOpsFn> | ||
Block &createPredicatedBlock(ConversionPatternRewriter &rewriter, Location loc, | ||
Value cond, ArrayRef<Value> ops, | ||
ThenOpsFn &&thenOpsFn) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
template <typename ThenOpsFn> | |
Block &createPredicatedBlock(ConversionPatternRewriter &rewriter, Location loc, | |
Value cond, ArrayRef<Value> ops, | |
ThenOpsFn &&thenOpsFn) { | |
template <typename F> | |
Block &createPredicatedBlock(RewriterBase &rewriter, Location loc, | |
Value cond, ValueRange ops, | |
F thenOpsFn) { |
Can we also have more specific names for ops
and thenOpsFn
? Maybe branchBlockArgs
and endBlockArgsGen
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be more readable to generate scf
operations here? If not, why not go down and use llvm.br
right away?
|
||
static SmallVector<Value> | ||
emitBaseIndexForDpasLayout(Location loc, RewriterBase &rewriter, | ||
const DpasEncodingAttr &dpasLayout, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const DpasEncodingAttr &dpasLayout, | |
DpasEncodingAttr dpasLayout, |
Same for other Attribute
args
75e76b7
to
d01bdca
Compare
You can use this commit to review the changes in the Utility.h |
Utility.h changes LGTM |
… and `emitBaseIndexForLayout`. The Triton GPU lowering util support dispatching those utils to the layout interface.
InterfaceMethod<"emit base index", | ||
"SmallVector<Value>", | ||
"emitBaseIndexWithinCTAForLayout", | ||
(ins "Location":$loc, | ||
"RewriterBase&":$rewriter, | ||
"RankedTensorType":$type)>, | ||
|
||
InterfaceMethod<"emit offset", | ||
"SmallVector<SmallVector<unsigned>>", | ||
"emitOffsetForLayout", | ||
(ins "RankedTensorType":$type)>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sidenote: you can provide a default implementation changing this to:
InterfaceMethod<"emit base index", | |
"SmallVector<Value>", | |
"emitBaseIndexWithinCTAForLayout", | |
(ins "Location":$loc, | |
"RewriterBase&":$rewriter, | |
"RankedTensorType":$type)>, | |
InterfaceMethod<"emit offset", | |
"SmallVector<SmallVector<unsigned>>", | |
"emitOffsetForLayout", | |
(ins "RankedTensorType":$type)>, | |
InterfaceMethod<"emit base index", | |
"SmallVector<Value>", | |
"emitBaseIndexWithinCTAForLayout", | |
(ins "Location":$loc, | |
"RewriterBase&":$rewriter, | |
"RankedTensorType":$type), | |
/*methodBody=*/[{}], | |
/*defaultImplementation=*/[{ | |
// IMPL | |
}]>, | |
InterfaceMethod<"emit offset", | |
"SmallVector<SmallVector<unsigned>>", | |
"emitOffsetForLayout", | |
(ins "RankedTensorType":$type), | |
/*methodBody=*/"", | |
/*defaultImplementation=*/[{ | |
// IMPL | |
}]>, |
Then, in attriburtes implementing this interface, instead of inheriting like:
[InterfaceFoo]
do:
[DeclareAttrInterfaceMethods<InterfaceFoo>]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please explain the motivation to move Utility.h from lib to include? It is in lib for the other backends.
I have moved it back in 2faf5ed, please take a look to see if you are ok with it. |
) Address review comments: #992 (comment), #992 (comment), #992 (comment), #992 (comment). --------- Signed-off-by: Whitney Tsang <[email protected]>
Other than adding default implementation as suggested in #992 (comment), this PR is ready to upstream IMO. |
Looks perfect to me. |
Current state LGTM. If the default definition is to be added in a different PR, we can keep as is. 👍 |
The Triton GPU lowering util support dispatching those utils to the layout interface.
Decouple the basic util function
emitIndices
,emitOffsetForLayout
andemitBaseIndexForLayout
.The layout specific difference is abstracted by the layout interface.
The target specific difference is abstracted by
TargetInfo
class.Then we can reuse the Triton GPU lowering.