Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cast PyArrow schema to large_* types #807

Merged
merged 6 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 46 additions & 10 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ def field(self, field: NestedField, field_result: pa.DataType) -> pa.Field:

def list(self, list_type: ListType, element_result: pa.DataType) -> pa.DataType:
element_field = self.field(list_type.element_field, element_result)
return pa.list_(value_type=element_field)
return pa.large_list(value_type=element_field)

def map(self, map_type: MapType, key_result: pa.DataType, value_result: pa.DataType) -> pa.DataType:
key_field = self.field(map_type.key_field, key_result)
Expand Down Expand Up @@ -548,7 +548,7 @@ def visit_timestamptz(self, _: TimestamptzType) -> pa.DataType:
return pa.timestamp(unit="us", tz="UTC")

def visit_string(self, _: StringType) -> pa.DataType:
return pa.string()
return pa.large_string()

def visit_uuid(self, _: UUIDType) -> pa.DataType:
return pa.binary(16)
Expand Down Expand Up @@ -680,6 +680,10 @@ def _pyarrow_to_schema_without_ids(schema: pa.Schema) -> Schema:
return visit_pyarrow(schema, _ConvertToIcebergWithoutIDs())


def _pyarrow_schema_ensure_large_types(schema: pa.Schema) -> pa.Schema:
return visit_pyarrow(schema, _ConvertToLargeTypes())


@singledispatch
def visit_pyarrow(obj: Union[pa.DataType, pa.Schema], visitor: PyArrowSchemaVisitor[T]) -> T:
"""Apply a pyarrow schema visitor to any point within a schema.
Expand Down Expand Up @@ -952,6 +956,30 @@ def after_map_value(self, element: pa.Field) -> None:
self._field_names.pop()


class _ConvertToLargeTypes(PyArrowSchemaVisitor[Union[pa.DataType, pa.Schema]]):
def schema(self, schema: pa.Schema, struct_result: pa.StructType) -> pa.Schema:
return pa.schema(struct_result)

def struct(self, struct: pa.StructType, field_results: List[pa.Field]) -> pa.StructType:
return pa.struct(field_results)

def field(self, field: pa.Field, field_result: pa.DataType) -> pa.Field:
return field.with_type(field_result)

def list(self, list_type: pa.ListType, element_result: pa.DataType) -> pa.DataType:
return pa.large_list(element_result)

def map(self, map_type: pa.MapType, key_result: pa.DataType, value_result: pa.DataType) -> pa.DataType:
return pa.map_(key_result, value_result)

def primitive(self, primitive: pa.DataType) -> pa.DataType:
if primitive == pa.string():
return pa.large_string()
elif primitive == pa.binary():
return pa.large_binary()
return primitive


class _ConvertToIcebergWithoutIDs(_ConvertToIceberg):
"""
Converts PyArrowSchema to Iceberg Schema with all -1 ids.
Expand Down Expand Up @@ -998,7 +1026,9 @@ def _task_to_table(

fragment_scanner = ds.Scanner.from_fragment(
fragment=fragment,
schema=physical_schema,
# We always use large types in memory as it uses larger offsets
# That can chunk more row values into the buffers
schema=_pyarrow_schema_ensure_large_types(physical_schema),
# This will push down the query to Arrow.
# But in case there are positional deletes, we have to apply them first
filter=pyarrow_filter if not positional_deletes else None,
Expand Down Expand Up @@ -1167,8 +1197,14 @@ def __init__(self, file_schema: Schema):

def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array:
file_field = self.file_schema.find_field(field.field_id)
if field.field_type.is_primitive and field.field_type != file_field.field_type:
return values.cast(schema_to_pyarrow(promote(file_field.field_type, field.field_type), include_field_ids=False))
if field.field_type.is_primitive:
if field.field_type != file_field.field_type:
return values.cast(schema_to_pyarrow(promote(file_field.field_type, field.field_type), include_field_ids=False))
elif (target_type := schema_to_pyarrow(field.field_type, include_field_ids=False)) != values.type:
# if file_field and field type (e.g. String) are the same
sungwy marked this conversation as resolved.
Show resolved Hide resolved
# but the pyarrow type of the array is different from the expected type
# (e.g. string vs larger_string), we want to cast the array to the larger type
return values.cast(target_type)
return values

def _construct_field(self, field: NestedField, arrow_type: pa.DataType) -> pa.Field:
Expand Down Expand Up @@ -1207,13 +1243,13 @@ def field(self, field: NestedField, _: Optional[pa.Array], field_array: Optional
return field_array

def list(self, list_type: ListType, list_array: Optional[pa.Array], value_array: Optional[pa.Array]) -> Optional[pa.Array]:
if isinstance(list_array, pa.ListArray) and value_array is not None:
if isinstance(list_array, (pa.ListArray, pa.LargeListArray, pa.FixedSizeListArray)) and value_array is not None:
if isinstance(value_array, pa.StructArray):
# This can be removed once this has been fixed:
# https://github.com/apache/arrow/issues/38809
list_array = pa.ListArray.from_arrays(list_array.offsets, value_array)
list_array = pa.LargeListArray.from_arrays(list_array.offsets, value_array)

arrow_field = pa.list_(self._construct_field(list_type.element_field, value_array.type))
arrow_field = pa.large_list(self._construct_field(list_type.element_field, value_array.type))
return list_array.cast(arrow_field)
else:
return None
Expand Down Expand Up @@ -1263,7 +1299,7 @@ def field_partner(self, partner_struct: Optional[pa.Array], field_id: int, _: st
return None

def list_element_partner(self, partner_list: Optional[pa.Array]) -> Optional[pa.Array]:
return partner_list.values if isinstance(partner_list, pa.ListArray) else None
return partner_list.values if isinstance(partner_list, (pa.ListArray, pa.LargeListArray, pa.FixedSizeListArray)) else None

def map_key_partner(self, partner_map: Optional[pa.Array]) -> Optional[pa.Array]:
return partner_map.keys if isinstance(partner_map, pa.MapArray) else None
Expand Down Expand Up @@ -1800,10 +1836,10 @@ def write_parquet(task: WriteTask) -> DataFile:
# otherwise use the original schema
if (sanitized_schema := sanitize_column_names(table_schema)) != table_schema:
file_schema = sanitized_schema
arrow_table = to_requested_schema(requested_schema=file_schema, file_schema=table_schema, table=arrow_table)
else:
file_schema = table_schema

arrow_table = to_requested_schema(requested_schema=file_schema, file_schema=table_schema, table=arrow_table)
file_path = f'{table_metadata.location}/data/{task.generate_data_file_path("parquet")}'
fo = io.new_output(file_path)
with fo.create(overwrite=True) as fos:
Expand Down
10 changes: 5 additions & 5 deletions tests/catalog/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def test_write_pyarrow_schema(catalog: SqlCatalog, table_identifier: Identifier)
pa.array([None, "A", "B", "C"]), # 'large' column
],
schema=pa.schema([
pa.field("foo", pa.string(), nullable=True),
pa.field("foo", pa.large_string(), nullable=True),
pa.field("bar", pa.int32(), nullable=False),
pa.field("baz", pa.bool_(), nullable=True),
pa.field("large", pa.large_string(), nullable=True),
Expand Down Expand Up @@ -1325,7 +1325,7 @@ def test_write_and_evolve(catalog: SqlCatalog, format_version: int) -> None:
{
"foo": ["a", None, "z"],
},
schema=pa.schema([pa.field("foo", pa.string(), nullable=True)]),
schema=pa.schema([pa.field("foo", pa.large_string(), nullable=True)]),
)

tbl = catalog.create_table(identifier=identifier, schema=pa_table.schema, properties={"format-version": str(format_version)})
Expand All @@ -1336,7 +1336,7 @@ def test_write_and_evolve(catalog: SqlCatalog, format_version: int) -> None:
"bar": [19, None, 25],
},
schema=pa.schema([
pa.field("foo", pa.string(), nullable=True),
pa.field("foo", pa.large_string(), nullable=True),
pa.field("bar", pa.int32(), nullable=True),
]),
)
Expand Down Expand Up @@ -1375,7 +1375,7 @@ def test_create_table_transaction(catalog: SqlCatalog, format_version: int) -> N
{
"foo": ["a", None, "z"],
},
schema=pa.schema([pa.field("foo", pa.string(), nullable=True)]),
schema=pa.schema([pa.field("foo", pa.large_string(), nullable=True)]),
)

pa_table_with_column = pa.Table.from_pydict(
Expand All @@ -1384,7 +1384,7 @@ def test_create_table_transaction(catalog: SqlCatalog, format_version: int) -> N
"bar": [19, None, 25],
},
schema=pa.schema([
pa.field("foo", pa.string(), nullable=True),
pa.field("foo", pa.large_string(), nullable=True),
pa.field("bar", pa.int32(), nullable=True),
]),
)
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2116,8 +2116,8 @@ def pa_schema() -> "pa.Schema":

