Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][IR] Rename "update root" to "modify op" in rewriter API #78260

Conversation

matthias-springer
Copy link
Member

This commit renames 4 pattern rewriter API functions:

  • updateRootInPlace -> modifyOpInPlace
  • startRootUpdate -> startOpModification
  • finalizeRootUpdate -> finalizeOpModification
  • cancelRootUpdate -> cancelOpModification

The term "root" is a misnomer. The root is the op that a rewrite pattern matches against (https://mlir.llvm.org/docs/PatternRewriter/#root-operation-name-optional). A rewriter must be notified of all in-place op modifications, not just in-place modifications of the root (https://mlir.llvm.org/docs/PatternRewriter/#pattern-rewriter). The old function names were confusing and have contributed to various broken rewrite patterns.

Note: The new function names use the term "modify" instead of "update" for consistency with the RewriterBase::Listener terminology (notifyOperationModified).

@llvmbot
Copy link
Collaborator

llvmbot commented Jan 16, 2024

@llvm/pr-subscribers-openacc
@llvm/pr-subscribers-mlir-scf
@llvm/pr-subscribers-flang-codegen
@llvm/pr-subscribers-mlir-tosa
@llvm/pr-subscribers-flang-fir-hlfir
@llvm/pr-subscribers-mlir-openacc
@llvm/pr-subscribers-mlir-sme
@llvm/pr-subscribers-mlir-sparse
@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

This commit renames 4 pattern rewriter API functions:

  • updateRootInPlace -> modifyOpInPlace
  • startRootUpdate -> startOpModification
  • finalizeRootUpdate -> finalizeOpModification
  • cancelRootUpdate -> cancelOpModification

The term "root" is a misnomer. The root is the op that a rewrite pattern matches against (https://mlir.llvm.org/docs/PatternRewriter/#root-operation-name-optional). A rewriter must be notified of all in-place op modifications, not just in-place modifications of the root (https://mlir.llvm.org/docs/PatternRewriter/#pattern-rewriter). The old function names were confusing and have contributed to various broken rewrite patterns.

Note: The new function names use the term "modify" instead of "update" for consistency with the RewriterBase::Listener terminology (notifyOperationModified).


Patch is 95.66 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/78260.diff

77 Files Affected:

  • (modified) flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp (+10-10)
  • (modified) flang/lib/Optimizer/CodeGen/CodeGen.cpp (+4-4)
  • (modified) flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp (+4-4)
  • (modified) flang/lib/Optimizer/Transforms/AffineDemotion.cpp (+2-2)
  • (modified) flang/lib/Optimizer/Transforms/AffinePromotion.cpp (+6-6)
  • (modified) flang/lib/Optimizer/Transforms/ExternalNameConversion.cpp (+4-4)
  • (modified) mlir/docs/PatternRewriter.md (+5-5)
  • (modified) mlir/examples/standalone/lib/Standalone/StandalonePasses.cpp (+1-1)
  • (modified) mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp (+2-2)
  • (modified) mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp (+2-2)
  • (modified) mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp (+2-2)
  • (modified) mlir/include/mlir/IR/PatternMatch.h (+23-21)
  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+7-7)
  • (modified) mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp (+1-1)
  • (modified) mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp (+1-2)
  • (modified) mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp (+2-3)
  • (modified) mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp (+3-3)
  • (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+3-3)
  • (modified) mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp (+2-2)
  • (modified) mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp (+2-2)
  • (modified) mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp (+2-2)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp (+2-2)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp (+1-1)
  • (modified) mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp (+4-4)
  • (modified) mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp (+3-3)
  • (modified) mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp (+2-2)
  • (modified) mlir/lib/Dialect/GPU/IR/GPUDialect.cpp (+1-1)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp (+2-2)
  • (modified) mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp (+7-7)
  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+2-2)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp (+4-4)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp (+3-3)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (+3-3)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp (+2-2)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp (+2-2)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp (+3-3)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Split.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+1-1)
  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp (+2-2)
  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+4-4)
  • (modified) mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp (+1-1)
  • (modified) mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp (+1-1)
  • (modified) mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp (+2-2)
  • (modified) mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp (+1-1)
  • (modified) mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp (+2-2)
  • (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+14-14)
  • (modified) mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp (+2-2)
  • (modified) mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp (+2-2)
  • (modified) mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp (+2-2)
  • (modified) mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp (+5-5)
  • (modified) mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp (+2-2)
  • (modified) mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp (+1-1)
  • (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+1-1)
  • (modified) mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp (+2-2)
  • (modified) mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp (+1-1)
  • (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp (+1-1)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp (+10-10)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp (+3-3)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp (+1-1)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp (+1-1)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+7-7)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+2-2)
  • (modified) mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+7-7)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp (+2-2)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+1-1)
  • (modified) mlir/lib/IR/PatternMatch.cpp (+2-2)
  • (modified) mlir/lib/Transforms/Mem2Reg.cpp (+2-2)
  • (modified) mlir/lib/Transforms/Utils/CommutativityUtils.cpp (+1-1)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+5-5)
  • (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+15-15)
