Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Target] Add __launch_bounds__ directive as part of the CUDA code generation #8678

Merged
merged 1 commit into from
Aug 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ void CodeGenC::AddFunction(const PrimFunc& f) {
bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);

this->PrintFuncPrefix();
this->PrintExtraAttrs(f);
this->stream << " " << static_cast<std::string>(global_symbol.value()) << "(";

for (size_t i = 0; i < f->params.size(); ++i) {
Expand Down Expand Up @@ -125,6 +126,8 @@ void CodeGenC::AddFunction(const PrimFunc& f) {

void CodeGenC::PrintFuncPrefix() { stream << "void"; }

void CodeGenC::PrintExtraAttrs(const PrimFunc& f) {}

void CodeGenC::PrintFinalReturn() {}

std::string CodeGenC::Finish() { return decl_stream.str() + stream.str(); }
Expand Down
6 changes: 6 additions & 0 deletions src/target/source/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
* Example: stream << "void";
*/
virtual void PrintFuncPrefix(); // NOLINT(*)
/*!
* \brief Print extra function attributes
*
* Example: __launch_bounds__(256) for CUDA functions
*/
virtual void PrintExtraAttrs(const PrimFunc& f);
/*!
* \brief Print the final return at the end the function.
*/
Expand Down
41 changes: 41 additions & 0 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@

#include "codegen_cuda.h"

#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/stmt_functor.h>

#include <cmath>
#include <string>
Expand All @@ -46,6 +48,45 @@ void CodeGenCUDA::Init(bool output_ssa) {

void CodeGenCUDA::PrintFuncPrefix() { stream << "extern \"C\" __global__ void"; }

class ThreadIdxExtractor : public tir::StmtVisitor {
private:
void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->var->name_hint == "threadIdx.x" || iv->thread_tag == "threadIdx.x") {
threadIdx_x_ext = op->value;
}
if (iv->var->name_hint == "threadIdx.y" || iv->thread_tag == "threadIdx.y") {
threadIdx_y_ext = op->value;
}
if (iv->var->name_hint == "threadIdx.z" || iv->thread_tag == "threadIdx.z") {
threadIdx_z_ext = op->value;
}
}
StmtVisitor::VisitStmt_(op);
}

public:
PrimExpr threadIdx_x_ext = Integer(1);
PrimExpr threadIdx_y_ext = Integer(1);
PrimExpr threadIdx_z_ext = Integer(1);
};

void CodeGenCUDA::PrintExtraAttrs(const PrimFunc& f) {
ThreadIdxExtractor extractor;
extractor(f->body);
arith::Analyzer analyzer;
PrimExpr threadIdx_ext = analyzer.Simplify(extractor.threadIdx_x_ext * extractor.threadIdx_y_ext *
extractor.threadIdx_z_ext);
if (const IntImmNode* const threadIdx_ext_int = threadIdx_ext.as<IntImmNode>()) {
if (threadIdx_ext_int->value == 1) {
// unable to extract the number of threads per block, hence directly return
return;
}
stream << " __launch_bounds__(" << threadIdx_ext_int->value << ")";
ArmageddonKnight marked this conversation as resolved.
Show resolved Hide resolved
}
}

std::string CodeGenCUDA::Finish() {
if (enable_fp16_) {
decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)\n";
Expand Down
1 change: 1 addition & 0 deletions src/target/source/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class CodeGenCUDA final : public CodeGenC {
}
// override behavior
void PrintFuncPrefix() final;
void PrintExtraAttrs(const PrimFunc& f) final;
void VisitStmt_(const ForNode* op) final;
void PrintStorageSync(const CallNode* op) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
Expand Down