diff --git a/python/ray/data/__init__.py b/python/ray/data/__init__.py index 2d96b211eed4..972253ed9943 100644 --- a/python/ray/data/__init__.py +++ b/python/ray/data/__init__.py @@ -44,6 +44,7 @@ read_text, read_mongo, read_tfrecords, + read_webdataset, ) @@ -88,6 +89,7 @@ "read_parquet_bulk", "read_sql", "read_tfrecords", + "read_webdataset", "set_progress_bars", "Preprocessor", ] diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index dbd96616b138..aef3dd146c9e 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -2663,6 +2663,76 @@ def write_tfrecords( tf_schema=tf_schema, ) + @PublicAPI(stability="alpha") + @ConsumptionAPI + def write_webdataset( + self, + path: str, + *, + filesystem: Optional["pyarrow.fs.FileSystem"] = None, + try_create_dir: bool = True, + arrow_open_stream_args: Optional[Dict[str, Any]] = None, + block_path_provider: BlockWritePathProvider = DefaultBlockWritePathProvider(), + ray_remote_args: Dict[str, Any] = None, + encoder: Optional[Union[bool, str, callable, list]] = True, + ) -> None: + """Write the dataset to WebDataset files. + + The `TFRecord `_ + files will contain + `tf.train.Example `_ # noqa: E501 + records, with one Example record for each row in the dataset. + + .. warning:: + tf.train.Feature only natively stores ints, floats, and bytes, + so this function only supports datasets with these data types, + and will error if the dataset contains unsupported types. + + This is only supported for datasets convertible to Arrow records. + To control the number of files, use ``.repartition()``. + + Unless a custom block path provider is given, the format of the output + files will be {uuid}_{block_idx}.tfrecords, where ``uuid`` is an unique id + for the dataset. + + Examples: + >>> import ray + >>> ds = ray.data.from_items([ + ... { "name": "foo", "score": 42 }, + ... { "name": "bar", "score": 43 }, + ... ]) + >>> ds.write_webdataset("s3://bucket/path") # doctest: +SKIP + + Time complexity: O(dataset size / parallelism) + + Args: + path: The path to the destination root directory, where tfrecords + files will be written to. + filesystem: The filesystem implementation to write to. + try_create_dir: Try to create all directories in destination path + if True. Does nothing if all directories already exist. + arrow_open_stream_args: kwargs passed to + pyarrow.fs.FileSystem.open_output_stream + block_path_provider: BlockWritePathProvider implementation to + write each dataset block to a custom output path. + ray_remote_args: Kwargs passed to ray.remote in the write tasks. + + """ + + from ray.data.datasource.webdataset_datasource import WebDatasetDatasource + + self.write_datasource( + WebDatasetDatasource(), + ray_remote_args=ray_remote_args, + path=path, + dataset_uuid=self._uuid, + filesystem=filesystem, + try_create_dir=try_create_dir, + open_stream_args=arrow_open_stream_args, + block_path_provider=block_path_provider, + encoder=encoder, + ) + @ConsumptionAPI def write_numpy( self, diff --git a/python/ray/data/datasource/__init__.py b/python/ray/data/datasource/__init__.py index 48440594e79a..54bd1b86df18 100644 --- a/python/ray/data/datasource/__init__.py +++ b/python/ray/data/datasource/__init__.py @@ -40,6 +40,7 @@ ) from ray.data.datasource.sql_datasource import Connection, SQLDatasource from ray.data.datasource.tfrecords_datasource import TFRecordDatasource +from ray.data.datasource.webdataset_datasource import WebDatasetDatasource from ray.data.datasource.text_datasource import TextDatasource __all__ = [ @@ -76,6 +77,7 @@ "Reader", "TextDatasource", "TFRecordDatasource", + "WebDatasetDatasource", "WriteResult", "_S3FileSystemWrapper", ] diff --git a/python/ray/data/datasource/webdataset_datasource.py b/python/ray/data/datasource/webdataset_datasource.py new file mode 100644 index 000000000000..6ca88f640d6d --- /dev/null +++ b/python/ray/data/datasource/webdataset_datasource.py @@ -0,0 +1,391 @@ +# Copyright NVIDIA Corporation 2023 +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Callable, Dict, Optional, Union, List, TYPE_CHECKING +import tarfile +import io +import time +import re +import uuid +import fnmatch +from functools import partial + +from ray.util.annotations import PublicAPI +from ray.data.block import BlockAccessor +from ray.data.datasource.file_based_datasource import FileBasedDatasource + + +if TYPE_CHECKING: + import pyarrow + + +def base_plus_ext(path: str): + """Split off all file extensions. + + Returns base, allext. + + Args: + path: path with extensions + + Returns: + str: path with all extensions removed + """ + match = re.match(r"^((?:.*/|)[^.]+)[.]([^/]*)$", path) + if not match: + return None, None + return match.group(1), match.group(2) + + +def _valid_sample(sample: Dict[str, Any]): + """Check whether a sample is valid. + + Args: + sample: sample to be checked + """ + return ( + sample is not None + and isinstance(sample, dict) + and len(list(sample.keys())) > 0 + and not sample.get("__bad__", False) + ) + + +def _apply_list( + f: Union[Callable, List[Callable]], sample: Dict[str, Any], default: Callable = None +): + """Apply a list of functions to a sample. + + Args: + f: function or list of functions + sample: sample to be modified + default: default function to be applied to all keys. + Defaults to None. + + Returns: + modified sample + """ + if f is None: + return sample + if not isinstance(f, list): + f = [f] + for g in f: + if default is not None and not callable(g): + g = partial(default, format=g) + sample = g(sample) + return sample + + +def _check_suffix(suffix: str, suffixes: Union[list, callable]): + """Check whether a suffix is valid. + + Suffixes can be either None (=accept everything), a callable, + or a list of patterns. If the pattern contains */? it is treated + as a glob pattern, otherwise it is treated as a literal. + + Args: + suffix: suffix to be checked + suffixes: list of valid suffixes + """ + if suffixes is None: + return True + if callable(suffixes): + return suffixes(suffix) + for pattern in suffixes: + if "*" in pattern or "?" in pattern: + if fnmatch.fnmatch("." + suffix, pattern): + return True + elif suffix == pattern or "." + suffix == pattern: + return True + return False + + +def _tar_file_iterator( + fileobj: Any, + fileselect: Optional[Union[bool, callable, list]] = None, + filerename: Optional[Union[bool, callable, list]] = None, + verbose_open: bool = False, + meta: dict = None, +): + """Iterate over tar file, yielding filename, content pairs for the given tar stream. + + Args: + fileobj: file object + fileselect: patterns or function selecting + files to be selected + meta: metadata to be added to each sample + """ + meta = meta or {} + stream = tarfile.open(fileobj=fileobj, mode="r|*") + if verbose_open: + print(f"start {meta}") + for tarinfo in stream: + fname = tarinfo.name + if not tarinfo.isreg() or fname is None: + continue + data = stream.extractfile(tarinfo).read() + fname = _apply_list(filerename, fname) + assert isinstance(fname, str) + if not _check_suffix(fname, fileselect): + continue + result = dict(fname=fname, data=data) + yield result + if verbose_open: + print(f"done {meta}") + + +def _group_by_keys( + data: List[Dict[str, Any]], + keys: callable = base_plus_ext, + suffixes: Optional[Union[list, callable]] = None, + meta: dict = None, +): + """Return function over iterator that groups key, value pairs into samples. + + Args: + data: iterator over key, value pairs + keys: function that returns key, suffix for a given key + suffixes: list of suffixes to be included in the sample + meta: metadata to be added to each sample + """ + meta = meta or {} + current_sample = None + for filesample in data: + assert isinstance(filesample, dict) + fname, value = filesample["fname"], filesample["data"] + prefix, suffix = keys(fname) + if prefix is None: + continue + if current_sample is None or prefix != current_sample["__key__"]: + if _valid_sample(current_sample): + current_sample.update(meta) + yield current_sample + current_sample = dict(__key__=prefix) + if "__url__" in filesample: + current_sample["__url__"] = filesample["__url__"] + if suffix in current_sample: + raise ValueError( + f"{fname}: duplicate file name in tar file " + + f"{suffix} {current_sample.keys()}" + ) + if suffixes is None or _check_suffix(suffix, suffixes): + current_sample[suffix] = value + if _valid_sample(current_sample): + current_sample.update(meta) + yield current_sample + + +def _default_decoder(sample: Dict[str, Any], format: Optional[Union[bool, str]] = True): + """A default decoder for webdataset. + + This handles common file extensions: .txt, .cls, .cls2, + .jpg, .png, .json, .npy, .mp, .pt, .pth, .pickle, .pkl. + These are the most common extensions used in webdataset. + For other extensions, users can provide their own decoder. + + Args: + sample: sample, modified in place + """ + sample = dict(sample) + for key, value in sample.items(): + extension = key.split(".")[-1] + if key.startswith("__"): + continue + elif extension in ["txt", "text"]: + sample[key] = value.decode("utf-8") + elif extension in ["cls", "cls2"]: + sample[key] = int(value.decode("utf-8")) + elif extension in ["jpg", "png", "ppm", "pgm", "pbm", "pnm"]: + import PIL.Image + import numpy as np + + if format == "PIL": + sample[key] = PIL.Image.open(io.BytesIO(value)) + else: + sample[key] = np.asarray(PIL.Image.open(io.BytesIO(value))) + elif extension == "json": + import json + + sample[key] = json.loads(value) + elif extension == "npy": + import numpy as np + + sample[key] = np.load(io.BytesIO(value)) + elif extension == "mp": + import msgpack + + sample[key] = msgpack.unpackb(value, raw=False) + elif extension in ["pt", "pth"]: + import torch + + sample[key] = torch.load(io.BytesIO(value)) + elif extension in ["pickle", "pkl"]: + import pickle + + sample[key] = pickle.loads(value) + return sample + + +extension_to_format = {"jpg": "jpeg"} + + +def _default_encoder(sample: Dict[str, Any], format: Optional[Union[str, bool]] = True): + """A default encoder for webdataset. + + This handles common file extensions: .txt, .cls, .cls2, .jpg, + .png, .json, .npy, .mp, .pt, .pth, .pickle, .pkl + These are the most common extensions used in webdataset. + For other extensions, users can provide their own encoder. + + Args: + sample (Dict[str, Any]): sample + """ + sample = dict(sample) + for key, value in sample.items(): + extension = key.split(".")[-1] + if key.startswith("__"): + continue + elif extension in ["txt"]: + sample[key] = value.encode("utf-8") + elif extension in ["cls", "cls2"]: + sample[key] = str(value).encode("utf-8") + elif extension in ["jpg", "jpeg", "png", "ppm", "pgm", "pbm", "pnm"]: + import PIL.Image + import numpy as np + + if isinstance(value, np.ndarray): + value = PIL.Image.fromarray(value) + assert isinstance(value, PIL.Image.Image) + stream = io.BytesIO() + value.save( + stream, format=extension_to_format.get(extension.lower(), extension) + ) + sample[key] = stream.getvalue() + elif extension == "json": + import json + + sample[key] = json.dumps(value).encode("utf-8") + elif extension == "npy": + import numpy as np + + stream = io.BytesIO() + np.save(stream, value) + sample[key] = stream.getvalue() + elif extension == "mp": + import msgpack + + sample[key] = msgpack.dumps(value) + elif extension in ["pt", "pth"]: + import torch + + stream = io.BytesIO() + torch.save(value, stream) + sample[key] = stream.getvalue() + elif extension in ["pickle", "pkl"]: + import pickle + + stream = io.BytesIO() + pickle.dump(value, stream) + sample[key] = stream.getvalue() + return sample + + +def _make_iterable(block: BlockAccessor): + """Make a block iterable. + + This is a placeholder for dealing with more complex blocks. + + Args: + block: Ray Dataset block + + Returns: + Iterable[Dict[str,Any]]: Iterable of samples + """ + return block.iter_rows() + + +@PublicAPI(stability="alpha") +class WebDatasetDatasource(FileBasedDatasource): + """A Datasource for WebDataset datasets (tar format with naming conventions).""" + + _FILE_EXTENSION = "tar" + + def _read_stream( + self, + stream: "pyarrow.NativeFile", + path: str, + decoder: Optional[Union[bool, str, callable, list]] = True, + fileselect: Optional[Union[bool, callable, list]] = None, + filerename: Optional[Union[bool, callable, list]] = None, + suffixes: Optional[Union[bool, callable, list]] = None, + verbose_open: bool = False, + **kw, + ): + """Read and decode samples from a stream. + + Note that fileselect selects files during reading, while suffixes + selects files during the grouping step. + + Args: + stream: File descriptor to read from. + path: Path to the dataset. + decoder: decoder or list of decoders to be applied to samples + fileselect: Predicate for skipping files in tar decoder. + Defaults to lambda_:False. + suffixes: List of suffixes to be extracted. Defaults to None. + verbose_open: Print message when opening files. Defaults to False. + + Yields: + List[Dict[str, Any]]: List of sample (list of length 1). + """ + + files = _tar_file_iterator( + stream, + fileselect=fileselect, + filerename=filerename, + verbose_open=verbose_open, + ) + samples = _group_by_keys(files, meta=dict(__url__=path), suffixes=suffixes) + for sample in samples: + if decoder is not None: + sample = _apply_list(decoder, sample, default=_default_decoder) + yield [sample] + + def _write_block( + self, + f: "pyarrow.NativeFile", + block: BlockAccessor, + writer_args_fn: Callable[[], Dict[str, Any]] = lambda: {}, + encoder: Optional[Union[bool, str, callable, list]] = True, + **kw, + ): + """Encode and write samples to a stream. + + Args: + f: File descriptor to write to. + block: Data to be written. + writer_args_fn: Ignored. Defaults to lambda:{}. + encoder: (List of) encoder(s) to be applied to samples. Defaults to True. + """ + + stream = tarfile.open(fileobj=f, mode="w|") + samples = _make_iterable(block) + for sample in samples: + if not isinstance(sample, dict): + sample = sample.as_pydict() + if encoder is not None: + sample = _apply_list(encoder, sample, default=_default_encoder) + if "__key__" not in sample: + sample["__key__"] = uuid.uuid4().hex + key = sample["__key__"] + for k, v in sample.items(): + if v is None or k.startswith("__"): + continue + assert isinstance(v, bytes) or isinstance(v, str) + if not isinstance(v, bytes): + v = v.encode("utf-8") + ti = tarfile.TarInfo(f"{key}.{k}") + ti.size = len(v) + ti.mtime = time.time() + ti.mode, ti.uname, ti.gname = 0o644, "data", "data" + stream.addfile(ti, io.BytesIO(v)) + stream.close() diff --git a/python/ray/data/read_api.py b/python/ray/data/read_api.py index f952563fcebf..e93d101ffd05 100644 --- a/python/ray/data/read_api.py +++ b/python/ray/data/read_api.py @@ -56,6 +56,7 @@ ReadTask, TextDatasource, TFRecordDatasource, + WebDatasetDatasource, ) from ray.data.datasource.file_based_datasource import ( _unwrap_arrow_serialization_workaround, @@ -1184,6 +1185,65 @@ def read_tfrecords( ) +@PublicAPI(stability="alpha") +def read_webdataset( + paths: Union[str, List[str]], + *, + filesystem: Optional["pyarrow.fs.FileSystem"] = None, + parallelism: int = -1, + arrow_open_stream_args: Optional[Dict[str, Any]] = None, + meta_provider: BaseFileMetadataProvider = DefaultFileMetadataProvider(), + partition_filter: Optional[PathPartitionFilter] = None, + decoder: Optional[Union[bool, str, callable, list]] = True, + fileselect: Optional[Union[list, callable]] = None, + filerename: Optional[Union[list, callable]] = None, + suffixes: Optional[Union[list, callable]] = None, + verbose_open: bool = False, +) -> Dataset[PandasRow]: + """Create a dataset from WebDataset files. + + Args: + paths: A single file/directory path or a list of file/directory paths. + A list of paths can contain both files and directories. + filesystem: The filesystem implementation to read from. + parallelism: The requested parallelism of the read. Parallelism may be + limited by the number of files in the dataset. + arrow_open_stream_args: Key-word arguments passed to + ``pyarrow.fs.FileSystem.open_input_stream``. To read a compressed TFRecord file, + pass the corresponding compression type (e.g. for ``GZIP`` or ``ZLIB``, use + ``arrow_open_stream_args={'compression_type': 'gzip'}``). + meta_provider: File metadata provider. Custom metadata providers may + be able to resolve file metadata more quickly and/or accurately. + partition_filter: Path-based partition filter, if any. Can be used + with a custom callback to read only selected partitions of a dataset. + decoder: A function or list of functions to decode the data. + fileselect: A callable or list of glob patterns to select files. + filerename: A function or list of tuples to rename files prior to grouping. + suffixes: A function or list of suffixes to select for creating samples. + verbose_open: Whether to print the file names as they are opened. + + Returns: + A :class:`~ray.data.Dataset` that contains the example features. + + Raises: + ValueError: If a file contains a message that isn't a ``tf.train.Example``. + """ # noqa: E501 + return read_datasource( + WebDatasetDatasource(), + parallelism=parallelism, + paths=paths, + filesystem=filesystem, + open_stream_args=arrow_open_stream_args, + meta_provider=meta_provider, + partition_filter=partition_filter, + decoder=decoder, + fileselect=fileselect, + filerename=filerename, + suffixes=suffixes, + verbose_open=verbose_open, + ) + + @PublicAPI def read_binary_files( paths: Union[str, List[str]], diff --git a/python/ray/data/tests/test_dataset_webdataset.py b/python/ray/data/tests/test_dataset_webdataset.py new file mode 100644 index 000000000000..9771ee36d042 --- /dev/null +++ b/python/ray/data/tests/test_dataset_webdataset.py @@ -0,0 +1,205 @@ +# Copyright NVIDIA Corporation 2023 +# SPDX-License-Identifier: Apache-2.0 + +import os +import io + +import pytest +import tarfile +import glob + +import ray + +from ray.tests.conftest import * # noqa + + +class TarWriter: + def __init__(self, path): + self.path = path + self.tar = tarfile.open(path, "w") + + def __enter__(self): + return self + + def __exit__(self, *args): + self.tar.close() + + def write(self, name, data): + f = self.tar.tarinfo() + f.name = name + f.size = len(data) + self.tar.addfile(f, io.BytesIO(data)) + + +def test_webdataset_read(ray_start_2_cpus, tmp_path): + path = os.path.join(tmp_path, "bar_000000.tar") + with TarWriter(path) as tf: + for i in range(100): + tf.write(f"{i}.a", str(i).encode("utf-8")) + tf.write(f"{i}.b", str(i**2).encode("utf-8")) + assert os.path.exists(path) + assert len(glob.glob(f"{tmp_path}/*.tar")) == 1 + ds = ray.data.read_webdataset(paths=[str(tmp_path)], parallelism=1) + samples = ds.take(100) + assert len(samples) == 100 + for i, sample in enumerate(samples): + assert isinstance(sample, dict), sample + assert sample["__key__"] == str(i) + assert sample["a"].decode("utf-8") == str(i) + assert sample["b"].decode("utf-8") == str(i**2) + + +def test_webdataset_suffixes(ray_start_2_cpus, tmp_path): + path = os.path.join(tmp_path, "bar_000000.tar") + with TarWriter(path) as tf: + for i in range(100): + tf.write(f"{i}.txt", str(i).encode("utf-8")) + tf.write(f"{i}.test.txt", str(i**2).encode("utf-8")) + tf.write(f"{i}.cls", str(i**2).encode("utf-8")) + tf.write(f"{i}.test.cls2", str(i**2).encode("utf-8")) + assert os.path.exists(path) + assert len(glob.glob(f"{tmp_path}/*.tar")) == 1 + + # test simple suffixes + ds = ray.data.read_webdataset( + paths=[str(tmp_path)], parallelism=1, suffixes=["txt", "cls"] + ) + samples = ds.take(100) + assert len(samples) == 100 + for i, sample in enumerate(samples): + assert set(sample.keys()) == {"__url__", "__key__", "txt", "cls"} + + # test fnmatch patterns for suffixes + ds = ray.data.read_webdataset( + paths=[str(tmp_path)], parallelism=1, suffixes=["*.txt", "*.cls"] + ) + samples = ds.take(100) + assert len(samples) == 100 + for i, sample in enumerate(samples): + assert set(sample.keys()) == {"__url__", "__key__", "txt", "cls", "test.txt"} + + # test selection function + def select(name): + return name.endswith("txt") + + ds = ray.data.read_webdataset(paths=[str(tmp_path)], parallelism=1, suffixes=select) + samples = ds.take(100) + assert len(samples) == 100 + for i, sample in enumerate(samples): + assert set(sample.keys()) == {"__url__", "__key__", "txt", "test.txt"} + + # test filerename + def renamer(name): + result = name.replace("txt", "text") + print("***", name, result) + return result + + ds = ray.data.read_webdataset( + paths=[str(tmp_path)], parallelism=1, filerename=renamer + ) + samples = ds.take(100) + assert len(samples) == 100 + for i, sample in enumerate(samples): + assert set(sample.keys()) == { + "__url__", + "__key__", + "text", + "cls", + "test.text", + "test.cls2", + } + + +def test_webdataset_write(ray_start_2_cpus, tmp_path): + print(ray.available_resources()) + data = [dict(__key__=str(i), a=str(i), b=str(i**2)) for i in range(100)] + ds = ray.data.from_items(data).repartition(1) + ds.write_webdataset(path=tmp_path, try_create_dir=True) + paths = glob.glob(f"{tmp_path}/*.tar") + assert len(paths) == 1 + with open(paths[0], "rb") as stream: + tf = tarfile.open(fileobj=stream) + for i in range(100): + assert tf.extractfile(f"{i}.a").read().decode("utf-8") == str(i) + assert tf.extractfile(f"{i}.b").read().decode("utf-8") == str(i**2) + + +def custom_decoder(sample): + for key, value in sample.items(): + if key == "png": + # check that images have already been decoded + assert not isinstance(value, bytes) + elif key.endswith("custom"): + sample[key] = "custom-value" + return sample + + +def test_webdataset_coding(ray_start_2_cpus, tmp_path): + import numpy as np + import torch + import PIL.Image + + image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8) + gray = np.random.randint(0, 255, (100, 100), dtype=np.uint8) + dstruct = dict(a=[1], b=dict(c=2), d="hello") + ttensor = torch.tensor([1, 2, 3]) + + sample = { + "__key__": "foo", + "jpg": image, + "gray.png": gray, + "mp": dstruct, + "json": dstruct, + "pt": ttensor, + "und": b"undecoded", + "custom": b"nothing", + } + + # write the encoded data using the default encoder + data = [sample] + ds = ray.data.from_items(data).repartition(1) + ds.write_webdataset(path=tmp_path, try_create_dir=True) + + # read the encoded data using the default decoder + paths = glob.glob(f"{tmp_path}/*.tar") + assert len(paths) == 1 + path = paths[0] + assert os.path.exists(path) + ds = ray.data.read_webdataset(paths=[str(tmp_path)], parallelism=1) + samples = ds.take(1) + assert len(samples) == 1 + for sample in samples: + assert isinstance(sample, dict), sample + assert sample["__key__"] == "foo" + assert isinstance(sample["jpg"], np.ndarray) + assert sample["jpg"].shape == (100, 100, 3) + assert isinstance(sample["gray.png"], np.ndarray) + assert sample["gray.png"].shape == (100, 100) + assert isinstance(sample["mp"], dict) + assert sample["mp"]["a"] == [1] + assert sample["mp"]["b"]["c"] == 2 + assert isinstance(sample["json"], dict) + assert sample["json"]["a"] == [1] + assert isinstance(sample["pt"], torch.Tensor) + assert sample["pt"].tolist() == [1, 2, 3] + + # test the format argument to the default decoder and multiple decoders + ds = ray.data.read_webdataset( + paths=[str(tmp_path)], parallelism=1, decoder=["PIL", custom_decoder] + ) + samples = ds.take(1) + assert len(samples) == 1 + for sample in samples: + assert isinstance(sample, dict), sample + assert sample["__key__"] == "foo" + assert isinstance(sample["jpg"], PIL.Image.Image) + assert isinstance(sample["gray.png"], PIL.Image.Image) + assert isinstance(sample["und"], bytes) + assert sample["und"] == b"undecoded" + assert sample["custom"] == "custom-value" + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(["-v", __file__]))