diff --git a/python/src/iceberg/schema.py b/python/src/iceberg/schema.py new file mode 100644 index 000000000000..ad60d870bbd5 --- /dev/null +++ b/python/src/iceberg/schema.py @@ -0,0 +1,447 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import sys +from abc import ABC, abstractmethod +from typing import Dict, Generic, Iterable, List, TypeVar + +if sys.version_info >= (3, 8): + from functools import singledispatch # pragma: no cover +else: + from singledispatch import singledispatch # pragma: no cover + +from iceberg.types import ( + IcebergType, + ListType, + MapType, + NestedField, + PrimitiveType, + StructType, +) + +T = TypeVar("T") + + +class Schema: + """A table Schema + + Example: + >>> from iceberg import schema + >>> from iceberg import types + """ + + def __init__(self, *columns: Iterable[NestedField], schema_id: int, identifier_field_ids: List[int] = []): + self._struct = StructType(*columns) # type: ignore + self._schema_id = schema_id + self._identifier_field_ids = identifier_field_ids + self._name_to_id: Dict[str, int] = index_by_name(self) + self._name_to_id_lower: Dict[str, int] = {} # Should be accessed through self._lazy_name_to_id_lower() + self._id_to_field: Dict[int, NestedField] = {} # Should be accessed through self._lazy_id_to_field() + + def __str__(self): + return "table {\n" + "\n".join([" " + str(field) for field in self.columns]) + "\n}" + + def __repr__(self): + return ( + f"Schema(fields={repr(self.columns)}, schema_id={self.schema_id}, identifier_field_ids={self.identifier_field_ids})" + ) + + @property + def columns(self) -> Iterable[NestedField]: + """A list of the top-level fields in the underlying struct""" + return self._struct.fields + + @property + def schema_id(self) -> int: + """The ID of this Schema""" + return self._schema_id + + @property + def identifier_field_ids(self) -> List[int]: + return self._identifier_field_ids + + def _lazy_id_to_field(self) -> Dict[int, NestedField]: + """Returns an index of field ID to NestedField instance + + This property is calculated once when called for the first time. Subsequent calls to this property will use a cached index. + """ + if not self._id_to_field: + self._id_to_field = index_by_id(self) + return self._id_to_field + + def _lazy_name_to_id_lower(self) -> Dict[str, int]: + """Returns an index of lower-case field names to field IDs + + This property is calculated once when called for the first time. Subsequent calls to this property will use a cached index. + """ + if not self._name_to_id_lower: + self._name_to_id_lower = {name.lower(): field_id for name, field_id in self._name_to_id.items()} + return self._name_to_id_lower + + def as_struct(self) -> StructType: + """Returns the underlying struct""" + return self._struct + + def find_field(self, name_or_id: str | int, case_sensitive: bool = True) -> NestedField: + """Find a field using a field name or field ID + + Args: + name_or_id (str | int): Either a field name or a field ID + case_sensitive (bool, optional): Whether to peform a case-sensitive lookup using a field name. Defaults to True. + + Returns: + NestedField: The matched NestedField + """ + if isinstance(name_or_id, int): + field = self._lazy_id_to_field().get(name_or_id) + return field # type: ignore + if case_sensitive: + field_id = self._name_to_id.get(name_or_id) + else: + field_id = self._lazy_name_to_id_lower().get(name_or_id.lower()) + return self._lazy_id_to_field().get(field_id) # type: ignore + + def find_type(self, name_or_id: str | int, case_sensitive: bool = True) -> IcebergType: + """Find a field type using a field name or field ID + + Args: + name_or_id (str | int): Either a field name or a field ID + case_sensitive (bool, optional): Whether to peform a case-sensitive lookup using a field name. Defaults to True. + + Returns: + NestedField: The type of the matched NestedField + """ + field = self.find_field(name_or_id=name_or_id, case_sensitive=case_sensitive) + return field.type # type: ignore + + def find_column_name(self, column_id: int): + """Find a column name given a column ID + + Args: + column_id (int): The ID of the column + + Raises: + ValueError: If no column name can be found for the given column ID + + Returns: + str: The column name + """ + raise NotImplementedError() + + def select(self, names: List[str], case_sensitive: bool = True) -> "Schema": + """Return a new schema instance pruned to a subset of columns + + Args: + names (List[str]): A list of column names + case_sensitive (bool, optional): Whether to peform a case-sensitive lookup for each column name. Defaults to True. + + Returns: + Schema: A new schema with pruned columns + """ + if case_sensitive: + return self._case_sensitive_select(schema=self, names=names) + return self._case_insensitive_select(schema=self, names=names) + + @classmethod + def _case_sensitive_select(cls, schema: "Schema", names: List[str]): + # TODO: Add a PruneColumns schema visitor and use it here + raise NotImplementedError() + + @classmethod + def _case_insensitive_select(cls, schema: "Schema", names: List[str]): + # TODO: Add a PruneColumns schema visitor and use it here + raise NotImplementedError() + + +class SchemaVisitor(Generic[T], ABC): + def before_field(self, field: NestedField) -> None: + """Override this method to perform an action immediately before visiting a field""" + + def after_field(self, field: NestedField) -> None: + """Override this method to perform an action immediately after visiting a field""" + + def before_list_element(self, element: NestedField) -> None: + """Override this method to perform an action immediately before visiting an element within a ListType""" + self.before_field(element) + + def after_list_element(self, element: NestedField) -> None: + """Override this method to perform an action immediately after visiting an element within a ListType""" + self.after_field(element) + + def before_map_key(self, key: NestedField) -> None: + """Override this method to perform an action immediately before visiting a key within a MapType""" + self.before_field(key) + + def after_map_key(self, key: NestedField) -> None: + """Override this method to perform an action immediately after visiting a key within a MapType""" + self.after_field(key) + + def before_map_value(self, value: NestedField) -> None: + """Override this method to perform an action immediately before visiting a value within a MapType""" + self.before_field(value) + + def after_map_value(self, value: NestedField) -> None: + """Override this method to perform an action immediately after visiting a value within a MapType""" + self.after_field(value) + + @abstractmethod + def schema(self, schema: Schema, struct_result: T) -> T: + """Visit a Schema""" + ... + + @abstractmethod + def struct(self, struct: StructType, field_results: List[T]) -> T: + """Visit a StructType""" + ... + + @abstractmethod + def field(self, field: NestedField, field_result: T) -> T: + """Visit a NestedField""" + ... + + @abstractmethod + def list(self, list_type: ListType, element_result: T) -> T: + """Visit a ListType""" + ... + + @abstractmethod + def map(self, map_type: MapType, key_result: T, value_result: T) -> T: + """Visit a MapType""" + ... + + @abstractmethod + def primitive(self, primitive: PrimitiveType) -> T: + """Visit a PrimitiveType""" + ... + + +@singledispatch +def visit(obj, visitor: SchemaVisitor[T]) -> T: + """A generic function for applying a schema visitor to any point within a schema + + Args: + obj(Schema | IcebergType): An instance of a Schema or an IcebergType + visitor (SchemaVisitor[T]): An instance of an implementation of the generic SchemaVisitor base class + + Raises: + NotImplementedError: If attempting to visit an unrecognized object type + """ + raise NotImplementedError("Cannot visit non-type: %s" % obj) + + +@visit.register(Schema) +def _(obj: Schema, visitor: SchemaVisitor[T]) -> T: + """Visit a Schema with a concrete SchemaVisitor""" + return visitor.schema(obj, visit(obj.as_struct(), visitor)) + + +@visit.register(StructType) +def _(obj: StructType, visitor: SchemaVisitor[T]) -> T: + """Visit a StructType with a concrete SchemaVisitor""" + results = [] + for field in obj.fields: + visitor.before_field(field) + try: + result = visit(field.type, visitor) + finally: + visitor.after_field(field) + + results.append(visitor.field(field, result)) + + return visitor.struct(obj, results) + + +@visit.register(ListType) +def _(obj: ListType, visitor: SchemaVisitor[T]) -> T: + """Visit a ListType with a concrete SchemaVisitor""" + visitor.before_list_element(obj.element) + try: + result = visit(obj.element.type, visitor) + finally: + visitor.after_list_element(obj.element) + + return visitor.list(obj, result) + + +@visit.register(MapType) +def _(obj: MapType, visitor: SchemaVisitor[T]) -> T: + """Visit a MapType with a concrete SchemaVisitor""" + visitor.before_map_key(obj.key) + try: + key_result = visit(obj.key.type, visitor) + finally: + visitor.after_map_key(obj.key) + + visitor.before_map_value(obj.value) + try: + value_result = visit(obj.value.type, visitor) + finally: + visitor.after_list_element(obj.value) + + return visitor.map(obj, key_result, value_result) + + +@visit.register(PrimitiveType) +def _(obj: PrimitiveType, visitor: SchemaVisitor[T]) -> T: + """Visit a PrimitiveType with a concrete SchemaVisitor""" + return visitor.primitive(obj) + + +class _IndexById(SchemaVisitor[Dict[int, NestedField]]): + """A schema visitor for generating a field ID to NestedField index""" + + def __init__(self) -> None: + self._index: Dict[int, NestedField] = {} + + def schema(self, schema, result): + return self._index + + def struct(self, struct, results): + return self._index + + def field(self, field, result): + """Add the field ID to the index""" + self._index[field.field_id] = field + return self._index + + def list(self, list_type, result): + """Add the list element ID to the index""" + self._index[list_type.element.field_id] = list_type.element + return self._index + + def map(self, map_type, key_result, value_result): + """Add the key ID and value ID as individual items in the index""" + self._index[map_type.key.field_id] = map_type.key + self._index[map_type.value.field_id] = map_type.value + return self._index + + def primitive(self, primitive): + return self._index + + +def index_by_id(schema_or_type) -> Dict[int, NestedField]: + """Generate an index of field IDs to NestedField instances + + Args: + schema_or_type (Schema | IcebergType): A schema or type to index + + Returns: + Dict[int, NestedField]: An index of field IDs to NestedField instances + """ + return visit(schema_or_type, _IndexById()) + + +class _IndexByName(SchemaVisitor[Dict[str, int]]): + """A schema visitor for generating a field name to field ID index""" + + def __init__(self) -> None: + self._index: Dict[str, int] = {} + self._short_name_to_id: Dict[str, int] = {} + self._combined_index: Dict[str, int] = {} + self._field_names: List[str] = [] + self._short_field_names: List[str] = [] + + def before_list_element(self, element: NestedField) -> None: + """Short field names omit element when the element is a StructType""" + if not isinstance(element.type, StructType): + self._short_field_names.append(element.name) + self._field_names.append(element.name) + + def after_list_element(self, element: NestedField) -> None: + if not isinstance(element.type, StructType): + self._short_field_names.pop() + self._field_names.pop() + + def before_field(self, field: NestedField) -> None: + """Store the field name""" + self._field_names.append(field.name) + self._short_field_names.append(field.name) + + def after_field(self, field: NestedField) -> None: + """Remove the last field name stored""" + self._field_names.pop() + self._short_field_names.pop() + + def schema(self, schema, struct_result): + return self._index + + def struct(self, struct, field_results): + return self._index + + def field(self, field, field_result): + """Add the field name to the index""" + self._add_field(field.name, field.field_id) + + def list(self, list_type, result): + """Add the list element name to the index""" + self._add_field(list_type.element.name, list_type.element.field_id) + + def map(self, map_type, key_result, value_result): + """Add the key name and value name as individual items in the index""" + self._add_field(map_type.key.name, map_type.key.field_id) + self._add_field(map_type.value.name, map_type.value.field_id) + + def _add_field(self, name: str, field_id: int): + """Add a field name to the index, mapping its full name to its field ID + + Args: + name (str): The field name + field_id (int): The field ID + + Raises: + ValueError: If the field name is already contained in the index + """ + full_name = name + + if self._field_names: + full_name = ".".join([".".join(self._field_names), name]) + + if full_name in self._index: + raise ValueError(f"Invalid schema, multiple fields for name {full_name}: {self._index[full_name]} and {field_id}") + self._index[full_name] = field_id + + if self._short_field_names: + short_name = ".".join([".".join(self._short_field_names), name]) + self._short_name_to_id[short_name] = field_id + + def primitive(self, primitive): + return self._index + + def by_name(self): + """Returns an index of combined full and short names + + Note: Only short names that do not conflict with full names are included. + """ + combined_index = self._short_name_to_id.copy() + combined_index.update(self._index) + return combined_index + + +def index_by_name(schema_or_type) -> Dict[str, int]: + """Generate an index of field names to field IDs + + Args: + schema_or_type (Schema | IcebergType): A schema or type to index + + Returns: + Dict[str, int]: An index of field names to field IDs + """ + indexer = _IndexByName() + visit(schema_or_type, indexer) + return indexer.by_name() diff --git a/python/tests/conftest.py b/python/tests/conftest.py new file mode 100644 index 000000000000..d8f5d35fba17 --- /dev/null +++ b/python/tests/conftest.py @@ -0,0 +1,85 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +from iceberg import schema +from iceberg.types import ( + BooleanType, + FloatType, + IntegerType, + ListType, + MapType, + NestedField, + StringType, + StructType, +) + + +@pytest.fixture(scope="session", autouse=True) +def table_schema_simple(): + return schema.Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), + schema_id=1, + identifier_field_ids=[1], + ) + + +@pytest.fixture(scope="session", autouse=True) +def table_schema_nested(): + return schema.Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), + NestedField( + field_id=4, + name="qux", + field_type=ListType(element_id=5, element_type=StringType(), element_is_optional=True), + is_optional=True, + ), + NestedField( + field_id=6, + name="quux", + field_type=MapType( + key_id=7, + key_type=StringType(), + value_id=8, + value_type=MapType( + key_id=9, key_type=StringType(), value_id=10, value_type=IntegerType(), value_is_optional=True + ), + value_is_optional=True, + ), + is_optional=True, + ), + NestedField( + field_id=11, + name="location", + field_type=ListType( + element_id=12, + element_type=StructType( + NestedField(field_id=13, name="latitude", field_type=FloatType(), is_optional=False), + NestedField(field_id=14, name="longitude", field_type=FloatType(), is_optional=False), + ), + element_is_optional=True, + ), + is_optional=True, + ), + schema_id=1, + identifier_field_ids=[1], + ) diff --git a/python/tests/test_schema.py b/python/tests/test_schema.py new file mode 100644 index 000000000000..321ee470b79d --- /dev/null +++ b/python/tests/test_schema.py @@ -0,0 +1,330 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from textwrap import dedent + +import pytest + +from iceberg import schema +from iceberg.types import ( + BooleanType, + FloatType, + IntegerType, + ListType, + MapType, + NestedField, + StringType, + StructType, +) + + +def test_schema_str(table_schema_simple): + """Test casting a schema to a string""" + assert str(table_schema_simple) == dedent( + """\ + table { + 1: foo: required string + 2: bar: optional int + 3: baz: required boolean + }""" + ) + + +@pytest.mark.parametrize( + "schema, expected_repr", + [ + ( + schema.Schema(NestedField(1, "foo", StringType()), schema_id=1), + "Schema(fields=(NestedField(field_id=1, name='foo', field_type=StringType(), is_optional=True),), schema_id=1, identifier_field_ids=[])", + ), + ( + schema.Schema( + NestedField(1, "foo", StringType()), NestedField(2, "bar", IntegerType(), is_optional=False), schema_id=1 + ), + "Schema(fields=(NestedField(field_id=1, name='foo', field_type=StringType(), is_optional=True), NestedField(field_id=2, name='bar', field_type=IntegerType(), is_optional=False)), schema_id=1, identifier_field_ids=[])", + ), + ], +) +def test_schema_repr(schema, expected_repr): + """Test schema representation""" + assert repr(schema) == expected_repr + + +def test_schema_index_by_id_visitor(table_schema_nested): + """Test index_by_id visitor function""" + index = schema.index_by_id(table_schema_nested) + assert index == { + 1: NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), + 2: NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), + 3: NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), + 4: NestedField( + field_id=4, + name="qux", + field_type=ListType(element_id=5, element_type=StringType(), element_is_optional=True), + is_optional=True, + ), + 5: NestedField(field_id=5, name="element", field_type=StringType(), is_optional=True), + 6: NestedField( + field_id=6, + name="quux", + field_type=MapType( + key_id=7, + key_type=StringType(), + value_id=8, + value_type=MapType( + key_id=9, key_type=StringType(), value_id=10, value_type=IntegerType(), value_is_optional=True + ), + value_is_optional=True, + ), + is_optional=True, + ), + 7: NestedField(field_id=7, name="key", field_type=StringType(), is_optional=False), + 9: NestedField(field_id=9, name="key", field_type=StringType(), is_optional=False), + 8: NestedField( + field_id=8, + name="value", + field_type=MapType(key_id=9, key_type=StringType(), value_id=10, value_type=IntegerType(), value_is_optional=True), + is_optional=True, + ), + 10: NestedField(field_id=10, name="value", field_type=IntegerType(), is_optional=True), + 11: NestedField( + field_id=11, + name="location", + field_type=ListType( + element_id=12, + element_type=StructType( + NestedField(field_id=13, name="latitude", field_type=FloatType(), is_optional=False), + NestedField(field_id=14, name="longitude", field_type=FloatType(), is_optional=False), + ), + element_is_optional=True, + ), + is_optional=True, + ), + 12: NestedField( + field_id=12, + name="element", + field_type=StructType( + NestedField(field_id=13, name="latitude", field_type=FloatType(), is_optional=False), + NestedField(field_id=14, name="longitude", field_type=FloatType(), is_optional=False), + ), + is_optional=True, + ), + 13: NestedField(field_id=13, name="latitude", field_type=FloatType(), is_optional=False), + 14: NestedField(field_id=14, name="longitude", field_type=FloatType(), is_optional=False), + } + + +def test_schema_index_by_name_visitor(table_schema_nested): + """Test index_by_name visitor function""" + index = schema.index_by_name(table_schema_nested) + assert index == { + "foo": 1, + "bar": 2, + "baz": 3, + "qux": 4, + "qux.element": 5, + "quux": 6, + "quux.key": 7, + "quux.value": 8, + "quux.value.key": 9, + "quux.value.value": 10, + "location": 11, + "location.element": 12, + "location.element.latitude": 13, + "location.element.longitude": 14, + "location.latitude": 13, + "location.longitude": 14, + } + + +# def test_schema_find_column_name(table_schema_nested): +# """Test finding a column name using its field ID""" +# assert table_schema_nested.find_column_name(1) == "foo" +# assert table_schema_nested.find_column_name(2) == "bar" +# assert table_schema_nested.find_column_name(3) == "baz" +# assert table_schema_nested.find_column_name(4) == "qux" +# assert table_schema_nested.find_column_name(5) == "qux.element" +# assert table_schema_nested.find_column_name(6) == "quux" +# assert table_schema_nested.find_column_name(7) == "quux.key" +# assert table_schema_nested.find_column_name(8) == "quux.value" +# assert table_schema_nested.find_column_name(9) == "quux.value.key" +# assert table_schema_nested.find_column_name(10) == "quux.value.value" + + +# def test_schema_find_column_name_on_id_not_found(table_schema_nested): +# """Test raising an error when a field ID cannot be found""" +# assert table_schema_nested.find_column_name(99) is None + + +# def test_schema_find_column_name(table_schema_simple): +# """Test finding a column name given its field ID""" +# assert table_schema_simple.find_column_name(1) == "foo" +# assert table_schema_simple.find_column_name(2) == "bar" +# assert table_schema_simple.find_column_name(3) == "baz" + + +def test_schema_find_field_by_id(table_schema_simple): + """Test finding a column using its field ID""" + index = schema.index_by_id(table_schema_simple) + + column1 = index[1] + assert isinstance(column1, NestedField) + assert column1.field_id == 1 + assert column1.type == StringType() + assert column1.is_optional == False + + column2 = index[2] + assert isinstance(column2, NestedField) + assert column2.field_id == 2 + assert column2.type == IntegerType() + assert column2.is_optional == True + + column3 = index[3] + assert isinstance(column3, NestedField) + assert column3.field_id == 3 + assert column3.type == BooleanType() + assert column3.is_optional == False + + +def test_schema_find_field_by_id_raise_on_unknown_field(table_schema_simple): + """Test raising when the field ID is not found among columns""" + index = schema.index_by_id(table_schema_simple) + with pytest.raises(Exception) as exc_info: + index[4] + assert str(exc_info.value) == "4" + + +def test_schema_find_field_type_by_id(table_schema_simple): + """Test retrieving a columns's type using its field ID""" + index = schema.index_by_id(table_schema_simple) + assert index[1] == NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False) + assert index[2] == NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True) + assert index[3] == NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False) + + +def test_index_by_id_schema_visitor(table_schema_nested): + """Test the index_by_id function that uses the IndexById schema visitor""" + assert schema.index_by_id(table_schema_nested) == { + 1: NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False), + 2: NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True), + 3: NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False), + 4: NestedField( + field_id=4, + name="qux", + field_type=ListType(element_id=5, element_type=StringType(), element_is_optional=True), + is_optional=True, + ), + 5: NestedField(field_id=5, name="element", field_type=StringType(), is_optional=True), + 6: NestedField( + field_id=6, + name="quux", + field_type=MapType( + key_id=7, + key_type=StringType(), + value_id=8, + value_type=MapType( + key_id=9, key_type=StringType(), value_id=10, value_type=IntegerType(), value_is_optional=True + ), + value_is_optional=True, + ), + is_optional=True, + ), + 7: NestedField(field_id=7, name="key", field_type=StringType(), is_optional=False), + 8: NestedField( + field_id=8, + name="value", + field_type=MapType(key_id=9, key_type=StringType(), value_id=10, value_type=IntegerType(), value_is_optional=True), + is_optional=True, + ), + 9: NestedField(field_id=9, name="key", field_type=StringType(), is_optional=False), + 10: NestedField(field_id=10, name="value", field_type=IntegerType(), is_optional=True), + 11: NestedField( + field_id=11, + name="location", + field_type=ListType( + element_id=12, + element_type=StructType( + NestedField(field_id=13, name="latitude", field_type=FloatType(), is_optional=False), + NestedField(field_id=14, name="longitude", field_type=FloatType(), is_optional=False), + ), + element_is_optional=True, + ), + is_optional=True, + ), + 12: NestedField( + field_id=12, + name="element", + field_type=StructType( + NestedField(field_id=13, name="latitude", field_type=FloatType(), is_optional=False), + NestedField(field_id=14, name="longitude", field_type=FloatType(), is_optional=False), + ), + is_optional=True, + ), + 13: NestedField(field_id=13, name="latitude", field_type=FloatType(), is_optional=False), + 14: NestedField(field_id=14, name="longitude", field_type=FloatType(), is_optional=False), + } + + +def test_index_by_id_schema_visitor_raise_on_unregistered_type(): + """Test raising a NotImplementedError when an invalid type is provided to the index_by_id function""" + with pytest.raises(NotImplementedError) as exc_info: + schema.index_by_id("foo") + assert "Cannot visit non-type: foo" in str(exc_info.value) + + +def test_schema_find_field(table_schema_simple): + """Test finding a field in a schema""" + assert ( + table_schema_simple.find_field(1) + == table_schema_simple.find_field("foo") + == table_schema_simple.find_field("FOO", case_sensitive=False) + == NestedField(field_id=1, name="foo", field_type=StringType(), is_optional=False) + ) + assert ( + table_schema_simple.find_field(2) + == table_schema_simple.find_field("bar") + == table_schema_simple.find_field("BAR", case_sensitive=False) + == NestedField(field_id=2, name="bar", field_type=IntegerType(), is_optional=True) + ) + assert ( + table_schema_simple.find_field(3) + == table_schema_simple.find_field("baz") + == table_schema_simple.find_field("BAZ", case_sensitive=False) + == NestedField(field_id=3, name="baz", field_type=BooleanType(), is_optional=False) + ) + + +def test_schema_find_type(table_schema_simple): + """Test finding the type of a column given its field ID""" + assert ( + table_schema_simple.find_type(1) + == table_schema_simple.find_type("foo") + == table_schema_simple.find_type("FOO", case_sensitive=False) + == StringType() + ) + assert ( + table_schema_simple.find_type(2) + == table_schema_simple.find_type("bar") + == table_schema_simple.find_type("BAR", case_sensitive=False) + == IntegerType() + ) + assert ( + table_schema_simple.find_type(3) + == table_schema_simple.find_type("baz") + == table_schema_simple.find_type("BAZ", case_sensitive=False) + == BooleanType() + )