Skip to content

Commit

Permalink
construct row_sparse ndarray for dist-async
Browse files Browse the repository at this point in the history
fix bug in rsp add

rsp sync push

race condition for push

fix bug in rsp pull. refactor test

cleanup comments

refactor dist server

fix lint

fix storage shape issue with the new ndarray constructor

data sharding draft;

fix lint. add comment

add support for zeros gradients

use std::upper_bound/lower_bound

remove special init function for rowsparse dist kvstore

temporary support for inplace operators for sparse

add test. fix return type

store kRowSparseNDArray in kv server

remove fcomp_ex sgd with dns weight and rsp gradient

bug fix in sparse retain

sparse pull c_api

revise rowsparse pull api

use engine to compute unique to ensure thread safety

add rowsparse pull to dist-kv

fix lint

add example for rsp_pull

remove name2idx;

add sparse_pull_dict param to module

fix unit test and  c rowid conversion

support str key type in kvstore (apache#6765)

* update kvstore unit test

* update model/module.py

* fix lint

* remove int keys in kvstore

* update cast to str function

* remove _cast_to_str_keys

* fix lint

* always cast to str

Conflicts:
	include/mxnet/c_api.h
	include/mxnet/kvstore.h
	python/mxnet/kvstore.py
	python/mxnet/model.py
	python/mxnet/module/module.py
	src/c_api/c_api.cc
	src/kvstore/kvstore_local.h
	tests/python/unittest/test_kvstore.py

update module API for other submodules

update stypes in kvstore after refactoring

change type of size from size_t to int64_t

add sparse linear regression example

remove sparse_pull_dict from module

fix init_optim for seq_module. update sparse example

resolve conflict for binary add rsp rsp
  • Loading branch information
eric-haibin-lin committed Jul 26, 2017
1 parent 6644d22 commit c2ed043
Show file tree
Hide file tree
Showing 20 changed files with 1,027 additions and 375 deletions.
15 changes: 15 additions & 0 deletions example/sparse/get_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# pylint: skip-file
import os, gzip
import pickle as pickle
import sys

def get_libsvm_data(data_dir, data_name, url, data_origin_name):
if not os.path.isdir(data_dir):
os.system("mkdir " + data_dir)
os.chdir(data_dir)
if (not os.path.exists(data_name)):
import urllib
zippath = os.path.join(data_dir, data_origin_name)
urllib.urlretrieve(url, zippath)
os.system("bzip2 -d %r" % data_origin_name)
os.chdir("..")
178 changes: 178 additions & 0 deletions example/sparse/linear_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import mxnet as mx
from mxnet.test_utils import *
from get_data import get_libsvm_data
import time
import argparse
import os

parser = argparse.ArgumentParser(description="Run sparse linear regression " \
"with distributed kvstore",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--profiler', type=int, default=0,
help='whether to use profiler')
parser.add_argument('--num-epoch', type=int, default=1,
help='number of epochs to train')
parser.add_argument('--batch-size', type=int, default=512,
help='number of examples per batch')
parser.add_argument('--num-batch', type=int, default=99999999,
help='number of batches per epoch')
parser.add_argument('--dummy-iter', type=int, default=0,
help='whether to use dummy iterator to exclude io cost')
parser.add_argument('--kvstore', type=str, default='dist_sync',
help='what kvstore to use [local, dist_sync, etc]')
parser.add_argument('--log-level', type=str, default='debug',
help='logging level [debug, info, error]')
parser.add_argument('--dataset', type=str, default='avazu',
help='what test dataset to use')

class DummyIter(mx.io.DataIter):
"A dummy iterator that always return the same batch, used for speed testing"
def __init__(self, real_iter):
super(DummyIter, self).__init__()
self.real_iter = real_iter
self.provide_data = real_iter.provide_data
self.provide_label = real_iter.provide_label
self.batch_size = real_iter.batch_size

for batch in real_iter:
self.the_batch = batch
break

def __iter__(self):
return self

def next(self):
return self.the_batch

# testing dataset sources
avazu = {
'data_name': 'avazu-app.t',
'data_origin_name': 'avazu-app.t.bz2',
'url': "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/avazu-app.t.bz2",
'feature_dim': 1000000,
}

kdda = {
'data_name': 'kdda.t',
'data_origin_name': 'kdda.t.bz2',
'url': "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/kdda.t.bz2",
'feature_dim': 20216830,
}

datasets = { 'kdda' : kdda, 'avazu' : avazu }

def regression_model(feature_dim):
initializer = mx.initializer.Normal()
x = mx.symbol.Variable("data", stype='csr')
norm_init = mx.initializer.Normal(sigma=0.01)
v = mx.symbol.Variable("v", shape=(feature_dim, 1), init=norm_init, stype='row_sparse')
embed = mx.symbol.dot(x, v)
y = mx.symbol.Variable("softmax_label")
model = mx.symbol.LinearRegressionOutput(data=embed, label=y, name="out")
return model

if __name__ == '__main__':

# arg parser
args = parser.parse_args()
num_epoch = args.num_epoch
num_batch = args.num_batch
kvstore = args.kvstore
profiler = args.profiler > 0
batch_size = args.batch_size
dummy_iter = args.dummy_iter
dataset = args.dataset
log_level = args.log_level

# create kvstore
kv = mx.kvstore.create(kvstore)
rank = kv.rank
num_worker = kv.num_workers

# only print log for rank 0 worker
import logging
if rank != 0:
log_level = logging.ERROR
elif log_level == 'DEBUG':
log_level = logging.DEBUG
else:
log_level = logging.INFO
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=log_level, format=head)

