From 1d3b7038ea4aa562a9360d538e40290855d254e2 Mon Sep 17 00:00:00 2001 From: Sami Virpioja Date: Wed, 22 Nov 2023 16:03:06 +0200 Subject: [PATCH] fix subset issues when subset size is larger than input --- opusfilter/opusfilter.py | 6 ++++-- tests/test_opusfilter.py | 41 ++++++++++++++++++++++++++++++++++++++-- 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/opusfilter/opusfilter.py b/opusfilter/opusfilter.py index 52312dd..a99a5f2 100644 --- a/opusfilter/opusfilter.py +++ b/opusfilter/opusfilter.py @@ -447,8 +447,10 @@ def get_subset(self, parameters, overwrite=False): logger.info("Sampling subset of %s lines from total %s lines", size, total) if total < size: logger.warning("Number of lines (%s) is smaller than requested size (%s)", total, size) - shutil.copyfile(infiles[0], outfiles[0]) - shutil.copyfile(infiles[1], outfiles[1]) + for infname, outfname in zip(infiles, outfiles): + with file_open(infname) as inf, file_open(outfname, 'w') as outf: + for line in inf: + outf.write(line) elif shuffle_subset: sample = random.sample(range(total), size) with file_open(infiles[0]) as inf, file_open(outfiles[0], 'w') as outf: diff --git a/tests/test_opusfilter.py b/tests/test_opusfilter.py index ca8c57f..200d2bf 100644 --- a/tests/test_opusfilter.py +++ b/tests/test_opusfilter.py @@ -687,9 +687,13 @@ def setUp(self): self.opus_filter = OpusFilter( {'common': {'output_directory': self.tempdir}, 'steps': []}) with open(os.path.join(self.tempdir, 'input_src'), 'w') as f: - f.write(''.join('sent_{}\n'.format(idx) for idx in range(100))) + f.write(''.join(f'sent_{idx}\n' for idx in range(100))) with open(os.path.join(self.tempdir, 'input_tgt'), 'w') as f: - f.write(''.join('sent_{}\n'.format(idx) for idx in range(100))) + f.write(''.join(f'sent_{idx}\n' for idx in range(100))) + with file_open(os.path.join(self.tempdir, 'input_src.gz'), 'w') as f: + f.write(''.join(f'sent_{idx}\n' for idx in range(100))) + with file_open(os.path.join(self.tempdir, 'input_tgt.gz'), 'w') as f: + f.write(''.join(f'sent_{idx}\n' for idx in range(100))) def tearDown(self): shutil.rmtree(self.tempdir) @@ -726,6 +730,39 @@ def test_subset_shuffle(self): self.assertEqual(len(lines2), 20) self.assertFalse(all(l1 == l2 for l1, l2 in zip(lines1, lines2))) + def test_subset_more_than_total(self): + parameters = { + 'inputs': [os.path.join(self.tempdir, 'input_src'), + os.path.join(self.tempdir, 'input_tgt')], + 'outputs': [os.path.join(self.tempdir, 'output_src'), + os.path.join(self.tempdir, 'output_tgt')], + 'size': 200} + self.opus_filter.get_subset(parameters) + with open(os.path.join(self.tempdir, 'output_src')) as fobj1, \ + open(os.path.join(self.tempdir, 'output_tgt')) as fobj2: + lines1 = fobj1.readlines() + lines2 = fobj2.readlines() + self.assertEqual(len(lines1), 100) + self.assertEqual(len(lines2), 100) + self.assertSequenceEqual(lines1, lines2) + + def test_subset_more_than_total_gzip(self): + parameters = { + 'inputs': [os.path.join(self.tempdir, 'input_src.gz'), + os.path.join(self.tempdir, 'input_tgt.gz')], + 'outputs': [os.path.join(self.tempdir, 'output_src'), + os.path.join(self.tempdir, 'output_tgt')], + 'size': 200} + self.opus_filter.get_subset(parameters) + with open(os.path.join(self.tempdir, 'output_src')) as fobj1, \ + open(os.path.join(self.tempdir, 'output_tgt')) as fobj2: + lines1 = fobj1.readlines() + lines2 = fobj2.readlines() + self.assertEqual(len(lines1), 100) + self.assertEqual(len(lines2), 100) + self.assertSequenceEqual(lines1, lines2) + + class TestSplit(unittest.TestCase):