Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: 2772 pyarrow doesnt permit selective reading with extensionarray #3127

Merged
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
04968a6
POC for 2772, outbound table conversion
tcawlfield May 8, 2024
97b3d13
Changing table-wide metadata
tcawlfield May 9, 2024
ca9ee04
Adding convert_native_arrow_table_to_awkward
tcawlfield May 10, 2024
8993d7f
Various fixes to table conversions
tcawlfield May 16, 2024
b5fc20e
Ruff formatting
tcawlfield May 16, 2024
0218837
Improvements to pyarrow_table_conv but still issues
tcawlfield May 19, 2024
35e7579
Fixing bug in array_with_replacement_type
tcawlfield May 20, 2024
cab2128
Adding unit testing, fixing a couple bugs
tcawlfield May 21, 2024
91a70f2
Adding hooks to parquet read & write
tcawlfield May 22, 2024
70506c0
Ruff-fmt fixes
tcawlfield May 22, 2024
094c70d
Making progress
tcawlfield May 22, 2024
43d8af5
Fixing another bug: convert each row group when writing
tcawlfield May 23, 2024
6269993
pyarrow_table_conv -- change our new table metadata key name
tcawlfield May 23, 2024
9f4a74a
Some stylistic improvements
tcawlfield May 23, 2024
8266b34
Commented-out a messy assertion in test_2772
tcawlfield May 23, 2024
f2cdcc6
style: pre-commit fixes
pre-commit-ci[bot] May 23, 2024
17a050b
Moving awkward._connect.pyarrow into a package
tcawlfield May 23, 2024
c2e1275
Restructuring ._connect.pyarrow package
tcawlfield May 23, 2024
bd9b16e
Fixing unused imports and other Ruffage
tcawlfield May 23, 2024
7047979
Fixing Ruffage, this time for sure
tcawlfield May 23, 2024
683e025
Fixes for old versions of pyarrow
tcawlfield May 24, 2024
a0a6ae4
Small fixes
tcawlfield May 24, 2024
4f474fc
Adding BSD licenses, moving a commented test
tcawlfield May 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions src/awkward/_connect/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,21 +134,27 @@ def __arrow_ext_class__(self):
return AwkwardArrowArray

def __arrow_ext_serialize__(self):
return json.dumps(
{
"mask_type": self._mask_type,
"node_type": self._node_type,
"mask_parameters": self._mask_parameters,
"node_parameters": self._node_parameters,
"record_is_tuple": self._record_is_tuple,
"record_is_scalar": self._record_is_scalar,
"is_nonnullable_nulltype": self._is_nonnullable_nulltype,
}
).encode(errors="surrogatescape")
return json.dumps(self._metadata_as_dict()).encode(errors="surrogatescape")

def _metadata_as_dict(self):
return {
"mask_type": self._mask_type,
"node_type": self._node_type,
"mask_parameters": self._mask_parameters,
"node_parameters": self._node_parameters,
"record_is_tuple": self._record_is_tuple,
"record_is_scalar": self._record_is_scalar,
"is_nonnullable_nulltype": self._is_nonnullable_nulltype,
}

@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized):
# pyarrow calls this internally
metadata = json.loads(serialized.decode(errors="surrogatescape"))
return cls._from_metadata_object(storage_type, metadata)

