From 46d73044e8ee1a267258b39d46174ca4f1f99ef3 Mon Sep 17 00:00:00 2001 From: Wenyi Zhao Date: Tue, 22 Nov 2022 17:11:06 +0800 Subject: [PATCH] [transform] add a pass to make the transformed payload ir suitable for RAL --- .../mlir/disc/tools/disc-transform/BUILD | 22 ++ .../tools/disc-transform/transforms/passes.h | 4 + .../transforms/rewrite_payload_ir_for_ral.cc | 196 ++++++++++++++++++ .../tests/rewrite-payload-ir-for-ral.mlir | 36 ++++ .../transforms/transform_passes.td | 9 + .../transforms/disc_assign_memory_space.cc | 8 +- .../mlir/disc/transforms/placement_utils.cc | 8 + .../mlir/disc/transforms/placement_utils.h | 4 + 8 files changed, 280 insertions(+), 7 deletions(-) create mode 100644 tao_compiler/mlir/disc/tools/disc-transform/transforms/rewrite_payload_ir_for_ral.cc create mode 100644 tao_compiler/mlir/disc/tools/disc-transform/transforms/tests/rewrite-payload-ir-for-ral.mlir diff --git a/tao_compiler/mlir/disc/tools/disc-transform/BUILD b/tao_compiler/mlir/disc/tools/disc-transform/BUILD index f6c6a7b1a3e..ac2a66817a3 100644 --- a/tao_compiler/mlir/disc/tools/disc-transform/BUILD +++ b/tao_compiler/mlir/disc/tools/disc-transform/BUILD @@ -183,6 +183,27 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "rewrite-payload-ir-for-ral", + srcs = ["transforms/rewrite_payload_ir_for_ral.cc"], + deps = [ + ":pass_details", + "//tensorflow/compiler/mlir/disc:placement_utils", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:LinalgDialect", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + ], + alwayslink = 1, +) + cc_library( name = "all_passes", hdrs = [ @@ -194,6 +215,7 @@ cc_library( ], deps = [ ":legalize_lmhlo_fusion_to_linalg", + ":rewrite-payload-ir-for-ral", ":transform_dialect_interpreter", "@llvm-project//mlir:Pass", ], diff --git a/tao_compiler/mlir/disc/tools/disc-transform/transforms/passes.h b/tao_compiler/mlir/disc/tools/disc-transform/transforms/passes.h index 1f35b32ccbe..afd67f387f4 100644 --- a/tao_compiler/mlir/disc/tools/disc-transform/transforms/passes.h +++ b/tao_compiler/mlir/disc/tools/disc-transform/transforms/passes.h @@ -42,6 +42,10 @@ std::unique_ptr> createDiscTransformDialectInterpreterPass(const std::string& fileName = "", bool enableExpensiveChecks = false); +// Converts the transformed payload IR to be suitable for RAL. +std::unique_ptr> createDiscRewritePayloadIRForRALPass( + bool gpuEnabled = false); + } // namespace disc_ral } // namespace mlir diff --git a/tao_compiler/mlir/disc/tools/disc-transform/transforms/rewrite_payload_ir_for_ral.cc b/tao_compiler/mlir/disc/tools/disc-transform/transforms/rewrite_payload_ir_for_ral.cc new file mode 100644 index 00000000000..dd6e671961d --- /dev/null +++ b/tao_compiler/mlir/disc/tools/disc-transform/transforms/rewrite_payload_ir_for_ral.cc @@ -0,0 +1,196 @@ +// Copyright 2022 The BladeDISC Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "tensorflow/compiler/mlir/disc/tools/disc-transform/transforms/PassDetail.h" +#include "tensorflow/compiler/mlir/disc/transforms/placement_utils.h" + +#define DEBUG_TYPE "disc-rewrite-payload-ir-for-ral" + +// This file implements the logic to convert the transformed payload IR to be +// suitable for RAL. + +namespace mlir { +namespace disc_ral { +namespace { + +using func::FuncOp; +using placement_utils::copyWithMemorySpace; +using scf::ForeachThreadOp; +using scf::ParallelOp; + +struct DiscRewritePayloadIRForRALPass + : public DiscRewritePayloadIRForRALPassBase< + DiscRewritePayloadIRForRALPass> { + explicit DiscRewritePayloadIRForRALPass(bool gpuEnabled) + : DiscRewritePayloadIRForRALPassBase:: + DiscRewritePayloadIRForRALPassBase() { + this->gpuEnabled_ = gpuEnabled; + } + void runOnOperation() override; + + // replace scf::foreach_thread op with scf::parallel op + LogicalResult convertForeachThreadToParallelOp(); + LogicalResult funcLevelConvertForeachThreadToParallelOp(FuncOp funcOp); + + // assign placement info for each memref value, e.g. memref -> + // memref + LogicalResult assignPlacement(); + LogicalResult assignPlacementForFuncOp(FuncOp funcOp); +}; + +LogicalResult +DiscRewritePayloadIRForRALPass::funcLevelConvertForeachThreadToParallelOp( + FuncOp funcOp) { + SmallVector forOps; + funcOp.walk([&](ForeachThreadOp op) { forOps.push_back(op); }); + + OpBuilder b(funcOp); + for (ForeachThreadOp foreachThreadOp : forOps) { + if (foreachThreadOp.getOutputs().size() != 0) + return foreachThreadOp->emitError() + << "Not support ForeachThreadOp with outputs a.t.m.\n"; + + b.setInsertionPoint(foreachThreadOp); + Location loc = foreachThreadOp.getLoc(); + int64_t rank = foreachThreadOp.getRank(); + Value zero = b.create(loc, 0); + Value one = b.create(loc, 1); + SmallVector lowerBounds(rank, zero); + SmallVector upperBounds = foreachThreadOp.getNumThreads(); + SmallVector steps(rank, one); + + auto parallelOp = + b.create(loc, lowerBounds, upperBounds, steps); + BlockAndValueMapping mapping; + for (const auto& z : llvm::zip(foreachThreadOp.getThreadIndices(), + parallelOp.getInductionVars())) + mapping.map(std::get<0>(z), std::get<1>(z)); + b.setInsertionPointToStart(parallelOp.getBody()); + for (auto& nestedOp : foreachThreadOp.getBody()->without_terminator()) { + Operation* cloned = b.clone(nestedOp, mapping); + } + foreachThreadOp->erase(); + } + return success(); +} + +LogicalResult +DiscRewritePayloadIRForRALPass::convertForeachThreadToParallelOp() { + for (auto funcOp : getOperation().getOps()) { + if (failed(funcLevelConvertForeachThreadToParallelOp(funcOp))) + return failure(); + } + return success(); +} + +LogicalResult DiscRewritePayloadIRForRALPass::assignPlacementForFuncOp( + FuncOp funcOp) { + auto maybeConvertType = [&](Type ty) -> Type { + auto memrefTy = ty.dyn_cast(); + if (!memrefTy || memrefTy.getMemorySpace()) return ty; + return copyWithMemorySpace( + memrefTy.getContext(), memrefTy, + this->gpuEnabled_ ? placement_utils::kGpu : placement_utils::kCpu); + }; + + auto convertValue = [&](Value v) { + auto newTy = maybeConvertType(v.getType()); + if (newTy != v.getType()) v.setType(newTy); + return success(); + }; + + // update types of results of operations + if (funcOp + ->walk([&](Operation* op) { + for (Value value : llvm::to_vector(op->getResults())) { + if (failed(convertValue(value))) return WalkResult::interrupt(); + } + return WalkResult::advance(); + }) + .wasInterrupted()) { + return failure(); + } + + // update types of block arguments + if (funcOp + ->walk([&](Block* block) { + for (Value value : llvm::to_vector((block->getArguments()))) { + if (failed(convertValue(value))) return WalkResult::interrupt(); + } + return WalkResult::advance(); + }) + .wasInterrupted()) { + return failure(); + } + + // update the type of func op + SmallVector refinedInputTypes; + for (Type ty : funcOp.getArgumentTypes()) { + refinedInputTypes.push_back(maybeConvertType(ty)); + } + SmallVector refinedOutputTypes; + for (Type ty : funcOp.getResultTypes()) { + refinedOutputTypes.push_back(maybeConvertType(ty)); + } + auto newFuncTy = FunctionType::get(funcOp.getContext(), refinedInputTypes, + refinedOutputTypes); + funcOp.setType(newFuncTy); + return success(); +} + +LogicalResult DiscRewritePayloadIRForRALPass::assignPlacement() { + if (gpuEnabled_) + return getOperation()->emitError() + << "not support assign placement info for gpu a.t.m.\n"; + + for (FuncOp funcOp : + llvm::to_vector<4>(getOperation().getOps())) { + if (failed(assignPlacementForFuncOp(funcOp))) return failure(); + } + + return success(); +} + +void DiscRewritePayloadIRForRALPass::runOnOperation() { + // 1, rewrite scf.foreach_thread to scf.parallel + if (failed(convertForeachThreadToParallelOp())) { + return signalPassFailure(); + } + LLVM_DEBUG(llvm::dbgs() << "After ForeachThreadOp -> ParallelOp:\n" + << getOperation() << "\n"); + + // 2, assign placement info for each memref value. + if (failed(assignPlacement())) { + return signalPassFailure(); + } + LLVM_DEBUG(llvm::dbgs() << "After assign placement:\n" + << getOperation() << "\n"); +} + +} // namespace + +std::unique_ptr> createDiscRewritePayloadIRForRALPass( + bool gpuEnabled) { + return std::make_unique(gpuEnabled); +} + +} // namespace disc_ral +} // namespace mlir diff --git a/tao_compiler/mlir/disc/tools/disc-transform/transforms/tests/rewrite-payload-ir-for-ral.mlir b/tao_compiler/mlir/disc/tools/disc-transform/transforms/tests/rewrite-payload-ir-for-ral.mlir new file mode 100644 index 00000000000..eee3fbeb357 --- /dev/null +++ b/tao_compiler/mlir/disc/tools/disc-transform/transforms/tests/rewrite-payload-ir-for-ral.mlir @@ -0,0 +1,36 @@ +// RUN: disc-opt --disc-rewrite-payload-ir-for-ral -split-input-file %s | FileCheck %s + +#map = affine_map<()[s0] -> (s0 ceildiv 6)> +#map1 = affine_map<()[s0] -> (s0 ceildiv 16)> +#map2 = affine_map<(d0)[s0] -> (d0 * -6 + s0, 6)> +#map3 = affine_map<(d0)[s0] -> (d0 * -16 + s0, 16)> +#map4 = affine_map<(d0) -> (d0 * 6)> +#map5 = affine_map<(d0) -> (d0 * 16)> +module { + // CHECK-LABEL: @matmul_nn + // CHECK-SAME: (%[[ARG0:.*]]: memref, %[[ARG1:.*]]: memref, %[[ARG2:.*]]: memref) + func.func @matmul_nn(%arg0: memref, %arg1: memref, %arg2: memref) -> memref attributes {test = true} { + %cst = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %dim = memref.dim %arg0, %c0 : memref + %dim_0 = memref.dim %arg1, %c1 : memref + %0 = affine.apply #map()[%dim] + %1 = affine.apply #map1()[%dim_0] + %dim_1 = memref.dim %arg0, %c1 : memref + // CHECK-NOT: scf.foreach_thread + // CHECK: scf.parallel + scf.foreach_thread (%arg3, %arg4) in (%0, %1) { + %2 = affine.min #map2(%arg3)[%dim] + %3 = affine.min #map3(%arg4)[%dim_0] + %4 = affine.apply #map4(%arg3) + %5 = affine.apply #map5(%arg4) + %subview = memref.subview %arg0[%4, 0] [%2, %dim_1] [1, 1] : memref to memref> + %subview_2 = memref.subview %arg1[0, %5] [%dim_1, %3] [1, 1] : memref to memref> + %subview_3 = memref.subview %arg2[%4, %5] [%2, %3] [1, 1] : memref to memref> + linalg.fill ins(%cst : f32) outs(%subview_3 : memref>) + linalg.matmul ins(%subview, %subview_2 : memref>, memref>) outs(%subview_3 : memref>) + } {thread_dim_mapping = []} + return %arg2 : memref + } +} \ No newline at end of file diff --git a/tao_compiler/mlir/disc/tools/disc-transform/transforms/transform_passes.td b/tao_compiler/mlir/disc/tools/disc-transform/transforms/transform_passes.td index 8acbd1b2090..e6ba060fd24 100644 --- a/tao_compiler/mlir/disc/tools/disc-transform/transforms/transform_passes.td +++ b/tao_compiler/mlir/disc/tools/disc-transform/transforms/transform_passes.td @@ -30,3 +30,12 @@ def DiscTransformDialectInterpreterPass : Pass<"disc-transform-dialect-interpret /*default=*/"false", "perform expensive checks to better report errors in the transform IR.">, ]; } + +def DiscRewritePayloadIRForRALPass : Pass<"disc-rewrite-payload-ir-for-ral", "ModuleOp"> { + let summary = "Pass to rewrite the payload IR transformed by transform IR to be suitable for RAL."; + let constructor = "createDiscRewritePayloadIRForRALPass()"; + let options = [ + Option<"gpuEnabled_", "gpu-enabled", "bool", + /*default=*/"false", "whether gpu is available.">, + ]; +} diff --git a/tao_compiler/mlir/disc/transforms/disc_assign_memory_space.cc b/tao_compiler/mlir/disc/transforms/disc_assign_memory_space.cc index 3eb988fee21..b7dd5c28750 100644 --- a/tao_compiler/mlir/disc/transforms/disc_assign_memory_space.cc +++ b/tao_compiler/mlir/disc/transforms/disc_assign_memory_space.cc @@ -44,13 +44,7 @@ namespace disc_ral { namespace { -// Returns a new memref type with provided memory space -MemRefType copyWithMemorySpace(MLIRContext* ctx, MemRefType type, - StringRef memory_space) { - Attribute memSpace = StringAttr::get(ctx, memory_space); - return MemRefType::get(type.getShape(), type.getElementType(), - type.getLayout(), memSpace); -} +using placement_utils::copyWithMemorySpace; // return a new memref type with provided memory space if the input type if a // memref type otherwise return the original type. diff --git a/tao_compiler/mlir/disc/transforms/placement_utils.cc b/tao_compiler/mlir/disc/transforms/placement_utils.cc index 6972691574c..329aebd205a 100644 --- a/tao_compiler/mlir/disc/transforms/placement_utils.cc +++ b/tao_compiler/mlir/disc/transforms/placement_utils.cc @@ -232,5 +232,13 @@ LogicalResult parseEntryFunctionOutputPlacements( return success(); } +// Returns a new memref type with provided memory space +MemRefType copyWithMemorySpace(MLIRContext* ctx, MemRefType type, + StringRef memory_space) { + Attribute memSpace = StringAttr::get(ctx, memory_space); + return MemRefType::get(type.getShape(), type.getElementType(), + type.getLayout(), memSpace); +} + } // namespace placement_utils } // namespace mlir diff --git a/tao_compiler/mlir/disc/transforms/placement_utils.h b/tao_compiler/mlir/disc/transforms/placement_utils.h index d3a06c95765..d73a237d7fe 100644 --- a/tao_compiler/mlir/disc/transforms/placement_utils.h +++ b/tao_compiler/mlir/disc/transforms/placement_utils.h @@ -15,6 +15,7 @@ #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" // TF:llvm-project #include "mlir/IR/MLIRContext.h" // TF:llvm-project #include "tensorflow/compiler/mlir/disc/IR/hlo_disc_ops.h" @@ -97,6 +98,9 @@ inline bool isMarkShapeCalcTargetOp(Operation* op) { return isTensorDialect(op) || isMhloDialect(op) || isStdOnTensor(op); } +// Returns a new memref type with provided memory space +MemRefType copyWithMemorySpace(MLIRContext* ctx, MemRefType type, + StringRef memory_space); } // namespace placement_utils } // namespace mlir