Skip to content

Commit

Permalink
Support generic reduction and scan cases. (#14)
Browse files Browse the repository at this point in the history
Signed-off-by: Ilya Enkovich <[email protected]>
  • Loading branch information
ienkovich authored and minjang committed Jun 24, 2024
1 parent 74f111f commit 0f9a0cf
Show file tree
Hide file tree
Showing 8 changed files with 519 additions and 27 deletions.
15 changes: 11 additions & 4 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2012,6 +2012,7 @@ def deserialize_fp8(np_data, in_dtype):
# ---------------


@pytest.mark.cpu
@pytest.mark.interpreter
def test_max_returns_zero(device):
# Simple test with a tl.max call that returns 0. The interpreter had a bug
Expand All @@ -2038,6 +2039,7 @@ def get_reduced_dtype(dtype_str, op):
return dtype_str


@pytest.mark.cpu
@pytest.mark.interpreter
@pytest.mark.parametrize("op, dtype_str, shape", [(op, dtype, shape) for op in [
'min',
Expand Down Expand Up @@ -2156,9 +2158,6 @@ def kernel(X, Z, BLOCK: tl.constexpr):
def test_reduce(op, dtype_str, shape, axis, keep_dims, num_ctas, device):
check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested

if is_cpu() and op in ('argmin', 'argmax'):
pytest.skip(f"Not yet implemented on CPU: {op}")

@triton.jit
def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, IS_3D: tl.constexpr,
AXIS: tl.constexpr, KEEP_DIMS: tl.constexpr):
Expand Down Expand Up @@ -2286,17 +2285,24 @@ def roll(a1, b1_last, b1_cur, a2, b2_last, b2_cur):
return a1 + a2, tl.where(a2 == 1, b1_cur, 0) + b2_last, b2_cur


@pytest.mark.cpu
@pytest.mark.interpreter
@pytest.mark.parametrize("op, dtype_str, shape, axis, reverse, num_warps", scan_configs + negative_config)
def test_scan2d(op, dtype_str, shape, axis, reverse, num_warps, device):
check_type_supported(dtype_str, device)
if dtype_str == 'bfloat16':
if op == 'cummax':
if is_cuda() and op == 'cummax':
pytest.skip("bfloat16 compare not suppoted before sm90")
if op == 'linear_recurrence':
pytest.skip("Skipping linear_recurrence scan on bfloat16 due to accuracy issues")
numpy_dtype_str = 'float32' if dtype_str == 'bfloat16' else dtype_str

# bf16 vector cast is broken in LLVM for large vectors:
# https://github.com/llvm/llvm-project/issues/92471
# TODO: Remove the change after the bug is fixed.
if is_cpu() and dtype_str == 'bfloat16':
shape = (min(shape[0], 128), min(shape[1], 128))

# triton kernel
@triton.jit
def kernel(X, Y, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):
Expand Down Expand Up @@ -2909,6 +2915,7 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis):
np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3)


@pytest.mark.cpu
@pytest.mark.interpreter
def test_generic_reduction(device):

Expand Down
1 change: 1 addition & 0 deletions third_party/cpu/include/TritonToTritonCPU/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ std::unique_ptr<OperationPass<ModuleOp>> createConvertDotOp();
std::unique_ptr<OperationPass<ModuleOp>> createConvertControlFlowOps();
std::unique_ptr<OperationPass<ModuleOp>> createConvertHistogramOp();
std::unique_ptr<OperationPass<ModuleOp>> createConvertReductionOp();
std::unique_ptr<OperationPass<ModuleOp>> createConvertScanOp();

void tritonToTritonCPUPipelineBuilder(OpPassManager &pm);
void registerTritonToTritonCPUPipeline();
Expand Down
14 changes: 14 additions & 0 deletions third_party/cpu/include/TritonToTritonCPU/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -100,4 +100,18 @@ def ConvertReductionOp : Pass<"triton-cpu-convert-reduction", "mlir::ModuleOp">
"mlir::triton::cpu::TritonCPUDialect"];
}

def ConvertScanOp : Pass<"triton-cpu-convert-scan", "mlir::ModuleOp"> {
let summary = "Convert Triton ScanOp.";
let description = [{

}];
let constructor = "mlir::triton::cpu::createConvertScanOp()";

let dependentDialects = ["mlir::arith::ArithDialect",
"mlir::vector::VectorDialect",
"mlir::scf::SCFDialect",
"mlir::triton::TritonDialect",
"mlir::triton::cpu::TritonCPUDialect"];
}

