Skip to content

Commit

Permalink
Revert "Tile and distribute linalg.generic in DispatchLinalgOnTensors (
Browse files Browse the repository at this point in the history
…#5159)" (#5170)

This reverts commit 156f0bb.
  • Loading branch information
ThomasRaoux authored Mar 19, 2021
1 parent f573559 commit 5582b5a
Show file tree
Hide file tree
Showing 9 changed files with 20 additions and 187 deletions.
45 changes: 0 additions & 45 deletions iree/compiler/Conversion/LinalgToNVVM/ConvertToNVVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,41 +126,6 @@ class ConvertIREEBindingOp : public ConvertToLLVMPattern {
}
};

/// A pattern to convert hal.interface.workgroup.id/count/size into
/// corresponding NVVM ops.
template <typename InterfaceOpTy, typename XOp, typename YOp, typename ZOp>
struct HALInterfaceWorkgroupOpsConverter final
: public OpConversionPattern<InterfaceOpTy> {
using OpConversionPattern<InterfaceOpTy>::OpConversionPattern;

LogicalResult matchAndRewrite(
InterfaceOpTy op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Type i32Type = rewriter.getI32Type();
Value newOp;
int32_t index = static_cast<int32_t>(op.dimension().getSExtValue());
switch (index) {
case 0:
newOp = rewriter.create<XOp>(loc, i32Type);
break;
case 1:
newOp = rewriter.create<YOp>(loc, i32Type);
break;
case 2:
newOp = rewriter.create<ZOp>(loc, i32Type);
break;
default:
return failure();
}

newOp =
rewriter.create<LLVM::SExtOp>(loc, rewriter.getIntegerType(64), newOp);
rewriter.replaceOp(op, {newOp});
return success();
}
};

/// A pass that replaces all occurrences of GPU device operations with their
/// corresponding NVVM equivalent.
///
Expand Down Expand Up @@ -192,16 +157,6 @@ struct ConvertToNVVMPass
OwningRewritePatternList llvmPatterns;
llvmPatterns.insert<ConvertFunc, ConvertIREEBindingOp>(m.getContext(),
converter);
llvmPatterns
.insert<HALInterfaceWorkgroupOpsConverter<
IREE::HAL::InterfaceWorkgroupIDOp, NVVM::BlockIdXOp,
NVVM::BlockIdYOp, NVVM::BlockIdZOp>,
HALInterfaceWorkgroupOpsConverter<
IREE::HAL::InterfaceWorkgroupCountOp, NVVM::GridDimXOp,
NVVM::GridDimYOp, NVVM::GridDimZOp>,
HALInterfaceWorkgroupOpsConverter<
IREE::HAL::InterfaceWorkgroupSizeOp, NVVM::BlockDimXOp,
NVVM::BlockDimYOp, NVVM::BlockDimZOp>>(m.getContext());
populateStdToLLVMConversionPatterns(converter, llvmPatterns);
populateGpuToNVVMConversionPatterns(converter, llvmPatterns);
LLVMConversionTarget target(getContext());
Expand Down
1 change: 0 additions & 1 deletion iree/compiler/Conversion/LinalgToNVVM/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ static void addLinalgToNVVMPasses(OpPassManager &pm) {
//===--------------------------------------------------------------------===//
// Initial clean up.
//===--------------------------------------------------------------------===//
pm.addPass(createLowerAffinePass());
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());

Expand Down
66 changes: 0 additions & 66 deletions iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,51 +312,6 @@ static LogicalResult getConfigForCooperativeMatmul(
return success();
}

