diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 51d136d5510e..6311b435f197 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -514,6 +514,34 @@ void CodeGenSPIRV::VisitStmt_(const ForNode* op) { builder_->StartLabel(merge_label); } +void CodeGenSPIRV::VisitStmt_(const WhileNode* op) { + spirv::Label head_label = builder_->NewLabel(); + spirv::Label body_label = builder_->NewLabel(); + spirv::Label continue_label = builder_->NewLabel(); + spirv::Label merge_label = builder_->NewLabel(); + builder_->MakeInst(spv::OpBranch, head_label); + + // Loop head + builder_->StartLabel(head_label); + spirv::Value loop_cond = MakeValue(op->condition); + uint32_t control = spv::LoopControlMaskNone; + builder_->MakeInst(spv::OpLoopMerge, merge_label, continue_label, control); + builder_->MakeInst(spv::OpBranchConditional, loop_cond, body_label, merge_label, + weight_likely_branch_, 1); + + // loop body + builder_->StartLabel(body_label); + this->VisitStmt(op->body); + builder_->MakeInst(spv::OpBranch, continue_label); + + // loop continue + builder_->StartLabel(continue_label); + builder_->MakeInst(spv::OpBranch, head_label); + + // loop merge + builder_->StartLabel(merge_label); +} + void CodeGenSPIRV::VisitStmt_(const IfThenElseNode* op) { spirv::Value cond = MakeValue(op->condition); spirv::Label then_label = builder_->NewLabel(); diff --git a/src/target/spirv/codegen_spirv.h b/src/target/spirv/codegen_spirv.h index be755641c8a5..1e80fcc4a931 100644 --- a/src/target/spirv/codegen_spirv.h +++ b/src/target/spirv/codegen_spirv.h @@ -93,6 +93,7 @@ class CodeGenSPIRV : public ExprFunctor, // stmt void VisitStmt_(const StoreNode* op) override; void VisitStmt_(const ForNode* op) override; + void VisitStmt_(const WhileNode* op) override; void VisitStmt_(const IfThenElseNode* op) override; void VisitStmt_(const AllocateNode* op) override; void VisitStmt_(const AttrStmtNode* op) override; diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index 46bc500fc503..8ad5cb63924e 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -405,6 +405,7 @@ def check_target(target, ir): check_target("llvm", mandel_ir_cpu) check_target("npvtx", mandel_ir_gpu) check_target("cuda", mandel_ir_gpu) + check_target("vulkan", mandel_ir_gpu) def test_while_binary_search(): @@ -493,6 +494,7 @@ def check_target(target, ir): check_target("llvm", searchsorted_ir_cpu) check_target("cuda", searchsorted_ir_gpu) check_target("nvptx", searchsorted_ir_gpu) + check_target("vulkan", searchsorted_ir_gpu) if __name__ == "__main__":