#endif
3 changes: 2 additions & 1 deletion third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ add_triton_library(TritonToTritonCPU
ConvertElementwiseOps.cpp
ConvertHistogramOp.cpp
ConvertMemoryOps.cpp
ConvertReductionOp.cpp
ConvertPtrOps.cpp
ConvertReductionOp.cpp
ConvertScanOp.cpp
Pipeline.cpp
TypeConverter.cpp

Expand Down
112 changes: 90 additions & 22 deletions third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp
Original file line number Diff line number Diff line change
@@ -1,22 +1,17 @@
#include "ReduceScanCommon.h"
#include "TypeConverter.h"

#include "cpu/include/TritonToTritonCPU/Passes.h"

#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include "triton/Analysis/Allocation.h"
#include "triton/Analysis/AxisInfo.h"
#include "triton/Analysis/Membar.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonCPU/IR/Dialect.h"

#include <numeric>

namespace mlir {
namespace triton {
#define GEN_PASS_DEF_CONVERTREDUCTIONOP
Expand Down Expand Up @@ -44,28 +39,91 @@ class ReductionConversionTarget : public ConversionTarget {
}
};

struct ReduceOpConversion : public OpConversionPattern<triton::ReduceOp> {
using OpConversionPattern::OpConversionPattern;
struct ReduceOpConversion
: public ReduceScanOpConversionBase<triton::ReduceOp,
triton::ReduceReturnOp> {
using ReduceScanOpConversionBase::ReduceScanOpConversionBase;

LogicalResult
matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MLIRContext *ctx = op.getContext();
// Currently, only simple reductions with a single input argumet are
// supported.
// TODO: support generic case.
// More simple cases with a single input and a single combine
// operation can utilize target-specific reduction operations like
// horizaontal vector operations. We detect such cases here and map
// them to the vector::MultiDimReductionOp.
if (succeeded(mapToMultiDimReductionOp(op, rewriter)))
return success();

return ReduceScanOpConversionBase::matchAndRewrite(op, adaptor, rewriter);
}

SmallVector<Value>
lower1DInput(ValueRange inputs, ReduceOp op,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
Region &combineOp = op.getRegion();
int64_t vecSize = cast<VectorType>(inputs[0].getType()).getShape()[0];
SmallVector<int64_t> range(vecSize);
std::iota(range.begin(), range.end(), 0);

ArrayRef<Value> dummies = createShuffleDummies(loc, inputs, rewriter);
SmallVector<Value> res = inputs;
for (int64_t stride = vecSize / 2; stride > 0; stride = stride / 2) {
SmallVector<int64_t> shuffleIndices = range;
for (int64_t i = 0; i < stride; ++i) {
std::swap(shuffleIndices[i], shuffleIndices[i + stride]);
}
SmallVector<Value> shuffledInput;
for (auto [val, dummy] : llvm::zip(res, dummies)) {
shuffledInput.push_back(rewriter.create<vector::ShuffleOp>(
loc, val, dummy, shuffleIndices));
}

res = accumulate(shuffledInput, res, combineOp, rewriter);
}

// The results are in the first element of each produced vector.
Value zero =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
for (size_t i = 0; i < res.size(); ++i) {
res[i] = rewriter.create<vector::ExtractElementOp>(loc, res[i], zero);
}
return res;
}

SmallVector<Value>
lowerLeadingDimension(ValueRange inputs, ReduceOp op,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
Region &combineOp = op.getRegion();
auto shape = cast<VectorType>(inputs[0].getType()).getShape();
SmallVector<Value> res;
for (int64_t idx = 0; idx < shape[0]; ++idx) {
SmallVector<Value> subInputs(inputs.size());
std::transform(inputs.begin(), inputs.end(), subInputs.begin(),
[&](auto val) {
return rewriter.create<vector::ExtractOp>(loc, val, idx);
});

res = accumulate(subInputs, res, combineOp, rewriter);
}
return res;
}

LogicalResult
mapToMultiDimReductionOp(triton::ReduceOp op,
ConversionPatternRewriter &rewriter) const {
if (op.getNumOperands() != 1 || op.getNumResults() != 1)
return failure();

Value src = rewriter.getRemappedValue(op.getOperand(0));
VectorType srcTy = dyn_cast<VectorType>(src.getType());
assert(srcTy);
VectorType srcTy = cast<VectorType>(src.getType());

Block *block = op.getBody();
if (block->getNumArguments() != 2)
return failure();
Value itArg = block->getArgument(0);
Value accArg = block->getArgument(1);
Value accArg = block->getArgument(0);
Value itArg = block->getArgument(1);

auto &blockOps = block->getOperations();
if (blockOps.size() != 2)
Expand Down Expand Up @@ -155,7 +213,18 @@ struct ReduceOpConversion : public OpConversionPattern<triton::ReduceOp> {
elemTy, static_cast<int64_t>(
(1UL << (elemTy.getIntOrFloatBitWidth() - 1)) - 1));
else if (kind == vector::CombiningKind::MINIMUMF ||
kind == vector::CombiningKind::MINNUMF) {
kind == vector::CombiningKind::MAXIMUMF) {
if (elemTy.isF32())
initVal =
rewriter.getF32FloatAttr(std::numeric_limits<float>::quiet_NaN());
else if (elemTy.isF64())
initVal =
rewriter.getF64FloatAttr(std::numeric_limits<double>::quiet_NaN());
else
llvm_unreachable("Unsupported type for acc init value.");
}

else if (kind == vector::CombiningKind::MINNUMF) {
if (elemTy.isF32())
initVal =
rewriter.getF32FloatAttr(std::numeric_limits<float>::infinity());
Expand All @@ -164,8 +233,7 @@ struct ReduceOpConversion : public OpConversionPattern<triton::ReduceOp> {
rewriter.getF64FloatAttr(std::numeric_limits<double>::infinity());
else
llvm_unreachable("Unsupported type for acc init value.");
} else if (kind == vector::CombiningKind::MAXIMUMF ||
kind == vector::CombiningKind::MAXNUMF) {
} else if (kind == vector::CombiningKind::MAXNUMF) {
if (elemTy.isF32())
initVal =
rewriter.getF32FloatAttr(-std::numeric_limits<float>::infinity());
Expand Down
Loading

0 comments on commit 0f9a0cf

Please sign in to comment.