From f27f181e25cc6079b08efb983b305a89ee67e052 Mon Sep 17 00:00:00 2001 From: samredai <43911210+samredai@users.noreply.github.com> Date: Sun, 13 Mar 2022 15:03:48 -0700 Subject: [PATCH 01/23] Add Schema class Co-authored-by: nssalian --- python/src/iceberg/table/schema.py | 182 ++++++++++++++++++ python/tests/table/test_schema.py | 292 +++++++++++++++++++++++++++++ 2 files changed, 474 insertions(+) create mode 100644 python/src/iceberg/table/schema.py create mode 100644 python/tests/table/test_schema.py diff --git a/python/src/iceberg/table/schema.py b/python/src/iceberg/table/schema.py new file mode 100644 index 000000000000..b233a23e160c --- /dev/null +++ b/python/src/iceberg/table/schema.py @@ -0,0 +1,182 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from enum import Enum, auto +from typing import Iterable + +try: + from typing import Literal +except ImportError: # pragma: no cover + from typing_extensions import Literal # type: ignore + +from iceberg.types import NestedField, StructType + + +class FIELD_IDENTIFIER_TYPES(Enum): + NAME = auto() + ALIAS = auto() + + +class Schema: + """Schema of a table + + Example: + >>> from iceberg.table.schema import Schema + >>> from iceberg.types import BooleanType, IntegerType, NestedField, StringType + >>> fields = [ + NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), + ] + >>> table_schema = Schema(fields=fields, schema_id=1, aliases={"qux": 3}) + >>> print(table_schema) + 1: name=foo, type=string, required=True + 2: name=bar, type=int, required=False + 3: name=baz, type=boolean, required=True + """ + + def __init__(self, fields: Iterable[NestedField], schema_id: int, aliases: dict = {}): + self._struct = StructType(*fields) + self._schema_id = schema_id + self._aliases = aliases + + def __str__(self): + schema_str = "" + for field in self.fields: + schema_str += f"{field.field_id}: name={field.name}, type={field.type}, required={field.is_required}\n" + return schema_str.rstrip() + + def __repr__(self): + return f"Schema(fields={repr(self.fields)}, schema_id={self.schema_id})" + + @property + def fields(self): + return self._struct.fields + + @property + def schema_id(self): + return self._schema_id + + @property + def struct(self): + return self._struct + + def _get_field_identifier_by_name(self, field_identifier: str, case_sensitive: bool): + """Get a field ID for a given field name + + Args: + field_identifier (str): A field name + case_sensitive (bool): If False, case will not be considered when retrieving the field from the schema + + Returns: + int: The field ID for the field name + """ + if case_sensitive: + for field in self.fields: + if field.name == field_identifier: + return field.field_id + else: + for field in self.fields: + if field.name.lower() == field_identifier.lower(): + return field.field_id + raise ValueError(f"Cannot get field ID, name not found: {field_identifier}") + + def _get_field_identifier_by_alias(self, field_identifier: str, case_sensitive: bool): + """Get a field ID for a given field alias + + Args: + field_identifier (str): A field alias + case_sensitive (bool): If False, case will not be considered when retrieving the field from the schema + + Returns: + int: The field ID for the field alias + """ + if case_sensitive: + try: + # For case-sensitive, just try looking up the alias and raise if it's not found + return self._aliases[field_identifier] + except KeyError: + raise ValueError(f"Cannot get field ID, alias not found: {field_identifier}") + if not case_sensitive: + matching_fields = [value for key, value in self._aliases.items() if key.lower() == field_identifier.lower()] + if len(matching_fields) == 1: # If one matching alias found, return the corresponding ID + return matching_fields[0] + elif not matching_fields: # If no matching fields, raise + raise ValueError(f"Cannot get field ID, alias not found: {field_identifier}") + + # If multiple IDs are returned for a case-insensitive alias lookup, raise + raise ValueError(f"Cannot get field ID, case-insensitive alias returns multiple results: {field_identifier}") + + def get_field_id( + self, + field_identifier: str, + field_identifier_type: Literal[FIELD_IDENTIFIER_TYPES.NAME, FIELD_IDENTIFIER_TYPES.ALIAS], + case_sensitive: bool = True, + ) -> int: + """Get a field ID for a given NAME or ALIAS + + This calls either the `_get_field_identifier_by_name` method or the `_get_field_identifier_by_alias` method depending on the value of the `field_identifier_type` argument. + + Args: + field_identifier (str): The unique identifier for the field + field_identifier_type (FIELD_IDENTIFIER_TYPES): An FIELD_IDENTIFIER_TYPES value of either NAME or ALIAS + case_sensitive (bool): If False, case will not be considered when retrieving the field from the schema. Default is True + + Raises: + ValueError: If the field identifier does not exist in the schema + + Returns + int: The field ID for the given field identifier + """ + if field_identifier_type == FIELD_IDENTIFIER_TYPES.NAME: + return self._get_field_identifier_by_name(field_identifier=field_identifier, case_sensitive=case_sensitive) + elif field_identifier_type == FIELD_IDENTIFIER_TYPES.ALIAS: + return self._get_field_identifier_by_alias(field_identifier=field_identifier, case_sensitive=case_sensitive) + + def get_field(self, field_id: int) -> NestedField: + """Get a field object for a given field ID + + This returns the NestedField instance for the given `field_id`. + + Args: + field_id (int): The ID of the field + + Raises: + ValueError: If the field ID does not exist in this schema + + Returns + NestedField: The field object for the given field ID + """ + for field in self.fields: + if field.field_id == field_id: + return field + raise ValueError(f"Cannot get field, ID does not exist: {field_id}") + + def get_type(self, field_id: int): + """Get the type of a field by field ID + + Args: + field_id (int): The ID of the field + + Raises: + ValueError: If the field ID does not exist in this schema + + Returns + IcebergType: The type of the field with ID of `field_id` + """ + field = self.get_field(field_id) + return field.type diff --git a/python/tests/table/test_schema.py b/python/tests/table/test_schema.py new file mode 100644 index 000000000000..f99e3f86ce07 --- /dev/null +++ b/python/tests/table/test_schema.py @@ -0,0 +1,292 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +from numpy import False_ + +from iceberg.table import schema +from iceberg.types import BooleanType, IntegerType, NestedField, StringType, StructType + + +def test_schema_init(): + """Test initializing a schema from a list of fields""" + fields = [ + NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), + ] + table_schema = schema.Schema(fields=fields, schema_id=1) + schema_struct = table_schema.struct + + assert table_schema.fields[0] == fields[0] + assert table_schema.fields[1] == fields[1] + assert table_schema.fields[2] == fields[2] + assert table_schema.schema_id == 1 + assert isinstance(schema_struct, StructType) + + +def test_schema_str(): + """Test casting a schema to a string""" + fields = [ + NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), + ] + table_schema = schema.Schema(fields=fields, schema_id=1) + assert """1: name=foo, type=string, required=True +2: name=bar, type=int, required=False +3: name=baz, type=boolean, required=True""" == str( + table_schema + ) + + +@pytest.mark.parametrize( + "schema, expected_repr", + [ + ( + schema.Schema(fields=[NestedField(1, "foo", StringType())], schema_id=1), + "Schema(fields=(NestedField(field_id=1, name='foo', field_type=StringType(), is_optional=True),), schema_id=1)", + ), + ( + schema.Schema( + fields=[NestedField(1, "foo", StringType()), NestedField(2, "bar", IntegerType(), is_optional=False)], schema_id=2 + ), + "Schema(fields=(NestedField(field_id=1, name='foo', field_type=StringType(), is_optional=True), NestedField(field_id=2, name='bar', field_type=IntegerType(), is_optional=False)), schema_id=2)", + ), + ], +) +def test_schema_repr(schema, expected_repr): + """Test schema representation""" + assert repr(schema) == expected_repr + + +def test_schema_get_field_id_case_sensitive(): + """Test case-sensitive retrieval of a field ID using the `get_field_id` method""" + fields = [ + NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), + ] + table_schema = schema.Schema(fields=fields, schema_id=1, aliases={"qux": 1, "foobar": 2}) + assert ( + table_schema.get_field_id( + field_identifier="foo", field_identifier_type=schema.FIELD_IDENTIFIER_TYPES.NAME, case_sensitive=True + ) + == 1 + ) + assert ( + table_schema.get_field_id( + field_identifier="bar", field_identifier_type=schema.FIELD_IDENTIFIER_TYPES.NAME, case_sensitive=True + ) + == 2 + ) + assert ( + table_schema.get_field_id( + field_identifier="baz", field_identifier_type=schema.FIELD_IDENTIFIER_TYPES.NAME, case_sensitive=True + ) + == 3 + ) + assert ( + table_schema.get_field_id( + field_identifier="qux", field_identifier_type=schema.FIELD_IDENTIFIER_TYPES.ALIAS, case_sensitive=True + ) + == 1 + ) + assert ( + table_schema.get_field_id( + field_identifier="foobar", field_identifier_type=schema.FIELD_IDENTIFIER_TYPES.ALIAS, case_sensitive=True + ) + == 2 + ) + + +def test_schema_get_field_id_case_insensitive(): + """Test case-insensitive retrieval of a field ID using the `get_field_id` method""" + fields = [ + NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), + ] + table_schema = schema.Schema(fields=fields, schema_id=1, aliases={"qux": 1, "foobar": 2}) + assert ( + table_schema.get_field_id( + field_identifier="fOO", field_identifier_type=schema.FIELD_IDENTIFIER_TYPES.NAME, case_sensitive=False + ) + == 1 + ) + assert ( + table_schema.get_field_id( + field_identifier="BAr", field_identifier_type=schema.FIELD_IDENTIFIER_TYPES.NAME, case_sensitive=False + ) + == 2 + ) + assert ( + table_schema.get_field_id( + field_identifier="BaZ", field_identifier_type=schema.FIELD_IDENTIFIER_TYPES.NAME, case_sensitive=False + ) + == 3 + ) + assert ( + table_schema.get_field_id( + field_identifier="qUx", field_identifier_type=schema.FIELD_IDENTIFIER_TYPES.ALIAS, case_sensitive=False + ) + == 1 + ) + assert ( + table_schema.get_field_id( + field_identifier="fooBAR", field_identifier_type=schema.FIELD_IDENTIFIER_TYPES.ALIAS, case_sensitive=False + ) + == 2 + ) + + +def test_schema_get_field_id_raise_on_not_found(): + """Test raising when the field ID for a given name or alias cannot be found""" + fields = [ + NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), + ] + table_schema = schema.Schema(fields=fields, schema_id=1, aliases={"qux": 1, "foobar": 2}) + + with pytest.raises(ValueError) as exc_info: + table_schema.get_field_id(field_identifier="name1", field_identifier_type=schema.FIELD_IDENTIFIER_TYPES.NAME) + + assert "Cannot get field ID, name not found: name1" in str(exc_info.value) + + with pytest.raises(ValueError) as exc_info: + table_schema.get_field_id( + field_identifier="name2", field_identifier_type=schema.FIELD_IDENTIFIER_TYPES.NAME, case_sensitive=False_ + ) + + assert "Cannot get field ID, name not found: name2" in str(exc_info.value) + + with pytest.raises(ValueError) as exc_info: + table_schema.get_field_id( + field_identifier="case_insensitive_name1", + field_identifier_type=schema.FIELD_IDENTIFIER_TYPES.NAME, + case_sensitive=False, + ) + + assert "Cannot get field ID, name not found: case_insensitive_name1" in str(exc_info.value) + + with pytest.raises(ValueError) as exc_info: + table_schema.get_field_id( + field_identifier="case_insensitive_name2", field_identifier_type=schema.FIELD_IDENTIFIER_TYPES.NAME + ) + + assert "Cannot get field ID, name not found: case_insensitive_name2" in str(exc_info.value) + + with pytest.raises(ValueError) as exc_info: + table_schema.get_field_id(field_identifier="alias1", field_identifier_type=schema.FIELD_IDENTIFIER_TYPES.ALIAS) + + assert "Cannot get field ID, alias not found: alias1" in str(exc_info.value) + + with pytest.raises(ValueError) as exc_info: + table_schema.get_field_id(field_identifier="alias2", field_identifier_type=schema.FIELD_IDENTIFIER_TYPES.ALIAS) + + assert "Cannot get field ID, alias not found: alias2" in str(exc_info.value) + + with pytest.raises(ValueError) as exc_info: + table_schema.get_field_id( + field_identifier="case_insensitive_alias1", + field_identifier_type=schema.FIELD_IDENTIFIER_TYPES.ALIAS, + case_sensitive=False, + ) + + assert "Cannot get field ID, alias not found: case_insensitive_alias1" in str(exc_info.value) + + with pytest.raises(ValueError) as exc_info: + table_schema.get_field_id( + field_identifier="case_insensitive_alias2", + field_identifier_type=schema.FIELD_IDENTIFIER_TYPES.ALIAS, + case_sensitive=False, + ) + + assert "Cannot get field ID, alias not found: case_insensitive_alias2" in str(exc_info.value) + + +def test_schema_get_field_id_raise_on_multiple_case_insensitive_alias_match(): + """Test raising when a case-insensitive alias search returns multiple aliases""" + fields = [ + NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), + ] + table_schema = schema.Schema(fields=fields, schema_id=1, aliases={"qux": 1, "QUX": 2}) + + with pytest.raises(ValueError) as exc_info: + table_schema.get_field_id( + field_identifier="qux", field_identifier_type=schema.FIELD_IDENTIFIER_TYPES.ALIAS, case_sensitive=False + ) + + assert "Cannot get field ID, case-insensitive alias returns multiple results: qux" in str(exc_info.value) + + +def test_schema_get_field(): + """Test retrieving a field using the field's ID""" + fields = [ + NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), + ] + table_schema = schema.Schema(fields=fields, schema_id=1) + field1 = table_schema.get_field(field_id=1) + field2 = table_schema.get_field(field_id=2) + field3 = table_schema.get_field(field_id=3) + + assert isinstance(field1, NestedField) + assert field1.field_id == 1 + assert field1.type == StringType() + assert field1.is_optional == False + assert isinstance(field2, NestedField) + assert field2.field_id == 2 + assert field2.type == IntegerType() + assert field2.is_optional == True + assert isinstance(field3, NestedField) + assert field3.field_id == 3 + assert field3.type == BooleanType() + assert field3.is_optional == False + + +def test_schema_get_field_raise_on_unknown_field(): + """Test raising when the field ID is not found""" + fields = [ + NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), + ] + table_schema = schema.Schema(fields=fields, schema_id=1) + + with pytest.raises(ValueError) as exc_info: + table_schema.get_field(field_id=4) + + assert "Cannot get field, ID does not exist: 4" in str(exc_info.value) + + +def test_schema_get_type(): + """Test retrieving a field's type using the field's ID""" + fields = [ + NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), + ] + table_schema = schema.Schema(fields=fields, schema_id=1) + + assert table_schema.get_type(field_id=1) == StringType() + assert table_schema.get_type(field_id=2) == IntegerType() + assert table_schema.get_type(field_id=3) == BooleanType() From 7cd8b960b9aa777165ea08ece93ecf9ac330c277 Mon Sep 17 00:00:00 2001 From: samredai <43911210+samredai@users.noreply.github.com> Date: Sun, 13 Mar 2022 16:21:03 -0700 Subject: [PATCH 02/23] Use Enum for field_identifier_type typehint and set default value to NAME --- python/src/iceberg/table/schema.py | 13 +++----- python/tests/table/test_schema.py | 52 ++++++------------------------ 2 files changed, 14 insertions(+), 51 deletions(-) diff --git a/python/src/iceberg/table/schema.py b/python/src/iceberg/table/schema.py index b233a23e160c..06fc06a0d9be 100644 --- a/python/src/iceberg/table/schema.py +++ b/python/src/iceberg/table/schema.py @@ -18,11 +18,6 @@ from enum import Enum, auto from typing import Iterable -try: - from typing import Literal -except ImportError: # pragma: no cover - from typing_extensions import Literal # type: ignore - from iceberg.types import NestedField, StructType @@ -124,7 +119,7 @@ def _get_field_identifier_by_alias(self, field_identifier: str, case_sensitive: def get_field_id( self, field_identifier: str, - field_identifier_type: Literal[FIELD_IDENTIFIER_TYPES.NAME, FIELD_IDENTIFIER_TYPES.ALIAS], + field_identifier_type: FIELD_IDENTIFIER_TYPES = FIELD_IDENTIFIER_TYPES.NAME, case_sensitive: bool = True, ) -> int: """Get a field ID for a given NAME or ALIAS @@ -143,9 +138,11 @@ def get_field_id( int: The field ID for the given field identifier """ if field_identifier_type == FIELD_IDENTIFIER_TYPES.NAME: - return self._get_field_identifier_by_name(field_identifier=field_identifier, case_sensitive=case_sensitive) + field_id = self._get_field_identifier_by_name(field_identifier=field_identifier, case_sensitive=case_sensitive) elif field_identifier_type == FIELD_IDENTIFIER_TYPES.ALIAS: - return self._get_field_identifier_by_alias(field_identifier=field_identifier, case_sensitive=case_sensitive) + field_id = self._get_field_identifier_by_alias(field_identifier=field_identifier, case_sensitive=case_sensitive) + + return field_id def get_field(self, field_id: int) -> NestedField: """Get a field object for a given field ID diff --git a/python/tests/table/test_schema.py b/python/tests/table/test_schema.py index f99e3f86ce07..87728925a26e 100644 --- a/python/tests/table/test_schema.py +++ b/python/tests/table/test_schema.py @@ -82,24 +82,9 @@ def test_schema_get_field_id_case_sensitive(): NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), ] table_schema = schema.Schema(fields=fields, schema_id=1, aliases={"qux": 1, "foobar": 2}) - assert ( - table_schema.get_field_id( - field_identifier="foo", field_identifier_type=schema.FIELD_IDENTIFIER_TYPES.NAME, case_sensitive=True - ) - == 1 - ) - assert ( - table_schema.get_field_id( - field_identifier="bar", field_identifier_type=schema.FIELD_IDENTIFIER_TYPES.NAME, case_sensitive=True - ) - == 2 - ) - assert ( - table_schema.get_field_id( - field_identifier="baz", field_identifier_type=schema.FIELD_IDENTIFIER_TYPES.NAME, case_sensitive=True - ) - == 3 - ) + assert table_schema.get_field_id(field_identifier="foo", case_sensitive=True) == 1 + assert table_schema.get_field_id(field_identifier="bar", case_sensitive=True) == 2 + assert table_schema.get_field_id(field_identifier="baz", case_sensitive=True) == 3 assert ( table_schema.get_field_id( field_identifier="qux", field_identifier_type=schema.FIELD_IDENTIFIER_TYPES.ALIAS, case_sensitive=True @@ -122,24 +107,9 @@ def test_schema_get_field_id_case_insensitive(): NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), ] table_schema = schema.Schema(fields=fields, schema_id=1, aliases={"qux": 1, "foobar": 2}) - assert ( - table_schema.get_field_id( - field_identifier="fOO", field_identifier_type=schema.FIELD_IDENTIFIER_TYPES.NAME, case_sensitive=False - ) - == 1 - ) - assert ( - table_schema.get_field_id( - field_identifier="BAr", field_identifier_type=schema.FIELD_IDENTIFIER_TYPES.NAME, case_sensitive=False - ) - == 2 - ) - assert ( - table_schema.get_field_id( - field_identifier="BaZ", field_identifier_type=schema.FIELD_IDENTIFIER_TYPES.NAME, case_sensitive=False - ) - == 3 - ) + assert table_schema.get_field_id(field_identifier="fOO", case_sensitive=False) == 1 + assert table_schema.get_field_id(field_identifier="BAr", case_sensitive=False) == 2 + assert table_schema.get_field_id(field_identifier="BaZ", case_sensitive=False) == 3 assert ( table_schema.get_field_id( field_identifier="qUx", field_identifier_type=schema.FIELD_IDENTIFIER_TYPES.ALIAS, case_sensitive=False @@ -164,14 +134,12 @@ def test_schema_get_field_id_raise_on_not_found(): table_schema = schema.Schema(fields=fields, schema_id=1, aliases={"qux": 1, "foobar": 2}) with pytest.raises(ValueError) as exc_info: - table_schema.get_field_id(field_identifier="name1", field_identifier_type=schema.FIELD_IDENTIFIER_TYPES.NAME) + table_schema.get_field_id(field_identifier="name1") assert "Cannot get field ID, name not found: name1" in str(exc_info.value) with pytest.raises(ValueError) as exc_info: - table_schema.get_field_id( - field_identifier="name2", field_identifier_type=schema.FIELD_IDENTIFIER_TYPES.NAME, case_sensitive=False_ - ) + table_schema.get_field_id(field_identifier="name2", case_sensitive=False_) assert "Cannot get field ID, name not found: name2" in str(exc_info.value) @@ -185,9 +153,7 @@ def test_schema_get_field_id_raise_on_not_found(): assert "Cannot get field ID, name not found: case_insensitive_name1" in str(exc_info.value) with pytest.raises(ValueError) as exc_info: - table_schema.get_field_id( - field_identifier="case_insensitive_name2", field_identifier_type=schema.FIELD_IDENTIFIER_TYPES.NAME - ) + table_schema.get_field_id(field_identifier="case_insensitive_name2") assert "Cannot get field ID, name not found: case_insensitive_name2" in str(exc_info.value) From d8b79a781ec32a01cf65b186370e44810ca6f4b0 Mon Sep 17 00:00:00 2001 From: samredai <43911210+samredai@users.noreply.github.com> Date: Sun, 13 Mar 2022 16:25:55 -0700 Subject: [PATCH 03/23] Add return typehints --- python/src/iceberg/table/schema.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/src/iceberg/table/schema.py b/python/src/iceberg/table/schema.py index 06fc06a0d9be..091564e0197b 100644 --- a/python/src/iceberg/table/schema.py +++ b/python/src/iceberg/table/schema.py @@ -18,7 +18,7 @@ from enum import Enum, auto from typing import Iterable -from iceberg.types import NestedField, StructType +from iceberg.types import IcebergType, NestedField, StructType class FIELD_IDENTIFIER_TYPES(Enum): @@ -70,7 +70,7 @@ def schema_id(self): def struct(self): return self._struct - def _get_field_identifier_by_name(self, field_identifier: str, case_sensitive: bool): + def _get_field_identifier_by_name(self, field_identifier: str, case_sensitive: bool) -> int: """Get a field ID for a given field name Args: @@ -90,7 +90,7 @@ def _get_field_identifier_by_name(self, field_identifier: str, case_sensitive: b return field.field_id raise ValueError(f"Cannot get field ID, name not found: {field_identifier}") - def _get_field_identifier_by_alias(self, field_identifier: str, case_sensitive: bool): + def _get_field_identifier_by_alias(self, field_identifier: str, case_sensitive: bool) -> int: """Get a field ID for a given field alias Args: @@ -163,7 +163,7 @@ def get_field(self, field_id: int) -> NestedField: return field raise ValueError(f"Cannot get field, ID does not exist: {field_id}") - def get_type(self, field_id: int): + def get_type(self, field_id: int) -> IcebergType: """Get the type of a field by field ID Args: From 72e7ab724bdc51690dc37c327caddc332b1ebef6 Mon Sep 17 00:00:00 2001 From: samredai <43911210+samredai@users.noreply.github.com> Date: Sun, 13 Mar 2022 23:12:18 -0700 Subject: [PATCH 04/23] Add IndexById and IndexByName schema visitors --- python/src/iceberg/table/schema.py | 283 +++++++++++++++++++++++------ python/tests/table/test_schema.py | 211 ++++++++++++++------- 2 files changed, 375 insertions(+), 119 deletions(-) diff --git a/python/src/iceberg/table/schema.py b/python/src/iceberg/table/schema.py index 091564e0197b..46de6610ee3d 100644 --- a/python/src/iceberg/table/schema.py +++ b/python/src/iceberg/table/schema.py @@ -15,10 +15,21 @@ # specific language governing permissions and limitations # under the License. +import sys from enum import Enum, auto from typing import Iterable -from iceberg.types import IcebergType, NestedField, StructType +if sys.version_info >= (3, 8): # pragma: no cover + from typing import Protocol +else: # pragma: no cover + from typing_extensions import Protocol # type: ignore + +if sys.version_info >= (3, 8): # pragma: no cover + from functools import singledispatchmethod +else: # pragma: no cover + from singledispatch import singledispatchmethod # type: ignore + +from iceberg.types import IcebergType, ListType, MapType, NestedField, StructType class FIELD_IDENTIFIER_TYPES(Enum): @@ -26,6 +37,11 @@ class FIELD_IDENTIFIER_TYPES(Enum): ALIAS = auto() +class SchemaVisitor(Protocol): + def visit(self, node) -> None: # pragma: no cover + ... + + class Schema: """Schema of a table @@ -49,6 +65,14 @@ def __init__(self, fields: Iterable[NestedField], schema_id: int, aliases: dict self._schema_id = schema_id self._aliases = aliases + index_by_id_visitor = IndexById() + index_by_id_visitor.visit(self._struct) + self._index_by_id = index_by_id_visitor.result + + index_by_name_visitor = IndexByName() + index_by_name_visitor.visit(self._struct) + self._index_by_name = index_by_name_visitor.result + def __str__(self): schema_str = "" for field in self.fields: @@ -70,101 +94,92 @@ def schema_id(self): def struct(self): return self._struct - def _get_field_identifier_by_name(self, field_identifier: str, case_sensitive: bool) -> int: + def find_field_id_by_name(self, field_name: str, case_sensitive: bool = True) -> int: """Get a field ID for a given field name Args: - field_identifier (str): A field name - case_sensitive (bool): If False, case will not be considered when retrieving the field from the schema + field_name (str): A field name + case_sensitive (bool): If False, case will not be considered when retrieving the field from the schema--default is True Returns: int: The field ID for the field name """ if case_sensitive: - for field in self.fields: - if field.name == field_identifier: - return field.field_id - else: - for field in self.fields: - if field.name.lower() == field_identifier.lower(): - return field.field_id - raise ValueError(f"Cannot get field ID, name not found: {field_identifier}") - - def _get_field_identifier_by_alias(self, field_identifier: str, case_sensitive: bool) -> int: + for indexed_field_name, indexed_field_id in self._index_by_name.items(): + if indexed_field_name == field_name: + return indexed_field_id + if not case_sensitive: + for indexed_field_name, indexed_field_id in self._index_by_name.items(): + if indexed_field_name.lower() == field_name.lower(): + return indexed_field_id + raise ValueError(f"Cannot get field ID, name not found: {field_name}") + + def find_field_id_by_alias(self, field_alias: str, case_sensitive: bool = True) -> int: """Get a field ID for a given field alias Args: - field_identifier (str): A field alias - case_sensitive (bool): If False, case will not be considered when retrieving the field from the schema + field_alias (str): A field alias + case_sensitive (bool): If False, case will not be considered when retrieving the field from the schema--default is True Returns: int: The field ID for the field alias + + raises: + ValueError: If the field ID cannot be retrieved either because the alias was not found or a case-insensitive + match returned multiple results """ if case_sensitive: try: # For case-sensitive, just try looking up the alias and raise if it's not found - return self._aliases[field_identifier] + return self._aliases[field_alias] except KeyError: - raise ValueError(f"Cannot get field ID, alias not found: {field_identifier}") + raise ValueError(f"Cannot get field ID, alias not found: {field_alias}") if not case_sensitive: - matching_fields = [value for key, value in self._aliases.items() if key.lower() == field_identifier.lower()] + matching_fields = [value for key, value in self._aliases.items() if key.lower() == field_alias.lower()] if len(matching_fields) == 1: # If one matching alias found, return the corresponding ID return matching_fields[0] elif not matching_fields: # If no matching fields, raise - raise ValueError(f"Cannot get field ID, alias not found: {field_identifier}") + raise ValueError(f"Cannot get field ID, alias not found: {field_alias}") # If multiple IDs are returned for a case-insensitive alias lookup, raise - raise ValueError(f"Cannot get field ID, case-insensitive alias returns multiple results: {field_identifier}") + raise ValueError(f"Cannot get field ID, case-insensitive alias returns multiple results: {field_alias}") - def get_field_id( - self, - field_identifier: str, - field_identifier_type: FIELD_IDENTIFIER_TYPES = FIELD_IDENTIFIER_TYPES.NAME, - case_sensitive: bool = True, - ) -> int: - """Get a field ID for a given NAME or ALIAS - - This calls either the `_get_field_identifier_by_name` method or the `_get_field_identifier_by_alias` method depending on the value of the `field_identifier_type` argument. + def find_field_name_by_field_id(self, field_id: int): + """Find a field name for a given field ID Args: - field_identifier (str): The unique identifier for the field - field_identifier_type (FIELD_IDENTIFIER_TYPES): An FIELD_IDENTIFIER_TYPES value of either NAME or ALIAS - case_sensitive (bool): If False, case will not be considered when retrieving the field from the schema. Default is True + field_id (int): A field ID - Raises: - ValueError: If the field identifier does not exist in the schema + Returns: + str: The field name for the field ID - Returns - int: The field ID for the given field identifier + Raises: + ValueError: If the field ID does not exist """ - if field_identifier_type == FIELD_IDENTIFIER_TYPES.NAME: - field_id = self._get_field_identifier_by_name(field_identifier=field_identifier, case_sensitive=case_sensitive) - elif field_identifier_type == FIELD_IDENTIFIER_TYPES.ALIAS: - field_id = self._get_field_identifier_by_alias(field_identifier=field_identifier, case_sensitive=case_sensitive) + for indexed_field_name, indexed_field_id in self._index_by_name.items(): + if indexed_field_id == field_id: + return indexed_field_name + raise ValueError(f"Cannot get field name, field ID not found: {field_id}") - return field_id - - def get_field(self, field_id: int) -> NestedField: - """Get a field object for a given field ID - - This returns the NestedField instance for the given `field_id`. + def find_field_by_id(self, field_id: int): + """Find a field by it's field ID Args: - field_id (int): The ID of the field + field_id (int): A field ID - Raises: - ValueError: If the field ID does not exist in this schema - - Returns + Returns: NestedField: The field object for the given field ID + + Raise: + ValueError: If the field ID does not exist """ for field in self.fields: if field.field_id == field_id: return field raise ValueError(f"Cannot get field, ID does not exist: {field_id}") - def get_type(self, field_id: int) -> IcebergType: - """Get the type of a field by field ID + def find_field_type(self, field_id: int) -> IcebergType: + """Find the type of a field by field ID Args: field_id (int): The ID of the field @@ -175,5 +190,163 @@ def get_type(self, field_id: int) -> IcebergType: Returns IcebergType: The type of the field with ID of `field_id` """ - field = self.get_field(field_id) + field = self.find_field_by_id(field_id) return field.type + + +class IndexById(SchemaVisitor): + """Index a Schema by IDs + + This visitor provides a field ID to field name map for a given Schema instance. The result is stored in the `result` instance attribute. + + Example: + >>> from iceberg.table.schema import IndexById, Schema + >>> from iceberg.types import BooleanType, IntegerType, NestedField, StringType + >>> fields = [ + NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), + ] + >>> table_schema = Schema(fields=fields, schema_id=1) + >>> visitor = IndexById() + >>> visitor.visit(table_schema) + >>> print(visitor.result) + {1: 'foo', 2: 'bar', 3: 'baz'} + """ + + def __init__(self): + self.result = {} + + @singledispatchmethod + def visit(self, node): + """A generic single dispatch visit method + + Raises: + NotImplementedError: If no concrete method has been registered for type of the `node` argument + """ + raise NotImplementedError(f"Cannot visit node, no IndexById operation implemented for node type: {type(node)}") + + @visit.register + def _(self, node: Schema): + self.visit(node.struct) + + @visit.register + def _(self, node: StructType): + for field in node.fields: + self.visit(field) + + @visit.register + def _(self, node: NestedField): + self.result[node.field_id] = node.name + if isinstance(node.type, (ListType, MapType)): + self.visit(node.type, nested_field_name=node.name) + + @visit.register + def _(self, node: ListType, nested_field_name: str): + """Index a ListType node + + ListType nodes add one item to the index: + 1. The ID for the ListType element and the name . + + is always "element". + Args: + node (ListType): The ListType instance for the NestedField instance + nested_field_name (str): The name of the NestedField instance containing the ListType + """ + self.result[node.element.field_id] = f"{nested_field_name}.{node.element.name}" + + @visit.register + def _(self, node: MapType, nested_field_name: str) -> None: + """Index a MapType node + + MapType nodes add two items to the index: + 1. The MapType key ID and the name . + 2. The MapType value ID and the name . + + and are always "key" and "value", respectively. + + Args: + node (MapType): The MapType instance for the NestedField instance + nested_field_name (str): The name of the NestedField instance containing the MapType + """ + self.result[node.key.field_id] = f"{nested_field_name}.key" + self.result[node.value.field_id] = f"{nested_field_name}.value" + + +class IndexByName(SchemaVisitor): + """Index a Schema by names + + This visitor provides a field name to field ID map for a given Schema instance. The result is stored in the `result` instance attribute. + + Example: + >>> from iceberg.table.schema import IndexByName, Schema + >>> from iceberg.types import BooleanType, IntegerType, NestedField, StringType + >>> fields = [ + NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), + ] + >>> table_schema = Schema(fields=fields, schema_id=1) + >>> visitor = IndexByName() + >>> visitor.visit(table_schema) + >>> print(visitor.result) + {'foo': 1, 'bar': 2, 'baz': 3} + """ + + def __init__(self): + self.result = {} + + @singledispatchmethod + def visit(self, node) -> None: + """A generic single dispatch visit method + + Raises: + NotImplementedError: If no concrete method has been registered for type of the `node` argument + """ + raise NotImplementedError(f"Cannot visit node, no IndexByName operation implemented for node type: {type(node)}") + + @visit.register + def _(self, node: Schema) -> None: + self.visit(node.struct) + + @visit.register + def _(self, node: StructType) -> None: + for field in node.fields: + self.visit(field) + + @visit.register + def _(self, node: NestedField) -> None: + self.result[node.name] = node.field_id + if isinstance(node.type, (ListType, MapType)): + self.visit(node.type, nested_field_name=node.name) + + @visit.register + def _(self, node: ListType, nested_field_name: str): + """Index a ListType node + + ListType nodes add one item to the index: + 1. The name . and the ID for the ListType element + + is always "element". + Args: + node (ListType): The ListType instance for the NestedField instance + nested_field_name (str): The name of the NestedField instance containing the ListType + """ + self.result[f"{nested_field_name}.{node.element.name}"] = node.element.field_id + + @visit.register + def _(self, node: MapType, nested_field_name: str) -> None: + """Index a MapType node + + MapType nodes add two items to the index: + 1. The name . and the MapType key ID + 2. The name . and the MapType value ID + + and are always "key" and "value", respectively. + + Args: + node (MapType): The MapType instance for the NestedField instance + nested_field_name (str): The name of the NestedField instance containing the MapType + """ + self.result[f"{nested_field_name}.{node.key.name}"] = node.key.field_id + self.result[f"{nested_field_name}.{node.value.name}"] = node.value.field_id diff --git a/python/tests/table/test_schema.py b/python/tests/table/test_schema.py index 87728925a26e..fb52b35246f8 100644 --- a/python/tests/table/test_schema.py +++ b/python/tests/table/test_schema.py @@ -19,7 +19,15 @@ from numpy import False_ from iceberg.table import schema -from iceberg.types import BooleanType, IntegerType, NestedField, StringType, StructType +from iceberg.types import ( + BooleanType, + IntegerType, + ListType, + MapType, + NestedField, + StringType, + StructType, +) def test_schema_init(): @@ -74,57 +82,50 @@ def test_schema_repr(schema, expected_repr): assert repr(schema) == expected_repr -def test_schema_get_field_id_case_sensitive(): - """Test case-sensitive retrieval of a field ID using the `get_field_id` method""" +def test_schema_find_field_id_case_sensitive(): + """Test case-sensitive retrieval of a field ID using the `find_field_id` method""" fields = [ NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), ] table_schema = schema.Schema(fields=fields, schema_id=1, aliases={"qux": 1, "foobar": 2}) - assert table_schema.get_field_id(field_identifier="foo", case_sensitive=True) == 1 - assert table_schema.get_field_id(field_identifier="bar", case_sensitive=True) == 2 - assert table_schema.get_field_id(field_identifier="baz", case_sensitive=True) == 3 - assert ( - table_schema.get_field_id( - field_identifier="qux", field_identifier_type=schema.FIELD_IDENTIFIER_TYPES.ALIAS, case_sensitive=True - ) - == 1 - ) - assert ( - table_schema.get_field_id( - field_identifier="foobar", field_identifier_type=schema.FIELD_IDENTIFIER_TYPES.ALIAS, case_sensitive=True - ) - == 2 - ) + assert table_schema.find_field_id_by_name(field_name="foo", case_sensitive=True) == 1 + assert table_schema.find_field_id_by_name(field_name="bar", case_sensitive=True) == 2 + assert table_schema.find_field_id_by_name(field_name="baz", case_sensitive=True) == 3 + assert table_schema.find_field_id_by_alias(field_alias="qux", case_sensitive=True) == 1 + assert table_schema.find_field_id_by_alias(field_alias="foobar", case_sensitive=True) == 2 -def test_schema_get_field_id_case_insensitive(): - """Test case-insensitive retrieval of a field ID using the `get_field_id` method""" +def test_schema_find_field_id_case_insensitive(): + """Test case-insensitive retrieval of a field ID using the `find_field_id_by_name` and `find_field_id_by_alias methods""" fields = [ NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), ] table_schema = schema.Schema(fields=fields, schema_id=1, aliases={"qux": 1, "foobar": 2}) - assert table_schema.get_field_id(field_identifier="fOO", case_sensitive=False) == 1 - assert table_schema.get_field_id(field_identifier="BAr", case_sensitive=False) == 2 - assert table_schema.get_field_id(field_identifier="BaZ", case_sensitive=False) == 3 - assert ( - table_schema.get_field_id( - field_identifier="qUx", field_identifier_type=schema.FIELD_IDENTIFIER_TYPES.ALIAS, case_sensitive=False - ) - == 1 - ) - assert ( - table_schema.get_field_id( - field_identifier="fooBAR", field_identifier_type=schema.FIELD_IDENTIFIER_TYPES.ALIAS, case_sensitive=False - ) - == 2 - ) + assert table_schema.find_field_id_by_name(field_name="fOO", case_sensitive=False) == 1 + assert table_schema.find_field_id_by_name(field_name="BAr", case_sensitive=False) == 2 + assert table_schema.find_field_id_by_name(field_name="BaZ", case_sensitive=False) == 3 + assert table_schema.find_field_id_by_alias(field_alias="qUx", case_sensitive=False) == 1 + assert table_schema.find_field_id_by_alias(field_alias="fooBAR", case_sensitive=False) == 2 + + +def test_schema_find_field_name_by_field_id(): + """Test finding a field name using a field ID""" + fields = [ + NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), + ] + table_schema = schema.Schema(fields=fields, schema_id=1, aliases={"qux": 1, "foobar": 2}) + assert table_schema.find_field_name_by_field_id(field_id=1) == "foo" + assert table_schema.find_field_name_by_field_id(field_id=2) == "bar" + assert table_schema.find_field_name_by_field_id(field_id=3) == "baz" -def test_schema_get_field_id_raise_on_not_found(): +def test_schema_find_field_id_raise_on_not_found(): """Test raising when the field ID for a given name or alias cannot be found""" fields = [ NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), @@ -134,59 +135,56 @@ def test_schema_get_field_id_raise_on_not_found(): table_schema = schema.Schema(fields=fields, schema_id=1, aliases={"qux": 1, "foobar": 2}) with pytest.raises(ValueError) as exc_info: - table_schema.get_field_id(field_identifier="name1") + table_schema.find_field_id_by_name(field_name="name1") assert "Cannot get field ID, name not found: name1" in str(exc_info.value) with pytest.raises(ValueError) as exc_info: - table_schema.get_field_id(field_identifier="name2", case_sensitive=False_) + table_schema.find_field_id_by_name(field_name="name2", case_sensitive=False_) assert "Cannot get field ID, name not found: name2" in str(exc_info.value) with pytest.raises(ValueError) as exc_info: - table_schema.get_field_id( - field_identifier="case_insensitive_name1", - field_identifier_type=schema.FIELD_IDENTIFIER_TYPES.NAME, + table_schema.find_field_id_by_name( + field_name="case_insensitive_name1", case_sensitive=False, ) assert "Cannot get field ID, name not found: case_insensitive_name1" in str(exc_info.value) with pytest.raises(ValueError) as exc_info: - table_schema.get_field_id(field_identifier="case_insensitive_name2") + table_schema.find_field_id_by_name(field_name="case_insensitive_name2") assert "Cannot get field ID, name not found: case_insensitive_name2" in str(exc_info.value) with pytest.raises(ValueError) as exc_info: - table_schema.get_field_id(field_identifier="alias1", field_identifier_type=schema.FIELD_IDENTIFIER_TYPES.ALIAS) + table_schema.find_field_id_by_alias(field_alias="alias1") assert "Cannot get field ID, alias not found: alias1" in str(exc_info.value) with pytest.raises(ValueError) as exc_info: - table_schema.get_field_id(field_identifier="alias2", field_identifier_type=schema.FIELD_IDENTIFIER_TYPES.ALIAS) + table_schema.find_field_id_by_alias(field_alias="alias2") assert "Cannot get field ID, alias not found: alias2" in str(exc_info.value) with pytest.raises(ValueError) as exc_info: - table_schema.get_field_id( - field_identifier="case_insensitive_alias1", - field_identifier_type=schema.FIELD_IDENTIFIER_TYPES.ALIAS, + table_schema.find_field_id_by_alias( + field_alias="case_insensitive_alias1", case_sensitive=False, ) assert "Cannot get field ID, alias not found: case_insensitive_alias1" in str(exc_info.value) with pytest.raises(ValueError) as exc_info: - table_schema.get_field_id( - field_identifier="case_insensitive_alias2", - field_identifier_type=schema.FIELD_IDENTIFIER_TYPES.ALIAS, + table_schema.find_field_id_by_alias( + field_alias="case_insensitive_alias2", case_sensitive=False, ) assert "Cannot get field ID, alias not found: case_insensitive_alias2" in str(exc_info.value) -def test_schema_get_field_id_raise_on_multiple_case_insensitive_alias_match(): +def test_schema_find_field_id_raise_on_multiple_case_insensitive_alias_match(): """Test raising when a case-insensitive alias search returns multiple aliases""" fields = [ NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), @@ -196,14 +194,27 @@ def test_schema_get_field_id_raise_on_multiple_case_insensitive_alias_match(): table_schema = schema.Schema(fields=fields, schema_id=1, aliases={"qux": 1, "QUX": 2}) with pytest.raises(ValueError) as exc_info: - table_schema.get_field_id( - field_identifier="qux", field_identifier_type=schema.FIELD_IDENTIFIER_TYPES.ALIAS, case_sensitive=False - ) + table_schema.find_field_id_by_alias(field_alias="qux", case_sensitive=False) assert "Cannot get field ID, case-insensitive alias returns multiple results: qux" in str(exc_info.value) -def test_schema_get_field(): +def test_schema_find_field_name_by_field_id_raise_on_unknown_field_id(): + """Test raising when the the field ID cannot be found while finding a field name""" + fields = [ + NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), + ] + table_schema = schema.Schema(fields=fields, schema_id=1, aliases={"qux": 1, "foobar": 2}) + + with pytest.raises(ValueError) as exc_info: + table_schema.find_field_name_by_field_id(field_id=4) + + assert "Cannot get field name, field ID not found: 4" in str(exc_info.value) + + +def test_schema_find_field_by_id(): """Test retrieving a field using the field's ID""" fields = [ NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), @@ -211,9 +222,9 @@ def test_schema_get_field(): NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), ] table_schema = schema.Schema(fields=fields, schema_id=1) - field1 = table_schema.get_field(field_id=1) - field2 = table_schema.get_field(field_id=2) - field3 = table_schema.get_field(field_id=3) + field1 = table_schema.find_field_by_id(field_id=1) + field2 = table_schema.find_field_by_id(field_id=2) + field3 = table_schema.find_field_by_id(field_id=3) assert isinstance(field1, NestedField) assert field1.field_id == 1 @@ -229,7 +240,7 @@ def test_schema_get_field(): assert field3.is_optional == False -def test_schema_get_field_raise_on_unknown_field(): +def test_schema_find_field_by_id_raise_on_unknown_field(): """Test raising when the field ID is not found""" fields = [ NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), @@ -239,12 +250,12 @@ def test_schema_get_field_raise_on_unknown_field(): table_schema = schema.Schema(fields=fields, schema_id=1) with pytest.raises(ValueError) as exc_info: - table_schema.get_field(field_id=4) + table_schema.find_field_by_id(field_id=4) assert "Cannot get field, ID does not exist: 4" in str(exc_info.value) -def test_schema_get_type(): +def test_schema_find_field_type(): """Test retrieving a field's type using the field's ID""" fields = [ NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), @@ -253,6 +264,78 @@ def test_schema_get_type(): ] table_schema = schema.Schema(fields=fields, schema_id=1) - assert table_schema.get_type(field_id=1) == StringType() - assert table_schema.get_type(field_id=2) == IntegerType() - assert table_schema.get_type(field_id=3) == BooleanType() + assert table_schema.find_field_type(field_id=1) == StringType() + assert table_schema.find_field_type(field_id=2) == IntegerType() + assert table_schema.find_field_type(field_id=3) == BooleanType() + + +def test_index_by_id_schema_visitor(): + """Test retrieving a field id to field name map using an IndexById schema visitor""" + fields = [ + NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), + NestedField( + field_id=4, + name="qux", + field_type=ListType(element_id=5, element_type=StringType(), element_is_optional=True), + is_optional=False, + ), + NestedField( + field_id=6, + name="quux", + field_type=MapType(key_id=7, key_type=StringType(), value_id=8, value_type=IntegerType(), value_is_optional=True), + is_optional=False, + ), + ] + table_schema = schema.Schema(fields=fields, schema_id=1) + visitor = schema.IndexById() + visitor.visit(table_schema) + + assert visitor.result == {1: "foo", 2: "bar", 3: "baz", 4: "qux", 5: "qux.element", 6: "quux", 7: "quux.key", 8: "quux.value"} + + +def test_index_by_name_schema_visitor(): + """Test retrieving a field name to field id map using an IndexByName schema visitor""" + fields = [ + NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), + NestedField( + field_id=4, + name="qux", + field_type=ListType(element_id=5, element_type=StringType(), element_is_optional=True), + is_optional=False, + ), + NestedField( + field_id=6, + name="quux", + field_type=MapType(key_id=7, key_type=StringType(), value_id=8, value_type=IntegerType(), value_is_optional=True), + is_optional=False, + ), + ] + table_schema = schema.Schema(fields=fields, schema_id=1) + visitor = schema.IndexByName() + visitor.visit(table_schema) + + assert visitor.result == {"bar": 2, "baz": 3, "foo": 1, "qux": 4, "qux.element": 5, "quux": 6, "quux.key": 7, "quux.value": 8} + + +def test_index_by_id_schema_visitor_raise_on_unregistered_type(): + """Test raising a NotImplementedError when a type with no registered visit operation is passed to an IndexById visitor""" + + visitor = schema.IndexById() + with pytest.raises(NotImplementedError) as exc_info: + visitor.visit("foo") + + assert "Cannot visit node, no IndexById operation implemented for node type: " in str(exc_info.value) + + +def test_index_by_name_schema_visitor_raise_on_unregistered_type(): + """Test raising a NotImplementedError when a type with no registered visit operation is passed to an IndexByName visitor""" + + visitor = schema.IndexByName() + with pytest.raises(NotImplementedError) as exc_info: + visitor.visit("foo") + + assert "Cannot visit node, no IndexByName operation implemented for node type: " in str(exc_info.value) From 933fdf86d34c6de7e142022fc7e50c0cf795554c Mon Sep 17 00:00:00 2001 From: samredai <43911210+samredai@users.noreply.github.com> Date: Mon, 14 Mar 2022 03:08:20 -0700 Subject: [PATCH 05/23] Add return types, remove enum no longer used --- python/src/iceberg/table/schema.py | 30 ++++++++++++------------------ 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/python/src/iceberg/table/schema.py b/python/src/iceberg/table/schema.py index 46de6610ee3d..f1d3a1fc7b9c 100644 --- a/python/src/iceberg/table/schema.py +++ b/python/src/iceberg/table/schema.py @@ -16,8 +16,7 @@ # under the License. import sys -from enum import Enum, auto -from typing import Iterable +from typing import Iterable, Tuple if sys.version_info >= (3, 8): # pragma: no cover from typing import Protocol @@ -32,11 +31,6 @@ from iceberg.types import IcebergType, ListType, MapType, NestedField, StructType -class FIELD_IDENTIFIER_TYPES(Enum): - NAME = auto() - ALIAS = auto() - - class SchemaVisitor(Protocol): def visit(self, node) -> None: # pragma: no cover ... @@ -83,15 +77,15 @@ def __repr__(self): return f"Schema(fields={repr(self.fields)}, schema_id={self.schema_id})" @property - def fields(self): + def fields(self) -> Tuple: return self._struct.fields @property - def schema_id(self): + def schema_id(self) -> int: return self._schema_id @property - def struct(self): + def struct(self) -> StructType: return self._struct def find_field_id_by_name(self, field_name: str, case_sensitive: bool = True) -> int: @@ -144,7 +138,7 @@ def find_field_id_by_alias(self, field_alias: str, case_sensitive: bool = True) # If multiple IDs are returned for a case-insensitive alias lookup, raise raise ValueError(f"Cannot get field ID, case-insensitive alias returns multiple results: {field_alias}") - def find_field_name_by_field_id(self, field_id: int): + def find_field_name_by_field_id(self, field_id: int) -> str: """Find a field name for a given field ID Args: @@ -161,7 +155,7 @@ def find_field_name_by_field_id(self, field_id: int): return indexed_field_name raise ValueError(f"Cannot get field name, field ID not found: {field_id}") - def find_field_by_id(self, field_id: int): + def find_field_by_id(self, field_id: int) -> NestedField: """Find a field by it's field ID Args: @@ -218,7 +212,7 @@ def __init__(self): self.result = {} @singledispatchmethod - def visit(self, node): + def visit(self, node) -> None: """A generic single dispatch visit method Raises: @@ -227,22 +221,22 @@ def visit(self, node): raise NotImplementedError(f"Cannot visit node, no IndexById operation implemented for node type: {type(node)}") @visit.register - def _(self, node: Schema): + def _(self, node: Schema) -> None: self.visit(node.struct) @visit.register - def _(self, node: StructType): + def _(self, node: StructType) -> None: for field in node.fields: self.visit(field) @visit.register - def _(self, node: NestedField): + def _(self, node: NestedField) -> None: self.result[node.field_id] = node.name if isinstance(node.type, (ListType, MapType)): self.visit(node.type, nested_field_name=node.name) @visit.register - def _(self, node: ListType, nested_field_name: str): + def _(self, node: ListType, nested_field_name: str) -> None: """Index a ListType node ListType nodes add one item to the index: @@ -321,7 +315,7 @@ def _(self, node: NestedField) -> None: self.visit(node.type, nested_field_name=node.name) @visit.register - def _(self, node: ListType, nested_field_name: str): + def _(self, node: ListType, nested_field_name: str) -> None: """Index a ListType node ListType nodes add one item to the index: From 4d8123824bd907147dc71df494915450fd1a829e Mon Sep 17 00:00:00 2001 From: samredai <43911210+samredai@users.noreply.github.com> Date: Mon, 14 Mar 2022 21:43:41 -0700 Subject: [PATCH 06/23] Consolidate import if statements --- python/src/iceberg/table/schema.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/python/src/iceberg/table/schema.py b/python/src/iceberg/table/schema.py index f1d3a1fc7b9c..e9a8dce74dc2 100644 --- a/python/src/iceberg/table/schema.py +++ b/python/src/iceberg/table/schema.py @@ -19,13 +19,10 @@ from typing import Iterable, Tuple if sys.version_info >= (3, 8): # pragma: no cover + from functools import singledispatchmethod from typing import Protocol else: # pragma: no cover from typing_extensions import Protocol # type: ignore - -if sys.version_info >= (3, 8): # pragma: no cover - from functools import singledispatchmethod -else: # pragma: no cover from singledispatch import singledispatchmethod # type: ignore from iceberg.types import IcebergType, ListType, MapType, NestedField, StructType From 7007ea16c5779a58f101e70c13f3e5a132b5d720 Mon Sep 17 00:00:00 2001 From: samredai <43911210+samredai@users.noreply.github.com> Date: Mon, 14 Mar 2022 21:44:13 -0700 Subject: [PATCH 07/23] Rename find_field_type to find_field_type_by_id --- python/src/iceberg/table/schema.py | 2 +- python/tests/table/test_schema.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/src/iceberg/table/schema.py b/python/src/iceberg/table/schema.py index e9a8dce74dc2..32339e7ad252 100644 --- a/python/src/iceberg/table/schema.py +++ b/python/src/iceberg/table/schema.py @@ -169,7 +169,7 @@ def find_field_by_id(self, field_id: int) -> NestedField: return field raise ValueError(f"Cannot get field, ID does not exist: {field_id}") - def find_field_type(self, field_id: int) -> IcebergType: + def find_field_type_by_id(self, field_id: int) -> IcebergType: """Find the type of a field by field ID Args: diff --git a/python/tests/table/test_schema.py b/python/tests/table/test_schema.py index fb52b35246f8..98234746c21a 100644 --- a/python/tests/table/test_schema.py +++ b/python/tests/table/test_schema.py @@ -255,7 +255,7 @@ def test_schema_find_field_by_id_raise_on_unknown_field(): assert "Cannot get field, ID does not exist: 4" in str(exc_info.value) -def test_schema_find_field_type(): +def test_schema_find_field_type_by_id(): """Test retrieving a field's type using the field's ID""" fields = [ NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), @@ -264,9 +264,9 @@ def test_schema_find_field_type(): ] table_schema = schema.Schema(fields=fields, schema_id=1) - assert table_schema.find_field_type(field_id=1) == StringType() - assert table_schema.find_field_type(field_id=2) == IntegerType() - assert table_schema.find_field_type(field_id=3) == BooleanType() + assert table_schema.find_field_type_by_id(field_id=1) == StringType() + assert table_schema.find_field_type_by_id(field_id=2) == IntegerType() + assert table_schema.find_field_type_by_id(field_id=3) == BooleanType() def test_index_by_id_schema_visitor(): From f89a73061215df3909debe21671d46bf8f3c5084 Mon Sep 17 00:00:00 2001 From: samredai <43911210+samredai@users.noreply.github.com> Date: Mon, 14 Mar 2022 21:48:14 -0700 Subject: [PATCH 08/23] Add a __len__ method for Schema --- python/src/iceberg/table/schema.py | 3 +++ python/tests/table/test_schema.py | 1 + 2 files changed, 4 insertions(+) diff --git a/python/src/iceberg/table/schema.py b/python/src/iceberg/table/schema.py index 32339e7ad252..0dd0ab586bbf 100644 --- a/python/src/iceberg/table/schema.py +++ b/python/src/iceberg/table/schema.py @@ -64,6 +64,9 @@ def __init__(self, fields: Iterable[NestedField], schema_id: int, aliases: dict index_by_name_visitor.visit(self._struct) self._index_by_name = index_by_name_visitor.result + def __len__(self): + return len(self.struct.fields) + def __str__(self): schema_str = "" for field in self.fields: diff --git a/python/tests/table/test_schema.py b/python/tests/table/test_schema.py index 98234746c21a..ec8f587e5a76 100644 --- a/python/tests/table/test_schema.py +++ b/python/tests/table/test_schema.py @@ -44,6 +44,7 @@ def test_schema_init(): assert table_schema.fields[1] == fields[1] assert table_schema.fields[2] == fields[2] assert table_schema.schema_id == 1 + assert len(table_schema) == 3 assert isinstance(schema_struct, StructType) From 2b6454c352ae572518f547e05503cd83cd339c40 Mon Sep 17 00:00:00 2001 From: samredai <43911210+samredai@users.noreply.github.com> Date: Fri, 25 Mar 2022 15:28:31 -0700 Subject: [PATCH 09/23] Fixing docstring examples to pass doctests --- python/src/iceberg/table/schema.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/python/src/iceberg/table/schema.py b/python/src/iceberg/table/schema.py index 0dd0ab586bbf..6afc961fe20d 100644 --- a/python/src/iceberg/table/schema.py +++ b/python/src/iceberg/table/schema.py @@ -39,10 +39,10 @@ class Schema: Example: >>> from iceberg.table.schema import Schema >>> from iceberg.types import BooleanType, IntegerType, NestedField, StringType - >>> fields = [ - NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), - NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), - NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), + >>> fields = [ \ + NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), \ + NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), \ + NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), \ ] >>> table_schema = Schema(fields=fields, schema_id=1, aliases={"qux": 3}) >>> print(table_schema) @@ -196,10 +196,10 @@ class IndexById(SchemaVisitor): Example: >>> from iceberg.table.schema import IndexById, Schema >>> from iceberg.types import BooleanType, IntegerType, NestedField, StringType - >>> fields = [ - NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), - NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), - NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), + >>> fields = [ \ + NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), \ + NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), \ + NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), \ ] >>> table_schema = Schema(fields=fields, schema_id=1) >>> visitor = IndexById() @@ -275,10 +275,10 @@ class IndexByName(SchemaVisitor): Example: >>> from iceberg.table.schema import IndexByName, Schema >>> from iceberg.types import BooleanType, IntegerType, NestedField, StringType - >>> fields = [ - NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), - NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), - NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), + >>> fields = [ \ + NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), \ + NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), \ + NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), \ ] >>> table_schema = Schema(fields=fields, schema_id=1) >>> visitor = IndexByName() From 1333815b1855fa59002446414c16c9f799daaa05 Mon Sep 17 00:00:00 2001 From: samredai <43911210+samredai@users.noreply.github.com> Date: Mon, 28 Mar 2022 09:56:51 -0700 Subject: [PATCH 10/23] Port schema and visitor from partition spec PR Co-authored-by: Ryan Blue --- python/src/iceberg/table/schema.py | 438 +++++++++-------------------- python/tests/table/test_schema.py | 283 +++++-------------- 2 files changed, 206 insertions(+), 515 deletions(-) diff --git a/python/src/iceberg/table/schema.py b/python/src/iceberg/table/schema.py index 6afc961fe20d..187a371367a9 100644 --- a/python/src/iceberg/table/schema.py +++ b/python/src/iceberg/table/schema.py @@ -16,7 +16,7 @@ # under the License. import sys -from typing import Iterable, Tuple +from typing import Dict, Generic, Iterable, List, Optional, TypeVar if sys.version_info >= (3, 8): # pragma: no cover from functools import singledispatchmethod @@ -25,322 +25,152 @@ from typing_extensions import Protocol # type: ignore from singledispatch import singledispatchmethod # type: ignore -from iceberg.types import IcebergType, ListType, MapType, NestedField, StructType +from iceberg.types import ListType, MapType, NestedField, PrimitiveType, StructType -class SchemaVisitor(Protocol): - def visit(self, node) -> None: # pragma: no cover - ... +class Schema(object): + """A table Schema""" + def __init__(self, *columns: Iterable[NestedField]): + self._struct = StructType(*columns) # type: ignore -class Schema: - """Schema of a table - - Example: - >>> from iceberg.table.schema import Schema - >>> from iceberg.types import BooleanType, IntegerType, NestedField, StringType - >>> fields = [ \ - NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), \ - NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), \ - NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), \ + def __str__(self): + column_strings = [ + f"{column.field_id}: name={column.name}, type={column.type}, required={column.is_required}" for column in self.columns ] - >>> table_schema = Schema(fields=fields, schema_id=1, aliases={"qux": 3}) - >>> print(table_schema) - 1: name=foo, type=string, required=True - 2: name=bar, type=int, required=False - 3: name=baz, type=boolean, required=True - """ + return "\n".join(column_strings) - def __init__(self, fields: Iterable[NestedField], schema_id: int, aliases: dict = {}): - self._struct = StructType(*fields) - self._schema_id = schema_id - self._aliases = aliases + def __repr__(self): + return f"Schema(fields={repr(self.columns)})" - index_by_id_visitor = IndexById() - index_by_id_visitor.visit(self._struct) - self._index_by_id = index_by_id_visitor.result + @property + def columns(self): + return self._struct.fields - index_by_name_visitor = IndexByName() - index_by_name_visitor.visit(self._struct) - self._index_by_name = index_by_name_visitor.result + def as_struct(self): + return self._struct - def __len__(self): - return len(self.struct.fields) - def __str__(self): - schema_str = "" - for field in self.fields: - schema_str += f"{field.field_id}: name={field.name}, type={field.type}, required={field.is_required}\n" - return schema_str.rstrip() +T = TypeVar("T") - def __repr__(self): - return f"Schema(fields={repr(self.fields)}, schema_id={self.schema_id})" - @property - def fields(self) -> Tuple: - return self._struct.fields +class SchemaVisitor(Generic[T]): + def before_field(self, field: NestedField) -> None: + pass - @property - def schema_id(self) -> int: - return self._schema_id + def after_field(self, field: NestedField) -> None: + pass - @property - def struct(self) -> StructType: - return self._struct + def before_list_element(self, element: NestedField) -> None: + self.before_field(element) + + def after_list_element(self, element: NestedField) -> None: + self.after_field(element) + + def before_map_key(self, key: NestedField) -> None: + self.before_field(key) + + def after_map_key(self, key: NestedField) -> None: + self.after_field(key) + + def before_map_value(self, value: NestedField) -> None: + self.before_field(value) + + def after_map_value(self, value: NestedField) -> None: + self.after_field(value) + + def schema(self, schema: Schema, struct_result: T) -> Optional[T]: + return None + + def struct(self, struct: StructType, field_results: List[T]) -> Optional[T]: + return None + + def field(self, field: NestedField, field_result: T) -> Optional[T]: + return None + + def list(self, list_type: ListType, element_result: T) -> Optional[T]: + return None + + def map(self, map_type: MapType, key_result: T, value_result: T) -> Optional[T]: + return None - def find_field_id_by_name(self, field_name: str, case_sensitive: bool = True) -> int: - """Get a field ID for a given field name - - Args: - field_name (str): A field name - case_sensitive (bool): If False, case will not be considered when retrieving the field from the schema--default is True - - Returns: - int: The field ID for the field name - """ - if case_sensitive: - for indexed_field_name, indexed_field_id in self._index_by_name.items(): - if indexed_field_name == field_name: - return indexed_field_id - if not case_sensitive: - for indexed_field_name, indexed_field_id in self._index_by_name.items(): - if indexed_field_name.lower() == field_name.lower(): - return indexed_field_id - raise ValueError(f"Cannot get field ID, name not found: {field_name}") - - def find_field_id_by_alias(self, field_alias: str, case_sensitive: bool = True) -> int: - """Get a field ID for a given field alias - - Args: - field_alias (str): A field alias - case_sensitive (bool): If False, case will not be considered when retrieving the field from the schema--default is True - - Returns: - int: The field ID for the field alias - - raises: - ValueError: If the field ID cannot be retrieved either because the alias was not found or a case-insensitive - match returned multiple results - """ - if case_sensitive: + def primitive(self, primitive: PrimitiveType) -> Optional[T]: + return None + + +def visit(obj, visitor: SchemaVisitor[Optional[T]]) -> Optional[T]: + if isinstance(obj, Schema): + return visitor.schema(obj, visit(obj.as_struct(), visitor)) + + elif isinstance(obj, StructType): + results = [] + for field in obj.fields: + visitor.before_field(field) try: - # For case-sensitive, just try looking up the alias and raise if it's not found - return self._aliases[field_alias] - except KeyError: - raise ValueError(f"Cannot get field ID, alias not found: {field_alias}") - if not case_sensitive: - matching_fields = [value for key, value in self._aliases.items() if key.lower() == field_alias.lower()] - if len(matching_fields) == 1: # If one matching alias found, return the corresponding ID - return matching_fields[0] - elif not matching_fields: # If no matching fields, raise - raise ValueError(f"Cannot get field ID, alias not found: {field_alias}") - - # If multiple IDs are returned for a case-insensitive alias lookup, raise - raise ValueError(f"Cannot get field ID, case-insensitive alias returns multiple results: {field_alias}") - - def find_field_name_by_field_id(self, field_id: int) -> str: - """Find a field name for a given field ID - - Args: - field_id (int): A field ID - - Returns: - str: The field name for the field ID - - Raises: - ValueError: If the field ID does not exist - """ - for indexed_field_name, indexed_field_id in self._index_by_name.items(): - if indexed_field_id == field_id: - return indexed_field_name - raise ValueError(f"Cannot get field name, field ID not found: {field_id}") - - def find_field_by_id(self, field_id: int) -> NestedField: - """Find a field by it's field ID - - Args: - field_id (int): A field ID - - Returns: - NestedField: The field object for the given field ID - - Raise: - ValueError: If the field ID does not exist - """ - for field in self.fields: - if field.field_id == field_id: - return field - raise ValueError(f"Cannot get field, ID does not exist: {field_id}") - - def find_field_type_by_id(self, field_id: int) -> IcebergType: - """Find the type of a field by field ID - - Args: - field_id (int): The ID of the field - - Raises: - ValueError: If the field ID does not exist in this schema - - Returns - IcebergType: The type of the field with ID of `field_id` - """ - field = self.find_field_by_id(field_id) - return field.type - - -class IndexById(SchemaVisitor): - """Index a Schema by IDs - - This visitor provides a field ID to field name map for a given Schema instance. The result is stored in the `result` instance attribute. - - Example: - >>> from iceberg.table.schema import IndexById, Schema - >>> from iceberg.types import BooleanType, IntegerType, NestedField, StringType - >>> fields = [ \ - NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), \ - NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), \ - NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), \ - ] - >>> table_schema = Schema(fields=fields, schema_id=1) - >>> visitor = IndexById() - >>> visitor.visit(table_schema) - >>> print(visitor.result) - {1: 'foo', 2: 'bar', 3: 'baz'} - """ - - def __init__(self): - self.result = {} - - @singledispatchmethod - def visit(self, node) -> None: - """A generic single dispatch visit method - - Raises: - NotImplementedError: If no concrete method has been registered for type of the `node` argument - """ - raise NotImplementedError(f"Cannot visit node, no IndexById operation implemented for node type: {type(node)}") - - @visit.register - def _(self, node: Schema) -> None: - self.visit(node.struct) - - @visit.register - def _(self, node: StructType) -> None: - for field in node.fields: - self.visit(field) - - @visit.register - def _(self, node: NestedField) -> None: - self.result[node.field_id] = node.name - if isinstance(node.type, (ListType, MapType)): - self.visit(node.type, nested_field_name=node.name) - - @visit.register - def _(self, node: ListType, nested_field_name: str) -> None: - """Index a ListType node - - ListType nodes add one item to the index: - 1. The ID for the ListType element and the name . - - is always "element". - Args: - node (ListType): The ListType instance for the NestedField instance - nested_field_name (str): The name of the NestedField instance containing the ListType - """ - self.result[node.element.field_id] = f"{nested_field_name}.{node.element.name}" - - @visit.register - def _(self, node: MapType, nested_field_name: str) -> None: - """Index a MapType node - - MapType nodes add two items to the index: - 1. The MapType key ID and the name . - 2. The MapType value ID and the name . - - and are always "key" and "value", respectively. - - Args: - node (MapType): The MapType instance for the NestedField instance - nested_field_name (str): The name of the NestedField instance containing the MapType - """ - self.result[node.key.field_id] = f"{nested_field_name}.key" - self.result[node.value.field_id] = f"{nested_field_name}.value" - - -class IndexByName(SchemaVisitor): - """Index a Schema by names - - This visitor provides a field name to field ID map for a given Schema instance. The result is stored in the `result` instance attribute. - - Example: - >>> from iceberg.table.schema import IndexByName, Schema - >>> from iceberg.types import BooleanType, IntegerType, NestedField, StringType - >>> fields = [ \ - NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), \ - NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), \ - NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), \ - ] - >>> table_schema = Schema(fields=fields, schema_id=1) - >>> visitor = IndexByName() - >>> visitor.visit(table_schema) - >>> print(visitor.result) - {'foo': 1, 'bar': 2, 'baz': 3} - """ - - def __init__(self): - self.result = {} - - @singledispatchmethod - def visit(self, node) -> None: - """A generic single dispatch visit method - - Raises: - NotImplementedError: If no concrete method has been registered for type of the `node` argument - """ - raise NotImplementedError(f"Cannot visit node, no IndexByName operation implemented for node type: {type(node)}") - - @visit.register - def _(self, node: Schema) -> None: - self.visit(node.struct) - - @visit.register - def _(self, node: StructType) -> None: - for field in node.fields: - self.visit(field) - - @visit.register - def _(self, node: NestedField) -> None: - self.result[node.name] = node.field_id - if isinstance(node.type, (ListType, MapType)): - self.visit(node.type, nested_field_name=node.name) - - @visit.register - def _(self, node: ListType, nested_field_name: str) -> None: - """Index a ListType node - - ListType nodes add one item to the index: - 1. The name . and the ID for the ListType element - - is always "element". - Args: - node (ListType): The ListType instance for the NestedField instance - nested_field_name (str): The name of the NestedField instance containing the ListType - """ - self.result[f"{nested_field_name}.{node.element.name}"] = node.element.field_id - - @visit.register - def _(self, node: MapType, nested_field_name: str) -> None: - """Index a MapType node - - MapType nodes add two items to the index: - 1. The name . and the MapType key ID - 2. The name . and the MapType value ID - - and are always "key" and "value", respectively. - - Args: - node (MapType): The MapType instance for the NestedField instance - nested_field_name (str): The name of the NestedField instance containing the MapType - """ - self.result[f"{nested_field_name}.{node.key.name}"] = node.key.field_id - self.result[f"{nested_field_name}.{node.value.name}"] = node.value.field_id + result = visit(field.type, visitor) + finally: + visitor.after_field(field) + + results.append(visitor.field(field, result)) + + return visitor.struct(obj, results) + + elif isinstance(obj, ListType): + visitor.before_list_element(obj.element) + try: + result = visit(obj.element.type, visitor) + finally: + visitor.after_list_element(obj.element) + + return visitor.list(obj, result) + + elif isinstance(obj, MapType): + visitor.before_map_key(obj.key) + try: + key_result = visit(obj.key.type, visitor) + finally: + visitor.after_map_key(obj.key) + + visitor.before_map_value(obj.value) + try: + value_result = visit(obj.value.type, visitor) + finally: + visitor.after_list_element(obj.value) + + return visitor.map(obj, key_result, value_result) + + elif isinstance(obj, PrimitiveType): + return visitor.primitive(obj) + + else: + raise NotImplementedError("Cannot visit non-type: %s" % obj) + + +def index_by_id(schema_or_type) -> Optional[Dict[int, NestedField]]: + class IndexById(SchemaVisitor[Optional[Dict[int, NestedField]]]): + def __init__(self): + self._index: Dict[int, NestedField] = {} + + def schema(self, schema, result): + return self._index + + def struct(self, struct, results): + return self._index + + def field(self, field, result): + self._index[field.field_id] = field + return self._index + + def list(self, list_type, result): + self._index[list_type.element.field_id] = list_type.element + return self._index + + def map(self, map_type, key_result, value_result): + self._index[map_type.key.field_id] = map_type.key + self._index[map_type.value.field_id] = map_type.value + return self._index + + def primitive(self, primitive): + return self._index + + return visit(schema_or_type, IndexById()) diff --git a/python/tests/table/test_schema.py b/python/tests/table/test_schema.py index ec8f587e5a76..2c4175a219e5 100644 --- a/python/tests/table/test_schema.py +++ b/python/tests/table/test_schema.py @@ -15,10 +15,12 @@ # specific language governing permissions and limitations # under the License. +from typing import Dict + import pytest -from numpy import False_ from iceberg.table import schema +from iceberg.table.schema import SchemaVisitor from iceberg.types import ( BooleanType, IntegerType, @@ -26,36 +28,17 @@ MapType, NestedField, StringType, - StructType, ) -def test_schema_init(): - """Test initializing a schema from a list of fields""" - fields = [ - NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), - NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), - NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), - ] - table_schema = schema.Schema(fields=fields, schema_id=1) - schema_struct = table_schema.struct - - assert table_schema.fields[0] == fields[0] - assert table_schema.fields[1] == fields[1] - assert table_schema.fields[2] == fields[2] - assert table_schema.schema_id == 1 - assert len(table_schema) == 3 - assert isinstance(schema_struct, StructType) - - def test_schema_str(): """Test casting a schema to a string""" - fields = [ + columns = [ NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), ] - table_schema = schema.Schema(fields=fields, schema_id=1) + table_schema = schema.Schema(*columns) assert """1: name=foo, type=string, required=True 2: name=bar, type=int, required=False 3: name=baz, type=boolean, required=True""" == str( @@ -67,14 +50,12 @@ def test_schema_str(): "schema, expected_repr", [ ( - schema.Schema(fields=[NestedField(1, "foo", StringType())], schema_id=1), - "Schema(fields=(NestedField(field_id=1, name='foo', field_type=StringType(), is_optional=True),), schema_id=1)", + schema.Schema(NestedField(1, "foo", StringType())), + "Schema(fields=(NestedField(field_id=1, name='foo', field_type=StringType(), is_optional=True),))", ), ( - schema.Schema( - fields=[NestedField(1, "foo", StringType()), NestedField(2, "bar", IntegerType(), is_optional=False)], schema_id=2 - ), - "Schema(fields=(NestedField(field_id=1, name='foo', field_type=StringType(), is_optional=True), NestedField(field_id=2, name='bar', field_type=IntegerType(), is_optional=False)), schema_id=2)", + schema.Schema(NestedField(1, "foo", StringType()), NestedField(2, "bar", IntegerType(), is_optional=False)), + "Schema(fields=(NestedField(field_id=1, name='foo', field_type=StringType(), is_optional=True), NestedField(field_id=2, name='bar', field_type=IntegerType(), is_optional=False)))", ), ], ) @@ -83,196 +64,84 @@ def test_schema_repr(schema, expected_repr): assert repr(schema) == expected_repr -def test_schema_find_field_id_case_sensitive(): - """Test case-sensitive retrieval of a field ID using the `find_field_id` method""" - fields = [ - NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), - NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), - NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), - ] - table_schema = schema.Schema(fields=fields, schema_id=1, aliases={"qux": 1, "foobar": 2}) - assert table_schema.find_field_id_by_name(field_name="foo", case_sensitive=True) == 1 - assert table_schema.find_field_id_by_name(field_name="bar", case_sensitive=True) == 2 - assert table_schema.find_field_id_by_name(field_name="baz", case_sensitive=True) == 3 - assert table_schema.find_field_id_by_alias(field_alias="qux", case_sensitive=True) == 1 - assert table_schema.find_field_id_by_alias(field_alias="foobar", case_sensitive=True) == 2 - - -def test_schema_find_field_id_case_insensitive(): - """Test case-insensitive retrieval of a field ID using the `find_field_id_by_name` and `find_field_id_by_alias methods""" - fields = [ - NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), - NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), - NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), - ] - table_schema = schema.Schema(fields=fields, schema_id=1, aliases={"qux": 1, "foobar": 2}) - assert table_schema.find_field_id_by_name(field_name="fOO", case_sensitive=False) == 1 - assert table_schema.find_field_id_by_name(field_name="BAr", case_sensitive=False) == 2 - assert table_schema.find_field_id_by_name(field_name="BaZ", case_sensitive=False) == 3 - assert table_schema.find_field_id_by_alias(field_alias="qUx", case_sensitive=False) == 1 - assert table_schema.find_field_id_by_alias(field_alias="fooBAR", case_sensitive=False) == 2 - - def test_schema_find_field_name_by_field_id(): - """Test finding a field name using a field ID""" - fields = [ - NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), - NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), - NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), - ] - table_schema = schema.Schema(fields=fields, schema_id=1, aliases={"qux": 1, "foobar": 2}) - assert table_schema.find_field_name_by_field_id(field_id=1) == "foo" - assert table_schema.find_field_name_by_field_id(field_id=2) == "bar" - assert table_schema.find_field_name_by_field_id(field_id=3) == "baz" - - -def test_schema_find_field_id_raise_on_not_found(): - """Test raising when the field ID for a given name or alias cannot be found""" - fields = [ - NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), - NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), - NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), - ] - table_schema = schema.Schema(fields=fields, schema_id=1, aliases={"qux": 1, "foobar": 2}) - - with pytest.raises(ValueError) as exc_info: - table_schema.find_field_id_by_name(field_name="name1") - - assert "Cannot get field ID, name not found: name1" in str(exc_info.value) - - with pytest.raises(ValueError) as exc_info: - table_schema.find_field_id_by_name(field_name="name2", case_sensitive=False_) - - assert "Cannot get field ID, name not found: name2" in str(exc_info.value) - - with pytest.raises(ValueError) as exc_info: - table_schema.find_field_id_by_name( - field_name="case_insensitive_name1", - case_sensitive=False, - ) - - assert "Cannot get field ID, name not found: case_insensitive_name1" in str(exc_info.value) - - with pytest.raises(ValueError) as exc_info: - table_schema.find_field_id_by_name(field_name="case_insensitive_name2") - - assert "Cannot get field ID, name not found: case_insensitive_name2" in str(exc_info.value) - - with pytest.raises(ValueError) as exc_info: - table_schema.find_field_id_by_alias(field_alias="alias1") - - assert "Cannot get field ID, alias not found: alias1" in str(exc_info.value) - - with pytest.raises(ValueError) as exc_info: - table_schema.find_field_id_by_alias(field_alias="alias2") - - assert "Cannot get field ID, alias not found: alias2" in str(exc_info.value) - - with pytest.raises(ValueError) as exc_info: - table_schema.find_field_id_by_alias( - field_alias="case_insensitive_alias1", - case_sensitive=False, - ) - - assert "Cannot get field ID, alias not found: case_insensitive_alias1" in str(exc_info.value) - - with pytest.raises(ValueError) as exc_info: - table_schema.find_field_id_by_alias( - field_alias="case_insensitive_alias2", - case_sensitive=False, - ) - - assert "Cannot get field ID, alias not found: case_insensitive_alias2" in str(exc_info.value) - - -def test_schema_find_field_id_raise_on_multiple_case_insensitive_alias_match(): - """Test raising when a case-insensitive alias search returns multiple aliases""" - fields = [ + """Test finding a column name using its field ID""" + columns = [ NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), ] - table_schema = schema.Schema(fields=fields, schema_id=1, aliases={"qux": 1, "QUX": 2}) + table_schema = schema.Schema(*columns) + index = schema.index_by_id(table_schema) - with pytest.raises(ValueError) as exc_info: - table_schema.find_field_id_by_alias(field_alias="qux", case_sensitive=False) + assert index[1].name == "foo" + assert index[2].name == "bar" + assert index[3].name == "baz" - assert "Cannot get field ID, case-insensitive alias returns multiple results: qux" in str(exc_info.value) - -def test_schema_find_field_name_by_field_id_raise_on_unknown_field_id(): - """Test raising when the the field ID cannot be found while finding a field name""" - fields = [ +def test_schema_find_field_by_id(): + """Test finding a column using its field ID""" + columns = [ NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), ] - table_schema = schema.Schema(fields=fields, schema_id=1, aliases={"qux": 1, "foobar": 2}) + table_schema = schema.Schema(*columns) + index = schema.index_by_id(table_schema) - with pytest.raises(ValueError) as exc_info: - table_schema.find_field_name_by_field_id(field_id=4) + column1 = index[1] + assert isinstance(column1, NestedField) + assert column1.field_id == 1 + assert column1.type == StringType() + assert column1.is_optional == False - assert "Cannot get field name, field ID not found: 4" in str(exc_info.value) + column2 = index[2] + assert isinstance(column2, NestedField) + assert column2.field_id == 2 + assert column2.type == IntegerType() + assert column2.is_optional == True - -def test_schema_find_field_by_id(): - """Test retrieving a field using the field's ID""" - fields = [ - NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), - NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), - NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), - ] - table_schema = schema.Schema(fields=fields, schema_id=1) - field1 = table_schema.find_field_by_id(field_id=1) - field2 = table_schema.find_field_by_id(field_id=2) - field3 = table_schema.find_field_by_id(field_id=3) - - assert isinstance(field1, NestedField) - assert field1.field_id == 1 - assert field1.type == StringType() - assert field1.is_optional == False - assert isinstance(field2, NestedField) - assert field2.field_id == 2 - assert field2.type == IntegerType() - assert field2.is_optional == True - assert isinstance(field3, NestedField) - assert field3.field_id == 3 - assert field3.type == BooleanType() - assert field3.is_optional == False + column3 = index[3] + assert isinstance(column3, NestedField) + assert column3.field_id == 3 + assert column3.type == BooleanType() + assert column3.is_optional == False def test_schema_find_field_by_id_raise_on_unknown_field(): - """Test raising when the field ID is not found""" - fields = [ + """Test raising when the field ID is not found among columns""" + columns = [ NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), ] - table_schema = schema.Schema(fields=fields, schema_id=1) + table_schema = schema.Schema(*columns) + index = schema.index_by_id(table_schema) - with pytest.raises(ValueError) as exc_info: - table_schema.find_field_by_id(field_id=4) + with pytest.raises(Exception) as exc_info: + index[4] - assert "Cannot get field, ID does not exist: 4" in str(exc_info.value) + assert str(exc_info.value) == "4" def test_schema_find_field_type_by_id(): - """Test retrieving a field's type using the field's ID""" - fields = [ + """Test retrieving a columns's type using its field ID""" + columns = [ NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), ] - table_schema = schema.Schema(fields=fields, schema_id=1) + table_schema = schema.Schema(*columns) + index = schema.index_by_id(table_schema) - assert table_schema.find_field_type_by_id(field_id=1) == StringType() - assert table_schema.find_field_type_by_id(field_id=2) == IntegerType() - assert table_schema.find_field_type_by_id(field_id=3) == BooleanType() + assert index[1] == columns[0] + assert index[2] == columns[1] + assert index[3] == columns[2] def test_index_by_id_schema_visitor(): - """Test retrieving a field id to field name map using an IndexById schema visitor""" - fields = [ + """Test the index_by_id function that uses the IndexById schema visitor""" + columns = [ NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), @@ -289,54 +158,46 @@ def test_index_by_id_schema_visitor(): is_optional=False, ), ] - table_schema = schema.Schema(fields=fields, schema_id=1) - visitor = schema.IndexById() - visitor.visit(table_schema) - - assert visitor.result == {1: "foo", 2: "bar", 3: "baz", 4: "qux", 5: "qux.element", 6: "quux", 7: "quux.key", 8: "quux.value"} - + table_schema = schema.Schema(*columns) -def test_index_by_name_schema_visitor(): - """Test retrieving a field name to field id map using an IndexByName schema visitor""" - fields = [ - NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), - NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), - NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), - NestedField( + assert schema.index_by_id(table_schema) == { + 1: NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), + 2: NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), + 3: NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), + 4: NestedField( field_id=4, name="qux", field_type=ListType(element_id=5, element_type=StringType(), element_is_optional=True), is_optional=False, ), - NestedField( + 5: NestedField(field_id=5, name="element", field_type=StringType(), is_optional=True), + 6: NestedField( field_id=6, name="quux", field_type=MapType(key_id=7, key_type=StringType(), value_id=8, value_type=IntegerType(), value_is_optional=True), is_optional=False, ), - ] - table_schema = schema.Schema(fields=fields, schema_id=1) - visitor = schema.IndexByName() - visitor.visit(table_schema) - - assert visitor.result == {"bar": 2, "baz": 3, "foo": 1, "qux": 4, "qux.element": 5, "quux": 6, "quux.key": 7, "quux.value": 8} + 7: NestedField(field_id=7, name="key", field_type=StringType(), is_optional=False), + 8: NestedField(field_id=8, name="value", field_type=IntegerType(), is_optional=True), + } def test_index_by_id_schema_visitor_raise_on_unregistered_type(): - """Test raising a NotImplementedError when a type with no registered visit operation is passed to an IndexById visitor""" + """Test raising a NotImplementedError when an invalid type is provided to the index_by_id function""" - visitor = schema.IndexById() with pytest.raises(NotImplementedError) as exc_info: - visitor.visit("foo") + schema.index_by_id("foo") - assert "Cannot visit node, no IndexById operation implemented for node type: " in str(exc_info.value) + assert "Cannot visit non-type: foo" in str(exc_info.value) def test_index_by_name_schema_visitor_raise_on_unregistered_type(): - """Test raising a NotImplementedError when a type with no registered visit operation is passed to an IndexByName visitor""" + """Test raising a NotImplementedError when an invalid type is provided to the visit function""" + + class FooVisitor(SchemaVisitor[Dict[int, NestedField]]): + pass - visitor = schema.IndexByName() with pytest.raises(NotImplementedError) as exc_info: - visitor.visit("foo") + schema.visit("foo", FooVisitor()) - assert "Cannot visit node, no IndexByName operation implemented for node type: " in str(exc_info.value) + assert "Cannot visit non-type: foo" in str(exc_info.value) From b06783dd67c77aee6c92cec9a24a374ced2ce793 Mon Sep 17 00:00:00 2001 From: samredai <43911210+samredai@users.noreply.github.com> Date: Mon, 28 Mar 2022 10:13:29 -0700 Subject: [PATCH 11/23] Use __str__ from other implementation --- python/src/iceberg/table/schema.py | 5 +---- python/tests/table/test_schema.py | 8 +++++--- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/python/src/iceberg/table/schema.py b/python/src/iceberg/table/schema.py index 187a371367a9..7e71bdd984d4 100644 --- a/python/src/iceberg/table/schema.py +++ b/python/src/iceberg/table/schema.py @@ -35,10 +35,7 @@ def __init__(self, *columns: Iterable[NestedField]): self._struct = StructType(*columns) # type: ignore def __str__(self): - column_strings = [ - f"{column.field_id}: name={column.name}, type={column.type}, required={column.is_required}" for column in self.columns - ] - return "\n".join(column_strings) + return "table { \n" + "\n".join([" " + str(field) for field in self.columns]) + "\n }" def __repr__(self): return f"Schema(fields={repr(self.columns)})" diff --git a/python/tests/table/test_schema.py b/python/tests/table/test_schema.py index 2c4175a219e5..36d0ca86b251 100644 --- a/python/tests/table/test_schema.py +++ b/python/tests/table/test_schema.py @@ -39,9 +39,11 @@ def test_schema_str(): NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), ] table_schema = schema.Schema(*columns) - assert """1: name=foo, type=string, required=True -2: name=bar, type=int, required=False -3: name=baz, type=boolean, required=True""" == str( + assert """table { + 1: foo: required string + 2: bar: optional int + 3: baz: required boolean + }""" == str( table_schema ) From 71a24f9f9b9ad24fb751d240280ff9a1eb813238 Mon Sep 17 00:00:00 2001 From: samredai <43911210+samredai@users.noreply.github.com> Date: Mon, 28 Mar 2022 10:17:44 -0700 Subject: [PATCH 12/23] Removing singledispatch imports, no longer used --- python/setup.cfg | 2 -- python/src/iceberg/table/schema.py | 8 -------- 2 files changed, 10 deletions(-) diff --git a/python/setup.cfg b/python/setup.cfg index 85b5909db973..a707fe123d77 100644 --- a/python/setup.cfg +++ b/python/setup.cfg @@ -42,8 +42,6 @@ package_dir = = src packages = find: python_requires = >=3.7 -install_requires = - singledispatch [options.extras_require] arrow = pyarrow diff --git a/python/src/iceberg/table/schema.py b/python/src/iceberg/table/schema.py index 7e71bdd984d4..00603b67e77e 100644 --- a/python/src/iceberg/table/schema.py +++ b/python/src/iceberg/table/schema.py @@ -15,16 +15,8 @@ # specific language governing permissions and limitations # under the License. -import sys from typing import Dict, Generic, Iterable, List, Optional, TypeVar -if sys.version_info >= (3, 8): # pragma: no cover - from functools import singledispatchmethod - from typing import Protocol -else: # pragma: no cover - from typing_extensions import Protocol # type: ignore - from singledispatch import singledispatchmethod # type: ignore - from iceberg.types import ListType, MapType, NestedField, PrimitiveType, StructType From 40019fc9ba1b3efb0613c7f6308d04b8b3a56989 Mon Sep 17 00:00:00 2001 From: samredai <43911210+samredai@users.noreply.github.com> Date: Mon, 28 Mar 2022 10:56:56 -0700 Subject: [PATCH 13/23] Adding Schema methods _find_field_by_name find_field find_type find_column_name select _case_sensitive_select _case_insensitive_select --- python/src/iceberg/table/schema.py | 80 ++++++++++++++++++++++++------ python/tests/table/test_schema.py | 57 ++++++++++++++++++--- 2 files changed, 115 insertions(+), 22 deletions(-) diff --git a/python/src/iceberg/table/schema.py b/python/src/iceberg/table/schema.py index 00603b67e77e..ee351bb68be9 100644 --- a/python/src/iceberg/table/schema.py +++ b/python/src/iceberg/table/schema.py @@ -15,9 +15,19 @@ # specific language governing permissions and limitations # under the License. -from typing import Dict, Generic, Iterable, List, Optional, TypeVar +from abc import ABC, abstractmethod +from typing import Dict, Generic, Iterable, List, TypeVar, Union + +from iceberg.types import ( + IcebergType, + ListType, + MapType, + NestedField, + PrimitiveType, + StructType, +) -from iceberg.types import ListType, MapType, NestedField, PrimitiveType, StructType +T = TypeVar("T") class Schema(object): @@ -39,8 +49,42 @@ def columns(self): def as_struct(self): return self._struct + def _find_field_by_name(self, index: dict, name: str) -> NestedField: + matched_fields = [field for field_id, field in index.items() if field.name == name] + if not matched_fields: + raise ValueError("Cannot find field: {name_or_id}") + return matched_fields[0] -T = TypeVar("T") + def find_field(self, name_or_id: Union[str, int], case_sensitive: bool = True) -> NestedField: + index = index_by_id(self) + if isinstance(name_or_id, int): + return index[name_or_id] + return self._find_field_by_name(index, name_or_id) + + def find_type(self, name_or_id: Union[str, int], case_sensitive: bool = True) -> IcebergType: + index = index_by_id(self) + if isinstance(name_or_id, int): + return index[name_or_id].type + return self._find_field_by_name(index, name_or_id).type + + def find_column_name(self, id: int) -> str: + index = index_by_id(self) + return index[id].name + + def select(self, names: List[str], case_sensitive: bool = True) -> "Schema": + return ( + self._case_sensitive_select(schema=self, names=names) + if case_sensitive + else self._case_insensitive_select(schema=self, names=names) + ) + + @classmethod + def _case_sensitive_select(cls, schema: "Schema", names: List[str]): + return cls(*[column for column in schema.columns if column.name.lower() in names]) + + @classmethod + def _case_insensitive_select(cls, schema: "Schema", names: List[str]): + return cls(*[column for column in schema.columns if column.name.lower() in [name.lower() for name in names]]) class SchemaVisitor(Generic[T]): @@ -68,23 +112,29 @@ def before_map_value(self, value: NestedField) -> None: def after_map_value(self, value: NestedField) -> None: self.after_field(value) - def schema(self, schema: Schema, struct_result: T) -> Optional[T]: - return None + @abstractmethod + def schema(self, schema: Schema, struct_result: T) -> T: + ... - def struct(self, struct: StructType, field_results: List[T]) -> Optional[T]: - return None + @abstractmethod + def struct(self, struct: StructType, field_results: List[T]) -> T: + ... - def field(self, field: NestedField, field_result: T) -> Optional[T]: - return None + @abstractmethod + def field(self, field: NestedField, field_result: T) -> T: + ... - def list(self, list_type: ListType, element_result: T) -> Optional[T]: - return None + @abstractmethod + def list(self, list_type: ListType, element_result: T) -> T: + ... - def map(self, map_type: MapType, key_result: T, value_result: T) -> Optional[T]: - return None + @abstractmethod + def map(self, map_type: MapType, key_result: T, value_result: T) -> T: + ... - def primitive(self, primitive: PrimitiveType) -> Optional[T]: - return None + @abstractmethod + def primitive(self, primitive: PrimitiveType) -> T: + ... def visit(obj, visitor: SchemaVisitor[Optional[T]]) -> Optional[T]: diff --git a/python/tests/table/test_schema.py b/python/tests/table/test_schema.py index 36d0ca86b251..c018b43e8115 100644 --- a/python/tests/table/test_schema.py +++ b/python/tests/table/test_schema.py @@ -193,13 +193,56 @@ def test_index_by_id_schema_visitor_raise_on_unregistered_type(): assert "Cannot visit non-type: foo" in str(exc_info.value) -def test_index_by_name_schema_visitor_raise_on_unregistered_type(): - """Test raising a NotImplementedError when an invalid type is provided to the visit function""" +def test_schema_find_field(): + """Test finding a field in a schema""" + columns = [ + NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), + ] + table_schema = schema.Schema(*columns) - class FooVisitor(SchemaVisitor[Dict[int, NestedField]]): - pass + assert table_schema.find_field(1) == table_schema.find_field("foo") == columns[0] + assert table_schema.find_field(2) == table_schema.find_field("bar") == columns[1] + assert table_schema.find_field(3) == table_schema.find_field("baz") == columns[2] - with pytest.raises(NotImplementedError) as exc_info: - schema.visit("foo", FooVisitor()) - assert "Cannot visit non-type: foo" in str(exc_info.value) +def test_schema_find_type(): + """Test finding the type of a column given its field ID""" + columns = [ + NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), + ] + table_schema = schema.Schema(*columns) + + assert table_schema.find_type(1) == table_schema.find_type("foo") == StringType() + assert table_schema.find_type(2) == table_schema.find_type("bar") == IntegerType() + assert table_schema.find_type(3) == table_schema.find_type("baz") == BooleanType() + + +def test_schema_find_column_name(): + """Test finding a column name given its field ID""" + columns = [ + NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), + ] + table_schema = schema.Schema(*columns) + + assert table_schema.find_column_name(1) == "foo" + assert table_schema.find_column_name(2) == "bar" + assert table_schema.find_column_name(3) == "baz" + + +def test_schema_select(): + """Test selecting columns in a schema""" + columns = [ + NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), + ] + table_schema = schema.Schema(*columns) + + projected_schema = table_schema.select(["foo", "bar"]) + len(projected_schema.columns) == 2 From 373c89ab1c143aad6503a30b49f08e03bd85c568 Mon Sep 17 00:00:00 2001 From: samredai <43911210+samredai@users.noreply.github.com> Date: Mon, 28 Mar 2022 11:00:41 -0700 Subject: [PATCH 14/23] Making Schema an ABC class (avoids confusing mypy by returning None) --- python/src/iceberg/table/schema.py | 10 +++++----- python/tests/table/test_schema.py | 2 -- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/python/src/iceberg/table/schema.py b/python/src/iceberg/table/schema.py index ee351bb68be9..ad1e66d11f05 100644 --- a/python/src/iceberg/table/schema.py +++ b/python/src/iceberg/table/schema.py @@ -30,7 +30,7 @@ T = TypeVar("T") -class Schema(object): +class Schema: """A table Schema""" def __init__(self, *columns: Iterable[NestedField]): @@ -87,7 +87,7 @@ def _case_insensitive_select(cls, schema: "Schema", names: List[str]): return cls(*[column for column in schema.columns if column.name.lower() in [name.lower() for name in names]]) -class SchemaVisitor(Generic[T]): +class SchemaVisitor(Generic[T], ABC): def before_field(self, field: NestedField) -> None: pass @@ -137,7 +137,7 @@ def primitive(self, primitive: PrimitiveType) -> T: ... -def visit(obj, visitor: SchemaVisitor[Optional[T]]) -> Optional[T]: +def visit(obj, visitor: SchemaVisitor[T]) -> T: if isinstance(obj, Schema): return visitor.schema(obj, visit(obj.as_struct(), visitor)) @@ -185,8 +185,8 @@ def visit(obj, visitor: SchemaVisitor[Optional[T]]) -> Optional[T]: raise NotImplementedError("Cannot visit non-type: %s" % obj) -def index_by_id(schema_or_type) -> Optional[Dict[int, NestedField]]: - class IndexById(SchemaVisitor[Optional[Dict[int, NestedField]]]): +def index_by_id(schema_or_type) -> Dict[int, NestedField]: + class IndexById(SchemaVisitor[Dict[int, NestedField]]): def __init__(self): self._index: Dict[int, NestedField] = {} diff --git a/python/tests/table/test_schema.py b/python/tests/table/test_schema.py index c018b43e8115..964103bf8d0b 100644 --- a/python/tests/table/test_schema.py +++ b/python/tests/table/test_schema.py @@ -15,12 +15,10 @@ # specific language governing permissions and limitations # under the License. -from typing import Dict import pytest from iceberg.table import schema -from iceberg.table.schema import SchemaVisitor from iceberg.types import ( BooleanType, IntegerType, From c163b76e187d9bf4391e0322ca2a31c714a70396 Mon Sep 17 00:00:00 2001 From: samredai <43911210+samredai@users.noreply.github.com> Date: Mon, 28 Mar 2022 11:22:20 -0700 Subject: [PATCH 15/23] Add temporary pin for click --- python/tox.ini | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tox.ini b/python/tox.ini index 8e47a3720415..44e276ee2d87 100644 --- a/python/tox.ini +++ b/python/tox.ini @@ -49,6 +49,7 @@ commands = [testenv:format-check] deps = black + click == 8.0.4 # TODO: Remove this line: Temporarily pins click back to 8.0.4--It's a dependency of black and the release of click 8.1.0 has caused an import issue isort autoflake commands = From 90ea3bf1443df2abe9fa7647b128de2a47ad0819 Mon Sep 17 00:00:00 2001 From: samredai <43911210+samredai@users.noreply.github.com> Date: Tue, 29 Mar 2022 21:36:01 -0700 Subject: [PATCH 16/23] Add index_by_name and clean up tests --- python/src/iceberg/table/schema.py | 98 ++++++++--- python/tests/conftest.py | 49 ++++++ python/tests/table/test_schema.py | 266 +++++++++++++++-------------- python/tox.ini | 1 - 4 files changed, 259 insertions(+), 155 deletions(-) create mode 100644 python/tests/conftest.py diff --git a/python/src/iceberg/table/schema.py b/python/src/iceberg/table/schema.py index ad1e66d11f05..28bdfd35514a 100644 --- a/python/src/iceberg/table/schema.py +++ b/python/src/iceberg/table/schema.py @@ -15,8 +15,10 @@ # specific language governing permissions and limitations # under the License. +from __future__ import annotations + from abc import ABC, abstractmethod -from typing import Dict, Generic, Iterable, List, TypeVar, Union +from typing import Dict, Generic, Iterable, List, TypeVar from iceberg.types import ( IcebergType, @@ -49,42 +51,43 @@ def columns(self): def as_struct(self): return self._struct - def _find_field_by_name(self, index: dict, name: str) -> NestedField: - matched_fields = [field for field_id, field in index.items() if field.name == name] - if not matched_fields: - raise ValueError("Cannot find field: {name_or_id}") - return matched_fields[0] - - def find_field(self, name_or_id: Union[str, int], case_sensitive: bool = True) -> NestedField: - index = index_by_id(self) + def find_field(self, name_or_id: str | int, case_sensitive: bool = True) -> NestedField: + id_index = index_by_id(self) if isinstance(name_or_id, int): - return index[name_or_id] - return self._find_field_by_name(index, name_or_id) + return id_index[name_or_id] + name_index = index_by_name(self) + field_id = name_index[name_or_id] + return id_index[field_id] - def find_type(self, name_or_id: Union[str, int], case_sensitive: bool = True) -> IcebergType: - index = index_by_id(self) + def find_type(self, name_or_id: str | int, case_sensitive: bool = True) -> IcebergType: + id_index = index_by_id(self) if isinstance(name_or_id, int): - return index[name_or_id].type - return self._find_field_by_name(index, name_or_id).type - - def find_column_name(self, id: int) -> str: - index = index_by_id(self) - return index[id].name + return id_index[name_or_id].type + name_index = index_by_name(self) + field_id = name_index[name_or_id] + return id_index[field_id].type + + def find_column_name(self, column_id: int) -> str: + index = index_by_name(self) + matched_column = [name for name, field_id in index.items() if field_id == column_id] + if not matched_column: + raise ValueError(f"Cannot find column name: {column_id}") + return matched_column[0] def select(self, names: List[str], case_sensitive: bool = True) -> "Schema": - return ( - self._case_sensitive_select(schema=self, names=names) - if case_sensitive - else self._case_insensitive_select(schema=self, names=names) - ) + if case_sensitive: + return self._case_sensitive_select(schema=self, names=names) + return self._case_insensitive_select(schema=self, names=names) @classmethod def _case_sensitive_select(cls, schema: "Schema", names: List[str]): - return cls(*[column for column in schema.columns if column.name.lower() in names]) + # TODO: Add a PruneColumns schema visitor and use it here + raise NotImplementedError() @classmethod def _case_insensitive_select(cls, schema: "Schema", names: List[str]): - return cls(*[column for column in schema.columns if column.name.lower() in [name.lower() for name in names]]) + # TODO: Add a PruneColumns schema visitor and use it here + raise NotImplementedError() class SchemaVisitor(Generic[T], ABC): @@ -213,3 +216,46 @@ def primitive(self, primitive): return self._index return visit(schema_or_type, IndexById()) + + +def index_by_name(schema_or_type) -> Dict[str, int]: + class IndexByName(SchemaVisitor[Dict[str, int]]): + def __init__(self): + self._index: Dict[str, int] = {} + self._field_names = [] + + def before_field(self, field: NestedField) -> None: + self._field_names.append(field.name) + + def after_field(self, field: NestedField) -> None: + self._field_names.pop() + + def schema(self, schema, struct_result): + return self._index + + def struct(self, struct, field_results): + return self._index + + def field(self, field, field_result): + self._add_field(field.name, field.field_id) + + def list(self, list_type, result): + self._add_field(list_type.element.name, list_type.element.field_id) + + def map(self, map_type, key_result, value_result): + self._add_field(map_type.key.name, map_type.key.field_id) + self._add_field(map_type.value.name, map_type.value.field_id) + + def _add_field(self, name, field_id): + full_name = name + if self._field_names: + full_name = ".".join([".".join(self._field_names), name]) + + if full_name in self._index: + raise ValueError(f"Invalid schema, multiple fields for name {full_name}: {index[full_name]} and {field_id}") + self._index[full_name] = field_id + + def primitive(self, primitive): + return self._index + + return visit(schema_or_type, IndexByName()) diff --git a/python/tests/conftest.py b/python/tests/conftest.py new file mode 100644 index 000000000000..f226708ff976 --- /dev/null +++ b/python/tests/conftest.py @@ -0,0 +1,49 @@ +import pytest + +from iceberg.table import schema +from iceberg.types import ( + BooleanType, + IntegerType, + ListType, + MapType, + NestedField, + StringType, +) + + +@pytest.fixture(scope="session", autouse=True) +def table_schema_simple(): + return schema.Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), + ) + + +@pytest.fixture(scope="session", autouse=True) +def table_schema_nested(): + return schema.Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), + NestedField( + field_id=4, + name="qux", + field_type=ListType(element_id=5, element_type=StringType(), element_is_optional=True), + is_optional=True, + ), + NestedField( + field_id=6, + name="quux", + field_type=MapType( + key_id=7, + key_type=StringType(), + value_id=8, + value_type=MapType( + key_id=9, key_type=StringType(), value_id=10, value_type=IntegerType(), value_is_optional=True + ), + value_is_optional=True, + ), + is_optional=True, + ), + ) diff --git a/python/tests/table/test_schema.py b/python/tests/table/test_schema.py index 964103bf8d0b..4d8e0ebd2269 100644 --- a/python/tests/table/test_schema.py +++ b/python/tests/table/test_schema.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +from textwrap import dedent import pytest @@ -29,20 +30,15 @@ ) -def test_schema_str(): +def test_schema_str(table_schema_simple): """Test casting a schema to a string""" - columns = [ - NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), - NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), - NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), - ] - table_schema = schema.Schema(*columns) - assert """table { - 1: foo: required string - 2: bar: optional int - 3: baz: required boolean - }""" == str( - table_schema + assert str(table_schema_simple) == dedent( + """\ + table { + 1: foo: required string + 2: bar: optional int + 3: baz: required boolean + }""" ) @@ -64,30 +60,87 @@ def test_schema_repr(schema, expected_repr): assert repr(schema) == expected_repr -def test_schema_find_field_name_by_field_id(): - """Test finding a column name using its field ID""" - columns = [ - NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), - NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), - NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), - ] - table_schema = schema.Schema(*columns) - index = schema.index_by_id(table_schema) +def test_schema_index_by_id_visitor(table_schema_nested): + """Test index_by_id visitor function""" + index = schema.index_by_id(table_schema_nested) + assert index == { + 1: NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), + 2: NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), + 3: NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), + 4: NestedField( + field_id=4, + name="qux", + field_type=ListType(element_id=5, element_type=StringType(), element_is_optional=True), + is_optional=True, + ), + 5: NestedField(field_id=5, name="element", field_type=StringType(), is_optional=True), + 6: NestedField( + field_id=6, + name="quux", + field_type=MapType( + key_id=7, + key_type=StringType(), + value_id=8, + value_type=MapType( + key_id=9, key_type=StringType(), value_id=10, value_type=IntegerType(), value_is_optional=True + ), + value_is_optional=True, + ), + is_optional=True, + ), + 7: NestedField(field_id=7, name="key", field_type=StringType(), is_optional=False), + 9: NestedField(field_id=9, name="key", field_type=StringType(), is_optional=False), + 8: NestedField( + field_id=8, + name="value", + field_type=MapType(key_id=9, key_type=StringType(), value_id=10, value_type=IntegerType(), value_is_optional=True), + is_optional=True, + ), + 10: NestedField(field_id=10, name="value", field_type=IntegerType(), is_optional=True), + } - assert index[1].name == "foo" - assert index[2].name == "bar" - assert index[3].name == "baz" +def test_schema_index_by_name_visitor(table_schema_nested): + """Test index_by_name visitor function""" + index = schema.index_by_name(table_schema_nested) + assert index == { + "foo": 1, + "bar": 2, + "baz": 3, + "qux": 4, + "qux.element": 5, + "quux": 6, + "quux.key": 7, + "quux.value": 8, + "quux.value.key": 9, + "quux.value.value": 10, + } -def test_schema_find_field_by_id(): + +def test_schema_find_column_name(table_schema_nested): + """Test finding a column name using its field ID""" + assert table_schema_nested.find_column_name(1) == "foo" + assert table_schema_nested.find_column_name(2) == "bar" + assert table_schema_nested.find_column_name(3) == "baz" + assert table_schema_nested.find_column_name(4) == "qux" + assert table_schema_nested.find_column_name(5) == "qux.element" + assert table_schema_nested.find_column_name(6) == "quux" + assert table_schema_nested.find_column_name(7) == "quux.key" + assert table_schema_nested.find_column_name(8) == "quux.value" + assert table_schema_nested.find_column_name(9) == "quux.value.key" + assert table_schema_nested.find_column_name(10) == "quux.value.value" + + +def test_schema_find_column_name_raise_on_id_not_found(table_schema_nested): + """Test raising an error when a field ID cannot be found""" + with pytest.raises(ValueError) as exc_info: + table_schema_nested.find_column_name(99) + assert "Cannot find column name: 99" in str(exc_info.value) + + +def test_schema_find_field_by_id(table_schema_simple): """Test finding a column using its field ID""" - columns = [ - NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), - NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), - NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), - ] - table_schema = schema.Schema(*columns) - index = schema.index_by_id(table_schema) + index = schema.index_by_id(table_schema_simple) column1 = index[1] assert isinstance(column1, NestedField) @@ -108,59 +161,25 @@ def test_schema_find_field_by_id(): assert column3.is_optional == False -def test_schema_find_field_by_id_raise_on_unknown_field(): +def test_schema_find_field_by_id_raise_on_unknown_field(table_schema_simple): """Test raising when the field ID is not found among columns""" - columns = [ - NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), - NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), - NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), - ] - table_schema = schema.Schema(*columns) - index = schema.index_by_id(table_schema) - + index = schema.index_by_id(table_schema_simple) with pytest.raises(Exception) as exc_info: index[4] - assert str(exc_info.value) == "4" -def test_schema_find_field_type_by_id(): +def test_schema_find_field_type_by_id(table_schema_simple): """Test retrieving a columns's type using its field ID""" - columns = [ - NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), - NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), - NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), - ] - table_schema = schema.Schema(*columns) - index = schema.index_by_id(table_schema) - - assert index[1] == columns[0] - assert index[2] == columns[1] - assert index[3] == columns[2] + index = schema.index_by_id(table_schema_simple) + assert index[1] == NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False) + assert index[2] == NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True) + assert index[3] == NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False) -def test_index_by_id_schema_visitor(): +def test_index_by_id_schema_visitor(table_schema_nested): """Test the index_by_id function that uses the IndexById schema visitor""" - columns = [ - NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), - NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), - NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), - NestedField( - field_id=4, - name="qux", - field_type=ListType(element_id=5, element_type=StringType(), element_is_optional=True), - is_optional=False, - ), - NestedField( - field_id=6, - name="quux", - field_type=MapType(key_id=7, key_type=StringType(), value_id=8, value_type=IntegerType(), value_is_optional=True), - is_optional=False, - ), - ] - table_schema = schema.Schema(*columns) - - assert schema.index_by_id(table_schema) == { + assert schema.index_by_id(table_schema_nested) == { 1: NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), 2: NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), 3: NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), @@ -168,79 +187,70 @@ def test_index_by_id_schema_visitor(): field_id=4, name="qux", field_type=ListType(element_id=5, element_type=StringType(), element_is_optional=True), - is_optional=False, + is_optional=True, ), 5: NestedField(field_id=5, name="element", field_type=StringType(), is_optional=True), 6: NestedField( field_id=6, name="quux", - field_type=MapType(key_id=7, key_type=StringType(), value_id=8, value_type=IntegerType(), value_is_optional=True), - is_optional=False, + field_type=MapType( + key_id=7, + key_type=StringType(), + value_id=8, + value_type=MapType( + key_id=9, key_type=StringType(), value_id=10, value_type=IntegerType(), value_is_optional=True + ), + value_is_optional=True, + ), + is_optional=True, ), 7: NestedField(field_id=7, name="key", field_type=StringType(), is_optional=False), - 8: NestedField(field_id=8, name="value", field_type=IntegerType(), is_optional=True), + 8: NestedField( + field_id=8, + name="value", + field_type=MapType(key_id=9, key_type=StringType(), value_id=10, value_type=IntegerType(), value_is_optional=True), + is_optional=True, + ), + 9: NestedField(field_id=9, name="key", field_type=StringType(), is_optional=False), + 10: NestedField(field_id=10, name="value", field_type=IntegerType(), is_optional=True), } def test_index_by_id_schema_visitor_raise_on_unregistered_type(): """Test raising a NotImplementedError when an invalid type is provided to the index_by_id function""" - with pytest.raises(NotImplementedError) as exc_info: schema.index_by_id("foo") - assert "Cannot visit non-type: foo" in str(exc_info.value) -def test_schema_find_field(): +def test_schema_find_field(table_schema_simple): """Test finding a field in a schema""" - columns = [ - NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), - NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), - NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), - ] - table_schema = schema.Schema(*columns) - - assert table_schema.find_field(1) == table_schema.find_field("foo") == columns[0] - assert table_schema.find_field(2) == table_schema.find_field("bar") == columns[1] - assert table_schema.find_field(3) == table_schema.find_field("baz") == columns[2] + assert ( + table_schema_simple.find_field(1) + == table_schema_simple.find_field("foo") + == NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False) + ) + assert ( + table_schema_simple.find_field(2) + == table_schema_simple.find_field("bar") + == NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True) + ) + assert ( + table_schema_simple.find_field(3) + == table_schema_simple.find_field("baz") + == NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False) + ) -def test_schema_find_type(): +def test_schema_find_type(table_schema_simple): """Test finding the type of a column given its field ID""" - columns = [ - NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), - NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), - NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), - ] - table_schema = schema.Schema(*columns) - - assert table_schema.find_type(1) == table_schema.find_type("foo") == StringType() - assert table_schema.find_type(2) == table_schema.find_type("bar") == IntegerType() - assert table_schema.find_type(3) == table_schema.find_type("baz") == BooleanType() + assert table_schema_simple.find_type(1) == table_schema_simple.find_type("foo") == StringType() + assert table_schema_simple.find_type(2) == table_schema_simple.find_type("bar") == IntegerType() + assert table_schema_simple.find_type(3) == table_schema_simple.find_type("baz") == BooleanType() -def test_schema_find_column_name(): +def test_schema_find_column_name(table_schema_simple): """Test finding a column name given its field ID""" - columns = [ - NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), - NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), - NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), - ] - table_schema = schema.Schema(*columns) - - assert table_schema.find_column_name(1) == "foo" - assert table_schema.find_column_name(2) == "bar" - assert table_schema.find_column_name(3) == "baz" - - -def test_schema_select(): - """Test selecting columns in a schema""" - columns = [ - NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), - NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), - NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), - ] - table_schema = schema.Schema(*columns) - - projected_schema = table_schema.select(["foo", "bar"]) - len(projected_schema.columns) == 2 + assert table_schema_simple.find_column_name(1) == "foo" + assert table_schema_simple.find_column_name(2) == "bar" + assert table_schema_simple.find_column_name(3) == "baz" diff --git a/python/tox.ini b/python/tox.ini index 44e276ee2d87..8e47a3720415 100644 --- a/python/tox.ini +++ b/python/tox.ini @@ -49,7 +49,6 @@ commands = [testenv:format-check] deps = black - click == 8.0.4 # TODO: Remove this line: Temporarily pins click back to 8.0.4--It's a dependency of black and the release of click 8.1.0 has caused an import issue isort autoflake commands = From bffd182f9d0b13e534d14e269f30c872dde731bd Mon Sep 17 00:00:00 2001 From: samredai <43911210+samredai@users.noreply.github.com> Date: Tue, 29 Mar 2022 21:47:59 -0700 Subject: [PATCH 17/23] Adding ASF headers --- python/tests/conftest.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/python/tests/conftest.py b/python/tests/conftest.py index f226708ff976..181181f5e534 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + import pytest from iceberg.table import schema From d957c9b90dcfae6f695fbd2f7358f59604805799 Mon Sep 17 00:00:00 2001 From: samredai <43911210+samredai@users.noreply.github.com> Date: Tue, 29 Mar 2022 22:47:04 -0700 Subject: [PATCH 18/23] Add schema_id and identifier_field_ids Schema properties --- python/src/iceberg/table/schema.py | 14 ++++++++++++-- python/tests/conftest.py | 4 ++++ python/tests/table/test_schema.py | 6 ++++-- 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/python/src/iceberg/table/schema.py b/python/src/iceberg/table/schema.py index 28bdfd35514a..66796b3ddcb9 100644 --- a/python/src/iceberg/table/schema.py +++ b/python/src/iceberg/table/schema.py @@ -35,8 +35,10 @@ class Schema: """A table Schema""" - def __init__(self, *columns: Iterable[NestedField]): + def __init__(self, *columns: Iterable[NestedField], schema_id: int, identifier_field_ids: List[int] = []): self._struct = StructType(*columns) # type: ignore + self._schema_id = schema_id + self._identifier_field_ids = identifier_field_ids def __str__(self): return "table { \n" + "\n".join([" " + str(field) for field in self.columns]) + "\n }" @@ -45,9 +47,17 @@ def __repr__(self): return f"Schema(fields={repr(self.columns)})" @property - def columns(self): + def columns(self) -> Iterable[NestedField]: return self._struct.fields + @property + def id(self) -> int: + return self._schema_id + + @property + def identifier_field_ids(self) -> List[int]: + return self._identifier_field_ids + def as_struct(self): return self._struct diff --git a/python/tests/conftest.py b/python/tests/conftest.py index 181181f5e534..72d02f36ae7a 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -34,6 +34,8 @@ def table_schema_simple(): NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), + schema_id=1, + identifier_field_ids=[1], ) @@ -63,4 +65,6 @@ def table_schema_nested(): ), is_optional=True, ), + schema_id=1, + identifier_field_ids=[1], ) diff --git a/python/tests/table/test_schema.py b/python/tests/table/test_schema.py index 4d8e0ebd2269..a9daa0000ece 100644 --- a/python/tests/table/test_schema.py +++ b/python/tests/table/test_schema.py @@ -46,11 +46,13 @@ def test_schema_str(table_schema_simple): "schema, expected_repr", [ ( - schema.Schema(NestedField(1, "foo", StringType())), + schema.Schema(NestedField(1, "foo", StringType()), schema_id=1), "Schema(fields=(NestedField(field_id=1, name='foo', field_type=StringType(), is_optional=True),))", ), ( - schema.Schema(NestedField(1, "foo", StringType()), NestedField(2, "bar", IntegerType(), is_optional=False)), + schema.Schema( + NestedField(1, "foo", StringType()), NestedField(2, "bar", IntegerType(), is_optional=False), schema_id=1 + ), "Schema(fields=(NestedField(field_id=1, name='foo', field_type=StringType(), is_optional=True), NestedField(field_id=2, name='bar', field_type=IntegerType(), is_optional=False)))", ), ], From 474358867787254402e2bb44736fd0110db21679 Mon Sep 17 00:00:00 2001 From: samredai <43911210+samredai@users.noreply.github.com> Date: Wed, 30 Mar 2022 17:58:09 -0700 Subject: [PATCH 19/23] Schema: lazily generate id_index, generate name_index during init --- python/src/iceberg/table/schema.py | 39 +++++++++++++++++++++--------- python/tests/table/test_schema.py | 24 +++++++++++++++--- 2 files changed, 48 insertions(+), 15 deletions(-) diff --git a/python/src/iceberg/table/schema.py b/python/src/iceberg/table/schema.py index 66796b3ddcb9..94b7f9b1f8bf 100644 --- a/python/src/iceberg/table/schema.py +++ b/python/src/iceberg/table/schema.py @@ -39,6 +39,8 @@ def __init__(self, *columns: Iterable[NestedField], schema_id: int, identifier_f self._struct = StructType(*columns) # type: ignore self._schema_id = schema_id self._identifier_field_ids = identifier_field_ids + self._name_index: Dict[str, int] = index_by_name(self) + self._id_index: Dict[int, NestedField] = {} # Will be lazily set when self.id_index property method is called def __str__(self): return "table { \n" + "\n".join([" " + str(field) for field in self.columns]) + "\n }" @@ -58,28 +60,41 @@ def id(self) -> int: def identifier_field_ids(self) -> List[int]: return self._identifier_field_ids + @property + def id_index(self) -> Dict[int, NestedField]: + if not self._id_index: + self._id_index = index_by_id(self) + return self._id_index + + @property + def name_index(self) -> Dict[str, int]: + return self._name_index + def as_struct(self): return self._struct def find_field(self, name_or_id: str | int, case_sensitive: bool = True) -> NestedField: - id_index = index_by_id(self) if isinstance(name_or_id, int): - return id_index[name_or_id] - name_index = index_by_name(self) - field_id = name_index[name_or_id] - return id_index[field_id] + return self.id_index[name_or_id] + if case_sensitive: + field_id = self.name_index[name_or_id] + else: + name_index_lower = {name.lower(): field_id for name, field_id in self.name_index.items()} + field_id = name_index_lower[name_or_id.lower()] + return self.id_index[field_id] def find_type(self, name_or_id: str | int, case_sensitive: bool = True) -> IcebergType: - id_index = index_by_id(self) if isinstance(name_or_id, int): - return id_index[name_or_id].type - name_index = index_by_name(self) - field_id = name_index[name_or_id] - return id_index[field_id].type + return self.id_index[name_or_id].type + if case_sensitive: + field_id = self.name_index[name_or_id] + else: + name_index_lower = {name.lower(): field_id for name, field_id in self.name_index.items()} + field_id = name_index_lower[name_or_id.lower()] + return self.id_index[field_id].type def find_column_name(self, column_id: int) -> str: - index = index_by_name(self) - matched_column = [name for name, field_id in index.items() if field_id == column_id] + matched_column = [name for name, field_id in self.name_index.items() if field_id == column_id] if not matched_column: raise ValueError(f"Cannot find column name: {column_id}") return matched_column[0] diff --git a/python/tests/table/test_schema.py b/python/tests/table/test_schema.py index a9daa0000ece..c381597dc050 100644 --- a/python/tests/table/test_schema.py +++ b/python/tests/table/test_schema.py @@ -230,25 +230,43 @@ def test_schema_find_field(table_schema_simple): assert ( table_schema_simple.find_field(1) == table_schema_simple.find_field("foo") + == table_schema_simple.find_field("FOO", case_sensitive=False) == NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False) ) assert ( table_schema_simple.find_field(2) == table_schema_simple.find_field("bar") + == table_schema_simple.find_field("BAR", case_sensitive=False) == NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True) ) assert ( table_schema_simple.find_field(3) == table_schema_simple.find_field("baz") + == table_schema_simple.find_field("BAZ", case_sensitive=False) == NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False) ) def test_schema_find_type(table_schema_simple): """Test finding the type of a column given its field ID""" - assert table_schema_simple.find_type(1) == table_schema_simple.find_type("foo") == StringType() - assert table_schema_simple.find_type(2) == table_schema_simple.find_type("bar") == IntegerType() - assert table_schema_simple.find_type(3) == table_schema_simple.find_type("baz") == BooleanType() + assert ( + table_schema_simple.find_type(1) + == table_schema_simple.find_type("foo") + == table_schema_simple.find_type("FOO", case_sensitive=False) + == StringType() + ) + assert ( + table_schema_simple.find_type(2) + == table_schema_simple.find_type("bar") + == table_schema_simple.find_type("BAR", case_sensitive=False) + == IntegerType() + ) + assert ( + table_schema_simple.find_type(3) + == table_schema_simple.find_type("baz") + == table_schema_simple.find_type("BAZ", case_sensitive=False) + == BooleanType() + ) def test_schema_find_column_name(table_schema_simple): From 338304abc66e33635f58992554d6ce85671005d4 Mon Sep 17 00:00:00 2001 From: samredai <43911210+samredai@users.noreply.github.com> Date: Wed, 30 Mar 2022 17:59:54 -0700 Subject: [PATCH 20/23] Adding a couple missing return typehints --- python/src/iceberg/table/schema.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/src/iceberg/table/schema.py b/python/src/iceberg/table/schema.py index 94b7f9b1f8bf..fb3c8e2b40bb 100644 --- a/python/src/iceberg/table/schema.py +++ b/python/src/iceberg/table/schema.py @@ -70,7 +70,7 @@ def id_index(self) -> Dict[int, NestedField]: def name_index(self) -> Dict[str, int]: return self._name_index - def as_struct(self): + def as_struct(self) -> StructType: return self._struct def find_field(self, name_or_id: str | int, case_sensitive: bool = True) -> NestedField: @@ -215,7 +215,7 @@ def visit(obj, visitor: SchemaVisitor[T]) -> T: def index_by_id(schema_or_type) -> Dict[int, NestedField]: class IndexById(SchemaVisitor[Dict[int, NestedField]]): - def __init__(self): + def __init__(self) -> None: self._index: Dict[int, NestedField] = {} def schema(self, schema, result): @@ -245,9 +245,9 @@ def primitive(self, primitive): def index_by_name(schema_or_type) -> Dict[str, int]: class IndexByName(SchemaVisitor[Dict[str, int]]): - def __init__(self): + def __init__(self) -> None: self._index: Dict[str, int] = {} - self._field_names = [] + self._field_names: List[str] = [] def before_field(self, field: NestedField) -> None: self._field_names.append(field.name) From 48cf967a01e7f183b441896fa5d102cf9f2183cd Mon Sep 17 00:00:00 2001 From: samredai <43911210+samredai@users.noreply.github.com> Date: Wed, 30 Mar 2022 18:10:12 -0700 Subject: [PATCH 21/23] Adding back singledispatch to install_requires --- python/setup.cfg | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/setup.cfg b/python/setup.cfg index a707fe123d77..85b5909db973 100644 --- a/python/setup.cfg +++ b/python/setup.cfg @@ -42,6 +42,8 @@ package_dir = = src packages = find: python_requires = >=3.7 +install_requires = + singledispatch [options.extras_require] arrow = pyarrow From 46f486b45a734ca4ebc581b382baec6d4601230b Mon Sep 17 00:00:00 2001 From: samredai <43911210+samredai@users.noreply.github.com> Date: Mon, 4 Apr 2022 13:20:34 -0700 Subject: [PATCH 22/23] Incorporate PR feedback --- python/src/iceberg/schema.py | 448 ++++++++++++++++++++++++ python/src/iceberg/table/schema.py | 286 --------------- python/tests/conftest.py | 17 +- python/tests/{table => }/test_schema.py | 78 ++++- 4 files changed, 530 insertions(+), 299 deletions(-) create mode 100644 python/src/iceberg/schema.py delete mode 100644 python/src/iceberg/table/schema.py rename python/tests/{table => }/test_schema.py (78%) diff --git a/python/src/iceberg/schema.py b/python/src/iceberg/schema.py new file mode 100644 index 000000000000..ade89c5e4127 --- /dev/null +++ b/python/src/iceberg/schema.py @@ -0,0 +1,448 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import sys +from abc import ABC, abstractmethod +from typing import Dict, Generic, Iterable, List, TypeVar + +if sys.version_info >= (3, 8): + from functools import singledispatch # pragma: no cover +else: + from singledispatch import singledispatch # pragma: no cover + +from iceberg.types import ( + IcebergType, + ListType, + MapType, + NestedField, + PrimitiveType, + StructType, +) + +T = TypeVar("T") + + +class Schema: + """A table Schema + + Example: + >>> from iceberg import schema + >>> from iceberg import types + """ + + def __init__(self, *columns: Iterable[NestedField], schema_id: int, identifier_field_ids: List[int] = []): + self._struct = StructType(*columns) # type: ignore + self._schema_id = schema_id + self._identifier_field_ids = identifier_field_ids + self._name_to_id: Dict[str, int] = index_by_name(self) + self._name_to_id_lower: Dict[str, int] = {} # Should be accessed through self._lazy_name_to_id_lower() + self._id_to_field: Dict[int, NestedField] = {} # Should be accessed through self._lazy_id_to_field() + + def __str__(self): + return "table {\n" + "\n".join([" " + str(field) for field in self.columns]) + "\n}" + + def __repr__(self): + return ( + f"Schema(fields={repr(self.columns)}, schema_id={self.schema_id}, identifier_field_ids={self.identifier_field_ids})" + ) + + @property + def columns(self) -> Iterable[NestedField]: + """A list of the top-level fields in the underlying struct""" + return self._struct.fields + + @property + def schema_id(self) -> int: + """The ID of this Schema""" + return self._schema_id + + @property + def identifier_field_ids(self) -> List[int]: + return self._identifier_field_ids + + def _lazy_id_to_field(self) -> Dict[int, NestedField]: + """Returns an index of field ID to NestedField instance + + This property is calculated once when called for the first time. Subsequent calls to this property will use a cached index. + """ + if not self._id_to_field: + self._id_to_field = index_by_id(self) + return self._id_to_field + + def _lazy_name_to_id_lower(self) -> Dict[str, int]: + """Returns an index of lower-case field names to field IDs + + This property is calculated once when called for the first time. Subsequent calls to this property will use a cached index. + """ + if not self._name_to_id_lower: + self._name_to_id_lower = {name.lower(): field_id for name, field_id in self._name_to_id.items()} + return self._name_to_id_lower + + def as_struct(self) -> StructType: + """Returns the underlying struct""" + return self._struct + + def find_field(self, name_or_id: str | int, case_sensitive: bool = True) -> NestedField: + """Find a field using a field name or field ID + + Args: + name_or_id (str | int): Either a field name or a field ID + case_sensitive (bool, optional): Whether to peform a case-sensitive lookup using a field name. Defaults to True. + + Returns: + NestedField: The matched NestedField + """ + if isinstance(name_or_id, int): + field = self._lazy_id_to_field().get(name_or_id) + return field # type: ignore + if case_sensitive: + field_id = self._name_to_id.get(name_or_id) + else: + field_id = self._lazy_name_to_id_lower().get(name_or_id.lower()) + return self._lazy_id_to_field().get(field_id) # type: ignore + + def find_type(self, name_or_id: str | int, case_sensitive: bool = True) -> IcebergType: + """Find a field type using a field name or field ID + + Args: + name_or_id (str | int): Either a field name or a field ID + case_sensitive (bool, optional): Whether to peform a case-sensitive lookup using a field name. Defaults to True. + + Returns: + NestedField: The type of the matched NestedField + """ + field = self.find_field(name_or_id=name_or_id, case_sensitive=case_sensitive) + return field.type # type: ignore + + def find_column_name(self, column_id: int) -> str: + """Find a column name given a column ID + + Args: + column_id (int): The ID of the column + + Raises: + ValueError: If no column name can be found for the given column ID + + Returns: + str: The column name + """ + column = self._lazy_id_to_field().get(column_id) + return None if column is None else column.name # type: ignore + + def select(self, names: List[str], case_sensitive: bool = True) -> "Schema": + """Return a new schema instance pruned to a subset of columns + + Args: + names (List[str]): A list of column names + case_sensitive (bool, optional): Whether to peform a case-sensitive lookup for each column name. Defaults to True. + + Returns: + Schema: A new schema with pruned columns + """ + if case_sensitive: + return self._case_sensitive_select(schema=self, names=names) + return self._case_insensitive_select(schema=self, names=names) + + @classmethod + def _case_sensitive_select(cls, schema: "Schema", names: List[str]): + # TODO: Add a PruneColumns schema visitor and use it here + raise NotImplementedError() + + @classmethod + def _case_insensitive_select(cls, schema: "Schema", names: List[str]): + # TODO: Add a PruneColumns schema visitor and use it here + raise NotImplementedError() + + +class SchemaVisitor(Generic[T], ABC): + def before_field(self, field: NestedField) -> None: + """Override this method to perform an action immediately before visiting a field""" + + def after_field(self, field: NestedField) -> None: + """Override this method to perform an action immediately after visiting a field""" + + def before_list_element(self, element: NestedField) -> None: + """Override this method to perform an action immediately before visiting an element within a ListType""" + self.before_field(element) + + def after_list_element(self, element: NestedField) -> None: + """Override this method to perform an action immediately after visiting an element within a ListType""" + self.after_field(element) + + def before_map_key(self, key: NestedField) -> None: + """Override this method to perform an action immediately before visiting a key within a MapType""" + self.before_field(key) + + def after_map_key(self, key: NestedField) -> None: + """Override this method to perform an action immediately after visiting a key within a MapType""" + self.after_field(key) + + def before_map_value(self, value: NestedField) -> None: + """Override this method to perform an action immediately before visiting a value within a MapType""" + self.before_field(value) + + def after_map_value(self, value: NestedField) -> None: + """Override this method to perform an action immediately after visiting a value within a MapType""" + self.after_field(value) + + @abstractmethod + def schema(self, schema: Schema, struct_result: T) -> T: + """Visit a Schema""" + ... + + @abstractmethod + def struct(self, struct: StructType, field_results: List[T]) -> T: + """Visit a StructType""" + ... + + @abstractmethod + def field(self, field: NestedField, field_result: T) -> T: + """Visit a NestedField""" + ... + + @abstractmethod + def list(self, list_type: ListType, element_result: T) -> T: + """Visit a ListType""" + ... + + @abstractmethod + def map(self, map_type: MapType, key_result: T, value_result: T) -> T: + """Visit a MapType""" + ... + + @abstractmethod + def primitive(self, primitive: PrimitiveType) -> T: + """Visit a PrimitiveType""" + ... + + +@singledispatch +def visit(obj, visitor: SchemaVisitor[T]) -> T: + """A generic function for applying a schema visitor to any point within a schema + + Args: + obj(Schema | IcebergType): An instance of a Schema or an IcebergType + visitor (SchemaVisitor[T]): An instance of an implementation of the generic SchemaVisitor base class + + Raises: + NotImplementedError: If attempting to visit an unrecognized object type + """ + raise NotImplementedError("Cannot visit non-type: %s" % obj) + + +@visit.register(Schema) +def _(obj: Schema, visitor: SchemaVisitor[T]) -> T: + """Visit a Schema with a concrete SchemaVisitor""" + return visitor.schema(obj, visit(obj.as_struct(), visitor)) + + +@visit.register(StructType) +def _(obj: StructType, visitor: SchemaVisitor[T]) -> T: + """Visit a StructType with a concrete SchemaVisitor""" + results = [] + for field in obj.fields: + visitor.before_field(field) + try: + result = visit(field.type, visitor) + finally: + visitor.after_field(field) + + results.append(visitor.field(field, result)) + + return visitor.struct(obj, results) + + +@visit.register(ListType) +def _(obj: ListType, visitor: SchemaVisitor[T]) -> T: + """Visit a ListType with a concrete SchemaVisitor""" + visitor.before_list_element(obj.element) + try: + result = visit(obj.element.type, visitor) + finally: + visitor.after_list_element(obj.element) + + return visitor.list(obj, result) + + +@visit.register(MapType) +def _(obj: MapType, visitor: SchemaVisitor[T]) -> T: + """Visit a MapType with a concrete SchemaVisitor""" + visitor.before_map_key(obj.key) + try: + key_result = visit(obj.key.type, visitor) + finally: + visitor.after_map_key(obj.key) + + visitor.before_map_value(obj.value) + try: + value_result = visit(obj.value.type, visitor) + finally: + visitor.after_list_element(obj.value) + + return visitor.map(obj, key_result, value_result) + + +@visit.register(PrimitiveType) +def _(obj: PrimitiveType, visitor: SchemaVisitor[T]) -> T: + """Visit a PrimitiveType with a concrete SchemaVisitor""" + return visitor.primitive(obj) + + +class _IndexById(SchemaVisitor[Dict[int, NestedField]]): + """A schema visitor for generating a field ID to NestedField index""" + + def __init__(self) -> None: + self._index: Dict[int, NestedField] = {} + + def schema(self, schema, result): + return self._index + + def struct(self, struct, results): + return self._index + + def field(self, field, result): + """Add the field ID to the index""" + self._index[field.field_id] = field + return self._index + + def list(self, list_type, result): + """Add the list element ID to the index""" + self._index[list_type.element.field_id] = list_type.element + return self._index + + def map(self, map_type, key_result, value_result): + """Add the key ID and value ID as individual items in the index""" + self._index[map_type.key.field_id] = map_type.key + self._index[map_type.value.field_id] = map_type.value + return self._index + + def primitive(self, primitive): + return self._index + + +def index_by_id(schema_or_type) -> Dict[int, NestedField]: + """Generate an index of field IDs to NestedField instances + + Args: + schema_or_type (Schema | IcebergType): A schema or type to index + + Returns: + Dict[int, NestedField]: An index of field IDs to NestedField instances + """ + return visit(schema_or_type, _IndexById()) + + +class _IndexByName(SchemaVisitor[Dict[str, int]]): + """A schema visitor for generating a field name to field ID index""" + + def __init__(self) -> None: + self._index: Dict[str, int] = {} + self._short_name_to_id: Dict[str, int] = {} + self._combined_index: Dict[str, int] = {} + self._field_names: List[str] = [] + self._short_field_names: List[str] = [] + + def before_list_element(self, element: NestedField) -> None: + """Short field names omit element when the element is a StructType""" + if not isinstance(element.type, StructType): + self._short_field_names.append(element.name) + self._field_names.append(element.name) + + def after_list_element(self, element: NestedField) -> None: + if not isinstance(element.type, StructType): + self._short_field_names.pop() + self._field_names.pop() + + def before_field(self, field: NestedField) -> None: + """Store the field name""" + self._field_names.append(field.name) + self._short_field_names.append(field.name) + + def after_field(self, field: NestedField) -> None: + """Remove the last field name stored""" + self._field_names.pop() + self._short_field_names.pop() + + def schema(self, schema, struct_result): + return self._index + + def struct(self, struct, field_results): + return self._index + + def field(self, field, field_result): + """Add the field name to the index""" + self._add_field(field.name, field.field_id) + + def list(self, list_type, result): + """Add the list element name to the index""" + self._add_field(list_type.element.name, list_type.element.field_id) + + def map(self, map_type, key_result, value_result): + """Add the key name and value name as individual items in the index""" + self._add_field(map_type.key.name, map_type.key.field_id) + self._add_field(map_type.value.name, map_type.value.field_id) + + def _add_field(self, name: str, field_id: int): + """Add a field name to the index, mapping its full name to its field ID + + Args: + name (str): The field name + field_id (int): The field ID + + Raises: + ValueError: If the field name is already contained in the index + """ + full_name = name + + if self._field_names: + full_name = ".".join([".".join(self._field_names), name]) + + if full_name in self._index: + raise ValueError(f"Invalid schema, multiple fields for name {full_name}: {self._index[full_name]} and {field_id}") + self._index[full_name] = field_id + + if self._short_field_names: + short_name = ".".join([".".join(self._short_field_names), name]) + self._short_name_to_id[short_name] = field_id + + def primitive(self, primitive): + return self._index + + def by_name(self): + """Returns an index of combined full and short names + + Note: Only short names that do not conflict with full names are included. + """ + combined_index = self._short_name_to_id.copy() + combined_index.update(self._index) + return combined_index + + +def index_by_name(schema_or_type) -> Dict[str, int]: + """Generate an index of field names to field IDs + + Args: + schema_or_type (Schema | IcebergType): A schema or type to index + + Returns: + Dict[str, int]: An index of field names to field IDs + """ + indexer = _IndexByName() + visit(schema_or_type, indexer) + return indexer.by_name() diff --git a/python/src/iceberg/table/schema.py b/python/src/iceberg/table/schema.py deleted file mode 100644 index fb3c8e2b40bb..000000000000 --- a/python/src/iceberg/table/schema.py +++ /dev/null @@ -1,286 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import Dict, Generic, Iterable, List, TypeVar - -from iceberg.types import ( - IcebergType, - ListType, - MapType, - NestedField, - PrimitiveType, - StructType, -) - -T = TypeVar("T") - - -class Schema: - """A table Schema""" - - def __init__(self, *columns: Iterable[NestedField], schema_id: int, identifier_field_ids: List[int] = []): - self._struct = StructType(*columns) # type: ignore - self._schema_id = schema_id - self._identifier_field_ids = identifier_field_ids - self._name_index: Dict[str, int] = index_by_name(self) - self._id_index: Dict[int, NestedField] = {} # Will be lazily set when self.id_index property method is called - - def __str__(self): - return "table { \n" + "\n".join([" " + str(field) for field in self.columns]) + "\n }" - - def __repr__(self): - return f"Schema(fields={repr(self.columns)})" - - @property - def columns(self) -> Iterable[NestedField]: - return self._struct.fields - - @property - def id(self) -> int: - return self._schema_id - - @property - def identifier_field_ids(self) -> List[int]: - return self._identifier_field_ids - - @property - def id_index(self) -> Dict[int, NestedField]: - if not self._id_index: - self._id_index = index_by_id(self) - return self._id_index - - @property - def name_index(self) -> Dict[str, int]: - return self._name_index - - def as_struct(self) -> StructType: - return self._struct - - def find_field(self, name_or_id: str | int, case_sensitive: bool = True) -> NestedField: - if isinstance(name_or_id, int): - return self.id_index[name_or_id] - if case_sensitive: - field_id = self.name_index[name_or_id] - else: - name_index_lower = {name.lower(): field_id for name, field_id in self.name_index.items()} - field_id = name_index_lower[name_or_id.lower()] - return self.id_index[field_id] - - def find_type(self, name_or_id: str | int, case_sensitive: bool = True) -> IcebergType: - if isinstance(name_or_id, int): - return self.id_index[name_or_id].type - if case_sensitive: - field_id = self.name_index[name_or_id] - else: - name_index_lower = {name.lower(): field_id for name, field_id in self.name_index.items()} - field_id = name_index_lower[name_or_id.lower()] - return self.id_index[field_id].type - - def find_column_name(self, column_id: int) -> str: - matched_column = [name for name, field_id in self.name_index.items() if field_id == column_id] - if not matched_column: - raise ValueError(f"Cannot find column name: {column_id}") - return matched_column[0] - - def select(self, names: List[str], case_sensitive: bool = True) -> "Schema": - if case_sensitive: - return self._case_sensitive_select(schema=self, names=names) - return self._case_insensitive_select(schema=self, names=names) - - @classmethod - def _case_sensitive_select(cls, schema: "Schema", names: List[str]): - # TODO: Add a PruneColumns schema visitor and use it here - raise NotImplementedError() - - @classmethod - def _case_insensitive_select(cls, schema: "Schema", names: List[str]): - # TODO: Add a PruneColumns schema visitor and use it here - raise NotImplementedError() - - -class SchemaVisitor(Generic[T], ABC): - def before_field(self, field: NestedField) -> None: - pass - - def after_field(self, field: NestedField) -> None: - pass - - def before_list_element(self, element: NestedField) -> None: - self.before_field(element) - - def after_list_element(self, element: NestedField) -> None: - self.after_field(element) - - def before_map_key(self, key: NestedField) -> None: - self.before_field(key) - - def after_map_key(self, key: NestedField) -> None: - self.after_field(key) - - def before_map_value(self, value: NestedField) -> None: - self.before_field(value) - - def after_map_value(self, value: NestedField) -> None: - self.after_field(value) - - @abstractmethod - def schema(self, schema: Schema, struct_result: T) -> T: - ... - - @abstractmethod - def struct(self, struct: StructType, field_results: List[T]) -> T: - ... - - @abstractmethod - def field(self, field: NestedField, field_result: T) -> T: - ... - - @abstractmethod - def list(self, list_type: ListType, element_result: T) -> T: - ... - - @abstractmethod - def map(self, map_type: MapType, key_result: T, value_result: T) -> T: - ... - - @abstractmethod - def primitive(self, primitive: PrimitiveType) -> T: - ... - - -def visit(obj, visitor: SchemaVisitor[T]) -> T: - if isinstance(obj, Schema): - return visitor.schema(obj, visit(obj.as_struct(), visitor)) - - elif isinstance(obj, StructType): - results = [] - for field in obj.fields: - visitor.before_field(field) - try: - result = visit(field.type, visitor) - finally: - visitor.after_field(field) - - results.append(visitor.field(field, result)) - - return visitor.struct(obj, results) - - elif isinstance(obj, ListType): - visitor.before_list_element(obj.element) - try: - result = visit(obj.element.type, visitor) - finally: - visitor.after_list_element(obj.element) - - return visitor.list(obj, result) - - elif isinstance(obj, MapType): - visitor.before_map_key(obj.key) - try: - key_result = visit(obj.key.type, visitor) - finally: - visitor.after_map_key(obj.key) - - visitor.before_map_value(obj.value) - try: - value_result = visit(obj.value.type, visitor) - finally: - visitor.after_list_element(obj.value) - - return visitor.map(obj, key_result, value_result) - - elif isinstance(obj, PrimitiveType): - return visitor.primitive(obj) - - else: - raise NotImplementedError("Cannot visit non-type: %s" % obj) - - -def index_by_id(schema_or_type) -> Dict[int, NestedField]: - class IndexById(SchemaVisitor[Dict[int, NestedField]]): - def __init__(self) -> None: - self._index: Dict[int, NestedField] = {} - - def schema(self, schema, result): - return self._index - - def struct(self, struct, results): - return self._index - - def field(self, field, result): - self._index[field.field_id] = field - return self._index - - def list(self, list_type, result): - self._index[list_type.element.field_id] = list_type.element - return self._index - - def map(self, map_type, key_result, value_result): - self._index[map_type.key.field_id] = map_type.key - self._index[map_type.value.field_id] = map_type.value - return self._index - - def primitive(self, primitive): - return self._index - - return visit(schema_or_type, IndexById()) - - -def index_by_name(schema_or_type) -> Dict[str, int]: - class IndexByName(SchemaVisitor[Dict[str, int]]): - def __init__(self) -> None: - self._index: Dict[str, int] = {} - self._field_names: List[str] = [] - - def before_field(self, field: NestedField) -> None: - self._field_names.append(field.name) - - def after_field(self, field: NestedField) -> None: - self._field_names.pop() - - def schema(self, schema, struct_result): - return self._index - - def struct(self, struct, field_results): - return self._index - - def field(self, field, field_result): - self._add_field(field.name, field.field_id) - - def list(self, list_type, result): - self._add_field(list_type.element.name, list_type.element.field_id) - - def map(self, map_type, key_result, value_result): - self._add_field(map_type.key.name, map_type.key.field_id) - self._add_field(map_type.value.name, map_type.value.field_id) - - def _add_field(self, name, field_id): - full_name = name - if self._field_names: - full_name = ".".join([".".join(self._field_names), name]) - - if full_name in self._index: - raise ValueError(f"Invalid schema, multiple fields for name {full_name}: {index[full_name]} and {field_id}") - self._index[full_name] = field_id - - def primitive(self, primitive): - return self._index - - return visit(schema_or_type, IndexByName()) diff --git a/python/tests/conftest.py b/python/tests/conftest.py index 72d02f36ae7a..d8f5d35fba17 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -17,14 +17,16 @@ import pytest -from iceberg.table import schema +from iceberg import schema from iceberg.types import ( BooleanType, + FloatType, IntegerType, ListType, MapType, NestedField, StringType, + StructType, ) @@ -65,6 +67,19 @@ def table_schema_nested(): ), is_optional=True, ), + NestedField( + field_id=11, + name="location", + field_type=ListType( + element_id=12, + element_type=StructType( + NestedField(field_id=13, name="latitude", field_type=FloatType(), is_optional=False), + NestedField(field_id=14, name="longitude", field_type=FloatType(), is_optional=False), + ), + element_is_optional=True, + ), + is_optional=True, + ), schema_id=1, identifier_field_ids=[1], ) diff --git a/python/tests/table/test_schema.py b/python/tests/test_schema.py similarity index 78% rename from python/tests/table/test_schema.py rename to python/tests/test_schema.py index c381597dc050..a01059c4a8d2 100644 --- a/python/tests/table/test_schema.py +++ b/python/tests/test_schema.py @@ -19,14 +19,16 @@ import pytest -from iceberg.table import schema +from iceberg import schema from iceberg.types import ( BooleanType, + FloatType, IntegerType, ListType, MapType, NestedField, StringType, + StructType, ) @@ -34,11 +36,11 @@ def test_schema_str(table_schema_simple): """Test casting a schema to a string""" assert str(table_schema_simple) == dedent( """\ - table { - 1: foo: required string - 2: bar: optional int - 3: baz: required boolean - }""" + table { + 1: foo: required string + 2: bar: optional int + 3: baz: required boolean + }""" ) @@ -47,13 +49,13 @@ def test_schema_str(table_schema_simple): [ ( schema.Schema(NestedField(1, "foo", StringType()), schema_id=1), - "Schema(fields=(NestedField(field_id=1, name='foo', field_type=StringType(), is_optional=True),))", + "Schema(fields=(NestedField(field_id=1, name='foo', field_type=StringType(), is_optional=True),), schema_id=1, identifier_field_ids=[])", ), ( schema.Schema( NestedField(1, "foo", StringType()), NestedField(2, "bar", IntegerType(), is_optional=False), schema_id=1 ), - "Schema(fields=(NestedField(field_id=1, name='foo', field_type=StringType(), is_optional=True), NestedField(field_id=2, name='bar', field_type=IntegerType(), is_optional=False)))", + "Schema(fields=(NestedField(field_id=1, name='foo', field_type=StringType(), is_optional=True), NestedField(field_id=2, name='bar', field_type=IntegerType(), is_optional=False)), schema_id=1, identifier_field_ids=[])", ), ], ) @@ -99,6 +101,30 @@ def test_schema_index_by_id_visitor(table_schema_nested): is_optional=True, ), 10: NestedField(field_id=10, name="value", field_type=IntegerType(), is_optional=True), + 11: NestedField( + field_id=11, + name="location", + field_type=ListType( + element_id=12, + element_type=StructType( + NestedField(field_id=13, name="latitude", field_type=FloatType(), is_optional=False), + NestedField(field_id=14, name="longitude", field_type=FloatType(), is_optional=False), + ), + element_is_optional=True, + ), + is_optional=True, + ), + 12: NestedField( + field_id=12, + name="element", + field_type=StructType( + NestedField(field_id=13, name="latitude", field_type=FloatType(), is_optional=False), + NestedField(field_id=14, name="longitude", field_type=FloatType(), is_optional=False), + ), + is_optional=True, + ), + 13: NestedField(field_id=13, name="latitude", field_type=FloatType(), is_optional=False), + 14: NestedField(field_id=14, name="longitude", field_type=FloatType(), is_optional=False), } @@ -116,6 +142,12 @@ def test_schema_index_by_name_visitor(table_schema_nested): "quux.value": 8, "quux.value.key": 9, "quux.value.value": 10, + "location": 11, + "location.element": 12, + "location.element.latitude": 13, + "location.element.longitude": 14, + "location.latitude": 13, + "location.longitude": 14, } @@ -133,11 +165,9 @@ def test_schema_find_column_name(table_schema_nested): assert table_schema_nested.find_column_name(10) == "quux.value.value" -def test_schema_find_column_name_raise_on_id_not_found(table_schema_nested): +def test_schema_find_column_name_on_id_not_found(table_schema_nested): """Test raising an error when a field ID cannot be found""" - with pytest.raises(ValueError) as exc_info: - table_schema_nested.find_column_name(99) - assert "Cannot find column name: 99" in str(exc_info.value) + assert table_schema_nested.find_column_name(99) is None def test_schema_find_field_by_id(table_schema_simple): @@ -215,6 +245,30 @@ def test_index_by_id_schema_visitor(table_schema_nested): ), 9: NestedField(field_id=9, name="key", field_type=StringType(), is_optional=False), 10: NestedField(field_id=10, name="value", field_type=IntegerType(), is_optional=True), + 11: NestedField( + field_id=11, + name="location", + field_type=ListType( + element_id=12, + element_type=StructType( + NestedField(field_id=13, name="latitude", field_type=FloatType(), is_optional=False), + NestedField(field_id=14, name="longitude", field_type=FloatType(), is_optional=False), + ), + element_is_optional=True, + ), + is_optional=True, + ), + 12: NestedField( + field_id=12, + name="element", + field_type=StructType( + NestedField(field_id=13, name="latitude", field_type=FloatType(), is_optional=False), + NestedField(field_id=14, name="longitude", field_type=FloatType(), is_optional=False), + ), + is_optional=True, + ), + 13: NestedField(field_id=13, name="latitude", field_type=FloatType(), is_optional=False), + 14: NestedField(field_id=14, name="longitude", field_type=FloatType(), is_optional=False), } From 8ccfca4dfebbdc93fcc348c0e348681f36883fa5 Mon Sep 17 00:00:00 2001 From: samredai <43911210+samredai@users.noreply.github.com> Date: Mon, 4 Apr 2022 14:24:48 -0700 Subject: [PATCH 23/23] Raise NotImplementedError for find_column_name --- python/src/iceberg/schema.py | 5 ++-- python/tests/test_schema.py | 44 ++++++++++++++++++------------------ 2 files changed, 24 insertions(+), 25 deletions(-) diff --git a/python/src/iceberg/schema.py b/python/src/iceberg/schema.py index ade89c5e4127..ad60d870bbd5 100644 --- a/python/src/iceberg/schema.py +++ b/python/src/iceberg/schema.py @@ -130,7 +130,7 @@ def find_type(self, name_or_id: str | int, case_sensitive: bool = True) -> Icebe field = self.find_field(name_or_id=name_or_id, case_sensitive=case_sensitive) return field.type # type: ignore - def find_column_name(self, column_id: int) -> str: + def find_column_name(self, column_id: int): """Find a column name given a column ID Args: @@ -142,8 +142,7 @@ def find_column_name(self, column_id: int) -> str: Returns: str: The column name """ - column = self._lazy_id_to_field().get(column_id) - return None if column is None else column.name # type: ignore + raise NotImplementedError() def select(self, names: List[str], case_sensitive: bool = True) -> "Schema": """Return a new schema instance pruned to a subset of columns diff --git a/python/tests/test_schema.py b/python/tests/test_schema.py index a01059c4a8d2..321ee470b79d 100644 --- a/python/tests/test_schema.py +++ b/python/tests/test_schema.py @@ -151,23 +151,30 @@ def test_schema_index_by_name_visitor(table_schema_nested): } -def test_schema_find_column_name(table_schema_nested): - """Test finding a column name using its field ID""" - assert table_schema_nested.find_column_name(1) == "foo" - assert table_schema_nested.find_column_name(2) == "bar" - assert table_schema_nested.find_column_name(3) == "baz" - assert table_schema_nested.find_column_name(4) == "qux" - assert table_schema_nested.find_column_name(5) == "qux.element" - assert table_schema_nested.find_column_name(6) == "quux" - assert table_schema_nested.find_column_name(7) == "quux.key" - assert table_schema_nested.find_column_name(8) == "quux.value" - assert table_schema_nested.find_column_name(9) == "quux.value.key" - assert table_schema_nested.find_column_name(10) == "quux.value.value" +# def test_schema_find_column_name(table_schema_nested): +# """Test finding a column name using its field ID""" +# assert table_schema_nested.find_column_name(1) == "foo" +# assert table_schema_nested.find_column_name(2) == "bar" +# assert table_schema_nested.find_column_name(3) == "baz" +# assert table_schema_nested.find_column_name(4) == "qux" +# assert table_schema_nested.find_column_name(5) == "qux.element" +# assert table_schema_nested.find_column_name(6) == "quux" +# assert table_schema_nested.find_column_name(7) == "quux.key" +# assert table_schema_nested.find_column_name(8) == "quux.value" +# assert table_schema_nested.find_column_name(9) == "quux.value.key" +# assert table_schema_nested.find_column_name(10) == "quux.value.value" -def test_schema_find_column_name_on_id_not_found(table_schema_nested): - """Test raising an error when a field ID cannot be found""" - assert table_schema_nested.find_column_name(99) is None +# def test_schema_find_column_name_on_id_not_found(table_schema_nested): +# """Test raising an error when a field ID cannot be found""" +# assert table_schema_nested.find_column_name(99) is None + + +# def test_schema_find_column_name(table_schema_simple): +# """Test finding a column name given its field ID""" +# assert table_schema_simple.find_column_name(1) == "foo" +# assert table_schema_simple.find_column_name(2) == "bar" +# assert table_schema_simple.find_column_name(3) == "baz" def test_schema_find_field_by_id(table_schema_simple): @@ -321,10 +328,3 @@ def test_schema_find_type(table_schema_simple): == table_schema_simple.find_type("BAZ", case_sensitive=False) == BooleanType() ) - - -def test_schema_find_column_name(table_schema_simple): - """Test finding a column name given its field ID""" - assert table_schema_simple.find_column_name(1) == "foo" - assert table_schema_simple.find_column_name(2) == "bar" - assert table_schema_simple.find_column_name(3) == "baz"