From 00cf708b6ba252c271a9dd31714f67fa0a2b113f Mon Sep 17 00:00:00 2001 From: can Date: Fri, 21 Jun 2024 00:43:18 +0000 Subject: [PATCH 1/2] [data] correct api annotation for tfrecords_datasource Signed-off-by: can --- doc/source/data/api/api.rst | 1 + doc/source/data/api/data_source.rst | 12 ++ .../datasource/tfrecords_datasource.py | 124 ++++++++++++++++++ python/ray/data/datasource/__init__.py | 2 - .../data/datasource/tfrecords_datasource.py | 111 +--------------- python/ray/data/read_api.py | 2 +- 6 files changed, 139 insertions(+), 113 deletions(-) create mode 100644 doc/source/data/api/data_source.rst create mode 100644 python/ray/data/_internal/datasource/tfrecords_datasource.py diff --git a/doc/source/data/api/api.rst b/doc/source/data/api/api.rst index e0b6df24ac30..dc87cf144f92 100644 --- a/doc/source/data/api/api.rst +++ b/doc/source/data/api/api.rst @@ -9,6 +9,7 @@ Ray Data API input_output.rst dataset.rst data_iterator.rst + data_source.rst execution_options.rst grouped_data.rst data_context.rst diff --git a/doc/source/data/api/data_source.rst b/doc/source/data/api/data_source.rst new file mode 100644 index 000000000000..570d4ffdae0d --- /dev/null +++ b/doc/source/data/api/data_source.rst @@ -0,0 +1,12 @@ +.. _data-source-api: + +Global configuration +==================== + +.. currentmodule:: ray.data.datasource + +.. autosummary:: + :nosignatures: + :toctree: doc/ + + tfrecords_datasource.TFXReadOptions diff --git a/python/ray/data/_internal/datasource/tfrecords_datasource.py b/python/ray/data/_internal/datasource/tfrecords_datasource.py new file mode 100644 index 000000000000..e91be7f99581 --- /dev/null +++ b/python/ray/data/_internal/datasource/tfrecords_datasource.py @@ -0,0 +1,124 @@ +import logging +from typing import TYPE_CHECKING, Iterator, List, Optional, Union + +import pyarrow + +from ray.air.util.tensor_extensions.arrow import pyarrow_table_from_pydict +from ray.data.block import Block +from ray.data.datasource.file_based_datasource import FileBasedDatasource +from ray.data.datasource.tfrecords_datasource import ( + TFXReadOptions, + _cast_large_list_to_list, + _convert_example_to_dict, + _read_records, +) + +if TYPE_CHECKING: + from tensorflow_metadata.proto.v0 import schema_pb2 + +logger = logging.getLogger(__name__) + + +class TFRecordDatasource(FileBasedDatasource): + """TFRecord datasource, for reading and writing TFRecord files.""" + + _FILE_EXTENSIONS = ["tfrecords"] + + def __init__( + self, + paths: Union[str, List[str]], + tf_schema: Optional["schema_pb2.Schema"] = None, + tfx_read_options: Optional[TFXReadOptions] = None, + **file_based_datasource_kwargs, + ): + """ + Args: + tf_schema: Optional TensorFlow Schema which is used to explicitly set + the schema of the underlying Dataset. + tfx_read_options: Optional options for enabling reading tfrecords + using tfx-bsl. + + """ + super().__init__(paths, **file_based_datasource_kwargs) + + self._tf_schema = tf_schema + self._tfx_read_options = tfx_read_options + + def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]: + if self._tfx_read_options: + yield from self._tfx_read_stream(f, path) + else: + yield from self._default_read_stream(f, path) + + def _default_read_stream( + self, f: "pyarrow.NativeFile", path: str + ) -> Iterator[Block]: + import tensorflow as tf + from google.protobuf.message import DecodeError + + for record in _read_records(f, path): + example = tf.train.Example() + try: + example.ParseFromString(record) + except DecodeError as e: + raise ValueError( + "`TFRecordDatasource` failed to parse `tf.train.Example` " + f"record in '{path}'. This error can occur if your TFRecord " + f"file contains a message type other than `tf.train.Example`: {e}" + ) + + yield pyarrow_table_from_pydict( + _convert_example_to_dict(example, self._tf_schema) + ) + + def _tfx_read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]: + import tensorflow as tf + from tfx_bsl.cc.tfx_bsl_extension.coders import ExamplesToRecordBatchDecoder + + full_path = self._resolve_full_path(path) + + compression = (self._open_stream_args or {}).get("compression", None) + + if compression: + compression = compression.upper() + + tf_schema_string = ( + self._tf_schema.SerializeToString() if self._tf_schema else None + ) + + decoder = ExamplesToRecordBatchDecoder(tf_schema_string) + exception_thrown = None + try: + for record in tf.data.TFRecordDataset( + full_path, compression_type=compression + ).batch(self._tfx_read_options.batch_size): + yield _cast_large_list_to_list( + pyarrow.Table.from_batches([decoder.DecodeBatch(record.numpy())]) + ) + except Exception as error: + logger.exception(f"Failed to read TFRecord file {full_path}") + exception_thrown = error + + # we need to do this hack were we raise an exception outside of the + # except block because tensorflow DataLossError is unpickable, and + # even if we raise a runtime error, ray keeps information about the + # original error, which makes it unpickable still. + if exception_thrown: + raise RuntimeError(f"Failed to read TFRecord file {full_path}.") + + def _resolve_full_path(self, relative_path): + if isinstance(self._filesystem, pyarrow.fs.S3FileSystem): + return f"s3://{relative_path}" + if isinstance(self._filesystem, pyarrow.fs.GcsFileSystem): + return f"gs://{relative_path}" + if isinstance(self._filesystem, pyarrow.fs.HadoopFileSystem): + return f"hdfs:///{relative_path}" + if isinstance(self._filesystem, pyarrow.fs.PyFileSystem): + protocol = self._filesystem.handler.fs.protocol + if isinstance(protocol, list) or isinstance(protocol, tuple): + protocol = protocol[0] + if protocol == "gcs": + protocol = "gs" + return f"{protocol}://{relative_path}" + + return relative_path diff --git a/python/ray/data/datasource/__init__.py b/python/ray/data/datasource/__init__.py index 32a2f5c8fe5c..65cfdc7aa9fe 100644 --- a/python/ray/data/datasource/__init__.py +++ b/python/ray/data/datasource/__init__.py @@ -49,7 +49,6 @@ from ray.data.datasource.sql_datasource import Connection, SQLDatasource from ray.data.datasource.text_datasource import TextDatasource from ray.data.datasource.tfrecords_datasink import _TFRecordDatasink -from ray.data.datasource.tfrecords_datasource import TFRecordDatasource from ray.data.datasource.torch_datasource import TorchDatasource from ray.data.datasource.webdataset_datasink import _WebDatasetDatasink from ray.data.datasource.webdataset_datasource import WebDatasetDatasource @@ -100,7 +99,6 @@ "RowBasedFileDatasink", "TextDatasource", "_TFRecordDatasink", - "TFRecordDatasource", "TorchDatasource", "_WebDatasetDatasink", "WebDatasetDatasource", diff --git a/python/ray/data/datasource/tfrecords_datasource.py b/python/ray/data/datasource/tfrecords_datasource.py index be116c1344d6..bc5bd4d829be 100644 --- a/python/ray/data/datasource/tfrecords_datasource.py +++ b/python/ray/data/datasource/tfrecords_datasource.py @@ -1,15 +1,12 @@ import logging import struct from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, Iterable, Iterator, List, Optional, Union +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Union import numpy as np import pyarrow -from ray.air.util.tensor_extensions.arrow import pyarrow_table_from_pydict from ray.data.aggregate import AggregateFn -from ray.data.block import Block -from ray.data.datasource.file_based_datasource import FileBasedDatasource from ray.util.annotations import PublicAPI if TYPE_CHECKING: @@ -45,112 +42,6 @@ class TFXReadOptions: auto_infer_schema: bool = True -@PublicAPI(stability="alpha") -class TFRecordDatasource(FileBasedDatasource): - """TFRecord datasource, for reading and writing TFRecord files.""" - - _FILE_EXTENSIONS = ["tfrecords"] - - def __init__( - self, - paths: Union[str, List[str]], - tf_schema: Optional["schema_pb2.Schema"] = None, - tfx_read_options: Optional[TFXReadOptions] = None, - **file_based_datasource_kwargs, - ): - """ - Args: - tf_schema: Optional TensorFlow Schema which is used to explicitly set - the schema of the underlying Dataset. - tfx_read_options: Optional options for enabling reading tfrecords - using tfx-bsl. - - """ - super().__init__(paths, **file_based_datasource_kwargs) - - self._tf_schema = tf_schema - self._tfx_read_options = tfx_read_options - - def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]: - if self._tfx_read_options: - yield from self._tfx_read_stream(f, path) - else: - yield from self._default_read_stream(f, path) - - def _default_read_stream( - self, f: "pyarrow.NativeFile", path: str - ) -> Iterator[Block]: - import tensorflow as tf - from google.protobuf.message import DecodeError - - for record in _read_records(f, path): - example = tf.train.Example() - try: - example.ParseFromString(record) - except DecodeError as e: - raise ValueError( - "`TFRecordDatasource` failed to parse `tf.train.Example` " - f"record in '{path}'. This error can occur if your TFRecord " - f"file contains a message type other than `tf.train.Example`: {e}" - ) - - yield pyarrow_table_from_pydict( - _convert_example_to_dict(example, self._tf_schema) - ) - - def _tfx_read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]: - import tensorflow as tf - from tfx_bsl.cc.tfx_bsl_extension.coders import ExamplesToRecordBatchDecoder - - full_path = self._resolve_full_path(path) - - compression = (self._open_stream_args or {}).get("compression", None) - - if compression: - compression = compression.upper() - - tf_schema_string = ( - self._tf_schema.SerializeToString() if self._tf_schema else None - ) - - decoder = ExamplesToRecordBatchDecoder(tf_schema_string) - exception_thrown = None - try: - for record in tf.data.TFRecordDataset( - full_path, compression_type=compression - ).batch(self._tfx_read_options.batch_size): - yield _cast_large_list_to_list( - pyarrow.Table.from_batches([decoder.DecodeBatch(record.numpy())]) - ) - except Exception as error: - logger.exception(f"Failed to read TFRecord file {full_path}") - exception_thrown = error - - # we need to do this hack were we raise an exception outside of the - # except block because tensorflow DataLossError is unpickable, and - # even if we raise a runtime error, ray keeps information about the - # original error, which makes it unpickable still. - if exception_thrown: - raise RuntimeError(f"Failed to read TFRecord file {full_path}.") - - def _resolve_full_path(self, relative_path): - if isinstance(self._filesystem, pyarrow.fs.S3FileSystem): - return f"s3://{relative_path}" - if isinstance(self._filesystem, pyarrow.fs.GcsFileSystem): - return f"gs://{relative_path}" - if isinstance(self._filesystem, pyarrow.fs.HadoopFileSystem): - return f"hdfs:///{relative_path}" - if isinstance(self._filesystem, pyarrow.fs.PyFileSystem): - protocol = self._filesystem.handler.fs.protocol - if isinstance(protocol, list) or isinstance(protocol, tuple): - protocol = protocol[0] - if protocol == "gcs": - protocol = "gs" - return f"{protocol}://{relative_path}" - - return relative_path - - def _convert_example_to_dict( example: "tf.train.Example", tf_schema: Optional["schema_pb2.Schema"], diff --git a/python/ray/data/read_api.py b/python/ray/data/read_api.py index 844cefa0f209..141e14e093f6 100644 --- a/python/ray/data/read_api.py +++ b/python/ray/data/read_api.py @@ -22,6 +22,7 @@ from ray._private.auto_init_hook import wrap_auto_init from ray.air.util.tensor_extensions.utils import _create_possibly_ragged_ndarray from ray.data._internal.datasource.avro_datasource import AvroDatasource +from ray.data._internal.datasource.tfrecords_datasource import TFRecordDatasource from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder from ray.data._internal.logical.operators.from_operators import ( FromArrow, @@ -65,7 +66,6 @@ RangeDatasource, SQLDatasource, TextDatasource, - TFRecordDatasource, TorchDatasource, WebDatasetDatasource, ) From 2647635580f8215cc76215dfe648941535db4583 Mon Sep 17 00:00:00 2001 From: can Date: Fri, 28 Jun 2024 16:35:17 +0000 Subject: [PATCH 2/2] put TFXReadOptions into input_output Signed-off-by: can --- doc/source/data/api/api.rst | 1 - doc/source/data/api/data_source.rst | 12 ------------ doc/source/data/api/input_output.rst | 2 +- 3 files changed, 1 insertion(+), 14 deletions(-) delete mode 100644 doc/source/data/api/data_source.rst diff --git a/doc/source/data/api/api.rst b/doc/source/data/api/api.rst index dc87cf144f92..e0b6df24ac30 100644 --- a/doc/source/data/api/api.rst +++ b/doc/source/data/api/api.rst @@ -9,7 +9,6 @@ Ray Data API input_output.rst dataset.rst data_iterator.rst - data_source.rst execution_options.rst grouped_data.rst data_context.rst diff --git a/doc/source/data/api/data_source.rst b/doc/source/data/api/data_source.rst deleted file mode 100644 index 570d4ffdae0d..000000000000 --- a/doc/source/data/api/data_source.rst +++ /dev/null @@ -1,12 +0,0 @@ -.. _data-source-api: - -Global configuration -==================== - -.. currentmodule:: ray.data.datasource - -.. autosummary:: - :nosignatures: - :toctree: doc/ - - tfrecords_datasource.TFXReadOptions diff --git a/doc/source/data/api/input_output.rst b/doc/source/data/api/input_output.rst index 751a87e4dcb3..b160c0d213ef 100644 --- a/doc/source/data/api/input_output.rst +++ b/doc/source/data/api/input_output.rst @@ -101,7 +101,7 @@ TFRecords read_tfrecords Dataset.write_tfrecords - + datasource.tfrecords_datasource.TFXReadOptions Pandas ------