# dataset
assert(dataset in datasets), "unknown dataset " + dataset
metadata = datasets[dataset]
feature_dim = metadata['feature_dim']
if logging:
logging.debug('preparing data ... ')
data_dir = os.path.join(os.getcwd(), 'data')
path = os.path.join(data_dir, metadata['data_name'])
if not os.path.exists(path):
get_libsvm_data(data_dir, metadata['data_name'], metadata['url'],
metadata['data_origin_name'])
assert os.path.exists(path)

# data iterator
train_data = mx.io.LibSVMIter(data_libsvm=path, data_shape=(feature_dim,),
batch_size=batch_size, num_parts=num_worker,
part_index=rank)
if dummy_iter:
train_data = DummyIter(train_data)

# model
model = regression_model(feature_dim)

# module
mod = mx.mod.Module(symbol=model, data_names=['data'], label_names=['softmax_label'])
mod.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label)
mod.init_params(initializer=mx.init.Uniform(scale=.1))
sgd = mx.optimizer.SGD(momentum=0.0, clip_gradient=5.0,
learning_rate=0.1, rescale_grad=1.0/batch_size/num_worker)
mod.init_optimizer(optimizer=sgd, kvstore=kv)
# use accuracy as the metric
metric = mx.metric.create('MSE')

# start profiler
if profiler:
import random
name = 'profile_output_' + str(num_worker) + '.json'
mx.profiler.profiler_set_config(mode='all', filename=name)
mx.profiler.profiler_set_state('run')

logging.debug('start training ...')
start = time.time()
data_iter = iter(train_data)
for epoch in range(num_epoch):
nbatch = 0
end_of_batch = False
data_iter.reset()
metric.reset()
next_batch = next(data_iter)
while not end_of_batch:
nbatch += 1
batch = next_batch
# TODO(haibin) remove extra copy after Jun's change
row_ids = batch.data[0].indices.copyto(mx.cpu())
# pull sparse weight
index = mod._exec_group.param_names.index('v')
kv.row_sparse_pull('v', mod._exec_group.param_arrays[index],
priority=-index, row_ids=[row_ids])
mod.forward_backward(batch)
# update parameters
mod.update()
try:
# pre fetch next batch
next_batch = next(data_iter)
if nbatch == num_batch:
raise StopIteration
except StopIteration:
end_of_batch = True
# accumulate prediction accuracy
mod.update_metric(metric, batch.label)
logging.info('epoch %d, %s' % (epoch, metric.get()))
if profiler:
mx.profiler.profiler_set_state('stop')
end = time.time()
time_cost = end - start
logging.info('num_worker = ' + str(num_worker) + ', time cost = ' + str(time_cost))
20 changes: 20 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1505,6 +1505,26 @@ MXNET_DLL int MXKVStorePullEx(KVStoreHandle handle,
const char** keys,
NDArrayHandle* vals,
int priority);

