diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc b/src/tir/transforms/plan_update_buffer_allocation_location.cc index 4c63d3393fd8..11d8330ec8fe 100644 --- a/src/tir/transforms/plan_update_buffer_allocation_location.cc +++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc @@ -61,24 +61,35 @@ class BufferAllocateOrderCollector : public StmtExprVisitor { } private: + bool find(const Buffer& buf) { + return std::find(buffer_alloc_recorder_.begin(), buffer_alloc_recorder_.end(), buf) != + buffer_alloc_recorder_.end(); + } + void VisitStmt_(const BlockNode* op) final { for (const Buffer& buffer : op->alloc_buffers) { buffer_alloc_recorder_.push_back(buffer); } + // Also visit match_buffers to collect constant buffers associated with AllocateConst nodes. + // These buffers only appear in read and match_buffer regions. + for (const auto& region : op->match_buffers) { + if (!find(region->source->buffer)) { + buffer_alloc_recorder_.push_back(region->source->buffer); + } + } + StmtExprVisitor::VisitStmt_(op); } void VisitExpr_(const BufferLoadNode* op) final { - if (std::find(buffer_alloc_recorder_.begin(), buffer_alloc_recorder_.end(), op->buffer) == - buffer_alloc_recorder_.end()) { + if (!find(op->buffer)) { buffer_alloc_recorder_.push_back(op->buffer); } StmtExprVisitor::VisitExpr_(op); } void VisitStmt_(const BufferStoreNode* op) final { - if (std::find(buffer_alloc_recorder_.begin(), buffer_alloc_recorder_.end(), op->buffer) == - buffer_alloc_recorder_.end()) { + if (!find(op->buffer)) { buffer_alloc_recorder_.push_back(op->buffer); } StmtExprVisitor::VisitStmt_(op);