/// Launch config for element-wise linalg.generic.
template <>
LogicalResult getOpLaunchConfig(linalg::GenericOp op,
const spirv::TargetEnv &targetEnv,
const SPIRVCodegenOptions &options,
TileSizesListType &tileSizes,
LaunchConfigInfo &config) {
int64_t subgroupSize =
targetEnv.getResourceLimits().subgroup_size().getValue().getSExtValue();
config.workgroupSize[0] = subgroupSize;
config.workgroupSize[1] = 1;
config.workgroupSize[2] = 1;
ShapedType outputShape = op.getOutputShapedType(0);

SmallVector<int64_t, 4> sizes;
// When Vectororization is not enabled we skil the second level of tiling and
// fall back to convertToGPU which will map one element to one thread. To
// avoid a mismatch in the number of workgroup dispatched, we pick a tile size
// to have one element per thread.
// TODO: Remove this once we switch to linalg on tensor path.
if (options.enableVectorization) {
sizes.append({4 * subgroupSize, 2 * subgroupSize});
}
sizes.push_back(subgroupSize);
// Use the first tile size that can divide the shape. If the shape is not
// aligned on any of the tile sizes pick the smallest tile of one element per
// thread.
int64_t lowerTs = config.workgroupSize[0];
for (int64_t size : sizes) {
if (outputShape.getShape().back() % size != 0) continue;
lowerTs = size;
break;
}
SmallVector<int64_t, 4> ts;
size_t numLoops = getNumOuterParallelLoops(op);
ts.resize(numLoops, 1);
ts.back() = lowerTs;
tileSizes.emplace_back(ts);
tileSizes.emplace_back();
ts.back() = lowerTs / subgroupSize;
tileSizes.emplace_back(ts);
config.vectorize = options.enableVectorization;
return success();
}

/// Launch configuration for different known GPU configuration.
static LogicalResult getTargetSpecificConfig(
linalg::MatmulOp op, const spirv::TargetEnv &targetEnv,
Expand Down Expand Up @@ -753,27 +708,6 @@ Optional<LaunchConfig> initGPULaunchConfig(
#undef DISPATCH
}

if (!rootOperation) {
for (linalg::LinalgOp linalgOp : linalgOps) {
if (auto op = dyn_cast<linalg::GenericOp>(linalgOp.getOperation())) {
if (getNumOuterParallelLoops(linalgOp) == 0 ||
llvm::any_of(linalgOp.getIndexingMaps(), [](AffineMap &map) {
return !map.isProjectedPermutation();
})) {
continue;
}
TileSizesListType tileSizesInfo;
if (failed(getOpLaunchConfig(op, targetEnv, options, tileSizesInfo,
config))) {
continue;
}
launchConfig.setTileSizes(op, tileSizesInfo);
launchConfig.setRootOperation(op);
break;
}
}
}

launchConfig.setWorkgroupSize(config.workgroupSize);
launchConfig.setNumSubgroups(config.numSubgroups);
launchConfig.setVectorize(config.vectorize);
Expand Down
36 changes: 11 additions & 25 deletions iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -612,13 +612,13 @@ struct TileAndDistributeOnTensorsPattern
SmallVector<Value, 4> count = llvm::to_vector<4>(
llvm::map_range(linalgOp.createLoopRanges(rewriter, loc),
[](Range r) { return r.size; }));
size_t numParrallelLoops = getNumOuterParallelLoops(op);
// Flow currently allows only 3 level of tiling. If there are more parallel
// dimension drop the higher dimensions.
if (numParrallelLoops > kNumMaxParallelDims) {
count.erase(
count.begin(),
std::next(count.begin(), numParrallelLoops - kNumMaxParallelDims));
// NOTE: Special treatment for convolution, which have more than 3 parallel
// dimensions. We want to ignore the batch dimension and tile along the
// next three.
// TODO(#5048): figure out a better way to avoid this special case.
if (isa<linalg::ConvInputNHWCFilterHWCFOp,
linalg::DepthwiseConvInputNHWCFilterHWCOp>(op)) {
count.erase(count.begin());
}
count.resize(getNumTilableLoops(op));
auto workload = convertToWorkload(rewriter, loc, count);
Expand Down Expand Up @@ -849,23 +849,6 @@ static void decideFusableLinalgOps(FuncOp funcOp) {
builder.getI64ArrayAttr(fusionGroups));
}
}

// As a second step mark all the element-wise linalg ops not fused as roots
// so that they get tiled and distributed.
for (linalg::LinalgOp linalgOp : linalgOps) {
Operation *op = linalgOp.getOperation();
if (!isa<linalg::GenericOp>(op) ||
getNumOuterParallelLoops(linalgOp) == 0 ||
llvm::any_of(linalgOp.getIndexingMaps(), [](AffineMap &map) {
return !map.isProjectedPermutation();
})) {
continue;
}

if (op->hasAttr(kRootOpAttr) || op->hasAttr(kFusionGroupsAttr)) continue;
unsigned currGroupNum = numRootOps++;
op->setAttr(kRootOpAttr, builder.getI64IntegerAttr(currGroupNum));
}
}
}

