Skip to content

Commit

Permalink
Fix shuffling and sorting of tensor columns, fix random access dataset.
Browse files Browse the repository at this point in the history
  • Loading branch information
clarkzinzow committed May 19, 2022
1 parent 4bf27a0 commit f3a2ac4
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 25 deletions.
100 changes: 79 additions & 21 deletions python/ray/data/impl/arrow_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,11 @@ def from_numpy(cls, data: Union[np.ndarray, List[np.ndarray]]):
)
return cls(table)

@staticmethod
def _build_tensor_row(row: ArrowRow) -> np.ndarray:
# Getting an item in a tensor column automatically does a NumPy conversion.
return row[VALUE_COL_NAME][0]

def slice(self, start: int, end: int, copy: bool) -> "pyarrow.Table":
view = self._table.slice(start, end - start)
if copy:
Expand All @@ -129,7 +134,7 @@ def slice(self, start: int, end: int, copy: bool) -> "pyarrow.Table":

def random_shuffle(self, random_seed: Optional[int]) -> "pyarrow.Table":
random = np.random.RandomState(random_seed)
return self._table.take(random.permutation(self.num_rows()))
return self.take(random.permutation(self.num_rows()))

def schema(self) -> "pyarrow.lib.Schema":
return self._table.schema
Expand All @@ -155,16 +160,10 @@ def to_numpy(
array = self._table[column]
if array.num_chunks == 0:
array = pyarrow.array([], type=array.type)
elif array.num_chunks == 1:
array = array.chunk(0)
elif isinstance(array.chunk(0), pyarrow.ExtensionArray):
# If an extension array, we manually concatenate the underlying storage
# arrays.
chunk = array.chunk(0)
array = type(chunk).from_storage(
chunk.type,
pyarrow.concat_arrays([chunk.storage for chunk in array.chunks]),
)
elif _is_column_extension_type(array):
array = _concatenate_extension_column(array)
else:
array = array.combine_chunks()
arrays.append(array.to_numpy(zero_copy_only=False))
if len(arrays) == 1:
arrays = arrays[0]
Expand Down Expand Up @@ -207,9 +206,45 @@ def builder() -> ArrowBlockBuilder[T]:
def _empty_table() -> "pyarrow.Table":
return ArrowBlockBuilder._empty_table()

@staticmethod
def take_table(
table: "pyarrow.Table",
indices: Union[List[int], "pyarrow.Array", "pyarrow.ChunkedArray"],
) -> "pyarrow.Table":
"""Select rows from the table.
This method is an alternative to pyarrow.Table.take(), which breaks for
extension arrays. This is exposed as a static method for easier use on
intermediate tables, not underlying an ArrowBlockAccessor.
"""
if any(_is_column_extension_type(col) for col in table.columns):
new_cols = []
for col in table.columns:
if _is_column_extension_type(col):
# .take() will concatenate internally, which currently breaks for
# extension arrays.
col = _concatenate_extension_column(col)
new_cols.append(col.take(indices))
table = pyarrow.Table.from_arrays(new_cols, schema=table.schema)
else:
table = table.take(indices)
return table

def take(
self,
indices: Union[List[int], "pyarrow.Array", "pyarrow.ChunkedArray"],
) -> "pyarrow.Table":
"""Select rows from the underlying table.
This method is an alternative to pyarrow.Table.take(), which breaks for
extension arrays.
"""
return self.take_table(self._table, indices)

def _sample(self, n_samples: int, key: "SortKeyT") -> "pyarrow.Table":
indices = random.sample(range(self._table.num_rows), n_samples)
return self._table.select([k[0] for k in key]).take(indices)
table = self._table.select([k[0] for k in key])
return self.take_table(table, indices)

def count(self, on: KeyFn) -> Optional[U]:
"""Count the number of non-null values in the provided column."""
Expand Down Expand Up @@ -306,7 +341,7 @@ def sort_and_partition(
import pyarrow.compute as pac

indices = pac.sort_indices(self._table, sort_keys=key)
table = self._table.take(indices)
table = self.take(indices)
if len(boundaries) == 0:
return [table]

Expand Down Expand Up @@ -431,7 +466,7 @@ def merge_sorted_blocks(
else:
ret = pyarrow.concat_tables(blocks, promote=True)
indices = pyarrow.compute.sort_indices(ret, sort_keys=key)
ret = ret.take(indices)
ret = ArrowBlockAccessor.take_table(ret, indices)
return ret, ArrowBlockAccessor(ret).get_metadata(None, exec_stats=stats.build())

@staticmethod
Expand Down Expand Up @@ -527,6 +562,33 @@ def gen():
return ret, ArrowBlockAccessor(ret).get_metadata(None, exec_stats=stats.build())


def _is_column_extension_type(ca: "pyarrow.ChunkedArray") -> bool:
"""Whether the provided Arrow Table column is an extension array, using an Arrow
extension type.
"""
return isinstance(ca.type, pyarrow.ExtensionType)


def _concatenate_extension_column(ca: "pyarrow.ChunkedArray") -> "pyarrow.Array":
"""Concatenate chunks of an extension column into a contiguous array.
This concatenation is required for creating copies and for .take() to work on
extension arrays.
See https://issues.apache.org/jira/browse/ARROW-16503.
"""
if not _is_column_extension_type(ca):
raise ValueError("Chunked array isn't an extension array: {ca}")

if ca.num_chunks == 0:
# No-op for no-chunk chunked arrays, since there's nothing to concatenate.
return ca

chunk = ca.chunk(0)
return type(chunk).from_storage(
chunk.type, pyarrow.concat_arrays([c.storage for c in ca.chunks])
)


def _copy_table(table: "pyarrow.Table") -> "pyarrow.Table":
"""Copy the provided Arrow table."""
import pyarrow as pa
Expand All @@ -536,14 +598,10 @@ def _copy_table(table: "pyarrow.Table") -> "pyarrow.Table":
cols = table.columns
new_cols = []
for col in cols:
if col.num_chunks > 0 and isinstance(col.chunk(0), pa.ExtensionArray):
# If an extension array, we copy the underlying storage arrays.
chunk = col.chunk(0)
arr = type(chunk).from_storage(
chunk.type, pa.concat_arrays([c.storage for c in col.chunks])
)
if _is_column_extension_type(col):
# Extension arrays don't support concatenation.
arr = _concatenate_extension_column(col)
else:
# Otherwise, we copy the top-level chunk arrays.
arr = col.combine_chunks()
new_cols.append(arr)
return pa.Table.from_arrays(new_cols, schema=table.schema)
6 changes: 6 additions & 0 deletions python/ray/data/impl/pandas_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ def __init__(self, table: "pandas.DataFrame"):
def column_names(self) -> List[str]:
return self._table.columns.tolist()

@staticmethod
def _build_tensor_row(row: PandasRow) -> np.ndarray:
# Getting an item in a Pandas tensor column returns a TensorArrayElement, which
# we have to convert to an ndarray.
return row[VALUE_COL_NAME].iloc[0].to_numpy()

def slice(self, start: int, end: int, copy: bool) -> "pandas.DataFrame":
view = self._table[start:end]
if copy:
Expand Down
6 changes: 5 additions & 1 deletion python/ray/data/impl/table_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,15 @@ def __init__(self, table: Any):
def _get_row(self, index: int, copy: bool = False) -> Union[TableRow, np.ndarray]:
row = self.slice(index, index + 1, copy=copy)
if self.is_tensor_wrapper():
row = row[VALUE_COL_NAME][0]
row = self._build_tensor_row(row)
else:
row = self.ROW_TYPE(row)
return row

@staticmethod
def _build_tensor_row(row: TableRow) -> np.ndarray:
raise NotImplementedError

def to_native(self) -> Block:
if self.is_tensor_wrapper():
native = self.to_numpy()
Expand Down
4 changes: 1 addition & 3 deletions python/ray/data/random_access_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,7 @@ def multiget(self, block_indices, keys):
col = block[self.key_field]
indices = np.searchsorted(col, keys)
acc = BlockAccessor.for_block(block)
result = [
acc._create_table_row(acc.slice(i, i + 1, copy=True)) for i in indices
]
result = [acc._get_row(i, copy=True) for i in indices]
# assert result == [self._get(i, k) for i, k in zip(block_indices, keys)]
else:
result = [self._get(i, k) for i, k in zip(block_indices, keys)]
Expand Down
73 changes: 73 additions & 0 deletions python/ray/data/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,79 @@ def pd_mapper(df):
np.testing.assert_equal(res, [np.array([2]), np.array([3])])


def test_tensors_shuffle(ray_start_regular_shared):
# Test Arrow table representation.
tensor_shape = (3, 5)
ds = ray.data.range_tensor(6, shape=tensor_shape)
shuffled_ds = ds.random_shuffle()
shuffled = shuffled_ds.take()
base = ds.take()
np.testing.assert_raises(
AssertionError,
np.testing.assert_equal,
shuffled,
base,
)
np.testing.assert_equal(
sorted(shuffled, key=lambda arr: arr.min()),
sorted(base, key=lambda arr: arr.min()),
)

# Test Pandas table representation.
tensor_shape = (3, 5)
ds = ray.data.range_tensor(6, shape=tensor_shape)
ds = ds.map_batches(lambda df: df, batch_format="pandas")
shuffled_ds = ds.random_shuffle()
shuffled = shuffled_ds.take()
base = ds.take()
np.testing.assert_raises(
AssertionError,
np.testing.assert_equal,
shuffled,
base,
)
np.testing.assert_equal(
sorted(shuffled, key=lambda arr: arr.min()),
sorted(base, key=lambda arr: arr.min()),
)


def test_tensors_sort(ray_start_regular_shared):
# Test Arrow table representation.
t = pa.table({"a": TensorArray(np.arange(32).reshape((2, 4, 4))), "b": [1, 2]})
ds = ray.data.from_arrow(t)
sorted_ds = ds.sort(key="b", descending=True)
sorted_arrs = [row["a"] for row in sorted_ds.take()]
base = [row["a"] for row in ds.take()]
np.testing.assert_raises(
AssertionError,
np.testing.assert_equal,
sorted_arrs,
base,
)
np.testing.assert_equal(
sorted_arrs,
sorted(base, key=lambda arr: -arr.min()),
)

# Test Pandas table representation.
df = pd.DataFrame({"a": TensorArray(np.arange(32).reshape((2, 4, 4))), "b": [1, 2]})
ds = ray.data.from_pandas(df)
sorted_ds = ds.sort(key="b", descending=True)
sorted_arrs = [np.asarray(row["a"]) for row in sorted_ds.take()]
base = [np.asarray(row["a"]) for row in ds.take()]
np.testing.assert_raises(
AssertionError,
np.testing.assert_equal,
sorted_arrs,
base,
)
np.testing.assert_equal(
sorted_arrs,
sorted(base, key=lambda arr: -arr.min()),
)


def test_tensors_inferred_from_map(ray_start_regular_shared):
# Test map.
ds = ray.data.range(10).map(lambda _: np.ones((4, 4)))
Expand Down

0 comments on commit f3a2ac4

Please sign in to comment.