Skip to content

Commit

Permalink
[data] correct api annotation for tfrecords_datasource
Browse files Browse the repository at this point in the history
Signed-off-by: can <[email protected]>
  • Loading branch information
can-anyscale committed Jun 21, 2024
1 parent 223233c commit 00cf708
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 113 deletions.
1 change: 1 addition & 0 deletions doc/source/data/api/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Ray Data API
input_output.rst
dataset.rst
data_iterator.rst
data_source.rst
execution_options.rst
grouped_data.rst
data_context.rst
Expand Down
12 changes: 12 additions & 0 deletions doc/source/data/api/data_source.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
.. _data-source-api:

Global configuration
====================

.. currentmodule:: ray.data.datasource

.. autosummary::
:nosignatures:
:toctree: doc/

tfrecords_datasource.TFXReadOptions
124 changes: 124 additions & 0 deletions python/ray/data/_internal/datasource/tfrecords_datasource.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import logging
from typing import TYPE_CHECKING, Iterator, List, Optional, Union

import pyarrow

from ray.air.util.tensor_extensions.arrow import pyarrow_table_from_pydict
from ray.data.block import Block
from ray.data.datasource.file_based_datasource import FileBasedDatasource
from ray.data.datasource.tfrecords_datasource import (
TFXReadOptions,
_cast_large_list_to_list,
_convert_example_to_dict,
_read_records,
)

if TYPE_CHECKING:
from tensorflow_metadata.proto.v0 import schema_pb2

logger = logging.getLogger(__name__)


class TFRecordDatasource(FileBasedDatasource):
"""TFRecord datasource, for reading and writing TFRecord files."""

_FILE_EXTENSIONS = ["tfrecords"]

def __init__(
self,
paths: Union[str, List[str]],
tf_schema: Optional["schema_pb2.Schema"] = None,
tfx_read_options: Optional[TFXReadOptions] = None,
**file_based_datasource_kwargs,
):
"""
Args:
tf_schema: Optional TensorFlow Schema which is used to explicitly set
the schema of the underlying Dataset.
tfx_read_options: Optional options for enabling reading tfrecords
using tfx-bsl.
"""
super().__init__(paths, **file_based_datasource_kwargs)

self._tf_schema = tf_schema
self._tfx_read_options = tfx_read_options

def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
if self._tfx_read_options:
yield from self._tfx_read_stream(f, path)
else:
yield from self._default_read_stream(f, path)

def _default_read_stream(
self, f: "pyarrow.NativeFile", path: str
) -> Iterator[Block]:
import tensorflow as tf
from google.protobuf.message import DecodeError

for record in _read_records(f, path):
example = tf.train.Example()
try:
example.ParseFromString(record)
except DecodeError as e:
raise ValueError(
"`TFRecordDatasource` failed to parse `tf.train.Example` "
f"record in '{path}'. This error can occur if your TFRecord "
f"file contains a message type other than `tf.train.Example`: {e}"
)

yield pyarrow_table_from_pydict(
_convert_example_to_dict(example, self._tf_schema)
)

def _tfx_read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
import tensorflow as tf
from tfx_bsl.cc.tfx_bsl_extension.coders import ExamplesToRecordBatchDecoder

full_path = self._resolve_full_path(path)

compression = (self._open_stream_args or {}).get("compression", None)

if compression:
compression = compression.upper()

tf_schema_string = (
self._tf_schema.SerializeToString() if self._tf_schema else None
)

decoder = ExamplesToRecordBatchDecoder(tf_schema_string)
exception_thrown = None
try:
for record in tf.data.TFRecordDataset(
full_path, compression_type=compression
).batch(self._tfx_read_options.batch_size):
yield _cast_large_list_to_list(
pyarrow.Table.from_batches([decoder.DecodeBatch(record.numpy())])
)
except Exception as error:
logger.exception(f"Failed to read TFRecord file {full_path}")
exception_thrown = error

# we need to do this hack were we raise an exception outside of the
# except block because tensorflow DataLossError is unpickable, and
# even if we raise a runtime error, ray keeps information about the
# original error, which makes it unpickable still.
if exception_thrown:
raise RuntimeError(f"Failed to read TFRecord file {full_path}.")

def _resolve_full_path(self, relative_path):
if isinstance(self._filesystem, pyarrow.fs.S3FileSystem):
return f"s3://{relative_path}"
if isinstance(self._filesystem, pyarrow.fs.GcsFileSystem):
return f"gs://{relative_path}"
if isinstance(self._filesystem, pyarrow.fs.HadoopFileSystem):
return f"hdfs:///{relative_path}"
if isinstance(self._filesystem, pyarrow.fs.PyFileSystem):
protocol = self._filesystem.handler.fs.protocol
if isinstance(protocol, list) or isinstance(protocol, tuple):
protocol = protocol[0]
if protocol == "gcs":
protocol = "gs"
return f"{protocol}://{relative_path}"

