Skip to content

Commit

Permalink
WebDataset support for Ray Data (ray-project#33336)
Browse files Browse the repository at this point in the history
Support for the WebDataset format. WebDataset is a highly scalable, simple open format for deep learning. It is supported by torchdata and native libraries for Python, Julia, Golang, and C++.

WebDataset reads datasets stored as tar files, with the simple convention that files that belong together and make up a training sample share the same basename. WebDataset can read files from a local disk or from any pipe, which allows it to access files using common cloud object stores.

The WebDataset representation allows writing purely sequential I/O pipelines for large-scale deep learning. This is important for achieving high I/O rates from local storage (3x-10x for local drives compared to random access) and for using object stores and cloud storage for training.

The WebDataset format represents images, movies, audio, etc. in their native file formats, making the creation of WebDataset format data as easy as just creating a tar archive. Because of the way data is aligned, WebDataset works well with block deduplication as well and aligns data on predictable boundaries.

Signed-off-by: elliottower <[email protected]>
  • Loading branch information
tmbdev authored and elliottower committed Apr 22, 2023
1 parent a26b9fc commit 09f3c94
Show file tree
Hide file tree
Showing 6 changed files with 730 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python/ray/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
read_text,
read_mongo,
read_tfrecords,
read_webdataset,
)


Expand Down Expand Up @@ -88,6 +89,7 @@
"read_parquet_bulk",
"read_sql",
"read_tfrecords",
"read_webdataset",
"set_progress_bars",
"Preprocessor",
]
70 changes: 70 additions & 0 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2663,6 +2663,76 @@ def write_tfrecords(
tf_schema=tf_schema,
)

@PublicAPI(stability="alpha")
@ConsumptionAPI
def write_webdataset(
self,
path: str,
*,
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
try_create_dir: bool = True,
arrow_open_stream_args: Optional[Dict[str, Any]] = None,
block_path_provider: BlockWritePathProvider = DefaultBlockWritePathProvider(),
ray_remote_args: Dict[str, Any] = None,
encoder: Optional[Union[bool, str, callable, list]] = True,
) -> None:
"""Write the dataset to WebDataset files.
The `TFRecord <https://www.tensorflow.org/tutorials/load_data/tfrecord>`_
files will contain
`tf.train.Example <https://www.tensorflow.org/api_docs/python/tf/train/Example>`_ # noqa: E501
records, with one Example record for each row in the dataset.
.. warning::
tf.train.Feature only natively stores ints, floats, and bytes,
so this function only supports datasets with these data types,
and will error if the dataset contains unsupported types.
This is only supported for datasets convertible to Arrow records.
To control the number of files, use ``.repartition()``.
Unless a custom block path provider is given, the format of the output
files will be {uuid}_{block_idx}.tfrecords, where ``uuid`` is an unique id
for the dataset.
Examples:
>>> import ray
>>> ds = ray.data.from_items([
... { "name": "foo", "score": 42 },
... { "name": "bar", "score": 43 },
... ])
>>> ds.write_webdataset("s3://bucket/path") # doctest: +SKIP
Time complexity: O(dataset size / parallelism)
Args:
path: The path to the destination root directory, where tfrecords
files will be written to.
filesystem: The filesystem implementation to write to.
try_create_dir: Try to create all directories in destination path
if True. Does nothing if all directories already exist.
arrow_open_stream_args: kwargs passed to
pyarrow.fs.FileSystem.open_output_stream
block_path_provider: BlockWritePathProvider implementation to
write each dataset block to a custom output path.
ray_remote_args: Kwargs passed to ray.remote in the write tasks.
"""

from ray.data.datasource.webdataset_datasource import WebDatasetDatasource

self.write_datasource(
WebDatasetDatasource(),
ray_remote_args=ray_remote_args,
path=path,
dataset_uuid=self._uuid,
filesystem=filesystem,
try_create_dir=try_create_dir,
open_stream_args=arrow_open_stream_args,
block_path_provider=block_path_provider,
encoder=encoder,
)

@ConsumptionAPI
def write_numpy(
self,
Expand Down
2 changes: 2 additions & 0 deletions python/ray/data/datasource/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
)
from ray.data.datasource.sql_datasource import Connection, SQLDatasource
from ray.data.datasource.tfrecords_datasource import TFRecordDatasource
from ray.data.datasource.webdataset_datasource import WebDatasetDatasource
from ray.data.datasource.text_datasource import TextDatasource

__all__ = [
Expand Down Expand Up @@ -76,6 +77,7 @@
"Reader",
"TextDatasource",
"TFRecordDatasource",
"WebDatasetDatasource",
"WriteResult",
"_S3FileSystemWrapper",
]
Loading

0 comments on commit 09f3c94

Please sign in to comment.