From ecbb4c310e6dc479bfee881d3cea95ad62e6b704 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+sungwy@users.noreply.github.com> Date: Mon, 26 Aug 2024 23:57:01 -0400 Subject: [PATCH] Expose Transforms to Python Binding (#556) * bucket transform rust binding * format * poetry x maturin * ignore poetry.lock in license check * update bindings_python_ci to use makefile * newline * https://github.com/python-poetry/poetry/pull/9135 * use hatch instead of poetry * refactor * revert licenserc change * adopt review feedback * comments * unused dependency * adopt review comment * newline * I like this approach a lot better * more tests --- .github/workflows/bindings_python_ci.yml | 2 +- .gitignore | 2 + bindings/python/Cargo.toml | 3 +- bindings/python/pyproject.toml | 1 + bindings/python/src/lib.rs | 6 ++ bindings/python/src/transform.rs | 87 ++++++++++++++++++++++ bindings/python/tests/test_transform.py | 91 ++++++++++++++++++++++++ 7 files changed, 190 insertions(+), 2 deletions(-) create mode 100644 bindings/python/src/transform.rs create mode 100644 bindings/python/tests/test_transform.py diff --git a/.github/workflows/bindings_python_ci.yml b/.github/workflows/bindings_python_ci.yml index 1f50a8eb5..d4b1aa922 100644 --- a/.github/workflows/bindings_python_ci.yml +++ b/.github/workflows/bindings_python_ci.yml @@ -80,4 +80,4 @@ jobs: set -e pip install hatch==1.12.0 hatch run dev:pip install dist/pyiceberg_core-*.whl --force-reinstall - hatch run dev:test \ No newline at end of file + hatch run dev:test diff --git a/.gitignore b/.gitignore index 05c11eda6..a3f05e817 100644 --- a/.gitignore +++ b/.gitignore @@ -24,3 +24,5 @@ dist/* **/venv *.so *.pyc +*.whl +*.tar.gz diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index c2c1007b7..0260f788b 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -32,4 +32,5 @@ crate-type = ["cdylib"] [dependencies] iceberg = { path = "../../crates/iceberg" } -pyo3 = { version = "0.22", features = ["extension-module"] } +pyo3 = { version = "0.21.1", features = ["extension-module"] } +arrow = { version = "52.2.0", features = ["pyarrow"] } diff --git a/bindings/python/pyproject.toml b/bindings/python/pyproject.toml index 910cf50dc..f1f0a100f 100644 --- a/bindings/python/pyproject.toml +++ b/bindings/python/pyproject.toml @@ -43,6 +43,7 @@ ignore = ["F403", "F405"] dependencies = [ "maturin>=1.0,<2.0", "pytest>=8.3.2", + "pyarrow>=17.0.0", ] [tool.hatch.envs.dev.scripts] diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index f0d5d1935..5c3f77ff7 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -17,6 +17,9 @@ use iceberg::io::FileIOBuilder; use pyo3::prelude::*; +use pyo3::wrap_pyfunction; + +mod transform; #[pyfunction] fn hello_world() -> PyResult { @@ -24,8 +27,11 @@ fn hello_world() -> PyResult { Ok("Hello, world!".to_string()) } + #[pymodule] fn pyiceberg_core_rust(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(hello_world, m)?)?; + + m.add_class::()?; Ok(()) } diff --git a/bindings/python/src/transform.rs b/bindings/python/src/transform.rs new file mode 100644 index 000000000..8f4585b2a --- /dev/null +++ b/bindings/python/src/transform.rs @@ -0,0 +1,87 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use iceberg::spec::Transform; +use iceberg::transform::create_transform_function; + +use arrow::{ + array::{make_array, Array, ArrayData}, +}; +use arrow::pyarrow::{FromPyArrow, ToPyArrow}; +use pyo3::{exceptions::PyValueError, prelude::*}; + +fn to_py_err(err: iceberg::Error) -> PyErr { + PyValueError::new_err(err.to_string()) +} + +#[pyclass] +pub struct ArrowArrayTransform { +} + +fn apply(array: PyObject, transform: Transform, py: Python) -> PyResult { + // import + let array = ArrayData::from_pyarrow_bound(array.bind(py))?; + let array = make_array(array); + let transform_function = create_transform_function(&transform).map_err(to_py_err)?; + let array = transform_function.transform(array).map_err(to_py_err)?; + // export + let array = array.into_data(); + array.to_pyarrow(py) +} + +#[pymethods] +impl ArrowArrayTransform { + #[staticmethod] + pub fn identity(array: PyObject, py: Python) -> PyResult { + apply(array, Transform::Identity, py) + } + + #[staticmethod] + pub fn void(array: PyObject, py: Python) -> PyResult { + apply(array, Transform::Void, py) + } + + #[staticmethod] + pub fn year(array: PyObject, py: Python) -> PyResult { + apply(array, Transform::Year, py) + } + + #[staticmethod] + pub fn month(array: PyObject, py: Python) -> PyResult { + apply(array, Transform::Month, py) + } + + #[staticmethod] + pub fn day(array: PyObject, py: Python) -> PyResult { + apply(array, Transform::Day, py) + } + + #[staticmethod] + pub fn hour(array: PyObject, py: Python) -> PyResult { + apply(array, Transform::Hour, py) + } + + #[staticmethod] + pub fn bucket(array: PyObject, num_buckets: u32, py: Python) -> PyResult { + apply(array, Transform::Bucket(num_buckets), py) + } + + #[staticmethod] + pub fn truncate(array: PyObject, width: u32, py: Python) -> PyResult { + apply(array, Transform::Truncate(width), py) + } +} diff --git a/bindings/python/tests/test_transform.py b/bindings/python/tests/test_transform.py new file mode 100644 index 000000000..1fa2d577a --- /dev/null +++ b/bindings/python/tests/test_transform.py @@ -0,0 +1,91 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from datetime import date, datetime + +import pyarrow as pa +import pytest +from pyiceberg_core import ArrowArrayTransform + + +def test_identity_transform(): + arr = pa.array([1, 2]) + result = ArrowArrayTransform.identity(arr) + assert result == arr + + +def test_bucket_transform(): + arr = pa.array([1, 2]) + result = ArrowArrayTransform.bucket(arr, 10) + expected = pa.array([6, 2], type=pa.int32()) + assert result == expected + + +def test_bucket_transform_fails_for_list_type_input(): + arr = pa.array([[1, 2], [3, 4]]) + with pytest.raises( + ValueError, + match=r"FeatureUnsupported => Unsupported data type for bucket transform", + ): + ArrowArrayTransform.bucket(arr, 10) + + +def test_bucket_chunked_array(): + chunked = pa.chunked_array([pa.array([1, 2]), pa.array([3, 4])]) + result_chunks = [] + for arr in chunked.iterchunks(): + result_chunks.append(ArrowArrayTransform.bucket(arr, 10)) + + expected = pa.chunked_array( + [pa.array([6, 2], type=pa.int32()), pa.array([5, 0], type=pa.int32())] + ) + assert pa.chunked_array(result_chunks).equals(expected) + + +def test_year_transform(): + arr = pa.array([date(1970, 1, 1), date(2000, 1, 1)]) + result = ArrowArrayTransform.year(arr) + expected = pa.array([0, 30], type=pa.int32()) + assert result == expected + + +def test_month_transform(): + arr = pa.array([date(1970, 1, 1), date(2000, 4, 1)]) + result = ArrowArrayTransform.month(arr) + expected = pa.array([0, 30 * 12 + 3], type=pa.int32()) + assert result == expected + + +def test_day_transform(): + arr = pa.array([date(1970, 1, 1), date(2000, 4, 1)]) + result = ArrowArrayTransform.day(arr) + expected = pa.array([0, 11048], type=pa.int32()) + assert result == expected + + +def test_hour_transform(): + arr = pa.array([datetime(1970, 1, 1, 19, 1, 23), datetime(2000, 3, 1, 12, 1, 23)]) + result = ArrowArrayTransform.hour(arr) + expected = pa.array([19, 264420], type=pa.int32()) + assert result == expected + + +def test_truncate_transform(): + arr = pa.array(["this is a long string", "hi my name is sung"]) + result = ArrowArrayTransform.truncate(arr, 5) + expected = pa.array(["this ", "hi my"]) + assert result == expected