Skip to content

Commit

Permalink
LibsvmIter Doc Updates (apache#8111)
Browse files Browse the repository at this point in the history
* add clarification to libsvm iter

* add reset notes

* add reset notes

* also update csv iter
  • Loading branch information
eric-haibin-lin authored and crazy-cat committed Oct 26, 2017
1 parent a344924 commit 0cfdfa8
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 32 deletions.
2 changes: 1 addition & 1 deletion benchmark/python/sparse/sparse_end2end.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@ def row_sparse_pull(kv, key, data, slices, weight_array, priority):
start_time_epoch = time.time()
nbatch = 0
end_of_batch = False
data_iter.reset()
metric.reset()
next_batch = next(data_iter)
if kv is not None:
Expand Down Expand Up @@ -300,6 +299,7 @@ def row_sparse_pull(kv, key, data, slices, weight_array, priority):
logging.info('num_worker = {}, time cost per epoch = {}'.format(str(num_worker), str(time_cost_epoch)))
if args.num_gpu < 1:
logging.info('|cpu/{} cores| {} | {} | {} |'.format(str(num_cores), str(num_worker), str(average_cost_epoch), rank))
data_iter.reset()
if profiler:
mx.profiler.profiler_set_state('stop')
end = time.time()
Expand Down
2 changes: 2 additions & 0 deletions src/io/iter_csv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ to set `round_batch` to False.
If ``data_csv = 'data/'`` is set, then all the files in this directory will be read.
``reset()`` is expected to be called only after a complete pass of data.
Examples::
// Contents of CSV file ``data/data.csv``.
Expand Down
44 changes: 29 additions & 15 deletions src/io/iter_libsvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,31 +198,35 @@ class LibSVMIter: public SparseIIterator<DataInst> {
DMLC_REGISTER_PARAMETER(LibSVMIterParam);

MXNET_REGISTER_IO_ITER(LibSVMIter)
.describe(R"code(Returns the libsvm file iterator which returns sparse data with `csr`
.describe(R"code(Returns the LibSVM iterator which returns data with `csr`
storage type. This iterator is experimental and should be used with care.
The input data is stored in a format similar to libsvm file format, except that the **indices
The input data is stored in a format similar to LibSVM file format, except that the **indices
are expected to be zero-based instead of one-based, and the column indices for each row are
expected to be sorted in ascending order**. Details of the libsvm format are available
at `https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/`
expected to be sorted in ascending order**. Details of the LibSVM format are available
`here. <https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/>`_
In this function, the `data_shape` parameter is used to set the shape of each line of the data.
The `data_shape` parameter is used to set the shape of each line of the data.
The dimension of both `data_shape` and `label_shape` are expected to be 1.
When `label_libsvm` is set to ``NULL``, both data and label are read from the same file specified
The `data_libsvm` parameter is used to set the path input LibSVM file.
When it is set to a directory, all the files in the directory will be read.
When `label_libsvm` is set to ``NULL``, both data and label are read from the file specified
by `data_libsvm`. In this case, the data is stored in `csr` storage type, while the label is a 1D
dense array. Otherwise, data is read from `data_libsvm` and label from `label_libsvm`,
in this case, both data and label are stored in csr storage type. If `data_libsvm` contains label,
it will ignored.
dense array.
The `LibSVMIter` only support `round_batch` parameter set to ``True``. Therefore, if `batch_size`
is 3 and there are 4 total rows in libsvm file, 2 more examples are consumed at the first round.
The `LibSVMIter` only support `round_batch` parameter set to ``True`` for now. So, if `batch_size`
is 3 and there are 4 total rows in libsvm file, 2 more examples
are consumed at the first round. If `reset` function is called after first round,
the call is ignored and remaining examples are returned in the second round.
When `num_parts` and `part_index` are provided, the data is split into `num_parts` partitions,
and the iterator only reads the `part_index`-th partition. However, the partitions are not
guaranteed to be even.
If ``data_libsvm = 'data/'`` is set, then all the files in this directory will be read.
``reset()`` is expected to be called only after a complete pass of data.
Examples::
Example::
# Contents of libsvm file ``data.t``.
1.0 0:0.5 2:1.2
Expand Down Expand Up @@ -256,6 +260,16 @@ Examples::
>>> second_batch.label[0].asnumpy()
[ 4. 1. -2.]
>>> data_iter.reset()
# To restart the iterator for the second pass of the data
When `label_libsvm` is set to the path to another LibSVM file,
data is read from `data_libsvm` and label from `label_libsvm`.
In this case, both data and label are stored in the csr format.
If the label column in the `data_libsvm` file is ignored.
Example::
# Contents of libsvm file ``label.t``
1.0
-2.0 0:0.125
Expand Down
29 changes: 13 additions & 16 deletions tests/python/unittest/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,6 @@ def test_MNISTIter():
assert(sum(label_0 - label_1) == 0)

def test_Cifar10Rec():
# skip-this test for saving time
return
get_data.GetCifar10()
dataiter = mx.io.ImageRecordIter(
path_imgrec="data/cifar/train.rec",
Expand Down Expand Up @@ -215,25 +213,24 @@ def check_libSVMIter_news_data():
'num_classes': 20,
'num_examples': 3993,
}
num_parts = 3
batch_size = 128
batch_size = 33
num_examples = news_metadata['num_examples']
data_dir = os.path.join(os.getcwd(), 'data')
get_bz2_data(data_dir, news_metadata['name'], news_metadata['url'],
news_metadata['origin_name'])
news_metadata['origin_name'])
path = os.path.join(data_dir, news_metadata['name'])
data_train = mx.io.LibSVMIter(data_libsvm=path, data_shape=(news_metadata['feature_dim'],),
batch_size=batch_size, num_parts=num_parts, part_index=0)
num_batches = 0
iterator = iter(data_train)
for batch in iterator:
# check the range of labels
assert(np.sum(batch.label[0].asnumpy() > 20) == 0)
assert(np.sum(batch.label[0].asnumpy() <= 0) == 0)
num_batches += 1
import math
expected_num_batches = math.ceil(num_examples * 1.0 / batch_size / num_parts)
assert(num_batches == int(expected_num_batches)), (num_batches, expected_num_batches)
batch_size=batch_size)
for epoch in range(2):
num_batches = 0
for batch in data_train:
# check the range of labels
assert(np.sum(batch.label[0].asnumpy() > 20) == 0)
assert(np.sum(batch.label[0].asnumpy() <= 0) == 0)
num_batches += 1
expected_num_batches = num_examples / batch_size
assert(num_batches == int(expected_num_batches)), num_batches
data_train.reset()

check_libSVMIter_synthetic()
check_libSVMIter_news_data()
Expand Down

0 comments on commit 0cfdfa8

Please sign in to comment.