Skip to content

Commit

Permalink
Address feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
Steve Zhang committed May 10, 2022
1 parent 9a08f94 commit 75707bd
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 17 deletions.
29 changes: 13 additions & 16 deletions python/src/iceberg/table/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Dict, List, Tuple
from typing import Dict, Iterable, List, Tuple

from iceberg.schema import Schema
from iceberg.transforms import Transform

_PARTITION_DATA_ID_START: int = 1000


class PartitionField:
"""
Expand Down Expand Up @@ -83,15 +85,13 @@ class PartitionSpec:
last_assigned_field_id(int): auto-increment partition field id starting from PARTITION_DATA_ID_START
"""

PARTITION_DATA_ID_START: int = 1000

def __init__(self, schema: Schema, spec_id: int, fields: Tuple[PartitionField], last_assigned_field_id: int):
def __init__(self, schema: Schema, spec_id: int, fields: Iterable[PartitionField], last_assigned_field_id: int):
self._schema = schema
self._spec_id = spec_id
self._fields = fields
self._fields = tuple(fields)
self._last_assigned_field_id = last_assigned_field_id
# derived
self.fields_by_source_id: Dict[int, List[PartitionField]] = {}
self._fields_by_source_id: Dict[int, List[PartitionField]] = {}

@property
def schema(self) -> Schema:
Expand All @@ -102,7 +102,7 @@ def spec_id(self) -> int:
return self._spec_id

@property
def fields(self) -> Tuple[PartitionField]:
def fields(self) -> Tuple[PartitionField, ...]:
return self._fields

@property
Expand Down Expand Up @@ -131,28 +131,26 @@ def __hash__(self):
def is_unpartitioned(self) -> bool:
return len(self.fields) < 1

def get_fields_by_source_id(self, filed_id: int) -> List[PartitionField]:
if not self.fields_by_source_id:
def fields_by_source_id(self, field_id: int) -> List[PartitionField]:
if not self._fields_by_source_id:
for partition_field in self.fields:
source_column = self.schema.find_column_name(partition_field.source_id)
if not source_column:
raise ValueError(f"Cannot find source column: {partition_field.source_id}")
existing = self.fields_by_source_id.get(partition_field.source_id, [])
existing = self._fields_by_source_id.get(partition_field.source_id, [])
existing.append(partition_field)
self.fields_by_source_id[partition_field.source_id] = existing
return self.fields_by_source_id.get(filed_id) # type: ignore
self._fields_by_source_id[partition_field.source_id] = existing
return self._fields_by_source_id[field_id]

def compatible_with(self, other) -> bool:
"""
Returns true if this partition spec is equivalent to the other, with partition field_id ignored.
That is, if both specs have the same number of fields, field order, field name, source column ids, and transforms.
"""
if self.__eq__(other):
if self == other:
return True

if len(self.fields) != len(other.fields):
return False

for index in range(len(self.fields)):
this_field = self.fields[index]
that_field = other.fields[index]
Expand All @@ -162,5 +160,4 @@ def compatible_with(self, other) -> bool:
or this_field.name != that_field.name
):
return False

return True
2 changes: 1 addition & 1 deletion python/tests/table/test_partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_partition_spec_init(table_schema_simple: Schema):
partition_spec2 = PartitionSpec(table_schema_simple, 0, (id_field2,), 1001)
assert hash(partition_spec1) != hash(partition_spec2)
assert partition_spec1.compatible_with(partition_spec2)
assert partition_spec1.get_fields_by_source_id(3) == [id_field1]
assert partition_spec1.fields_by_source_id(3) == [id_field1]


def test_unpartitioned(table_schema_simple: Schema):
Expand Down

0 comments on commit 75707bd

Please sign in to comment.