diff --git a/nnvm/python/nnvm/compiler/build_module.py b/nnvm/python/nnvm/compiler/build_module.py index 04b1f0a2e96f..24d7a85dea99 100644 --- a/nnvm/python/nnvm/compiler/build_module.py +++ b/nnvm/python/nnvm/compiler/build_module.py @@ -8,7 +8,7 @@ from .. import runtime OPT_PASS_LEVEL = { - "SimplifyBatchNormInference": 2, + "SimplifyInference": 2, "PrecomputePrune": 2, "OpFusion": 1 } @@ -115,12 +115,9 @@ def optimize(graph, shape, dtype="float32"): """ # pylint: disable=unused-argument cfg = BuildConfig.current - graph = graph_attr.set_shape_inputs(graph, shape) - graph = graph.apply("InferShape") - if graph.json_attr("shape_num_unknown_nodes"): - raise ValueError("InferShape fails..") - if cfg.opt_level >= OPT_PASS_LEVEL["SimplifyBatchNormInference"]: - graph = graph.apply("SimplifyBatchNormInference") + if cfg.opt_level >= OPT_PASS_LEVEL["SimplifyInference"]: + graph = graph_attr.set_shape_inputs(graph, shape) + graph = graph.apply(["InferShape", "SimplifyInference"]) return graph diff --git a/nnvm/python/nnvm/top/tensor.py b/nnvm/python/nnvm/top/tensor.py index 0259e99b6aee..05396d65abeb 100644 --- a/nnvm/python/nnvm/top/tensor.py +++ b/nnvm/python/nnvm/top/tensor.py @@ -44,6 +44,11 @@ def _compute(attrs, x, _): _fschedule_broadcast = tvm.convert(_schedule_broadcast) +# copy +reg.register_compute("copy", _compute_unary(topi.identity)) +reg.register_pattern("copy", OpPattern.ELEM_WISE) +reg.register_schedule("copy", _fschedule_broadcast) + # exp reg.register_compute("exp", _compute_unary(topi.exp)) reg.register_pattern("exp", OpPattern.ELEM_WISE) diff --git a/nnvm/src/compiler/simplify_batch_norm.cc b/nnvm/src/compiler/simplify_inference.cc similarity index 90% rename from nnvm/src/compiler/simplify_batch_norm.cc rename to nnvm/src/compiler/simplify_inference.cc index fdee4dca2352..ad27d885f68e 100644 --- a/nnvm/src/compiler/simplify_batch_norm.cc +++ b/nnvm/src/compiler/simplify_inference.cc @@ -22,6 +22,7 @@ BatchNormToInferUnpack(const nnvm::NodeAttrs& attrs, nnvm::NodeEntry moving_mean, nnvm::NodeEntry moving_var, TShape dshape) { + CHECK_NE(dshape.ndim(), 0); CHECK(attrs.op); static const Op* bn_op = Op::Get("batch_norm"); CHECK(attrs.op == bn_op); @@ -76,13 +77,14 @@ BatchNormToInferUnpack(const nnvm::NodeAttrs& attrs, return {out, undef, undef}; } -Graph SimplifyBatchNormInference(nnvm::Graph src) { +Graph SimplifyInference(nnvm::Graph src) { // Get attributes from the graph const IndexedGraph& idx = src.indexed_graph(); const ShapeVector& shape_vec = src.GetAttr("shape"); auto transform = [&](uint32_t nid, const Node* n, std::vector* ret) { if (n->is_variable()) return false; static const Op* bn_op = Op::Get("batch_norm"); + static const Op* dropout_op = Op::Get("dropout"); if (n->op() == bn_op) { *ret = BatchNormToInferUnpack( n->attrs, @@ -93,6 +95,10 @@ Graph SimplifyBatchNormInference(nnvm::Graph src) { n->inputs[4], shape_vec[idx.entry_id(nid, 0)]); return true; + } else if (n->op() == dropout_op) { + NodeEntry undef = MakeNode("__undef__", "undef", {}); + *ret = {n->inputs[0], undef}; + return true; } else { return false; } @@ -100,8 +106,8 @@ Graph SimplifyBatchNormInference(nnvm::Graph src) { return GraphTransform(src, transform); } -NNVM_REGISTER_PASS(SimplifyBatchNormInference) -.set_body(SimplifyBatchNormInference); +NNVM_REGISTER_PASS(SimplifyInference) +.set_body(SimplifyInference); } // namespace compiler } // namespace nnvm diff --git a/nnvm/tests/python/compiler/test_simplify_batchnorm.py b/nnvm/tests/python/compiler/test_simplify_inference.py similarity index 94% rename from nnvm/tests/python/compiler/test_simplify_batchnorm.py rename to nnvm/tests/python/compiler/test_simplify_inference.py index 307c1348fbba..eaff669a904b 100644 --- a/nnvm/tests/python/compiler/test_simplify_batchnorm.py +++ b/nnvm/tests/python/compiler/test_simplify_inference.py @@ -30,12 +30,13 @@ def check(dim, axis, nstep): for i in range(nstep): y1 = sym.batch_norm( y1 + 1, gamma, beta, moving_mean, moving_var, epsilon=eps, axis=axis) + y1 = sym.dropout(y1) y2 = simple_bn(y2 + 1, gamma, beta, moving_mean, moving_var, epsilon=eps, axis=axis, shape=ishape["x"]) g = nnvm.graph.create(y1) g2 = nnvm.graph.create(y2) graph_attr.set_shape_inputs(g, ishape) - g1 = g.apply("InferShape").apply("SimplifyBatchNormInference") + g1 = g.apply("InferShape").apply("SimplifyInference") # Some prints for debug # print(g1.ir()) # assert graph equals as expected