Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

CSVIter and LibSVMIter not returning correct number of batches per epoch #8062

Closed
eric-haibin-lin opened this issue Sep 27, 2017 · 3 comments
Labels

Comments

@eric-haibin-lin
Copy link
Member

For bugs or installation issues, please provide the following information.
The more information you provide, the more likely people will be able to help you.

Environment info

Operating System: DeepLearninig AMI

Compiler:

Package used (Python/R/Scala/Julia):

MXNet version:

Or if installed from source:

MXNet commit hash (git rev-parse HEAD): ae975e5

If you are using python package, please provide

Python version and distribution:

If you are using R package, please provide

R sessionInfo():

Error Message:

Please paste the full error message, including stack trace.

  File "./tests/python/unittest/test_io.py", line 301, in check_CSVIter_synthetic
    assert(nbatch == 100), nbatch
AssertionError: 185

Minimum reproducible example

if you are using your own code, please provide a short script that reproduces the error.

    def check_CSVIter_synthetic():
        cwd = os.getcwd()
        data_path = os.path.join(cwd, 'data.t')
        with open(data_path, 'w') as fout:
            for i in range(100):
                fout.write(','.join([str(i + 1) for _ in range(8*8)]) + '\n')
        batch_size = 1
        data_train = mx.io.CSVIter(data_csv=data_path, data_shape=(8,8),
                                   batch_size=batch_size)
        for epoch in range(10):
            data_train.reset()
            nbatch = 0
            for batch in iter(data_train):
                nbatch += 1
            assert(nbatch == 100), nbatch

Steps to reproduce

or if you are running standard examples, please provide the commands you have run that lead to the error.

  1. Run the above code. Same kind of error for libsvm iterator.
@eric-haibin-lin
Copy link
Member Author

@tqchen any idea where things could go wrong?

@eric-haibin-lin
Copy link
Member Author

Based on #2248 moving data_train.reset() to the end of the epoch will resolve this issue.

@eric-haibin-lin
Copy link
Member Author

The number of batches will be correct if reset() is moved to the end of the epoch:

        for epoch in range(10):
            nbatch = 0
            for batch in iter(data_train):
                nbatch += 1
            assert(nbatch == 100), nbatch
            data_train.reset()

I've updated the documentation for this in #8111

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

No branches or pull requests

1 participant