Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][IR] Rename "update root" to "modify op" in rewriter API #78260

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,14 +215,14 @@ class BoxedProcedurePass
rewriter.replaceOpWithNewOp<ConvertOp>(
addr, typeConverter.convertType(addr.getType()), addr.getVal());
} else if (typeConverter.needsConversion(resTy)) {
rewriter.startRootUpdate(op);
rewriter.startOpModification(op);
op->getResult(0).setType(typeConverter.convertType(resTy));
rewriter.finalizeRootUpdate(op);
rewriter.finalizeOpModification(op);
}
} else if (auto func = mlir::dyn_cast<mlir::func::FuncOp>(op)) {
mlir::FunctionType ty = func.getFunctionType();
if (typeConverter.needsConversion(ty)) {
rewriter.startRootUpdate(func);
rewriter.startOpModification(func);
auto toTy =
typeConverter.convertType(ty).cast<mlir::FunctionType>();
if (!func.empty())
Expand All @@ -235,7 +235,7 @@ class BoxedProcedurePass
block.eraseArgument(i + 1);
}
func.setType(toTy);
rewriter.finalizeRootUpdate(func);
rewriter.finalizeOpModification(func);
}
} else if (auto embox = mlir::dyn_cast<EmboxProcOp>(op)) {
// Rewrite all `fir.emboxproc` ops to either `fir.convert` or a thunk
Expand Down Expand Up @@ -273,10 +273,10 @@ class BoxedProcedurePass
} else if (auto global = mlir::dyn_cast<GlobalOp>(op)) {
auto ty = global.getType();
if (typeConverter.needsConversion(ty)) {
rewriter.startRootUpdate(global);
rewriter.startOpModification(global);
auto toTy = typeConverter.convertType(ty);
global.setType(toTy);
rewriter.finalizeRootUpdate(global);
rewriter.finalizeOpModification(global);
}
} else if (auto mem = mlir::dyn_cast<AllocaOp>(op)) {
auto ty = mem.getType();
Expand Down Expand Up @@ -339,17 +339,17 @@ class BoxedProcedurePass
mem, toTy, index.getFieldId(), toOnTy, index.getTypeparams());
}
} else if (op->getDialect() == firDialect) {
rewriter.startRootUpdate(op);
rewriter.startOpModification(op);
for (auto i : llvm::enumerate(op->getResultTypes()))
if (typeConverter.needsConversion(i.value())) {
auto toTy = typeConverter.convertType(i.value());
op->getResult(i.index()).setType(toTy);
}
rewriter.finalizeRootUpdate(op);
rewriter.finalizeOpModification(op);
}
// Ensure block arguments are updated if needed.
if (op->getNumRegions() != 0) {
rewriter.startRootUpdate(op);
rewriter.startOpModification(op);
for (mlir::Region &region : op->getRegions())
for (mlir::Block &block : region.getBlocks())
for (mlir::BlockArgument blockArg : block.getArguments())
Expand All @@ -358,7 +358,7 @@ class BoxedProcedurePass
typeConverter.convertType(blockArg.getType());
blockArg.setType(toTy);
}
rewriter.finalizeRootUpdate(op);
rewriter.finalizeOpModification(op);
}
});
}
Expand Down
8 changes: 4 additions & 4 deletions flang/lib/Optimizer/CodeGen/CodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3763,13 +3763,13 @@ class RenameMSVCLibmCallees
mlir::LogicalResult
matchAndRewrite(mlir::LLVM::CallOp op,
mlir::PatternRewriter &rewriter) const override {
rewriter.startRootUpdate(op);
rewriter.startOpModification(op);
auto callee = op.getCallee();
if (callee)
if (callee->equals("hypotf"))
op.setCalleeAttr(mlir::SymbolRefAttr::get(op.getContext(), "_hypotf"));

rewriter.finalizeRootUpdate(op);
rewriter.finalizeOpModification(op);
return mlir::success();
}
};
Expand All @@ -3782,10 +3782,10 @@ class RenameMSVCLibmFuncs
mlir::LogicalResult
matchAndRewrite(mlir::LLVM::LLVMFuncOp op,
mlir::PatternRewriter &rewriter) const override {
rewriter.startRootUpdate(op);
rewriter.startOpModification(op);
if (op.getSymName().equals("hypotf"))
op.setSymNameAttr(rewriter.getStringAttr("_hypotf"));
rewriter.finalizeRootUpdate(op);
rewriter.finalizeOpModification(op);
return mlir::success();
}
};
Expand Down
8 changes: 4 additions & 4 deletions flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,9 @@ struct AssignOpConversion : public mlir::OpConversionPattern<hlfir::AssignOp> {
llvm::SmallVector<mlir::Value> newOperands;
for (mlir::Value operand : adaptor.getOperands())
newOperands.push_back(getBufferizedExprStorage(operand));
rewriter.startRootUpdate(assign);
rewriter.startOpModification(assign);
assign->setOperands(newOperands);
rewriter.finalizeRootUpdate(assign);
rewriter.finalizeOpModification(assign);
return mlir::success();
}
};
Expand Down Expand Up @@ -834,9 +834,9 @@ struct ElementalOpConversion
// Explicitly delete the body of the elemental to get rid
// of any users of hlfir.expr values inside the body as early
// as possible.
rewriter.startRootUpdate(elemental);
rewriter.startOpModification(elemental);
rewriter.eraseBlock(elemental.getBody());
rewriter.finalizeRootUpdate(elemental);
rewriter.finalizeOpModification(elemental);
rewriter.replaceOp(elemental, bufferizedExpr);
return mlir::success();
}
Expand Down
4 changes: 2 additions & 2 deletions flang/lib/Optimizer/Transforms/AffineDemotion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,9 @@ class ConvertConversion : public mlir::OpRewritePattern<fir::ConvertOp> {
op.getValue());
return success();
}
rewriter.startRootUpdate(op->getParentOp());
rewriter.startOpModification(op->getParentOp());
op.getResult().replaceAllUsesWith(op.getValue());
rewriter.finalizeRootUpdate(op->getParentOp());
rewriter.finalizeOpModification(op->getParentOp());
rewriter.eraseOp(op);
}
return success();
Expand Down
12 changes: 6 additions & 6 deletions flang/lib/Optimizer/Transforms/AffinePromotion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -464,15 +464,15 @@ class AffineLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> {
auto affineFor = loopAndIndex.first;
auto inductionVar = loopAndIndex.second;

rewriter.startRootUpdate(affineFor.getOperation());
rewriter.startOpModification(affineFor.getOperation());
affineFor.getBody()->getOperations().splice(
std::prev(affineFor.getBody()->end()), loopOps, loopOps.begin(),
std::prev(loopOps.end()));
rewriter.finalizeRootUpdate(affineFor.getOperation());
rewriter.finalizeOpModification(affineFor.getOperation());

rewriter.startRootUpdate(loop.getOperation());
rewriter.startOpModification(loop.getOperation());
loop.getInductionVar().replaceAllUsesWith(inductionVar);
rewriter.finalizeRootUpdate(loop.getOperation());
rewriter.finalizeOpModification(loop.getOperation());

rewriteMemoryOps(affineFor.getBody(), rewriter);

Expand Down Expand Up @@ -561,7 +561,7 @@ class AffineIfConversion : public mlir::OpRewritePattern<fir::IfOp> {
auto affineIf = rewriter.create<affine::AffineIfOp>(
op.getLoc(), affineCondition.getIntegerSet(),
affineCondition.getAffineArgs(), !op.getElseRegion().empty());
rewriter.startRootUpdate(affineIf);
rewriter.startOpModification(affineIf);
affineIf.getThenBlock()->getOperations().splice(
std::prev(affineIf.getThenBlock()->end()), ifOps, ifOps.begin(),
std::prev(ifOps.end()));
Expand All @@ -571,7 +571,7 @@ class AffineIfConversion : public mlir::OpRewritePattern<fir::IfOp> {
std::prev(affineIf.getElseBlock()->end()), otherOps, otherOps.begin(),
std::prev(otherOps.end()));
}
rewriter.finalizeRootUpdate(affineIf);
rewriter.finalizeOpModification(affineIf);
rewriteMemoryOps(affineIf.getBody(), rewriter);

LLVM_DEBUG(llvm::dbgs() << "AffineIfConversion: if converted to:\n";
Expand Down
8 changes: 4 additions & 4 deletions flang/lib/Optimizer/Transforms/ExternalNameConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ struct MangleNameOnFuncOp : public mlir::OpRewritePattern<mlir::func::FuncOp> {
matchAndRewrite(mlir::func::FuncOp op,
mlir::PatternRewriter &rewriter) const override {
mlir::LogicalResult ret = success();
rewriter.startRootUpdate(op);
rewriter.startOpModification(op);
llvm::StringRef oldName = op.getSymName();
auto result = fir::NameUniquer::deconstruct(oldName);
if (fir::NameUniquer::isExternalFacingUniquedName(result)) {
Expand All @@ -95,7 +95,7 @@ struct MangleNameOnFuncOp : public mlir::OpRewritePattern<mlir::func::FuncOp> {
}

updateEarlyOutliningParentName(op, appendUnderscore);
rewriter.finalizeRootUpdate(op);
rewriter.finalizeOpModification(op);
return ret;
}

Expand All @@ -114,15 +114,15 @@ struct MangleNameForCommonBlock : public mlir::OpRewritePattern<fir::GlobalOp> {
mlir::LogicalResult
matchAndRewrite(fir::GlobalOp op,
mlir::PatternRewriter &rewriter) const override {
rewriter.startRootUpdate(op);
rewriter.startOpModification(op);
auto result = fir::NameUniquer::deconstruct(
op.getSymref().getRootReference().getValue());
if (fir::NameUniquer::isExternalFacingUniquedName(result)) {
auto newName = mangleExternalName(result, appendUnderscore);
op.setSymrefAttr(mlir::SymbolRefAttr::get(op.getContext(), newName));
SymbolTable::setSymbolName(op, newName);
}
rewriter.finalizeRootUpdate(op);
rewriter.finalizeOpModification(op);
return success();
}

Expand Down
10 changes: 5 additions & 5 deletions mlir/docs/PatternRewriter.md
Original file line number Diff line number Diff line change
Expand Up @@ -213,15 +213,15 @@ user is determined by the specific pattern driver.
This method replaces an operation's results with a set of provided values, and
erases the operation.

* Update an Operation in-place : `(start|cancel|finalize)RootUpdate`
* Update an Operation in-place : `(start|cancel|finalize)OpModification`

This is a collection of methods that provide a transaction-like API for updating
the attributes, location, operands, or successors of an operation in-place
within a pattern. An in-place update transaction is started with
`startRootUpdate`, and may either be canceled or finalized with
`cancelRootUpdate` and `finalizeRootUpdate` respectively. A convenience wrapper,
`updateRootInPlace`, is provided that wraps a `start` and `finalize` around a
callback.
`startOpModification`, and may either be canceled or finalized with
`cancelOpModification` and `finalizeOpModification` respectively. A convenience
wrapper, `modifyOpInPlace`, is provided that wraps a `start` and `finalize`
around a callback.

* OpBuilder API

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class StandaloneSwitchBarFooRewriter : public OpRewritePattern<func::FuncOp> {
LogicalResult matchAndRewrite(func::FuncOp op,
PatternRewriter &rewriter) const final {
if (op.getSymName() == "bar") {
rewriter.updateRootInPlace(op, [&op]() { op.setSymName("foo"); });
rewriter.modifyOpInPlace(op, [&op]() { op.setSymName("foo"); });
return success();
}
return failure();
Expand Down
4 changes: 2 additions & 2 deletions mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,8 @@ struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
ConversionPatternRewriter &rewriter) const final {
// We don't lower "toy.print" in this pass, but we need to update its
// operands.
rewriter.updateRootInPlace(op,
[&] { op->setOperands(adaptor.getOperands()); });
rewriter.modifyOpInPlace(op,
[&] { op->setOperands(adaptor.getOperands()); });
return success();
}
};
Expand Down
4 changes: 2 additions & 2 deletions mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,8 @@ struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
ConversionPatternRewriter &rewriter) const final {
// We don't lower "toy.print" in this pass, but we need to update its
// operands.
rewriter.updateRootInPlace(op,
[&] { op->setOperands(adaptor.getOperands()); });
rewriter.modifyOpInPlace(op,
[&] { op->setOperands(adaptor.getOperands()); });
return success();
}
};
Expand Down
4 changes: 2 additions & 2 deletions mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,8 @@ struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
ConversionPatternRewriter &rewriter) const final {
// We don't lower "toy.print" in this pass, but we need to update its
// operands.
rewriter.updateRootInPlace(op,
[&] { op->setOperands(adaptor.getOperands()); });
rewriter.modifyOpInPlace(op,
[&] { op->setOperands(adaptor.getOperands()); });
return success();
}
};
Expand Down
44 changes: 23 additions & 21 deletions mlir/include/mlir/IR/PatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -585,28 +585,30 @@ class RewriterBase : public OpBuilder {

/// This method is used to notify the rewriter that an in-place operation
/// modification is about to happen. A call to this function *must* be
/// followed by a call to either `finalizeRootUpdate` or `cancelRootUpdate`.
/// This is a minor efficiency win (it avoids creating a new operation and
/// removing the old one) but also often allows simpler code in the client.
virtual void startRootUpdate(Operation *op) {}

/// This method is used to signal the end of a root update on the given
/// operation. This can only be called on operations that were provided to a
/// call to `startRootUpdate`.
virtual void finalizeRootUpdate(Operation *op);

/// This method cancels a pending root update. This can only be called on
/// operations that were provided to a call to `startRootUpdate`.
virtual void cancelRootUpdate(Operation *op) {}

/// This method is a utility wrapper around a root update of an operation. It
/// wraps calls to `startRootUpdate` and `finalizeRootUpdate` around the given
/// callable.
/// followed by a call to either `finalizeOpModification` or
/// `cancelOpModification`. This is a minor efficiency win (it avoids creating
/// a new operation and removing the old one) but also often allows simpler
/// code in the client.
virtual void startOpModification(Operation *op) {}

/// This method is used to signal the end of an in-place modification of the
/// given operation. This can only be called on operations that were provided
/// to a call to `startOpModification`.
virtual void finalizeOpModification(Operation *op);

/// This method cancels a pending in-place modification. This can only be
/// called on operations that were provided to a call to
/// `startOpModification`.
virtual void cancelOpModification(Operation *op) {}

/// This method is a utility wrapper around an in-place modification of an
/// operation. It wraps calls to `startOpModification` and
/// `finalizeOpModification` around the given callable.
template <typename CallableT>
void updateRootInPlace(Operation *root, CallableT &&callable) {
startRootUpdate(root);
void modifyOpInPlace(Operation *root, CallableT &&callable) {
startOpModification(root);
callable();
finalizeRootUpdate(root);
finalizeOpModification(root);
}

/// Find uses of `from` and replace them with `to`. It also marks every
Expand All @@ -619,7 +621,7 @@ class RewriterBase : public OpBuilder {
void replaceAllUsesWith(IRObjectWithUseList<OperandType> *from, ValueT &&to) {
for (OperandType &operand : llvm::make_early_inc_range(from->getUses())) {
Operation *op = operand.getOwner();
updateRootInPlace(op, [&]() { operand.set(to); });
modifyOpInPlace(op, [&]() { operand.set(to); });
}
}
void replaceAllUsesWith(ValueRange from, ValueRange to) {
Expand Down
14 changes: 7 additions & 7 deletions mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -739,17 +739,17 @@ class ConversionPatternRewriter final : public PatternRewriter,
/// PatternRewriter hook for inserting a new operation.
void notifyOperationInserted(Operation *op) override;

/// PatternRewriter hook for updating the root operation in-place.
/// Note: These methods only track updates to the top-level operation itself,
/// PatternRewriter hook for updating the given operation in-place.
/// Note: These methods only track updates to the given operation itself,
/// and not nested regions. Updates to regions will still require notification
/// through other more specific hooks above.
void startRootUpdate(Operation *op) override;
void startOpModification(Operation *op) override;

/// PatternRewriter hook for updating the root operation in-place.
void finalizeRootUpdate(Operation *op) override;
/// PatternRewriter hook for updating the given operation in-place.
void finalizeOpModification(Operation *op) override;

/// PatternRewriter hook for updating the root operation in-place.
void cancelRootUpdate(Operation *op) override;
/// PatternRewriter hook for updating the given operation in-place.
void cancelOpModification(Operation *op) override;

/// PatternRewriter hook for notifying match failure reasons.
LogicalResult
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
// Step 2. Assign the op a real tile ID.
// For simplicity, we always use tile 0 (which always exists).
auto zeroTileId = rewriter.getI32IntegerAttr(0);
rewriter.updateRootInPlace(tileOp, [&] { tileOp.setTileId(zeroTileId); });
rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(zeroTileId); });

VectorType tileVectorType = tileOp.getTileType();
auto sliceType = VectorType::Builder(tileVectorType).dropDim(0);
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -918,8 +918,7 @@ LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
for (auto stream : streams)
streamDestroyCallBuilder.create(loc, rewriter, {stream});

rewriter.updateRootInPlace(yieldOp,
[&] { yieldOp->setOperands(newOperands); });
rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperands(newOperands); });
return success();
}

Expand Down
5 changes: 2 additions & 3 deletions mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,13 @@ class ExpandIfCondition : public OpRewritePattern<OpTy> {
if (!matchPattern(op.getIfCond(), m_Constant(&constAttr))) {
auto ifOp = rewriter.create<scf::IfOp>(op.getLoc(), TypeRange(),
op.getIfCond(), false);
rewriter.updateRootInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
auto thenBodyBuilder = ifOp.getThenBodyBuilder(rewriter.getListener());
thenBodyBuilder.clone(*op.getOperation());
rewriter.eraseOp(op);
} else {
if (constAttr.getInt())
rewriter.updateRootInPlace(op,
[&]() { op.getIfCondMutable().erase(0); });
rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
else
rewriter.eraseOp(op);
}
Expand Down
Loading