From 57e4548ff78b7a3e052650f14e9e45eba5e58e31 Mon Sep 17 00:00:00 2001 From: jialin Date: Fri, 11 Feb 2022 14:49:45 -0800 Subject: [PATCH] add optional empty lines filter in read_text (#22298) ray.data.read_text() currently doesn't take care of empty lines; this pr adds a flag to enable the empty line filter; with this change, read_text will only return non-empty line by default, unless otherwise setting drop_empty_line to False. Co-authored-by: Eric Liang Co-authored-by: Jialin Liu --- python/ray/data/read_api.py | 9 ++++++++- python/ray/data/tests/test_dataset.py | 6 +++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/python/ray/data/read_api.py b/python/ray/data/read_api.py index ab7528dccac5..4c3cf9c24cc4 100644 --- a/python/ray/data/read_api.py +++ b/python/ray/data/read_api.py @@ -468,6 +468,7 @@ def read_text( *, encoding: str = "utf-8", errors: str = "ignore", + drop_empty_lines: bool = True, filesystem: Optional["pyarrow.fs.FileSystem"] = None, parallelism: int = 200, arrow_open_stream_args: Optional[Dict[str, Any]] = None, @@ -496,12 +497,18 @@ def read_text( Dataset holding lines of text read from the specified paths. """ + def to_text(s): + lines = s.decode(encoding).split("\n") + if drop_empty_lines: + lines = [line for line in lines if line.strip() != ""] + return lines + return read_binary_files( paths, filesystem=filesystem, parallelism=parallelism, arrow_open_stream_args=arrow_open_stream_args, - ).flat_map(lambda x: x.decode(encoding, errors=errors).split("\n")) + ).flat_map(to_text) @PublicAPI(stability="beta") diff --git a/python/ray/data/tests/test_dataset.py b/python/ray/data/tests/test_dataset.py index e54ecdf82755..2aa511d234a2 100644 --- a/python/ray/data/tests/test_dataset.py +++ b/python/ray/data/tests/test_dataset.py @@ -1003,8 +1003,12 @@ def test_read_text(ray_start_regular_shared, tmp_path): f.write("world") with open(os.path.join(path, "file2.txt"), "w") as f: f.write("goodbye") + with open(os.path.join(path, "file3.txt"), "w") as f: + f.write("ray\n") ds = ray.data.read_text(path) - assert sorted(ds.take()) == ["goodbye", "hello", "world"] + assert sorted(ds.take()) == ["goodbye", "hello", "ray", "world"] + ds = ray.data.read_text(path, drop_empty_lines=False) + assert ds.count() == 5 @pytest.mark.parametrize("pipelined", [False, True])