/*!
* \brief pull a list of (key, value) pairs from the kvstore, where each key is a string.
* The NDArray pulled back will be in row_sparse storage with only the specified
* row_ids present based row_ids (others rows are zeros).
* \param handle handle to the kvstore
* \param num the number of key-value pairs
* \param keys the list of keys
* \param vals the list of values
* \param row_ids the list of row_id NDArrays
* \param priority the priority of the action
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXKVStorePullRowSparse(KVStoreHandle handle,
mx_uint num,
const char** keys,
NDArrayHandle* vals,
const NDArrayHandle* row_ids,
int priority);

/*!
* \brief user-defined updater for the kvstore
* It's this updater's responsibility to delete \a recv and \a local
Expand Down
24 changes: 24 additions & 0 deletions include/mxnet/kvstore.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#define MXNET_KVSTORE_H_
#include <dmlc/io.h>
#include <vector>
#include <utility>
#include <unordered_map>
#include <string>
#include <functional>
Expand Down Expand Up @@ -155,6 +156,29 @@ class KVStore {
const std::vector<NDArray*>& values,
int priority = 0) = 0;

/*!
* \brief pull a list of key-value pairs from the store.
* The NDArray pulled back will be in row_sparse storage with only the
* specified row_ids present (others rows are zeros).
* \param keys the list of keys
* \param values the list of buffers - row_id pairs
* \param priority the priority of the action.
*/
virtual void PullRowSparse(const std::vector<int>& str_keys,
const std::vector<std::pair<NDArray*, NDArray>>& val_rowids,
const int priority = 0) = 0;

/*!
* \brief pull a list of key-value pairs from the store, where each key is a string.
* The NDArray pulled back will be in row_sparse storage with only the
* specified row_ids present (others rows are zeros).
* \param keys the list of keys in string format
* \param values the list of buffers - row_id pairs
* \param priority the priority of the action.
*/
virtual void PullRowSparse(const std::vector<std::string>& str_keys,
const std::vector<std::pair<NDArray*, NDArray>>& val_rowids,
const int priority = 0) = 0;

/**
* \brief the prototype of user-defined updater
Expand Down
11 changes: 8 additions & 3 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,13 @@ class NDArray {
return shape_;
}
/*!
* \return the shape of underlying chunk which stores the NDArray values.
* For default storage, it is the same as shape(). For row-sparse storage, it is the shape of
* \return the shape of underlying chunk which stores the NDArray data/value.
* It is only intended for non-default storage. For row-sparse storage, it is the shape of
* the tensor which stores the non-zero values.
*/
inline const TShape &storage_shape() const {
CHECK(ptr_ != nullptr);
CHECK_NE(storage_type(), kDefaultStorage);
return ptr_->storage_shape;
}

Expand Down Expand Up @@ -271,7 +272,11 @@ class NDArray {
if (is_none()) return false;
auto stype = storage_type();
CHECK_NE(stype, kDefaultStorage);
if (stype == kRowSparseStorage || stype == kCSRStorage) {
if (stype == kRowSparseStorage) {
CHECK_EQ(aux_shape(rowsparse::kIdx)[0], storage_shape()[0]);
return aux_shape(0).Size() != 0;
} else if (stype == kCSRStorage) {
CHECK_EQ(aux_shape(csr::kIdx)[0], storage_shape()[0]);
return aux_shape(0).Size() != 0;
} else {
LOG(FATAL) << "Unknown storage type";
Expand Down
Loading

0 comments on commit c2ed043

Please sign in to comment.