Expand Down Expand Up @@ -923,8 +906,11 @@ void DispatchLinalgOnTensorsPass::runOnOperation() {
// parallel dimensions. We want to ignore the batch dimension and tile
// along the next three. That means setting the first position to zero.
// TODO(#5048): figure out a better way to avoid this special case.
bool isConvOp = isa<linalg::ConvInputNHWCFilterHWCFOp,
linalg::DepthwiseConvInputNHWCFilterHWCOp>(op);

for (size_t dim = 0; dim < numTiledLoops; ++dim) {
useTileSizes[numParallelDims - dim - 1] =
useTileSizes[(isConvOp ? numParallelDims : numTiledLoops) - dim - 1] =
buildFlowWorkgroupInfoOp<Flow::DispatchWorkgroupSizeOp>(builder, dim);
}
return useTileSizes;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ func @generic_op(%A: tensor<?x?xf32>, %B: tensor<?xf32>) -> tensor<?x?xf32> {
} -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
// CHECK: #[[MULMAP:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
// CHECK: func @generic_op
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?xf32>
Expand All @@ -81,27 +80,13 @@ func @generic_op(%A: tensor<?x?xf32>, %B: tensor<?xf32>) -> tensor<?x?xf32> {
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<writeonly:?x?xf32>
// CHECK-DAG: %[[WGSIZE_X:.+]] = flow.dispatch.workgroup.size[0]
// CHECK-DAG: %[[WGSIZE_Y:.+]] = flow.dispatch.workgroup.size[1]
// CHECK-DAG: %[[WGID_X:.+]] = flow.dispatch.workgroup.id[0]
// CHECK-DAG: %[[WGID_Y:.+]] = flow.dispatch.workgroup.id[1]
// CHECK-DAG: %[[WGCOUNT_X:.+]] = flow.dispatch.workgroup.count[0]
// CHECK-DAG: %[[WGCOUNT_Y:.+]] = flow.dispatch.workgroup.count[1]
// CHECK: %[[OFFSET_Y:.+]] = affine.apply #[[MULMAP]]()[%[[WGID_Y]], %[[WGSIZE_Y]]]
// CHECK: %[[STEP_Y:.+]] = affine.apply #[[MULMAP]]()[%[[WGCOUNT_Y]], %[[WGSIZE_Y]]]
// CHECK: scf.for %[[ARG7:.+]] = %[[OFFSET_Y]]
// CHECK-SAME: to %{{.+}} step %[[STEP_Y]]
// CHECK: %[[OFFSET_X:.+]] = affine.apply #[[MULMAP]]()[%[[WGID_X]], %[[WGSIZE_X]]]
// CHECK: %[[STEP_X:.+]] = affine.apply #[[MULMAP]]()[%[[WGCOUNT_X]], %[[WGSIZE_X]]]
// CHECK: scf.for %[[ARG8:.+]] = %[[OFFSET_X]]
// CHECK-SAME: to %{{.+}} step %[[STEP_X]]
// CHECK-DAG: %[[LOAD2:.+]] = flow.dispatch.tensor.load %[[ARG2]]
// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor
// CHECK-DAG: %[[LOAD3:.+]] = flow.dispatch.tensor.load %[[ARG3]]
// CHECK: %[[RESULT:.+]] = linalg.generic
// CHECK-SAME: ins(%[[LOAD2]], %[[LOAD3]] : tensor<?x?xf32>, tensor<?xf32>)
// CHECK-SAME: outs(%[[INIT]] : tensor<?x?xf32>)
// CHECK: flow.dispatch.tensor.store %[[RESULT]], %[[ARG6]]
// CHECK-DAG: %[[LOAD2:.+]] = flow.dispatch.tensor.load %[[ARG2]] : !flow.dispatch.tensor<readonly:?x?xf32>
// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor [%[[ARG4]], %[[ARG5]]]
// CHECK-DAG: %[[LOAD3:.+]] = flow.dispatch.tensor.load %[[ARG3]] : !flow.dispatch.tensor<readonly:?xf32>
// CHECK: %[[RESULT:.+]] = linalg.generic
// CHECK-SAME: ins(%[[LOAD2]], %[[LOAD3]] : tensor<?x?xf32>, tensor<?xf32>)
// CHECK-SAME: outs(%[[INIT]] : tensor<?x?xf32>)
// CHECK: flow.dispatch.tensor.store %[[RESULT]], %[[ARG6]]

// -----

Expand Down Expand Up @@ -310,8 +295,6 @@ func @generic_op_4D
} -> tensor<?x?x?x?xf32>
return %1 : tensor<?x?x?x?xf32>
}
// For ops of rank greater than 3 we serialized the higher dimension. When flow
// supports larger ranks this can be changed.
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
// CHECK: func @generic_op_4D
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?x?xf32>
Expand All @@ -323,7 +306,8 @@ func @generic_op_4D
// CHECK-DAG: %[[D1:.+]] = memref.dim %[[ARG0]], %[[C1]]
// CHECK-DAG: %[[D2:.+]] = memref.dim %[[ARG0]], %[[C2]]
// CHECK-DAG: %[[D3:.+]] = memref.dim %[[ARG0]], %[[C3]]
// CHECK: flow.dispatch.workgroups[%[[D3]], %[[D2]], %[[D1]]]
// CHECK: %[[WORKLOAD_Z:.+]] = affine.apply #[[MAP0]]()[%[[D0]], %[[D1]]]
// CHECK: flow.dispatch.workgroups[%[[D3]], %[[D2]], %[[WORKLOAD_Z]]]

