Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite ReduceOp to support arbitrary reduce operations #1305

Merged
merged 23 commits into from
Apr 13, 2023

Conversation

peterbell10
Copy link
Contributor

@peterbell10 peterbell10 commented Mar 8, 2023

Fixes #1285

This changes tt.reduce to replace redOp by a region containing arbitrary code. For example, tl.sum is now lowered as:

%res = "tt.reduce"(%arg0) ({
^bb0(%arg1: f32, %arg2: f32):
  %add = arith.addf %arg1, %arg2 : f32
  tt.reduce.return %add : f32
}) {axis = 1 : i32} : (tensor<128x128xf32>) -> tensor<128xf32>

Support for index reductions at the MLIR level are also dropped in favor of simultaneous reductions over multiple tensors. Which generalizes the code without loss of performance. So for example argmin gets lowered as:

  %7 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32>
  %8 = tt.view %7 : (tensor<256xi32>) -> tensor<1x256xi32>
  %9:2 = "tt.reduce"(%6, %8) ({
  ^bb0(%arg4: f32, %arg5: i32, %arg6: f32, %arg7: i32):
    %14 = arith.cmpf olt, %arg4, %arg6 : f32
    %15 = arith.cmpf ogt, %arg4, %arg6 : f32
    %16 = arith.cmpi slt, %arg5, %arg7 : i32
    %17 = arith.select %16, %arg5, %arg7 : i32
    %18 = arith.select %15, %arg7, %17 : i32
    %19 = arith.select %14, %arg5, %18 : i32
    %20 = arith.cmpf olt, %arg4, %arg6 : f32
    %21 = arith.select %20, %arg4, %arg6 : f32
    tt.reduce.return %21, %19 : f32, i32
  }) {axis = 1 : i32} : (tensor<1x256xf32>, tensor<1x256xi32>) -> (tensor<1xf32>, tensor<1xi32>)

Comment on lines 34 to 58
// Create a new copy of the reduce block, and inline it
Block *currentBlock = rewriter.getBlock();
Region &parent = *currentBlock->getParent();
rewriter.cloneRegionBefore(reduceOp, &parent.front());
auto &newReduce = parent.front();
auto returnOp = dyn_cast<triton::GenericReduceReturnOp>(newReduce.getTerminator());
rewriter.mergeBlockBefore(&newReduce, &*rewriter.getInsertionPoint(), {acc, cur});
acc = returnOp.getResult();
// Delete the terminator, which is no longer used
rewriter.eraseOp(returnOp);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the main change compared to ReduceOpToLLVM.cpp.

