diff --git a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp index 24cf2f39fc9a09..7d73af4d7103dc 100644 --- a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp +++ b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp @@ -215,14 +215,14 @@ class BoxedProcedurePass rewriter.replaceOpWithNewOp( addr, typeConverter.convertType(addr.getType()), addr.getVal()); } else if (typeConverter.needsConversion(resTy)) { - rewriter.startRootUpdate(op); + rewriter.startOpModification(op); op->getResult(0).setType(typeConverter.convertType(resTy)); - rewriter.finalizeRootUpdate(op); + rewriter.finalizeOpModification(op); } } else if (auto func = mlir::dyn_cast(op)) { mlir::FunctionType ty = func.getFunctionType(); if (typeConverter.needsConversion(ty)) { - rewriter.startRootUpdate(func); + rewriter.startOpModification(func); auto toTy = typeConverter.convertType(ty).cast(); if (!func.empty()) @@ -235,7 +235,7 @@ class BoxedProcedurePass block.eraseArgument(i + 1); } func.setType(toTy); - rewriter.finalizeRootUpdate(func); + rewriter.finalizeOpModification(func); } } else if (auto embox = mlir::dyn_cast(op)) { // Rewrite all `fir.emboxproc` ops to either `fir.convert` or a thunk @@ -273,10 +273,10 @@ class BoxedProcedurePass } else if (auto global = mlir::dyn_cast(op)) { auto ty = global.getType(); if (typeConverter.needsConversion(ty)) { - rewriter.startRootUpdate(global); + rewriter.startOpModification(global); auto toTy = typeConverter.convertType(ty); global.setType(toTy); - rewriter.finalizeRootUpdate(global); + rewriter.finalizeOpModification(global); } } else if (auto mem = mlir::dyn_cast(op)) { auto ty = mem.getType(); @@ -339,17 +339,17 @@ class BoxedProcedurePass mem, toTy, index.getFieldId(), toOnTy, index.getTypeparams()); } } else if (op->getDialect() == firDialect) { - rewriter.startRootUpdate(op); + rewriter.startOpModification(op); for (auto i : llvm::enumerate(op->getResultTypes())) if (typeConverter.needsConversion(i.value())) { auto toTy = typeConverter.convertType(i.value()); op->getResult(i.index()).setType(toTy); } - rewriter.finalizeRootUpdate(op); + rewriter.finalizeOpModification(op); } // Ensure block arguments are updated if needed. if (op->getNumRegions() != 0) { - rewriter.startRootUpdate(op); + rewriter.startOpModification(op); for (mlir::Region ®ion : op->getRegions()) for (mlir::Block &block : region.getBlocks()) for (mlir::BlockArgument blockArg : block.getArguments()) @@ -358,7 +358,7 @@ class BoxedProcedurePass typeConverter.convertType(blockArg.getType()); blockArg.setType(toTy); } - rewriter.finalizeRootUpdate(op); + rewriter.finalizeOpModification(op); } }); } diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index e07732d57880c5..f2c731d47909a9 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -3763,13 +3763,13 @@ class RenameMSVCLibmCallees mlir::LogicalResult matchAndRewrite(mlir::LLVM::CallOp op, mlir::PatternRewriter &rewriter) const override { - rewriter.startRootUpdate(op); + rewriter.startOpModification(op); auto callee = op.getCallee(); if (callee) if (callee->equals("hypotf")) op.setCalleeAttr(mlir::SymbolRefAttr::get(op.getContext(), "_hypotf")); - rewriter.finalizeRootUpdate(op); + rewriter.finalizeOpModification(op); return mlir::success(); } }; @@ -3782,10 +3782,10 @@ class RenameMSVCLibmFuncs mlir::LogicalResult matchAndRewrite(mlir::LLVM::LLVMFuncOp op, mlir::PatternRewriter &rewriter) const override { - rewriter.startRootUpdate(op); + rewriter.startOpModification(op); if (op.getSymName().equals("hypotf")) op.setSymNameAttr(rewriter.getStringAttr("_hypotf")); - rewriter.finalizeRootUpdate(op); + rewriter.finalizeOpModification(op); return mlir::success(); } }; diff --git a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp index 97127f57cc3eb9..641854bd201f0b 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp @@ -256,9 +256,9 @@ struct AssignOpConversion : public mlir::OpConversionPattern { llvm::SmallVector newOperands; for (mlir::Value operand : adaptor.getOperands()) newOperands.push_back(getBufferizedExprStorage(operand)); - rewriter.startRootUpdate(assign); + rewriter.startOpModification(assign); assign->setOperands(newOperands); - rewriter.finalizeRootUpdate(assign); + rewriter.finalizeOpModification(assign); return mlir::success(); } }; @@ -834,9 +834,9 @@ struct ElementalOpConversion // Explicitly delete the body of the elemental to get rid // of any users of hlfir.expr values inside the body as early // as possible. - rewriter.startRootUpdate(elemental); + rewriter.startOpModification(elemental); rewriter.eraseBlock(elemental.getBody()); - rewriter.finalizeRootUpdate(elemental); + rewriter.finalizeOpModification(elemental); rewriter.replaceOp(elemental, bufferizedExpr); return mlir::success(); } diff --git a/flang/lib/Optimizer/Transforms/AffineDemotion.cpp b/flang/lib/Optimizer/Transforms/AffineDemotion.cpp index 0c256deeca4161..da29ae880700e6 100644 --- a/flang/lib/Optimizer/Transforms/AffineDemotion.cpp +++ b/flang/lib/Optimizer/Transforms/AffineDemotion.cpp @@ -114,9 +114,9 @@ class ConvertConversion : public mlir::OpRewritePattern { op.getValue()); return success(); } - rewriter.startRootUpdate(op->getParentOp()); + rewriter.startOpModification(op->getParentOp()); op.getResult().replaceAllUsesWith(op.getValue()); - rewriter.finalizeRootUpdate(op->getParentOp()); + rewriter.finalizeOpModification(op->getParentOp()); rewriter.eraseOp(op); } return success(); diff --git a/flang/lib/Optimizer/Transforms/AffinePromotion.cpp b/flang/lib/Optimizer/Transforms/AffinePromotion.cpp index af2200f6a7b02d..d1831cf1c200cc 100644 --- a/flang/lib/Optimizer/Transforms/AffinePromotion.cpp +++ b/flang/lib/Optimizer/Transforms/AffinePromotion.cpp @@ -464,15 +464,15 @@ class AffineLoopConversion : public mlir::OpRewritePattern { auto affineFor = loopAndIndex.first; auto inductionVar = loopAndIndex.second; - rewriter.startRootUpdate(affineFor.getOperation()); + rewriter.startOpModification(affineFor.getOperation()); affineFor.getBody()->getOperations().splice( std::prev(affineFor.getBody()->end()), loopOps, loopOps.begin(), std::prev(loopOps.end())); - rewriter.finalizeRootUpdate(affineFor.getOperation()); + rewriter.finalizeOpModification(affineFor.getOperation()); - rewriter.startRootUpdate(loop.getOperation()); + rewriter.startOpModification(loop.getOperation()); loop.getInductionVar().replaceAllUsesWith(inductionVar); - rewriter.finalizeRootUpdate(loop.getOperation()); + rewriter.finalizeOpModification(loop.getOperation()); rewriteMemoryOps(affineFor.getBody(), rewriter); @@ -561,7 +561,7 @@ class AffineIfConversion : public mlir::OpRewritePattern { auto affineIf = rewriter.create( op.getLoc(), affineCondition.getIntegerSet(), affineCondition.getAffineArgs(), !op.getElseRegion().empty()); - rewriter.startRootUpdate(affineIf); + rewriter.startOpModification(affineIf); affineIf.getThenBlock()->getOperations().splice( std::prev(affineIf.getThenBlock()->end()), ifOps, ifOps.begin(), std::prev(ifOps.end())); @@ -571,7 +571,7 @@ class AffineIfConversion : public mlir::OpRewritePattern { std::prev(affineIf.getElseBlock()->end()), otherOps, otherOps.begin(), std::prev(otherOps.end())); } - rewriter.finalizeRootUpdate(affineIf); + rewriter.finalizeOpModification(affineIf); rewriteMemoryOps(affineIf.getBody(), rewriter); LLVM_DEBUG(llvm::dbgs() << "AffineIfConversion: if converted to:\n"; diff --git a/flang/lib/Optimizer/Transforms/ExternalNameConversion.cpp b/flang/lib/Optimizer/Transforms/ExternalNameConversion.cpp index 221e93ff85e18e..bc5be3f196b81a 100644 --- a/flang/lib/Optimizer/Transforms/ExternalNameConversion.cpp +++ b/flang/lib/Optimizer/Transforms/ExternalNameConversion.cpp @@ -76,7 +76,7 @@ struct MangleNameOnFuncOp : public mlir::OpRewritePattern { matchAndRewrite(mlir::func::FuncOp op, mlir::PatternRewriter &rewriter) const override { mlir::LogicalResult ret = success(); - rewriter.startRootUpdate(op); + rewriter.startOpModification(op); llvm::StringRef oldName = op.getSymName(); auto result = fir::NameUniquer::deconstruct(oldName); if (fir::NameUniquer::isExternalFacingUniquedName(result)) { @@ -95,7 +95,7 @@ struct MangleNameOnFuncOp : public mlir::OpRewritePattern { } updateEarlyOutliningParentName(op, appendUnderscore); - rewriter.finalizeRootUpdate(op); + rewriter.finalizeOpModification(op); return ret; } @@ -114,7 +114,7 @@ struct MangleNameForCommonBlock : public mlir::OpRewritePattern { mlir::LogicalResult matchAndRewrite(fir::GlobalOp op, mlir::PatternRewriter &rewriter) const override { - rewriter.startRootUpdate(op); + rewriter.startOpModification(op); auto result = fir::NameUniquer::deconstruct( op.getSymref().getRootReference().getValue()); if (fir::NameUniquer::isExternalFacingUniquedName(result)) { @@ -122,7 +122,7 @@ struct MangleNameForCommonBlock : public mlir::OpRewritePattern { op.setSymrefAttr(mlir::SymbolRefAttr::get(op.getContext(), newName)); SymbolTable::setSymbolName(op, newName); } - rewriter.finalizeRootUpdate(op); + rewriter.finalizeOpModification(op); return success(); } diff --git a/mlir/docs/PatternRewriter.md b/mlir/docs/PatternRewriter.md index 8fe5ef35a76039..011cd14175634b 100644 --- a/mlir/docs/PatternRewriter.md +++ b/mlir/docs/PatternRewriter.md @@ -213,15 +213,15 @@ user is determined by the specific pattern driver. This method replaces an operation's results with a set of provided values, and erases the operation. -* Update an Operation in-place : `(start|cancel|finalize)RootUpdate` +* Update an Operation in-place : `(start|cancel|finalize)OpModification` This is a collection of methods that provide a transaction-like API for updating the attributes, location, operands, or successors of an operation in-place within a pattern. An in-place update transaction is started with -`startRootUpdate`, and may either be canceled or finalized with -`cancelRootUpdate` and `finalizeRootUpdate` respectively. A convenience wrapper, -`updateRootInPlace`, is provided that wraps a `start` and `finalize` around a -callback. +`startOpModification`, and may either be canceled or finalized with +`cancelOpModification` and `finalizeOpModification` respectively. A convenience +wrapper, `modifyOpInPlace`, is provided that wraps a `start` and `finalize` +around a callback. * OpBuilder API diff --git a/mlir/examples/standalone/lib/Standalone/StandalonePasses.cpp b/mlir/examples/standalone/lib/Standalone/StandalonePasses.cpp index d438cb46ecdada..a23d0420f04350 100644 --- a/mlir/examples/standalone/lib/Standalone/StandalonePasses.cpp +++ b/mlir/examples/standalone/lib/Standalone/StandalonePasses.cpp @@ -24,7 +24,7 @@ class StandaloneSwitchBarFooRewriter : public OpRewritePattern { LogicalResult matchAndRewrite(func::FuncOp op, PatternRewriter &rewriter) const final { if (op.getSymName() == "bar") { - rewriter.updateRootInPlace(op, [&op]() { op.setSymName("foo"); }); + rewriter.modifyOpInPlace(op, [&op]() { op.setSymName("foo"); }); return success(); } return failure(); diff --git a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp index 240b9f9338665a..ae4bd980c34b53 100644 --- a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp @@ -260,8 +260,8 @@ struct PrintOpLowering : public OpConversionPattern { ConversionPatternRewriter &rewriter) const final { // We don't lower "toy.print" in this pass, but we need to update its // operands. - rewriter.updateRootInPlace(op, - [&] { op->setOperands(adaptor.getOperands()); }); + rewriter.modifyOpInPlace(op, + [&] { op->setOperands(adaptor.getOperands()); }); return success(); } }; diff --git a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp index 240b9f9338665a..ae4bd980c34b53 100644 --- a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp @@ -260,8 +260,8 @@ struct PrintOpLowering : public OpConversionPattern { ConversionPatternRewriter &rewriter) const final { // We don't lower "toy.print" in this pass, but we need to update its // operands. - rewriter.updateRootInPlace(op, - [&] { op->setOperands(adaptor.getOperands()); }); + rewriter.modifyOpInPlace(op, + [&] { op->setOperands(adaptor.getOperands()); }); return success(); } }; diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp index 240b9f9338665a..ae4bd980c34b53 100644 --- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp @@ -260,8 +260,8 @@ struct PrintOpLowering : public OpConversionPattern { ConversionPatternRewriter &rewriter) const final { // We don't lower "toy.print" in this pass, but we need to update its // operands. - rewriter.updateRootInPlace(op, - [&] { op->setOperands(adaptor.getOperands()); }); + rewriter.modifyOpInPlace(op, + [&] { op->setOperands(adaptor.getOperands()); }); return success(); } }; diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 9b4fa65bff49e1..b065d4e8d37689 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -585,28 +585,30 @@ class RewriterBase : public OpBuilder { /// This method is used to notify the rewriter that an in-place operation /// modification is about to happen. A call to this function *must* be - /// followed by a call to either `finalizeRootUpdate` or `cancelRootUpdate`. - /// This is a minor efficiency win (it avoids creating a new operation and - /// removing the old one) but also often allows simpler code in the client. - virtual void startRootUpdate(Operation *op) {} - - /// This method is used to signal the end of a root update on the given - /// operation. This can only be called on operations that were provided to a - /// call to `startRootUpdate`. - virtual void finalizeRootUpdate(Operation *op); - - /// This method cancels a pending root update. This can only be called on - /// operations that were provided to a call to `startRootUpdate`. - virtual void cancelRootUpdate(Operation *op) {} - - /// This method is a utility wrapper around a root update of an operation. It - /// wraps calls to `startRootUpdate` and `finalizeRootUpdate` around the given - /// callable. + /// followed by a call to either `finalizeOpModification` or + /// `cancelOpModification`. This is a minor efficiency win (it avoids creating + /// a new operation and removing the old one) but also often allows simpler + /// code in the client. + virtual void startOpModification(Operation *op) {} + + /// This method is used to signal the end of an in-place modification of the + /// given operation. This can only be called on operations that were provided + /// to a call to `startOpModification`. + virtual void finalizeOpModification(Operation *op); + + /// This method cancels a pending in-place modification. This can only be + /// called on operations that were provided to a call to + /// `startOpModification`. + virtual void cancelOpModification(Operation *op) {} + + /// This method is a utility wrapper around an in-place modification of an + /// operation. It wraps calls to `startOpModification` and + /// `finalizeOpModification` around the given callable. template - void updateRootInPlace(Operation *root, CallableT &&callable) { - startRootUpdate(root); + void modifyOpInPlace(Operation *root, CallableT &&callable) { + startOpModification(root); callable(); - finalizeRootUpdate(root); + finalizeOpModification(root); } /// Find uses of `from` and replace them with `to`. It also marks every @@ -619,7 +621,7 @@ class RewriterBase : public OpBuilder { void replaceAllUsesWith(IRObjectWithUseList *from, ValueT &&to) { for (OperandType &operand : llvm::make_early_inc_range(from->getUses())) { Operation *op = operand.getOwner(); - updateRootInPlace(op, [&]() { operand.set(to); }); + modifyOpInPlace(op, [&]() { operand.set(to); }); } } void replaceAllUsesWith(ValueRange from, ValueRange to) { diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index c5725e9c856256..9568540789df3f 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -739,17 +739,17 @@ class ConversionPatternRewriter final : public PatternRewriter, /// PatternRewriter hook for inserting a new operation. void notifyOperationInserted(Operation *op) override; - /// PatternRewriter hook for updating the root operation in-place. - /// Note: These methods only track updates to the top-level operation itself, + /// PatternRewriter hook for updating the given operation in-place. + /// Note: These methods only track updates to the given operation itself, /// and not nested regions. Updates to regions will still require notification /// through other more specific hooks above. - void startRootUpdate(Operation *op) override; + void startOpModification(Operation *op) override; - /// PatternRewriter hook for updating the root operation in-place. - void finalizeRootUpdate(Operation *op) override; + /// PatternRewriter hook for updating the given operation in-place. + void finalizeOpModification(Operation *op) override; - /// PatternRewriter hook for updating the root operation in-place. - void cancelRootUpdate(Operation *op) override; + /// PatternRewriter hook for updating the given operation in-place. + void cancelOpModification(Operation *op) override; /// PatternRewriter hook for notifying match failure reasons. LogicalResult diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp index 16214d72fcddc2..bbef3b996e40b8 100644 --- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp +++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp @@ -255,7 +255,7 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern { // Step 2. Assign the op a real tile ID. // For simplicity, we always use tile 0 (which always exists). auto zeroTileId = rewriter.getI32IntegerAttr(0); - rewriter.updateRootInPlace(tileOp, [&] { tileOp.setTileId(zeroTileId); }); + rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(zeroTileId); }); VectorType tileVectorType = tileOp.getTileType(); auto sliceType = VectorType::Builder(tileVectorType).dropDim(0); diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp index 94df3765a67e74..f853d5c47b623c 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -918,8 +918,7 @@ LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite( for (auto stream : streams) streamDestroyCallBuilder.create(loc, rewriter, {stream}); - rewriter.updateRootInPlace(yieldOp, - [&] { yieldOp->setOperands(newOperands); }); + rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperands(newOperands); }); return success(); } diff --git a/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp b/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp index 8c1a7d9c6b2a43..54e6bec12b897c 100644 --- a/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp +++ b/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp @@ -43,14 +43,13 @@ class ExpandIfCondition : public OpRewritePattern { if (!matchPattern(op.getIfCond(), m_Constant(&constAttr))) { auto ifOp = rewriter.create(op.getLoc(), TypeRange(), op.getIfCond(), false); - rewriter.updateRootInPlace(op, [&]() { op.getIfCondMutable().erase(0); }); + rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); }); auto thenBodyBuilder = ifOp.getThenBodyBuilder(rewriter.getListener()); thenBodyBuilder.clone(*op.getOperation()); rewriter.eraseOp(op); } else { if (constAttr.getInt()) - rewriter.updateRootInPlace(op, - [&]() { op.getIfCondMutable().erase(0); }); + rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); }); else rewriter.eraseOp(op); } diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp index 44fbac1935fed7..f8485e02a2208e 100644 --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -645,13 +645,13 @@ struct PrepareTransferWriteConversion rewriter.create(loc, xferOp.getVector(), buffers.dataBuffer); auto loadedVec = rewriter.create(loc, buffers.dataBuffer); - rewriter.updateRootInPlace(xferOp, [&]() { + rewriter.modifyOpInPlace(xferOp, [&]() { xferOp.getVectorMutable().assign(loadedVec); xferOp->setAttr(kPassLabel, rewriter.getUnitAttr()); }); if (xferOp.getMask()) { - rewriter.updateRootInPlace(xferOp, [&]() { + rewriter.modifyOpInPlace(xferOp, [&]() { xferOp.getMaskMutable().assign(buffers.maskBuffer); }); } @@ -966,7 +966,7 @@ struct TransferOpConversion : public VectorToSCFPattern { loadIndices, iv); auto mask = b.create(loc, castedMaskBuffer, loadIndices); - rewriter.updateRootInPlace(newXfer, [&]() { + rewriter.modifyOpInPlace(newXfer, [&]() { newXfer.getMaskMutable().assign(mask); }); } diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index d5be2e906989fa..c260e68d509e98 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -2493,7 +2493,7 @@ FailureOr AffineForOp::replaceWithAdditionalYields( newYieldValuesFn(rewriter, getLoc(), newIterArgs); assert(newInitOperands.size() == newYieldedValues.size() && "expected as many new yield values as new iter operands"); - rewriter.updateRootInPlace(yieldOp, [&]() { + rewriter.modifyOpInPlace(yieldOp, [&]() { yieldOp.getOperandsMutable().append(newYieldedValues); }); } @@ -2686,9 +2686,9 @@ struct SimplifyDeadElse : public OpRewritePattern { !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults()) return failure(); - rewriter.startRootUpdate(ifOp); + rewriter.startOpModification(ifOp); rewriter.eraseBlock(ifOp.getElseBlock()); - rewriter.finalizeRootUpdate(ifOp); + rewriter.finalizeOpModification(ifOp); return success(); } }; diff --git a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp index e5501e848c1646..f28fb3acb7db7f 100644 --- a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp @@ -71,10 +71,10 @@ void mlir::affine::reorderOperandsByHoistability(RewriterBase &rewriter, op->getContext()); canonicalizeMapAndOperands(&map, &operands); - rewriter.startRootUpdate(op); + rewriter.startOpModification(op); op.setMap(map); op->setOperands(operands); - rewriter.finalizeRootUpdate(op); + rewriter.finalizeOpModification(op); } /// Build an affine.apply that is a subexpression `expr` of `originalOp`s affine diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp index 4d49efecbe05c3..4acb2a8fb7b539 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp @@ -218,7 +218,7 @@ struct AssignTileIDsPattern return defaultVal; }; auto setDiscardableIntAttr = [&](StringRef name, auto value) { - rewriter.updateRootInPlace(tileOp, [&] { + rewriter.modifyOpInPlace(tileOp, [&] { func->setDiscardableAttr(name, rewriter.getI32IntegerAttr((unsigned)value)); }); @@ -274,10 +274,10 @@ struct AssignTileIDsPattern setDiscardableIntAttr(kTilesInUseAttr, tilesInUse); else setDiscardableIntAttr(kNextInMemoryTileIdAttr, *tileId + 1); - rewriter.updateRootInPlace(tileOp, [&] { tileOp.setTileId(tileIDAttr); }); + rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIDAttr); }); for (auto *op : dependantOps) { if (auto dependantTileOp = llvm::dyn_cast(op)) { - rewriter.updateRootInPlace( + rewriter.modifyOpInPlace( dependantTileOp, [&] { dependantTileOp.setTileId(tileIDAttr); }); } } diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp index 92278c0d74d574..32c87c1b824074 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp @@ -30,8 +30,8 @@ class ForwardOperands : public OpConversionPattern { if (adaptor.getOperands().getTypes() == op->getOperands().getTypes()) return rewriter.notifyMatchFailure(op, "operand types already match"); - rewriter.updateRootInPlace( - op, [&]() { op->setOperands(adaptor.getOperands()); }); + rewriter.modifyOpInPlace(op, + [&]() { op->setOperands(adaptor.getOperands()); }); return success(); } }; diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp index bf627d95ae5573..8b4bacd7227121 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp @@ -106,8 +106,8 @@ struct RelaxScalableVectorAllocaAlignment // Set alignment based on the defaults for SVE vectors and predicates. unsigned aligment = vectorType.getElementType().isInteger(1) ? 2 : 16; - rewriter.updateRootInPlace(allocaOp, - [&] { allocaOp.setAlignment(aligment); }); + rewriter.modifyOpInPlace(allocaOp, + [&] { allocaOp.setAlignment(aligment); }); return success(); } diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index a0bb8715f2c561..4b1dfee4a2b926 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -253,7 +253,7 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts( copiedOpOperands.contains(opOperand)); if (failed(copy)) return failure(); - rewriter.updateRootInPlace(op, [&]() { opOperand->set(*copy); }); + rewriter.modifyOpInPlace(op, [&]() { opOperand->set(*copy); }); } // Insert copies of Values. @@ -274,7 +274,7 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts( // dynamic extents. Do not update these either. if (isa(use->getOwner())) continue; - rewriter.updateRootInPlace(use->getOwner(), [&]() { use->set(*copy); }); + rewriter.modifyOpInPlace(use->getOwner(), [&]() { use->set(*copy); }); } } diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp index 94bc2bcea63be9..253fcf2525121b 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -895,7 +895,7 @@ static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp, deallocOp.getConditions() == conditions) return failure(); - rewriter.updateRootInPlace(deallocOp, [&]() { + rewriter.modifyOpInPlace(deallocOp, [&]() { deallocOp.getMemrefsMutable().assign(memrefs); deallocOp.getConditionsMutable().assign(conditions); }); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp index 42653517249d66..75d65193809f10 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp @@ -42,7 +42,7 @@ static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp, deallocOp.getConditions() == conditions) return failure(); - rewriter.updateRootInPlace(deallocOp, [&]() { + rewriter.modifyOpInPlace(deallocOp, [&]() { deallocOp.getMemrefsMutable().assign(memrefs); deallocOp.getConditionsMutable().assign(conditions); }); diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp index 999c04e48ee168..d242d75bd51fa7 100644 --- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp +++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp @@ -403,8 +403,8 @@ struct CondBranchTruthPropagation : public OpRewritePattern { constantTrue = rewriter.create( condbr.getLoc(), ty, rewriter.getBoolAttr(true)); - rewriter.updateRootInPlace(use.getOwner(), - [&] { use.set(constantTrue); }); + rewriter.modifyOpInPlace(use.getOwner(), + [&] { use.set(constantTrue); }); } } } @@ -418,8 +418,8 @@ struct CondBranchTruthPropagation : public OpRewritePattern { constantFalse = rewriter.create( condbr.getLoc(), ty, rewriter.getBoolAttr(false)); - rewriter.updateRootInPlace(use.getOwner(), - [&] { use.set(constantFalse); }); + rewriter.modifyOpInPlace(use.getOwner(), + [&] { use.set(constantFalse); }); } } } diff --git a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp index 98ae826b6497fb..fa030cb18e035d 100644 --- a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp +++ b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp @@ -86,7 +86,7 @@ struct DecomposeCallGraphTypesForFuncArgs if (failed(typeConverter->convertTypes(functionType.getResults(), newResultTypes))) return failure(); - rewriter.updateRootInPlace(op, [&] { + rewriter.modifyOpInPlace(op, [&] { op.setType(rewriter.getFunctionType(conversion.getConvertedTypes(), newResultTypes)); }); diff --git a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp index 742830ec722f17..d1f3b56dbed738 100644 --- a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp +++ b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp @@ -84,7 +84,7 @@ class BranchOpInterfaceTypeConversion newOperands[idx] = operands[idx]; } } - rewriter.updateRootInPlace( + rewriter.modifyOpInPlace( op, [newOperands, op]() { op->setOperands(newOperands); }); return success(); } @@ -107,8 +107,8 @@ class ReturnOpTypeConversion : public OpConversionPattern { ConversionPatternRewriter &rewriter) const final { // For a return, all operands go to the results of the parent, so // rewrite them all. - rewriter.updateRootInPlace(op, - [&] { op->setOperands(adaptor.getOperands()); }); + rewriter.modifyOpInPlace(op, + [&] { op->setOperands(adaptor.getOperands()); }); return success(); } }; diff --git a/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp b/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp index 70056932411215..c04986cad84f9d 100644 --- a/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp +++ b/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp @@ -80,7 +80,7 @@ class ConvertTypesInFuncFuncOp : public OneToNOpConversionPattern { auto newType = FunctionType::get(rewriter.getContext(), argumentMapping.getConvertedTypes(), funcResultMapping.getConvertedTypes()); - rewriter.updateRootInPlace(op, [&] { op.setType(newType); }); + rewriter.modifyOpInPlace(op, [&] { op.setType(newType); }); // Update block signatures. if (!op.isExternal()) { @@ -105,7 +105,7 @@ class ConvertTypesInFuncReturnOp : public OneToNOpConversionPattern { return failure(); // Convert operands. - rewriter.updateRootInPlace( + rewriter.modifyOpInPlace( op, [&] { op->setOperands(adaptor.getFlatOperands()); }); return success(); diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 514b3e9a6e8a56..30b6cd74147e6f 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -2030,7 +2030,7 @@ struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern { continue; validOperands.push_back(operand); } - rewriter.updateRootInPlace(op, [&]() { op->setOperands(validOperands); }); + rewriter.modifyOpInPlace(op, [&]() { op->setOperands(validOperands); }); return success(); } }; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp index 96a0ef591c1cfe..bf24194d03ddb2 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp @@ -301,7 +301,7 @@ DeletionKind LLVM::DbgValueOp::removeBlockingUses( // the variable has been optimized out. auto undef = rewriter.create(getValue().getLoc(), getValue().getType()); - rewriter.updateRootInPlace(*this, [&] { getValueMutable().assign(undef); }); + rewriter.modifyOpInPlace(*this, [&] { getValueMutable().assign(undef); }); return DeletionKind::Keep; } @@ -394,7 +394,7 @@ DeletionKind LLVM::GEPOp::rewire(const DestructurableMemorySlot &slot, return DeletionKind::Delete; } - rewriter.updateRootInPlace(*this, [&]() { + rewriter.modifyOpInPlace(*this, [&]() { // Rewire the indices by popping off the second index. // Start with a single zero, then add the indices beyond the second. SmallVector newIndices(1); diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp index cf900ac0be8fd2..72f9295749a66b 100644 --- a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp +++ b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp @@ -83,8 +83,8 @@ static void insertFieldIndirection(MemOp op, PatternRewriter &rewriter, op->getLoc(), LLVM::LLVMPointerType::get(op.getContext()), elemType, op.getAddr(), firstTypeIndices); - rewriter.updateRootInPlace(op, - [&]() { op.getAddrMutable().assign(properPtr); }); + rewriter.modifyOpInPlace(op, + [&]() { op.getAddrMutable().assign(properPtr); }); } template <> @@ -111,8 +111,8 @@ LogicalResult AddFieldGetterToStructDirectUse::matchAndRewrite( rewriter.setInsertionPointAfterValue(load.getResult()); BitcastOp bitcast = rewriter.create( load->getLoc(), load.getResult().getType(), load.getResult()); - rewriter.updateRootInPlace(load, - [&]() { load.getResult().setType(firstType); }); + rewriter.modifyOpInPlace(load, + [&]() { load.getResult().setType(firstType); }); rewriter.replaceAllUsesExcept(load.getResult(), bitcast.getResult(), bitcast); } @@ -141,7 +141,7 @@ LogicalResult AddFieldGetterToStructDirectUse::matchAndRewrite( insertFieldIndirection(store, rewriter, inconsistentElementType); - rewriter.updateRootInPlace( + rewriter.modifyOpInPlace( store, [&]() { store.getValueMutable().assign(store.getValue()); }); return success(); @@ -630,8 +630,8 @@ LogicalResult BitcastStores::matchAndRewrite(StoreOp store, auto bitcastOp = rewriter.create(store.getLoc(), typeHint, store.getValue()); - rewriter.updateRootInPlace( - store, [&] { store.getValueMutable().assign(bitcastOp); }); + rewriter.modifyOpInPlace(store, + [&] { store.getValueMutable().assign(bitcastOp); }); return success(); } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 139566d350fe83..f7cfe8abddb2e8 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -785,7 +785,7 @@ tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( rewriter.replaceOp(sliceOpToTile, *maybeRankReduced); // Replace the use in containingOp. - rewriter.updateRootInPlace(containingOp, [&]() { + rewriter.modifyOpInPlace(containingOp, [&]() { containingOp->setOperand(pUse->getOperandNumber(), destinationTensors.front()); }); @@ -835,7 +835,7 @@ static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag, OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(use->getOwner()); fusedOp = rewriter.clone(*producerOp); - rewriter.updateRootInPlace( + rewriter.modifyOpInPlace( use->getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); }); return fusedOp; diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp index d8df5d82e28759..ff13aaf9b4abca 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp @@ -311,7 +311,7 @@ Value linalg::bufferizeToAllocation( auto toTensorOp = resultUse->get().getDefiningOp(); assert(toTensorOp && "expected to_tensor op"); - rewriter.updateRootInPlace(toTensorOp, [&]() { + rewriter.modifyOpInPlace(toTensorOp, [&]() { toTensorOp.setRestrict(true); toTensorOp.setWritable(true); }); @@ -559,11 +559,11 @@ Value linalg::bufferizeToAllocation( // tensor is uninitialized. createMemcpy(rewriter, op->getLoc(), operand->get(), alloc, options); } - rewriter.updateRootInPlace(op, [&]() { + rewriter.modifyOpInPlace(op, [&]() { auto toTensorOp = rewriter.create(op->getLoc(), alloc); operand->set(toTensorOp); if (options.bufferizeDestinationOnly) { - rewriter.updateRootInPlace(toTensorOp, [&]() { + rewriter.modifyOpInPlace(toTensorOp, [&]() { toTensorOp.setRestrict(true); toTensorOp.setWritable(true); }); @@ -584,7 +584,7 @@ Value linalg::bufferizeToAllocation( for (OpOperand *resultUse : resultUses) { auto toTensorOp = resultUse->get().getDefiningOp(); assert(toTensorOp && "expected to_tensor op"); - rewriter.updateRootInPlace(toTensorOp, [&]() { + rewriter.modifyOpInPlace(toTensorOp, [&]() { toTensorOp.setRestrict(true); toTensorOp.setWritable(true); }); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp index bf91a708ae1589..98cd0444760ece 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -104,7 +104,7 @@ struct FunctionNonEntryBlockConversion LogicalResult matchAndRewrite(FunctionOpInterface op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - rewriter.startRootUpdate(op); + rewriter.startOpModification(op); Region ®ion = op.getFunctionBody(); SmallVector conversions; @@ -125,11 +125,11 @@ struct FunctionNonEntryBlockConversion if (failed(rewriter.convertNonEntryRegionTypes(®ion, *typeConverter, conversions))) { - rewriter.cancelRootUpdate(op); + rewriter.cancelOpModification(op); return failure(); } - rewriter.finalizeRootUpdate(op); + rewriter.finalizeOpModification(op); return success(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 031f5c7a5d4783..e4cb2f223f3c7e 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1816,7 +1816,7 @@ struct RemoveOutsDependency : public OpRewritePattern { LogicalResult matchAndRewrite(GenericOp op, PatternRewriter &rewriter) const override { - rewriter.startRootUpdate(op); + rewriter.startOpModification(op); bool modifiedOutput = false; Location loc = op.getLoc(); for (OpOperand &opOperand : op.getDpsInitsMutable()) { @@ -1843,10 +1843,10 @@ struct RemoveOutsDependency : public OpRewritePattern { } } if (!modifiedOutput) { - rewriter.cancelRootUpdate(op); + rewriter.cancelOpModification(op); return failure(); } - rewriter.finalizeRootUpdate(op); + rewriter.finalizeOpModification(op); return success(); } }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp index f28f8f0d34a4da..81669a1807796c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp @@ -87,7 +87,7 @@ LogicalResult linalg::linalgOpAnchoredEmptyTensorEliminationStep( } // Turn the "in" into an "out". - rewriter.updateRootInPlace(op, [&]() { + rewriter.modifyOpInPlace(op, [&]() { out->set(in->get()); // The original "in" could be removed entirely here (because it will no // longer have any uses in the payload), but we delegate this to diff --git a/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp index 3378eda2bd6734..16ab45ea8bee63 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp @@ -354,7 +354,7 @@ struct RemoveUnusedCycleInGenericOp : public OpRewritePattern { // Directly replace the cycle with the blockArg such that // Deduplicate pattern can eliminate it along with unused yield. rewriter.replaceOp(cycleOp, outputArg); - rewriter.updateRootInPlace(genericOp, [] {}); + rewriter.modifyOpInPlace(genericOp, [] {}); hasRemovedCycles = true; } @@ -404,7 +404,7 @@ struct FoldDuplicateInputBbArgs : public OpRewritePattern { return failure(); // Rewrite the op. - rewriter.updateRootInPlace(genericOp, [&]() { + rewriter.modifyOpInPlace(genericOp, [&]() { for (auto [before, after] : replacements) { BlockArgument bbArg = genericOp.getBody()->getArgument(before); BlockArgument replacement = genericOp.getBody()->getArgument(after); diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp index 805c9d4ed3b79f..b32ea8eebaecb9 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -854,10 +854,10 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting, LLVM_DEBUG(DBGS() << "with result #" << numOriginalForOpResults + iterArgNumber << " of forOp, giving us: " << extracted << "\n"); - rewriter.startRootUpdate(extracted); + rewriter.startOpModification(extracted); extracted.getSourceMutable().assign( newForOp.getResult(numOriginalForOpResults + iterArgNumber)); - rewriter.finalizeRootUpdate(extracted); + rewriter.finalizeOpModification(extracted); LLVM_DEBUG(DBGS() << "replace uses of: " << paddedValueBeforeHoisting << "\n"); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp index f46ba71599b3fd..a0faeb524c57db 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp @@ -60,9 +60,9 @@ mlir::linalg::interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, assert(permutationMap && "unexpected null map"); // Start a guarded inplace update. - rewriter.startRootUpdate(genericOp); - auto guard = - llvm::make_scope_exit([&]() { rewriter.finalizeRootUpdate(genericOp); }); + rewriter.startOpModification(genericOp); + auto guard = llvm::make_scope_exit( + [&]() { rewriter.finalizeOpModification(genericOp); }); // 2. Compute the interchanged indexing maps. SmallVector newIndexingMaps; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp index bbe3a542f66b88..0174db45a83db2 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp @@ -113,7 +113,7 @@ linalg::splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, // Need to pretend that the original op now takes as operands firstResults, // otherwise tiling interface implementation will take the wrong value to // produce data tiles. - rewriter.updateRootInPlace(op, [&]() { + rewriter.modifyOpInPlace(op, [&]() { unsigned numTotalOperands = op->getNumOperands(); unsigned numOutputOperands = firstResults.size(); op->setOperands(numTotalOperands - numOutputOperands, numOutputOperands, diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index 7f3ab1f1a24b2f..ebf80e3c5dc685 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -722,7 +722,7 @@ FailureOr linalg::tileReductionUsingForall( // We cannot use a IRMapping here because it can replace // different OpOperands with the same value. Operation *clonedOp = b.clone(*op.getOperation()); - b.updateRootInPlace(clonedOp, [&]() { + b.modifyOpInPlace(clonedOp, [&]() { for (auto [initOperandPtr, tiledInitValue] : llvm::zip_equal( cast(clonedOp).getDpsInitsMutable(), tiledDpsInitOperands)) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index dc348ea827cde1..0610f24ddaf471 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1952,7 +1952,7 @@ struct PadOpVectorizationWithTransferReadPattern if (xferOp.hasOutOfBoundsDim() || xferOp.getMask()) return failure(); - rewriter.updateRootInPlace(xferOp, [&]() { + rewriter.modifyOpInPlace(xferOp, [&]() { SmallVector inBounds(xferOp.getVectorType().getRank(), false); xferOp->setAttr(xferOp.getInBoundsAttrName(), rewriter.getBoolArrayAttr(inBounds)); diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp index be301c191d5139..561b8619032cce 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp @@ -227,7 +227,7 @@ DeletionKind memref::LoadOp::rewire(const DestructurableMemorySlot &slot, Attribute index = getAttributeIndexFromIndexOperands( getContext(), getIndices(), getMemRefType()); const MemorySlot &memorySlot = subslots.at(index); - rewriter.updateRootInPlace(*this, [&]() { + rewriter.modifyOpInPlace(*this, [&]() { setMemRef(memorySlot.ptr); getIndicesMutable().clear(); }); @@ -280,7 +280,7 @@ DeletionKind memref::StoreOp::rewire(const DestructurableMemorySlot &slot, Attribute index = getAttributeIndexFromIndexOperands( getContext(), getIndices(), getMemRefType()); const MemorySlot &memorySlot = subslots.at(index); - rewriter.updateRootInPlace(*this, [&]() { + rewriter.modifyOpInPlace(*this, [&]() { setMemRef(memorySlot.ptr); getIndicesMutable().clear(); }); diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 394640f9ebac89..b79ab8f3d671e0 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -792,7 +792,7 @@ struct FoldCopyOfCast : public OpRewritePattern { if (fromType && toType) { if (fromType.getShape() == toType.getShape() && fromType.getElementType() == toType.getElementType()) { - rewriter.updateRootInPlace(copyOp, [&] { + rewriter.modifyOpInPlace(copyOp, [&] { copyOp.getSourceMutable().assign(castOp.getSource()); }); modified = true; @@ -808,7 +808,7 @@ struct FoldCopyOfCast : public OpRewritePattern { if (fromType && toType) { if (fromType.getShape() == toType.getShape() && fromType.getElementType() == toType.getElementType()) { - rewriter.updateRootInPlace(copyOp, [&] { + rewriter.modifyOpInPlace(copyOp, [&] { copyOp.getTargetMutable().assign(castOp.getSource()); }); modified = true; @@ -1366,7 +1366,7 @@ static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, loc, llvm::cast(maybeConstant.template get()) .getInt()); for (Operation *op : llvm::make_early_inc_range(result.getUsers())) { - // updateRootInplace: lambda cannot capture structured bindings in C++17 + // modifyOpInPlace: lambda cannot capture structured bindings in C++17 // yet. op->replaceUsesOfWith(result, constantVal); atLeastOneReplacement = true; @@ -2436,7 +2436,7 @@ struct CollapseShapeOpMemRefCastFolder op.getReassociationIndices()); if (newResultType == op.getResultType()) { - rewriter.updateRootInPlace( + rewriter.modifyOpInPlace( op, [&]() { op.getSrcMutable().assign(cast.getSource()); }); } else { Value newOp = rewriter.create( diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp index 101e099d2b644c..8047c60187b2fd 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -797,7 +797,7 @@ class RewriteExtractAlignedPointerAsIndexOfViewLikeOp extractOp.getSource().getDefiningOp(); if (!viewLikeOp) return rewriter.notifyMatchFailure(extractOp, "not a ViewLike source"); - rewriter.updateRootInPlace(extractOp, [&]() { + rewriter.modifyOpInPlace(extractOp, [&]() { extractOp.getSourceMutable().assign(viewLikeOp.getViewSource()); }); return success(); diff --git a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp index 03765e95b01e7a..10ba508265e7b9 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp @@ -154,7 +154,7 @@ static void replaceAndPropagateMemRefType(RewriterBase &rewriter, for (OpOperand &operand : user->getOpOperands()) { if ([[maybe_unused]] auto castOp = operand.get().getDefiningOp()) { - rewriter.updateRootInPlace( + rewriter.modifyOpInPlace( user, [&]() { operand.set(conversion->getOperand(0)); }); } } diff --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp index 397bd5856bcb07..bc0dd034f63851 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp @@ -79,9 +79,9 @@ static void replaceUsesAndPropagateType(RewriterBase &rewriter, // TODO: can we use an early_inc iterator? for (OpOperand *operand : operandsToReplace) { Operation *op = operand->getOwner(); - rewriter.startRootUpdate(op); + rewriter.startOpModification(op); operand->set(val); - rewriter.finalizeRootUpdate(op); + rewriter.finalizeOpModification(op); } // Perform late op erasure. diff --git a/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp b/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp index 8bfb4be5225f4a..8163f428683d8d 100644 --- a/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp +++ b/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp @@ -54,7 +54,7 @@ struct MmaSyncF32ToTF32Pattern : public OpRewritePattern { "for nvgpu.mma.sync on f32 datatype"); if (precision == MmaSyncF32Lowering::TF32) { - rewriter.updateRootInPlace( + rewriter.modifyOpInPlace( op, [&]() { op.setTf32EnabledAttr(rewriter.getUnitAttr()); }); } diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index bf3264b5da9802..8698c00d1cb728 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -359,7 +359,7 @@ struct RemoveConstantIfCondition : public OpRewritePattern { if (!matchPattern(ifCond, m_Constant(&constAttr))) return failure(); if (constAttr.getInt()) - rewriter.updateRootInPlace(op, [&]() { op.getIfCondMutable().erase(0); }); + rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); }); else rewriter.eraseOp(op); @@ -398,7 +398,7 @@ struct RemoveConstantIfConditionWithRegion : public OpRewritePattern { if (!matchPattern(ifCond, m_Constant(&constAttr))) return failure(); if (constAttr.getInt()) - rewriter.updateRootInPlace(op, [&]() { op.getIfCondMutable().erase(0); }); + rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); }); else replaceOpWithRegion(rewriter, op, op.getRegion()); diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index cdc0b6f1696ae9..45cc7479f209b5 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -552,7 +552,7 @@ ForOp::replaceWithAdditionalYields(RewriterBase &rewriter, newYieldValuesFn(rewriter, getLoc(), newIterArgs); assert(newInitOperands.size() == newYieldedValues.size() && "expected as many new yield values as new iter operands"); - rewriter.updateRootInPlace(yieldOp, [&]() { + rewriter.modifyOpInPlace(yieldOp, [&]() { yieldOp.getResultsMutable().append(newYieldedValues); }); } @@ -1444,7 +1444,7 @@ struct DimOfForallOp : public OpRewritePattern { Value sharedOut = forallOp.getTiedOpOperand(llvm::cast(dimOp.getSource())) ->get(); - rewriter.updateRootInPlace( + rewriter.modifyOpInPlace( dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); }); return success(); } @@ -1464,7 +1464,7 @@ class ForallOpControlOperandsFolder : public OpRewritePattern { failed(foldDynamicIndexList(mixedStep))) return failure(); - rewriter.updateRootInPlace(op, [&]() { + rewriter.modifyOpInPlace(op, [&]() { SmallVector dynamicLowerBound, dynamicUpperBound, dynamicStep; SmallVector staticLowerBound, staticUpperBound, staticStep; dispatchIndexOpFoldResults(mixedLowerBound, dynamicLowerBound, @@ -1556,7 +1556,7 @@ struct ForallOpSingleOrZeroIterationDimsFolder for (const auto &namedAttr : op->getAttrs()) { if (llvm::is_contained(elidedAttrs, namedAttr.getName())) continue; - rewriter.updateRootInPlace(newOp, [&]() { + rewriter.modifyOpInPlace(newOp, [&]() { newOp->setAttr(namedAttr.getName(), namedAttr.getValue()); }); } @@ -2023,8 +2023,8 @@ struct RemoveUnusedResults : public OpRewritePattern { [&](OpResult result) { return yieldOp.getOperand(result.getResultNumber()); }); - rewriter.updateRootInPlace(yieldOp, - [&]() { yieldOp->setOperands(usedOperands); }); + rewriter.modifyOpInPlace(yieldOp, + [&]() { yieldOp->setOperands(usedOperands); }); } LogicalResult matchAndRewrite(IfOp op, @@ -2189,8 +2189,8 @@ struct ConditionPropagation : public OpRewritePattern { constantTrue = rewriter.create( op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1)); - rewriter.updateRootInPlace(use.getOwner(), - [&]() { use.set(constantTrue); }); + rewriter.modifyOpInPlace(use.getOwner(), + [&]() { use.set(constantTrue); }); } else if (op.getElseRegion().isAncestor( use.getOwner()->getParentRegion())) { changed = true; @@ -2199,8 +2199,8 @@ struct ConditionPropagation : public OpRewritePattern { constantFalse = rewriter.create( op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0)); - rewriter.updateRootInPlace(use.getOwner(), - [&]() { use.set(constantFalse); }); + rewriter.modifyOpInPlace(use.getOwner(), + [&]() { use.set(constantFalse); }); } } @@ -2383,14 +2383,14 @@ struct CombineIfs : public OpRewritePattern { llvm::make_early_inc_range(std::get<0>(it).getUses())) { if (nextThen && nextThen->getParent()->isAncestor( use.getOwner()->getParentRegion())) { - rewriter.startRootUpdate(use.getOwner()); + rewriter.startOpModification(use.getOwner()); use.set(std::get<1>(it)); - rewriter.finalizeRootUpdate(use.getOwner()); + rewriter.finalizeOpModification(use.getOwner()); } else if (nextElse && nextElse->getParent()->isAncestor( use.getOwner()->getParentRegion())) { - rewriter.startRootUpdate(use.getOwner()); + rewriter.startOpModification(use.getOwner()); use.set(std::get<2>(it)); - rewriter.finalizeRootUpdate(use.getOwner()); + rewriter.finalizeOpModification(use.getOwner()); } } diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp index dc3c46bf896a9c..90f935d71c2fe9 100644 --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -688,7 +688,7 @@ struct ForOpInterface yieldValues.push_back(*alloc); } - rewriter.updateRootInPlace( + rewriter.modifyOpInPlace( yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); }); return success(); } @@ -928,7 +928,7 @@ struct WhileOpInterface return failure(); beforeYieldValues.push_back(*alloc); } - rewriter.updateRootInPlace(conditionOp, [&]() { + rewriter.modifyOpInPlace(conditionOp, [&]() { conditionOp.getArgsMutable().assign(beforeYieldValues); }); diff --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp index 7b6b07eabf6c48..cda561b1d1054d 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp @@ -89,8 +89,8 @@ struct ForLoopLoweringPattern : public OpRewritePattern { for (auto yieldOp : afterBlock->getOps()) { SmallVector yieldOperands = yieldOp.getOperands(); yieldOperands.insert(yieldOperands.begin(), ivIncOp.getResult()); - rewriter.updateRootInPlace( - yieldOp, [&]() { yieldOp->setOperands(yieldOperands); }); + rewriter.modifyOpInPlace(yieldOp, + [&]() { yieldOp->setOperands(yieldOperands); }); } // We cannot do a direct replacement of the forOp since the while op returns diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp index eee0791b397ae6..c6d024c462e837 100644 --- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp @@ -99,7 +99,7 @@ struct DimOfIterArgFolder : public OpRewritePattern { return failure(); Value initArg = forOp.getTiedLoopInit(blockArg)->get(); - rewriter.updateRootInPlace( + rewriter.modifyOpInPlace( dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); }); return success(); @@ -141,7 +141,7 @@ struct DimOfLoopResultFolder : public OpRewritePattern { unsigned resultNumber = opResult.getResultNumber(); if (!isShapePreserving(forOp, resultNumber)) return failure(); - rewriter.updateRootInPlace(dimOp, [&]() { + rewriter.modifyOpInPlace(dimOp, [&]() { dimOp.getSourceMutable().assign(forOp.getInitArgs()[resultNumber]); }); return success(); diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp index 342213507486af..a5bff0a892c3df 100644 --- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp @@ -160,8 +160,8 @@ static LogicalResult peelForLoop(RewriterBase &b, ForOp forOp, partialIteration.getInitArgsMutable().assign(forOp->getResults()); // Set new upper loop bound. - b.updateRootInPlace( - forOp, [&]() { forOp.getUpperBoundMutable().assign(splitBound); }); + b.modifyOpInPlace(forOp, + [&]() { forOp.getUpperBoundMutable().assign(splitBound); }); return success(); } @@ -239,7 +239,7 @@ LogicalResult mlir::scf::peelForLoopFirstIteration(RewriterBase &b, ForOp forOp, firstIteration = cast(b.clone(*forOp.getOperation(), map)); // Update main loop with new lower bound. - b.updateRootInPlace(forOp, [&]() { + b.modifyOpInPlace(forOp, [&]() { forOp.getInitArgsMutable().assign(firstIteration->getResults()); forOp.getLowerBoundMutable().assign(splitBound); }); @@ -286,11 +286,11 @@ struct ForLoopPeelingPattern : public OpRewritePattern { } // Apply label, so that the same loop is not rewritten a second time. - rewriter.updateRootInPlace(partialIteration, [&]() { + rewriter.modifyOpInPlace(partialIteration, [&]() { partialIteration->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr()); partialIteration->setAttr(kPartialIterationLabel, rewriter.getUnitAttr()); }); - rewriter.updateRootInPlace(forOp, [&]() { + rewriter.modifyOpInPlace(forOp, [&]() { forOp->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr()); }); return success(); diff --git a/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp b/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp index 8c2c544a89f7de..5aa35e79babfce 100644 --- a/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp @@ -111,7 +111,7 @@ class ConvertTypesInSCFYieldOp : public OneToNOpConversionPattern { return failure(); // Convert operands. - rewriter.updateRootInPlace( + rewriter.modifyOpInPlace( op, [&] { op->setOperands(adaptor.getFlatOperands()); }); return success(); @@ -131,7 +131,7 @@ class ConvertTypesInSCFConditionOp return failure(); // Convert operands. - rewriter.updateRootInPlace( + rewriter.modifyOpInPlace( op, [&] { op->setOperands(adaptor.getFlatOperands()); }); return success(); diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp index 7932c38a3e8d8b..e2cc5b4c5ff49b 100644 --- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp @@ -241,7 +241,7 @@ class ConvertConditionOpTypes : public OpConversionPattern { for (Value operand : adaptor.getOperands()) unpackUnrealizedConversionCast(operand, unpackedYield); - rewriter.updateRootInPlace(op, [&]() { op->setOperands(unpackedYield); }); + rewriter.modifyOpInPlace(op, [&]() { op->setOperands(unpackedYield); }); return success(); } }; diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 38e0625d7ce093..5c9b5281468fc7 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -692,7 +692,7 @@ void mlir::scf::yieldReplacementForFusedProducer( sliceOp.getLoc(), newRegionArg, sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(), sliceOp.getMixedStrides()); unsigned resultNumber = fusableProducer.getResultNumber(); - rewriter.updateRootInPlace(tiledDestStyleOp, [&]() { + rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() { tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice); }); } diff --git a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp index c22cb6710a7e5d..354db6467a582b 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp @@ -91,8 +91,8 @@ class SPIRVPassThroughConversion : public OpConversionPattern { LogicalResult matchAndRewrite(OpT op, typename OpT::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.updateRootInPlace(op, - [&] { op->setOperands(adaptor.getOperands()); }); + rewriter.modifyOpInPlace(op, + [&] { op->setOperands(adaptor.getOperands()); }); return success(); } }; diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp index 9f2755da092293..6150b5ee17851d 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -261,7 +261,7 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite( return failure(); // Creates a new function with the update signature. - rewriter.updateRootInPlace(funcOp, [&] { + rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(rewriter.getFunctionType( signatureConverter.getConvertedTypes(), std::nullopt)); }); diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp index c8e77f7de48300..d33eb9d2877ae3 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp @@ -29,7 +29,7 @@ sparse_tensor::detail::stageWithSortImpl(StageWithSortSparseOp op, // Clones the original operation but changing the output to an unordered COO. Operation *cloned = rewriter.clone(*op.getOperation()); - rewriter.updateRootInPlace(cloned, [cloned, srcCOOTp]() { + rewriter.modifyOpInPlace(cloned, [cloned, srcCOOTp]() { cloned->getOpResult(0).setType(srcCOOTp); }); Value srcCOO = cloned->getOpResult(0); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp index 50713be8296fa8..a0f7b55ce4446f 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp @@ -389,14 +389,14 @@ struct GenericOpReinterpretMap auto stt = tryGetSparseTensorType(res); auto [idxMap, itTp] = *transMap; - rewriter.startRootUpdate(linalgOp); + rewriter.startOpModification(linalgOp); linalgOp.setIndexingMapsAttr(idxMap); linalgOp.setIteratorTypesAttr(itTp); // Use demapped arguments. linalgOp.getInputsMutable().assign(adaptor.getInputs()); linalgOp.getDpsInitsMutable().assign(adaptor.getOutputs()); res.setType(adaptor.getOutputs()[0].getType()); - rewriter.finalizeRootUpdate(linalgOp); + rewriter.finalizeOpModification(linalgOp); rewriter.setInsertionPointAfter(linalgOp); if (stt && stt->hasEncoding()) { @@ -458,7 +458,7 @@ struct GenericOpScheduler : public OpRewritePattern { } // Marks the GenericOp to avoid recursive matching. - rewriter.updateRootInPlace(linalgOp, [&]() { + rewriter.modifyOpInPlace(linalgOp, [&]() { linalgOp->setAttr(sorted, rewriter.getBoolAttr(true)); }); @@ -482,10 +482,10 @@ struct GenericOpScheduler : public OpRewritePattern { for (AffineMap &idxMap : idxMaps) idxMap = idxMap.compose(order); // sorted loop -> lvl map - rewriter.startRootUpdate(linalgOp); + rewriter.startOpModification(linalgOp); linalgOp.setIndexingMapsAttr(rewriter.getAffineMapArrayAttr(idxMaps)); linalgOp.setIteratorTypesAttr(rewriter.getArrayAttr(curItTypes)); - rewriter.finalizeRootUpdate(linalgOp); + rewriter.finalizeOpModification(linalgOp); return success(); } @@ -570,7 +570,7 @@ struct GenericOpScheduler : public OpRewritePattern { rewriter.setInsertionPoint(linalgOp); RankedTensorType dstTp = stt.withDimToLvl(dimToLvl).getRankedTensorType(); Value dst = rewriter.create(tval.getLoc(), dstTp, tval); - rewriter.updateRootInPlace(linalgOp, [&]() { + rewriter.modifyOpInPlace(linalgOp, [&]() { linalgOp->setOperand(t->getOperandNumber(), dst); }); return success(); @@ -623,10 +623,10 @@ struct TensorAllocDemapper : public OpRewritePattern { } assert(dynSz.empty()); // should have consumed all. - rewriter.startRootUpdate(op); + rewriter.startOpModification(op); op->setOperands(dynLvlSzs); op.getResult().setType(stt.getDemappedType()); - rewriter.finalizeRootUpdate(op); + rewriter.finalizeOpModification(op); rewriter.setInsertionPointAfter(op); Value t = genRemap(rewriter, stt.getEncoding(), op.getResult()); @@ -676,7 +676,7 @@ struct ForeachOpDemapper auto srcStt = getSparseTensorType(op.getTensor()); SmallVector prevRetTps(op.getResultTypes()); - rewriter.startRootUpdate(op); + rewriter.startOpModification(op); op.getTensorMutable().assign(adaptor.getTensor()); op.getInitArgsMutable().assign(adaptor.getInitArgs()); // Update results' types. @@ -731,7 +731,7 @@ struct ForeachOpDemapper rewriter.eraseOp(yield); } } - rewriter.finalizeRootUpdate(op); + rewriter.finalizeOpModification(op); rewriter.setInsertionPointAfter(op); SmallVector outs = diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp index fa97e405584791..b1b8b762d164d5 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -329,7 +329,7 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern { .getCopy(); AllocTensorOp a = op.getDpsInitOperand(0)->get().getDefiningOp(); - rewriter.updateRootInPlace(a, [&]() { a.getCopyMutable().assign(init); }); + rewriter.modifyOpInPlace(a, [&]() { a.getCopyMutable().assign(init); }); } // Replace consumer with fused operation. Old producer // and consumer ops will be removed by DCE. @@ -366,7 +366,7 @@ struct FuseTensorCast : public OpRewritePattern { if (tensor::isSameTypeWithoutEncoding(srcType, dstType)) { if (Operation *def = op.getSource().getDefiningOp()) { if (def->hasOneUse() && isa(def)) { - rewriter.updateRootInPlace(def, [&]() { + rewriter.modifyOpInPlace(def, [&]() { def->getResult(0).setType(op->getResultTypes()[0]); }); rewriter.replaceOp(op, def->getResult(0)); @@ -804,7 +804,7 @@ struct ReshapeRewriter : public OpRewritePattern { auto denseTp = RankedTensorType::get(rtp.getShape(), rtp.getElementType()); auto convert = rewriter.create(loc, denseTp, op.getSrc()); - rewriter.updateRootInPlace(op, [&]() { op->setOperand(0, convert); }); + rewriter.modifyOpInPlace(op, [&]() { op->setOperand(0, convert); }); return success(); } if (encDst) { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp index 7710a44a7ca052..3a487a3bd6a069 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp @@ -545,7 +545,7 @@ static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl, forOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName())); rewriter.setInsertionPointToStart(forOpNew.getBody()); } else { - rewriter.updateRootInPlace(forOp, [&]() { forOp.setStep(step); }); + rewriter.modifyOpInPlace(forOp, [&]() { forOp.setStep(step); }); rewriter.setInsertionPoint(yield); } vmask = genVectorMask(rewriter, loc, vl, forOp.getInductionVar(), diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp index 5834426cae2f41..fec23d2a72347f 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -583,7 +583,7 @@ static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block, if (def->getBlock() == block) { rewriter.setInsertionPoint(def); for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) { - rewriter.updateRootInPlace(def, [&]() { + rewriter.modifyOpInPlace(def, [&]() { def->setOperand( i, relinkBranch(env, rewriter, block, def->getOperand(i))); }); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp index 80dad064676220..3d8cc5222b828b 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp @@ -1416,7 +1416,7 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc, Operation *newRed = rewriter.clone(*redExp); // Replaces arguments of the reduction expression by using the block // arguments from scf.reduce. - rewriter.updateRootInPlace( + rewriter.modifyOpInPlace( newRed, [&]() { newRed->setOperands(redBlock->getArguments()); }); // Erases the out-dated reduction expression. rewriter.eraseOp(redExp); diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 816e6ba8fed94e..b2fe58099b2fb3 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -819,7 +819,7 @@ struct DimOfDestStyleOp : public OpRewritePattern { auto resultIndex = source.cast().getResultNumber(); auto initOperand = destOp.getDpsInitOperand(resultIndex); - rewriter.updateRootInPlace( + rewriter.modifyOpInPlace( dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); }); return success(); } @@ -1752,7 +1752,7 @@ struct FoldCollapseOfCastOp : public OpRewritePattern { srcType, collapseShapeOp.getReassociationMaps()); if (newResultType == collapseShapeOp.getResultType()) { - rewriter.updateRootInPlace(collapseShapeOp, [&]() { + rewriter.modifyOpInPlace(collapseShapeOp, [&]() { collapseShapeOp.getSrcMutable().assign(castOp.getSource()); }); } else { @@ -2930,7 +2930,7 @@ struct FoldSourceTensorCast : public OpRewritePattern { padTensorOp.getResultType().getShape()); if (newResultType == padTensorOp.getResultType()) { - rewriter.updateRootInPlace(padTensorOp, [&]() { + rewriter.modifyOpInPlace(padTensorOp, [&]() { padTensorOp.getSourceMutable().assign(castOp.getSource()); }); } else { @@ -3994,9 +3994,9 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) { // Fold optional PaddingValue operand away if padding is not needed. if (packOp.getPaddingValue() && paddingIsNotNeeded(packOp)) { - rewriter.startRootUpdate(packOp); + rewriter.startOpModification(packOp); packOp.getPaddingValueMutable().clear(); - rewriter.finalizeRootUpdate(packOp); + rewriter.finalizeOpModification(packOp); return success(); } return failure(); @@ -4166,8 +4166,8 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp, unPackOp.getDest().getDefiningOp()) { auto destValue = unPackOp.getDest().cast(); Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()]; - rewriter.updateRootInPlace( - unPackOp, [&]() { unPackOp.setDpsInitOperand(0, newDest); }); + rewriter.modifyOpInPlace(unPackOp, + [&]() { unPackOp.setDpsInitOperand(0, newDest); }); return success(); } return failure(); diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 26c39ff3523434..744ab4154fe8a9 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -66,7 +66,7 @@ LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) { auto notOp = op.getPred().getDefiningOp(); if (!notOp) return failure(); - rewriter.updateRootInPlace(op, [&]() { + rewriter.modifyOpInPlace(op, [&]() { op.getOperation()->setOperands( {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()}); }); diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index f257728a7b947c..749eb56b3d3bec 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -4416,7 +4416,7 @@ class FoldWaw final : public OpRewritePattern { writeOp.getSource().getDefiningOp(); while (defWrite) { if (checkSameValueWAW(writeOp, defWrite)) { - rewriter.updateRootInPlace(writeToModify, [&]() { + rewriter.modifyOpInPlace(writeToModify, [&]() { writeToModify.getSourceMutable().assign(defWrite.getSource()); }); return success(); @@ -4533,7 +4533,7 @@ struct SwapExtractSliceOfTransferWrite transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(), transferOp.getIndices(), transferOp.getPermutationMapAttr(), rewriter.getBoolArrayAttr(newInBounds)); - rewriter.updateRootInPlace(insertOp, [&]() { + rewriter.modifyOpInPlace(insertOp, [&]() { insertOp.getSourceMutable().assign(newTransferWriteOp.getResult()); }); return success(); diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp index 5782ee1d58cf53..1caec5bb8644f3 100644 --- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp @@ -225,7 +225,7 @@ struct MaskOpInterface newReturnValues[it.index()] = it.value(); } } - rewriter.updateRootInPlace(yieldOp, [&]() { + rewriter.modifyOpInPlace(yieldOp, [&]() { yieldOp.getOperandsMutable().assign(newYieldedValues); }); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index 9d5ad20d4715b1..620ceee48b196d 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -182,7 +182,7 @@ static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns( auto yield = cast(newOpBody.getBlocks().begin()->getTerminator()); - rewriter.updateRootInPlace( + rewriter.modifyOpInPlace( yield, [&]() { yield.getOperandsMutable().assign(newYieldedValues); }); return newWarpOp; } @@ -724,7 +724,7 @@ struct WarpOpConstant : public OpRewritePattern { return failure(); // Notify the rewriter that the warp op is changing (see the comment on // the WarpOpTransferRead pattern). - rewriter.startRootUpdate(warpOp); + rewriter.startOpModification(warpOp); unsigned operandIndex = yieldOperand->getOperandNumber(); Attribute scalarAttr = dense.getSplatValue(); auto newAttr = DenseElementsAttr::get( @@ -733,7 +733,7 @@ struct WarpOpConstant : public OpRewritePattern { rewriter.setInsertionPointAfter(warpOp); Value distConstant = rewriter.create(loc, newAttr); rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), distConstant); - rewriter.finalizeRootUpdate(warpOp); + rewriter.finalizeOpModification(warpOp); return success(); } }; @@ -1017,9 +1017,9 @@ struct WarpOpForwardOperand : public OpRewritePattern { return failure(); // Notify the rewriter that the warp op is changing (see the comment on // the WarpOpTransferRead pattern). - rewriter.startRootUpdate(warpOp); + rewriter.startOpModification(warpOp); rewriter.replaceAllUsesWith(warpOp.getResult(resultIndex), valForwarded); - rewriter.finalizeRootUpdate(warpOp); + rewriter.finalizeOpModification(warpOp); return success(); } }; @@ -1159,7 +1159,7 @@ struct WarpOpCreateMask : public OpRewritePattern { // Notify the rewriter that the warp op is changing (see the comment on // the WarpOpTransferRead pattern). - rewriter.startRootUpdate(warpOp); + rewriter.startOpModification(warpOp); AffineExpr s0, s1; bindSymbols(rewriter.getContext(), s0, s1); @@ -1179,7 +1179,7 @@ struct WarpOpCreateMask : public OpRewritePattern { auto newMask = rewriter.create(loc, distType, newOperands); rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask); - rewriter.finalizeRootUpdate(warpOp); + rewriter.finalizeOpModification(warpOp); return success(); } }; diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp index ea33453e7215e3..f1a27168bd4e54 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp @@ -525,7 +525,7 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer( SmallVector bools(xferOp.getTransferRank(), true); auto inBoundsAttr = b.getBoolArrayAttr(bools); if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) { - b.updateRootInPlace(xferOp, [&]() { + b.modifyOpInPlace(xferOp, [&]() { xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr); }); return success(); @@ -598,7 +598,7 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer( for (unsigned i = 0, e = returnTypes.size(); i != e; ++i) xferReadOp.setOperand(i, fullPartialIfOp.getResult(i)); - b.updateRootInPlace(xferOp, [&]() { + b.modifyOpInPlace(xferOp, [&]() { xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr); }); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index 661674dd74c0cd..bd02c07981466d 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1050,7 +1050,7 @@ struct MaterializeTransferMask : public OpRewritePattern { mask = rewriter.create(loc, mask, xferOp.getMask()); } - rewriter.updateRootInPlace(xferOp, [&]() { + rewriter.modifyOpInPlace(xferOp, [&]() { xferOp.getMaskMutable().assign(mask); xferOp.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); }); diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp index 5e788cdb4897d3..73f232fd0de01a 100644 --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -263,7 +263,7 @@ void RewriterBase::eraseBlock(Block *block) { block->erase(); } -void RewriterBase::finalizeRootUpdate(Operation *op) { +void RewriterBase::finalizeOpModification(Operation *op) { // Notify the listener that the operation was modified. if (auto *rewriteListener = dyn_cast_if_present(listener)) rewriteListener->notifyOperationModified(op); @@ -276,7 +276,7 @@ void RewriterBase::replaceUsesWithIf(Value from, Value to, function_ref functor) { for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) { if (functor(operand)) - updateRootInPlace(operand.getOwner(), [&]() { operand.set(to); }); + modifyOpInPlace(operand.getOwner(), [&]() { operand.set(to); }); } } diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp index 26a7ea5d5e219e..f3a973d9994083 100644 --- a/mlir/lib/Transforms/Mem2Reg.cpp +++ b/mlir/lib/Transforms/Mem2Reg.cpp @@ -506,7 +506,7 @@ void MemorySlotPromoter::computeReachingDefInRegion(Region *region, if (info.mergePoints.contains(blockOperand.get())) { if (!job.reachingDef) job.reachingDef = getLazyDefaultValue(); - rewriter.updateRootInPlace(terminator, [&]() { + rewriter.modifyOpInPlace(terminator, [&]() { terminator.getSuccessorOperands(blockOperand.getOperandNumber()) .append(job.reachingDef); }); @@ -596,7 +596,7 @@ void MemorySlotPromoter::promoteSlot() { assert(succOperands.size() == mergePoint->getNumArguments() || succOperands.size() + 1 == mergePoint->getNumArguments()); if (succOperands.size() + 1 == mergePoint->getNumArguments()) - rewriter.updateRootInPlace( + rewriter.modifyOpInPlace( user, [&]() { succOperands.append(getLazyDefaultValue()); }); } } diff --git a/mlir/lib/Transforms/Utils/CommutativityUtils.cpp b/mlir/lib/Transforms/Utils/CommutativityUtils.cpp index 6034366631d10f..5ba6e4747cb57f 100644 --- a/mlir/lib/Transforms/Utils/CommutativityUtils.cpp +++ b/mlir/lib/Transforms/Utils/CommutativityUtils.cpp @@ -304,7 +304,7 @@ class SortCommutativeOperands : public RewritePattern { sortedOperands.push_back(commOperand->operand); if (sortedOperands == operands) return failure(); - rewriter.updateRootInPlace(op, [&] { op->setOperands(sortedOperands); }); + rewriter.modifyOpInPlace(op, [&] { op->setOperands(sortedOperands); }); return success(); } }; diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 85433d088dcbf0..ef6a49455d1860 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1614,15 +1614,15 @@ void ConversionPatternRewriter::notifyOperationInserted(Operation *op) { impl->createdOps.push_back(op); } -void ConversionPatternRewriter::startRootUpdate(Operation *op) { +void ConversionPatternRewriter::startOpModification(Operation *op) { #ifndef NDEBUG impl->pendingRootUpdates.insert(op); #endif impl->rootUpdates.emplace_back(op); } -void ConversionPatternRewriter::finalizeRootUpdate(Operation *op) { - PatternRewriter::finalizeRootUpdate(op); +void ConversionPatternRewriter::finalizeOpModification(Operation *op) { + PatternRewriter::finalizeOpModification(op); // There is nothing to do here, we only need to track the operation at the // start of the update. #ifndef NDEBUG @@ -1631,7 +1631,7 @@ void ConversionPatternRewriter::finalizeRootUpdate(Operation *op) { #endif } -void ConversionPatternRewriter::cancelRootUpdate(Operation *op) { +void ConversionPatternRewriter::cancelOpModification(Operation *op) { #ifndef NDEBUG assert(impl->pendingRootUpdates.erase(op) && "operation did not have a pending in-place update"); @@ -3115,7 +3115,7 @@ static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, auto newType = FunctionType::get(rewriter.getContext(), result.getConvertedTypes(), newResults); - rewriter.updateRootInPlace(funcOp, [&] { funcOp.setType(newType); }); + rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); }); return success(); } diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 2e3bc76009ca20..d1ac5e81e75a69 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -160,7 +160,7 @@ struct IncrementIntAttribute : public OpRewritePattern { int64_t val = intAttr.getInt(); if (val >= MaxVal) return failure(); - rewriter.updateRootInPlace( + rewriter.modifyOpInPlace( op, [&]() { op.setAttrAttr(rewriter.getI32IntegerAttr(val + 1)); }); return success(); } @@ -175,7 +175,7 @@ struct MakeOpEligible : public RewritePattern { PatternRewriter &rewriter) const override { if (op->hasAttr("eligible")) return failure(); - rewriter.updateRootInPlace( + rewriter.modifyOpInPlace( op, [&]() { op->setAttr("eligible", rewriter.getUnitAttr()); }); return success(); } @@ -195,7 +195,7 @@ struct HoistEligibleOps : public OpRewritePattern { return failure(); // Hoisting means removing an op from the enclosing op. I.e., the enclosing // op is modified. - rewriter.updateRootInPlace(op, [&]() { toBeHoisted->moveBefore(op); }); + rewriter.modifyOpInPlace(op, [&]() { toBeHoisted->moveBefore(op); }); return success(); } }; @@ -327,7 +327,7 @@ struct TestStrictPatternDriver Operation *newOp = rewriter.create(op->getLoc(), op->getName().getIdentifier(), op->getOperands(), op->getResultTypes()); - rewriter.updateRootInPlace( + rewriter.modifyOpInPlace( op, [&]() { op->setAttr("skip", rewriter.getBoolAttr(true)); }); newOp->setAttr("skip", rewriter.getBoolAttr(true)); @@ -415,8 +415,8 @@ struct TestStrictPatternDriver PatternRewriter &rewriter) const override { if (op->getNumSuccessors() < 1 || op->getSuccessor(0) == op->getBlock()) return failure(); - rewriter.updateRootInPlace( - op, [&]() { op->setSuccessor(op->getBlock(), 0); }); + rewriter.modifyOpInPlace(op, + [&]() { op->setSuccessor(op->getBlock(), 0); }); return success(); } }; @@ -650,7 +650,7 @@ struct TestUndoBlockArgReplace : public ConversionPattern { rewriter.create(op->getLoc(), rewriter.getF32Type()); rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0), illegalOp->getResult(0)); - rewriter.updateRootInPlace(op, [] {}); + rewriter.modifyOpInPlace(op, [] {}); return success(); } }; @@ -667,7 +667,7 @@ struct TestUndoBlockErase : public ConversionPattern { rewriter.setInsertionPointToStart(secondBlock); rewriter.create(op->getLoc(), rewriter.getF32Type()); rewriter.eraseBlock(secondBlock); - rewriter.updateRootInPlace(op, [] {}); + rewriter.modifyOpInPlace(op, [] {}); return success(); } }; @@ -827,7 +827,7 @@ struct TestBoundedRecursiveRewrite LogicalResult matchAndRewrite(TestRecursiveRewriteOp op, PatternRewriter &rewriter) const final { // Decrement the depth of the op in-place. - rewriter.updateRootInPlace(op, [&] { + rewriter.modifyOpInPlace(op, [&] { op->setAttr("depth", rewriter.getI64IntegerAttr(op.getDepth() - 1)); }); return success(); @@ -1333,7 +1333,7 @@ struct TestTestSignatureConversionNoConverter if (failed( converter.convertSignatureArgs(entry->getArgumentTypes(), result))) return failure(); - rewriter.updateRootInPlace( + rewriter.modifyOpInPlace( op, [&] { rewriter.applySignatureConversion(®ion, result); }); return success(); } @@ -1350,8 +1350,8 @@ struct TestTypeConsumerForward LogicalResult matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { - rewriter.updateRootInPlace(op, - [&] { op->setOperands(adaptor.getOperands()); }); + rewriter.modifyOpInPlace(op, + [&] { op->setOperands(adaptor.getOperands()); }); return success(); } }; @@ -1567,7 +1567,7 @@ struct TestMergeBlock : public OpConversionPattern { SmallVector replacements(succOperands); rewriter.eraseOp(branchOp); rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); - rewriter.updateRootInPlace(op, [] {}); + rewriter.modifyOpInPlace(op, [] {}); return success(); } }; @@ -1588,7 +1588,7 @@ struct TestUndoBlocksMerge : public ConversionPattern { SmallVector replacements(succOperands); rewriter.eraseOp(branchOp); rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); - rewriter.updateRootInPlace(op, [] {}); + rewriter.modifyOpInPlace(op, [] {}); return success(); } }; @@ -1613,7 +1613,7 @@ struct TestMergeSingleBlockOps rewriter.inlineBlockBefore(&innerBlock, op); rewriter.eraseOp(innerTerminator); rewriter.eraseOp(op); - rewriter.updateRootInPlace(op, [] {}); + rewriter.modifyOpInPlace(op, [] {}); return success(); } };