Skip to content

Commit

Permalink
[CodeGen] Cleanup generated code (apache#5424)
Browse files Browse the repository at this point in the history
- remove unnecessary white spaces from storage kind
- do not start a new scope for vectorization as temporary
  variables are alll uniquely generated.

The above two changes make vectorized code much cleaner.

Signed-off-by: Wei Pan <[email protected]>
  • Loading branch information
wpan11nv authored and Trevor Morris committed Jun 18, 2020
1 parent 0d65f5c commit 924e464
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 48 deletions.
10 changes: 1 addition & 9 deletions src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ void CodeGenC::AddFunction(const PrimFunc& f) {
auto it = alloc_storage_scope_.find(v.get());
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, stream);
stream << ' ';
}

PrintType(GetType(v), stream);
Expand Down Expand Up @@ -179,7 +178,6 @@ std::string CodeGenC::GetBufferRef(
if (!scope.empty() && IsScopePartOfType()) {
PrintStorageScope(scope, os);
}
os << ' ';
PrintType(t, os);
os << "*)" << vid << ')';
} else {
Expand Down Expand Up @@ -213,15 +211,13 @@ std::string CodeGenC::GetBufferRef(
if (!scope.empty() && IsScopePartOfType()) {
PrintStorageScope(scope, os);
}
os << ' ';
PrintType(t, os);
os << "*)(";
if (!HandleTypeMatch(buffer, t.element_of())) {
os << '(';
if (!scope.empty() && IsScopePartOfType()) {
PrintStorageScope(scope, os);
}
os << ' ';
PrintType(t.element_of(), os);
os << "*)";
}
Expand Down Expand Up @@ -681,7 +677,6 @@ void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*)
auto it = alloc_storage_scope_.find(op->buffer_var.get());
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, value_temp);
value_temp << ' ';
}
}
PrintType(elem_type, value_temp);
Expand Down Expand Up @@ -731,7 +726,6 @@ void CodeGenC::VisitStmt_(const StoreNode* op) {
auto it = alloc_storage_scope_.find(op->buffer_var.get());
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, stream);
stream << ' ';
}
}
PrintType(elem_type, stream);
Expand Down Expand Up @@ -823,10 +817,8 @@ void CodeGenC::VisitStmt_(const AllocateNode* op) {
const VarNode* buffer = op->buffer_var.as<VarNode>();
std::string scope = alloc_storage_scope_.at(buffer);
PrintStorageScope(scope, stream);
stream << ' ';
PrintType(op->dtype, stream);
stream << ' '<< vid << '['
<< constant_size << "];\n";
stream << ' ' << vid << '[' << constant_size << "];\n";

RegisterHandleType(op->buffer_var.get(), op->dtype);
this->PrintStmt(op->body);
Expand Down
23 changes: 0 additions & 23 deletions src/target/source/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -257,29 +257,6 @@ class CodeGenC :
/*! \brief the data type of allocated buffers */
std::unordered_map<const VarNode*, DataType> handle_data_type_;

/*!
* \brief A RAII utility class for emitting code in a scoped region.
*/
class EnterScopeRAII {
// The codegen context.
CodeGenC* cg;

// The new scope level.
int scope;

public:
explicit EnterScopeRAII(CodeGenC* cg) : cg(cg) {
cg->PrintIndent();
cg->stream << "{\n";
scope = cg->BeginScope();
}
~EnterScopeRAII() {
cg->EndScope(scope);
cg->PrintIndent();
cg->stream << "}\n";
}
};

private:
/*! \brief whether to print in SSA form */
bool print_ssa_form_{false};
Expand Down
10 changes: 1 addition & 9 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,6 @@ void CodeGenCUDA::PrintVecBinaryOp(
this->PrintType(t, stream);
stream << ' ' << sret << ";\n";
{
EnterScopeRAII scope(this);

// Unpack into individual ops.
std::string vlhs = SSAGetID(PrintExpr(lhs), lhs.dtype());
std::string vrhs = SSAGetID(PrintExpr(rhs), rhs.dtype());
Expand Down Expand Up @@ -350,7 +348,7 @@ void CodeGenCUDA::PrintStorageScope(
const std::string& scope, std::ostream& os) { // NOLINT(*)
CHECK_NE(scope, "global");
if (scope == "shared") {
os << "__shared__";
os << "__shared__ ";
}
}

Expand All @@ -370,7 +368,6 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) {
this->PrintType(target_ty, stream);
stream << ' ' << sret << ";\n";
{
EnterScopeRAII scope(this);
std::string src = SSAGetID(PrintExpr(op->value), from_ty);
for (int i = 0, lanes = from_ty.lanes(); i < lanes; ++i) {
std::ostringstream val;
Expand Down Expand Up @@ -470,8 +467,6 @@ void CodeGenCUDA::VisitExpr_(const CallNode *op, std::ostream& os) {
this->PrintType(op->dtype, stream);
stream << ' ' << sret << ";\n";
{
EnterScopeRAII scope(this);

// Load arguments.
std::vector<std::string> sargs;
for (size_t i = 0; i < op->args.size(); ++i) {
Expand Down Expand Up @@ -541,7 +536,6 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) {
PrintWmmaScope(scope, op->dtype, buffer, stream);
} else {
PrintStorageScope(scope, stream);
stream << ' ';
PrintType(op->dtype, stream);
}
if ((op->dtype == DataType::Int(4) ||
Expand Down Expand Up @@ -657,8 +651,6 @@ void CodeGenCUDA::VisitExpr_(const SelectNode* op, std::ostream &os) {
this->PrintType(op->dtype, stream);
stream << ' ' << r_var << ";\n";
{
EnterScopeRAII scope(this);

std::string c_var = SSAGetID(PrintExpr(op->condition), op->dtype);
std::string t_var = SSAGetID(PrintExpr(op->true_value), op->dtype);
std::string f_var = SSAGetID(PrintExpr(op->false_value), op->dtype);
Expand Down
7 changes: 3 additions & 4 deletions src/target/source/codegen_metal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) {
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, stream);
}
stream << ' ';
PrintType(GetType(v), stream);
// Register handle data type
// TODO(tvm-team): consider simply keep type info in the
Expand Down Expand Up @@ -236,11 +235,11 @@ void CodeGenMetal::PrintVecElemStore(const std::string& vec,
void CodeGenMetal::PrintStorageScope(
const std::string& scope, std::ostream& os) { // NOLINT(*)
if (scope == "global") {
os << "device";
os << "device ";
} else if (scope == "shared") {
os << "threadgroup";
os << "threadgroup ";
} else {
os << "thread";
os << "thread ";
}
}

Expand Down
5 changes: 2 additions & 3 deletions src/target/source/codegen_opencl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ void CodeGenOpenCL::PrintVecAddr(const VarNode* buffer, DataType t,
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, os);
}
os << ' ';
PrintType(t.element_of(), os);
os << "*)";
}
Expand Down Expand Up @@ -191,9 +190,9 @@ void CodeGenOpenCL::PrintStorageSync(const CallNode* op) {
void CodeGenOpenCL::PrintStorageScope(
const std::string& scope, std::ostream& os) { // NOLINT(*)
if (scope == "global") {
os << "__global";
os << "__global ";
} else if (scope == "shared") {
os << "__local";
os << "__local ";
}
}

Expand Down

0 comments on commit 924e464

Please sign in to comment.