Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

LibsvmIter Doc Updates #8111

Merged
merged 4 commits into from
Oct 1, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmark/python/sparse/sparse_end2end.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@ def row_sparse_pull(kv, key, data, slices, weight_array, priority):
start_time_epoch = time.time()
nbatch = 0
end_of_batch = False
data_iter.reset()
metric.reset()
next_batch = next(data_iter)
if kv is not None:
Expand Down Expand Up @@ -300,6 +299,7 @@ def row_sparse_pull(kv, key, data, slices, weight_array, priority):
logging.info('num_worker = {}, time cost per epoch = {}'.format(str(num_worker), str(time_cost_epoch)))
if args.num_gpu < 1:
logging.info('|cpu/{} cores| {} | {} | {} |'.format(str(num_cores), str(num_worker), str(average_cost_epoch), rank))
data_iter.reset()
if profiler:
mx.profiler.profiler_set_state('stop')
end = time.time()
Expand Down
2 changes: 2 additions & 0 deletions src/io/iter_csv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ to set `round_batch` to False.

If ``data_csv = 'data/'`` is set, then all the files in this directory will be read.

``reset()`` is expected to be called only after a complete pass of data.

Examples::

// Contents of CSV file ``data/data.csv``.
Expand Down
44 changes: 29 additions & 15 deletions src/io/iter_libsvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,31 +198,35 @@ class LibSVMIter: public SparseIIterator<DataInst> {
DMLC_REGISTER_PARAMETER(LibSVMIterParam);

MXNET_REGISTER_IO_ITER(LibSVMIter)
.describe(R"code(Returns the libsvm file iterator which returns sparse data with `csr`
.describe(R"code(Returns the LibSVM iterator which returns data with `csr`
storage type. This iterator is experimental and should be used with care.

The input data is stored in a format similar to libsvm file format, except that the **indices
The input data is stored in a format similar to LibSVM file format, except that the **indices
are expected to be zero-based instead of one-based, and the column indices for each row are
expected to be sorted in ascending order**. Details of the libsvm format are available
at `https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/`
expected to be sorted in ascending order**. Details of the LibSVM format are available
`here. <https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/>`_

In this function, the `data_shape` parameter is used to set the shape of each line of the data.

The `data_shape` parameter is used to set the shape of each line of the data.
The dimension of both `data_shape` and `label_shape` are expected to be 1.

When `label_libsvm` is set to ``NULL``, both data and label are read from the same file specified
The `data_libsvm` parameter is used to set the path input LibSVM file.
When it is set to a directory, all the files in the directory will be read.

When `label_libsvm` is set to ``NULL``, both data and label are read from the file specified
by `data_libsvm`. In this case, the data is stored in `csr` storage type, while the label is a 1D
dense array. Otherwise, data is read from `data_libsvm` and label from `label_libsvm`,
in this case, both data and label are stored in csr storage type. If `data_libsvm` contains label,
it will ignored.
dense array.

The `LibSVMIter` only support `round_batch` parameter set to ``True``. Therefore, if `batch_size`
is 3 and there are 4 total rows in libsvm file, 2 more examples are consumed at the first round.

The `LibSVMIter` only support `round_batch` parameter set to ``True`` for now. So, if `batch_size`
is 3 and there are 4 total rows in libsvm file, 2 more examples
are consumed at the first round. If `reset` function is called after first round,
the call is ignored and remaining examples are returned in the second round.
When `num_parts` and `part_index` are provided, the data is split into `num_parts` partitions,
and the iterator only reads the `part_index`-th partition. However, the partitions are not
guaranteed to be even.

If ``data_libsvm = 'data/'`` is set, then all the files in this directory will be read.
``reset()`` is expected to be called only after a complete pass of data.

Examples::
Example::

# Contents of libsvm file ``data.t``.
1.0 0:0.5 2:1.2
Expand Down Expand Up @@ -256,6 +260,16 @@ Examples::
>>> second_batch.label[0].asnumpy()
[ 4. 1. -2.]

>>> data_iter.reset()
# To restart the iterator for the second pass of the data

When `label_libsvm` is set to the path to another LibSVM file,
data is read from `data_libsvm` and label from `label_libsvm`.
In this case, both data and label are stored in the csr format.
If the label column in the `data_libsvm` file is ignored.

Example::

# Contents of libsvm file ``label.t``
1.0
-2.0 0:0.125
Expand Down
29 changes: 13 additions & 16 deletions tests/python/unittest/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,6 @@ def test_MNISTIter():
assert(sum(label_0 - label_1) == 0)

def test_Cifar10Rec():
# skip-this test for saving time
return
get_data.GetCifar10()
dataiter = mx.io.ImageRecordIter(
path_imgrec="data/cifar/train.rec",
Expand Down Expand Up @@ -215,25 +213,24 @@ def check_libSVMIter_news_data():
'num_classes': 20,
'num_examples': 3993,
}
num_parts = 3
batch_size = 128
batch_size = 33
num_examples = news_metadata['num_examples']
data_dir = os.path.join(os.getcwd(), 'data')
get_bz2_data(data_dir, news_metadata['name'], news_metadata['url'],
news_metadata['origin_name'])
news_metadata['origin_name'])
path = os.path.join(data_dir, news_metadata['name'])
data_train = mx.io.LibSVMIter(data_libsvm=path, data_shape=(news_metadata['feature_dim'],),
batch_size=batch_size, num_parts=num_parts, part_index=0)
num_batches = 0
iterator = iter(data_train)
for batch in iterator:
# check the range of labels
assert(np.sum(batch.label[0].asnumpy() > 20) == 0)
assert(np.sum(batch.label[0].asnumpy() <= 0) == 0)
num_batches += 1
import math
expected_num_batches = math.ceil(num_examples * 1.0 / batch_size / num_parts)
assert(num_batches == int(expected_num_batches)), (num_batches, expected_num_batches)
batch_size=batch_size)
for epoch in range(2):
num_batches = 0
for batch in data_train:
# check the range of labels
assert(np.sum(batch.label[0].asnumpy() > 20) == 0)
assert(np.sum(batch.label[0].asnumpy() <= 0) == 0)
num_batches += 1
expected_num_batches = num_examples / batch_size
assert(num_batches == int(expected_num_batches)), num_batches
data_train.reset()

check_libSVMIter_synthetic()
check_libSVMIter_news_data()
Expand Down