diff --git a/python/ray/air/util/data_batch_conversion.py b/python/ray/air/util/data_batch_conversion.py index c640faab69e3..3af016f77cc7 100644 --- a/python/ray/air/util/data_batch_conversion.py +++ b/python/ray/air/util/data_batch_conversion.py @@ -7,14 +7,8 @@ from ray.air.data_batch_type import DataBatchType from ray.air.constants import TENSOR_COLUMN_NAME from ray.util.annotations import DeveloperAPI -from ray.air.util.tensor_extensions.arrow import ArrowTensorType # TODO: Consolidate data conversion edges for arrow bug workaround. -from ray.air.util.transform_pyarrow import ( - _is_column_extension_type, - _concatenate_extension_column, -) - try: import pyarrow except ImportError: @@ -139,6 +133,12 @@ def _convert_batch_type_to_numpy( ) return data elif pyarrow is not None and isinstance(data, pyarrow.Table): + from ray.air.util.tensor_extensions.arrow import ArrowTensorType + from ray.air.util.transform_pyarrow import ( + _is_column_extension_type, + _concatenate_extension_column, + ) + if data.column_names == [TENSOR_COLUMN_NAME] and ( isinstance(data.schema.types[0], ArrowTensorType) ): diff --git a/python/ray/data/_internal/arrow_block.py b/python/ray/data/_internal/arrow_block.py index 7b467477ce98..f1b656459661 100644 --- a/python/ray/data/_internal/arrow_block.py +++ b/python/ray/data/_internal/arrow_block.py @@ -16,10 +16,6 @@ import numpy as np -from ray.air.util.transform_pyarrow import ( - _concatenate_extension_column, - _is_column_extension_type, -) from ray.data._internal.arrow_ops import transform_polars, transform_pyarrow from ray.data._internal.table_block import ( VALUE_COL_NAME, @@ -193,6 +189,11 @@ def to_pandas(self) -> "pandas.DataFrame": def to_numpy( self, columns: Optional[Union[str, List[str]]] = None ) -> Union[np.ndarray, Dict[str, np.ndarray]]: + from ray.air.util.transform_pyarrow import ( + _concatenate_extension_column, + _is_column_extension_type, + ) + if columns is None: columns = self._table.column_names if not isinstance(columns, list): @@ -597,6 +598,10 @@ def gen(): def _copy_table(table: "pyarrow.Table") -> "pyarrow.Table": """Copy the provided Arrow table.""" import pyarrow as pa + from ray.air.util.transform_pyarrow import ( + _concatenate_extension_column, + _is_column_extension_type, + ) # Copy the table by copying each column and constructing a new table with # the same schema. diff --git a/python/ray/data/_internal/arrow_ops/transform_pyarrow.py b/python/ray/data/_internal/arrow_ops/transform_pyarrow.py index 0262b89984c0..1e93298369cd 100644 --- a/python/ray/data/_internal/arrow_ops/transform_pyarrow.py +++ b/python/ray/data/_internal/arrow_ops/transform_pyarrow.py @@ -1,10 +1,5 @@ from typing import TYPE_CHECKING, List, Union -from ray.air.util.transform_pyarrow import ( - _is_column_extension_type, - _concatenate_extension_column, -) - try: import pyarrow except ImportError: @@ -31,6 +26,11 @@ def take_table( extension arrays. This is exposed as a static method for easier use on intermediate tables, not underlying an ArrowBlockAccessor. """ + from ray.air.util.transform_pyarrow import ( + _is_column_extension_type, + _concatenate_extension_column, + ) + if any(_is_column_extension_type(col) for col in table.columns): new_cols = [] for col in table.columns: diff --git a/python/ray/data/preprocessor.py b/python/ray/data/preprocessor.py index cc5f59bc2c20..1ece54642276 100644 --- a/python/ray/data/preprocessor.py +++ b/python/ray/data/preprocessor.py @@ -3,14 +3,12 @@ from enum import Enum from typing import TYPE_CHECKING, Optional, Union, Dict -import numpy as np - from ray.data import Dataset from ray.util.annotations import DeveloperAPI, PublicAPI if TYPE_CHECKING: import pandas as pd - + import numpy as np from ray.air.data_batch_type import DataBatchType @@ -262,6 +260,7 @@ def _transform(self, dataset: Dataset) -> Dataset: def _transform_batch(self, data: "DataBatchType") -> "DataBatchType": # For minimal install to locally import air modules import pandas as pd + import numpy as np from ray.air.util.data_batch_conversion import ( convert_batch_type_to_pandas, _convert_batch_type_to_numpy, @@ -299,7 +298,7 @@ def _transform_pandas(self, df: "pd.DataFrame") -> "pd.DataFrame": @DeveloperAPI def _transform_numpy( - self, np_data: Union[np.ndarray, Dict[str, np.ndarray]] - ) -> Union[np.ndarray, Dict[str, np.ndarray]]: + self, np_data: Union["np.ndarray", Dict[str, "np.ndarray"]] + ) -> Union["np.ndarray", Dict[str, "np.ndarray"]]: """Run the transformation on a data batch in a NumPy ndarray format.""" raise NotImplementedError()