Skip to content

Commit

Permalink
Merge branch 'google' into main-to-google
Browse files Browse the repository at this point in the history
  • Loading branch information
KoolJBlack committed Mar 31, 2021
2 parents 1c59bd1 + 7a8867c commit 431ede6
Show file tree
Hide file tree
Showing 27 changed files with 125 additions and 99 deletions.
8 changes: 4 additions & 4 deletions SUBMODULE_VERSIONS.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
4fb0ff7069bd88ee85902f4d0bb62794e5f6d021 third_party/flatcc
b1fbd33c06cdb0024c67733c6fdec2009d17b384 third_party/googletest
88b845dee001723c4a0db1fe5477de735b6d3bb0 third_party/liburing
189e771009a640214e08e855830ae6f15a83c655 third_party/llvm-bazel
1f6a57c1a0fad922e04a2b1f414b092d4b0cd8b0 third_party/llvm-project
fad5434701aa52c920404c81532aa3ebf44bc3b7 third_party/llvm-bazel
c06a8f9caa51c7ea71dac716e0a35f5e343e4546 third_party/llvm-project
dde739ffd00a6fa99175cf3c0f28e4b763dc6f5f third_party/mlir-emitc
cbef26c6a8f1e4be3f4cfb902db992c45e93b7a6 third_party/mlir-hlo
e78c59d9277935f1d4d3b40d08e447be91be832a third_party/mlir-hlo
2b2bd45bbf9be04fd22ece5cc1f54679202e9257 third_party/pffft
d8c7ee00a687ac369e62e2032514a93a9b413502 third_party/pybind11
2887692065c38ef6617f423feafc6b69dd0a0681 third_party/ruy
685f86471e9d26b3eb7676695a2e2cefb4551ae9 third_party/spirv_cross
f8bf11a0253a32375c32cad92c841237b96696c0 third_party/spirv_headers
da3da1e8a81a9866d98bcfe54eb21ec27cab7000 third_party/tensorflow
8bd49272bc4d80523516d0aa00986114a48d8700 third_party/tensorflow
8732f0e94e4e41049a43029202bda94d7b4e85da third_party/tracy
9bd3f561bcee3f01d22912de10bb07ce4e23d378 third_party/vulkan_headers
3528e2aed3e8808f33e1e7d63eeb1560456a605a third_party/vulkan_memory_allocator
Expand Down
6 changes: 2 additions & 4 deletions experimental/ModelBuilder/ModelRunner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,8 @@ static void addVulkanLoweringPass(mlir::PassManager& manager) {
modulePM.addPass(mlir::spirv::createLowerABIAttributesPass());
modulePM.addPass(mlir::spirv::createUpdateVersionCapabilityExtensionPass());
manager.addPass(mlir::createConvertGpuLaunchFuncToVulkanLaunchFuncPass());
mlir::LowerToLLVMOptions llvmOptions = {
/*useBarePtrCallConv =*/false,
/*emitCWrappers = */ true,
/*indexBitwidth =*/mlir::kDeriveIndexBitwidthFromDataLayout};
mlir::LowerToLLVMOptions llvmOptions(manager.getContext());
llvmOptions.emitCWrappers = true;
manager.addPass(createLowerToLLVMPass(llvmOptions));
manager.addPass(mlir::createConvertVulkanLaunchFuncToVulkanCallsPass());
}
Expand Down
2 changes: 1 addition & 1 deletion experimental/ModelBuilder/VulkanWrapperPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ void AddVulkanLaunchWrapper::runOnOperation() {
}