diff --git a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
index 24cf2f39fc9a090..7d73af4d7103dcd 100644
--- a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
+++ b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
@@ -215,14 +215,14 @@ class BoxedProcedurePass
             rewriter.replaceOpWithNewOp<ConvertOp>(
                 addr, typeConverter.convertType(addr.getType()), addr.getVal());
           } else if (typeConverter.needsConversion(resTy)) {
-            rewriter.startRootUpdate(op);
+            rewriter.startOpModification(op);
             op->getResult(0).setType(typeConverter.convertType(resTy));
-            rewriter.finalizeRootUpdate(op);
+            rewriter.finalizeOpModification(op);
           }
         } else if (auto func = mlir::dyn_cast<mlir::func::FuncOp>(op)) {
           mlir::FunctionType ty = func.getFunctionType();
           if (typeConverter.needsConversion(ty)) {
-            rewriter.startRootUpdate(func);
+            rewriter.startOpModification(func);
             auto toTy =
                 typeConverter.convertType(ty).cast<mlir::FunctionType>();
             if (!func.empty())
@@ -235,7 +235,7 @@ class BoxedProcedurePass
                 block.eraseArgument(i + 1);
               }
             func.setType(toTy);
-            rewriter.finalizeRootUpdate(func);
+            rewriter.finalizeOpModification(func);
           }
         } else if (auto embox = mlir::dyn_cast<EmboxProcOp>(op)) {
           // Rewrite all `fir.emboxproc` ops to either `fir.convert` or a thunk
@@ -273,10 +273,10 @@ class BoxedProcedurePass
         } else if (auto global = mlir::dyn_cast<GlobalOp>(op)) {
           auto ty = global.getType();
           if (typeConverter.needsConversion(ty)) {
-            rewriter.startRootUpdate(global);
+            rewriter.startOpModification(global);
             auto toTy = typeConverter.convertType(ty);
             global.setType(toTy);
-            rewriter.finalizeRootUpdate(global);
+            rewriter.finalizeOpModification(global);
           }
         } else if (auto mem = mlir::dyn_cast<AllocaOp>(op)) {
           auto ty = mem.getType();
@@ -339,17 +339,17 @@ class BoxedProcedurePass
                 mem, toTy, index.getFieldId(), toOnTy, index.getTypeparams());
           }
         } else if (op->getDialect() == firDialect) {
-          rewriter.startRootUpdate(op);
+          rewriter.startOpModification(op);
           for (auto i : llvm::enumerate(op->getResultTypes()))
             if (typeConverter.needsConversion(i.value())) {
               auto toTy = typeConverter.convertType(i.value());
               op->getResult(i.index()).setType(toTy);
             }
-          rewriter.finalizeRootUpdate(op);
+          rewriter.finalizeOpModification(op);
         }
         // Ensure block arguments are updated if needed.
         if (op->getNumRegions() != 0) {
-          rewriter.startRootUpdate(op);
+          rewriter.startOpModification(op);
           for (mlir::Region &region : op->getRegions())
             for (mlir::Block &block : region.getBlocks())
               for (mlir::BlockArgument blockArg : block.getArguments())
@@ -358,7 +358,7 @@ class BoxedProcedurePass
                       typeConverter.convertType(blockArg.getType());
                   blockArg.setType(toTy);
                 }
