diff --git a/nnvm/include/nnvm/graph.h b/nnvm/include/nnvm/graph.h index 19abfa406c92..60028113cc1e 100644 --- a/nnvm/include/nnvm/graph.h +++ b/nnvm/include/nnvm/graph.h @@ -171,6 +171,8 @@ class IndexedGraph { inline const std::vector& outputs() const { return outputs_; } + // disalllow copy assign + IndexedGraph(const IndexedGraph&) = delete; private: friend class Graph; diff --git a/nnvm/python/nnvm/libinfo.py b/nnvm/python/nnvm/libinfo.py index e15d4d8ae1bf..241960f129f1 100644 --- a/nnvm/python/nnvm/libinfo.py +++ b/nnvm/python/nnvm/libinfo.py @@ -1,9 +1,15 @@ # coding: utf-8 """Information about nnvm.""" from __future__ import absolute_import +import sys import os import platform +if sys.version_info[0] == 3: + import builtins as __builtin__ +else: + import __builtin__ + def find_lib_path(): """Find NNNet dynamic library files. @@ -12,10 +18,19 @@ def find_lib_path(): lib_path : list(string) List of all found path to the libraries """ - curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) - api_path = os.path.join(curr_path, '../../lib/') - cmake_build_path = os.path.join(curr_path, '../../build/Release/') - dll_path = [curr_path, api_path, cmake_build_path] + if hasattr(__builtin__, "NNVM_BASE_PATH"): + base_path = __builtin__.NNVM_BASE_PATH + else: + base_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) + + if hasattr(__builtin__, "NNVM_LIBRARY_NAME"): + lib_name = __builtin__.NNVM_LIBRARY_NAME + else: + lib_name = "libnnvm_example" + + api_path = os.path.join(base_path, '../../lib/') + cmake_build_path = os.path.join(base_path, '../../build/Release/') + dll_path = [base_path, api_path, cmake_build_path] if os.name == 'nt': vs_configuration = 'Release' if platform.architecture()[0] == '64bit': @@ -27,9 +42,9 @@ def find_lib_path(): elif os.name == "posix" and os.environ.get('LD_LIBRARY_PATH', None): dll_path.extend([p.strip() for p in os.environ['LD_LIBRARY_PATH'].split(":")]) if os.name == 'nt': - dll_path = [os.path.join(p, 'libnnvm_example.dll') for p in dll_path] + dll_path = [os.path.join(p, '%s.dll' % lib_name) for p in dll_path] else: - dll_path = [os.path.join(p, 'libnnvm_example.so') for p in dll_path] + dll_path = [os.path.join(p, '%s.so' % lib_name) for p in dll_path] lib_path = [p for p in dll_path if os.path.exists(p) and os.path.isfile(p)] if len(lib_path) == 0: raise RuntimeError('Cannot find the files.\n' + diff --git a/nnvm/python/nnvm/symbol.py b/nnvm/python/nnvm/symbol.py index df5bad28266a..3b11ad9711fd 100644 --- a/nnvm/python/nnvm/symbol.py +++ b/nnvm/python/nnvm/symbol.py @@ -57,7 +57,7 @@ def __rsub__(self, other): def __mul__(self, other): if isinstance(other, Symbol): return _internal.__mul_symbol__(self, other) - if isinstance(other, Number): + if isinstance(other, _Number): return _internal.__mul_scalar__(self, scalar=other) else: raise TypeError('type %s not supported' % str(type(other))) diff --git a/nnvm/src/pass/infer_shape_type.cc b/nnvm/src/pass/infer_shape_type.cc index 3f9eca0a8c4b..a18528b5b465 100644 --- a/nnvm/src/pass/infer_shape_type.cc +++ b/nnvm/src/pass/infer_shape_type.cc @@ -23,11 +23,16 @@ Graph InferAttr(Graph &&ret, using AttrVector = std::vector; const IndexedGraph& idx = ret.indexed_graph(); static auto& finfer_shape = - Op::GetAttr>(infer_name); + Op::GetAttr >(infer_name); static auto& backward_map = Op::GetAttr("FBackwardOutToInIndex"); // reshape shape vector - AttrVector rshape(idx.num_node_entries(), default_val); + AttrVector rshape; + if (ret.attrs.count(attr_name) != 0) { + rshape = ret.MoveCopyAttr(attr_name); + } else { + rshape.resize(idx.num_node_entries(), default_val); + } if (ret.attrs.count(input_name) != 0) { const AttrVector& shape_args = ret.GetAttr(input_name); @@ -39,6 +44,7 @@ Graph InferAttr(Graph &&ret, // erase the provided arguments ret.attrs.erase(input_name); } + std::string shape_attr_key; if (ret.attrs.count(attr_key_name) != 0) { shape_attr_key = ret.GetAttr(attr_key_name); diff --git a/nnvm/tests/python/test_graph.py b/nnvm/tests/python/test_graph.py index b958cee10be0..f30acbf5c075 100644 --- a/nnvm/tests/python/test_graph.py +++ b/nnvm/tests/python/test_graph.py @@ -59,6 +59,22 @@ def test_infer_shape(): assert g.json_attr('shape')[jnode_row_ptr[nindex["reshape1"]]] == [2, 4] assert g.json_attr('shape')[jnode_row_ptr[nindex["add1"]]] == [4, 2] +def test_infer_shape_known_partial(): + x = sym.Variable('x', shape=(4, 2)) + y = sym.add(x, x, name='add1') + y = sym.reshape(y, target=(2, 4), name="reshape1") + g = graph.create(y) + jgraph = json.loads(g.apply('SaveJSON').json_attr('json')) + shape = [[4, 2], [] , []] + g._set_json_attr("shape", shape, 'list_shape') + g = g.apply("InferShape") + jnodes = jgraph['nodes'] + jnode_row_ptr = jgraph['node_row_ptr'] + nindex = {n['name']: i for i, n in enumerate(jnodes)} + assert g.json_attr('shape')[jnode_row_ptr[nindex["reshape1"]]] == [2, 4] + assert g.json_attr('shape')[jnode_row_ptr[nindex["add1"]]] == [4, 2] + + def test_infer_type(): x = sym.Variable('x') y = sym.add(x, x, name='add1') @@ -116,6 +132,7 @@ def test_plan_memory(): test_graph_json_attr() test_json_pass() test_infer_shape() + test_infer_shape_known_partial() test_infer_type() test_place_device() test_plan_memory()