diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc index be3a006e3b57..414e68e9256f 100644 --- a/src/c_api/c_api_ndarray.cc +++ b/src/c_api/c_api_ndarray.cc @@ -270,6 +270,21 @@ void SetDependency(std::vector *p_read_vars, Engine::Get()->DeduplicateVarHandle(&read_vars, &write_vars); } +inline void SetWriteInplaceReq(const std::vector &ndinputs, + const std::vector &ndoutputs, + std::vector *req) { + std::unordered_set in_vars; + for (auto &nd : ndinputs) { + in_vars.insert(nd.var()); + } + for (size_t i = 0; i < ndoutputs.size(); i++) { + // output NDArray shares the memory with the input NDArray + if (in_vars.find(ndoutputs[i].var()) != in_vars.end()) { + req->at(i) = kWriteInplace; + } + } +} + void PushFCompute(const FCompute& fn, const nnvm::Op* op, const nnvm::NodeAttrs& attrs, @@ -332,6 +347,7 @@ void PushFComputeEx(const FComputeEx& fn, engine::CallbackOnComplete(), requested}; std::vector req(ndoutputs.size(), kWriteTo); + SetWriteInplaceReq(ndinputs, ndoutputs, &req); fn(attrs, opctx, ndinputs, req, ndoutputs); if (ctx.dev_mask() == gpu::kDevMask) { rctx.get_stream()->Wait(); @@ -406,6 +422,7 @@ void PushOperator(const OpStatePtr& state, engine::CallbackOnComplete on_complete) { OpContext opctx{is_train, rctx, on_complete, requested}; std::vector req(ndoutputs.size(), kWriteTo); + SetWriteInplaceReq(ndinputs, ndoutputs, &req); fcompute_ex(state, opctx, ndinputs, req, ndoutputs); if (exec_type == ExecType::kSync) { if (rctx.get_ctx().dev_mask() == gpu::kDevMask) { diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h index e5da182216d6..14673904798a 100644 --- a/src/operator/optimizer_op-inl.h +++ b/src/operator/optimizer_op-inl.h @@ -93,13 +93,13 @@ struct SGDDnsRspKernel { // IType is row sparse idx type // i is the ith row in row sparse gradient template - MSHADOW_XINLINE static void Map(int i, size_t row_length, DType* out, const DType* weight, + MSHADOW_XINLINE static void Map(int i, const index_t row_length, DType* out, const DType* weight, const IType* grad_idx, const DType *grad_val, const DType clip_gradient, const DType lr, const DType wd, const DType rescale_grad) { - for (size_t j = 0; j < row_length; j++) { - uint64_t data_i = grad_idx[i] * row_length + j; - uint64_t grad_i = i * row_length + j; + for (index_t j = 0; j < row_length; j++) { + index_t data_i = grad_idx[i] * row_length + j; + index_t grad_i = i * row_length + j; if (clip_gradient >= 0.0f) { KERNEL_ASSIGN(out[data_i], req, (1.f - lr * wd) * weight[data_i] - (lr) * mshadow_op::clip::Map(rescale_grad * grad_val[grad_i], clip_gradient)); @@ -126,6 +126,7 @@ inline void SGDUpdateDnsRspImpl(const SGDParam& param, CHECK_EQ(grad.storage_type(), kRowSparseStorage); // if gradients are zeros, no weights are updated if (!grad.storage_initialized() || req == kNullOp) return; + CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse sgd_mom_update"; CHECK_GT(weight.shape_.Size(), 0); MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, { @@ -151,7 +152,7 @@ inline void SGDUpdateDnsRspImpl(const SGDParam& param, template struct SGDRspDnsKernel { template - MSHADOW_XINLINE static void Map(int i, size_t num_cols, DType* out, const DType* weight, + MSHADOW_XINLINE static void Map(int i, const index_t num_cols, DType* out, const DType* weight, const DType *grad, const DType clip_gradient, const DType lr, const DType wd, const DType rescale_grad) { bool contains_non_zeros = false; @@ -191,6 +192,7 @@ inline void SGDUpdateRspDnsImpl(const SGDParam& param, CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "SGDUpdate", "weights"); CHECK_EQ(weight.storage_type(), kRowSparseStorage); if (req == kNullOp) return; + CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse sgd_update"; CHECK(weight.storage_initialized()); Stream* s = ctx.get_stream(); MSHADOW_REAL_TYPE_SWITCH(weight.dtype(), DType, { @@ -216,14 +218,9 @@ inline void SGDUpdateRspRspImpl(const SGDParam& param, const OpReqType& req, NDArray *out) { CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "SGDUpdate", "weights"); - // TODO(haibin) this is a temporary solution, due to the fact that imperative_invoke only - // feed in kWriteTo as req for all operators. - // For sgd we don't want to assign zeros to the output values when req == kWriteTo - auto out_req = req; - if (out_req == kWriteTo) out_req = kWriteInplace; // reuse dns rsp implementation when storage_shape == shape TBlob out_blob = out->data(); - SGDUpdateDnsRspImpl(param, ctx, weight.data(), grad, out_req, &out_blob); + SGDUpdateDnsRspImpl(param, ctx, weight.data(), grad, req, &out_blob); } template @@ -425,14 +422,14 @@ inline void MP_SGDMomUpdate(const nnvm::NodeAttrs& attrs, template struct SGDMomDnsRspDnsKernel { template - MSHADOW_XINLINE static void Map(int i, size_t row_length, DType* out_data, + MSHADOW_XINLINE static void Map(int i, index_t row_length, DType* out_data, DType* mom_data, const DType* weight_data, const IType* grad_idx, const DType* grad_data, const DType clip_gradient, const DType momentum, const DType lr, const DType wd, const DType rescale_grad) { const DType rate = lr * wd; - for (size_t j = 0; j < row_length; j++) { - uint64_t data_i = grad_idx[i] * row_length + j; - uint64_t grad_i = i * row_length + j; + for (index_t j = 0; j < row_length; j++) { + index_t data_i = grad_idx[i] * row_length + j; + index_t grad_i = i * row_length + j; if (clip_gradient >= 0.0f) { mom_data[data_i] = momentum * mom_data[data_i] - rate * weight_data[data_i] @@ -461,6 +458,7 @@ inline void SGDMomUpdateDnsRspDnsImpl(const SGDMomParam& param, using namespace rowsparse; Stream* s = ctx.get_stream(); if (!grad.storage_initialized() || req == kNullOp) return; + CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse sgd_mom_update"; CHECK_GT(weight.shape_.Size(), 0); CHECK_GT(mom.shape_.Size(), 0); @@ -487,7 +485,7 @@ inline void SGDMomUpdateDnsRspDnsImpl(const SGDMomParam& param, template struct SGDMomRspDnsKernel { template - MSHADOW_XINLINE static void Map(int i, size_t num_cols, DType* out, DType* mom, + MSHADOW_XINLINE static void Map(int i, index_t num_cols, DType* out, DType* mom, const DType* weight, const DType *grad, const DType clip_gradient, const DType momentum, const DType lr, const DType wd, const DType rescale_grad) { @@ -531,19 +529,15 @@ inline void SGDMomUpdateRspDnsImpl(const SGDMomParam& param, Stream* s = ctx.get_stream(); CHECK_EQ(weight.storage_type(), kRowSparseStorage); if (req == kNullOp) return; + CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse sgd_mom_update"; CHECK(weight.storage_initialized()); // fill mom with zero values if not initialized yet if (!mom.storage_initialized()) { NDArray mom_zeros = mom; FillDnsZerosRspImpl(s, &mom_zeros); } - // TODO(haibin) this is a temporary solution, due to the fact that imperative_invoke only - // feed in kWriteTo as req for all operators. - // For sgd we don't want to assign zeros to the output values when req == kWriteTo - auto out_req = req; - if (out_req == kWriteTo) out_req = kWriteInplace; MSHADOW_REAL_TYPE_SWITCH(weight.dtype(), DType, { - MXNET_ASSIGN_REQ_SWITCH(out_req, req_type, { + MXNET_ASSIGN_REQ_SWITCH(req, req_type, { DType* weight_data = weight.data().dptr(); DType* grad_data = grad.dptr(); DType* mom_data = mom.data().dptr(); @@ -578,15 +572,10 @@ inline void SGDMomUpdateRspRspRspImpl(const SGDMomParam& param, NDArray mom_zeros = mom; FillDnsZerosRspImpl(s, &mom_zeros); } - // TODO(haibin) this is a temporary solution, due to the fact that imperative_invoke only - // feed in kWriteTo as req for all operators. - // For sgd we don't want to assign zeros to the output values when req == kWriteTo - auto out_req = req; - if (out_req == kWriteTo) out_req = kWriteInplace; TBlob out_blob = out->data(); // reuse dns rsp implementation when storage_shape == shape SGDMomUpdateDnsRspDnsImpl(param, ctx, weight.data(), grad, - mom.data(), out_req, &out_blob); + mom.data(), req, &out_blob); } template diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py index 76e121e462dc..a7557809e225 100644 --- a/tests/python/unittest/test_sparse_ndarray.py +++ b/tests/python/unittest/test_sparse_ndarray.py @@ -81,45 +81,14 @@ def check_sparse_nd_copy(from_stype, to_stype, shape): def test_sparse_nd_basic(): - def check_rsp_creation(values, indices, shape): - rsp = mx.nd.row_sparse(values, indices, shape) - dns = mx.nd.zeros(shape) - dns[1] = mx.nd.array(values[0]) - dns[3] = mx.nd.array(values[1]) - indices_np = mx.nd.array(indices, dtype='int64').asnumpy() - assert_almost_equal(rsp.indices.asnumpy(), indices_np) - - def check_csr_creation(shape): - csr, (indptr, indices, values) = rand_sparse_ndarray(shape, 'csr') - assert_almost_equal(csr.indptr.asnumpy(), indptr) - assert_almost_equal(csr.indices.asnumpy(), indices) - assert_almost_equal(csr.data.asnumpy(), values) - - def check_sparse_nd_rsp_aux(): + def check_sparse_nd_basic_rsp(): storage_type = 'row_sparse' shape = rand_shape_2d() nd, (v, idx) = rand_sparse_ndarray(shape, storage_type) assert(nd._num_aux == 1) assert(nd.indices.dtype == np.int64) assert(nd.stype == 'row_sparse') - assert_almost_equal(nd.indices.asnumpy(), idx) - assert_almost_equal(nd.data.asnumpy(), v) - - shape = (4,2) - values = np.random.rand(2,2) - indices = np.array([1,3], dtype='int64') - check_rsp_creation(values, indices, shape) - - values = mx.nd.array(np.random.rand(2,2)) - indices = mx.nd.array([1,3], dtype='int64') - check_rsp_creation(values, indices, shape) - - values = [[0.1, 0.2], [0.3, 0.4]] - indices = [1,3] - check_rsp_creation(values, indices, shape) - - check_csr_creation(shape) - check_sparse_nd_rsp_aux() + check_sparse_nd_basic_rsp() def test_sparse_nd_setitem():