@classmethod
def _from_metadata_object(cls, storage_type, metadata):
return cls(
storage_type,
metadata["mask_type"],
Expand All @@ -168,6 +174,9 @@ def num_buffers(self):
def num_fields(self):
return self.storage_type.num_fields

def field(self, i: int | str):
return self.storage_type.field(i)

pyarrow.register_extension_type(
AwkwardArrowType(pyarrow.null(), None, None, None, None, None, None)
)
Expand Down
250 changes: 250 additions & 0 deletions src/awkward/_connect/pyarrow_table_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
from __future__ import annotations

import json

from .pyarrow import (
AwkwardArrowArray,
AwkwardArrowType,
import_pyarrow,
to_awkwardarrow_storage_types,
)

pyarrow = import_pyarrow("pyarrow_table_conv")
AWKWARD_INFO_KEY = b"ak_extn_array_info" # metadata field in Table schema
tcawlfield marked this conversation as resolved.
Show resolved Hide resolved


def convert_awkward_arrow_table_to_native(aatable: pyarrow.Table) -> pyarrow.Table:
"""
aatable: A pyarrow Table created with extensionarray=True
returns: A pyarrow Table without extensionsarrays, but
with 'awkward_info' in the schema's metadata that can be used to
tcawlfield marked this conversation as resolved.
Show resolved Hide resolved
convert the resulting table back into one with extensionarrays.
"""
new_fields = []
metadata = {} # metadata for table column types
for aacol_field in aatable.schema:
metadata[aacol_field.name] = collect_ak_arr_type_metadata(aacol_field)
new_field = awkward_arrow_field_to_native(aacol_field)
new_fields.append(new_field)
metadata_serial = json.dumps(metadata).encode(errors="surrogatescape")
if aatable.schema.metadata is None:
new_metadata = {}
else:
new_metadata = aatable.schema.metadata.copy()
new_metadata[AWKWARD_INFO_KEY] = metadata_serial
new_schema = pyarrow.schema(new_fields, metadata=new_metadata)
# return = aatable.cast(new_schema)
return replace_schema(aatable, new_schema)


def convert_native_arrow_table_to_awkward(table: pyarrow.Table) -> pyarrow.Table:
"""
table: A pyarrow Table converted with convert_awkward_arrow_table_to_native
returns: A pyarrow Table without extensionsarrays, but
with 'awkward_info' in the schema's metadata that can be used to
convert the resulting table back into one with extensionarrays.
"""
if table.schema.metadata is None or AWKWARD_INFO_KEY not in table.schema.metadata:
return table # Prior versions don't include metadata here
new_fields = []
metadata = json.loads(
table.schema.metadata[AWKWARD_INFO_KEY].decode(errors="surrogatescape")
)
for aacol_field in table.schema:
if aacol_field.name not in metadata:
raise ValueError(
f"Awkward metadata in Arrow table does not have info for column {aacol_field.name}"
)
new_fields.append(
native_arrow_field_to_akarraytype(aacol_field, metadata[aacol_field.name])
)
new_metadata = table.schema.metadata.copy()
del new_metadata[AWKWARD_INFO_KEY]
new_schema = pyarrow.schema(new_fields, metadata=new_metadata)
# return table.cast(new_schema) # Similar (same even?) results
return replace_schema(table, new_schema)


def collect_ak_arr_type_metadata(aafield: pyarrow.Field) -> dict | list | None:
"""
Given a Field, collect ArrowExtensionArray metadata as an object.
If that field holds more ArrowExtensionArray types, a "subfield_metadata"
property is added that holds a list of metadata objects for the sub-fields.
This recurses down the whole type structure.
"""
typ = aafield.type
if not isinstance(typ, AwkwardArrowType):
return None # Not expected to reach here
subfields = _fields_of_strg_type(typ.storage_type)
metadata = typ._metadata_as_dict()
metadata["field_name"] = aafield.name
if len(subfields) == 0:
# Simple type
return metadata
# Compound type
subfield_metadata_list = []
for ak_field in subfields:
subfield_metadata_list.append(
collect_ak_arr_type_metadata(ak_field) # Recurse
)
metadata["subfield_metadata"] = subfield_metadata_list
return metadata


def awkward_arrow_field_to_native(aafield: pyarrow.Field) -> pyarrow.Field:
"""
Given a Field with ArrowExtensionArray type, returns a corresponding
field with only Arrow builtin, or storage, types. Metadata is removed.
"""
typ = aafield.type
if not isinstance(typ, AwkwardArrowType):
# Not expected to reach this. Maybe throw ValueError?
return aafield

fields = _fields_of_strg_type(typ.storage_type)
if len(fields) == 0:
# We have a simple type wrapped in AwkwardArrowType.
new_field = pyarrow.field(
aafield.name, type=typ.storage_type, nullable=aafield.nullable
)
return new_field

# We have nested types
native_fields = [
awkward_arrow_field_to_native(field) # Recurse
for field in fields
]
native_type = _make_pyarrow_type_like(typ, native_fields)
new_field = pyarrow.field(aafield.name, type=native_type, nullable=aafield.nullable)
return new_field


def native_arrow_field_to_akarraytype(
ntv_field: pyarrow.Field, metadata: dict
) -> pyarrow.Field:
if isinstance(ntv_field, AwkwardArrowType):
raise ValueError(f"field {ntv_field} is already an AwkwardArrowType")
storage_type = ntv_field.type
fields = _fields_of_strg_type(storage_type)
if len(fields) > 0:
# We need to replace storage_type with one that contains AwkwardArrowTypes.
awkwardized_fields = [
native_arrow_field_to_akarraytype(field, meta) # Recurse
for field, meta in zip(fields, metadata["subfield_metadata"])
]
storage_type = _make_pyarrow_type_like(storage_type, awkwardized_fields)
ak_type = AwkwardArrowType._from_metadata_object(storage_type, metadata)
return pyarrow.field(ntv_field.name, type=ak_type, nullable=ntv_field.nullable)


def _fields_of_strg_type(typ: pyarrow.Type) -> list[pyarrow.Field]:
if isinstance(typ, pyarrow.lib.DictionaryType):
return [
pyarrow.field("value", typ.value_type)
] # Wrap in a field for consistency
return [typ.field(i) for i in range(typ.num_fields)]


def _make_pyarrow_type_like(
typ: pyarrow.Type, fields: list[pyarrow.Field]
) -> pyarrow.Type:
storage_type = to_awkwardarrow_storage_types(typ)[1]
if isinstance(storage_type, pyarrow.lib.DictionaryType):
return pyarrow.dictionary(storage_type.index_type, fields[0].type)
if isinstance(storage_type, pyarrow.lib.FixedSizeListType):
return pyarrow.list_(fields[0], storage_type.list_size)
if isinstance(storage_type, pyarrow.lib.ListType):
return pyarrow.list_(fields[0])
if isinstance(storage_type, pyarrow.lib.LargeListType):
return pyarrow.large_list(fields[0])
if isinstance(storage_type, pyarrow.lib.MapType):
# return pyarrow.map_(storage_type.index_type, fields[0])
raise NotImplementedError("pyarrow MapType is not supported by Awkward")
if isinstance(storage_type, pyarrow.lib.StructType):
return pyarrow.struct(fields)
if isinstance(storage_type, pyarrow.lib.UnionType):
return pyarrow.union(fields, storage_type.mode, storage_type.type_codes)
if isinstance(storage_type, pyarrow.lib.DataType) and storage_type.num_fields == 0:
# Catch-all for primitive types, nulltype, string types, FixedSizeBinaryType
return storage_type
raise NotImplementedError(f"Type {typ} is not handled for conversion.")
tcawlfield marked this conversation as resolved.
Show resolved Hide resolved


def replace_schema(table: pyarrow.Table, new_schema: pyarrow.Schema) -> pyarrow.Table:
"""
This function is like `pyarrow.Table.cast()` except it only works if the
new schema uses the same storage types and storage geometries as the original.
It explicitly will not convert one primitive type to another.
"""
new_batches = []
for batch in table.to_batches():
columns = []
for col, new_field in zip(batch.itercolumns(), new_schema):
columns.append(array_with_replacement_type(col, new_field.type))
new_batches.append(
pyarrow.RecordBatch.from_arrays(arrays=columns, schema=new_schema)
)
return pyarrow.Table.from_batches(new_batches)


def array_with_replacement_type(
orig_array: pyarrow.Array, new_type: pyarrow.Type
) -> pyarrow.Array:
"""
Creates a new array with a different type.
Either pyarrow native -> ExtensionArray or vice-versa.
"""
children_orig = _get_children(orig_array)
native_type = to_awkwardarrow_storage_types(new_type)[1]
new_fields = _fields_of_strg_type(native_type)
if len(new_fields) != len(children_orig):
raise AssertionError(
f"Number of children: {len(children_orig) =} != {len(new_fields) =}"
)
children_new = [
array_with_replacement_type(child, new_child_type.type)
for child, new_child_type in zip(children_orig, new_fields)
]
own_buffers = orig_array.buffers()[: orig_array.type.num_buffers]
if isinstance(native_type, pyarrow.lib.DictionaryType):
native_dict = pyarrow.DictionaryArray.from_buffers(
type=native_type,
length=len(orig_array),
buffers=own_buffers,
dictionary=children_new[0],
null_count=orig_array.null_count,
offset=orig_array.offset,
)
if isinstance(new_type, pyarrow.ExtensionType):
return AwkwardArrowArray.from_storage(new_type, native_dict)
else:
return native_dict
else:
return pyarrow.Array.from_buffers(
type=new_type,
length=len(orig_array),
buffers=own_buffers,
null_count=orig_array.null_count,
offset=orig_array.offset,
children=children_new,
)


def _get_children(array: pyarrow.Array) -> list[pyarrow.Array]:
"""
Different types of pyarrow Arrays have different ways to
access their "children." It helps to unify these.
"""
arrow_type = to_awkwardarrow_storage_types(array.type)[1]
if isinstance(array, AwkwardArrowArray):
array = array.storage

if isinstance(arrow_type, pyarrow.lib.DictionaryType):
return [array.dictionary]
if arrow_type.num_fields == 0:
return []
if hasattr(array, "field"):
return [array.field(idx) for idx in range(arrow_type.num_fields)]
if hasattr(array, "values"):
return [array.values]
raise NotImplementedError(f"Cannot get children of arrow type {arrow_type}")
2 changes: 2 additions & 0 deletions src/awkward/operations/ak_from_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import fsspec.parquet

import awkward as ak
from awkward._connect.pyarrow_table_conv import convert_native_arrow_table_to_awkward
from awkward._dispatch import high_level_function
from awkward._layout import wrap_layout
from awkward._regularize import is_integer
Expand Down Expand Up @@ -296,6 +297,7 @@ def _read_parquet_file(
else:
arrow_table = parquetfile.read_row_groups(row_groups, parquet_columns)

arrow_table = convert_native_arrow_table_to_awkward(arrow_table)
return ak.operations.ak_from_arrow._impl(
arrow_table,
generate_bitmasks,
Expand Down
4 changes: 4 additions & 0 deletions src/awkward/operations/ak_to_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import fsspec

import awkward as ak
from awkward._connect.pyarrow_table_conv import convert_awkward_arrow_table_to_native
from awkward._dispatch import high_level_function
from awkward._nplikes.numpy_like import NumpyMetadata

Expand Down Expand Up @@ -401,6 +402,9 @@ def parquet_columns(specifier, only=None):
)
parquet_byte_stream_split = [x for x, value in replacement.items() if value]

if extensionarray:
table = convert_awkward_arrow_table_to_native(table)

if parquet_extra_options is None:
parquet_extra_options = {}

Expand Down
Loading
Loading