Skip to content

Commit

Permalink
fix subset issues when subset size is larger than input
Browse files Browse the repository at this point in the history
  • Loading branch information
svirpioj committed Nov 22, 2023
1 parent f35ce1a commit 1d3b703
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 4 deletions.
6 changes: 4 additions & 2 deletions opusfilter/opusfilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,8 +447,10 @@ def get_subset(self, parameters, overwrite=False):
logger.info("Sampling subset of %s lines from total %s lines", size, total)
if total < size:
logger.warning("Number of lines (%s) is smaller than requested size (%s)", total, size)
shutil.copyfile(infiles[0], outfiles[0])
shutil.copyfile(infiles[1], outfiles[1])
for infname, outfname in zip(infiles, outfiles):
with file_open(infname) as inf, file_open(outfname, 'w') as outf:
for line in inf:
outf.write(line)
elif shuffle_subset:
sample = random.sample(range(total), size)
with file_open(infiles[0]) as inf, file_open(outfiles[0], 'w') as outf:
Expand Down
41 changes: 39 additions & 2 deletions tests/test_opusfilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,9 +687,13 @@ def setUp(self):
self.opus_filter = OpusFilter(
{'common': {'output_directory': self.tempdir}, 'steps': []})
with open(os.path.join(self.tempdir, 'input_src'), 'w') as f:
f.write(''.join('sent_{}\n'.format(idx) for idx in range(100)))
f.write(''.join(f'sent_{idx}\n' for idx in range(100)))
with open(os.path.join(self.tempdir, 'input_tgt'), 'w') as f:
f.write(''.join('sent_{}\n'.format(idx) for idx in range(100)))
f.write(''.join(f'sent_{idx}\n' for idx in range(100)))
with file_open(os.path.join(self.tempdir, 'input_src.gz'), 'w') as f:
f.write(''.join(f'sent_{idx}\n' for idx in range(100)))
with file_open(os.path.join(self.tempdir, 'input_tgt.gz'), 'w') as f:
f.write(''.join(f'sent_{idx}\n' for idx in range(100)))

def tearDown(self):
shutil.rmtree(self.tempdir)
Expand Down Expand Up @@ -726,6 +730,39 @@ def test_subset_shuffle(self):
self.assertEqual(len(lines2), 20)
self.assertFalse(all(l1 == l2 for l1, l2 in zip(lines1, lines2)))

def test_subset_more_than_total(self):
parameters = {
'inputs': [os.path.join(self.tempdir, 'input_src'),
os.path.join(self.tempdir, 'input_tgt')],
'outputs': [os.path.join(self.tempdir, 'output_src'),
os.path.join(self.tempdir, 'output_tgt')],
'size': 200}
self.opus_filter.get_subset(parameters)
with open(os.path.join(self.tempdir, 'output_src')) as fobj1, \
open(os.path.join(self.tempdir, 'output_tgt')) as fobj2:
lines1 = fobj1.readlines()
lines2 = fobj2.readlines()
self.assertEqual(len(lines1), 100)
self.assertEqual(len(lines2), 100)
self.assertSequenceEqual(lines1, lines2)

def test_subset_more_than_total_gzip(self):
parameters = {
'inputs': [os.path.join(self.tempdir, 'input_src.gz'),
os.path.join(self.tempdir, 'input_tgt.gz')],
'outputs': [os.path.join(self.tempdir, 'output_src'),
os.path.join(self.tempdir, 'output_tgt')],
'size': 200}
self.opus_filter.get_subset(parameters)
with open(os.path.join(self.tempdir, 'output_src')) as fobj1, \
open(os.path.join(self.tempdir, 'output_tgt')) as fobj2:
lines1 = fobj1.readlines()
lines2 = fobj2.readlines()
self.assertEqual(len(lines1), 100)
self.assertEqual(len(lines2), 100)
self.assertSequenceEqual(lines1, lines2)



class TestSplit(unittest.TestCase):

Expand Down

0 comments on commit 1d3b703

Please sign in to comment.