From c549613bc2c131b446e9bda7100714a75565ea93 Mon Sep 17 00:00:00 2001 From: jukejian Date: Fri, 8 Nov 2024 15:33:11 +0800 Subject: [PATCH] [data] return the output of read_webdataset in formats compatible with PyArrow and Pandas Signed-off-by: jukejian --- .../datasource/webdataset_datasource.py | 18 +++++++++++++++++- python/ray/data/read_api.py | 2 ++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/python/ray/data/_internal/datasource/webdataset_datasource.py b/python/ray/data/_internal/datasource/webdataset_datasource.py index d9f5f876d4ec..5d132619cc62 100644 --- a/python/ray/data/_internal/datasource/webdataset_datasource.py +++ b/python/ray/data/_internal/datasource/webdataset_datasource.py @@ -314,6 +314,7 @@ def __init__( filerename: Optional[Union[bool, callable, list]] = None, suffixes: Optional[Union[bool, callable, list]] = None, verbose_open: bool = False, + output_format: Optional[str] = "pandas", **file_based_datasource_kwargs, ): super().__init__(paths, **file_based_datasource_kwargs) @@ -323,6 +324,7 @@ def __init__( self.filerename = filerename self.suffixes = suffixes self.verbose_open = verbose_open + self.output_format = output_format def _read_stream(self, stream: "pyarrow.NativeFile", path: str): """Read and decode samples from a stream. @@ -343,6 +345,7 @@ def _read_stream(self, stream: "pyarrow.NativeFile", path: str): List[Dict[str, Any]]: List of sample (list of length 1). """ import pandas as pd + import pyarrow as pa def get_tar_file_iterator(): return _tar_file_iterator( @@ -362,4 +365,17 @@ def get_tar_file_iterator(): for sample in samples: if self.decoder is not None: sample = _apply_list(self.decoder, sample, default=_default_decoder) - yield pd.DataFrame({k: [v] for k, v in sample.items()}) + if self.output_format == "pandas": + yield pd.DataFrame( + { + k: v if isinstance(v, list) and len(v) == 1 else [v] + for k, v in sample.items() + } + ) + else: + column_names = list(sample.keys()) + arrays = [ + pa.array(v if isinstance(v, list) and len(v) == 1 else [v]) + for v in sample.values() + ] + yield pa.Table.from_arrays(arrays, names=column_names) diff --git a/python/ray/data/read_api.py b/python/ray/data/read_api.py index 91d6ecaf3727..66099dd4cb00 100644 --- a/python/ray/data/read_api.py +++ b/python/ray/data/read_api.py @@ -1855,6 +1855,7 @@ def read_webdataset( file_extensions: Optional[List[str]] = None, concurrency: Optional[int] = None, override_num_blocks: Optional[int] = None, + output_format: Optional[str] = "pandas", ) -> Dataset: """Create a :class:`~ray.data.Dataset` from `WebDataset `_ files. @@ -1918,6 +1919,7 @@ def read_webdataset( shuffle=shuffle, include_paths=include_paths, file_extensions=file_extensions, + output_format=output_format, ) return read_datasource( datasource,