Skip to content

Commit

Permalink
Move writes to the transaction class (#571)
Browse files Browse the repository at this point in the history
  • Loading branch information
sungwy authored Apr 4, 2024
1 parent a892309 commit d69407c
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 61 deletions.
160 changes: 99 additions & 61 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,100 @@ def update_snapshot(self, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> U
"""
return UpdateSnapshot(self, io=self._table.io, snapshot_properties=snapshot_properties)

def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None:
"""
Shorthand API for appending a PyArrow table to a table transaction.
Args:
df: The Arrow dataframe that will be appended to overwrite the table
snapshot_properties: Custom properties to be added to the snapshot summary
"""
try:
import pyarrow as pa
except ModuleNotFoundError as e:
raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e

if not isinstance(df, pa.Table):
raise ValueError(f"Expected PyArrow table, got: {df}")

if len(self._table.spec().fields) > 0:
raise ValueError("Cannot write to partitioned tables")

_check_schema_compatible(self._table.schema(), other_schema=df.schema)
# cast if the two schemas are compatible but not equal
table_arrow_schema = self._table.schema().as_arrow()
if table_arrow_schema != df.schema:
df = df.cast(table_arrow_schema)

with self.update_snapshot(snapshot_properties=snapshot_properties).fast_append() as update_snapshot:
# skip writing data files if the dataframe is empty
if df.shape[0] > 0:
data_files = _dataframe_to_data_files(
table_metadata=self._table.metadata, write_uuid=update_snapshot.commit_uuid, df=df, io=self._table.io
)
for data_file in data_files:
update_snapshot.append_data_file(data_file)

def overwrite(
self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_TRUE, snapshot_properties: Dict[str, str] = EMPTY_DICT
) -> None:
"""
Shorthand for adding a table overwrite with a PyArrow table to the transaction.
Args:
df: The Arrow dataframe that will be used to overwrite the table
overwrite_filter: ALWAYS_TRUE when you overwrite all the data,
or a boolean expression in case of a partial overwrite
snapshot_properties: Custom properties to be added to the snapshot summary
"""
try:
import pyarrow as pa
except ModuleNotFoundError as e:
raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e

if not isinstance(df, pa.Table):
raise ValueError(f"Expected PyArrow table, got: {df}")

if overwrite_filter != AlwaysTrue():
raise NotImplementedError("Cannot overwrite a subset of a table")

if len(self._table.spec().fields) > 0:
raise ValueError("Cannot write to partitioned tables")

_check_schema_compatible(self._table.schema(), other_schema=df.schema)
# cast if the two schemas are compatible but not equal
table_arrow_schema = self._table.schema().as_arrow()
if table_arrow_schema != df.schema:
df = df.cast(table_arrow_schema)

with self.update_snapshot(snapshot_properties=snapshot_properties).overwrite() as update_snapshot:
# skip writing data files if the dataframe is empty
if df.shape[0] > 0:
data_files = _dataframe_to_data_files(
table_metadata=self._table.metadata, write_uuid=update_snapshot.commit_uuid, df=df, io=self._table.io
)
for data_file in data_files:
update_snapshot.append_data_file(data_file)

def add_files(self, file_paths: List[str]) -> None:
"""
Shorthand API for adding files as data files to the table transaction.
Args:
file_paths: The list of full file paths to be added as data files to the table
Raises:
FileNotFoundError: If the file does not exist.
"""
if self._table.name_mapping() is None:
self.set_properties(**{TableProperties.DEFAULT_NAME_MAPPING: self._table.schema().name_mapping.model_dump_json()})
with self.update_snapshot().fast_append() as update_snapshot:
data_files = _parquet_files_to_data_files(
table_metadata=self._table.metadata, file_paths=file_paths, io=self._table.io
)
for data_file in data_files:
update_snapshot.append_data_file(data_file)

def update_spec(self) -> UpdateSpec:
"""Create a new UpdateSpec to update the partitioning of the table.
Expand Down Expand Up @@ -1219,32 +1313,8 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT)
df: The Arrow dataframe that will be appended to overwrite the table
snapshot_properties: Custom properties to be added to the snapshot summary
"""
try:
import pyarrow as pa
except ModuleNotFoundError as e:
raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e

if not isinstance(df, pa.Table):
raise ValueError(f"Expected PyArrow table, got: {df}")

if len(self.spec().fields) > 0:
raise ValueError("Cannot write to partitioned tables")

_check_schema_compatible(self.schema(), other_schema=df.schema)
# cast if the two schemas are compatible but not equal
table_arrow_schema = self.schema().as_arrow()
if table_arrow_schema != df.schema:
df = df.cast(table_arrow_schema)

with self.transaction() as txn:
with txn.update_snapshot(snapshot_properties=snapshot_properties).fast_append() as update_snapshot:
# skip writing data files if the dataframe is empty
if df.shape[0] > 0:
data_files = _dataframe_to_data_files(
table_metadata=self.metadata, write_uuid=update_snapshot.commit_uuid, df=df, io=self.io
)
for data_file in data_files:
update_snapshot.append_data_file(data_file)
with self.transaction() as tx:
tx.append(df=df, snapshot_properties=snapshot_properties)

def overwrite(
self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_TRUE, snapshot_properties: Dict[str, str] = EMPTY_DICT
Expand All @@ -1258,35 +1328,8 @@ def overwrite(
or a boolean expression in case of a partial overwrite
snapshot_properties: Custom properties to be added to the snapshot summary
"""
try:
import pyarrow as pa
except ModuleNotFoundError as e:
raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e

if not isinstance(df, pa.Table):
raise ValueError(f"Expected PyArrow table, got: {df}")

if overwrite_filter != AlwaysTrue():
raise NotImplementedError("Cannot overwrite a subset of a table")

if len(self.spec().fields) > 0:
raise ValueError("Cannot write to partitioned tables")

_check_schema_compatible(self.schema(), other_schema=df.schema)
# cast if the two schemas are compatible but not equal
table_arrow_schema = self.schema().as_arrow()
if table_arrow_schema != df.schema:
df = df.cast(table_arrow_schema)

with self.transaction() as txn:
with txn.update_snapshot(snapshot_properties=snapshot_properties).overwrite() as update_snapshot:
# skip writing data files if the dataframe is empty
if df.shape[0] > 0:
data_files = _dataframe_to_data_files(
table_metadata=self.metadata, write_uuid=update_snapshot.commit_uuid, df=df, io=self.io
)
for data_file in data_files:
update_snapshot.append_data_file(data_file)
with self.transaction() as tx:
tx.overwrite(df=df, overwrite_filter=overwrite_filter, snapshot_properties=snapshot_properties)

def add_files(self, file_paths: List[str]) -> None:
"""
Expand All @@ -1299,12 +1342,7 @@ def add_files(self, file_paths: List[str]) -> None:
FileNotFoundError: If the file does not exist.
"""
with self.transaction() as tx:
if self.name_mapping() is None:
tx.set_properties(**{TableProperties.DEFAULT_NAME_MAPPING: self.schema().name_mapping.model_dump_json()})
with tx.update_snapshot().fast_append() as update_snapshot:
data_files = _parquet_files_to_data_files(table_metadata=self.metadata, file_paths=file_paths, io=self.io)
for data_file in data_files:
update_snapshot.append_data_file(data_file)
tx.add_files(file_paths=file_paths)

def update_spec(self, case_sensitive: bool = True) -> UpdateSpec:
return UpdateSpec(Transaction(self, autocommit=True), case_sensitive=case_sensitive)
Expand Down
28 changes: 28 additions & 0 deletions tests/integration/test_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,3 +832,31 @@ def test_inspect_snapshots(
continue

assert left == right, f"Difference in column {column}: {left} != {right}"


@pytest.mark.integration
def test_write_within_transaction(spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None:
identifier = "default.write_in_open_transaction"
tbl = _create_table(session_catalog, identifier, {"format-version": "1"}, [])

def get_metadata_entries_count(identifier: str) -> int:
return spark.sql(
f"""
SELECT *
FROM {identifier}.metadata_log_entries
"""
).count()

# one metadata entry from table creation
assert get_metadata_entries_count(identifier) == 1

# one more metadata entry from transaction
with tbl.transaction() as tx:
tx.set_properties({"test": "1"})
tx.append(arrow_table_with_null)
assert get_metadata_entries_count(identifier) == 2

# two more metadata entries added from two separate transactions
tbl.transaction().set_properties({"test": "2"}).commit_transaction()
tbl.append(arrow_table_with_null)
assert get_metadata_entries_count(identifier) == 4

0 comments on commit d69407c

Please sign in to comment.