Skip to content

Commit

Permalink
[transform] add a pass to make the transformed payload ir suitable fo…
Browse files Browse the repository at this point in the history
…r RAL
  • Loading branch information
wyzero committed Nov 24, 2022
1 parent b9c8180 commit 33cd154
Show file tree
Hide file tree
Showing 8 changed files with 280 additions and 7 deletions.
22 changes: 22 additions & 0 deletions tao_compiler/mlir/disc/tools/disc-transform/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,27 @@ cc_library(
alwayslink = 1,
)

cc_library(
name = "rewrite-payload-ir-for-ral",
srcs = ["transforms/rewrite_payload_ir_for_ral.cc"],
deps = [
":pass_details",
"//tensorflow/compiler/mlir/disc:placement_utils",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:LinalgDialect",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
alwayslink = 1,
)

cc_library(
name = "all_passes",
hdrs = [
Expand All @@ -194,6 +215,7 @@ cc_library(
],
deps = [
":legalize_lmhlo_fusion_to_linalg",
":rewrite-payload-ir-for-ral",
":transform_dialect_interpreter",
"@llvm-project//mlir:Pass",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ std::unique_ptr<OperationPass<ModuleOp>>
createDiscTransformDialectInterpreterPass(const std::string& fileName = "",
bool enableExpensiveChecks = false);

// Converts the transformed payload IR to be suitable for RAL.
std::unique_ptr<OperationPass<ModuleOp>> createDiscRewritePayloadIRForRALPass(
bool gpuEnabled = false);

} // namespace disc_ral
} // namespace mlir

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
// Copyright 2022 The BladeDISC Authors. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "tensorflow/compiler/mlir/disc/tools/disc-transform/transforms/PassDetail.h"
#include "tensorflow/compiler/mlir/disc/transforms/placement_utils.h"

#define DEBUG_TYPE "disc-rewrite-payload-ir-for-ral"

// This file implements the logic to convert the transformed payload IR to be
// suitable for RAL.

namespace mlir {
namespace disc_ral {
namespace {

using func::FuncOp;
using placement_utils::copyWithMemorySpace;
using scf::ForeachThreadOp;
using scf::ParallelOp;

struct DiscRewritePayloadIRForRALPass
: public DiscRewritePayloadIRForRALPassBase<
DiscRewritePayloadIRForRALPass> {
explicit DiscRewritePayloadIRForRALPass(bool gpuEnabled)
: DiscRewritePayloadIRForRALPassBase<DiscRewritePayloadIRForRALPass>::
DiscRewritePayloadIRForRALPassBase() {
this->gpuEnabled_ = gpuEnabled;
}
void runOnOperation() override;

// replace scf::foreach_thread op with scf::parallel op
LogicalResult convertForeachThreadToParallelOp();
LogicalResult funcLevelConvertForeachThreadToParallelOp(FuncOp funcOp);

// assign placement info for each memref value, e.g. memref<f32> ->
// memref<f32, "cpu">
LogicalResult assignPlacement();
LogicalResult assignPlacementForFuncOp(FuncOp funcOp);
};

LogicalResult
DiscRewritePayloadIRForRALPass::funcLevelConvertForeachThreadToParallelOp(
FuncOp funcOp) {
SmallVector<ForeachThreadOp> forOps;
funcOp.walk([&](ForeachThreadOp op) { forOps.push_back(op); });

OpBuilder b(funcOp);
for (ForeachThreadOp foreachThreadOp : forOps) {
if (foreachThreadOp.getOutputs().size() != 0)
return foreachThreadOp->emitError()
<< "Not support ForeachThreadOp with outputs a.t.m.\n";

b.setInsertionPoint(foreachThreadOp);
Location loc = foreachThreadOp.getLoc();
int64_t rank = foreachThreadOp.getRank();
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
SmallVector<Value> lowerBounds(rank, zero);
SmallVector<Value> upperBounds = foreachThreadOp.getNumThreads();
SmallVector<Value> steps(rank, one);

auto parallelOp =
b.create<ParallelOp>(loc, lowerBounds, upperBounds, steps);
BlockAndValueMapping mapping;
for (const auto& z : llvm::zip(foreachThreadOp.getThreadIndices(),
parallelOp.getInductionVars()))
mapping.map(std::get<0>(z), std::get<1>(z));
b.setInsertionPointToStart(parallelOp.getBody());
for (auto& nestedOp : foreachThreadOp.getBody()->without_terminator()) {
Operation* cloned = b.clone(nestedOp, mapping);
}
foreachThreadOp->erase();
}
return success();
}

LogicalResult
DiscRewritePayloadIRForRALPass::convertForeachThreadToParallelOp() {
for (auto funcOp : getOperation().getOps<FuncOp>()) {
if (failed(funcLevelConvertForeachThreadToParallelOp(funcOp)))
return failure();
}
return success();
}

LogicalResult DiscRewritePayloadIRForRALPass::assignPlacementForFuncOp(
FuncOp funcOp) {
auto maybeConvertType = [&](Type ty) -> Type {
auto memrefTy = ty.dyn_cast<MemRefType>();
if (!memrefTy || memrefTy.getMemorySpace()) return ty;
return copyWithMemorySpace(
memrefTy.getContext(), memrefTy,
this->gpuEnabled_ ? placement_utils::kGpu : placement_utils::kCpu);
};

auto convertValue = [&](Value v) {
auto newTy = maybeConvertType(v.getType());
if (newTy != v.getType()) v.setType(newTy);
return success();
};

// update types of results of operations
if (funcOp
->walk([&](Operation* op) {
for (Value value : llvm::to_vector(op->getResults())) {
if (failed(convertValue(value))) return WalkResult::interrupt();
}
return WalkResult::advance();
})
.wasInterrupted()) {
return failure();
}

// update types of block arguments
if (funcOp
->walk([&](Block* block) {
for (Value value : llvm::to_vector((block->getArguments()))) {
if (failed(convertValue(value))) return WalkResult::interrupt();
}
return WalkResult::advance();
})
.wasInterrupted()) {
return failure();
}

// update the type of func op
SmallVector<Type, 4> refinedInputTypes;
for (Type ty : funcOp.getArgumentTypes()) {
refinedInputTypes.push_back(maybeConvertType(ty));
}
SmallVector<Type, 4> refinedOutputTypes;
for (Type ty : funcOp.getResultTypes()) {
refinedOutputTypes.push_back(maybeConvertType(ty));
}
auto newFuncTy = FunctionType::get(funcOp.getContext(), refinedInputTypes,
refinedOutputTypes);
funcOp.setType(newFuncTy);
return success();
}

LogicalResult DiscRewritePayloadIRForRALPass::assignPlacement() {
if (gpuEnabled_)
return getOperation()->emitError()
<< "not support assign placement info for gpu a.t.m.\n";

for (FuncOp funcOp :
llvm::to_vector<4>(getOperation().getOps<func::FuncOp>())) {
if (failed(assignPlacementForFuncOp(funcOp))) return failure();
}

return success();
}

void DiscRewritePayloadIRForRALPass::runOnOperation() {
// 1, rewrite scf.foreach_thread to scf.parallel
if (failed(convertForeachThreadToParallelOp())) {
return signalPassFailure();
}
LLVM_DEBUG(llvm::dbgs() << "After ForeachThreadOp -> ParallelOp:\n"
<< getOperation() << "\n");

// 2, assign placement info for each memref value.
if (failed(assignPlacement())) {
return signalPassFailure();
}
LLVM_DEBUG(llvm::dbgs() << "After assign placement:\n"
<< getOperation() << "\n");
}

} // namespace

std::unique_ptr<OperationPass<ModuleOp>> createDiscRewritePayloadIRForRALPass(
bool gpuEnabled) {
return std::make_unique<DiscRewritePayloadIRForRALPass>(gpuEnabled);
}

} // namespace disc_ral
} // namespace mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// RUN: disc-opt --disc-rewrite-payload-ir-for-ral -split-input-file %s | FileCheck %s

#map = affine_map<()[s0] -> (s0 ceildiv 6)>
#map1 = affine_map<()[s0] -> (s0 ceildiv 16)>
#map2 = affine_map<(d0)[s0] -> (d0 * -6 + s0, 6)>
#map3 = affine_map<(d0)[s0] -> (d0 * -16 + s0, 16)>
#map4 = affine_map<(d0) -> (d0 * 6)>
#map5 = affine_map<(d0) -> (d0 * 16)>
module {
// CHECK-LABEL: @matmul_nn
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32, "cpu">, %[[ARG1:.*]]: memref<?x?xf32, "cpu">, %[[ARG2:.*]]: memref<?x?xf32, "cpu">)
func.func @matmul_nn(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) -> memref<?x?xf32> attributes {test = true} {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%dim = memref.dim %arg0, %c0 : memref<?x?xf32>
%dim_0 = memref.dim %arg1, %c1 : memref<?x?xf32>
%0 = affine.apply #map()[%dim]
%1 = affine.apply #map1()[%dim_0]
%dim_1 = memref.dim %arg0, %c1 : memref<?x?xf32>
// CHECK-NOT: scf.foreach_thread
// CHECK: scf.parallel
scf.foreach_thread (%arg3, %arg4) in (%0, %1) {
%2 = affine.min #map2(%arg3)[%dim]
%3 = affine.min #map3(%arg4)[%dim_0]
%4 = affine.apply #map4(%arg3)
%5 = affine.apply #map5(%arg4)
%subview = memref.subview %arg0[%4, 0] [%2, %dim_1] [1, 1] : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
%subview_2 = memref.subview %arg1[0, %5] [%dim_1, %3] [1, 1] : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
%subview_3 = memref.subview %arg2[%4, %5] [%2, %3] [1, 1] : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
linalg.fill ins(%cst : f32) outs(%subview_3 : memref<?x?xf32, strided<[?, 1], offset: ?>>)
linalg.matmul ins(%subview, %subview_2 : memref<?x?xf32, strided<[?, 1], offset: ?>>, memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%subview_3 : memref<?x?xf32, strided<[?, 1], offset: ?>>)
} {thread_dim_mapping = []}
return %arg2 : memref<?x?xf32>
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,12 @@ def DiscTransformDialectInterpreterPass : Pass<"disc-transform-dialect-interpret
/*default=*/"false", "perform expensive checks to better report errors in the transform IR.">,
];
}

def DiscRewritePayloadIRForRALPass : Pass<"disc-rewrite-payload-ir-for-ral", "ModuleOp"> {
let summary = "Pass to rewrite the payload IR transformed by transform IR to be suitable for RAL.";
let constructor = "createDiscRewritePayloadIRForRALPass()";
let options = [
Option<"gpuEnabled_", "gpu-enabled", "bool",
/*default=*/"false", "whether gpu is available.">,
];
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,7 @@ namespace disc_ral {

namespace {

// Returns a new memref type with provided memory space
MemRefType copyWithMemorySpace(MLIRContext* ctx, MemRefType type,
StringRef memory_space) {
Attribute memSpace = StringAttr::get(ctx, memory_space);
return MemRefType::get(type.getShape(), type.getElementType(),
type.getLayout(), memSpace);
}
using placement_utils::copyWithMemorySpace;

// return a new memref type with provided memory space if the input type if a
// memref type otherwise return the original type.
Expand Down
8 changes: 8 additions & 0 deletions tao_compiler/mlir/disc/transforms/placement_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,5 +232,13 @@ LogicalResult parseEntryFunctionOutputPlacements(
return success();
}

// Returns a new memref type with provided memory space
MemRefType copyWithMemorySpace(MLIRContext* ctx, MemRefType type,
StringRef memory_space) {
Attribute memSpace = StringAttr::get(ctx, memory_space);
return MemRefType::get(type.getShape(), type.getElementType(),
type.getLayout(), memSpace);
}

} // namespace placement_utils
} // namespace mlir
4 changes: 4 additions & 0 deletions tao_compiler/mlir/disc/transforms/placement_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/disc/IR/hlo_disc_ops.h"
Expand Down Expand Up @@ -97,6 +98,9 @@ inline bool isMarkShapeCalcTargetOp(Operation* op) {
return isTensorDialect(op) || isMhloDialect(op) || isStdOnTensor(op);
}

// Returns a new memref type with provided memory space
MemRefType copyWithMemorySpace(MLIRContext* ctx, MemRefType type,
StringRef memory_space);
} // namespace placement_utils
} // namespace mlir

Expand Down

0 comments on commit 33cd154

Please sign in to comment.