Skip to content

Commit

Permalink
Support tl.histogram for CPU. (#12)
Browse files Browse the repository at this point in the history
Signed-off-by: Ilya Enkovich <[email protected]>
Co-authored-by: Minjang Kim <[email protected]>
  • Loading branch information
2 people authored and int3 committed Aug 28, 2024
1 parent c0d4775 commit 725ec53
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 0 deletions.
1 change: 1 addition & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2496,6 +2496,7 @@ def kernel(X, Y, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.const
# ---------------


@pytest.mark.cpu
@pytest.mark.interpreter
@pytest.mark.parametrize("M, N", [[2048, 2], [1024, 8], [1024, 128], [256, 512], [32, 512], [8, 512], [8, 2]])
def test_histogram(M, N, device):
Expand Down
1 change: 1 addition & 0 deletions third_party/cpu/include/TritonToTritonCPU/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ std::unique_ptr<OperationPass<ModuleOp>> createConvertMemoryOps();
std::unique_ptr<OperationPass<ModuleOp>> createConvertPtrOps();
std::unique_ptr<OperationPass<ModuleOp>> createConvertDotOp();
std::unique_ptr<OperationPass<ModuleOp>> createConvertControlFlowOps();
std::unique_ptr<OperationPass<ModuleOp>> createConvertHistogramOp();
std::unique_ptr<OperationPass<ModuleOp>> createConvertReductionOp();

void tritonToTritonCPUPipelineBuilder(OpPassManager &pm);
Expand Down
11 changes: 11 additions & 0 deletions third_party/cpu/include/TritonToTritonCPU/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,17 @@ def ConvertControlFlowOps : Pass<"triton-cpu-convert-control-flow-op", "mlir::Mo
"mlir::triton::cpu::TritonCPUDialect"];
}