LogicalResult AddVulkanLaunchWrapper::declareVulkanLaunchFunc(Location loc) {
OpBuilder builder(getOperation().getBody()->getTerminator());
auto builder = OpBuilder::atBlockEnd(getOperation().getBody());

SmallVector<Type, 8> vulkanLaunchTypes(3, builder.getIndexType());
vulkanLaunchTypes.insert(vulkanLaunchTypes.end(), args.begin(), args.end());
Expand Down
6 changes: 2 additions & 4 deletions experimental/ModelBuilder/test/BenchMatMulVectorGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,8 @@ static void addLoweringPasses(mlir::PassManager &pm,
mlir::spirv::createUpdateVersionCapabilityExtensionPass());

pm.addPass(mlir::createAddVulkanLaunchWrapperPass(numWorkgroups, args));
mlir::LowerToLLVMOptions llvmOptions = {
/*useBarePtrCallConv=*/false,
/*emitCWrappers=*/true,
/*indexBitwidth=*/mlir::kDeriveIndexBitwidthFromDataLayout};
mlir::LowerToLLVMOptions llvmOptions(pm.getContext());
llvmOptions.emitCWrappers = true;
pm.addPass(createLowerToLLVMPass(llvmOptions));
pm.addPass(mlir::createConvertVulkanLaunchFuncToVulkanCallsPass());
}
Expand Down
6 changes: 2 additions & 4 deletions experimental/ModelBuilder/test/TestVectorToGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,8 @@ static void addLoweringPasses(mlir::PassManager &pm,
mlir::spirv::createUpdateVersionCapabilityExtensionPass());

