Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Snapshot Retention Properties #913

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2198,6 +2198,70 @@ def create_branch(
self._requirements += requirement
return self

def set_min_snapshots_to_keep(self, branch_name: str, min_snapshots_to_keep: int) -> ManageSnapshots:
"""
Set minimum number of snapshots to keep on the given branch.

Args:
branch_name (str): name of the branch
min_snapshots_to_keep (int): minimum number of snapshots to keep

Returns:
This for method chaining
"""
self._commit_if_ref_updates_exist()
if branch_name not in self._transaction.table_metadata.refs:
raise ValidationError(f"ref {branch_name} not found")
snapshot_id = self._transaction.table_metadata.refs[branch_name].snapshot_id
return self._set_ref_snapshot(
snapshot_id=snapshot_id,
ref_name=branch_name,
type=str(SnapshotRefType.BRANCH),
min_snapshots_to_keep=min_snapshots_to_keep,
)

def set_max_snapshot_age_ms(self, branch_name: str, max_snapshot_age_ms: int) -> ManageSnapshots:
"""
Set minimum number of snapshots to keep on the given branch.

Args:
branch_name (str): name of the branch
max_snapshot_age_ms (int): maximum snapshot age in milliseconds

Returns:
This for method chaining
"""
self._commit_if_ref_updates_exist()
if branch_name not in self._transaction.table_metadata.refs:
raise ValidationError(f"ref {branch_name} not found")
snapshot_id = self._transaction.table_metadata.refs[branch_name].snapshot_id
return self._set_ref_snapshot(
snapshot_id=snapshot_id,
ref_name=branch_name,
type=str(SnapshotRefType.BRANCH),
max_snapshot_age_ms=max_snapshot_age_ms,
)

def set_max_ref_age_ms(self, ref_name: str, max_ref_age_ms: int) -> ManageSnapshots:
"""
Set minimum number of snapshots to keep on the given branch / tag.

Args:
ref_name (str): name of the branch / tag
max_ref_age_ms (int): maximum ref age in milliseconds

Returns:
This for method chaining
"""
self._commit_if_ref_updates_exist()
if ref_name not in self._transaction.table_metadata.refs:
raise ValidationError(f"ref {ref_name} not found")
snapshot_id = self._transaction.table_metadata.refs[ref_name].snapshot_id
ref_type = self._transaction.table_metadata.refs[ref_name].snapshot_ref_type
return self._set_ref_snapshot(
snapshot_id=snapshot_id, ref_name=ref_name, type=str(ref_type), max_ref_age_ms=max_ref_age_ms
)


class UpdateSchema(UpdateTableMetadata["UpdateSchema"]):
_schema: Schema
Expand Down
53 changes: 53 additions & 0 deletions tests/integration/test_snapshot_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,56 @@ def test_create_branch(catalog: Catalog) -> None:
branch_snapshot_id = tbl.history()[-2].snapshot_id
tbl.manage_snapshots().create_branch(snapshot_id=branch_snapshot_id, branch_name="branch123").commit()
assert tbl.metadata.refs["branch123"] == SnapshotRef(snapshot_id=branch_snapshot_id, snapshot_ref_type="branch")


@pytest.mark.integration
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
def test_set_min_snapshots_to_keep(catalog: Catalog) -> None:
identifier = "default.test_table_snapshot_operations"
tbl = catalog.load_table(identifier)
assert len(tbl.history()) > 2
snapshot_id = tbl.history()[-2].snapshot_id
branch_name, min_snapshots_to_keep = "test_branch_min_snapshots_to_keep", 2
with tbl.manage_snapshots() as ms:
ms.create_branch(branch_name=branch_name, snapshot_id=snapshot_id)
ms.set_min_snapshots_to_keep(branch_name=branch_name, min_snapshots_to_keep=min_snapshots_to_keep)
assert tbl.metadata.refs[branch_name] == SnapshotRef(
snapshot_id=snapshot_id, snapshot_ref_type=str(SnapshotRefType.BRANCH), min_snapshots_to_keep=min_snapshots_to_keep
)


@pytest.mark.integration
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
def test_set_max_snapshot_age_ms(catalog: Catalog) -> None:
identifier = "default.test_table_snapshot_operations"
tbl = catalog.load_table(identifier)
assert len(tbl.history()) > 3
snapshot_id = tbl.history()[-3].snapshot_id
branch_name, max_snapshot_age_ms = "test_branch_max_snapshot_age_ms", 3600000
with tbl.manage_snapshots() as ms:
ms.create_branch(branch_name=branch_name, snapshot_id=snapshot_id)
ms.set_max_snapshot_age_ms(branch_name=branch_name, max_snapshot_age_ms=max_snapshot_age_ms)
assert tbl.metadata.refs[branch_name] == SnapshotRef(
snapshot_id=snapshot_id, snapshot_ref_type=str(SnapshotRefType.BRANCH), max_snapshot_age_ms=max_snapshot_age_ms
)


@pytest.mark.integration
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
def test_set_max_ref_age_ms(catalog: Catalog) -> None:
identifier = "default.test_table_snapshot_operations"
tbl = catalog.load_table(identifier)
assert len(tbl.history()) > 4
branch_snapshot_id, tag_snapshot_id = tbl.history()[-2].snapshot_id, tbl.history()[-3].snapshot_id
branch_name, tag_name, max_ref_age_ms = "test_branch_max_ref_age_ms", "test_tag_max_ref_age_ms", 604800000
with tbl.manage_snapshots() as ms:
ms.create_branch(branch_name=branch_name, snapshot_id=branch_snapshot_id)
ms.set_max_ref_age_ms(ref_name=branch_name, max_ref_age_ms=max_ref_age_ms)
ms.create_tag(tag_name=tag_name, snapshot_id=tag_snapshot_id)
ms.set_max_ref_age_ms(ref_name=tag_name, max_ref_age_ms=max_ref_age_ms)
assert tbl.metadata.refs[branch_name] == SnapshotRef(
snapshot_id=branch_snapshot_id, snapshot_ref_type=str(SnapshotRefType.BRANCH), max_ref_age_ms=max_ref_age_ms
)
assert tbl.metadata.refs[tag_name] == SnapshotRef(
snapshot_id=tag_snapshot_id, snapshot_ref_type=str(SnapshotRefType.TAG), max_ref_age_ms=max_ref_age_ms
)