diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index f5a553aa0598d..99d71ebe15bda 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -113,6 +113,16 @@ class BuiltinLower : public StmtExprMutator { op = stmt.as(); // Get constant allocation bound. int64_t nbytes = GetVectorBytes(op->dtype); + if (device_type_.defined()) { + if (const auto* dev_type = device_type_.as()) { + if (dev_type->value == kDLCPU) { + int32_t constant_size = op->constant_allocation_size(); + if (constant_size > 0 && constant_size * nbytes < runtime::kMaxStackAlloca) { + return stmt; + } + } + } + } PrimExpr total_bytes = make_const(op->extents[0].dtype(), nbytes); for (size_t i = 0; i < op->extents.size(); ++i) { total_bytes = total_bytes * op->extents[i];