Skip to content

Commit

Permalink
[FIRRTL] InferWidths: fix invalid frame reference
Browse files Browse the repository at this point in the history
The frame.indent reference can be invalidated when we push to the
worklist. This change caches the indentation level locally to fix the
issue.
  • Loading branch information
youngar committed Oct 13, 2024
1 parent 9dea002 commit 3e3345e
Showing 1 changed file with 16 additions and 18 deletions.
34 changes: 16 additions & 18 deletions lib/Dialect/FIRRTL/Transforms/InferWidths.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -887,6 +887,7 @@ static ExprSolution solveExpr(Expr *expr, SmallPtrSetImpl<Expr *> &seenVars,

while (!worklist.empty()) {
auto &frame = worklist.back();
auto indent = frame.indent;
auto setSolution = [&](ExprSolution solution) {
// Memoize the result.
if (solution.first && !solution.second)
Expand All @@ -897,11 +898,10 @@ static ExprSolution solveExpr(Expr *expr, SmallPtrSetImpl<Expr *> &seenVars,
LLVM_DEBUG({
if (!isa<KnownExpr>(frame.expr)) {
if (solution.first)
llvm::dbgs().indent(frame.indent * 2)
llvm::dbgs().indent(indent * 2)
<< "= Solved " << *frame.expr << " = " << *solution.first;
else
llvm::dbgs().indent(frame.indent * 2)
<< "= Skipped " << *frame.expr;
llvm::dbgs().indent(indent * 2) << "= Skipped " << *frame.expr;
llvm::dbgs() << " (" << (solution.second ? "cycle broken" : "unique")
<< ")\n";
}
Expand All @@ -914,9 +914,8 @@ static ExprSolution solveExpr(Expr *expr, SmallPtrSetImpl<Expr *> &seenVars,
if (frame.expr->getSolution()) {
LLVM_DEBUG({
if (!isa<KnownExpr>(frame.expr))
llvm::dbgs().indent(frame.indent * 2)
<< "- Cached " << *frame.expr << " = "
<< *frame.expr->getSolution() << "\n";
llvm::dbgs().indent(indent * 2) << "- Cached " << *frame.expr << " = "
<< *frame.expr->getSolution() << "\n";
});
setSolution(ExprSolution{*frame.expr->getSolution(), false});
continue;
Expand All @@ -925,8 +924,7 @@ static ExprSolution solveExpr(Expr *expr, SmallPtrSetImpl<Expr *> &seenVars,
// Otherwise compute the value of the expression.
LLVM_DEBUG({
if (!isa<KnownExpr>(frame.expr))
llvm::dbgs().indent(frame.indent * 2)
<< "- Solving " << *frame.expr << "\n";
llvm::dbgs().indent(indent * 2) << "- Solving " << *frame.expr << "\n";
});

TypeSwitch<Expr *>(frame.expr)
Expand Down Expand Up @@ -960,21 +958,21 @@ static ExprSolution solveExpr(Expr *expr, SmallPtrSetImpl<Expr *> &seenVars,
if (!seenVars.insert(expr).second)
return setSolution(ExprSolution{std::nullopt, true});

worklist.emplace_back(expr->constraint, frame.indent + 1);
worklist.emplace_back(expr->constraint, indent + 1);
if (expr->upperBound)
worklist.emplace_back(expr->upperBound, frame.indent + 1);
worklist.emplace_back(expr->upperBound, indent + 1);
})
.Case<IdExpr>([&](auto *expr) {
if (solvedExprs.contains(expr->arg))
return setSolution(solvedExprs[expr->arg]);
worklist.emplace_back(expr->arg, frame.indent + 1);
worklist.emplace_back(expr->arg, indent + 1);
})
.Case<PowExpr>([&](auto *expr) {
if (solvedExprs.contains(expr->arg))
return setSolution(computeUnary(
solvedExprs[expr->arg], [](int32_t arg) { return 1 << arg; }));

worklist.emplace_back(expr->arg, frame.indent + 1);
worklist.emplace_back(expr->arg, indent + 1);
})
.Case<AddExpr>([&](auto *expr) {
if (solvedExprs.contains(expr->lhs()) &&
Expand All @@ -983,8 +981,8 @@ static ExprSolution solveExpr(Expr *expr, SmallPtrSetImpl<Expr *> &seenVars,
solvedExprs[expr->lhs()], solvedExprs[expr->rhs()],
[](int32_t lhs, int32_t rhs) { return lhs + rhs; }));

worklist.emplace_back(expr->lhs(), frame.indent + 1);
worklist.emplace_back(expr->rhs(), frame.indent + 1);
worklist.emplace_back(expr->lhs(), indent + 1);
worklist.emplace_back(expr->rhs(), indent + 1);
})
.Case<MaxExpr>([&](auto *expr) {
if (solvedExprs.contains(expr->lhs()) &&
Expand All @@ -993,8 +991,8 @@ static ExprSolution solveExpr(Expr *expr, SmallPtrSetImpl<Expr *> &seenVars,
solvedExprs[expr->lhs()], solvedExprs[expr->rhs()],
[](int32_t lhs, int32_t rhs) { return std::max(lhs, rhs); }));

worklist.emplace_back(expr->lhs(), frame.indent + 1);
worklist.emplace_back(expr->rhs(), frame.indent + 1);
worklist.emplace_back(expr->lhs(), indent + 1);
worklist.emplace_back(expr->rhs(), indent + 1);
})
.Case<MinExpr>([&](auto *expr) {
if (solvedExprs.contains(expr->lhs()) &&
Expand All @@ -1003,8 +1001,8 @@ static ExprSolution solveExpr(Expr *expr, SmallPtrSetImpl<Expr *> &seenVars,
solvedExprs[expr->lhs()], solvedExprs[expr->rhs()],
[](int32_t lhs, int32_t rhs) { return std::min(lhs, rhs); }));

worklist.emplace_back(expr->lhs(), frame.indent + 1);
worklist.emplace_back(expr->rhs(), frame.indent + 1);
worklist.emplace_back(expr->lhs(), indent + 1);
worklist.emplace_back(expr->rhs(), indent + 1);
})
.Default([&](auto) {
setSolution(ExprSolution{std::nullopt, false});
Expand Down

0 comments on commit 3e3345e

Please sign in to comment.