Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[data] correct api annotation for tfrecords_datasource #46171

Merged
merged 2 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/data/api/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Ray Data API
input_output.rst
dataset.rst
data_iterator.rst
data_source.rst
execution_options.rst
grouped_data.rst
data_context.rst
Expand Down
12 changes: 12 additions & 0 deletions doc/source/data/api/data_source.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
.. _data-source-api:

Global configuration
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To minimize clutter in the sidebar and improve the discoverability of TFXReadOptions, could we co-locate this reference with read_tfrecords in the Input/Output page? I don't think we should place TFXReadOptions on a page by itself if it's only used with read_tfrecords

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

totally

====================

.. currentmodule:: ray.data.datasource

.. autosummary::
:nosignatures:
:toctree: doc/

tfrecords_datasource.TFXReadOptions
124 changes: 124 additions & 0 deletions python/ray/data/_internal/datasource/tfrecords_datasource.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import logging
from typing import TYPE_CHECKING, Iterator, List, Optional, Union

import pyarrow

from ray.air.util.tensor_extensions.arrow import pyarrow_table_from_pydict
from ray.data.block import Block
from ray.data.datasource.file_based_datasource import FileBasedDatasource
from ray.data.datasource.tfrecords_datasource import (
TFXReadOptions,
_cast_large_list_to_list,
_convert_example_to_dict,
_read_records,
)

if TYPE_CHECKING:
from tensorflow_metadata.proto.v0 import schema_pb2

logger = logging.getLogger(__name__)


class TFRecordDatasource(FileBasedDatasource):
"""TFRecord datasource, for reading and writing TFRecord files."""

_FILE_EXTENSIONS = ["tfrecords"]

def __init__(
self,
paths: Union[str, List[str]],
tf_schema: Optional["schema_pb2.Schema"] = None,
tfx_read_options: Optional[TFXReadOptions] = None,
**file_based_datasource_kwargs,
):
"""
Args:
tf_schema: Optional TensorFlow Schema which is used to explicitly set
the schema of the underlying Dataset.
tfx_read_options: Optional options for enabling reading tfrecords
using tfx-bsl.

"""
super().__init__(paths, **file_based_datasource_kwargs)

self._tf_schema = tf_schema
self._tfx_read_options = tfx_read_options

def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
if self._tfx_read_options:
yield from self._tfx_read_stream(f, path)
else:
yield from self._default_read_stream(f, path)

def _default_read_stream(
self, f: "pyarrow.NativeFile", path: str
) -> Iterator[Block]:
import tensorflow as tf
from google.protobuf.message import DecodeError

for record in _read_records(f, path):
example = tf.train.Example()
try:
example.ParseFromString(record)
except DecodeError as e:
raise ValueError(
"`TFRecordDatasource` failed to parse `tf.train.Example` "
f"record in '{path}'. This error can occur if your TFRecord "
f"file contains a message type other than `tf.train.Example`: {e}"
)

yield pyarrow_table_from_pydict(
_convert_example_to_dict(example, self._tf_schema)
)

def _tfx_read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
import tensorflow as tf
from tfx_bsl.cc.tfx_bsl_extension.coders import ExamplesToRecordBatchDecoder

full_path = self._resolve_full_path(path)

compression = (self._open_stream_args or {}).get("compression", None)

if compression:
compression = compression.upper()

tf_schema_string = (
self._tf_schema.SerializeToString() if self._tf_schema else None
)

decoder = ExamplesToRecordBatchDecoder(tf_schema_string)
exception_thrown = None
try:
for record in tf.data.TFRecordDataset(
full_path, compression_type=compression
).batch(self._tfx_read_options.batch_size):
yield _cast_large_list_to_list(
pyarrow.Table.from_batches([decoder.DecodeBatch(record.numpy())])
)
except Exception as error:
logger.exception(f"Failed to read TFRecord file {full_path}")
exception_thrown = error

# we need to do this hack were we raise an exception outside of the
# except block because tensorflow DataLossError is unpickable, and
# even if we raise a runtime error, ray keeps information about the
# original error, which makes it unpickable still.
if exception_thrown:
raise RuntimeError(f"Failed to read TFRecord file {full_path}.")

def _resolve_full_path(self, relative_path):
if isinstance(self._filesystem, pyarrow.fs.S3FileSystem):
return f"s3://{relative_path}"
if isinstance(self._filesystem, pyarrow.fs.GcsFileSystem):
return f"gs://{relative_path}"
if isinstance(self._filesystem, pyarrow.fs.HadoopFileSystem):
return f"hdfs:///{relative_path}"
if isinstance(self._filesystem, pyarrow.fs.PyFileSystem):
protocol = self._filesystem.handler.fs.protocol
if isinstance(protocol, list) or isinstance(protocol, tuple):
protocol = protocol[0]
if protocol == "gcs":
protocol = "gs"
return f"{protocol}://{relative_path}"

