Skip to content

Commit

Permalink
[Target] Add __launch_bounds__ directive as part of the CUDA code gen…
Browse files Browse the repository at this point in the history
…eration
  • Loading branch information
ArmageddonKnight committed Aug 7, 2021
1 parent 49756a5 commit 9819ed1
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ void CodeGenC::AddFunction(const PrimFunc& f) {
bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);

this->PrintFuncPrefix();
this->PrintExtraAttrs(f);
this->stream << " " << static_cast<std::string>(global_symbol.value()) << "(";

for (size_t i = 0; i < f->params.size(); ++i) {
Expand Down Expand Up @@ -125,6 +126,8 @@ void CodeGenC::AddFunction(const PrimFunc& f) {

void CodeGenC::PrintFuncPrefix() { stream << "void"; }

void CodeGenC::PrintExtraAttrs(const PrimFunc& f) {}

void CodeGenC::PrintFinalReturn() {}

std::string CodeGenC::Finish() { return decl_stream.str() + stream.str(); }
Expand Down
6 changes: 6 additions & 0 deletions src/target/source/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
* Example: stream << "void";
*/
virtual void PrintFuncPrefix(); // NOLINT(*)
/*!
* \brief Print extra function attributes
*
* Example: __launch_bounds__(256) for CUDA functions
*/
virtual void PrintExtraAttrs(const PrimFunc& f);
/*!
* \brief Print the final return at the end the function.
*/
Expand Down
44 changes: 44 additions & 0 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@

#include "codegen_cuda.h"

#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/stmt_functor.h>

#include <cmath>
#include <string>
Expand All @@ -46,6 +48,48 @@ void CodeGenCUDA::Init(bool output_ssa) {

void CodeGenCUDA::PrintFuncPrefix() { stream << "extern \"C\" __global__ void"; }

class ThreadIdxExtractor : public tir::StmtVisitor {
private:
void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->var->name_hint == "threadIdx.x" ||
iv->thread_tag == "threadIdx.x") {
threadIdx_x_ext = op->value;
}
if (iv->var->name_hint == "threadIdx.y" ||
iv->thread_tag == "threadIdx.y") {
threadIdx_y_ext = op->value;
}
if (iv->var->name_hint == "threadIdx.z" ||
iv->thread_tag == "threadIdx.z") {
threadIdx_z_ext = op->value;
}
}
StmtVisitor::VisitStmt_(op);
}

public:
PrimExpr threadIdx_x_ext = Integer(1);
PrimExpr threadIdx_y_ext = Integer(1);
PrimExpr threadIdx_z_ext = Integer(1);
};

void CodeGenCUDA::PrintExtraAttrs(const PrimFunc& f) {
ThreadIdxExtractor extractor;
extractor(f->body);
arith::Analyzer analyzer;
PrimExpr threadIdx_ext = analyzer.Simplify(extractor.threadIdx_x_ext * extractor.threadIdx_y_ext *
extractor.threadIdx_z_ext);
if (const IntImmNode* const threadIdx_ext_int = threadIdx_ext.as<IntImmNode>()) {
if (threadIdx_ext_int->value == 1) {
// unable to extract the number of threads per block, hence directly return
return;
}
stream << " __launch_bounds__(" << threadIdx_ext_int->value << ")";
}
}

std::string CodeGenCUDA::Finish() {
if (enable_fp16_) {
decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)\n";
Expand Down
1 change: 1 addition & 0 deletions src/target/source/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class CodeGenCUDA final : public CodeGenC {
}
// override behavior
void PrintFuncPrefix() final;
void PrintExtraAttrs(const PrimFunc& f) final;
void VisitStmt_(const ForNode* op) final;
void PrintStorageSync(const CallNode* op) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
Expand Down

0 comments on commit 9819ed1

Please sign in to comment.