// -----

Expand Down
1 change: 0 additions & 1 deletion iree/compiler/Dialect/HAL/Target/CUDA/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ cc_library(
"@llvm-project//mlir:LLVMDialect",
"@llvm-project//mlir:LLVMToLLVMIRTranslation",
"@llvm-project//mlir:NVVMDialect",
"@llvm-project//mlir:NVVMToLLVMIRTranslation",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:ToLLVMIRTranslation",
Expand Down
1 change: 0 additions & 1 deletion iree/compiler/Dialect/HAL/Target/CUDA/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ iree_cc_library(
MLIRLLVMIR
MLIRLLVMToLLVMIRTranslation
MLIRNVVMIR
MLIRNVVMToLLVMIRTranslation
MLIRPass
MLIRSupport
MLIRTargetLLVMIRExport
Expand Down
2 changes: 0 additions & 2 deletions iree/compiler/Dialect/HAL/Target/CUDA/CUDATarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Export.h"

namespace mlir {
Expand Down Expand Up @@ -65,7 +64,6 @@ class CUDATargetBackend final : public TargetBackend {

void getDependentDialects(DialectRegistry &registry) const override {
mlir::registerLLVMDialectTranslation(registry);
mlir::registerNVVMDialectTranslation(registry);
}

void buildTranslationPassPipeline(OpPassManager &passManager) override {
Expand Down
21 changes: 0 additions & 21 deletions iree/test/e2e/xla_ops/add.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,3 @@ func @tensor() attributes { iree.module.export } {
check.expect_almost_eq_const(%result, dense<[6.0, 8.0, 10.0, 12.0]> : tensor<4xf32>) : tensor<4xf32>
return
}

func @tensor_4d() attributes { iree.module.export } {
%0 = iree.unfoldable_constant dense<[[[[1.0, 2.0], [3.0, 4.0]],
[[5.0, 6.0], [7.0, 8.0]]],
[[[9.0, 10.0], [11.0, 12.0]],
[[13.0, 14.0], [15.0, 16.0]]]]> :
tensor<2x2x2x2xf32>
%1 = iree.unfoldable_constant dense<[[[[1.0, 2.0], [3.0, 4.0]],
[[5.0, 6.0], [7.0, 8.0]]],
[[[9.0, 10.0], [11.0, 12.0]],
[[13.0, 14.0], [15.0, 16.0]]]]> :
tensor<2x2x2x2xf32>
%result = "mhlo.add"(%0, %1) : (tensor<2x2x2x2xf32>, tensor<2x2x2x2xf32>)
-> tensor<2x2x2x2xf32>
check.expect_almost_eq_const(%result, dense<[[[[2.0, 4.0], [6.0, 8.0]],
[[10.0, 12.0], [14.0, 16.0]]],
[[[18.0, 20.0], [22.0, 24.0]],
[[26.0, 28.0], [30.0, 32.0]]]]> :
tensor<2x2x2x2xf32>) : tensor<2x2x2x2xf32>
return
}

0 comments on commit 5582b5a

Please sign in to comment.