Skip to content

Commit

Permalink
fix type inference in index_copy. (apache#12890)
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-da authored and Jose Luis Contreras committed Nov 13, 2018
1 parent 4224fc5 commit 87d0b73
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
12 changes: 11 additions & 1 deletion src/operator/contrib/index_copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@
namespace mxnet {
namespace op {

static bool IndexCopyType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 3U);
CHECK_EQ(out_attrs->size(), 1U);
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
return out_attrs->at(0) != -1;
}

NNVM_REGISTER_OP(_contrib_index_copy)
.describe(R"code(Copies the elements of a `new_tensor` into the `old_tensor` by
selecting the indices in the order given in `index`. The output will be a new tensor
Expand Down Expand Up @@ -56,7 +66,7 @@ mx.nd.contrib.index_copy(x, index, t)
.set_num_inputs(3)
.set_num_outputs(1)
.set_attr<nnvm::FInferShape>("FInferShape", IndexCopyShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>)
.set_attr<nnvm::FInferType>("FInferType", IndexCopyType)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_contrib_backward_index_copy"})
.set_attr<FCompute>("FCompute<cpu>", IndexCopyForward<cpu>)
.add_argument("old_tensor", "NDArray-or-Symbol", "Old tensor")
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4766,7 +4766,7 @@ def test_quantization_op():
def test_index_copy():
x = mx.nd.zeros((5,3))
t = mx.nd.array([[1,2,3],[4,5,6],[7,8,9]])
index = mx.nd.array([0,4,2])
index = mx.nd.array([0,4,2], dtype=np.int64)

x.attach_grad()
t.attach_grad()
Expand Down

0 comments on commit 87d0b73

Please sign in to comment.