Skip to content

Commit

Permalink
Dataclass with PartitionSpec
Browse files Browse the repository at this point in the history
  • Loading branch information
Steve Zhang committed May 19, 2022
1 parent 9f52a6f commit 136228a
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 72 deletions.
1 change: 1 addition & 0 deletions python/spellcheck-dictionary.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ io
NativeFile
NestedField
nullability
PartitionField
pragma
PrimitiveType
pyarrow
Expand Down
106 changes: 35 additions & 71 deletions python/src/iceberg/table/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Dict, Iterable, List, Tuple
from dataclasses import dataclass, field
from typing import Dict, List, Tuple

from iceberg.schema import Schema
from iceberg.transforms import Transform

_PARTITION_DATA_ID_START: int = 1000


@dataclass(frozen=True)
class PartitionField:
"""
PartitionField is a single element with name and unique id,
Expand All @@ -33,43 +36,16 @@ class PartitionField:
name(str): The name of this partition field
"""

def __init__(self, source_id: int, field_id: int, transform: Transform, name: str):
self._source_id = source_id
self._field_id = field_id
self._transform = transform
self._name = name

@property
def source_id(self) -> int:
return self._source_id

@property
def field_id(self) -> int:
return self._field_id

@property
def name(self) -> str:
return self._name

@property
def transform(self) -> Transform:
return self._transform

def __eq__(self, other):
return (
self.field_id == other.field_id
and self.source_id == other.source_id
and self.name == other.name
and self.transform == other.transform
)
source_id: int
field_id: int
transform: Transform
name: str

def __str__(self):
return f"{self.field_id}: {self.name}: {self.transform}({self.source_id})"

def __repr__(self):
return f"PartitionField(field_id={self.field_id}, name={self.name}, transform={repr(self.transform)}, source_id={self.source_id})"


@dataclass(eq=False)
class PartitionSpec:
"""
PartitionSpec capture the transformation from table data to partition values
Expand All @@ -81,34 +57,32 @@ class PartitionSpec:
last_assigned_field_id(int): auto-increment partition field id starting from PARTITION_DATA_ID_START
"""

def __init__(self, schema: Schema, spec_id: int, fields: Iterable[PartitionField], last_assigned_field_id: int):
self._schema = schema
self._spec_id = spec_id
self._fields = tuple(fields)
self._last_assigned_field_id = last_assigned_field_id
# derived
self._fields_by_source_id: Dict[int, List[PartitionField]] = {}

@property
def schema(self) -> Schema:
return self._schema
schema: Schema
spec_id: int
fields: Tuple[PartitionField]
last_assigned_field_id: int
source_id_to_fields_map: Dict[int, List[PartitionField]] = field(init=False)

@property
def spec_id(self) -> int:
return self._spec_id

@property
def fields(self) -> Tuple[PartitionField, ...]:
return self._fields

@property
def last_assigned_field_id(self) -> int:
return self._last_assigned_field_id
def __post_init__(self):
self.source_id_to_fields_map = dict()
for partition_field in self.fields:
source_column = self.schema.find_column_name(partition_field.source_id)
if not source_column:
raise ValueError(f"Cannot find source column: {partition_field.source_id}")
existing = self.source_id_to_fields_map.get(partition_field.source_id, [])
existing.append(partition_field)
self.source_id_to_fields_map[partition_field.source_id] = existing

def __eq__(self, other):
"""
Equality check on spec_id and partition fields only
"""
return self.spec_id == other.spec_id and self.fields == other.fields

def __str__(self):
"""
PartitionSpec str method highlight the partition field only
"""
result_str = "["
for partition_field in self.fields:
result_str += f"\n {str(partition_field)}"
Expand All @@ -117,30 +91,20 @@ def __str__(self):
result_str += "]"
return result_str

def __repr__(self):
return f"PartitionSpec: {str(self)}"

def is_unpartitioned(self) -> bool:
return len(self.fields) < 1

def fields_by_source_id(self, field_id: int) -> List[PartitionField]:
if not self._fields_by_source_id:
for partition_field in self.fields:
source_column = self.schema.find_column_name(partition_field.source_id)
if not source_column:
raise ValueError(f"Cannot find source column: {partition_field.source_id}")
existing = self._fields_by_source_id.get(partition_field.source_id, [])
existing.append(partition_field)
self._fields_by_source_id[partition_field.source_id] = existing
return self._fields_by_source_id[field_id]
return self.source_id_to_fields_map[field_id]

def compatible_with(self, other: "PartitionSpec") -> bool:
"""
Returns true if this partition spec is equivalent to the other, with partition field_id ignored.
That is, if both specs have the same number of fields, field order, field name, source column ids, and transforms.
"""
return all(
this_field.source_id == that_field.source_id and this_field.transform == that_field.transform and this_field.name == that_field.name
for this_field, that_field
in zip(self.fields, other.fields)
)
this_field.source_id == that_field.source_id
and this_field.transform == that_field.transform
and this_field.name == that_field.name
for this_field, that_field in zip(self.fields, other.fields)
)
5 changes: 4 additions & 1 deletion python/tests/table/test_partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,13 @@ def test_partition_field_init():
assert partition_field.transform == bucket_transform
assert partition_field.name == "id"
assert partition_field == partition_field
print(str(partition_field))
print("repr")
print(repr(partition_field))
assert str(partition_field) == "1000: id: bucket[100](3)"
assert (
repr(partition_field)
== "PartitionField(field_id=1000, name=id, transform=transforms.bucket(source_type=IntegerType(), num_buckets=100), source_id=3)"
== "PartitionField(source_id=3, field_id=1000, transform=transforms.bucket(source_type=IntegerType(), num_buckets=100), name='id')"
)


Expand Down

0 comments on commit 136228a

Please sign in to comment.