Skip to content

Commit

Permalink
enhance shape inference. allow in complete shape (apache#94)
Browse files Browse the repository at this point in the history
  • Loading branch information
piiswrong authored and tqchen committed Jan 7, 2017
1 parent 7b0ac8f commit ccb349d
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
5 changes: 5 additions & 0 deletions include/nnvm/tuple.h
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,11 @@ class TShape : public Tuple<index_t> {
return begin();
}
#ifdef MSHADOW_XINLINE
template<int dim>
inline TShape(const mshadow::Shape<dim> &s) {// NOLINT(*)
this->assign(s.shape_, s.shape_ + dim);
}

template<int dim>
inline TShape(mshadow::Shape<dim> &&s) {// NOLINT(*)
this->assign(s.shape_, s.shape_ + dim);
Expand Down
18 changes: 10 additions & 8 deletions src/pass/infer_shape_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,27 +155,29 @@ Graph InferAttr(Graph &&ret,
}
};

size_t num_unknown = 0;
const int kMaxStep = 3;
for (int i = 0; i < kMaxStep; ++i) {
size_t last_num_unknown;
size_t num_unknown = rshape.size();
int i = 0;
do {
if (i % 2 == 0) {
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
infer_step(nid, i + 1 == kMaxStep);
infer_step(nid, false);
}
} else {
// backward inference
for (uint32_t i = idx.num_nodes(); i != 0; --i) {
infer_step(i - 1, i + 1 == kMaxStep);
infer_step(i - 1, false);
}
}
last_num_unknown = num_unknown;
num_unknown = 0;
for (size_t j = 0; j < idx.num_node_entries(); ++j) {
if (fis_none(rshape[j])) {
++num_unknown;
}
}
if (num_unknown == 0) break;
}
++i;
} while (num_unknown > 0 && last_num_unknown > num_unknown);
// set the shapes
ret.attrs[attr_name] = std::make_shared<any>(std::move(rshape));
// number of nodes who knows the shape.
Expand All @@ -190,7 +192,7 @@ NNVM_REGISTER_PASS(InferShape)
std::move(ret), TShape(),
"FInferShape", "shape_inputs", "shape_attr_key",
"shape", "shape_num_unknown_nodes",
[](const TShape& s) { return s.ndim() == 0; },
[](const TShape& s) { return s.ndim() == 0 || s.Size() == 0; },
nullptr);
})
.set_change_graph(false)
Expand Down

0 comments on commit ccb349d

Please sign in to comment.