Skip to content

Commit

Permalink
[CodeGen][CUDA] Vectorization for intrinsics (apache#5101)
Browse files Browse the repository at this point in the history
- This allows to emit vectorized loads/stores
  for CUDA math intrinsics.

- A few intrinsics should be lowered as CUDAMath not CUDAFastMath ones.

- Fixed the code block identation.
  • Loading branch information
wpan11nv authored and zhiics committed Apr 17, 2020
1 parent 31df622 commit 356115b
Show file tree
Hide file tree
Showing 4 changed files with 238 additions and 53 deletions.
23 changes: 23 additions & 0 deletions src/target/source/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,29 @@ class CodeGenC :
/*! \brief the data type of allocated buffers */
std::unordered_map<const VarNode*, DataType> handle_data_type_;

/*!
* \brief A RAII utility class for emitting code in a scoped region.
*/
class EnterScopeRAII {
// The codegen context.
CodeGenC* cg;

// The new scope level.
int scope;

public:
explicit EnterScopeRAII(CodeGenC* cg) : cg(cg) {
cg->PrintIndent();
cg->stream << "{\n";
scope = cg->BeginScope();
}
~EnterScopeRAII() {
cg->EndScope(scope);
cg->PrintIndent();
cg->stream << "}\n";
}
};

private:
/*! \brief whether to print in SSA form */
bool print_ssa_form_{false};
Expand Down
118 changes: 80 additions & 38 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <tvm/runtime/registry.h>

#include <cmath>
#include <utility>
#include <vector>
#include <string>
#include "literal/cuda_half_t.h"
Expand Down Expand Up @@ -235,25 +236,19 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
void CodeGenCUDA::PrintVecBinaryOp(
const std::string& op, DataType t,
PrimExpr lhs, PrimExpr rhs, std::ostream& os) { // NOLINT(*)
// unpacking operations.
int lanes = t.lanes();

// Delcare the result.
std::string sret = GetUniqueName("_");
this->PrintIndent();
this->PrintType(t, stream);
stream << ' ' << sret << ";\n";
{
// The assignment below introduces side-effect, and the resulting value cannot
// be reused across multiple expression, thus a new scope is needed
int vec_scope = BeginScope();
EnterScopeRAII scope(this);

// default: unpack into individual ops.
// Unpack into individual ops.
std::string vlhs = SSAGetID(PrintExpr(lhs), lhs.dtype());
std::string vrhs = SSAGetID(PrintExpr(rhs), rhs.dtype());
std::string sret = GetUniqueName("_");
{
// delcare type.
this->PrintIndent();
this->PrintType(t, stream);
stream << ' ' << sret << ";\n";
}
for (int i = 0; i < lanes; ++i) {

for (int i = 0, lanes = t.lanes(); i < lanes; ++i) {
std::ostringstream value_temp;
if (isalpha(op[0])) {
value_temp << op << "(";
Expand All @@ -270,9 +265,8 @@ void CodeGenCUDA::PrintVecBinaryOp(
}
PrintVecElemStore(sret, t, i, value_temp.str());
}
os << sret;
EndScope(vec_scope);
}
os << sret;
}

void CodeGenCUDA::PrintVecElemLoad(
Expand Down Expand Up @@ -418,6 +412,54 @@ void CodeGenCUDA::VisitExpr_(const CallNode *op, std::ostream& os) {
this->PrintExpr(op->args[i * 2 + 1], os);
os << "]" << ((i < 3) ? ", ": ")");
}
} else if (op->call_type == CallNode::PureExtern && op->dtype.is_vector()) {
//
// Emit an unsupported vector call
//
// v = intrin_f((float4*)A[0], (float4*)B[0])
//
// as
//
// float4 __ret;
// {
// float4 __arg0 = ((float4*)A)[0];
// float4 __arg1 = ((float4*)B)[0];
// __ret.x = intrin_f(__arg0.x, __arg1.x);
// __ret.y = intrin_f(__arg0.y, __arg1.y);
// __ret.z = intrin_f(__arg0.z, __arg1.z);
// __ret.w = intrin_f(__arg0.w, __arg1.w);
// }
// v = __ret;
//
// Declare the result vector.
std::string sret = GetUniqueName("_");
this->PrintIndent();
this->PrintType(op->dtype, stream);
stream << ' ' << sret << ";\n";
{
EnterScopeRAII scope(this);

// Load arguments.
std::vector<std::string> sargs;
for (size_t i = 0; i < op->args.size(); ++i) {
std::string val = SSAGetID(PrintExpr(op->args[i]), op->args[i].dtype());
sargs.push_back(std::move(val));
}

// Emit a scalar call for each lane.
for (int i = 0; i < op->dtype.lanes(); ++i) {
std::ostringstream scall;
scall << op->name << "(";
for (size_t j = 0; j < op->args.size(); ++j) {
if (j > 0)
scall << ", ";
PrintVecElemLoad(sargs[j], op->args[j].dtype(), i, scall);
}
scall << ")";
PrintVecElemStore(sret, op->dtype, i, scall.str());
}
}
os << sret;
} else {
CodeGenC::VisitExpr_(op, os);
}
Expand Down Expand Up @@ -580,34 +622,34 @@ void CodeGenCUDA::VisitExpr_(const SelectNode* op, std::ostream &os) {
op->true_value->dtype == op->dtype &&
op->dtype.lanes() == op->condition.dtype().lanes());

int lanes = op->dtype.lanes();
int scope = BeginScope();

std::string c_var = SSAGetID(PrintExpr(op->condition), op->dtype);
std::string t_var = SSAGetID(PrintExpr(op->true_value), op->dtype);
std::string f_var = SSAGetID(PrintExpr(op->false_value), op->dtype);
std::string r_var = GetUniqueName("_");

this->PrintIndent();
this->PrintType(op->dtype, stream);
stream << ' ' << r_var << ";\n";
{
EnterScopeRAII scope(this);

std::string c_var = SSAGetID(PrintExpr(op->condition), op->dtype);
std::string t_var = SSAGetID(PrintExpr(op->true_value), op->dtype);
std::string f_var = SSAGetID(PrintExpr(op->false_value), op->dtype);

// The condition is stored as an ushort vector.
DataType memory_ty(DataType::TypeCode::kUInt, 16, lanes);

for (int i = 0; i < lanes; ++i) {
std::ostringstream item;
item << "(bool(";
PrintVecElemLoad(c_var, memory_ty, i, item);
item << ")?";
PrintVecElemLoad(t_var, op->dtype, i, item);
item << ':';
PrintVecElemLoad(f_var, op->dtype, i, item);
item << ')';
PrintVecElemStore(r_var, op->dtype, i, item.str());
// The condition is stored as an ushort vector.
int lanes = op->dtype.lanes();
DataType memory_ty(DataType::TypeCode::kUInt, 16, lanes);

for (int i = 0; i < lanes; ++i) {
std::ostringstream item;
item << "(bool(";
PrintVecElemLoad(c_var, memory_ty, i, item);
item << ")?";
PrintVecElemLoad(t_var, op->dtype, i, item);
item << ':';
PrintVecElemLoad(f_var, op->dtype, i, item);
item << ')';
PrintVecElemStore(r_var, op->dtype, i, item.str());
}
}
os << r_var;
EndScope(scope);
}

inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) { // NOLINT(*)
Expand Down
26 changes: 12 additions & 14 deletions src/target/source/intrin_rule_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,12 @@ namespace intrin {
// Add float suffix to the intrinsics, CUDA fast math.
struct CUDAMath {
std::string operator()(DataType t, std::string name) const {
if (t.lanes() == 1) {
if (t.is_float()) {
switch (t.bits()) {
case 64: return name;
case 32: return name + 'f';
case 16: return 'h' + name;
default: return "";
}
if (t.is_float()) {
switch (t.bits()) {
case 64: return name;
case 32: return name + 'f';
case 16: return 'h' + name;
default: return "";
}
}
return "";
Expand All @@ -45,7 +43,7 @@ struct CUDAMath {

struct CUDAFastMath : public CUDAMath {
std::string operator()(DataType t, std::string name) const {
if (t.lanes() == 1 && t.is_float() && t.bits() == 32) {
if (t.is_float() && t.bits() == 32) {
return "__" + name + 'f';
} else {
return CUDAMath::operator()(t, name);
Expand All @@ -56,7 +54,7 @@ struct CUDAFastMath : public CUDAMath {

struct CUDAFastMathTan : public CUDAMath {
std::string operator()(DataType t, std::string name) const {
if (t.lanes() == 1 && t.is_float()) {
if (t.is_float()) {
switch (t.bits()) {
case 64: return name;
// `__tanf` seems to produce some values too deviant from numpy tan version.
Expand All @@ -72,7 +70,7 @@ struct CUDAFastMathTan : public CUDAMath {

struct CUDAPopcount {
std::string operator()(DataType t, std::string name) const {
if (t.lanes() == 1 && t.is_uint()) {
if (t.is_uint()) {
switch (t.bits()) {
case 32: return "__popc";
case 64: return "__popcll";
Expand Down Expand Up @@ -108,7 +106,7 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp")
.set_body(DispatchExtern<CUDAFastMath>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp2")
.set_body(DispatchExtern<CUDAFastMath>);
.set_body(DispatchExtern<CUDAMath>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp10")
.set_body(DispatchExtern<CUDAFastMath>);
Expand All @@ -132,13 +130,13 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cos")
.set_body(DispatchExtern<CUDAFastMath>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cosh")
.set_body(DispatchExtern<CUDAFastMath>);
.set_body(DispatchExtern<CUDAMath>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sin")
.set_body(DispatchExtern<CUDAFastMath>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sinh")
.set_body(DispatchExtern<CUDAFastMath>);
.set_body(DispatchExtern<CUDAMath>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.atan")
.set_body(DispatchExtern<CUDAMath>);
Expand Down
Loading

0 comments on commit 356115b

Please sign in to comment.