Comment on lines 1208 to 1218
def prod(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:

def make_mul(reduce_op):
ir_scalar_ty = input.type.scalar.to_ir(builder)
region = reduce_op.get_region(0)
with insertion_guard(builder):
block = builder.create_block_with_parent(region, [ir_scalar_ty] * 2)
fmul = builder.create_fmul(block.arg(0), block.arg(1))
builder.create_reduce_ret(fmul)

return reduction(input, axis, make_mul, builder)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been using this for testing but the end goal would be to have the compiler build the inner function from a lambda, or something like that. I might need some help with that though.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha yeah it's not entirely trivial. I think it means the ASTVisitor should be modified to create MLIR functions out of lambda, and then the reduce op could merge in the basic block from this function

def TT_GenericReduceOp: TT_Op<"generic_reduce",
[Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Reduction using generic combination algorithm";
let arguments = (ins TT_Tensor:$operand, I32Attr:$axis);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ptillet assuming I can get index reductions working, do you think it would be reasonable to replace ReduceOp entirely?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, if index reductions can work, then I think we could replace ReduceOp with the new op. We'll have to do some heavier testing to make sure that the performance hasn't decreased

Comment on lines 1177 to 1186
axis = _constexpr_to_value(axis)
n = input.shape[axis]
index = arange(0, n, _builder=_builder)
new_shape = [constexpr(1)] * len(input.shape)
new_shape[axis] = constexpr(n)
index = view(index, new_shape, _builder=_builder)
index = broadcast_to(index, input.shape, _builder=_builder)

values, indices = semantic.min_with_index(input, index, axis, _builder)
return indices
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is my strategy for armin/argmax. Instead of special casing it I just lower it as a reduction over two tensors:

  %7 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32>
  %8 = tt.view %7 : (tensor<256xi32>) -> tensor<1x256xi32>
  %9:2 = "tt.generic_reduce"(%6, %8) ({
  ^bb0(%arg4: f32, %arg5: i32, %arg6: f32, %arg7: i32):
    %15 = arith.cmpf olt, %arg4, %arg6 : f32
    %16 = arith.cmpf ogt, %arg4, %arg6 : f32
    %17 = arith.minsi %arg5, %arg7 : i32
    %18 = arith.select %16, %arg7, %17 : i32
    %19 = arith.select %15, %arg5, %18 : i32
    %20 = arith.minf %arg4, %arg6 : f32
    tt.generic_reduce.return %20, %19 : f32, i32
  }) {axis = 1 : i32} : (tensor<1x256xf32>, tensor<1x256xi32>) -> (tensor<1xf32>, tensor<1xi32>)

This has some really nice properties.

  1. the reduction code is the same whether you discard the min/max value or not
  2. It generalized perfectly to higher numbers of tensors, e.g. the 3 needed for aten.var_mean
  3. argmin/argmax specific logic is defined entirely at python level
  4. In my limited testing so far, it performs identically

include/triton/Analysis/Utility.h Show resolved Hide resolved
@@ -80,6 +80,8 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
// Some ops from SCF are illegal
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp, scf::ReduceOp,
scf::ReduceReturnOp>();
// We have custom versions of some arith operators
addIllegalOp<arith::CmpIOp, arith::CmpFOp, arith::SelectOp>();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did start running into edge cases in the some of the Dialect conversion code, where these were slipping through despite there being a conversion rule for them. It's possible that nested regions are handled differently by MLIR, not sure.

barrier();
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
store(acc[i], writePtrs[i]);
}
Copy link
Contributor Author

@peterbell10 peterbell10 Mar 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new changes here basically just change

foo(acc)
if (withIndex)
    foo(accIndex)

into equivalent for loops.

@peterbell10 peterbell10 changed the title POC: Add generic reduction operator to mlir dialect Rewrite ReduceOp to support arbitrary reduce operations Mar 14, 2023
@peterbell10 peterbell10 marked this pull request as ready for review March 14, 2023 21:12
param_types = [ty.to_ir(_builder) for ty in prototype.param_types]
block = _builder.create_block_with_parent(region, param_types)
args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)]
results = _generator.call_JitFunction(combine_fn, args, kwargs={})
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm in two minds whether this is hacky or elegant, but it works. I pass the CodeGenerator in via the _generator argument much like the _builder argument, then call this function which I factored out of visit_Call to generate the appropriate function definition and call it.

peterbell10 added a commit to peterbell10/triton that referenced this pull request Mar 14, 2023
This is cherry-picked from triton-lang#1305

If you call a `JITFunction` twice in the same kernel, first with
`int32` then with `uint32`, the second call will treat the unsigned
value as signed. This passes through MLIR without error because MLIR
uses the same types for both, but different operation calls will be
generated.
ptillet pushed a commit that referenced this pull request Mar 15, 2023
This is cherry-picked from #1305

If you call a `JITFunction` twice in the same kernel, first with `int32`
then with `uint32`, the second call will treat the unsigned value as
signed. This passes through MLIR without error because MLIR uses the
same types for both, but different operation calls will be generated so
you may silently get the wrong result.
@ptillet
Copy link
Collaborator

ptillet commented Mar 16, 2023

Thanks for the PR. Things are busy right now, but we will review it next week!

@ptillet
Copy link
Collaborator

ptillet commented Mar 23, 2023

(sorry, things have been busy and haven't had time to review this yet!)

@peterbell10
Copy link
Contributor Author

@ptillet do you have any idea when you might have time to review this?

lib/Analysis/Membar.cpp Outdated Show resolved Hide resolved
include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td Outdated Show resolved Hide resolved
include/triton/Dialect/Triton/IR/TritonOps.td Show resolved Hide resolved
lib/Analysis/Utility.cpp Outdated Show resolved Hide resolved
}

// TODO: This always takes layout from the first argument which
// is fine for argmin/argmax but may not be optimal generally
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you limit all arguments of reduce to have the same encoding. So this is just fine?

  if (t.getShape() != srcShape) {
    rop.emitError() << "shape mismatch";
  }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The concern is that the first argument might be cheap to convert but the second argument slow to convert. In that case this will remove the cheap layout conversion and add a more expensive one.

