Skip to content

Commit

Permalink
[Target] Add __launch_bounds__ directive as part of the CUDA code gen…
Browse files Browse the repository at this point in the history
…eration
  • Loading branch information
ArmageddonKnight committed Aug 6, 2021
1 parent 338940d commit 28e4a55
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ void CodeGenC::AddFunction(const PrimFunc& f) {
bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);

this->PrintFuncPrefix();
this->PrintExtraAttrs(f);
this->stream << " " << static_cast<std::string>(global_symbol.value()) << "(";

for (size_t i = 0; i < f->params.size(); ++i) {
Expand Down Expand Up @@ -125,6 +126,8 @@ void CodeGenC::AddFunction(const PrimFunc& f) {

void CodeGenC::PrintFuncPrefix() { stream << "void"; }

void CodeGenC::PrintExtraAttrs(const PrimFunc& f) {}

void CodeGenC::PrintFinalReturn() {}

std::string CodeGenC::Finish() { return decl_stream.str() + stream.str(); }
Expand Down
6 changes: 6 additions & 0 deletions src/target/source/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
* Example: stream << "void";
*/
virtual void PrintFuncPrefix(); // NOLINT(*)
/*!
* \brief Print extra function attributes
*
* Example: __launch_bounds__(256) for CUDA functions
*/
virtual void PrintExtraAttrs(const PrimFunc& f);
/*!
* \brief Print the final return at the end the function.
*/
Expand Down
38 changes: 38 additions & 0 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@

#include "codegen_cuda.h"

#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/stmt_functor.h>

#include <cmath>
#include <string>
Expand All @@ -46,6 +48,42 @@ void CodeGenCUDA::Init(bool output_ssa) {

void CodeGenCUDA::PrintFuncPrefix() { stream << "extern \"C\" __global__ void"; }

class threadIdxExtractor : public tir::StmtVisitor {
private:
void VisitStmt_(const AttrStmtNode* op) override final {
if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
LOG(INFO) << iv->var->name_hint;

if (iv->var->name_hint == "threadIdx.x") {
threadIdx_x_ext = op->value;
}
if (iv->var->name_hint == "threadIdx.y") {
threadIdx_y_ext = op->value;
}
if (iv->var->name_hint == "threadIdx.z") {
threadIdx_z_ext = op->value;
}
}
StmtVisitor::VisitStmt_(op);
}
public:
PrimExpr threadIdx_x_ext = Integer(1), threadIdx_y_ext = Integer(1), threadIdx_z_ext = Integer(1);
};


void CodeGenCUDA::PrintExtraAttrs(const PrimFunc& f) {
threadIdxExtractor extractor;
extractor(f->body);
arith::Analyzer analyzer;
PrimExpr threadIdx_ext =
analyzer.Simplify(extractor.threadIdx_x_ext * extractor.threadIdx_y_ext *
extractor.threadIdx_z_ext);
if (const IntImmNode* const threadIdx_ext_int = threadIdx_ext.as<IntImmNode>()) {
stream << " __launch_bounds__(" << threadIdx_ext_int->value << ")";
}
}

std::string CodeGenCUDA::Finish() {
if (enable_fp16_) {
decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)\n";
Expand Down
1 change: 1 addition & 0 deletions src/target/source/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class CodeGenCUDA final : public CodeGenC {
}
// override behavior
void PrintFuncPrefix() final;
void PrintExtraAttrs(const PrimFunc& f) final;
void VisitStmt_(const ForNode* op) final;
void PrintStorageSync(const CallNode* op) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
Expand Down

0 comments on commit 28e4a55

Please sign in to comment.