diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 207003ae029d..c684c7ad6057 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -483,7 +483,13 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, std::vector *out_attrs) { CHECK_EQ(in_attrs->size(), 2U); CHECK_EQ(out_attrs->size(), 1U); - out_attrs->at(0) = kDefaultStorage; + const DotParam& param = nnvm::get(attrs.parsed); + if (param.transpose_a && kCSRStorage == (*in_attrs)[0] + && kDefaultStorage == (*in_attrs)[1]) { + STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kRowSparseStorage); + } else { + STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kDefaultStorage); + } return true; } @@ -493,8 +499,14 @@ inline bool DotBackwardInferStorageType(const nnvm::NodeAttrs& attrs, std::vector *out_attrs) { CHECK_EQ(in_attrs->size(), 3U); CHECK_EQ(out_attrs->size(), 2U); - out_attrs->at(0) = kDefaultStorage; - out_attrs->at(1) = kDefaultStorage; + const DotParam& param = nnvm::get(attrs.parsed); + STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kDefaultStorage); + if (!param.transpose_a && kDefaultStorage == (*in_attrs)[0] + && kCSRStorage == (*in_attrs)[1] && kDefaultStorage == (*in_attrs)[2]) { + STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 1, kRowSparseStorage); + } else { + STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 1, kDefaultStorage); + } return true; } @@ -642,6 +654,45 @@ struct DotCsrTransDnsDnsByRowBlocks { } }; +/*! + * \brief Kernel of dot(csr.T(), dns) = rsp + * Parallelization by row blocks. + * This kernel fills up the row_idx array + * of the rsp with 1 for nonzero rows and 0 + * for zero rows. + * The matrix will be compacted after this kernel call. + */ +struct DotCsrTransDnsRspByRowBlocks { + /*! + * \brief + * \param i the i-th thread + */ + template + MSHADOW_XINLINE static void Map(int i, DType* out, RType* row_idx, const DType* data_l, + const IType* indptr_l, const CType* col_idx_l, + const DType* data_r, const size_t seg_len, + const size_t num_rows_l, const size_t num_rows, + const size_t num_cols) { + const size_t seg_start = i * seg_len; + if (seg_start >= num_rows) return; + const size_t seg_end = (i + 1) * seg_len; + for (size_t j = 0; j < num_rows_l; ++j) { + if (indptr_l[j] == indptr_l[j+1]) continue; + const size_t offset_r = j * num_cols; + for (auto k = indptr_l[j]; k < indptr_l[j+1]; ++k) { + const auto col_idx = col_idx_l[k]; + if (col_idx < seg_start || col_idx >= seg_end) continue; + const size_t offset_out = col_idx * num_cols; + row_idx[col_idx] = 1; + const auto val = data_l[k]; + for (size_t l = 0; l < num_cols; ++l) { + out[offset_out+l] += data_r[offset_r+l] * val; + } + } + } + } +}; + template void DotCsrDnsDnsImpl(const OpContext& ctx, const NDArray& lhs, @@ -702,6 +753,75 @@ void DotCsrDnsDnsImpl(const OpContext& ctx, }); } +template +void DotCsrDnsRspImpl(const OpContext& ctx, + const NDArray& lhs, + const TBlob& rhs, + const OpReqType req, + const bool trans_lhs, + NDArray* ret) { + if (kNullOp == req) return; + CHECK_EQ(lhs.storage_type(), kCSRStorage); + CHECK_EQ(ret->storage_type(), kRowSparseStorage); + if (!lhs.storage_initialized()) return; + + mshadow::Stream *s = ctx.get_stream(); + const TBlob data_l = lhs.data(); + const TBlob indptr_l = lhs.aux_data(csr::kIndPtr); + const TBlob col_idx_l = lhs.aux_data(csr::kIdx); + const TBlob& data_r = rhs; + + // pre-allocate spaces for ret using the dense dimension size + ret->CheckAndAlloc({mshadow::Shape1(lhs.shape()[1])}); + const TBlob data_out = ret->data(); + const TBlob row_idx_out = ret->aux_data(rowsparse::kIdx); + + MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type + MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type + MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type + MSHADOW_IDX_TYPE_SWITCH(row_idx_out.type_flag_, RType, { // col idx type + if (std::is_same::value) { // cpu parallelization by row blocks + if (kWriteTo == req) { + mxnet_op::Kernel::Launch( + s, data_out.Size(), data_out.dptr()); + } + RType* row_idx = row_idx_out.dptr(); + mxnet_op::Kernel::Launch( + s, row_idx_out.Size(), row_idx); + int num_threads = mxnet_op::get_num_threads(data_out.shape_[0]); + size_t seg_len = (data_out.shape_[0] + num_threads - 1) / num_threads; + if (trans_lhs) { + mxnet_op::Kernel::Launch(s, num_threads, + data_out.dptr(), row_idx, data_l.dptr(), + indptr_l.dptr(), col_idx_l.dptr(), data_r.dptr(), + seg_len, lhs.shape()[0], data_out.shape_[0], data_out.shape_[1]); + index_t nnr = 0; + nnr = mxnet::common::ParallelAccumulate(row_idx, ret->shape()[0], nnr); + ret->set_aux_shape(rowsparse::kIdx, mshadow::Shape1(nnr)); + ret->set_storage_shape(mshadow::Shape2(nnr, ret->shape()[1])); + if (0 == nnr) return; + mshadow::Tensor rsp_data = data_out.FlatTo2D(s); + size_t idx = 0; + for (index_t i = 0; i < ret->shape()[0]; ++i) { + if (row_idx[i] > 0) { + row_idx[idx] = i; + mshadow::Copy(rsp_data[idx], rsp_data[i], s); + ++idx; + } + } + } else { + LOG(FATAL) << "DotCsrDnsRspImpl has not implemented dot(csr, dns)=rsp yet." + " Only the cpu version of dot(csr.T, dns)=rsp is supported now"; + } + } else { + LOG(FATAL) << "DotCsrDnsRspImpl has not implemented GPU version yet."; + } + }); + }); + }); + }); +} + template void DotCsrRspDnsImpl(const OpContext& ctx, const NDArray& lhs, @@ -803,10 +923,12 @@ void DotForwardEx(const nnvm::NodeAttrs& attrs, out_stype == kDefaultStorage) { TBlob ret = outputs[0].data(); DotCsrRspDnsImpl(ctx, inputs[0], inputs[1], req[0], param.transpose_a, &ret); - } else { // TODO(junwu): add fallback - LOG(FATAL) << "Not supported dot operation for lhs.storage_type = " - << inputs[0].storage_type() << ", rhs.storage_type = " << inputs[1].storage_type() - << ", out.storage_type = " << outputs[0].storage_type(); + } else if (lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage + && out_stype == kRowSparseStorage) { + NDArray out = outputs[0]; + DotCsrDnsRspImpl(ctx, inputs[0], inputs[1].data(), req[0], param.transpose_a, &out); + } else { + FCompExFallback(attrs, ctx, inputs, req, outputs, DotForward_, "DotForward_"); } } @@ -823,7 +945,6 @@ void DotBackwardEx(const nnvm::NodeAttrs& attrs, << "sparse dot does not support computing the gradient of the csr/lhs"; CHECK_NE(req[1], kWriteInplace) << "DotBackwardEx does not support WriteInplace"; - // TODO(junwu): check whether this CHECK is reasonable const DotParam& param = nnvm::get(attrs.parsed); CHECK(!param.transpose_b) << "sparse dot only supports dot(A, X) and dot(A.T(), X)"; auto ograd_stype = inputs[0].storage_type(); @@ -836,11 +957,15 @@ void DotBackwardEx(const nnvm::NodeAttrs& attrs, // dns, csr, dns => *, dns DotBackwardCsrDnsDns(attrs, ctx, inputs, req, outputs); } else if (ograd_stype == kDefaultStorage && lhs_stype == kCSRStorage && - rhs_stype == kRowSparseStorage && outputs[1].storage_type() == kDefaultStorage) { + rhs_stype == kRowSparseStorage && outputs[1].storage_type() == kDefaultStorage) { // dns, csr, rsp => *, dns DotBackwardCsrRspDns(attrs, ctx, inputs, req, outputs); + } else if (ograd_stype == kDefaultStorage && lhs_stype == kCSRStorage && + rhs_stype == kDefaultStorage && outputs[1].storage_type() == kRowSparseStorage) { + NDArray grad_rhs = outputs[1]; + DotCsrDnsRspImpl(ctx, inputs[1], inputs[2].data(), req[1], !param.transpose_a, &grad_rhs); } else { - LOG(FATAL) << "Not supported dot backward for sparse input(s) with sparse gradients"; + FCompExFallback(attrs, ctx, inputs, req, outputs, DotBackward_, "DotBackward_"); } } diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py index dcc0f38b208a..5801bb1829d3 100644 --- a/tests/python/unittest/test_module.py +++ b/tests/python/unittest/test_module.py @@ -383,7 +383,7 @@ def fm_model(k, feature_dim): x = mx.symbol.Variable("data", storage_type='csr') v = mx.symbol.Variable("v", shape=(feature_dim, k), init=norm, storage_type='row_sparse') - w1_weight = mx.symbol.var('w1_weight', shape=(feature_dim, 1), init=norm) + w1_weight = mx.symbol.var('w1_weight', shape=(feature_dim, 1), init=norm, storage_type='row_sparse') w1 = mx.symbol.dot(x, w1_weight) v_s = mx.symbol.sum(data=mx.symbol.square(data=v), axis=1) diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py index ba10ad830f23..4d2debe5f9d2 100644 --- a/tests/python/unittest/test_sparse_operator.py +++ b/tests/python/unittest/test_sparse_operator.py @@ -100,6 +100,7 @@ def test_dns_to_csr(dns_in): test_csr_to_dns((4, 4)) test_dns_to_csr([[0, 1, 0], [0, 2, 0], [3, 0, 0], [0, 0, 4], [5, 6, 0], [0, 0, 7]]) + def test_sparse_dot(): def test_dot_csr(lhs_shape, rhs_shape, rhs_stype, trans_lhs): lhs_dns = rand_ndarray(lhs_shape, 'default') @@ -107,7 +108,10 @@ def test_dot_csr(lhs_shape, rhs_shape, rhs_stype, trans_lhs): rhs_nd = rand_ndarray(rhs_shape, rhs_stype, density=1) rhs_dns = rhs_nd if rhs_stype == 'default' else rhs_nd.todense() out = mx.nd.dot(lhs_nd, rhs_dns, transpose_a=trans_lhs) - assert out.storage_type == 'default' + if trans_lhs: + assert out.storage_type == 'row_sparse' + else: + assert out.storage_type == 'default' out_expected = mx.nd.dot(lhs_dns, rhs_dns, transpose_a=trans_lhs) out_np = out_expected.asnumpy() backward_trans = not trans_lhs @@ -132,6 +136,7 @@ def test_dot_csr(lhs_shape, rhs_shape, rhs_stype, trans_lhs): test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'row_sparse', False) test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 'row_sparse', True) + def test_sparse_embedding(): in_dim = 10 out_dim = 4