Also, I don't think there's ever a case where shape mismatch can happen.

test/Conversion/triton_ops.mlir Show resolved Hide resolved
lib/Dialect/TritonGPU/Transforms/Utility.cpp Outdated Show resolved Hide resolved
test/TritonGPU/combine.mlir Outdated Show resolved Hide resolved
@peterbell10
Copy link
Contributor Author

@Jokeren I've fixed the merge conflicts with #1497 and #1514. Test are passing for me with an A100.

@ptillet ptillet enabled auto-merge (squash) April 12, 2023 16:55
@ptillet
Copy link
Collaborator

ptillet commented Apr 12, 2023

Benchmark related stuff were merged in yesterday, so it's possible the tests got flaky. I'll investigate later today.

@ptillet ptillet merged commit e152183 into triton-lang:main Apr 13, 2023
@ptillet
Copy link
Collaborator

ptillet commented Apr 13, 2023

Thanks again for the PR @peterbell10 . And thanks @Jokeren for the review.

peterbell10 added a commit to peterbell10/triton that referenced this pull request Apr 13, 2023
A small oversight in triton-lang#1305, since `view` can rearrange elements it
should be avoided here. Instead I use indexing with `None` to create
new dimensions.
ptillet added a commit that referenced this pull request Apr 13, 2023
A small oversight in #1305, since `view` can rearrange elements it
should be avoided here. Instead I use indexing with `None` to create new
dimensions.

Co-authored-by: Philippe Tillet <[email protected]>
@peterbell10 peterbell10 deleted the generic-reduction branch August 18, 2023 14:32
pingzhuu pushed a commit to siliconflow/triton that referenced this pull request Apr 2, 2024
…on-lang#1340)

This is cherry-picked from triton-lang#1305

If you call a `JITFunction` twice in the same kernel, first with `int32`
then with `uint32`, the second call will treat the unsigned value as
signed. This passes through MLIR without error because MLIR uses the
same types for both, but different operation calls will be generated so
you may silently get the wrong result.
pingzhuu pushed a commit to siliconflow/triton that referenced this pull request Apr 2, 2024
…riton-lang#1305)

Fixes triton-lang#1285

This changes `tt.reduce` to replace `redOp` by a region containing
arbitrary code. For example, `tl.sum` is now lowered as:
```mlir
%res = "tt.reduce"(%arg0) ({
^bb0(%arg1: f32, %arg2: f32):
  %add = arith.addf %arg1, %arg2 : f32
  tt.reduce.return %add : f32
}) {axis = 1 : i32} : (tensor<128x128xf32>) -> tensor<128xf32>
```
Support for index reductions at the MLIR level are also dropped in favor
of simultaneous reductions over multiple tensors. Which generalizes the
code without loss of performance. So for example `argmin` gets lowered
as:
```mlir
  %7 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32>
  %8 = tt.view %7 : (tensor<256xi32>) -> tensor<1x256xi32>
  %9:2 = "tt.reduce"(%6, %8) ({
  ^bb0(%arg4: f32, %arg5: i32, %arg6: f32, %arg7: i32):
    %14 = arith.cmpf olt, %arg4, %arg6 : f32
    %15 = arith.cmpf ogt, %arg4, %arg6 : f32
    %16 = arith.cmpi slt, %arg5, %arg7 : i32
    %17 = arith.select %16, %arg5, %arg7 : i32
    %18 = arith.select %15, %arg7, %17 : i32
    %19 = arith.select %14, %arg5, %18 : i32
    %20 = arith.cmpf olt, %arg4, %arg6 : f32
    %21 = arith.select %20, %arg4, %arg6 : f32
    tt.reduce.return %21, %19 : f32, i32
  }) {axis = 1 : i32} : (tensor<1x256xf32>, tensor<1x256xi32>) -> (tensor<1xf32>, tensor<1xi32>)
```
pingzhuu pushed a commit to siliconflow/triton that referenced this pull request Apr 2, 2024
A small oversight in triton-lang#1305, since `view` can rearrange elements it
should be avoided here. Instead I use indexing with `None` to create new
dimensions.

Co-authored-by: Philippe Tillet <[email protected]>
ZzEeKkAa pushed a commit to ZzEeKkAa/triton that referenced this pull request Aug 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[frontend] allow var_mean to be implemented in one pass
3 participants