diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 8397044e8b939..f676f0f598d80 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -83,6 +83,7 @@ void CodeGenC::AddFunction(const PrimFunc& f) { bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias); this->PrintFuncPrefix(); + this->PrintExtraAttrs(f); this->stream << " " << static_cast(global_symbol.value()) << "("; for (size_t i = 0; i < f->params.size(); ++i) { @@ -125,6 +126,8 @@ void CodeGenC::AddFunction(const PrimFunc& f) { void CodeGenC::PrintFuncPrefix() { stream << "void"; } +void CodeGenC::PrintExtraAttrs(const PrimFunc& f) {} + void CodeGenC::PrintFinalReturn() {} std::string CodeGenC::Finish() { return decl_stream.str() + stream.str(); } diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index 834c57ac10fd0..6ebade7191f27 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -103,6 +103,12 @@ class CodeGenC : public ExprFunctor, * Example: stream << "void"; */ virtual void PrintFuncPrefix(); // NOLINT(*) + /*! + * \brief Print extra function attributes + * + * Example: __launch_bounds__(256) for CUDA functions + */ + virtual void PrintExtraAttrs(const PrimFunc& f); /*! * \brief Print the final return at the end the function. */ diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 7897490730a3e..3abcbd545e676 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -23,7 +23,9 @@ #include "codegen_cuda.h" +#include #include +#include #include #include @@ -46,6 +48,42 @@ void CodeGenCUDA::Init(bool output_ssa) { void CodeGenCUDA::PrintFuncPrefix() { stream << "extern \"C\" __global__ void"; } +class threadIdxExtractor : public tir::StmtVisitor { + private: + void VisitStmt_(const AttrStmtNode* op) override final { + if (op->attr_key == tir::attr::thread_extent) { + IterVar iv = Downcast(op->node); + LOG(INFO) << iv->var->name_hint; + + if (iv->var->name_hint == "threadIdx.x") { + threadIdx_x_ext = op->value; + } + if (iv->var->name_hint == "threadIdx.y") { + threadIdx_y_ext = op->value; + } + if (iv->var->name_hint == "threadIdx.z") { + threadIdx_z_ext = op->value; + } + } + StmtVisitor::VisitStmt_(op); + } + public: + PrimExpr threadIdx_x_ext = Integer(1), threadIdx_y_ext = Integer(1), threadIdx_z_ext = Integer(1); +}; + + +void CodeGenCUDA::PrintExtraAttrs(const PrimFunc& f) { + threadIdxExtractor extractor; + extractor(f->body); + arith::Analyzer analyzer; + PrimExpr threadIdx_ext = + analyzer.Simplify(extractor.threadIdx_x_ext * extractor.threadIdx_y_ext * + extractor.threadIdx_z_ext); + if (const IntImmNode* const threadIdx_ext_int = threadIdx_ext.as()) { + stream << " __launch_bounds__(" << threadIdx_ext_int->value << ")"; + } +} + std::string CodeGenCUDA::Finish() { if (enable_fp16_) { decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)\n"; diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h index 2098b8ac83448..385b7343c8fd0 100644 --- a/src/target/source/codegen_cuda.h +++ b/src/target/source/codegen_cuda.h @@ -46,6 +46,7 @@ class CodeGenCUDA final : public CodeGenC { } // override behavior void PrintFuncPrefix() final; + void PrintExtraAttrs(const PrimFunc& f) final; void VisitStmt_(const ForNode* op) final; void PrintStorageSync(const CallNode* op) final; void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)