Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[WIP] Sparse Tensor #5800

Merged
merged 29 commits into from
Jun 26, 2017
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
f98912b
squash
eric-haibin-lin Apr 29, 2017
a880bc7
draft for sgd rsp rsp (#75)
eric-haibin-lin Jun 10, 2017
6d329cd
fix lint (#78)
eric-haibin-lin Jun 10, 2017
c8d3742
fix lint (#79)
eric-haibin-lin Jun 10, 2017
16a6d7f
serial elemwise sum impl (#80)
eric-haibin-lin Jun 10, 2017
87bb1f7
bug fix for initializing module with row_sparse weight (#81)
eric-haibin-lin Jun 11, 2017
3d2d1c0
Sparse ndarray serialization and deserialization (#77)
reminisce Jun 12, 2017
1e86dbf
Fix lint (#84)
reminisce Jun 12, 2017
86f896a
Sgd with row_sparse weight, dns gradient (#83)
eric-haibin-lin Jun 13, 2017
df49954
update mshadow version (#88)
eric-haibin-lin Jun 13, 2017
3fc1703
csr slice bug fix (#90)
eric-haibin-lin Jun 13, 2017
972647d
benchmark dot code refactor (#87)
jiajiechen Jun 13, 2017
31a0233
Add unit test (#91)
jiajiechen Jun 14, 2017
0b021d5
kvstore push row sparse (#93)
reminisce Jun 14, 2017
073d8f2
Fix windows openmp build failure (#94)
reminisce Jun 14, 2017
1a129e4
update mshadow submoduel (#95)
eric-haibin-lin Jun 14, 2017
b2b3af2
Revert "update mshadow submoduel (#95)" (#96)
eric-haibin-lin Jun 14, 2017
24105cf
Refactor sparse tensor code (#99)
reminisce Jun 19, 2017
ddbe565
fix pylint (#100)
eric-haibin-lin Jun 19, 2017
905304c
Fix refactor sparse gpu test (#104)
reminisce Jun 21, 2017
a9612e3
change idx types from int32 to int64 (#101)
eric-haibin-lin Jun 21, 2017
c134687
revert LOG(DEBUG) change (#105)
eric-haibin-lin Jun 21, 2017
1914471
fix undefined zeros in optimizer.py (#106)
eric-haibin-lin Jun 21, 2017
ad8e74c
move init dns zeros to init_op.h for kvstore to use (#107)
eric-haibin-lin Jun 22, 2017
ac554ac
Refactor cast storage (#109)
reminisce Jun 23, 2017
2727528
Rowsparse kv (#111)
eric-haibin-lin Jun 25, 2017
4d4fbc7
fix lint (#113)
eric-haibin-lin Jun 26, 2017
f8556cc
rename some python funciton (#114)
eric-haibin-lin Jun 26, 2017
cb5deca
fix lint (#115)
eric-haibin-lin Jun 26, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -215,17 +215,17 @@ del /Q *.7z
// Python unittest for CPU
def python_ut(docker_type) {
timeout(time: max_time, unit: 'MINUTES') {
sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests --with-timer --verbose tests/python/unittest"
sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-2.7 --with-timer --verbose tests/python/unittest"
sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-3.4 --with-timer --verbose tests/python/unittest"
sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests --with-timer --verbose tests/python/train"
sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-2.7 --with-timer --verbose tests/python/train"
}
}

// GPU test has two parts. 1) run unittest on GPU, 2) compare the results on
// both CPU and GPU
def python_gpu_ut(docker_type) {
timeout(time: max_time, unit: 'MINUTES') {
sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests --with-timer --verbose tests/python/gpu"
sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-2.7 --with-timer --verbose tests/python/gpu"
sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-3.4 --with-timer --verbose tests/python/gpu"
}
}
Expand Down
191 changes: 191 additions & 0 deletions benchmark/python/sparse_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import ctypes

from mxnet.test_utils import *
import scipy.sparse as sp
import os
import time
import argparse

from mxnet.base import check_call, _LIB

parser = argparse.ArgumentParser(description="Benchmark sparse operators",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--num-omp-threads', type=int, default=1, help='number of omp threads to set in MXNet')
args = parser.parse_args()


def get_avazu(data_dir):
if not os.path.isdir(data_dir):
os.system("mkdir " + data_dir)
os.chdir(data_dir)
if (not os.path.exists('avazu-app.t')):
import urllib
zippath = os.path.join(data_dir, "avazu-app.t.bz2")
url = "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/avazu-app.t.bz2"
urllib.urlretrieve(url, zippath)
# decompress
os.system("bzip2 -d avazu-app.t.bz2")
os.chdir("..")


def test_dot_real():
def get_iter(path, data_shape, batch_size):
data_train = mx.io.LibSVMIter(data_libsvm=path,
data_shape=data_shape,
batch_size=batch_size)
data_iter = iter(data_train)
return data_iter
data_dir = os.path.join(os.getcwd(), 'data')
get_avazu(data_dir)
path = os.path.join(data_dir, 'avazu-app.t')
# TODO(haibin) get file size automatically
size = 336490781 >> 20

# model
batch_size = 512
feature_dim = 1000000
data_shape = (feature_dim, )
train_iter = get_iter(path, data_shape, batch_size)

k = 500
weight = mx.nd.random_uniform(low=0, high=1, shape=(feature_dim, k))
weight.wait_to_read()

# start workload
start = time.time()
results = []
num_batch = 0
for batch in train_iter:
data = train_iter.getdata()
results.append(mx.nd.dot(data, weight))
num_batch += 1
for result in results:
result.wait_to_read()

end = time.time()
cost = end - start
print(size / cost, cost, num_batch, num_batch / cost)


def test_dot_synthetic():
"""benchmark mx.nd.dot(sparse_ndarray, dense_ndarray) with given density.
`t_sparse` is the time cost of dot(csr, dns), while `t_dense` is the time cost
of dot(dns, dns), with the same matrix except that it is in default storage type.
"""
def measure_cost_forward_baseline(repeat, dot, lhs, rhs):
start = time.time()
for i in range(repeat):
dot(lhs, rhs)
end = time.time()
diff = end - start
return diff / repeat

def measure_cost_backward_baseline(repeat, dot, transpose, lhs, rhs):
start = time.time()
for i in range(repeat):
dot(transpose(lhs), rhs)
end = time.time()
diff = end -start
return diff / repeat

def measure_cost(repeat, f, *args, **kwargs):
# start bench
start = time.time()
results = []
for i in range(repeat):
results.append(f(*args, **kwargs))
for result in results:
result.wait_to_read()
end = time.time()
diff = end - start
return diff / repeat

def bench_dot_forward(m, k, n, density, ctx, repeat):
set_default_context(ctx)
dns = mx.nd.random_uniform(shape=(k, n)).copyto(ctx)
data_shape = (m, k)
csr_data = rand_ndarray(data_shape, 'csr', density)
dns_data = csr_data.to_dense()
rhs_dns_np = dns.asnumpy()
lhs_csr_sp = sp.csr_matrix(dns_data.asnumpy()) # csr in scipy
lhs_dns_np = lhs_csr_sp.todense()

data = [dns_data, csr_data]
costs = []
for d in data:
dns.wait_to_read()
d.wait_to_read()
cost = measure_cost(repeat, mx.nd.dot, d, dns)
costs.append(cost / repeat)
ratio = costs[1] / costs[0]

costs_baseline = []
cost = measure_cost_forward_baseline(repeat, np.dot, lhs_dns_np, rhs_dns_np)
costs_baseline.append(cost)
cost = measure_cost_forward_baseline(repeat, sp.spmatrix.dot, lhs_csr_sp, rhs_dns_np)
costs_baseline.append(cost)
ratio_baseline = costs_baseline[1] / costs_baseline[0]
fmt = "%0.1f\t\t%s\t%d\t%d\t%d\t%0.6f\t%0.5f\t%0.2f\t\t\t%0.6f\t%0.5f\t\t%0.2f"
print(fmt % (density * 100, str(ctx), n, m, k, costs[1], costs[0], ratio,
costs_baseline[1], costs_baseline[0], ratio_baseline))

def bench_dot_backward(m, k, n, density, ctx, repeat):
set_default_context(ctx)
dns = mx.nd.random_uniform(shape=(m, n)).copyto(ctx)
data_shape = (m, k)
csr_data = rand_ndarray(data_shape, 'csr', density)
dns_data = csr_data.to_dense()
rhs_dns_np = dns.asnumpy()
lhs_csr_sp = sp.csr_matrix(dns_data.asnumpy())
lhs_dns_np = lhs_csr_sp.todense()

data = [dns_data, csr_data]
costs = []
for d in data:
dns.wait_to_read()
d.wait_to_read()
cost = measure_cost(repeat, mx.nd.dot, d, dns, transpose_a=True)
costs.append(cost)
ratio = costs[1] / costs[0]

costs_baseline = []
cost = measure_cost_backward_baseline(repeat, np.dot, np.transpose, lhs_dns_np, rhs_dns_np)
costs_baseline.append(cost)
cost = measure_cost_backward_baseline(repeat, sp.spmatrix.dot, sp.spmatrix.transpose, lhs_csr_sp, rhs_dns_np)
costs_baseline.append(cost)
ratio_baseline = costs_baseline[1] / costs_baseline[0]
fmt = "%0.1f\t\t%s\t%d\t%d\t%d\t%0.6f\t%0.5f\t%0.2f\t\t\t%0.6f\t%0.5f\t\t%0.2f"
print(fmt % (density * 100, str(ctx), n, m, k, costs[1], costs[0], ratio,
costs_baseline[1], costs_baseline[0], ratio_baseline))

print("A = sparse NDArray of shape(m, k)")
print("B = dense NDArray of shape(k, n)")
print("dot_forward\tdot(csr, dns)")
print('density(%)\tcontext\tn\tm\tk\tt_sparse\tt_dense\tt_sparse/t_dense'
'\tt_scipy_sparse\tt_scipy_dense\tt_scipy_sparse/t_scipy_dense')

check_call(_LIB.MXSetNumOMPThreads(ctypes.c_int(args.num_omp_threads)))
# TODO(haibin) make these runtime options
m = 512
k = [50000, 100000]
n = [50, 100]
density = [0.05, 0.02, 0.01, 0.005, 0.001]
num_repeat = 10
# contexts = [mx.cpu(), mx.gpu(0)]
contexts = [mx.cpu()]
for i in range(2):
for ctx in contexts:
for den in density:
bench_dot_forward(m, k[i], n[i], den, ctx, num_repeat)

print("dot_backward\tdot(csr.T, dns)")
print('density(%)\tcontext\tn\tm\tk\tt_sparse\tt_dense\tt_sparse/t_dense'
'\tt_scipy_sparse\tt_scipy_dense\tt_scipy_sparse/t_scipy_dense')
for i in range(2):
for ctx in contexts:
for den in density:
bench_dot_backward(m, k[i], n[i], den, ctx, num_repeat)

if __name__ == "__main__":
test_dot_real()
test_dot_synthetic()
94 changes: 94 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,38 @@ MXNET_DLL int MXNDArrayCreateEx(const mx_uint *shape,
int delay_alloc,
int dtype,
NDArrayHandle *out);


/*!
* \brief create an empty sparse NDArray with specified shape and data type
* \param storage_type the storage type of the ndarray
* \param shape the pointer to the shape
* \param ndim the dimension of the shape
* \param dev_type device type, specify device we want to take
* \param dev_id the device id of the specific device
* \param delay_alloc whether to delay allocation until
* the narray is first mutated
* \param dtype data type of created array
* \param num_aux the number of aux data to support this ndarray
* \param aux_type data type of the aux data for the created array
* \param aux_ndims the dimension of the shapes of aux data
* \param aux_shape the shapes of aux data
* \param out the returning handle
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArrayCreateSparseEx(int storage_type,
const mx_uint *shape,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

identation

mx_uint ndim,
int dev_type,
int dev_id,
int delay_alloc,
int dtype,
mx_uint num_aux,
int *aux_type,
mx_uint *aux_ndims,
const mx_uint *aux_shape,
NDArrayHandle *out);

/*!
* \brief create a NDArray handle that is loaded from raw bytes.
* \param buf the head of the raw bytes
Expand Down Expand Up @@ -358,6 +390,19 @@ MXNET_DLL int MXNDArraySlice(NDArrayHandle handle,
mx_uint slice_begin,
mx_uint slice_end,
NDArrayHandle *out);

/*!
* \brief Slice the NDArray with non-default storage along axis 0.
* \param handle the handle to the NDArray
* \param slice_begin The beginning index of slice
* \param slice_end The ending index of slice
* \param out The NDArrayHandle of sliced NDArray
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArraySliceEx(NDArrayHandle handle,
mx_uint slice_begin,
mx_uint slice_end,
NDArrayHandle out);
/*!
* \brief Index the NDArray along axis 0.
* \param handle the handle to the NDArray
Expand All @@ -368,6 +413,13 @@ MXNET_DLL int MXNDArraySlice(NDArrayHandle handle,
MXNET_DLL int MXNDArrayAt(NDArrayHandle handle,
mx_uint idx,
NDArrayHandle *out);

/*!
* \brief get the storage type of the array
*/
MXNET_DLL int MXNDArrayGetStorageType(NDArrayHandle handle,
int *out_storage_type);

/*!
* \brief Reshape the NDArray.
* \param handle the handle to the narray
Expand Down Expand Up @@ -406,6 +458,26 @@ MXNET_DLL int MXNDArrayGetData(NDArrayHandle handle,
*/
MXNET_DLL int MXNDArrayGetDType(NDArrayHandle handle,
int *out_dtype);

/*!
* \brief get the type of the ith aux data in NDArray
* \param handle the handle to the narray
* \param i the index of the aux data
* \param out_type pointer holder to get type of aux data
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArrayGetAuxType(NDArrayHandle handle,
mx_uint i,
int *out_type);

// Get the ith aux data blob wrapped in an NDArray
MXNET_DLL int MXNDArrayGetAuxNDArray(NDArrayHandle handle,
mx_uint i,
NDArrayHandle *out);

// Get the data blob wrapped in an NDArray
MXNET_DLL int MXNDArrayGetDataNDArray(NDArrayHandle handle,
NDArrayHandle *out);
/*!
* \brief get the context of the NDArray
* \param handle the handle to the narray
Expand Down Expand Up @@ -1003,6 +1075,25 @@ MXNET_DLL int MXSymbolInferType(SymbolHandle sym,
mx_uint *aux_type_size,
const int **aux_type_data,
int *complete);




/*!
* \brief infer storage type of unknown input types given the known one.
*/
MXNET_DLL int MXSymbolInferStorageType(SymbolHandle sym,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This API could probably be removed

mx_uint num_args,
const char** keys,
const int *arg_storage_type_data,
mx_uint *in_storage_type_size,
const int **in_storage_type_data,
mx_uint *out_storage_type_size,
const int **out_storage_type_data,
mx_uint *aux_storage_type_size,
const int **aux_storage_type_data,
int *complete);

//--------------------------------------------
// Part 4: Executor interface
//--------------------------------------------
Expand Down Expand Up @@ -1167,6 +1258,9 @@ MXNET_DLL int MXExecutorSimpleBind(SymbolHandle symbol_handle,
const mx_uint num_provided_arg_dtypes,
const char** provided_arg_dtype_names,
const int* provided_arg_dtypes,
const mx_uint num_provided_arg_stypes,
const char** provided_arg_stype_names,
const int* provided_arg_stypes,
const mx_uint num_shared_arg_names,
const char** shared_arg_name_list,
int* shared_buffer_len,
Expand Down
1 change: 1 addition & 0 deletions include/mxnet/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class Executor {
const std::vector<Context>& aux_state_ctxes,
const std::unordered_map<std::string, TShape>& arg_shape_map,
const std::unordered_map<std::string, int>& arg_dtype_map,
const std::unordered_map<std::string, int>& arg_stype_map,
const std::vector<OpReqType>& grad_req_types,
const std::unordered_set<std::string>& param_names,
std::vector<NDArray>* in_args,
Expand Down
13 changes: 13 additions & 0 deletions include/mxnet/io.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,19 @@ class IIterator : public dmlc::DataIter<DType> {
}
}; // class IIterator

/*!
* \brief iterator type
* \param DType data type
*/
template<typename DType>
class SparseIIterator : public IIterator<DType> {
Copy link
Contributor

@piiswrong piiswrong Jun 11, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put this undr src/

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

public:
/*! \brief storage type of the data or label */
virtual const NDArrayStorageType GetStorageType(bool is_data) const = 0;
/*! \brief shape of the data or label */
virtual const TShape GetShape(bool is_data) const = 0;
}; // class SparseIIterator

/*! \brief a single data instance */
struct DataInst {
/*! \brief unique id for instance */
Expand Down
Loading