return pa.schema([
("bool", pa.bool_()),
("string", pa.string()),
("string_long", pa.string()),
("string", pa.large_string()),
("string_long", pa.large_string()),
("int", pa.int32()),
("long", pa.int64()),
("float", pa.float32()),
Expand Down
54 changes: 54 additions & 0 deletions tests/integration/test_writes/test_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,60 @@ def test_python_writes_dictionary_encoded_column_with_spark_reads(
assert spark_df.equals(pyiceberg_df)


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_python_writes_with_small_and_large_types_spark_reads(
spark: SparkSession, session_catalog: Catalog, format_version: int
) -> None:
identifier = "default.python_writes_with_small_and_large_types_spark_reads"
TEST_DATA = {
"foo": ["a", None, "z"],
"id": [1, 2, 3],
"name": ["AB", "CD", "EF"],
"address": [
{"street": "123", "city": "SFO", "zip": 12345, "bar": "a"},
{"street": "456", "city": "SW", "zip": 67890, "bar": "b"},
{"street": "789", "city": "Random", "zip": 10112, "bar": "c"},
],
}
pa_schema = pa.schema([
pa.field("foo", pa.large_string()),
pa.field("id", pa.int32()),
pa.field("name", pa.string()),
pa.field(
"address",
pa.struct([
pa.field("street", pa.string()),
pa.field("city", pa.string()),
pa.field("zip", pa.int32()),
pa.field("bar", pa.large_string()),
]),
),
])
arrow_table = pa.Table.from_pydict(TEST_DATA, schema=pa_schema)
tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, schema=pa_schema)

tbl.overwrite(arrow_table)
spark_df = spark.sql(f"SELECT * FROM {identifier}").toPandas()
pyiceberg_df = tbl.scan().to_pandas()
assert spark_df.equals(pyiceberg_df)
arrow_table_on_read = tbl.scan().to_arrow()
assert arrow_table_on_read.schema == pa.schema([
pa.field("foo", pa.large_string()),
pa.field("id", pa.int32()),
pa.field("name", pa.large_string()),
pa.field(
"address",
pa.struct([
pa.field("street", pa.large_string()),
pa.field("city", pa.large_string()),
pa.field("zip", pa.int32()),
pa.field("bar", pa.large_string()),
]),
),
])


@pytest.mark.integration
def test_write_bin_pack_data_files(spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None:
identifier = "default.write_bin_pack_data_files"
Expand Down
Loading