Skip to content

Commit

Permalink
[mlir][IR] Rename "update root" to "modify op" in rewriter API (#78260)
Browse files Browse the repository at this point in the history
This commit renames 4 pattern rewriter API functions:
* `updateRootInPlace` -> `modifyOpInPlace`
* `startRootUpdate` -> `startOpModification`
* `finalizeRootUpdate` -> `finalizeOpModification`
* `cancelRootUpdate` -> `cancelOpModification`

The term "root" is a misnomer. The root is the op that a rewrite pattern
matches against
(https://mlir.llvm.org/docs/PatternRewriter/#root-operation-name-optional).
A rewriter must be notified of all in-place op modifications, not just
in-place modifications of the root
(https://mlir.llvm.org/docs/PatternRewriter/#pattern-rewriter). The old
function names were confusing and have contributed to various broken
rewrite patterns.

Note: The new function names use the term "modify" instead of "update"
for consistency with the `RewriterBase::Listener` terminology
(`notifyOperationModified`).
  • Loading branch information
matthias-springer committed Jan 17, 2024
1 parent 57b50ef commit 5fcf907
Show file tree
Hide file tree
Showing 78 changed files with 246 additions and 246 deletions.
20 changes: 10 additions & 10 deletions flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,14 +215,14 @@ class BoxedProcedurePass
rewriter.replaceOpWithNewOp<ConvertOp>(
addr, typeConverter.convertType(addr.getType()), addr.getVal());
} else if (typeConverter.needsConversion(resTy)) {
rewriter.startRootUpdate(op);
rewriter.startOpModification(op);
op->getResult(0).setType(typeConverter.convertType(resTy));
rewriter.finalizeRootUpdate(op);
rewriter.finalizeOpModification(op);
}
} else if (auto func = mlir::dyn_cast<mlir::func::FuncOp>(op)) {
mlir::FunctionType ty = func.getFunctionType();
if (typeConverter.needsConversion(ty)) {
rewriter.startRootUpdate(func);
rewriter.startOpModification(func);
auto toTy =
typeConverter.convertType(ty).cast<mlir::FunctionType>();
if (!func.empty())
Expand All @@ -235,7 +235,7 @@ class BoxedProcedurePass
block.eraseArgument(i + 1);
}
func.setType(toTy);
rewriter.finalizeRootUpdate(func);
rewriter.finalizeOpModification(func);
}
} else if (auto embox = mlir::dyn_cast<EmboxProcOp>(op)) {
// Rewrite all `fir.emboxproc` ops to either `fir.convert` or a thunk
Expand Down Expand Up @@ -273,10 +273,10 @@ class BoxedProcedurePass
} else if (auto global = mlir::dyn_cast<GlobalOp>(op)) {
auto ty = global.getType();
if (typeConverter.needsConversion(ty)) {
rewriter.startRootUpdate(global);
rewriter.startOpModification(global);
auto toTy = typeConverter.convertType(ty);
global.setType(toTy);
rewriter.finalizeRootUpdate(global);
rewriter.finalizeOpModification(global);
}
} else if (auto mem = mlir::dyn_cast<AllocaOp>(op)) {
auto ty = mem.getType();
Expand Down Expand Up @@ -339,17 +339,17 @@ class BoxedProcedurePass
mem, toTy, index.getFieldId(), toOnTy, index.getTypeparams());
}
} else if (op->getDialect() == firDialect) {
rewriter.startRootUpdate(op);
rewriter.startOpModification(op);
for (auto i : llvm::enumerate(op->getResultTypes()))
if (typeConverter.needsConversion(i.value())) {
auto toTy = typeConverter.convertType(i.value());
op->getResult(i.index()).setType(toTy);
}
rewriter.finalizeRootUpdate(op);
rewriter.finalizeOpModification(op);
}
// Ensure block arguments are updated if needed.
if (op->getNumRegions() != 0) {
rewriter.startRootUpdate(op);
rewriter.startOpModification(op);
for (mlir::Region &region : op->getRegions())
for (mlir::Block &block : region.getBlocks())
for (mlir::BlockArgument blockArg : block.getArguments())
Expand All @@ -358,7 +358,7 @@ class BoxedProcedurePass
typeConverter.convertType(blockArg.getType());
blockArg.setType(toTy);
}
rewriter.finalizeRootUpdate(op);
rewriter.finalizeOpModification(op);
}
});
}
Expand Down
8 changes: 4 additions & 4 deletions flang/lib/Optimizer/CodeGen/CodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3763,13 +3763,13 @@ class RenameMSVCLibmCallees
mlir::LogicalResult
matchAndRewrite(mlir::LLVM::CallOp op,
mlir::PatternRewriter &rewriter) const override {
rewriter.startRootUpdate(op);
rewriter.startOpModification(op);
auto callee = op.getCallee();
if (callee)
if (callee->equals("hypotf"))
op.setCalleeAttr(mlir::SymbolRefAttr::get(op.getContext(), "_hypotf"));

rewriter.finalizeRootUpdate(op);
rewriter.finalizeOpModification(op);
return mlir::success();
}
};
Expand All @@ -3782,10 +3782,10 @@ class RenameMSVCLibmFuncs
mlir::LogicalResult
matchAndRewrite(mlir::LLVM::LLVMFuncOp op,
mlir::PatternRewriter &rewriter) const override {
rewriter.startRootUpdate(op);
rewriter.startOpModification(op);
if (op.getSymName().equals("hypotf"))
op.setSymNameAttr(rewriter.getStringAttr("_hypotf"));
rewriter.finalizeRootUpdate(op);
rewriter.finalizeOpModification(op);
return mlir::success();
}
};
Expand Down
8 changes: 4 additions & 4 deletions flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,9 @@ struct AssignOpConversion : public mlir::OpConversionPattern<hlfir::AssignOp> {
llvm::SmallVector<mlir::Value> newOperands;
for (mlir::Value operand : adaptor.getOperands())
newOperands.push_back(getBufferizedExprStorage(operand));
rewriter.startRootUpdate(assign);
rewriter.startOpModification(assign);
assign->setOperands(newOperands);
rewriter.finalizeRootUpdate(assign);
rewriter.finalizeOpModification(assign);
return mlir::success();
}
};
Expand Down Expand Up @@ -834,9 +834,9 @@ struct ElementalOpConversion
// Explicitly delete the body of the elemental to get rid
// of any users of hlfir.expr values inside the body as early
// as possible.
rewriter.startRootUpdate(elemental);
rewriter.startOpModification(elemental);
rewriter.eraseBlock(elemental.getBody());
rewriter.finalizeRootUpdate(elemental);
rewriter.finalizeOpModification(elemental);
rewriter.replaceOp(elemental, bufferizedExpr);
return mlir::success();
}
Expand Down
4 changes: 2 additions & 2 deletions flang/lib/Optimizer/Transforms/AffineDemotion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,9 @@ class ConvertConversion : public mlir::OpRewritePattern<fir::ConvertOp> {
op.getValue());
return success();
}
rewriter.startRootUpdate(op->getParentOp());
rewriter.startOpModification(op->getParentOp());
op.getResult().replaceAllUsesWith(op.getValue());
rewriter.finalizeRootUpdate(op->getParentOp());
rewriter.finalizeOpModification(op->getParentOp());
rewriter.eraseOp(op);
}
return success();
Expand Down
12 changes: 6 additions & 6 deletions flang/lib/Optimizer/Transforms/AffinePromotion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -464,15 +464,15 @@ class AffineLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> {
auto affineFor = loopAndIndex.first;
auto inductionVar = loopAndIndex.second;

rewriter.startRootUpdate(affineFor.getOperation());
rewriter.startOpModification(affineFor.getOperation());
affineFor.getBody()->getOperations().splice(
std::prev(affineFor.getBody()->end()), loopOps, loopOps.begin(),
std::prev(loopOps.end()));
rewriter.finalizeRootUpdate(affineFor.getOperation());
rewriter.finalizeOpModification(affineFor.getOperation());

rewriter.startRootUpdate(loop.getOperation());
rewriter.startOpModification(loop.getOperation());
loop.getInductionVar().replaceAllUsesWith(inductionVar);
rewriter.finalizeRootUpdate(loop.getOperation());
rewriter.finalizeOpModification(loop.getOperation());

rewriteMemoryOps(affineFor.getBody(), rewriter);

Expand Down Expand Up @@ -561,7 +561,7 @@ class AffineIfConversion : public mlir::OpRewritePattern<fir::IfOp> {
auto affineIf = rewriter.create<affine::AffineIfOp>(
op.getLoc(), affineCondition.getIntegerSet(),
affineCondition.getAffineArgs(), !op.getElseRegion().empty());
rewriter.startRootUpdate(affineIf);
rewriter.startOpModification(affineIf);
affineIf.getThenBlock()->getOperations().splice(
std::prev(affineIf.getThenBlock()->end()), ifOps, ifOps.begin(),
std::prev(ifOps.end()));
Expand All @@ -571,7 +571,7 @@ class AffineIfConversion : public mlir::OpRewritePattern<fir::IfOp> {
std::prev(affineIf.getElseBlock()->end()), otherOps, otherOps.begin(),
std::prev(otherOps.end()));
}
rewriter.finalizeRootUpdate(affineIf);
rewriter.finalizeOpModification(affineIf);
rewriteMemoryOps(affineIf.getBody(), rewriter);

LLVM_DEBUG(llvm::dbgs() << "AffineIfConversion: if converted to:\n";
Expand Down
8 changes: 4 additions & 4 deletions flang/lib/Optimizer/Transforms/ExternalNameConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ struct MangleNameOnFuncOp : public mlir::OpRewritePattern<mlir::func::FuncOp> {
matchAndRewrite(mlir::func::FuncOp op,
mlir::PatternRewriter &rewriter) const override {
mlir::LogicalResult ret = success();
rewriter.startRootUpdate(op);
rewriter.startOpModification(op);
llvm::StringRef oldName = op.getSymName();
auto result = fir::NameUniquer::deconstruct(oldName);
if (fir::NameUniquer::isExternalFacingUniquedName(result)) {
Expand All @@ -95,7 +95,7 @@ struct MangleNameOnFuncOp : public mlir::OpRewritePattern<mlir::func::FuncOp> {
}

updateEarlyOutliningParentName(op, appendUnderscore);
rewriter.finalizeRootUpdate(op);
rewriter.finalizeOpModification(op);
return ret;
}

Expand All @@ -114,15 +114,15 @@ struct MangleNameForCommonBlock : public mlir::OpRewritePattern<fir::GlobalOp> {
mlir::LogicalResult
matchAndRewrite(fir::GlobalOp op,
mlir::PatternRewriter &rewriter) const override {
rewriter.startRootUpdate(op);
rewriter.startOpModification(op);
auto result = fir::NameUniquer::deconstruct(
op.getSymref().getRootReference().getValue());
if (fir::NameUniquer::isExternalFacingUniquedName(result)) {
auto newName = mangleExternalName(result, appendUnderscore);
op.setSymrefAttr(mlir::SymbolRefAttr::get(op.getContext(), newName));
SymbolTable::setSymbolName(op, newName);
}
rewriter.finalizeRootUpdate(op);
rewriter.finalizeOpModification(op);
return success();
}

Expand Down
10 changes: 5 additions & 5 deletions mlir/docs/PatternRewriter.md
Original file line number Diff line number Diff line change
Expand Up @@ -213,15 +213,15 @@ user is determined by the specific pattern driver.
This method replaces an operation's results with a set of provided values, and
erases the operation.
* Update an Operation in-place : `(start|cancel|finalize)RootUpdate`
* Update an Operation in-place : `(start|cancel|finalize)OpModification`
This is a collection of methods that provide a transaction-like API for updating
the attributes, location, operands, or successors of an operation in-place
within a pattern. An in-place update transaction is started with
`startRootUpdate`, and may either be canceled or finalized with
`cancelRootUpdate` and `finalizeRootUpdate` respectively. A convenience wrapper,
`updateRootInPlace`, is provided that wraps a `start` and `finalize` around a
callback.
`startOpModification`, and may either be canceled or finalized with
`cancelOpModification` and `finalizeOpModification` respectively. A convenience
wrapper, `modifyOpInPlace`, is provided that wraps a `start` and `finalize`
around a callback.
* OpBuilder API
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class StandaloneSwitchBarFooRewriter : public OpRewritePattern<func::FuncOp> {
LogicalResult matchAndRewrite(func::FuncOp op,
PatternRewriter &rewriter) const final {
if (op.getSymName() == "bar") {
rewriter.updateRootInPlace(op, [&op]() { op.setSymName("foo"); });
rewriter.modifyOpInPlace(op, [&op]() { op.setSymName("foo"); });
return success();
}
return failure();
Expand Down
4 changes: 2 additions & 2 deletions mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,8 @@ struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
ConversionPatternRewriter &rewriter) const final {
// We don't lower "toy.print" in this pass, but we need to update its
// operands.
rewriter.updateRootInPlace(op,
[&] { op->setOperands(adaptor.getOperands()); });
rewriter.modifyOpInPlace(op,
[&] { op->setOperands(adaptor.getOperands()); });
return success();
}
};
Expand Down
4 changes: 2 additions & 2 deletions mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,8 @@ struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
ConversionPatternRewriter &rewriter) const final {
// We don't lower "toy.print" in this pass, but we need to update its
// operands.
rewriter.updateRootInPlace(op,
[&] { op->setOperands(adaptor.getOperands()); });
rewriter.modifyOpInPlace(op,
[&] { op->setOperands(adaptor.getOperands()); });
return success();
}
};
Expand Down
4 changes: 2 additions & 2 deletions mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,8 @@ struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
ConversionPatternRewriter &rewriter) const final {
// We don't lower "toy.print" in this pass, but we need to update its
// operands.
rewriter.updateRootInPlace(op,
[&] { op->setOperands(adaptor.getOperands()); });
rewriter.modifyOpInPlace(op,
[&] { op->setOperands(adaptor.getOperands()); });
return success();
}
};
Expand Down
44 changes: 23 additions & 21 deletions mlir/include/mlir/IR/PatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -585,28 +585,30 @@ class RewriterBase : public OpBuilder {

/// This method is used to notify the rewriter that an in-place operation
/// modification is about to happen. A call to this function *must* be
/// followed by a call to either `finalizeRootUpdate` or `cancelRootUpdate`.
/// This is a minor efficiency win (it avoids creating a new operation and
/// removing the old one) but also often allows simpler code in the client.
virtual void startRootUpdate(Operation *op) {}

/// This method is used to signal the end of a root update on the given
/// operation. This can only be called on operations that were provided to a
/// call to `startRootUpdate`.
virtual void finalizeRootUpdate(Operation *op);

/// This method cancels a pending root update. This can only be called on
/// operations that were provided to a call to `startRootUpdate`.
virtual void cancelRootUpdate(Operation *op) {}

/// This method is a utility wrapper around a root update of an operation. It
/// wraps calls to `startRootUpdate` and `finalizeRootUpdate` around the given
/// callable.
/// followed by a call to either `finalizeOpModification` or
/// `cancelOpModification`. This is a minor efficiency win (it avoids creating
/// a new operation and removing the old one) but also often allows simpler
/// code in the client.
virtual void startOpModification(Operation *op) {}

/// This method is used to signal the end of an in-place modification of the
/// given operation. This can only be called on operations that were provided
/// to a call to `startOpModification`.
virtual void finalizeOpModification(Operation *op);

/// This method cancels a pending in-place modification. This can only be
/// called on operations that were provided to a call to
/// `startOpModification`.
virtual void cancelOpModification(Operation *op) {}

/// This method is a utility wrapper around an in-place modification of an
/// operation. It wraps calls to `startOpModification` and
/// `finalizeOpModification` around the given callable.
template <typename CallableT>
void updateRootInPlace(Operation *root, CallableT &&callable) {
startRootUpdate(root);
void modifyOpInPlace(Operation *root, CallableT &&callable) {
startOpModification(root);
callable();
finalizeRootUpdate(root);
finalizeOpModification(root);
}

/// Find uses of `from` and replace them with `to`. It also marks every
Expand All @@ -619,7 +621,7 @@ class RewriterBase : public OpBuilder {
void replaceAllUsesWith(IRObjectWithUseList<OperandType> *from, ValueT &&to) {
for (OperandType &operand : llvm::make_early_inc_range(from->getUses())) {
Operation *op = operand.getOwner();
updateRootInPlace(op, [&]() { operand.set(to); });
modifyOpInPlace(op, [&]() { operand.set(to); });
}
}
void replaceAllUsesWith(ValueRange from, ValueRange to) {
Expand Down
14 changes: 7 additions & 7 deletions mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -739,17 +739,17 @@ class ConversionPatternRewriter final : public PatternRewriter,
/// PatternRewriter hook for inserting a new operation.
void notifyOperationInserted(Operation *op) override;

/// PatternRewriter hook for updating the root operation in-place.
/// Note: These methods only track updates to the top-level operation itself,
/// PatternRewriter hook for updating the given operation in-place.
/// Note: These methods only track updates to the given operation itself,
/// and not nested regions. Updates to regions will still require notification
/// through other more specific hooks above.
void startRootUpdate(Operation *op) override;
void startOpModification(Operation *op) override;

/// PatternRewriter hook for updating the root operation in-place.
void finalizeRootUpdate(Operation *op) override;
/// PatternRewriter hook for updating the given operation in-place.
void finalizeOpModification(Operation *op) override;

/// PatternRewriter hook for updating the root operation in-place.
void cancelRootUpdate(Operation *op) override;
/// PatternRewriter hook for updating the given operation in-place.
void cancelOpModification(Operation *op) override;

/// PatternRewriter hook for notifying match failure reasons.
LogicalResult
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
// Step 2. Assign the op a real tile ID.
// For simplicity, we always use tile 0 (which always exists).
auto zeroTileId = rewriter.getI32IntegerAttr(0);
rewriter.updateRootInPlace(tileOp, [&] { tileOp.setTileId(zeroTileId); });
rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(zeroTileId); });

VectorType tileVectorType = tileOp.getTileType();
auto sliceType = VectorType::Builder(tileVectorType).dropDim(0);
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -918,8 +918,7 @@ LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
for (auto stream : streams)
streamDestroyCallBuilder.create(loc, rewriter, {stream});

rewriter.updateRootInPlace(yieldOp,
[&] { yieldOp->setOperands(newOperands); });
rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperands(newOperands); });
return success();
}

Expand Down
5 changes: 2 additions & 3 deletions mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,13 @@ class ExpandIfCondition : public OpRewritePattern<OpTy> {
if (!matchPattern(op.getIfCond(), m_Constant(&constAttr))) {
auto ifOp = rewriter.create<scf::IfOp>(op.getLoc(), TypeRange(),
op.getIfCond(), false);
rewriter.updateRootInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
auto thenBodyBuilder = ifOp.getThenBodyBuilder(rewriter.getListener());
thenBodyBuilder.clone(*op.getOperation());
rewriter.eraseOp(op);
} else {
if (constAttr.getInt())
rewriter.updateRootInPlace(op,
[&]() { op.getIfCondMutable().erase(0); });
rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
else
rewriter.eraseOp(op);
}
Expand Down
Loading

0 comments on commit 5fcf907

Please sign in to comment.