From 767e2881bd90aa3f8aa515bae7bf75602fc4f355 Mon Sep 17 00:00:00 2001 From: MaheshRavishankar <1663364+MaheshRavishankar@users.noreply.github.com> Date: Fri, 6 Sep 2024 00:50:37 -0700 Subject: [PATCH] Undo revert of https://github.com/llvm/llvm-project/pull/104668 (#18451) Signed-off-by: MaheshRavishankar Co-authored-by: Matthias Springer --- .../test/stablehlo_to_linalg_ext.mlir | 2 - .../Codegen/Common/TypePropagationPass.cpp | 66 +++---------------- third_party/llvm-project | 2 +- 3 files changed, 10 insertions(+), 60 deletions(-) diff --git a/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_ext.mlir b/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_ext.mlir index 09b2bd4d87bd..713fb05c61e4 100644 --- a/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_ext.mlir +++ b/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_ext.mlir @@ -516,7 +516,6 @@ func.func @reverse_unsigned(%arg0: tensor<3x5xui32>) -> tensor<3x5xui32> { } // CHECK-LABEL: func.func @reverse_unsigned // CHECK-SAME: %[[IN:[a-zA-Z0-9]+]] -// CHECK: %[[BITCAST:.+]] = builtin.unrealized_conversion_cast %[[IN]] : tensor<3x5xui32> to tensor<3x5xi32> // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<3x5xui32> // CHECK: %[[GEN:.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[INIT]] : tensor<3x5xui32>) // CHECK: %[[SAME_DIM:.+]] = linalg.index 0 : index @@ -654,7 +653,6 @@ func.func @prefix(%arg0: tensor<7x5xi32>, %arg1: tensor) -> tensor<7x5xi32> }) {base_dilations = array, padding = dense<[[0, 0], [4, 0]]> : tensor<2x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor<7x5xi32>, tensor) -> tensor<7x5xi32> return %reduce : tensor<7x5xi32> } -// CHECK: %extracted = tensor.extract %[[ARG1]][] : tensor // CHECK: %[[OUT0:.+]] = tensor.empty() : tensor<7x5xi32> // CHECK: %[[OUT1:.+]] = tensor.empty() : tensor<7xi32> // CHECK: %[[FILL:.+]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel"]} ins(%[[ARG1]] : tensor) outs(%[[OUT1]] : tensor<7xi32>) diff --git a/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp b/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp index f3fba9006430..fce31c71298b 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TypePropagationPass.cpp @@ -229,29 +229,13 @@ struct GenericOpTypePropagation signatureConverter.addInputs(index, legalizedArgType.value()); } rewriter.applySignatureConversion(&modifiedOpRegion.front(), - signatureConverter); + signatureConverter, getTypeConverter()); // 6. Introduce scalar conversion operations to convert back to the // original scalar type. { OpBuilder::InsertionGuard g(rewriter); Block *entryBlock = modifiedOp.getBlock(); - for (auto modifiedOperandIndex : modifiedOperandIndex) { - OpOperand *modifiedOpOperand = - &modifiedOp->getOpOperand(modifiedOperandIndex); - BlockArgument source = - modifiedOp.getMatchingBlockArgument(modifiedOpOperand); - Type destType = getElementTypeOrSelf( - genericOp.getOperand(modifiedOperandIndex).getType()); - - // 6a. If the value of the argument is used the argument is in the - // legalized type. Convert it to a value that is in the original - // element type for replacement of all uses in the block. - rewriter.setInsertionPointToStart(entryBlock); - Value replacement = - convertElementType(rewriter, source.getLoc(), destType, source); - rewriter.replaceUsesOfBlockArgument(source, replacement); - } // 6b. If any of the operands modified were outputs, the yield values // need to be modified as well. @@ -372,27 +356,13 @@ struct IREELinalgExtScatterTypePropagation signatureConverter.addInputs(0, legalizedArgType.value()); signatureConverter.addInputs(1, legalizedArgType.value()); rewriter.applySignatureConversion(&modifiedOpRegion.front(), - signatureConverter); + signatureConverter, getTypeConverter()); { // Introduce scalar conversion operations to convert back to the original // scalar type. OpBuilder::InsertionGuard g(rewriter); Block *entryBlock = &modifiedOp->getRegion(0).getBlocks().front(); - BlockArgument inputArg = entryBlock->getArgument(0); - BlockArgument outputArg = entryBlock->getArgument(1); - - auto destType = getElementTypeOrSelf(inputType); - rewriter.setInsertionPointToStart(entryBlock); - - Value replacementInput = - convertElementType(rewriter, inputArg.getLoc(), destType, inputArg); - rewriter.replaceUsesOfBlockArgument(entryBlock->getArgument(0), - replacementInput); - Value replacementOutput = - convertElementType(rewriter, outputArg.getLoc(), destType, outputArg); - rewriter.replaceUsesOfBlockArgument(entryBlock->getArgument(1), - replacementOutput); // If the output is of an illegal type, the yield value needs to be // modified @@ -449,31 +419,7 @@ struct IREELinalgExtSortTypePropagation signatureConverter.addInputs(index, legalizedArgType.value()); } rewriter.applySignatureConversion(&modifiedOpRegion.front(), - signatureConverter); - - { - // Introduce scalar conversion operations to convert back to the original - // scalar type. - OpBuilder::InsertionGuard g(rewriter); - Block *entryBlock = &modifiedOp->getRegion(0).getBlocks().front(); - for (auto [index, operand] : llvm::enumerate(sortOp->getOpOperands())) { - BlockArgument firstInputArg = entryBlock->getArgument(index * 2); - BlockArgument secondInputArg = entryBlock->getArgument(index * 2 + 1); - - auto destType = getElementTypeOrSelf(operand.get().getType()); - rewriter.setInsertionPointToStart(entryBlock); - if (destType != getElementTypeOrSelf(legalizedResultTypes[index])) { - Value replacementFirstInput = convertElementType( - rewriter, firstInputArg.getLoc(), destType, firstInputArg); - rewriter.replaceUsesOfBlockArgument(firstInputArg, - replacementFirstInput); - Value replacementSecondInput = convertElementType( - rewriter, secondInputArg.getLoc(), destType, secondInputArg); - rewriter.replaceUsesOfBlockArgument(secondInputArg, - replacementSecondInput); - } - } - } + signatureConverter, getTypeConverter()); rewriter.replaceOp(sortOp, modifiedOp->getResults()); return success(); } @@ -580,6 +526,12 @@ struct TypePropagationPass final RewritePatternSet patterns(context); TypePropagationTypeConverter typeConverter; + typeConverter.addArgumentMaterialization( + [&](OpBuilder &builder, Type type, ValueRange inputs, Location loc) { + assert(inputs.size() == 1 && "expected exactly one input"); + return convertElementType(builder, loc, type, inputs[0]); + }); + patterns.insert< ConstantOpTypeConversion, ForwardSourceType, ForwardSourceType, GenericOpTypePropagation, diff --git a/third_party/llvm-project b/third_party/llvm-project index bea0be37cfe8..ac0ef19c9953 160000 --- a/third_party/llvm-project +++ b/third_party/llvm-project @@ -1 +1 @@ -Subproject commit bea0be37cfe820dbaacac8f22d7122352416d7b4 +Subproject commit ac0ef19c9953bdb947461bd7f90a85ce92b6b32f