def ConvertHistogramOp : Pass<"triton-cpu-convert-histogram-op", "mlir::ModuleOp"> {
let summary = "Convert Triton HistogramOp.";
let description = [{

}];
let constructor = "mlir::triton::cpu::createConvertHistogramOp()";

let dependentDialects = ["mlir::arith::ArithDialect",
"mlir::memref::MemRefDialect",
"mlir::vector::VectorDialect",

def ConvertReductionOp : Pass<"triton-cpu-convert-reduction", "mlir::ModuleOp"> {
let summary = "Convert Triton ReduceOp.";
let description = [{
Expand Down
1 change: 1 addition & 0 deletions third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ add_triton_library(TritonToTritonCPU
ConvertControlFlowOps.cpp
ConvertDotOp.cpp
ConvertElementwiseOps.cpp
ConvertHistogramOp.cpp
ConvertMemoryOps.cpp
ConvertReductionOp.cpp
ConvertPtrOps.cpp
Expand Down
134 changes: 134 additions & 0 deletions third_party/cpu/lib/TritonToTritonCPU/ConvertHistogramOp.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
#include "TypeConverter.h"

#include "cpu/include/TritonToTritonCPU/Passes.h"

#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/Pass/Pass.h"

#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonCPU/IR/Dialect.h"

namespace mlir {
namespace triton {
#define GEN_PASS_DEF_CONVERTHISTOGRAMOP
#include "cpu/include/TritonToTritonCPU/Passes.h.inc"
} // namespace triton
} // namespace mlir

using namespace mlir;
using namespace mlir::triton;
using namespace mlir::triton::cpu;

namespace {

class HistogramConversionTarget : public ConversionTarget {
public:
explicit HistogramConversionTarget(MLIRContext &ctx, TypeConverter &converter)
: ConversionTarget(ctx) {
addLegalDialect<mlir::BuiltinDialect>();
addLegalDialect<vector::VectorDialect>();
addLegalDialect<arith::ArithDialect>();
addLegalDialect<math::MathDialect>();
addLegalDialect<TritonDialect>();
addLegalDialect<TritonCPUDialect>();

addIllegalOp<triton::HistogramOp>();
}
};

struct HistogramOpConversion : public OpConversionPattern<triton::HistogramOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(triton::HistogramOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto src = rewriter.getRemappedValue(op.getSrc());
auto srcTy = dyn_cast<VectorType>(src.getType());
auto resTy =
dyn_cast<VectorType>(getTypeConverter()->convertType(op.getType()));

if (srcTy.getRank() != 1)
llvm_unreachable("unsupported input for histogram op (rank != 1)");

Value zero = rewriter.create<arith::ConstantOp>(
loc, resTy, rewriter.getZeroAttr(resTy));
Value one = rewriter.create<arith::ConstantOp>(loc, resTy,
rewriter.getOneAttr(resTy));
VectorType cmpVecTy =
VectorType::get(resTy.getShape(), srcTy.getElementType());
Value rangeVec = rewriter.create<arith::ConstantOp>(
loc, resTy, makeRangeAttr(cmpVecTy, rewriter));
Value res = zero;
for (int64_t i = 0; i < srcTy.getShape()[0]; ++i) {
Value idx = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIndexType(), rewriter.getIndexAttr(i));
Value elem = rewriter.create<vector::ExtractElementOp>(loc, src, idx);
Value elemVec = rewriter.create<vector::BroadcastOp>(loc, cmpVecTy, elem);
Value mask = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
elemVec, rangeVec);
Value delta = vector::selectPassthru(rewriter, mask, one, zero);
res = rewriter.create<arith::AddIOp>(loc, res, delta);
}

rewriter.replaceOp(op, res);

return success();
}

TypedAttr makeRangeAttr(VectorType resTy,
ConversionPatternRewriter &rewriter) const {
Type elemTy = resTy.getElementType();
if (elemTy.isInteger(32)) {
SmallVector<int32_t> range(resTy.getShape()[0]);
std::iota(range.begin(), range.end(), 0);
return rewriter.getI32VectorAttr(range);
} else if (elemTy.isInteger(64)) {
SmallVector<int64_t> range(resTy.getShape()[0]);
std::iota(range.begin(), range.end(), 0);
return rewriter.getI64VectorAttr(range);
} else {
llvm_unreachable(
"unsupported src elem type for histogram (expected i32 or i64)");
}
}
};

struct ConvertHistogramOp
: public triton::impl::ConvertHistogramOpBase<ConvertHistogramOp> {
using ConvertHistogramOpBase::ConvertHistogramOpBase;

ConvertHistogramOp() : ConvertHistogramOpBase() {}

void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp mod = getOperation();

TritonToTritonCPUTypeConverter typeConverter;
HistogramConversionTarget convTarget(*context, typeConverter);
RewritePatternSet patterns(context);
patterns.add<HistogramOpConversion>(typeConverter, context);

if (failed(applyPartialConversion(mod, convTarget, std::move(patterns))))
return signalPassFailure();
}
};

} // namespace

namespace mlir {
namespace triton {
namespace cpu {

std::unique_ptr<OperationPass<ModuleOp>> createConvertHistogramOp() {
return std::make_unique<ConvertHistogramOp>();
}

} // namespace cpu
} // namespace triton
} // namespace mlir
1 change: 1 addition & 0 deletions third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ void tritonToTritonCPUPipelineBuilder(OpPassManager &pm) {
pm.addPass(mlir::triton::cpu::createConvertPtrOps());
pm.addPass(mlir::triton::cpu::createConvertElementwiseOps());
pm.addPass(mlir::triton::cpu::createConvertDotOp());
pm.addPass(mlir::triton::cpu::createConvertHistogramOp());
pm.addPass(mlir::triton::cpu::createConvertReductionOp());
pm.addPass(mlir::triton::cpu::createConvertControlFlowOps());
// pm.addPass(mlir::createReconcileUnrealizedCastsPass());
Expand Down

0 comments on commit 725ec53

Please sign in to comment.