-          rewriter.finalizeRootUpdate(op);
+          rewriter.finalizeOpModification(op);
         }
       });
     }
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index e07732d57880c55..f2c731d47909a94 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -3763,13 +3763,13 @@ class RenameMSVCLibmCallees
   mlir::LogicalResult
   matchAndRewrite(mlir::LLVM::CallOp op,
                   mlir::PatternRewriter &rewriter) const override {
-    rewriter.startRootUpdate(op);
+    rewriter.startOpModification(op);
     auto callee = op.getCallee();
     if (callee)
       if (callee->equals("hypotf"))
         op.setCalleeAttr(mlir::SymbolRefAttr::get(op.getContext(), "_hypotf"));
 
-    rewriter.finalizeRootUpdate(op);
+    rewriter.finalizeOpModification(op);
     return mlir::success();
   }
 };
@@ -3782,10 +3782,10 @@ class RenameMSVCLibmFuncs
   mlir::LogicalResult
   matchAndRewrite(mlir::LLVM::LLVMFuncOp op,
                   mlir::PatternRewriter &rewriter) const override {
-    rewriter.startRootUpdate(op);
+    rewriter.startOpModification(op);
     if (op.getSymName().equals("hypotf"))
       op.setSymNameAttr(rewriter.getStringAttr("_hypotf"));
-    rewriter.finalizeRootUpdate(op);
+    rewriter.finalizeOpModification(op);
     return mlir::success();
   }
 };
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
index 97127f57cc3eb99..641854bd201f0b0 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
@@ -256,9 +256,9 @@ struct AssignOpConversion : public mlir::OpConversionPattern<hlfir::AssignOp> {
     llvm::SmallVector<mlir::Value> newOperands;
     for (mlir::Value operand : adaptor.getOperands())
       newOperands.push_back(getBufferizedExprStorage(operand));
-    rewriter.startRootUpdate(assign);
+    rewriter.startOpModification(assign);
     assign->setOperands(newOperands);
-    rewriter.finalizeRootUpdate(assign);
+    rewriter.finalizeOpModification(assign);
     return mlir::success();
   }
 };
@@ -834,9 +834,9 @@ struct ElementalOpConversion
     // Explicitly delete the body of the elemental to get rid
     // of any users of hlfir.expr values inside the body as early
     // as possible.
-    rewriter.startRootUpdate(elemental);
+    rewriter.startOpModification(elemental);
     rewriter.eraseBlock(elemental.getBody());
-    rewriter.finalizeRootUpdate(elemental);
+    rewriter.finalizeOpModification(elemental);
     rewriter.replaceOp(elemental, bufferizedExpr);
     return mlir::success();
   }
diff --git a/flang/lib/Optimizer/Transforms/AffineDemotion.cpp b/flang/lib/Optimizer/Transforms/AffineDemotion.cpp
index 0c256deeca41617..da29ae880700e62 100644
--- a/flang/lib/Optimizer/Transforms/AffineDemotion.cpp
+++ b/flang/lib/Optimizer/Transforms/AffineDemotion.cpp
@@ -114,9 +114,9 @@ class ConvertConversion : public mlir::OpRewritePattern<fir::ConvertOp> {
                                                       op.getValue());
           return success();
         }
-      rewriter.startRootUpdate(op->getParentOp());
+      rewriter.startOpModification(op->getParentOp());
       op.getResult().replaceAllUsesWith(op.getValue());
-      rewriter.finalizeRootUpdate(op->getParentOp());
+      rewriter.finalizeOpModification(op->getParentOp());
       rewriter.eraseOp(op);
     }
     return success();
