-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SCHEDULE] New Reduction Mode for Tensorize #727
Changes from 1 commit
e386176
1151fd6
784a0fd
5bbc941
ab49e9f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -208,5 +208,31 @@ Stmt Substitute(Stmt s, | |
return ir::Substitute(s, init); | ||
} | ||
|
||
Stmt TransformUpdate(const Stage& stage, | ||
const std::unordered_map<IterVar, Range>& dom_map, | ||
Stmt body, | ||
Stmt update) | ||
{ | ||
Expr condition = (make_zero(Int(32)) == make_zero(Int(32))); //will try Bool(1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. build a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done, please check |
||
for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) { | ||
IterVar iv = stage->leaf_iter_vars[i]; | ||
auto iit = stage->iter_var_attrs.find(iv); | ||
if (iit != stage->iter_var_attrs.end()) { | ||
const IterVarAttr& attr = (*iit).second; | ||
if (attr->iter_type == kTensorized) { | ||
break; | ||
} | ||
} | ||
if (iv->iter_type == kCommReduce) { | ||
auto vit = dom_map.find(iv); | ||
CHECK(vit != dom_map.end()); | ||
const Range& vrange = vit->second; | ||
Expr newcond = ( iv->var == vrange->min); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use condition
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
condition = condition && newcond; | ||
} | ||
} | ||
return IfThenElse::make(condition, body, update); | ||
} | ||
|
||
} // namespace op | ||
} // namespace tvm |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -416,32 +416,48 @@ Stmt MakeTensorize(const ComputeOpNode* self, | |
return MergeNest(nest, body); | ||
} else { | ||
// Need to split reduction | ||
CHECK(intrin->reduce_init.defined()) | ||
<< "Reduction init op for intrin " << intrin << " is not defined"; | ||
// Comment out the following check for the case when there is no init func | ||
//CHECK(intrin->reduce_init.defined()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove lgtm, simply remove them There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
// << "Reduction init op for intrin " << intrin << " is not defined"; | ||
CHECK(intrin->reduce_update.defined()) | ||
<< "Reduction update op for intrin " << intrin << " is not defined"; | ||
// Need init and update steps | ||
CHECK_NE(self->reduce_axis.size(), 0U); | ||
std::vector<std::vector<Stmt> > common( | ||
n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1); | ||
// init nest | ||
std::vector<std::vector<Stmt> > init_nest( | ||
n.init_nest.begin(), n.init_nest.begin() + tloc + 1); | ||
init_nest.emplace_back(op::MakeIfNest(n.init_predicates)); | ||
Stmt init = MergeNest(output_bind_nest, intrin->reduce_init); | ||
init = Substitute(init, n.init_vmap); | ||
init = MergeNest(init_nest, init); | ||
// The update | ||
std::vector<std::vector<Stmt> > update_nest( | ||
n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.begin() + tloc + 1); | ||
update_nest.emplace_back(op::MakeIfNest(n.main_predicates)); | ||
Stmt update = MergeNest(output_bind_nest, intrin->reduce_update); | ||
update = MergeNest(input_bind_nest, update); | ||
update = Substitute(update, vmap); | ||
update = MergeNest(binder.asserts(), update); | ||
update = Substitute(update, n.main_vmap); | ||
update = MergeNest(update_nest, update); | ||
return MergeNest(common, Block::make(init, update)); | ||
|
||
if (intrin->reduce_init.defined()) { | ||
// init nest | ||
std::vector<std::vector<Stmt> > init_nest( | ||
n.init_nest.begin(), n.init_nest.begin() + tloc + 1); | ||
init_nest.emplace_back(op::MakeIfNest(n.init_predicates)); | ||
Stmt init = MergeNest(output_bind_nest, intrin->reduce_init); | ||
init = Substitute(init, n.init_vmap); | ||
init = MergeNest(init_nest, init); | ||
// The update | ||
Stmt update = MergeNest(output_bind_nest, intrin->reduce_update); | ||
update = MergeNest(input_bind_nest, update); | ||
update = Substitute(update, vmap); | ||
update = MergeNest(binder.asserts(), update); | ||
update = Substitute(update, n.main_vmap); | ||
update = MergeNest(update_nest, update); | ||
return MergeNest(common, Block::make(init, update)); | ||
} else { | ||
// The update | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a comment about what this logic is about. Something like
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
Stmt update = TransformUpdate (stage, dom_map, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add CHECK(body.defined()); There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
intrin->body, | ||
intrin->reduce_update); | ||
update = MergeNest(output_bind_nest, update); | ||
update = MergeNest(input_bind_nest, update); | ||
update = Substitute(update, vmap); | ||
update = MergeNest(binder.asserts(), update); | ||
update = Substitute(update, n.main_vmap); | ||
update = MergeNest(update_nest, update); | ||
return MergeNest(common, update); | ||
} | ||
} | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is specific to Tensorize, move it to tensorize.cc, no need to keep comment in header file, move comment document to tensorize.cc as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, move the function into tensorize.cc