From 6708a6eaa76a2b4aab58601f10210f778b3f03a4 Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Thu, 29 Feb 2024 13:10:42 +0100 Subject: [PATCH] Update table metadata throughout transaction (#471) * Update table metadata throughout transaction This PR add support for updating the table metadata throughout the transaction. This way, if a schema is first evolved, and then a snapshot is created based on the latest schema, it will be able to find the schema. * Fix integration tests * Thanks Honah! * Include the partition evolution * Cleanup --- pyiceberg/io/pyarrow.py | 21 +- pyiceberg/table/__init__.py | 537 ++++++++++++-------------- pyiceberg/table/metadata.py | 43 +++ tests/catalog/test_sql.py | 52 +++ tests/integration/test_rest_schema.py | 23 +- tests/integration/test_writes.py | 2 +- tests/table/test_init.py | 15 +- tests/test_schema.py | 42 +- 8 files changed, 396 insertions(+), 339 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index d41f7a07a5..be944ffb36 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -125,6 +125,7 @@ visit_with_partner, ) from pyiceberg.table import PropertyUtil, TableProperties, WriteTask +from pyiceberg.table.metadata import TableMetadata from pyiceberg.table.name_mapping import NameMapping from pyiceberg.transforms import TruncateTransform from pyiceberg.typedef import EMPTY_DICT, Properties, Record @@ -1720,7 +1721,7 @@ def fill_parquet_file_metadata( data_file.split_offsets = split_offsets -def write_file(table: Table, tasks: Iterator[WriteTask], file_schema: Optional[Schema] = None) -> Iterator[DataFile]: +def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteTask]) -> Iterator[DataFile]: task = next(tasks) try: @@ -1730,15 +1731,15 @@ def write_file(table: Table, tasks: Iterator[WriteTask], file_schema: Optional[S except StopIteration: pass - parquet_writer_kwargs = _get_parquet_writer_kwargs(table.properties) + parquet_writer_kwargs = _get_parquet_writer_kwargs(table_metadata.properties) - file_path = f'{table.location()}/data/{task.generate_data_file_filename("parquet")}' - file_schema = file_schema or table.schema() - arrow_file_schema = schema_to_pyarrow(file_schema) + file_path = f'{table_metadata.location}/data/{task.generate_data_file_filename("parquet")}' + schema = table_metadata.schema() + arrow_file_schema = schema_to_pyarrow(schema) - fo = table.io.new_output(file_path) + fo = io.new_output(file_path) row_group_size = PropertyUtil.property_as_int( - properties=table.properties, + properties=table_metadata.properties, property_name=TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES, default=TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES_DEFAULT, ) @@ -1757,7 +1758,7 @@ def write_file(table: Table, tasks: Iterator[WriteTask], file_schema: Optional[S # sort_order_id=task.sort_order_id, sort_order_id=None, # Just copy these from the table for now - spec_id=table.spec().spec_id, + spec_id=table_metadata.default_spec_id, equality_ids=None, key_metadata=None, ) @@ -1765,8 +1766,8 @@ def write_file(table: Table, tasks: Iterator[WriteTask], file_schema: Optional[S fill_parquet_file_metadata( data_file=data_file, parquet_metadata=writer.writer.metadata, - stats_columns=compute_statistics_plan(file_schema, table.properties), - parquet_column_mapping=parquet_path_to_id_mapping(file_schema), + stats_columns=compute_statistics_plan(schema, table_metadata.properties), + parquet_column_mapping=parquet_path_to_id_mapping(schema), ) return iter([data_file]) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index e29369a26c..1a4183c914 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -31,6 +31,7 @@ Any, Callable, Dict, + Generic, Iterable, List, Literal, @@ -137,6 +138,7 @@ from pyiceberg.catalog import Catalog + ALWAYS_TRUE = AlwaysTrue() TABLE_ROOT_ID = -1 @@ -229,18 +231,23 @@ def property_as_int(properties: Dict[str, str], property_name: str, default: Opt class Transaction: _table: Table + table_metadata: TableMetadata + _autocommit: bool _updates: Tuple[TableUpdate, ...] _requirements: Tuple[TableRequirement, ...] - def __init__( - self, - table: Table, - actions: Optional[Tuple[TableUpdate, ...]] = None, - requirements: Optional[Tuple[TableRequirement, ...]] = None, - ): + def __init__(self, table: Table, autocommit: bool = False): + """Open a transaction to stage and commit changes to a table. + + Args: + table: The table that will be altered. + autocommit: Option to automatically commit the changes when they are staged. + """ + self.table_metadata = table.metadata self._table = table - self._updates = actions or () - self._requirements = requirements or () + self._autocommit = autocommit + self._updates = () + self._requirements = () def __enter__(self) -> Transaction: """Start a transaction to update the table.""" @@ -248,49 +255,23 @@ def __enter__(self) -> Transaction: def __exit__(self, _: Any, value: Any, traceback: Any) -> None: """Close and commit the transaction.""" - fresh_table = self.commit_transaction() - # Update the new data in place - self._table.metadata = fresh_table.metadata - self._table.metadata_location = fresh_table.metadata_location + self.commit_transaction() - def _append_updates(self, *new_updates: TableUpdate) -> Transaction: - """Append updates to the set of staged updates. + def _apply(self, updates: Tuple[TableUpdate, ...], requirements: Tuple[TableRequirement, ...] = ()) -> Transaction: + """Check if the requirements are met, and applies the updates to the metadata.""" + for requirement in requirements: + requirement.validate(self.table_metadata) - Args: - *new_updates: Any new updates. + self._updates += updates + self._requirements += requirements - Raises: - ValueError: When the type of update is not unique. + self.table_metadata = update_table_metadata(self.table_metadata, updates) - Returns: - Transaction object with the new updates appended. - """ - for new_update in new_updates: - # explicitly get type of new_update as new_update is an instantiated class - type_new_update = type(new_update) - if any(isinstance(update, type_new_update) for update in self._updates): - raise ValueError(f"Updates in a single commit need to be unique, duplicate: {type_new_update}") - self._updates = self._updates + new_updates - return self + if self._autocommit: + self.commit_transaction() + self._updates = () + self._requirements = () - def _append_requirements(self, *new_requirements: TableRequirement) -> Transaction: - """Append requirements to the set of staged requirements. - - Args: - *new_requirements: Any new requirements. - - Raises: - ValueError: When the type of requirement is not unique. - - Returns: - Transaction object with the new requirements appended. - """ - for new_requirement in new_requirements: - # explicitly get type of new_update as requirement is an instantiated class - type_new_requirement = type(new_requirement) - if any(isinstance(requirement, type_new_requirement) for requirement in self._requirements): - raise ValueError(f"Requirements in a single commit need to be unique, duplicate: {type_new_requirement}") - self._requirements = self._requirements + new_requirements return self def upgrade_table_version(self, format_version: Literal[1, 2]) -> Transaction: @@ -307,10 +288,11 @@ def upgrade_table_version(self, format_version: Literal[1, 2]) -> Transaction: if format_version < self._table.metadata.format_version: raise ValueError(f"Cannot downgrade v{self._table.metadata.format_version} table to v{format_version}") + if format_version > self._table.metadata.format_version: - return self._append_updates(UpgradeFormatVersionUpdate(format_version=format_version)) - else: - return self + return self._apply((UpgradeFormatVersionUpdate(format_version=format_version),)) + + return self def set_properties(self, **updates: str) -> Transaction: """Set properties. @@ -323,56 +305,19 @@ def set_properties(self, **updates: str) -> Transaction: Returns: The alter table builder. """ - return self._append_updates(SetPropertiesUpdate(updates=updates)) - - def add_snapshot(self, snapshot: Snapshot) -> Transaction: - """Add a new snapshot to the table. - - Returns: - The transaction with the add-snapshot staged. - """ - self._append_updates(AddSnapshotUpdate(snapshot=snapshot)) - self._append_requirements(AssertTableUUID(uuid=self._table.metadata.table_uuid)) - - return self - - def set_ref_snapshot( - self, - snapshot_id: int, - parent_snapshot_id: Optional[int], - ref_name: str, - type: str, - max_age_ref_ms: Optional[int] = None, - max_snapshot_age_ms: Optional[int] = None, - min_snapshots_to_keep: Optional[int] = None, - ) -> Transaction: - """Update a ref to a snapshot. + return self._apply((SetPropertiesUpdate(updates=updates),)) - Returns: - The transaction with the set-snapshot-ref staged - """ - self._append_updates( - SetSnapshotRefUpdate( - snapshot_id=snapshot_id, - parent_snapshot_id=parent_snapshot_id, - ref_name=ref_name, - type=type, - max_age_ref_ms=max_age_ref_ms, - max_snapshot_age_ms=max_snapshot_age_ms, - min_snapshots_to_keep=min_snapshots_to_keep, - ) - ) - - self._append_requirements(AssertRefSnapshotId(snapshot_id=parent_snapshot_id, ref="main")) - return self - - def update_schema(self) -> UpdateSchema: + def update_schema(self, allow_incompatible_changes: bool = False, case_sensitive: bool = True) -> UpdateSchema: """Create a new UpdateSchema to alter the columns of this table. + Args: + allow_incompatible_changes: If changes are allowed that might break downstream consumers. + case_sensitive: If field names are case-sensitive. + Returns: A new UpdateSchema. """ - return UpdateSchema(self._table, self) + return UpdateSchema(self, allow_incompatible_changes=allow_incompatible_changes, case_sensitive=case_sensitive) def update_snapshot(self) -> UpdateSnapshot: """Create a new UpdateSnapshot to produce a new snapshot for the table. @@ -380,7 +325,7 @@ def update_snapshot(self) -> UpdateSnapshot: Returns: A new UpdateSnapshot """ - return UpdateSnapshot(self._table, self) + return UpdateSnapshot(self, io=self._table.io) def update_spec(self) -> UpdateSpec: """Create a new UpdateSpec to update the partitioning of the table. @@ -388,7 +333,7 @@ def update_spec(self) -> UpdateSpec: Returns: A new UpdateSpec. """ - return UpdateSpec(self._table, self) + return UpdateSpec(self) def remove_properties(self, *removals: str) -> Transaction: """Remove properties. @@ -399,7 +344,7 @@ def remove_properties(self, *removals: str) -> Transaction: Returns: The alter table builder. """ - return self._append_updates(RemovePropertiesUpdate(removals=removals)) + return self._apply((RemovePropertiesUpdate(removals=removals),)) def update_location(self, location: str) -> Transaction: """Set the new table location. @@ -412,19 +357,12 @@ def update_location(self, location: str) -> Transaction: """ raise NotImplementedError("Not yet implemented") - def schema(self) -> Schema: - try: - return next(update for update in self._updates if isinstance(update, AddSchemaUpdate)).schema_ - except StopIteration: - return self._table.schema() - def commit_transaction(self) -> Table: """Commit the changes to the catalog. Returns: The table with the updates applied. """ - # Strip the catalog name if len(self._updates) > 0: self._table._do_commit( # pylint: disable=W0212 updates=self._updates, @@ -913,7 +851,7 @@ class AssertLastAssignedPartitionId(TableRequirement): """The table's last assigned partition id must match the requirement's `last-assigned-partition-id`.""" type: Literal["assert-last-assigned-partition-id"] = Field(default="assert-last-assigned-partition-id") - last_assigned_partition_id: int = Field(..., alias="last-assigned-partition-id") + last_assigned_partition_id: Optional[int] = Field(..., alias="last-assigned-partition-id") def validate(self, base_metadata: Optional[TableMetadata]) -> None: if base_metadata is None: @@ -954,6 +892,9 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None: ) +UpdatesAndRequirements = Tuple[Tuple[TableUpdate, ...], Tuple[TableRequirement, ...]] + + class Namespace(IcebergRootModel[List[str]]): """Reference to one or more levels of a namespace.""" @@ -998,6 +939,11 @@ def __init__( self.catalog = catalog def transaction(self) -> Transaction: + """Create a new transaction object to first stage the changes, and then commit them to the catalog. + + Returns: + The transaction object + """ return Transaction(self) def refresh(self) -> Table: @@ -1080,17 +1026,6 @@ def location(self) -> str: def last_sequence_number(self) -> int: return self.metadata.last_sequence_number - def next_sequence_number(self) -> int: - return self.last_sequence_number + 1 if self.metadata.format_version > 1 else INITIAL_SEQUENCE_NUMBER - - def new_snapshot_id(self) -> int: - """Generate a new snapshot-id that's not in use.""" - snapshot_id = _generate_snapshot_id() - while self.snapshot_by_id(snapshot_id) is not None: - snapshot_id = _generate_snapshot_id() - - return snapshot_id - def current_snapshot(self) -> Optional[Snapshot]: """Get the current snapshot for this table, or None if there is no current snapshot.""" if self.metadata.current_snapshot_id is not None: @@ -1114,18 +1049,19 @@ def history(self) -> List[SnapshotLogEntry]: def update_schema(self, allow_incompatible_changes: bool = False, case_sensitive: bool = True) -> UpdateSchema: """Create a new UpdateSchema to alter the columns of this table. - Returns: - A new UpdateSchema. - """ - return UpdateSchema(self, allow_incompatible_changes=allow_incompatible_changes, case_sensitive=case_sensitive) - - def update_snapshot(self) -> UpdateSnapshot: - """Create a new UpdateSnapshot to produce a new snapshot for the table. + Args: + allow_incompatible_changes: If changes are allowed that might break downstream consumers. + case_sensitive: If field names are case-sensitive. Returns: - A new UpdateSnapshot + A new UpdateSchema. """ - return UpdateSnapshot(self) + return UpdateSchema( + transaction=Transaction(self, autocommit=True), + allow_incompatible_changes=allow_incompatible_changes, + case_sensitive=case_sensitive, + name_mapping=self.name_mapping(), + ) def name_mapping(self) -> Optional[NameMapping]: """Return the table's field-id NameMapping.""" @@ -1154,12 +1090,15 @@ def append(self, df: pa.Table) -> None: _check_schema(self.schema(), other_schema=df.schema) - with self.update_snapshot().fast_append() as update_snapshot: - # skip writing data files if the dataframe is empty - if df.shape[0] > 0: - data_files = _dataframe_to_data_files(self, write_uuid=update_snapshot.commit_uuid, df=df) - for data_file in data_files: - update_snapshot.append_data_file(data_file) + with self.transaction() as txn: + with txn.update_snapshot().fast_append() as update_snapshot: + # skip writing data files if the dataframe is empty + if df.shape[0] > 0: + data_files = _dataframe_to_data_files( + table_metadata=self.metadata, write_uuid=update_snapshot.commit_uuid, df=df, io=self.io + ) + for data_file in data_files: + update_snapshot.append_data_file(data_file) def overwrite(self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_TRUE) -> None: """ @@ -1186,15 +1125,18 @@ def overwrite(self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_T _check_schema(self.schema(), other_schema=df.schema) - with self.update_snapshot().overwrite() as update_snapshot: - # skip writing data files if the dataframe is empty - if df.shape[0] > 0: - data_files = _dataframe_to_data_files(self, write_uuid=update_snapshot.commit_uuid, df=df) - for data_file in data_files: - update_snapshot.append_data_file(data_file) + with self.transaction() as txn: + with txn.update_snapshot().overwrite() as update_snapshot: + # skip writing data files if the dataframe is empty + if df.shape[0] > 0: + data_files = _dataframe_to_data_files( + table_metadata=self.metadata, write_uuid=update_snapshot.commit_uuid, df=df, io=self.io + ) + for data_file in data_files: + update_snapshot.append_data_file(data_file) def update_spec(self, case_sensitive: bool = True) -> UpdateSpec: - return UpdateSpec(self, case_sensitive=case_sensitive) + return UpdateSpec(Transaction(self, autocommit=True), case_sensitive=case_sensitive) def refs(self) -> Dict[str, SnapshotRef]: """Return the snapshot references in the table.""" @@ -1613,8 +1555,31 @@ class Move: other_field_id: Optional[int] = None -class UpdateSchema: - _table: Optional[Table] +U = TypeVar('U') + + +class UpdateTableMetadata(ABC, Generic[U]): + _transaction: Transaction + + def __init__(self, transaction: Transaction) -> None: + self._transaction = transaction + + @abstractmethod + def _commit(self) -> UpdatesAndRequirements: ... + + def commit(self) -> None: + self._transaction._apply(*self._commit()) + + def __exit__(self, _: Any, value: Any, traceback: Any) -> None: + """Close and commit the change.""" + self.commit() + + def __enter__(self) -> U: + """Update the table.""" + return self # type: ignore + + +class UpdateSchema(UpdateTableMetadata["UpdateSchema"]): _schema: Schema _last_column_id: itertools.count[int] _identifier_field_names: Set[str] @@ -1629,27 +1594,25 @@ class UpdateSchema: _id_to_parent: Dict[int, str] = {} _allow_incompatible_changes: bool _case_sensitive: bool - _transaction: Optional[Transaction] def __init__( self, - table: Optional[Table], - transaction: Optional[Transaction] = None, + transaction: Transaction, allow_incompatible_changes: bool = False, case_sensitive: bool = True, schema: Optional[Schema] = None, + name_mapping: Optional[NameMapping] = None, ) -> None: - self._table = table + super().__init__(transaction) if isinstance(schema, Schema): self._schema = schema self._last_column_id = itertools.count(1 + schema.highest_field_id) - elif table is not None: - self._schema = table.schema() - self._last_column_id = itertools.count(1 + table.metadata.last_column_id) else: - raise ValueError("Either provide a table or a schema") + self._schema = self._transaction.table_metadata.schema() + self._last_column_id = itertools.count(1 + self._transaction.table_metadata.last_column_id) + self._name_mapping = name_mapping self._identifier_field_names = self._schema.identifier_field_names() self._adds = {} @@ -1673,14 +1636,6 @@ def get_column_name(field_id: int) -> str: self._case_sensitive = case_sensitive self._transaction = transaction - def __exit__(self, _: Any, value: Any, traceback: Any) -> None: - """Close and commit the change.""" - return self.commit() - - def __enter__(self) -> UpdateSchema: - """Update the table.""" - return self - def case_sensitive(self, case_sensitive: bool) -> UpdateSchema: """Determine if the case of schema needs to be considered when comparing column names. @@ -2069,38 +2024,36 @@ def move_after(self, path: Union[str, Tuple[str, ...]], after_name: Union[str, T return self - def commit(self) -> None: + def _commit(self) -> UpdatesAndRequirements: """Apply the pending changes and commit.""" - if self._table is None: - raise ValueError("Requires a table to commit to") - new_schema = self._apply() - existing_schema_id = next((schema.schema_id for schema in self._table.metadata.schemas if schema == new_schema), None) + existing_schema_id = next( + (schema.schema_id for schema in self._transaction.table_metadata.schemas if schema == new_schema), None + ) + + requirements: Tuple[TableRequirement, ...] = () + updates: Tuple[TableUpdate, ...] = () # Check if it is different current schema ID - if existing_schema_id != self._table.schema().schema_id: - requirements = (AssertCurrentSchemaId(current_schema_id=self._schema.schema_id),) + if existing_schema_id != self._schema.schema_id: + requirements += (AssertCurrentSchemaId(current_schema_id=self._schema.schema_id),) if existing_schema_id is None: - last_column_id = max(self._table.metadata.last_column_id, new_schema.highest_field_id) - updates = ( + last_column_id = max(self._transaction.table_metadata.last_column_id, new_schema.highest_field_id) + updates += ( AddSchemaUpdate(schema=new_schema, last_column_id=last_column_id), SetCurrentSchemaUpdate(schema_id=-1), ) else: - updates = (SetCurrentSchemaUpdate(schema_id=existing_schema_id),) # type: ignore + updates += (SetCurrentSchemaUpdate(schema_id=existing_schema_id),) - if name_mapping := self._table.name_mapping(): + if name_mapping := self._name_mapping: updated_name_mapping = update_mapping(name_mapping, self._updates, self._adds) - updates += ( # type: ignore + updates += ( SetPropertiesUpdate(updates={TableProperties.DEFAULT_NAME_MAPPING: updated_name_mapping.model_dump_json()}), ) - if self._transaction is not None: - self._transaction._append_updates(*updates) # pylint: disable=W0212 - self._transaction._append_requirements(*requirements) # pylint: disable=W0212 - else: - self._table._do_commit(updates=updates, requirements=requirements) # pylint: disable=W0212 + return updates, requirements def _apply(self) -> Schema: """Apply the pending changes to the original schema and returns the result. @@ -2126,7 +2079,13 @@ def _apply(self) -> Schema: field_ids.add(field.field_id) - next_schema_id = 1 + (max(self._table.schemas().keys()) if self._table is not None else self._schema.schema_id) + if txn := self._transaction: + next_schema_id = 1 + ( + max(schema.schema_id for schema in txn.table_metadata.schemas) if txn.table_metadata is not None else 0 + ) + else: + next_schema_id = 0 + return Schema(*struct.fields, schema_id=next_schema_id, identifier_field_ids=field_ids) def assign_new_column_id(self) -> int: @@ -2456,20 +2415,6 @@ def _add_and_move_fields( return None if len(adds) == 0 else tuple(*fields, *adds) -def _generate_snapshot_id() -> int: - """Generate a new Snapshot ID from a UUID. - - Returns: An 64 bit long - """ - rnd_uuid = uuid.uuid4() - snapshot_id = int.from_bytes( - bytes(lhs ^ rhs for lhs, rhs in zip(rnd_uuid.bytes[0:8], rnd_uuid.bytes[8:16])), byteorder='little', signed=True - ) - snapshot_id = snapshot_id if snapshot_id >= 0 else snapshot_id * -1 - - return snapshot_id - - @dataclass(frozen=True) class WriteTask: write_uuid: uuid.UUID @@ -2496,7 +2441,7 @@ def _generate_manifest_list_path(location: str, snapshot_id: int, attempt: int, def _dataframe_to_data_files( - table: Table, df: pa.Table, write_uuid: Optional[uuid.UUID] = None, file_schema: Optional[Schema] = None + table_metadata: TableMetadata, df: pa.Table, io: FileIO, write_uuid: Optional[uuid.UUID] = None ) -> Iterable[DataFile]: """Convert a PyArrow table into a DataFile. @@ -2505,7 +2450,7 @@ def _dataframe_to_data_files( """ from pyiceberg.io.pyarrow import write_file - if len(table.spec().fields) > 0: + if len([spec for spec in table_metadata.partition_specs if spec.spec_id != 0]) > 0: raise ValueError("Cannot write to partitioned tables") counter = itertools.count(0) @@ -2513,41 +2458,33 @@ def _dataframe_to_data_files( # This is an iter, so we don't have to materialize everything every time # This will be more relevant when we start doing partitioned writes - yield from write_file(table, iter([WriteTask(write_uuid, next(counter), df)]), file_schema=file_schema) + yield from write_file(io=io, table_metadata=table_metadata, tasks=iter([WriteTask(write_uuid, next(counter), df)])) -class _MergingSnapshotProducer: +class _MergingSnapshotProducer(UpdateTableMetadata["_MergingSnapshotProducer"]): commit_uuid: uuid.UUID _operation: Operation - _table: Table _snapshot_id: int _parent_snapshot_id: Optional[int] _added_data_files: List[DataFile] - _transaction: Optional[Transaction] def __init__( self, operation: Operation, - table: Table, + transaction: Transaction, + io: FileIO, commit_uuid: Optional[uuid.UUID] = None, - transaction: Optional[Transaction] = None, ) -> None: + super().__init__(transaction) self.commit_uuid = commit_uuid or uuid.uuid4() + self._io = io self._operation = operation - self._table = table - self._snapshot_id = table.new_snapshot_id() + self._snapshot_id = self._transaction.table_metadata.new_snapshot_id() # Since we only support the main branch for now - self._parent_snapshot_id = snapshot.snapshot_id if (snapshot := self._table.current_snapshot()) else None + self._parent_snapshot_id = ( + snapshot.snapshot_id if (snapshot := self._transaction.table_metadata.current_snapshot()) else None + ) self._added_data_files = [] - self._transaction = transaction - - def __enter__(self) -> _MergingSnapshotProducer: - """Start a transaction to update the table.""" - return self - - def __exit__(self, _: Any, value: Any, traceback: Any) -> None: - """Close and commit the transaction.""" - self.commit() def append_data_file(self, data_file: DataFile) -> _MergingSnapshotProducer: self._added_data_files.append(data_file) @@ -2562,12 +2499,14 @@ def _existing_manifests(self) -> List[ManifestFile]: ... def _manifests(self) -> List[ManifestFile]: def _write_added_manifest() -> List[ManifestFile]: if self._added_data_files: - output_file_location = _new_manifest_path(location=self._table.location(), num=0, commit_uuid=self.commit_uuid) + output_file_location = _new_manifest_path( + location=self._transaction.table_metadata.location, num=0, commit_uuid=self.commit_uuid + ) with write_manifest( - format_version=self._table.format_version, - spec=self._table.spec(), - schema=self._table.schema(), - output_file=self._table.io.new_output(output_file_location), + format_version=self._transaction.table_metadata.format_version, + spec=self._transaction.table_metadata.spec(), + schema=self._transaction.table_metadata.schema(), + output_file=self._io.new_output(output_file_location), snapshot_id=self._snapshot_id, ) as writer: for data_file in self._added_data_files: @@ -2588,13 +2527,15 @@ def _write_delete_manifest() -> List[ManifestFile]: # Check if we need to mark the files as deleted deleted_entries = self._deleted_entries() if len(deleted_entries) > 0: - output_file_location = _new_manifest_path(location=self._table.location(), num=1, commit_uuid=self.commit_uuid) + output_file_location = _new_manifest_path( + location=self._transaction.table_metadata.location, num=1, commit_uuid=self.commit_uuid + ) with write_manifest( - format_version=self._table.format_version, - spec=self._table.spec(), - schema=self._table.schema(), - output_file=self._table.io.new_output(output_file_location), + format_version=self._transaction.table_metadata.format_version, + spec=self._transaction.table_metadata.spec(), + schema=self._transaction.table_metadata.schema(), + output_file=self._io.new_output(output_file_location), snapshot_id=self._snapshot_id, ) as writer: for delete_entry in deleted_entries: @@ -2617,7 +2558,11 @@ def _summary(self) -> Summary: for data_file in self._added_data_files: ssc.add_file(data_file=data_file) - previous_snapshot = self._table.snapshot_by_id(self._parent_snapshot_id) if self._parent_snapshot_id is not None else None + previous_snapshot = ( + self._transaction.table_metadata.snapshot_by_id(self._parent_snapshot_id) + if self._parent_snapshot_id is not None + else None + ) return update_snapshot_summaries( summary=Summary(operation=self._operation, **ssc.build()), @@ -2625,18 +2570,21 @@ def _summary(self) -> Summary: truncate_full_table=self._operation == Operation.OVERWRITE, ) - def commit(self) -> Snapshot: + def _commit(self) -> UpdatesAndRequirements: new_manifests = self._manifests() - next_sequence_number = self._table.next_sequence_number() + next_sequence_number = self._transaction.table_metadata.next_sequence_number() summary = self._summary() manifest_list_file_path = _generate_manifest_list_path( - location=self._table.location(), snapshot_id=self._snapshot_id, attempt=0, commit_uuid=self.commit_uuid + location=self._transaction.table_metadata.location, + snapshot_id=self._snapshot_id, + attempt=0, + commit_uuid=self.commit_uuid, ) with write_manifest_list( - format_version=self._table.metadata.format_version, - output_file=self._table.io.new_output(manifest_list_file_path), + format_version=self._transaction.table_metadata.format_version, + output_file=self._io.new_output(manifest_list_file_path), snapshot_id=self._snapshot_id, parent_snapshot_id=self._parent_snapshot_id, sequence_number=next_sequence_number, @@ -2649,22 +2597,21 @@ def commit(self) -> Snapshot: manifest_list=manifest_list_file_path, sequence_number=next_sequence_number, summary=summary, - schema_id=self._table.schema().schema_id, + schema_id=self._transaction.table_metadata.current_schema_id, ) - if self._transaction is not None: - self._transaction.add_snapshot(snapshot=snapshot) - self._transaction.set_ref_snapshot( - snapshot_id=self._snapshot_id, parent_snapshot_id=self._parent_snapshot_id, ref_name="main", type="branch" - ) - else: - with self._table.transaction() as tx: - tx.add_snapshot(snapshot=snapshot) - tx.set_ref_snapshot( + return ( + ( + AddSnapshotUpdate(snapshot=snapshot), + SetSnapshotRefUpdate( snapshot_id=self._snapshot_id, parent_snapshot_id=self._parent_snapshot_id, ref_name="main", type="branch" - ) - - return snapshot + ), + ), + ( + AssertTableUUID(uuid=self._transaction.table_metadata.table_uuid), + AssertRefSnapshotId(snapshot_id=self._parent_snapshot_id, ref="main"), + ), + ) class FastAppendFiles(_MergingSnapshotProducer): @@ -2677,12 +2624,12 @@ def _existing_manifests(self) -> List[ManifestFile]: existing_manifests = [] if self._parent_snapshot_id is not None: - previous_snapshot = self._table.snapshot_by_id(self._parent_snapshot_id) + previous_snapshot = self._transaction.table_metadata.snapshot_by_id(self._parent_snapshot_id) if previous_snapshot is None: raise ValueError(f"Snapshot could not be found: {self._parent_snapshot_id}") - for manifest in previous_snapshot.manifests(io=self._table.io): + for manifest in previous_snapshot.manifests(io=self._io): if manifest.has_added_files() or manifest.has_existing_files() or manifest.added_snapshot_id == self._snapshot_id: existing_manifests.append(manifest) @@ -2713,7 +2660,7 @@ def _deleted_entries(self) -> List[ManifestEntry]: which entries are affected. """ if self._parent_snapshot_id is not None: - previous_snapshot = self._table.snapshot_by_id(self._parent_snapshot_id) + previous_snapshot = self._transaction.table_metadata.snapshot_by_id(self._parent_snapshot_id) if previous_snapshot is None: # This should never happen since you cannot overwrite an empty table raise ValueError(f"Could not find the previous snapshot: {self._parent_snapshot_id}") @@ -2729,37 +2676,39 @@ def _get_entries(manifest: ManifestFile) -> List[ManifestEntry]: file_sequence_number=entry.file_sequence_number, data_file=entry.data_file, ) - for entry in manifest.fetch_manifest_entry(self._table.io, discard_deleted=True) + for entry in manifest.fetch_manifest_entry(self._io, discard_deleted=True) if entry.data_file.content == DataFileContent.DATA ] - list_of_entries = executor.map(_get_entries, previous_snapshot.manifests(self._table.io)) + list_of_entries = executor.map(_get_entries, previous_snapshot.manifests(self._io)) return list(chain(*list_of_entries)) else: return [] class UpdateSnapshot: - _table: Table - _transaction: Optional[Transaction] + _transaction: Transaction + _io: FileIO - def __init__(self, table: Table, transaction: Optional[Transaction] = None) -> None: - self._table = table + def __init__(self, transaction: Transaction, io: FileIO) -> None: self._transaction = transaction + self._io = io def fast_append(self) -> FastAppendFiles: - return FastAppendFiles(table=self._table, operation=Operation.APPEND, transaction=self._transaction) + return FastAppendFiles(operation=Operation.APPEND, transaction=self._transaction, io=self._io) def overwrite(self) -> OverwriteFiles: return OverwriteFiles( - table=self._table, - operation=Operation.OVERWRITE if self._table.current_snapshot() is not None else Operation.APPEND, + operation=Operation.OVERWRITE + if self._transaction.table_metadata.current_snapshot() is not None + else Operation.APPEND, transaction=self._transaction, + io=self._io, ) -class UpdateSpec: - _table: Table +class UpdateSpec(UpdateTableMetadata["UpdateSpec"]): + _transaction: Transaction _name_to_field: Dict[str, PartitionField] = {} _name_to_added_field: Dict[str, PartitionField] = {} _transform_to_field: Dict[Tuple[int, str], PartitionField] = {} @@ -2770,17 +2719,18 @@ class UpdateSpec: _adds: List[PartitionField] _deletes: Set[int] _last_assigned_partition_id: int - _transaction: Optional[Transaction] - def __init__(self, table: Table, transaction: Optional[Transaction] = None, case_sensitive: bool = True) -> None: - self._table = table - self._name_to_field = {field.name: field for field in table.spec().fields} + def __init__(self, transaction: Transaction, case_sensitive: bool = True) -> None: + super().__init__(transaction) + self._name_to_field = {field.name: field for field in transaction.table_metadata.spec().fields} self._name_to_added_field = {} - self._transform_to_field = {(field.source_id, repr(field.transform)): field for field in table.spec().fields} + self._transform_to_field = { + (field.source_id, repr(field.transform)): field for field in transaction.table_metadata.spec().fields + } self._transform_to_added_field = {} self._adds = [] self._deletes = set() - self._last_assigned_partition_id = table.last_partition_id() + self._last_assigned_partition_id = transaction.table_metadata.last_partition_id or PARTITION_FIELD_ID_START - 1 self._renames = {} self._transaction = transaction self._case_sensitive = case_sensitive @@ -2793,7 +2743,7 @@ def add_field( partition_field_name: Optional[str] = None, ) -> UpdateSpec: ref = Reference(source_column_name) - bound_ref = ref.bind(self._table.schema(), self._case_sensitive) + bound_ref = ref.bind(self._transaction.table_metadata.schema(), self._case_sensitive) # verify transform can actually bind it output_type = bound_ref.field.field_type if not transform.can_transform(output_type): @@ -2864,31 +2814,24 @@ def rename_field(self, name: str, new_name: str) -> UpdateSpec: self._renames[name] = new_name return self - def commit(self) -> None: + def _commit(self) -> UpdatesAndRequirements: new_spec = self._apply() - if self._table.metadata.default_spec_id != new_spec.spec_id: - if new_spec.spec_id not in self._table.specs(): - updates = [AddPartitionSpecUpdate(spec=new_spec), SetDefaultSpecUpdate(spec_id=-1)] - else: - updates = [SetDefaultSpecUpdate(spec_id=new_spec.spec_id)] - - required_last_assigned_partitioned_id = self._table.last_partition_id() - requirements = [AssertLastAssignedPartitionId(last_assigned_partition_id=required_last_assigned_partitioned_id)] + updates: Tuple[TableUpdate, ...] = () + requirements: Tuple[TableRequirement, ...] = () - if self._transaction is not None: - self._transaction._append_updates(*updates) # pylint: disable=W0212 - self._transaction._append_requirements(*requirements) # pylint: disable=W0212 + if self._transaction.table_metadata.default_spec_id != new_spec.spec_id: + if new_spec.spec_id not in self._transaction.table_metadata.specs(): + updates = ( + AddPartitionSpecUpdate(spec=new_spec), + SetDefaultSpecUpdate(spec_id=-1), + ) else: - requirements.append(AssertDefaultSpecId(default_spec_id=self._table.spec().spec_id)) - self._table._do_commit(updates=tuple(updates), requirements=tuple(requirements)) # pylint: disable=W0212 + updates = (SetDefaultSpecUpdate(spec_id=new_spec.spec_id),) - def __exit__(self, _: Any, value: Any, traceback: Any) -> None: - """Close and commit the change.""" - return self.commit() + required_last_assigned_partitioned_id = self._transaction.table_metadata.last_partition_id + requirements = (AssertLastAssignedPartitionId(last_assigned_partition_id=required_last_assigned_partitioned_id),) - def __enter__(self) -> UpdateSpec: - """Update the table.""" - return self + return updates, requirements def _apply(self) -> PartitionSpec: def _check_and_add_partition_name(schema: Schema, name: str, source_id: int, partition_names: Set[str]) -> None: @@ -2915,27 +2858,47 @@ def _add_new_field( partition_fields = [] partition_names: Set[str] = set() - for field in self._table.spec().fields: + for field in self._transaction.table_metadata.spec().fields: if field.field_id not in self._deletes: renamed = self._renames.get(field.name) if renamed: new_field = _add_new_field( - self._table.schema(), field.source_id, field.field_id, renamed, field.transform, partition_names + self._transaction.table_metadata.schema(), + field.source_id, + field.field_id, + renamed, + field.transform, + partition_names, ) else: new_field = _add_new_field( - self._table.schema(), field.source_id, field.field_id, field.name, field.transform, partition_names + self._transaction.table_metadata.schema(), + field.source_id, + field.field_id, + field.name, + field.transform, + partition_names, ) partition_fields.append(new_field) - elif self._table.format_version == 1: + elif self._transaction.table_metadata.format_version == 1: renamed = self._renames.get(field.name) if renamed: new_field = _add_new_field( - self._table.schema(), field.source_id, field.field_id, renamed, VoidTransform(), partition_names + self._transaction.table_metadata.schema(), + field.source_id, + field.field_id, + renamed, + VoidTransform(), + partition_names, ) else: new_field = _add_new_field( - self._table.schema(), field.source_id, field.field_id, field.name, VoidTransform(), partition_names + self._transaction.table_metadata.schema(), + field.source_id, + field.field_id, + field.name, + VoidTransform(), + partition_names, ) partition_fields.append(new_field) @@ -2952,7 +2915,7 @@ def _add_new_field( # Reuse spec id or create a new one. new_spec = PartitionSpec(*partition_fields) new_spec_id = INITIAL_PARTITION_SPEC_ID - for spec in self._table.specs().values(): + for spec in self._transaction.table_metadata.specs().values(): if new_spec.compatible_with(spec): new_spec_id = spec.spec_id break @@ -2961,10 +2924,10 @@ def _add_new_field( return PartitionSpec(*partition_fields, spec_id=new_spec_id) def _partition_field(self, transform_key: Tuple[int, Transform[Any, Any]], name: Optional[str]) -> PartitionField: - if self._table.metadata.format_version == 2: + if self._transaction.table_metadata.format_version == 2: source_id, transform = transform_key historical_fields = [] - for spec in self._table.specs().values(): + for spec in self._transaction.table_metadata.specs().values(): for field in spec.fields: historical_fields.append((field.source_id, field.field_id, repr(field.transform), field.name)) @@ -2976,7 +2939,7 @@ def _partition_field(self, transform_key: Tuple[int, Transform[Any, Any]], name: new_field_id = self._new_field_id() if name is None: tmp_field = PartitionField(transform_key[0], new_field_id, transform_key[1], 'unassigned_field_name') - name = _visit_partition_field(self._table.schema(), tmp_field, _PartitionNameGenerator()) + name = _visit_partition_field(self._transaction.table_metadata.schema(), tmp_field, _PartitionNameGenerator()) return PartitionField(transform_key[0], new_field_id, transform_key[1], name) def _new_field_id(self) -> int: diff --git a/pyiceberg/table/metadata.py b/pyiceberg/table/metadata.py index 2fd8b13af4..c716915192 100644 --- a/pyiceberg/table/metadata.py +++ b/pyiceberg/table/metadata.py @@ -226,11 +226,54 @@ def schema_by_id(self, schema_id: int) -> Optional[Schema]: """Get the schema by schema_id.""" return next((schema for schema in self.schemas if schema.schema_id == schema_id), None) + def schema(self) -> Schema: + """Return the schema for this table.""" + return next(schema for schema in self.schemas if schema.schema_id == self.current_schema_id) + + def spec(self) -> PartitionSpec: + """Return the partition spec of this table.""" + return next(spec for spec in self.partition_specs if spec.spec_id == self.default_spec_id) + + def specs(self) -> Dict[int, PartitionSpec]: + """Return a dict the partition specs this table.""" + return {spec.spec_id: spec for spec in self.partition_specs} + + def new_snapshot_id(self) -> int: + """Generate a new snapshot-id that's not in use.""" + snapshot_id = _generate_snapshot_id() + while self.snapshot_by_id(snapshot_id) is not None: + snapshot_id = _generate_snapshot_id() + + return snapshot_id + + def current_snapshot(self) -> Optional[Snapshot]: + """Get the current snapshot for this table, or None if there is no current snapshot.""" + if self.current_snapshot_id is not None: + return self.snapshot_by_id(self.current_snapshot_id) + return None + + def next_sequence_number(self) -> int: + return self.last_sequence_number + 1 if self.format_version > 1 else INITIAL_SEQUENCE_NUMBER + def sort_order_by_id(self, sort_order_id: int) -> Optional[SortOrder]: """Get the sort order by sort_order_id.""" return next((sort_order for sort_order in self.sort_orders if sort_order.order_id == sort_order_id), None) +def _generate_snapshot_id() -> int: + """Generate a new Snapshot ID from a UUID. + + Returns: An 64 bit long + """ + rnd_uuid = uuid.uuid4() + snapshot_id = int.from_bytes( + bytes(lhs ^ rhs for lhs, rhs in zip(rnd_uuid.bytes[0:8], rnd_uuid.bytes[8:16])), byteorder='little', signed=True + ) + snapshot_id = snapshot_id if snapshot_id >= 0 else snapshot_id * -1 + + return snapshot_id + + class TableMetadataV1(TableMetadataCommonFields, IcebergBaseModel): """Represents version 1 of the Table Metadata. diff --git a/tests/catalog/test_sql.py b/tests/catalog/test_sql.py index 2b1167f1b6..9f4d4af4c7 100644 --- a/tests/catalog/test_sql.py +++ b/tests/catalog/test_sql.py @@ -39,6 +39,7 @@ from pyiceberg.io.pyarrow import schema_to_pyarrow from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC from pyiceberg.schema import Schema +from pyiceberg.table import _dataframe_to_data_files from pyiceberg.table.snapshots import Operation from pyiceberg.table.sorting import ( NullOrder, @@ -863,3 +864,54 @@ def test_concurrent_commit_table(catalog: SqlCatalog, table_schema_simple: Schem # This one should fail since it already has been updated with table_b.update_schema() as update: update.add_column(path="c", field_type=IntegerType()) + + +@pytest.mark.parametrize( + 'catalog', + [ + lazy_fixture('catalog_memory'), + lazy_fixture('catalog_sqlite'), + lazy_fixture('catalog_sqlite_without_rowcount'), + ], +) +@pytest.mark.parametrize("format_version", [1, 2]) +def test_write_and_evolve(catalog: SqlCatalog, format_version: int) -> None: + identifier = f"default.arrow_write_data_and_evolve_schema_v{format_version}" + + try: + catalog.create_namespace("default") + except NamespaceAlreadyExistsError: + pass + + try: + catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + + pa_table = pa.Table.from_pydict( + { + 'foo': ['a', None, 'z'], + }, + schema=pa.schema([pa.field("foo", pa.string(), nullable=True)]), + ) + + tbl = catalog.create_table(identifier=identifier, schema=pa_table.schema, properties={"format-version": str(format_version)}) + + pa_table_with_column = pa.Table.from_pydict( + { + 'foo': ['a', None, 'z'], + 'bar': [19, None, 25], + }, + schema=pa.schema([ + pa.field("foo", pa.string(), nullable=True), + pa.field("bar", pa.int32(), nullable=True), + ]), + ) + + with tbl.transaction() as txn: + with txn.update_schema() as schema_txn: + schema_txn.union_by_name(pa_table_with_column.schema) + + with txn.update_snapshot().fast_append() as snapshot_update: + for data_file in _dataframe_to_data_files(table_metadata=txn.table_metadata, df=pa_table_with_column, io=tbl.io): + snapshot_update.append_data_file(data_file) diff --git a/tests/integration/test_rest_schema.py b/tests/integration/test_rest_schema.py index aae07caba9..17fb338080 100644 --- a/tests/integration/test_rest_schema.py +++ b/tests/integration/test_rest_schema.py @@ -84,7 +84,7 @@ def _create_table_with_schema(catalog: Catalog, schema: Schema) -> Table: @pytest.mark.integration def test_add_already_exists(catalog: Catalog, table_schema_nested: Schema) -> None: table = _create_table_with_schema(catalog, table_schema_nested) - update = UpdateSchema(table) + update = table.update_schema() with pytest.raises(ValueError) as exc_info: update.add_column("foo", IntegerType()) @@ -98,7 +98,7 @@ def test_add_already_exists(catalog: Catalog, table_schema_nested: Schema) -> No @pytest.mark.integration def test_add_to_non_struct_type(catalog: Catalog, table_schema_simple: Schema) -> None: table = _create_table_with_schema(catalog, table_schema_simple) - update = UpdateSchema(table) + update = table.update_schema() with pytest.raises(ValueError) as exc_info: update.add_column(path=("foo", "lat"), field_type=IntegerType()) assert "Cannot add column 'lat' to non-struct type: foo" in str(exc_info.value) @@ -1066,13 +1066,13 @@ def test_add_nested_list_of_structs(catalog: Catalog) -> None: def test_add_required_column(catalog: Catalog) -> None: schema_ = Schema(NestedField(field_id=1, name="a", field_type=BooleanType(), required=False)) table = _create_table_with_schema(catalog, schema_) - update = UpdateSchema(table) + update = table.update_schema() with pytest.raises(ValueError) as exc_info: update.add_column(path="data", field_type=IntegerType(), required=True) assert "Incompatible change: cannot add required column: data" in str(exc_info.value) new_schema = ( - UpdateSchema(table, allow_incompatible_changes=True) # pylint: disable=W0212 + UpdateSchema(transaction=table.transaction(), allow_incompatible_changes=True) .add_column(path="data", field_type=IntegerType(), required=True) ._apply() ) @@ -1088,12 +1088,13 @@ def test_add_required_column_case_insensitive(catalog: Catalog) -> None: table = _create_table_with_schema(catalog, schema_) with pytest.raises(ValueError) as exc_info: - with UpdateSchema(table, allow_incompatible_changes=True) as update: - update.case_sensitive(False).add_column(path="ID", field_type=IntegerType(), required=True) + with table.transaction() as txn: + with txn.update_schema(allow_incompatible_changes=True) as update: + update.case_sensitive(False).add_column(path="ID", field_type=IntegerType(), required=True) assert "already exists: ID" in str(exc_info.value) new_schema = ( - UpdateSchema(table, allow_incompatible_changes=True) # pylint: disable=W0212 + UpdateSchema(transaction=table.transaction(), allow_incompatible_changes=True) .add_column(path="ID", field_type=IntegerType(), required=True) ._apply() ) @@ -1264,7 +1265,7 @@ def test_mixed_changes(catalog: Catalog) -> None: @pytest.mark.integration def test_ambiguous_column(catalog: Catalog, table_schema_nested: Schema) -> None: table = _create_table_with_schema(catalog, table_schema_nested) - update = UpdateSchema(table) + update = UpdateSchema(transaction=table.transaction()) with pytest.raises(ValueError) as exc_info: update.add_column(path="location.latitude", field_type=IntegerType()) @@ -2507,16 +2508,14 @@ def test_two_add_schemas_in_a_single_transaction(catalog: Catalog) -> None: ), ) - with pytest.raises(ValueError) as exc_info: + with pytest.raises(CommitFailedException) as exc_info: with tbl.transaction() as tr: with tr.update_schema() as update: update.add_column("bar", field_type=StringType()) with tr.update_schema() as update: update.add_column("baz", field_type=StringType()) - assert "Updates in a single commit need to be unique, duplicate: " in str( - exc_info.value - ) + assert "CommitFailedException: Requirement failed: current schema changed: expected id 1 != 0" in str(exc_info.value) @pytest.mark.integration diff --git a/tests/integration/test_writes.py b/tests/integration/test_writes.py index a16ba48a27..388e566bce 100644 --- a/tests/integration/test_writes.py +++ b/tests/integration/test_writes.py @@ -652,5 +652,5 @@ def test_write_and_evolve(session_catalog: Catalog, format_version: int) -> None schema_txn.union_by_name(pa_table_with_column.schema) with txn.update_snapshot().fast_append() as snapshot_update: - for data_file in _dataframe_to_data_files(table=tbl, df=pa_table_with_column, file_schema=txn.schema()): + for data_file in _dataframe_to_data_files(table_metadata=txn.table_metadata, df=pa_table_with_column, io=tbl.io): snapshot_update.append_data_file(data_file) diff --git a/tests/table/test_init.py b/tests/table/test_init.py index 39aa72f8db..e6407b60cb 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -62,12 +62,11 @@ UpdateSchema, _apply_table_update, _check_schema, - _generate_snapshot_id, _match_deletes_to_data_file, _TableMetadataUpdateContext, update_table_metadata, ) -from pyiceberg.table.metadata import INITIAL_SEQUENCE_NUMBER, TableMetadataUtil, TableMetadataV2 +from pyiceberg.table.metadata import INITIAL_SEQUENCE_NUMBER, TableMetadataUtil, TableMetadataV2, _generate_snapshot_id from pyiceberg.table.snapshots import ( Operation, Snapshot, @@ -435,7 +434,7 @@ def test_serialize_set_properties_updates() -> None: def test_add_column(table_v2: Table) -> None: - update = UpdateSchema(table_v2) + update = UpdateSchema(transaction=table_v2.transaction()) update.add_column(path="b", field_type=IntegerType()) apply_schema: Schema = update._apply() # pylint: disable=W0212 assert len(apply_schema.fields) == 4 @@ -469,7 +468,7 @@ def test_add_primitive_type_column(table_v2: Table) -> None: for name, type_ in primitive_type.items(): field_name = f"new_column_{name}" - update = UpdateSchema(table_v2) + update = UpdateSchema(transaction=table_v2.transaction()) update.add_column(path=field_name, field_type=type_, doc=f"new_column_{name}") new_schema = update._apply() # pylint: disable=W0212 @@ -481,7 +480,7 @@ def test_add_primitive_type_column(table_v2: Table) -> None: def test_add_nested_type_column(table_v2: Table) -> None: # add struct type column field_name = "new_column_struct" - update = UpdateSchema(table_v2) + update = UpdateSchema(transaction=table_v2.transaction()) struct_ = StructType( NestedField(1, "lat", DoubleType()), NestedField(2, "long", DoubleType()), @@ -499,7 +498,7 @@ def test_add_nested_type_column(table_v2: Table) -> None: def test_add_nested_map_type_column(table_v2: Table) -> None: # add map type column field_name = "new_column_map" - update = UpdateSchema(table_v2) + update = UpdateSchema(transaction=table_v2.transaction()) map_ = MapType(1, StringType(), 2, IntegerType(), False) update.add_column(path=field_name, field_type=map_) new_schema = update._apply() # pylint: disable=W0212 @@ -511,7 +510,7 @@ def test_add_nested_map_type_column(table_v2: Table) -> None: def test_add_nested_list_type_column(table_v2: Table) -> None: # add list type column field_name = "new_column_list" - update = UpdateSchema(table_v2) + update = UpdateSchema(transaction=table_v2.transaction()) list_ = ListType( element_id=101, element_type=StructType( @@ -806,7 +805,7 @@ def test_metadata_isolation_from_illegal_updates(table_v1: Table) -> None: def test_generate_snapshot_id(table_v2: Table) -> None: assert isinstance(_generate_snapshot_id(), int) - assert isinstance(table_v2.new_snapshot_id(), int) + assert isinstance(table_v2.metadata.new_snapshot_id(), int) def test_assert_create(table_v2: Table) -> None: diff --git a/tests/test_schema.py b/tests/test_schema.py index cfee6e7f14..6394b72ba6 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -928,7 +928,7 @@ def primitive_fields() -> List[NestedField]: def test_add_top_level_primitives(primitive_fields: NestedField) -> None: for primitive_field in primitive_fields: new_schema = Schema(primitive_field) - applied = UpdateSchema(None, schema=Schema()).union_by_name(new_schema)._apply() + applied = UpdateSchema(transaction=None, schema=Schema()).union_by_name(new_schema)._apply() # type: ignore assert applied == new_schema @@ -942,7 +942,7 @@ def test_add_top_level_list_of_primitives(primitive_fields: NestedField) -> None required=False, ) ) - applied = UpdateSchema(None, schema=Schema()).union_by_name(new_schema)._apply() + applied = UpdateSchema(transaction=None, schema=Schema()).union_by_name(new_schema)._apply() # type: ignore assert applied.as_struct() == new_schema.as_struct() @@ -958,7 +958,7 @@ def test_add_top_level_map_of_primitives(primitive_fields: NestedField) -> None: required=False, ) ) - applied = UpdateSchema(None, schema=Schema()).union_by_name(new_schema)._apply() + applied = UpdateSchema(transaction=None, schema=Schema()).union_by_name(new_schema)._apply() # type: ignore assert applied.as_struct() == new_schema.as_struct() @@ -972,7 +972,7 @@ def test_add_top_struct_of_primitives(primitive_fields: NestedField) -> None: required=False, ) ) - applied = UpdateSchema(None, schema=Schema()).union_by_name(new_schema)._apply() + applied = UpdateSchema(transaction=None, schema=Schema()).union_by_name(new_schema)._apply() # type: ignore assert applied.as_struct() == new_schema.as_struct() @@ -987,7 +987,7 @@ def test_add_nested_primitive(primitive_fields: NestedField) -> None: required=False, ) ) - applied = UpdateSchema(None, schema=current_schema).union_by_name(new_schema)._apply() + applied = UpdateSchema(None, None, schema=current_schema).union_by_name(new_schema)._apply() # type: ignore assert applied.as_struct() == new_schema.as_struct() @@ -1007,7 +1007,7 @@ def test_add_nested_primitives(primitive_fields: NestedField) -> None: field_id=1, name="aStruct", field_type=StructType(*_primitive_fields(TEST_PRIMITIVE_TYPES, 2)), required=False ) ) - applied = UpdateSchema(None, schema=current_schema).union_by_name(new_schema)._apply() + applied = UpdateSchema(transaction=None, schema=current_schema).union_by_name(new_schema)._apply() # type: ignore assert applied.as_struct() == new_schema.as_struct() @@ -1048,7 +1048,7 @@ def test_add_nested_lists(primitive_fields: NestedField) -> None: required=False, ) ) - applied = UpdateSchema(None, schema=Schema()).union_by_name(new_schema)._apply() + applied = UpdateSchema(transaction=None, schema=Schema()).union_by_name(new_schema)._apply() # type: ignore assert applied.as_struct() == new_schema.as_struct() @@ -1098,7 +1098,7 @@ def test_add_nested_struct(primitive_fields: NestedField) -> None: required=False, ) ) - applied = UpdateSchema(None, schema=Schema()).union_by_name(new_schema)._apply() + applied = UpdateSchema(transaction=None, schema=Schema()).union_by_name(new_schema)._apply() # type: ignore assert applied.as_struct() == new_schema.as_struct() @@ -1141,7 +1141,7 @@ def test_add_nested_maps(primitive_fields: NestedField) -> None: required=False, ) ) - applied = UpdateSchema(None, schema=Schema()).union_by_name(new_schema)._apply() + applied = UpdateSchema(transaction=None, schema=Schema()).union_by_name(new_schema)._apply() # type: ignore assert applied.as_struct() == new_schema.as_struct() @@ -1164,7 +1164,7 @@ def test_detect_invalid_top_level_list() -> None: ) with pytest.raises(ValidationError, match="Cannot change column type: aList.element: string -> double"): - _ = UpdateSchema(None, schema=current_schema).union_by_name(new_schema)._apply() + _ = UpdateSchema(transaction=None, schema=current_schema).union_by_name(new_schema)._apply() # type: ignore def test_detect_invalid_top_level_maps() -> None: @@ -1186,14 +1186,14 @@ def test_detect_invalid_top_level_maps() -> None: ) with pytest.raises(ValidationError, match="Cannot change column type: aMap.key: string -> uuid"): - _ = UpdateSchema(None, schema=current_schema).union_by_name(new_schema)._apply() + _ = UpdateSchema(transaction=None, schema=current_schema).union_by_name(new_schema)._apply() # type: ignore def test_promote_float_to_double() -> None: current_schema = Schema(NestedField(field_id=1, name="aCol", field_type=FloatType(), required=False)) new_schema = Schema(NestedField(field_id=1, name="aCol", field_type=DoubleType(), required=False)) - applied = UpdateSchema(None, schema=current_schema).union_by_name(new_schema)._apply() + applied = UpdateSchema(transaction=None, schema=current_schema).union_by_name(new_schema)._apply() # type: ignore assert applied.as_struct() == new_schema.as_struct() assert len(applied.fields) == 1 @@ -1205,7 +1205,7 @@ def test_detect_invalid_promotion_double_to_float() -> None: new_schema = Schema(NestedField(field_id=1, name="aCol", field_type=FloatType(), required=False)) with pytest.raises(ValidationError, match="Cannot change column type: aCol: double -> float"): - _ = UpdateSchema(None, schema=current_schema).union_by_name(new_schema)._apply() + _ = UpdateSchema(transaction=None, schema=current_schema).union_by_name(new_schema)._apply() # type: ignore # decimal(P,S) Fixed-point decimal; precision P, scale S -> Scale is fixed [1], @@ -1214,7 +1214,7 @@ def test_type_promote_decimal_to_fixed_scale_with_wider_precision() -> None: current_schema = Schema(NestedField(field_id=1, name="aCol", field_type=DecimalType(precision=20, scale=1), required=False)) new_schema = Schema(NestedField(field_id=1, name="aCol", field_type=DecimalType(precision=22, scale=1), required=False)) - applied = UpdateSchema(None, schema=current_schema).union_by_name(new_schema)._apply() + applied = UpdateSchema(transaction=None, schema=current_schema).union_by_name(new_schema)._apply() # type: ignore assert applied.as_struct() == new_schema.as_struct() assert len(applied.fields) == 1 @@ -1282,7 +1282,7 @@ def test_add_nested_structs(primitive_fields: NestedField) -> None: required=False, ) ) - applied = UpdateSchema(None, schema=schema).union_by_name(new_schema)._apply() + applied = UpdateSchema(transaction=None, schema=schema).union_by_name(new_schema)._apply() # type: ignore expected = Schema( NestedField( @@ -1322,7 +1322,7 @@ def test_replace_list_with_primitive() -> None: new_schema = Schema(NestedField(field_id=1, name="aCol", field_type=StringType())) with pytest.raises(ValidationError, match="Cannot change column type: list is not a primitive"): - _ = UpdateSchema(None, schema=current_schema).union_by_name(new_schema)._apply() + _ = UpdateSchema(transaction=None, schema=current_schema).union_by_name(new_schema)._apply() # type: ignore def test_mirrored_schemas() -> None: @@ -1345,7 +1345,7 @@ def test_mirrored_schemas() -> None: NestedField(9, "string6", StringType(), required=False), ) - applied = UpdateSchema(None, schema=current_schema).union_by_name(mirrored_schema)._apply() + applied = UpdateSchema(transaction=None, schema=current_schema).union_by_name(mirrored_schema)._apply() # type: ignore assert applied.as_struct() == current_schema.as_struct() @@ -1397,7 +1397,7 @@ def test_add_new_top_level_struct() -> None: ), ) - applied = UpdateSchema(None, schema=current_schema).union_by_name(observed_schema)._apply() + applied = UpdateSchema(transaction=None, schema=current_schema).union_by_name(observed_schema)._apply() # type: ignore assert applied.as_struct() == observed_schema.as_struct() @@ -1476,7 +1476,7 @@ def test_append_nested_struct() -> None: ) ) - applied = UpdateSchema(None, schema=current_schema).union_by_name(observed_schema)._apply() + applied = UpdateSchema(transaction=None, schema=current_schema).union_by_name(observed_schema)._apply() # type: ignore assert applied.as_struct() == observed_schema.as_struct() @@ -1541,7 +1541,7 @@ def test_append_nested_lists() -> None: required=False, ) ) - union = UpdateSchema(None, schema=current_schema).union_by_name(observed_schema)._apply() + union = UpdateSchema(transaction=None, schema=current_schema).union_by_name(observed_schema)._apply() # type: ignore expected = Schema( NestedField( @@ -1591,7 +1591,7 @@ def test_union_with_pa_schema(primitive_fields: NestedField) -> None: pa.field("baz", pa.bool_(), nullable=True), ]) - new_schema = UpdateSchema(None, schema=base_schema).union_by_name(pa_schema)._apply() + new_schema = UpdateSchema(transaction=None, schema=base_schema).union_by_name(pa_schema)._apply() # type: ignore expected_schema = Schema( NestedField(field_id=1, name="foo", field_type=StringType(), required=True),