Skip to content

Commit

Permalink
[AIR] Inline AIR level ray.data imports (#29517)
Browse files Browse the repository at this point in the history
This is a quick and relatively safer attempt to address #29324

In #28418 we attempted to unify ray.air utils with shared utils function but triggered expensive ray.data imports.

Where longer term and more robust solution should be #27658
  • Loading branch information
jiaodong authored Oct 22, 2022
1 parent a736831 commit 3f0c294
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 20 deletions.
12 changes: 6 additions & 6 deletions python/ray/air/util/data_batch_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,8 @@
from ray.air.data_batch_type import DataBatchType
from ray.air.constants import TENSOR_COLUMN_NAME
from ray.util.annotations import DeveloperAPI
from ray.air.util.tensor_extensions.arrow import ArrowTensorType

# TODO: Consolidate data conversion edges for arrow bug workaround.
from ray.air.util.transform_pyarrow import (
_is_column_extension_type,
_concatenate_extension_column,
)

try:
import pyarrow
except ImportError:
Expand Down Expand Up @@ -139,6 +133,12 @@ def _convert_batch_type_to_numpy(
)
return data
elif pyarrow is not None and isinstance(data, pyarrow.Table):
from ray.air.util.tensor_extensions.arrow import ArrowTensorType
from ray.air.util.transform_pyarrow import (
_is_column_extension_type,
_concatenate_extension_column,
)

if data.column_names == [TENSOR_COLUMN_NAME] and (
isinstance(data.schema.types[0], ArrowTensorType)
):
Expand Down
13 changes: 9 additions & 4 deletions python/ray/data/_internal/arrow_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@

import numpy as np

from ray.air.util.transform_pyarrow import (
_concatenate_extension_column,
_is_column_extension_type,
)
from ray.data._internal.arrow_ops import transform_polars, transform_pyarrow
from ray.data._internal.table_block import (
VALUE_COL_NAME,
Expand Down Expand Up @@ -193,6 +189,11 @@ def to_pandas(self) -> "pandas.DataFrame":
def to_numpy(
self, columns: Optional[Union[str, List[str]]] = None
) -> Union[np.ndarray, Dict[str, np.ndarray]]:
from ray.air.util.transform_pyarrow import (
_concatenate_extension_column,
_is_column_extension_type,
)

if columns is None:
columns = self._table.column_names
if not isinstance(columns, list):
Expand Down Expand Up @@ -597,6 +598,10 @@ def gen():
def _copy_table(table: "pyarrow.Table") -> "pyarrow.Table":
"""Copy the provided Arrow table."""
import pyarrow as pa
from ray.air.util.transform_pyarrow import (
_concatenate_extension_column,
_is_column_extension_type,
)

# Copy the table by copying each column and constructing a new table with
# the same schema.
Expand Down
10 changes: 5 additions & 5 deletions python/ray/data/_internal/arrow_ops/transform_pyarrow.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
from typing import TYPE_CHECKING, List, Union

from ray.air.util.transform_pyarrow import (
_is_column_extension_type,
_concatenate_extension_column,
)

try:
import pyarrow
except ImportError:
Expand All @@ -31,6 +26,11 @@ def take_table(
extension arrays. This is exposed as a static method for easier use on
intermediate tables, not underlying an ArrowBlockAccessor.
"""
from ray.air.util.transform_pyarrow import (
_is_column_extension_type,
_concatenate_extension_column,
)

if any(_is_column_extension_type(col) for col in table.columns):
new_cols = []
for col in table.columns:
Expand Down
9 changes: 4 additions & 5 deletions python/ray/data/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@
from enum import Enum
from typing import TYPE_CHECKING, Optional, Union, Dict

import numpy as np

from ray.data import Dataset
from ray.util.annotations import DeveloperAPI, PublicAPI

if TYPE_CHECKING:
import pandas as pd

import numpy as np
from ray.air.data_batch_type import DataBatchType


Expand Down Expand Up @@ -262,6 +260,7 @@ def _transform(self, dataset: Dataset) -> Dataset:
def _transform_batch(self, data: "DataBatchType") -> "DataBatchType":
# For minimal install to locally import air modules
import pandas as pd
import numpy as np
from ray.air.util.data_batch_conversion import (
convert_batch_type_to_pandas,
_convert_batch_type_to_numpy,
Expand Down Expand Up @@ -299,7 +298,7 @@ def _transform_pandas(self, df: "pd.DataFrame") -> "pd.DataFrame":

@DeveloperAPI
def _transform_numpy(
self, np_data: Union[np.ndarray, Dict[str, np.ndarray]]
) -> Union[np.ndarray, Dict[str, np.ndarray]]:
self, np_data: Union["np.ndarray", Dict[str, "np.ndarray"]]
) -> Union["np.ndarray", Dict[str, "np.ndarray"]]:
"""Run the transformation on a data batch in a NumPy ndarray format."""
raise NotImplementedError()

0 comments on commit 3f0c294

Please sign in to comment.