pm.addPass(mlir::createAddVulkanLaunchWrapperPass(workgroupSize, args));
mlir::LowerToLLVMOptions llvmOptions = {
/*useBarePtrCallConv=*/false,
/*emitCWrappers=*/true,
/*indexBitwidth=*/mlir::kDeriveIndexBitwidthFromDataLayout};
mlir::LowerToLLVMOptions llvmOptions(pm.getContext());
llvmOptions.emitCWrappers = true;
pm.addPass(createLowerToLLVMPass(llvmOptions));
pm.addPass(mlir::createConvertVulkanLaunchFuncToVulkanCallsPass());
}
Expand Down
5 changes: 5 additions & 0 deletions integrations/tensorflow/iree_tf_compiler/TF/ConvertToMHLO.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ class ConvertToMHLOPass : public PassWrapper<ConvertToMHLOPass, FunctionPass> {
target.addLegalOp<mlir::CallOp>();
target.addLegalOp<mlir::tensor::CastOp>();

// TODO(suderman): Enable logicistic op for lowering once the op is
// supported in IREE. Also, remove the numerically unstable ConvertSigmoidOp
// pattern in the legalize-tf pass.
target.addIllegalOp<mhlo::LogisticOp>();

DenseSet<Operation *> prevUnconvertedOps;
DenseSet<Operation *> unconvertedOps;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,35 @@ func @f(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<3xf32>) {
// CHECK: return [[VAL8]]
return %29 : tensor<3xf32>
}

// CHECK-LABEL: @sigmoid
func @sigmoid(%arg0: tensor<2xf32>) -> tensor<2xf32> {
// CHECK-DAG: [[HALF:%.+]] = mhlo.constant dense<5.000000e-01> : tensor<2xf32>
// CHECK-DAG: [[R1:%.+]] = mhlo.multiply %arg0, [[HALF]] : tensor<2xf32>
// CHECK-DAG: [[R2:%.+]] = "mhlo.tanh"([[R1]]) : (tensor<2xf32>) -> tensor<2xf32>
// CHECK-DAG: [[R3:%.+]] = mhlo.multiply [[R2]], [[HALF]] : tensor<2xf32>
// CHECK-DAG: [[R4:%.+]] = mhlo.add [[R3]], [[HALF]] : tensor<2xf32>
%0 = "tf.Sigmoid"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}

// CHECK-LABEL: @sigmoid_complex
func @sigmoid_complex(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> {
// CHECK: [[R0:%.+]] = mhlo.constant dense<(5.000000e-01,0.000000e+00)> : tensor<complex<f32>>
// CHECK-NOT: tf.Sigmoid
%0 = "tf.Sigmoid"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>>
return %0 : tensor<2xcomplex<f32>>
}

// CHECK-LABEL: @sigmoid_unranked
func @sigmoid_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
// CHECK-DAG: [[SCALAR:%.+]] = mhlo.constant dense<5.000000e-01> : tensor<f32>
// CHECK-DAG: [[SHAPE_VAL:%.+]] = shape.shape_of %arg0 : tensor<*xf32> -> tensor<?xindex>
// CHECK-DAG: [[HALF:%.+]] = "mhlo.dynamic_broadcast_in_dim"([[SCALAR]], [[SHAPE_VAL]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-DAG: [[R1:%.+]] = mhlo.multiply %arg0, [[HALF]] : tensor<*xf32>
// CHECK-DAG: [[R2:%.+]] = "mhlo.tanh"([[R1]]) : (tensor<*xf32>) -> tensor<*xf32>
// CHECK-DAG: [[R3:%.+]] = mhlo.multiply [[R2]], [[HALF]] : tensor<*xf32>
// CHECK-DAG: [[R4:%.+]] = mhlo.add [[R3]], [[HALF]] : tensor<*xf32>
%0 = "tf.Sigmoid"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ struct FoldReshapeIntoInterfaceTensorLoad

// Removes operations with Allocate MemoryEffects but no uses.
struct RemoveDeadMemAllocs : RewritePattern {
RemoveDeadMemAllocs(PatternBenefit benefit = 1)
: RewritePattern(benefit, MatchAnyOpTypeTag()) {}
RemoveDeadMemAllocs(MLIRContext *context, PatternBenefit benefit = 1)
: RewritePattern(MatchAnyOpTypeTag(), benefit, context) {}

LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
Expand All @@ -109,8 +109,8 @@ struct BufferAllocViewCleanUpPass
: public PassWrapper<BufferAllocViewCleanUpPass, FunctionPass> {
void runOnFunction() override {
OwningRewritePatternList patterns(&getContext());
patterns.insert<FoldReshapeIntoInterfaceTensorLoad>(&getContext());
patterns.insert<RemoveDeadMemAllocs>();
patterns.insert<FoldReshapeIntoInterfaceTensorLoad, RemoveDeadMemAllocs>(
&getContext());
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
Expand Down
10 changes: 6 additions & 4 deletions iree/compiler/Conversion/Common/LinalgBufferizePass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,12 +329,14 @@ static LogicalResult convertTensorReshapeOp(
resultTensorType, {}, inputBufferType.getMemorySpaceAsInt());
Value bufferReshape = b.create<linalg::ReshapeOp>(
loc, reshapeResultType, reshapeSrc, op.reassociation());
auto allocationDynamicSizes = linalg::getReshapeOutputShapeFromInputShape(
b, loc, inputBuffer, resultTensorType.getShape(),
op.getReassociationMaps());
SmallVector<SmallVector<Value>> reshapeResultShape;
if (failed(op.reifyReturnTypeShapesPerResultDim(b, reshapeResultShape)) ||
reshapeResultShape.size() != 1) {
return op.emitError("failed to get shape of result");
}
return createAliasingBufferOrAllocationForResult(
b, loc, allocationFn, srcTensor, bufferReshape, resultTensor,
allocationDynamicSizes, bvm);
reshapeResultShape[0], bvm);
}

static SmallVector<int64_t, 4> extractFromI64ArrayAttr(ArrayAttr attr) {
Expand Down
2 changes: 1 addition & 1 deletion iree/compiler/Conversion/HLOToHLO/DemoteF32ToF16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class FloatTypeConverter : public TypeConverter {
class GenericTypeConvert : public ConversionPattern {
public:
GenericTypeConvert(MLIRContext *context, TypeConverter &converter)
: ConversionPattern(0, converter, MatchAnyOpTypeTag()) {}
: ConversionPattern(converter, MatchAnyOpTypeTag(), 0, context) {}
LogicalResult matchAndRewrite(
Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Expand Down
2 changes: 1 addition & 1 deletion iree/compiler/Conversion/LinalgToLLVM/ConvertToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@ void ConvertToLLVMPass::runOnOperation() {
LLVMConversionTarget target(getContext());
// IREE::HAL::InterfaceOp will be removed after successful conversion of the
// rest of the IR.
target.addLegalOp<ModuleOp, ModuleTerminatorOp, IREE::HAL::InterfaceOp,
target.addLegalOp<ModuleOp, IREE::HAL::InterfaceOp,
IREE::HAL::InterfaceBindingOp, IREE::HAL::InterfaceEndOp>();
target.addIllegalDialect<ShapeDialect, StandardOpsDialect, IREEDialect,
IREE::HAL::HALDialect, math::MathDialect>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ namespace {
// that is always templated on an op.
struct TileWorkgroups : public linalg::LinalgBaseTilingPattern {
using Base = linalg::LinalgBaseTilingPattern;
TileWorkgroups(linalg::LinalgTilingOptions options,
TileWorkgroups(MLIRContext *context, linalg::LinalgTilingOptions options,
linalg::LinalgTransformationFilter marker)
: LinalgBaseTilingPattern(options, marker) {}
: LinalgBaseTilingPattern(context, options, marker) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
auto contractionOp = dyn_cast<linalg::ContractionOpInterface>(op);
Expand Down Expand Up @@ -153,6 +153,7 @@ void TileAndVectorizeWorkgroups::runOnFunction() {
// First level of tiling patterns. (workgroups memory)
OwningRewritePatternList l1patterns(&getContext());
l1patterns.insert<TileWorkgroups>(
context,
linalg::LinalgTilingOptions().setTileSizeComputationFunction(
[](OpBuilder &builder,
Operation *operation) -> SmallVector<Value, 4> {
Expand All @@ -175,6 +176,7 @@ void TileAndVectorizeWorkgroups::runOnFunction() {
{
OwningRewritePatternList l2patterns(&getContext());
l2patterns.insert<TileWorkgroups>(
context,
linalg::LinalgTilingOptions().setTileSizeComputationFunction(
[](OpBuilder &builder,
Operation *operation) -> SmallVector<Value, 4> {
Expand Down
6 changes: 2 additions & 4 deletions iree/compiler/Conversion/LinalgToNVVM/ConvertToNVVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,8 @@ struct ConvertToNVVMPass
ModuleOp m = getOperation();

/// Customize the bitwidth used for the device side index computations.
LowerToLLVMOptions options = {/*useBarePtrCallConv =*/false,
/*emitCWrappers =*/false,
/*indexBitwidth =*/64,
/*useAlignedAlloc =*/false};
LowerToLLVMOptions options(m.getContext(), DataLayout(m));
options.overrideIndexBitwidth(64);
LLVMTypeConverter converter(m.getContext(), options);
// Apply in-dialect lowering first. In-dialect lowering will replace ops
// which need to be lowered further, which is not supported by a single
Expand Down
2 changes: 1 addition & 1 deletion iree/compiler/Conversion/LinalgToSPIRV/VectorToGPUPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ class VectorContractLowering : public OpRewritePattern<vector::ContractionOp> {
class ElementwiseLowering : public RewritePattern {
public:
ElementwiseLowering(MLIRContext *context)
: RewritePattern(0, MatchAnyOpTypeTag()) {}
: RewritePattern(MatchAnyOpTypeTag(), 0, context) {}

LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,8 @@ hal.executable @matmul_fusion attributes {sym_visibility = "private"} {
// CHECK: %[[LBX:.+]] = affine.apply #[[MAP3]]()[%[[BIDX]]]
// CHECK: %[[SV_ARG1:.+]] = memref.subview %[[ARG1]][0, %[[LBX]]]
// CHECK: %[[SV_RET0:.+]] = memref.subview %[[RET0]][%[[LBY]], %[[LBX]]]
// CHECK: linalg.fill(%[[SV_RET0]], %{{.+}})
// CHECK: %[[SV_RET0_1:.+]] = memref.subview %[[RET0]][%[[LBY]], %[[LBX]]]
// CHECK: linalg.fill(%[[SV_RET0_1]], %{{.+}})
// CHECK-SAME: "workgroup"
// CHECK: linalg.matmul
// CHECK-SAME: "workgroup"
Expand Down Expand Up @@ -328,7 +329,9 @@ hal.executable @conv_no_padding_fusion attributes {sym_visibility = "private"} {
// CHECK-SAME: [%[[BIDZ]], %[[LBY]], %[[LBX]], 0]
// CHECK: %[[SV_RET0:.+]] = memref.subview %[[RET0]]
// CHECK-SAME: [%[[BIDZ]], %[[LBY]], %[[LBX]], 0]
// CHECK: linalg.fill(%[[SV_RET0]], %{{.*}})
// CHECK: %[[SV_RET0_1:.+]] = memref.subview %[[RET0]]
// CHECK-SAME: [%[[BIDZ]], %[[LBY]], %[[LBX]], 0]
// CHECK: linalg.fill(%[[SV_RET0_1]], %{{.*}})
// CHECK-SAME: "workgroup"
// CHECK: linalg.conv_2d_input_nhwc_filter_hwcf
// CHECK-SAME: "workgroup"
Expand Down Expand Up @@ -397,10 +400,13 @@ hal.executable @three_op_fusion attributes {sym_visibility = "private"} {
}
}
}

// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 8)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (8, s0 * -8 + 25)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 * 16)>
// CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0] -> (16, s0 * -16 + 75)>
// CHECK-DAG: #[[MAP1_2:.+]] = affine_map<()[s0] -> (s0 * -8 + 25, 8)>
// CHECK-DAG: #[[MAP3_2:.+]] = affine_map<()[s0] -> (s0 * -16 + 75, 16)>
// CHECK: hal.executable.entry_point @three_op_fusion
// CHECK-DAG: %[[C1:.+]] = constant 1
// CHECK-DAG: %[[C4:.+]] = constant 4
Expand All @@ -416,14 +422,16 @@ hal.executable @three_op_fusion attributes {sym_visibility = "private"} {
// CHECK-DAG: %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
// CHECK-NOT: scf.parallel
// CHECK-NOT: scf.for
// CHECK: %[[LBY:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
// CHECK: %[[TILE_M:.+]] = affine.min #[[MAP1]]()[%[[BIDY]]]
// CHECK: %[[LBX:.+]] = affine.apply #[[MAP2]]()[%[[BIDX]]]
// CHECK: %[[TILE_N:.+]] = affine.min #[[MAP3]]()[%[[BIDX]]]
// CHECK: %[[SV_ARG2:.+]] = memref.subview %[[ARG2]][%[[LBX]]] [%[[TILE_N]]]
// CHECK: %[[SV_RET0:.+]] = memref.subview %[[RET0]][%[[LBY]], %[[LBX]]
// CHECK-SAME: [%[[TILE_M]], %[[TILE_N]]]
// CHECK: %[[TILE_N_2:.+]] = affine.min #[[MAP3]]()[%[[BIDX]]]
// CHECK: %[[SV_ARG2:.+]] = memref.subview %[[ARG2]][%[[LBX]]] [%[[TILE_N_2]]]
// CHECK: %[[LBY:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
// CHECK: %[[TILE_M_2:.+]] = affine.min #[[MAP1]]()[%[[BIDY]]]
// CHECK: %[[SV_RET0:.+]] = memref.subview %[[RET0]][%[[LBY]], %[[LBX]]]
// CHECK-SAME: [%[[TILE_M_2]], %[[TILE_N_2]]]
// CHECK: %[[TILE_M:.+]] = affine.min #[[MAP1_2]]()[%[[BIDY]]]
// CHECK: %[[SV_ARG0:.+]] = memref.subview %[[ARG0]][%[[LBY]], 0] [%[[TILE_M]], 50]
// CHECK: %[[TILE_N:.+]] = affine.min #[[MAP3_2]]()[%[[BIDX]]]
// CHECK: %[[SV_ARG1:.+]] = memref.subview %[[ARG1]][0, %[[LBX]]] [50, %[[TILE_N]]]
// CHECK: %[[SV_ALLOC:.+]] = memref.subview %[[ALLOC]][0, 0] [%[[TILE_M]], %[[TILE_N]]]
// CHECK: linalg.fill(%[[SV_ALLOC]], %{{.+}})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -669,10 +669,11 @@ namespace {
struct TileAndDistributeOnTensorsPattern
: public linalg::LinalgBaseTilingPattern {
using Base = linalg::LinalgBaseTilingPattern;
TileAndDistributeOnTensorsPattern(linalg::LinalgTilingOptions options,
TileAndDistributeOnTensorsPattern(MLIRContext *context,
linalg::LinalgTilingOptions options,
linalg::LinalgTransformationFilter marker,
PatternBenefit benefit = 1)
: Base(options, marker, benefit) {}
: Base(context, options, marker, benefit) {}

LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -779,8 +780,8 @@ static Optional<SmallVector<Value, 4>> getResultShape(PatternRewriter &rewriter,
/// element-wise operations is not beneficial. These are handled appropriately
/// by the backends.
struct MakeDispatchWorkgroupsOp : public RewritePattern {
MakeDispatchWorkgroupsOp(PatternBenefit benefit = 1)
: RewritePattern(benefit, MatchAnyOpTypeTag()) {}
MakeDispatchWorkgroupsOp(MLIRContext *context, PatternBenefit benefit = 1)
: RewritePattern(MatchAnyOpTypeTag(), benefit, context) {}

LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -1019,7 +1020,7 @@ void DispatchLinalgOnTensorsPass::runOnOperation() {
assert(linalgTilingOptions.distribution.hasValue());

patterns.insert<TileAndDistributeOnTensorsPattern>(
linalgTilingOptions,
context, linalgTilingOptions,
// TODO(nicolavasilache): use refactored `getWorkgroupMarker()`
linalg::LinalgTransformationFilter(
ArrayRef<Identifier>(), Identifier::get("workgroup", context)));
Expand All @@ -1042,8 +1043,8 @@ void DispatchLinalgOnTensorsPass::runOnOperation() {

// Move other operations into their own dispatch regions.
{
OwningRewritePatternList patterns(&getContext());
patterns.insert<MakeDispatchWorkgroupsOp>();
OwningRewritePatternList patterns(context);
patterns.insert<MakeDispatchWorkgroupsOp>(context);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ void populateStandardStructuralToHALPatterns(MLIRContext *context,
void setupStandardToHALLegality(MLIRContext *context,
ConversionTarget &conversionTarget,
TypeConverter &typeConverter) {
conversionTarget.addLegalOp<mlir::ModuleOp, mlir::ModuleTerminatorOp>();
conversionTarget.addLegalOp<mlir::ModuleOp>();

// We need to rewrite certain types on operands/results so use the default
// dynamic legality checker to force any ops using such types to run through
Expand Down
6 changes: 5 additions & 1 deletion iree/compiler/Dialect/HAL/Target/TargetBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,11 @@ static void mergeModuleInto(Operation *sourceModuleOp,
}
targetSymbolMap[symbolInterface.getName()] = op;
}
op->moveBefore(&targetBlock.back());
if (!targetBlock.empty() &&
targetBlock.back().hasTrait<OpTrait::IsTerminator>())
op->moveBefore(&targetBlock.back());
else
op->moveBefore(&targetBlock, targetBlock.end());
}

// Now that we're done cloning its ops, delete the original target op.
Expand Down
Loading

0 comments on commit 431ede6

Please sign in to comment.