return relative_path
2 changes: 0 additions & 2 deletions python/ray/data/datasource/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
from ray.data.datasource.sql_datasource import Connection, SQLDatasource
from ray.data.datasource.text_datasource import TextDatasource
from ray.data.datasource.tfrecords_datasink import _TFRecordDatasink
from ray.data.datasource.tfrecords_datasource import TFRecordDatasource
from ray.data.datasource.torch_datasource import TorchDatasource
from ray.data.datasource.webdataset_datasink import _WebDatasetDatasink
from ray.data.datasource.webdataset_datasource import WebDatasetDatasource
Expand Down Expand Up @@ -100,7 +99,6 @@
"RowBasedFileDatasink",
"TextDatasource",
"_TFRecordDatasink",
"TFRecordDatasource",
"TorchDatasource",
"_WebDatasetDatasink",
"WebDatasetDatasource",
Expand Down
111 changes: 1 addition & 110 deletions python/ray/data/datasource/tfrecords_datasource.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import logging
import struct
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Iterable, Iterator, List, Optional, Union
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Union

import numpy as np
import pyarrow

from ray.air.util.tensor_extensions.arrow import pyarrow_table_from_pydict
from ray.data.aggregate import AggregateFn
from ray.data.block import Block
from ray.data.datasource.file_based_datasource import FileBasedDatasource
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
Expand Down Expand Up @@ -45,112 +42,6 @@ class TFXReadOptions:
auto_infer_schema: bool = True


@PublicAPI(stability="alpha")
class TFRecordDatasource(FileBasedDatasource):
"""TFRecord datasource, for reading and writing TFRecord files."""

_FILE_EXTENSIONS = ["tfrecords"]

def __init__(
self,
paths: Union[str, List[str]],
tf_schema: Optional["schema_pb2.Schema"] = None,
tfx_read_options: Optional[TFXReadOptions] = None,
**file_based_datasource_kwargs,
):
"""
Args:
tf_schema: Optional TensorFlow Schema which is used to explicitly set
the schema of the underlying Dataset.
tfx_read_options: Optional options for enabling reading tfrecords
using tfx-bsl.

"""
super().__init__(paths, **file_based_datasource_kwargs)

self._tf_schema = tf_schema
self._tfx_read_options = tfx_read_options

def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
if self._tfx_read_options:
yield from self._tfx_read_stream(f, path)
else:
yield from self._default_read_stream(f, path)

def _default_read_stream(
self, f: "pyarrow.NativeFile", path: str
) -> Iterator[Block]:
import tensorflow as tf
from google.protobuf.message import DecodeError

for record in _read_records(f, path):
example = tf.train.Example()
try:
example.ParseFromString(record)
except DecodeError as e:
raise ValueError(
"`TFRecordDatasource` failed to parse `tf.train.Example` "
f"record in '{path}'. This error can occur if your TFRecord "
f"file contains a message type other than `tf.train.Example`: {e}"
)

yield pyarrow_table_from_pydict(
_convert_example_to_dict(example, self._tf_schema)
)

def _tfx_read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
import tensorflow as tf
from tfx_bsl.cc.tfx_bsl_extension.coders import ExamplesToRecordBatchDecoder

full_path = self._resolve_full_path(path)

compression = (self._open_stream_args or {}).get("compression", None)

if compression:
compression = compression.upper()

tf_schema_string = (
self._tf_schema.SerializeToString() if self._tf_schema else None
)

decoder = ExamplesToRecordBatchDecoder(tf_schema_string)
exception_thrown = None
try:
for record in tf.data.TFRecordDataset(
full_path, compression_type=compression
).batch(self._tfx_read_options.batch_size):
yield _cast_large_list_to_list(
pyarrow.Table.from_batches([decoder.DecodeBatch(record.numpy())])
)
except Exception as error:
logger.exception(f"Failed to read TFRecord file {full_path}")
exception_thrown = error

# we need to do this hack were we raise an exception outside of the
# except block because tensorflow DataLossError is unpickable, and
# even if we raise a runtime error, ray keeps information about the
# original error, which makes it unpickable still.
if exception_thrown:
raise RuntimeError(f"Failed to read TFRecord file {full_path}.")

def _resolve_full_path(self, relative_path):
if isinstance(self._filesystem, pyarrow.fs.S3FileSystem):
return f"s3://{relative_path}"
if isinstance(self._filesystem, pyarrow.fs.GcsFileSystem):
return f"gs://{relative_path}"
if isinstance(self._filesystem, pyarrow.fs.HadoopFileSystem):
return f"hdfs:///{relative_path}"
if isinstance(self._filesystem, pyarrow.fs.PyFileSystem):
protocol = self._filesystem.handler.fs.protocol
if isinstance(protocol, list) or isinstance(protocol, tuple):
protocol = protocol[0]
if protocol == "gcs":
protocol = "gs"
return f"{protocol}://{relative_path}"

return relative_path


def _convert_example_to_dict(
example: "tf.train.Example",
tf_schema: Optional["schema_pb2.Schema"],
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/read_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ray._private.auto_init_hook import wrap_auto_init
from ray.air.util.tensor_extensions.utils import _create_possibly_ragged_ndarray
from ray.data._internal.datasource.avro_datasource import AvroDatasource
from ray.data._internal.datasource.tfrecords_datasource import TFRecordDatasource
from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
from ray.data._internal.logical.operators.from_operators import (
FromArrow,
Expand Down Expand Up @@ -65,7 +66,6 @@
RangeDatasource,
SQLDatasource,
TextDatasource,
TFRecordDatasource,
TorchDatasource,
WebDatasetDatasource,
)
Expand Down
Loading