diff --git a/flang/lib/Optimizer/Transforms/AffinePromotion.cpp b/flang/lib/Optimizer/Transforms/AffinePromotion.cpp
index af2200f6a7b02dd..d1831cf1c200cc2 100644
--- a/flang/lib/Optimizer/Transforms/AffinePromotion.cpp
+++ b/flang/lib/Optimizer/Transforms/AffinePromotion.cpp
@@ -464,15 +464,15 @@ class AffineLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> {
     auto affineFor = loopAndIndex.first;
     auto inductionVar = loopAndIndex.second;
 
-    rewriter.startRootUpdate(affineFor.getOperation());
+    rewriter.startOpModification(affineFor.getOperation());
     affineFor.getBody()->getOperations().splice(
         std::prev(affineFor.getBody()->end()), loopOps, loopOps.begin(),
         std::prev(loopOps.end()));
-    rewriter.finalizeRootUpdate(affineFor.getOperation());
+    rewriter.finalizeOpModification(affineFor.getOperation());
 
-    rewriter.startRootUpdate(loop.getOperation());
+    rewriter.startOpModification(loop.getOperation());
     loop.getInductionVar().replaceAllUsesWith(inductionVar);
-    rewriter.finalizeRootUpdate(loop.getOperation());
+    rewriter.finalizeOpModification(loop.getOperation());
 
     rewriteMemoryOps(affineFor.getBody(), rewriter);
 
@@ -561,7 +561,7 @@ class AffineIfConversion : public mlir::OpRewritePattern<fir::IfOp> {
     auto affineIf = rewriter.create<affine::AffineIfOp>(
         op.getLoc(), affineCondition.getIntegerSet(),
         affineCondition.getAffineArgs(), !op.getElseRegion().empty());
-    rewriter.startRootUpdate(affineIf);
+    rewriter.startOpModification(affineIf);
     affineIf.getThenBlock()->getOperations().splice(
         std::prev(affineIf.getThenBlock()->end()), ifOps, ifOps.begin(),
         std::prev(ifOps.end()));
@@ -571,7 +571,7 @@ class AffineIfConversion : public mlir::OpRewritePattern<fir::IfOp> {
           std::prev(affineIf.getElseBlock()->end()), otherOps, otherOps.begin(),
           std::prev(otherOps.end()));
     }
-    rewriter.finalizeRootUpdate(affineIf);
+    rewriter.finalizeOpModification(affineIf);
     rewriteMemoryOps(affineIf.getBody(), rewriter);
 
     LLVM_DEBUG(llvm::dbgs() << "AffineIfConversion: if converted to:\n";
diff --git a/flang/lib/Optimizer/Transforms/ExternalNameConversion.cpp b/flang/lib/Optimizer/Transforms/ExternalNameConversion.cpp
index 221e93ff85e18e5..bc5be3f196b81ac 100644
--- a/flang/lib/Optimizer/Transforms/ExternalNameConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/ExternalNameConversion.cpp
@@ -76,7 +76,7 @@ struct MangleNameOnFuncOp : public mlir::OpRewritePattern<mlir::func::FuncOp> {
   matchAndRewrite(mlir::func::FuncOp op,
                   mlir::PatternRewriter &rewriter) const override {
     mlir::LogicalResult ret = success();
-    rewriter.startRootUpdate(op);
+    rewriter.startOpModification(op);
     llvm::StringRef oldName = op.getSymName();
     auto result = fir::NameUniquer::deconstruct(oldName);
     if (fir::NameUniquer::isExternalFacingUniquedName(result)) {
@@ -95,7 +95,7 @@ struct MangleNameOnFuncOp : public mlir::OpRewritePattern<mlir::func::FuncOp> {
     }
 
     updateEarlyOutliningParentName(op, appendUnderscore);
-    rewriter.finalizeRootUpdate(op);
+    rewriter.finalizeOpModification(op);
     return ret;
   }
 
@@ -114,7 +114,7 @@ struct MangleNameForCommonBlock : public mlir::OpRewritePattern<fir::GlobalOp> {
   mlir::LogicalResult
   matchAndRewrite(fir::GlobalOp op,
                   mlir::PatternRewriter &rewriter) const override {
-    rewriter.startRootUpdate(op);
+    rewriter.startOpModification(op);
     auto result = fir::NameUniquer::deconstruct(
         op.getSymref().getRootReference().getValue());
     if (fir::NameUniquer::isExternalFacingUniquedName(result)) {
@@ -122,7 +122,7 @@ struct MangleNameForCommonBlock : public mlir::OpRewritePattern<fir::GlobalOp> {
       op.setSymrefAttr(mlir::SymbolRefAttr::get(op.getContext(), newName));
       SymbolTable::setSymbolName(op, newName);
     }
-    rewriter.finalizeRootUpdate(op);
+    rewriter.finalizeOpModification(op);
     return success();
   }
 
diff --git a/mlir/docs/PatternRewriter.md b/mlir/docs/PatternRewriter.md
index 8fe5ef35a760392..011cd14175634be 100644
--- a/mlir/docs/PatternRewriter.md
+++ b/mlir/docs/PatternRewriter.md
@@ -213,15 +213,15 @@ user is determined by the specific pattern driver.
 This method replaces an operation's results with a set of provided values, and
 erases the operation.
 
-*   Update an Operation in-place : `(start|cancel|finalize)RootUpdate`
+*   Update an Operation in-place : `(start|cancel|finalize)OpModification`
 
 This is a collection of methods that provide a transaction-like API for updating
 the attributes, location, operands, or successors of an operation in-place
 within a pattern. An in-place update transaction is started with
-`startRootUpdate`, and may either be canceled or finalized with
-`cancelRootUpdate` and `finalizeRootUpdate` respectively. A convenience wrapper,
-`updateRootInPlace`, is provided that wraps a `start` and `finalize` around a
-callback.
+`startOpModification`, and may either be canceled or finalized with
+`cancelOpModification` and `finalizeOpModification` respectively. A convenience
+wrapper, `modifyOpInPlace`, is provided that wraps a `start` and `finalize`
+around a callback.
 
 *   OpBuilder API
 
diff --git a/mlir/examples/standalone/lib/Standalone/StandalonePasses.cpp b/mlir/examples/standalone/lib/Standalone/StandalonePasses.cpp
index d438cb46ecdada6..a23d0420f04350e 100644
--- a/mlir/examples/standalone/lib/Standalone/StandalonePasses.cpp
+++ b/mlir/examples/standalone/lib/Standalone/StandalonePasses.cpp
@@ -24,7 +24,7 @@ class StandaloneSwitchBarFooRewriter : public OpRewritePattern<func::FuncOp> {
   LogicalResult matchAndRewrite(func::FuncOp op,
                                 PatternRewriter &rewriter) const final {
     if (op.getSymName() == "bar") {
-      rewriter.updateRootInPlace(op, [&op]() { op.setSymName("foo"); });
+      rewriter.modifyOpInPlace(op, [&op]() { op.setSymName("foo"); });
       return success();
     }
     return failure();
diff --git a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
index 240b9f9338665a2..ae4bd980c34b53b 100644
--- a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
@@ -260,8 +260,8 @@ struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
                   ConversionPatternRewriter &rewriter) const final {
     // We don't lower "toy.print" in this pass, but we need to update its
     // operands.
-    rewriter.updateRootInPlace(op,
-                               [&] { op->setOperands(adaptor.getOperands()); });
+    rewriter.modifyOpInPlace(op,
+                             [&] { op->setOperands(adaptor.getOperands()); });
     return success();
   }
 };
diff --git a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
index 240b9f9338665a2..ae4bd980c34b53b 100644
--- a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
@@ -260,8 +260,8 @@ struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
                   ConversionPatternRewriter &rewriter) const final {
     // We don't lower "toy.print" in this pass, but we need to update its
     // operands.
-    rewriter.updateRootInPlace(op,
-                               [&] { op->setOperands(adaptor.getOperands()); });
+    rewriter.modifyOpInPlace(op,
+                             [&] { op->setOperands(adaptor.getOperands()); });
     return success();
   }
 };
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
index 240b9f9338665a2..ae4bd980c34b53b 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
@@ -260,8 +260,8 @@ struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
                   ConversionPatternRewriter &rewriter) const final {
     // We don't lower "toy.print" in this pass, but we need to update its
     // operands.
-    rewriter.updateRootInPlace(op,
-                               [&] { op->setOperands(adaptor.getOperands()); });
+    rewriter.modifyOpInPlace(op,
+                             [&] { op->setOperands(adaptor.getOperands()); });
     return success();
   }
 };
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 9b4fa65bff49e12..b065d4e8d37689e 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -585,28 +585,30 @@ class RewriterBase : public OpBuilder {
 
   /// This method is used to notify the rewriter that an in-place operation
   /// modification is about to happen. A call to this function *must* be
-  /// followed by a call to either `finalizeRootUpdate` or `cancelRootUpdate`.
-  /// This is a minor efficiency win (it avoids creating a new operation and
-  /// removing the old one) but also often allows simpler code in the client.
-  virtual void startRootUpdate(Operation *op) {}
-
-  /// This method is used to signal the end of a root update on the given
-  /// operation. This can only be called on operations that were provided to a
-  /// call to `startRootUpdate`.
-  virtual void finalizeRootUpdate(Operation *op);
-
-  /// This method cancels a pending root update. This can only be called on
-  /// operations that were provided to a call to `startRootUpdate`.
-  virtual void cancelRootUpdate(Operation *op) {}
-
-  /// This method is a utility wrapper around a root update of an operation. It
-  /// wraps calls to `startRootUpdate` and `finalizeRootUpdate` around the given
-  /// callable.
+  /// followed by a call to either `finalizeOpModification` or
+  /// `cancelOpModification`. This is a minor efficiency win (it avoids creating
+  /// a new operation and removing the old one) but also often allows simpler
+  /// code in the client.
+  virtual void startOpModification(Operation *op) {}
+
+  /// This method is used to signal the end of an in-place modification of the
+  /// given operation. This can only be called on operations that were provided
+  /// to a call to `startOpModification`.
+  virtual void finalizeOpModification(Operation *op);
+
+  /// This method cancels a pending in-place modification. This can only be
+  /// called on operations that were provided to a call to
+  /// `startOpModification`.
+  virtual void cancelOpModification(Operation *op) {}
+
+  /// This method is a utility wrapper around an in-place modification of an
+  /// operation. It wraps calls to `startOpModification` and
+  /// `finalizeOpModification` around the given callable.
   template <typename CallableT>
-  void updateRootInPlace(Operation *root, CallableT &&callable) {
-    startRootUpdate(root);
+  void modifyOpInPlace(Operation *root, CallableT &&callable) {
+    startOpModification(root);
     callable();
-    finalizeRootUpdate(root);
+    finalizeOpModification(root);
   }
 
   /// Find uses of `from` and replace them with `to`. It also marks every
@@ -619,7 +621,7 @@ class RewriterBase : public OpBuilder {
   void replaceAllUsesWith(IRObjectWithUseList<OperandType> *from, ValueT &&to) {
     for (OperandType &operand : llvm::make_early_inc_range(from->getUses())) {
       Operation *op = operand.getOwner();
-      updateRootInPlace(op, [&]() { operand.set(to); });
+      modifyOpInPlace(op, [&]() { operand.set(to); });
     }
   }
   void replaceAllUsesWith(ValueRange from, ValueRange to) {
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index c5725e9c856256b..9568540789df3f6 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -739,17 +739,17 @@ class ConversionPatternRewriter final : public PatternRewriter,
   /// PatternRewriter hook for inserting a new operation.
   void notifyOperationInserted(Operation *op) override;
 
-  /// PatternRewriter hook for updating the root operation in-place.
-  /// Note: These methods only track updates to the top-level operation itself,
+  /// PatternRewriter hook for updating the given operation in-place.
+  /// Note: These methods only track updates to the given operation itself,
   /// and not nested regions. Updates to regions will still require notification
   /// through other more specific hooks above.
-  void startRootUpdate(Operation *op) override;
+  void startOpModification(Operation *op) override;
 
-  /// PatternRewriter hook for updating the root operation in-place.
-  void finalizeRootUpdate(Operation *op) override;
+  /// PatternRewriter hook for updating the given operation in-place.
+  void finalizeOpModification(Operation *op) override;
 
-  /// PatternRewriter hook for updating the root operation in-place.
-  void cancelRootUpdate(Operation *op) override;
+  /// PatternRewriter hook for updating the given operation in-place.
+  void cancelOpModification(Operation *op) override;
 
   /// PatternRewriter hook for notifying match failure reasons.
   LogicalResult
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 16214d72fcddc22..bbef3b996e40b88 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -255,7 +255,7 @@ struct ConvertArmSMESpillsAndFillsToLLVM : publ...
[truncated]

This commit renames 4 pattern rewriter API functions:
* `updateRootInPlace` -> `modifyOpInPlace`
* `startRootUpdate` -> `startOpModification`
* `finalizeRootUpdate` -> `finalizeOpModification`
* `cancelRootUpdate` -> `cancelOpModification`

The term "root" is a misnomer. The root is the op that a rewrite pattern matches against (https://mlir.llvm.org/docs/PatternRewriter/#root-operation-name-optional). Rewriter must be notified of all in-place op modifications, not just in-place modifications of the root (https://mlir.llvm.org/docs/PatternRewriter/#pattern-rewriter). The old function names were confusing and have contributed to various broken rewrite patterns.

Note: The new function names use the term "modify" instead of "update" for consistency with the `RewriterBase::Listener` terminology (`notifyOperationModified`).
@matthias-springer matthias-springer force-pushed the pattern_rewriter_update_root_in_place branch from e525c83 to 885a022 Compare January 16, 2024 12:33
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes much more sense, thank you!

LGTM

Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this was confusing.

@matthias-springer
Copy link
Member Author

I've been working on additional listener notifications for blocks: notifyBlockRemoved (#78306) and notifyBlockModified. The goal is to have as good support for blocks as we have for operations. A good example of this is IRMapping.

To trigger the latter notification, we would need a RewriterBase::modifyBlockInPlace. This API would used when adding arguments to an existing block, etc. (But not when adding/removing operations from a block, because we have dedicated notifications for that.)

Ideally, I'd have a single RewriterBase::modifyInPlace function (same for start, finalize, cancel) function that can take blocks or operations. (To be precise, two overloads, same as IRMapping.)

I'd like to avoid changing this API twice in a row, so I wanted to bring it up before merging this commit. Any thoughts? If it needs more discussion, I can also merge this commit now and post on Discourse.

@matthias-springer
Copy link
Member Author

I am merging this now, let's talk about the block notifications separately. That is probably a longer discussion.

@matthias-springer matthias-springer merged commit 5fcf907 into llvm:main Jan 17, 2024
4 checks passed
mwachs5 added a commit to mwachs5/circt that referenced this pull request Jan 17, 2024
Lewuathe added a commit to Lewuathe/mlir-hello that referenced this pull request Jan 23, 2024
stellaraccident pushed a commit to llvm/torch-mlir that referenced this pull request Jan 31, 2024
With the recent LLVM integrate and changes from
llvm/llvm-project#78260, we hit this build error
in Stablehlo (which is quite old).
```
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1020:14: error: no member named 'startRootUpdate' in 'mlir::PatternRewriter'
    rewriter.startRootUpdate(op);
    ~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1026:16: error: no member named 'finalizeRootUpdate' in 'mlir::PatternRewriter'
      rewriter.finalizeRootUpdate(op);
      ~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1029:16: error: no member named 'cancelRootUpdate' in 'mlir::PatternRewriter'
      rewriter.cancelRootUpdate(op);
      ~~~~~~~~ ^
external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1108:14: error: no member named 'updateRootInPlace' in 'mlir::PatternRewriter'
    rewriter.updateRootInPlace(op->getParentOp(), [&]() { return; });
    ~~~~~~~~ ^
4 errors generated.
Target @torch-mlir//:torch-mlir-opt failed to build
```

I'm still puzzled as to how this didn't fail with the CMake merge gating
CI (do we not test Stablehlo builds/tests?). In any case, bumping our
submodule to openxla/stablehlo#1918 fixes it.

It exposes a new failing lit test in TorchToStablehlo though, that I
have looped stablehlo developers into
([here](https://discord.com/channels/999073994483433573/999074539138990131/1201235845391331419)).
```
bazel run @torch-mlir//test/Conversion:TorchToStablehlo/scatter.mlir.test 

...external/torch-mlir/test/Conversion/TorchToStablehlo/scatter.mlir
within split at <stdin>:1 offset :33:8: error: unexpected error: Expects non-empty reduction block for type inference                                                                               
  %0 = torch.aten.scatter.src %arg0, %int0, %arg1, %arg2 : !torch.vtensor<[?,?],si64>, !torch.int, !torch.vtensor<[?,?],si64>, !torch.vtensor<[?,?],si64> -> !torch.vtensor<[?,?],si64>             
       ^                                                                                                                                                                                            
LLVM ERROR: Failed to infer result type(s).               
```

Bazel CI:
https://github.com/sjain-stanford/torch-mlir/actions/runs/7732673480/job/21083102228
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants