From bb3f99f129920b1b67541e652c763a9f408a2f8e Mon Sep 17 00:00:00 2001 From: Antoine Beyeler Date: Fri, 11 Oct 2024 15:01:53 +0200 Subject: [PATCH] Implement fix in codegen + add test --- .../src/codegen/python/mod.rs | 4 ++- .../definitions/rerun/datatypes/utf8.fbs | 2 +- .../rerun_sdk/rerun/datatypes/entity_path.py | 6 ++++- rerun_py/rerun_sdk/rerun/datatypes/utf8.py | 18 ++++++++++--- .../rerun_sdk/rerun/datatypes/utf8_ext.py | 24 ----------------- .../test_types/components/affix_fuzzer10.py | 6 ++++- .../test_types/components/affix_fuzzer9.py | 6 ++++- .../test_types/datatypes/string_component.py | 6 ++++- rerun_py/tests/unit/test_utf8.py | 27 +++++++++++++++++++ 9 files changed, 65 insertions(+), 34 deletions(-) delete mode 100644 rerun_py/rerun_sdk/rerun/datatypes/utf8_ext.py create mode 100644 rerun_py/tests/unit/test_utf8.py diff --git a/crates/build/re_types_builder/src/codegen/python/mod.rs b/crates/build/re_types_builder/src/codegen/python/mod.rs index d0e21f345df2..4a925afa4d6b 100644 --- a/crates/build/re_types_builder/src/codegen/python/mod.rs +++ b/crates/build/re_types_builder/src/codegen/python/mod.rs @@ -1995,9 +1995,11 @@ fn quote_arrow_serialization( return Ok(unindent( r##" if isinstance(data, str): - array = [data] + array: Union[list[str], npt.ArrayLike] = [data] elif isinstance(data, Sequence): array = [str(datum) for datum in data] + elif isinstance(data, np.ndarray): + array = data else: array = [str(data)] diff --git a/crates/store/re_types/definitions/rerun/datatypes/utf8.fbs b/crates/store/re_types/definitions/rerun/datatypes/utf8.fbs index 3370cb03715f..01569cf24671 100644 --- a/crates/store/re_types/definitions/rerun/datatypes/utf8.fbs +++ b/crates/store/re_types/definitions/rerun/datatypes/utf8.fbs @@ -8,7 +8,7 @@ namespace rerun.datatypes; table Utf8 ( "attr.arrow.transparent", "attr.python.aliases": "str", - "attr.python.array_aliases": "str, Sequence[str]", + "attr.python.array_aliases": "str, Sequence[str], npt.ArrayLike", "attr.rust.derive": "Default, PartialEq, Eq, PartialOrd, Ord, Hash", "attr.rust.override_crate": "re_types_core", "attr.rust.repr": "transparent", diff --git a/rerun_py/rerun_sdk/rerun/datatypes/entity_path.py b/rerun_py/rerun_sdk/rerun/datatypes/entity_path.py index 1eac29289de9..234290dc9f58 100644 --- a/rerun_py/rerun_sdk/rerun/datatypes/entity_path.py +++ b/rerun_py/rerun_sdk/rerun/datatypes/entity_path.py @@ -7,6 +7,8 @@ from typing import TYPE_CHECKING, Any, Sequence, Union +import numpy as np +import numpy.typing as npt import pyarrow as pa from attrs import define, field @@ -58,9 +60,11 @@ class EntityPathBatch(BaseBatch[EntityPathArrayLike]): @staticmethod def _native_to_pa_array(data: EntityPathArrayLike, data_type: pa.DataType) -> pa.Array: if isinstance(data, str): - array = [data] + array: Union[list[str], npt.ArrayLike] = [data] elif isinstance(data, Sequence): array = [str(datum) for datum in data] + elif isinstance(data, np.ndarray): + array = data else: array = [str(data)] diff --git a/rerun_py/rerun_sdk/rerun/datatypes/utf8.py b/rerun_py/rerun_sdk/rerun/datatypes/utf8.py index 91e73e5d4a07..c3c2ef43027c 100644 --- a/rerun_py/rerun_sdk/rerun/datatypes/utf8.py +++ b/rerun_py/rerun_sdk/rerun/datatypes/utf8.py @@ -7,6 +7,8 @@ from typing import TYPE_CHECKING, Any, Sequence, Union +import numpy as np +import numpy.typing as npt import pyarrow as pa from attrs import define, field @@ -14,13 +16,12 @@ BaseBatch, BaseExtensionType, ) -from .utf8_ext import Utf8Ext __all__ = ["Utf8", "Utf8ArrayLike", "Utf8Batch", "Utf8Like", "Utf8Type"] @define(init=False) -class Utf8(Utf8Ext): +class Utf8: """**Datatype**: A string of text, encoded as UTF-8.""" def __init__(self: Any, value: Utf8Like): @@ -43,7 +44,7 @@ def __hash__(self) -> int: else: Utf8Like = Any -Utf8ArrayLike = Union[Utf8, Sequence[Utf8Like], str, Sequence[str]] +Utf8ArrayLike = Union[Utf8, Sequence[Utf8Like], str, Sequence[str], npt.ArrayLike] class Utf8Type(BaseExtensionType): @@ -58,4 +59,13 @@ class Utf8Batch(BaseBatch[Utf8ArrayLike]): @staticmethod def _native_to_pa_array(data: Utf8ArrayLike, data_type: pa.DataType) -> pa.Array: - return Utf8Ext.native_to_pa_array_override(data, data_type) + if isinstance(data, str): + array: Union[list[str], npt.ArrayLike] = [data] + elif isinstance(data, Sequence): + array = [str(datum) for datum in data] + elif isinstance(data, np.ndarray): + array = data + else: + array = [str(data)] + + return pa.array(array, type=data_type) diff --git a/rerun_py/rerun_sdk/rerun/datatypes/utf8_ext.py b/rerun_py/rerun_sdk/rerun/datatypes/utf8_ext.py deleted file mode 100644 index 7972a0bab097..000000000000 --- a/rerun_py/rerun_sdk/rerun/datatypes/utf8_ext.py +++ /dev/null @@ -1,24 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Sequence - -import numpy as np -import pyarrow as pa - -if TYPE_CHECKING: - from . import Utf8ArrayLike - - -class Utf8Ext: - @staticmethod - def native_to_pa_array_override(data: Utf8ArrayLike, data_type: pa.DataType) -> pa.Array: - if isinstance(data, str): - array = [data] - elif isinstance(data, Sequence): - array = [str(datum) for datum in data] - elif isinstance(data, np.ndarray): - array = data - else: - array = [str(data)] - - return pa.array(array, type=data_type) diff --git a/rerun_py/tests/test_types/components/affix_fuzzer10.py b/rerun_py/tests/test_types/components/affix_fuzzer10.py index 061b5ad1645d..1b74db055cc2 100644 --- a/rerun_py/tests/test_types/components/affix_fuzzer10.py +++ b/rerun_py/tests/test_types/components/affix_fuzzer10.py @@ -7,6 +7,8 @@ from typing import Any, Sequence, Union +import numpy as np +import numpy.typing as npt import pyarrow as pa from attrs import define, field from rerun._baseclasses import ( @@ -55,9 +57,11 @@ class AffixFuzzer10Batch(BaseBatch[AffixFuzzer10ArrayLike], ComponentBatchMixin) @staticmethod def _native_to_pa_array(data: AffixFuzzer10ArrayLike, data_type: pa.DataType) -> pa.Array: if isinstance(data, str): - array = [data] + array: Union[list[str], npt.ArrayLike] = [data] elif isinstance(data, Sequence): array = [str(datum) for datum in data] + elif isinstance(data, np.ndarray): + array = data else: array = [str(data)] diff --git a/rerun_py/tests/test_types/components/affix_fuzzer9.py b/rerun_py/tests/test_types/components/affix_fuzzer9.py index b3c991c0df34..fde70f4c1d5e 100644 --- a/rerun_py/tests/test_types/components/affix_fuzzer9.py +++ b/rerun_py/tests/test_types/components/affix_fuzzer9.py @@ -7,6 +7,8 @@ from typing import Any, Sequence, Union +import numpy as np +import numpy.typing as npt import pyarrow as pa from attrs import define, field from rerun._baseclasses import ( @@ -58,9 +60,11 @@ class AffixFuzzer9Batch(BaseBatch[AffixFuzzer9ArrayLike], ComponentBatchMixin): @staticmethod def _native_to_pa_array(data: AffixFuzzer9ArrayLike, data_type: pa.DataType) -> pa.Array: if isinstance(data, str): - array = [data] + array: Union[list[str], npt.ArrayLike] = [data] elif isinstance(data, Sequence): array = [str(datum) for datum in data] + elif isinstance(data, np.ndarray): + array = data else: array = [str(data)] diff --git a/rerun_py/tests/test_types/datatypes/string_component.py b/rerun_py/tests/test_types/datatypes/string_component.py index b1dc986439ce..eb79d9ba1d2e 100644 --- a/rerun_py/tests/test_types/datatypes/string_component.py +++ b/rerun_py/tests/test_types/datatypes/string_component.py @@ -7,6 +7,8 @@ from typing import Any, Sequence, Union +import numpy as np +import numpy.typing as npt import pyarrow as pa from attrs import define, field from rerun._baseclasses import ( @@ -60,9 +62,11 @@ class StringComponentBatch(BaseBatch[StringComponentArrayLike]): @staticmethod def _native_to_pa_array(data: StringComponentArrayLike, data_type: pa.DataType) -> pa.Array: if isinstance(data, str): - array = [data] + array: Union[list[str], npt.ArrayLike] = [data] elif isinstance(data, Sequence): array = [str(datum) for datum in data] + elif isinstance(data, np.ndarray): + array = data else: array = [str(data)] diff --git a/rerun_py/tests/unit/test_utf8.py b/rerun_py/tests/unit/test_utf8.py new file mode 100644 index 000000000000..6844eb4fb9c2 --- /dev/null +++ b/rerun_py/tests/unit/test_utf8.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +import numpy as np +from rerun import datatypes + + +def test_utf8_batch_single() -> None: + single_string = "hello" + list_of_one_string = ["hello"] + array_of_one_string = np.array(["hello"]) + + assert ( + datatypes.Utf8Batch(single_string).as_arrow_array() == datatypes.Utf8Batch(list_of_one_string).as_arrow_array() + ) + + assert ( + datatypes.Utf8Batch(single_string).as_arrow_array() == datatypes.Utf8Batch(array_of_one_string).as_arrow_array() + ) + + +def test_utf8_batch_many() -> None: + list_of_strings = ["hello", "world"] + array_of_strings = np.array(["hello", "world"]) + + assert ( + datatypes.Utf8Batch(list_of_strings).as_arrow_array() == datatypes.Utf8Batch(array_of_strings).as_arrow_array() + )