return relative_path
2 changes: 0 additions & 2 deletions python/ray/data/datasource/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
from ray.data.datasource.sql_datasource import Connection, SQLDatasource
from ray.data.datasource.text_datasource import TextDatasource
from ray.data.datasource.tfrecords_datasink import _TFRecordDatasink
from ray.data.datasource.tfrecords_datasource import TFRecordDatasource
from ray.data.datasource.torch_datasource import TorchDatasource
from ray.data.datasource.webdataset_datasink import _WebDatasetDatasink
from ray.data.datasource.webdataset_datasource import WebDatasetDatasource
Expand Down Expand Up @@ -100,7 +99,6 @@
"RowBasedFileDatasink",
"TextDatasource",
"_TFRecordDatasink",
"TFRecordDatasource",
"TorchDatasource",
"_WebDatasetDatasink",
"WebDatasetDatasource",
Expand Down
111 changes: 1 addition & 110 deletions python/ray/data/datasource/tfrecords_datasource.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import logging
import struct
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Iterable, Iterator, List, Optional, Union
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Union

import numpy as np
import pyarrow

from ray.air.util.tensor_extensions.arrow import pyarrow_table_from_pydict
from ray.data.aggregate import AggregateFn
from ray.data.block import Block
from ray.data.datasource.file_based_datasource import FileBasedDatasource
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
Expand Down Expand Up @@ -45,112 +42,6 @@ class TFXReadOptions:
auto_infer_schema: bool = True


@PublicAPI(stability="alpha")
class TFRecordDatasource(FileBasedDatasource):
"""TFRecord datasource, for reading and writing TFRecord files."""

_FILE_EXTENSIONS = ["tfrecords"]

def __init__(
self,
paths: Union[str, List[str]],
tf_schema: Optional["schema_pb2.Schema"] = None,
tfx_read_options: Optional[TFXReadOptions] = None,
**file_based_datasource_kwargs,
):
"""
Args:
tf_schema: Optional TensorFlow Schema which is used to explicitly set
the schema of the underlying Dataset.
tfx_read_options: Optional options for enabling reading tfrecords
using tfx-bsl.
"""
super().__init__(paths, **file_based_datasource_kwargs)

self._tf_schema = tf_schema
self._tfx_read_options = tfx_read_options

def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
if self._tfx_read_options:
yield from self._tfx_read_stream(f, path)
else:
yield from self._default_read_stream(f, path)

def _default_read_stream(
self, f: "pyarrow.NativeFile", path: str
) -> Iterator[Block]:
import tensorflow as tf
from google.protobuf.message import DecodeError

for record in _read_records(f, path):
example = tf.train.Example()
try:
example.ParseFromString(record)
except DecodeError as e:
raise ValueError(
"`TFRecordDatasource` failed to parse `tf.train.Example` "
f"record in '{path}'. This error can occur if your TFRecord "
f"file contains a message type other than `tf.train.Example`: {e}"
)

yield pyarrow_table_from_pydict(
_convert_example_to_dict(example, self._tf_schema)
)

def _tfx_read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
import tensorflow as tf
from tfx_bsl.cc.tfx_bsl_extension.coders import ExamplesToRecordBatchDecoder

full_path = self._resolve_full_path(path)

compression = (self._open_stream_args or {}).get("compression", None)

if compression:
compression = compression.upper()

tf_schema_string = (
self._tf_schema.SerializeToString() if self._tf_schema else None
)

decoder = ExamplesToRecordBatchDecoder(tf_schema_string)
exception_thrown = None
try:
for record in tf.data.TFRecordDataset(
full_path, compression_type=compression
).batch(self._tfx_read_options.batch_size):
yield _cast_large_list_to_list(
pyarrow.Table.from_batches([decoder.DecodeBatch(record.numpy())])
)
except Exception as error:
logger.exception(f"Failed to read TFRecord file {full_path}")
exception_thrown = error

# we need to do this hack were we raise an exception outside of the
# except block because tensorflow DataLossError is unpickable, and
# even if we raise a runtime error, ray keeps information about the
# original error, which makes it unpickable still.
if exception_thrown:
raise RuntimeError(f"Failed to read TFRecord file {full_path}.")

def _resolve_full_path(self, relative_path):
if isinstance(self._filesystem, pyarrow.fs.S3FileSystem):
return f"s3://{relative_path}"
if isinstance(self._filesystem, pyarrow.fs.GcsFileSystem):
return f"gs://{relative_path}"
if isinstance(self._filesystem, pyarrow.fs.HadoopFileSystem):
return f"hdfs:///{relative_path}"
if isinstance(self._filesystem, pyarrow.fs.PyFileSystem):
protocol = self._filesystem.handler.fs.protocol
if isinstance(protocol, list) or isinstance(protocol, tuple):
protocol = protocol[0]
if protocol == "gcs":
protocol = "gs"
return f"{protocol}://{relative_path}"

return relative_path


def _convert_example_to_dict(
example: "tf.train.Example",
tf_schema: Optional["schema_pb2.Schema"],
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/read_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ray._private.auto_init_hook import wrap_auto_init
from ray.air.util.tensor_extensions.utils import _create_possibly_ragged_ndarray
from ray.data._internal.datasource.avro_datasource import AvroDatasource
from ray.data._internal.datasource.tfrecords_datasource import TFRecordDatasource
from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
from ray.data._internal.logical.operators.from_operators import (
FromArrow,
Expand Down Expand Up @@ -65,7 +66,6 @@
RangeDatasource,
SQLDatasource,
TextDatasource,
TFRecordDatasource,
TorchDatasource,
WebDatasetDatasource,
)
Expand Down

0 comments on commit 00cf708

Please sign in to comment.