Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[transform] make the transformed payload ir suitable for RAL #784

Merged
merged 1 commit into from
Nov 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions tao_compiler/mlir/disc/tools/disc-transform/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,27 @@ cc_library(
alwayslink = 1,
)

cc_library(
name = "rewrite-payload-ir-for-ral",
srcs = ["transforms/rewrite_payload_ir_for_ral.cc"],
deps = [
":pass_details",
"//tensorflow/compiler/mlir/disc:placement_utils",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:LinalgDialect",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
alwayslink = 1,
)

cc_library(
name = "all_passes",
hdrs = [
Expand All @@ -194,6 +215,7 @@ cc_library(
],
deps = [
":legalize_lmhlo_fusion_to_linalg",
":rewrite-payload-ir-for-ral",
":transform_dialect_interpreter",
"@llvm-project//mlir:Pass",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ std::unique_ptr<OperationPass<ModuleOp>>
createDiscTransformDialectInterpreterPass(const std::string& fileName = "",
bool enableExpensiveChecks = false);

// Converts the transformed payload IR to be suitable for RAL.
std::unique_ptr<OperationPass<ModuleOp>> createDiscRewritePayloadIRForRALPass(
bool gpuEnabled = false);

} // namespace disc_ral
} // namespace mlir

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
// Copyright 2022 The BladeDISC Authors. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "tensorflow/compiler/mlir/disc/tools/disc-transform/transforms/PassDetail.h"
#include "tensorflow/compiler/mlir/disc/transforms/placement_utils.h"

#define DEBUG_TYPE "disc-rewrite-payload-ir-for-ral"

// This file implements the logic to convert the transformed payload IR to be
// suitable for RAL.

namespace mlir {
namespace disc_ral {
namespace {

using func::FuncOp;
using placement_utils::copyWithMemorySpace;
using scf::ForeachThreadOp;
using scf::ParallelOp;

struct DiscRewritePayloadIRForRALPass
: public DiscRewritePayloadIRForRALPassBase<
DiscRewritePayloadIRForRALPass> {
explicit DiscRewritePayloadIRForRALPass(bool gpuEnabled)
: DiscRewritePayloadIRForRALPassBase<DiscRewritePayloadIRForRALPass>::
DiscRewritePayloadIRForRALPassBase() {
this->gpuEnabled_ = gpuEnabled;
}
void runOnOperation() override;

// replace scf::foreach_thread op with scf::parallel op
LogicalResult convertForeachThreadToParallelOp();
LogicalResult funcLevelConvertForeachThreadToParallelOp(FuncOp funcOp);

// assign placement info for each memref value, e.g. memref<f32> ->
// memref<f32, "cpu">
LogicalResult assignPlacement();
LogicalResult assignPlacementForFuncOp(FuncOp funcOp);
};

LogicalResult
DiscRewritePayloadIRForRALPass::funcLevelConvertForeachThreadToParallelOp(
FuncOp funcOp) {
SmallVector<ForeachThreadOp> forOps;
funcOp.walk([&](ForeachThreadOp op) { forOps.push_back(op); });

OpBuilder b(funcOp);
for (ForeachThreadOp foreachThreadOp : forOps) {
if (foreachThreadOp.getOutputs().size() != 0)
return foreachThreadOp->emitError()
<< "Not support ForeachThreadOp with outputs a.t.m.\n";

b.setInsertionPoint(foreachThreadOp);
Location loc = foreachThreadOp.getLoc();
int64_t rank = foreachThreadOp.getRank();
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
SmallVector<Value> lowerBounds(rank, zero);
SmallVector<Value> upperBounds = foreachThreadOp.getNumThreads();
SmallVector<Value> steps(rank, one);

auto parallelOp =
b.create<ParallelOp>(loc, lowerBounds, upperBounds, steps);
BlockAndValueMapping mapping;
for (const auto& z : llvm::zip(foreachThreadOp.getThreadIndices(),
parallelOp.getInductionVars()))
mapping.map(std::get<0>(z), std::get<1>(z));
b.setInsertionPointToStart(parallelOp.getBody());
for (auto& nestedOp : foreachThreadOp.getBody()->without_terminator()) {
Operation* cloned = b.clone(nestedOp, mapping);
}
foreachThreadOp->erase();
}
return success();
}

LogicalResult
DiscRewritePayloadIRForRALPass::convertForeachThreadToParallelOp() {
for (auto funcOp : getOperation().getOps<FuncOp>()) {
if (failed(funcLevelConvertForeachThreadToParallelOp(funcOp)))
return failure();
}
return success();
}

LogicalResult DiscRewritePayloadIRForRALPass::assignPlacementForFuncOp(
FuncOp funcOp) {
auto maybeConvertType = [&](Type ty) -> Type {
auto memrefTy = ty.dyn_cast<MemRefType>();
if (!memrefTy || memrefTy.getMemorySpace()) return ty;
return copyWithMemorySpace(
memrefTy.getContext(), memrefTy,
this->gpuEnabled_ ? placement_utils::kGpu : placement_utils::kCpu);
};

auto convertValue = [&](Value v) {
auto newTy = maybeConvertType(v.getType());
if (newTy != v.getType()) v.setType(newTy);
return success();
};

// update types of results of operations
if (funcOp
->walk([&](Operation* op) {
for (Value value : llvm::to_vector(op->getResults())) {
if (failed(convertValue(value))) return WalkResult::interrupt();
}
return WalkResult::advance();
})
.wasInterrupted()) {
return failure();
}

// update types of block arguments
if (funcOp
->walk([&](Block* block) {
for (Value value : llvm::to_vector((block->getArguments()))) {
if (failed(convertValue(value))) return WalkResult::interrupt();
}
return WalkResult::advance();
})
.wasInterrupted()) {
return failure();
}

// update the type of func op
SmallVector<Type, 4> refinedInputTypes;
for (Type ty : funcOp.getArgumentTypes()) {
refinedInputTypes.push_back(maybeConvertType(ty));
}
SmallVector<Type, 4> refinedOutputTypes;
for (Type ty : funcOp.getResultTypes()) {
refinedOutputTypes.push_back(maybeConvertType(ty));
}
auto newFuncTy = FunctionType::get(funcOp.getContext(), refinedInputTypes,
refinedOutputTypes);
funcOp.setType(newFuncTy);
return success();
}

LogicalResult DiscRewritePayloadIRForRALPass::assignPlacement() {
if (gpuEnabled_)
return getOperation()->emitError()
<< "not support assign placement info for gpu a.t.m.\n";

for (FuncOp funcOp :
llvm::to_vector<4>(getOperation().getOps<func::FuncOp>())) {
if (failed(assignPlacementForFuncOp(funcOp))) return failure();
}

return success();
}

void DiscRewritePayloadIRForRALPass::runOnOperation() {
// 1, rewrite scf.foreach_thread to scf.parallel
if (failed(convertForeachThreadToParallelOp())) {
return signalPassFailure();
}
LLVM_DEBUG(llvm::dbgs() << "After ForeachThreadOp -> ParallelOp:\n"
<< getOperation() << "\n");

// 2, assign placement info for each memref value.
if (failed(assignPlacement())) {
return signalPassFailure();
}
LLVM_DEBUG(llvm::dbgs() << "After assign placement:\n"
<< getOperation() << "\n");
}

} // namespace

std::unique_ptr<OperationPass<ModuleOp>> createDiscRewritePayloadIRForRALPass(
bool gpuEnabled) {
return std::make_unique<DiscRewritePayloadIRForRALPass>(gpuEnabled);
}

} // namespace disc_ral
} // namespace mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// RUN: disc-opt --disc-rewrite-payload-ir-for-ral -split-input-file %s | FileCheck %s

#map = affine_map<()[s0] -> (s0 ceildiv 6)>
#map1 = affine_map<()[s0] -> (s0 ceildiv 16)>
#map2 = affine_map<(d0)[s0] -> (d0 * -6 + s0, 6)>
#map3 = affine_map<(d0)[s0] -> (d0 * -16 + s0, 16)>
#map4 = affine_map<(d0) -> (d0 * 6)>
#map5 = affine_map<(d0) -> (d0 * 16)>
module {
// CHECK-LABEL: @matmul_nn
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32, "cpu">, %[[ARG1:.*]]: memref<?x?xf32, "cpu">, %[[ARG2:.*]]: memref<?x?xf32, "cpu">)
func.func @matmul_nn(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) -> memref<?x?xf32> attributes {test = true} {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%dim = memref.dim %arg0, %c0 : memref<?x?xf32>
%dim_0 = memref.dim %arg1, %c1 : memref<?x?xf32>
%0 = affine.apply #map()[%dim]
%1 = affine.apply #map1()[%dim_0]
%dim_1 = memref.dim %arg0, %c1 : memref<?x?xf32>
// CHECK-NOT: scf.foreach_thread
// CHECK: scf.parallel
scf.foreach_thread (%arg3, %arg4) in (%0, %1) {
%2 = affine.min #map2(%arg3)[%dim]
%3 = affine.min #map3(%arg4)[%dim_0]
%4 = affine.apply #map4(%arg3)
%5 = affine.apply #map5(%arg4)
%subview = memref.subview %arg0[%4, 0] [%2, %dim_1] [1, 1] : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
%subview_2 = memref.subview %arg1[0, %5] [%dim_1, %3] [1, 1] : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
%subview_3 = memref.subview %arg2[%4, %5] [%2, %3] [1, 1] : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
linalg.fill ins(%cst : f32) outs(%subview_3 : memref<?x?xf32, strided<[?, 1], offset: ?>>)
linalg.matmul ins(%subview, %subview_2 : memref<?x?xf32, strided<[?, 1], offset: ?>>, memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%subview_3 : memref<?x?xf32, strided<[?, 1], offset: ?>>)
} {thread_dim_mapping = []}
return %arg2 : memref<?x?xf32>
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,12 @@ def DiscTransformDialectInterpreterPass : Pass<"disc-transform-dialect-interpret
/*default=*/"false", "perform expensive checks to better report errors in the transform IR.">,
];
}

def DiscRewritePayloadIRForRALPass : Pass<"disc-rewrite-payload-ir-for-ral", "ModuleOp"> {
let summary = "Pass to rewrite the payload IR transformed by transform IR to be suitable for RAL.";
let constructor = "createDiscRewritePayloadIRForRALPass()";
let options = [
Option<"gpuEnabled_", "gpu-enabled", "bool",
/*default=*/"false", "whether gpu is available.">,
];
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,7 @@ namespace disc_ral {

namespace {

// Returns a new memref type with provided memory space
MemRefType copyWithMemorySpace(MLIRContext* ctx, MemRefType type,
StringRef memory_space) {
Attribute memSpace = StringAttr::get(ctx, memory_space);
return MemRefType::get(type.getShape(), type.getElementType(),
type.getLayout(), memSpace);
}
using placement_utils::copyWithMemorySpace;

// return a new memref type with provided memory space if the input type if a
// memref type otherwise return the original type.
Expand Down
8 changes: 8 additions & 0 deletions tao_compiler/mlir/disc/transforms/placement_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,5 +232,13 @@ LogicalResult parseEntryFunctionOutputPlacements(
return success();
}

// Returns a new memref type with provided memory space
MemRefType copyWithMemorySpace(MLIRContext* ctx, MemRefType type,
StringRef memory_space) {
Attribute memSpace = StringAttr::get(ctx, memory_space);
return MemRefType::get(type.getShape(), type.getElementType(),
type.getLayout(), memSpace);
}

} // namespace placement_utils
} // namespace mlir
4 changes: 4 additions & 0 deletions tao_compiler/mlir/disc/transforms/placement_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/disc/IR/hlo_disc_ops.h"
Expand Down Expand Up @@ -97,6 +98,9 @@ inline bool isMarkShapeCalcTargetOp(Operation* op) {
return isTensorDialect(op) || isMhloDialect(op) || isStdOnTensor(op);
}

// Returns a new memref type with provided memory space
MemRefType copyWithMemorySpace(MLIRContext* ctx, MemRefType type,
StringRef memory_space);
} // namespace placement_utils
} // namespace mlir

Expand Down