Skip to content

Commit

Permalink
Expose Transforms to Python Binding (#556)
Browse files Browse the repository at this point in the history
* bucket transform rust binding

* format

* poetry x maturin

* ignore poetry.lock in license check

* update bindings_python_ci to use makefile

* newline

* python-poetry/poetry#9135

* use hatch instead of poetry

* refactor

* revert licenserc change

* adopt review feedback

* comments

* unused dependency

* adopt review comment

* newline

* I like this approach a lot better

* more tests
  • Loading branch information
sungwy authored Aug 27, 2024
1 parent 905ebd2 commit ecbb4c3
Show file tree
Hide file tree
Showing 7 changed files with 190 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/bindings_python_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,4 @@ jobs:
set -e
pip install hatch==1.12.0
hatch run dev:pip install dist/pyiceberg_core-*.whl --force-reinstall
hatch run dev:test
hatch run dev:test
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,5 @@ dist/*
**/venv
*.so
*.pyc
*.whl
*.tar.gz
3 changes: 2 additions & 1 deletion bindings/python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,5 @@ crate-type = ["cdylib"]

[dependencies]
iceberg = { path = "../../crates/iceberg" }
pyo3 = { version = "0.22", features = ["extension-module"] }
pyo3 = { version = "0.21.1", features = ["extension-module"] }
arrow = { version = "52.2.0", features = ["pyarrow"] }
1 change: 1 addition & 0 deletions bindings/python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ ignore = ["F403", "F405"]
dependencies = [
"maturin>=1.0,<2.0",
"pytest>=8.3.2",
"pyarrow>=17.0.0",
]

[tool.hatch.envs.dev.scripts]
Expand Down
6 changes: 6 additions & 0 deletions bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,21 @@

use iceberg::io::FileIOBuilder;
use pyo3::prelude::*;
use pyo3::wrap_pyfunction;

mod transform;

#[pyfunction]
fn hello_world() -> PyResult<String> {
let _ = FileIOBuilder::new_fs_io().build().unwrap();
Ok("Hello, world!".to_string())
}


#[pymodule]
fn pyiceberg_core_rust(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(hello_world, m)?)?;

m.add_class::<transform::ArrowArrayTransform>()?;
Ok(())
}
87 changes: 87 additions & 0 deletions bindings/python/src/transform.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use iceberg::spec::Transform;
use iceberg::transform::create_transform_function;

use arrow::{
array::{make_array, Array, ArrayData},
};
use arrow::pyarrow::{FromPyArrow, ToPyArrow};
use pyo3::{exceptions::PyValueError, prelude::*};

fn to_py_err(err: iceberg::Error) -> PyErr {
PyValueError::new_err(err.to_string())
}

#[pyclass]
pub struct ArrowArrayTransform {
}

fn apply(array: PyObject, transform: Transform, py: Python) -> PyResult<PyObject> {
// import
let array = ArrayData::from_pyarrow_bound(array.bind(py))?;
let array = make_array(array);
let transform_function = create_transform_function(&transform).map_err(to_py_err)?;
let array = transform_function.transform(array).map_err(to_py_err)?;
// export
let array = array.into_data();
array.to_pyarrow(py)
}

#[pymethods]
impl ArrowArrayTransform {
#[staticmethod]
pub fn identity(array: PyObject, py: Python) -> PyResult<PyObject> {
apply(array, Transform::Identity, py)
}

#[staticmethod]
pub fn void(array: PyObject, py: Python) -> PyResult<PyObject> {
apply(array, Transform::Void, py)
}

#[staticmethod]
pub fn year(array: PyObject, py: Python) -> PyResult<PyObject> {
apply(array, Transform::Year, py)
}

#[staticmethod]
pub fn month(array: PyObject, py: Python) -> PyResult<PyObject> {
apply(array, Transform::Month, py)
}

#[staticmethod]
pub fn day(array: PyObject, py: Python) -> PyResult<PyObject> {
apply(array, Transform::Day, py)
}

#[staticmethod]
pub fn hour(array: PyObject, py: Python) -> PyResult<PyObject> {
apply(array, Transform::Hour, py)
}

#[staticmethod]
pub fn bucket(array: PyObject, num_buckets: u32, py: Python) -> PyResult<PyObject> {
apply(array, Transform::Bucket(num_buckets), py)
}

#[staticmethod]
pub fn truncate(array: PyObject, width: u32, py: Python) -> PyResult<PyObject> {
apply(array, Transform::Truncate(width), py)
}
}
91 changes: 91 additions & 0 deletions bindings/python/tests/test_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from datetime import date, datetime

import pyarrow as pa
import pytest
from pyiceberg_core import ArrowArrayTransform


def test_identity_transform():
arr = pa.array([1, 2])
result = ArrowArrayTransform.identity(arr)
assert result == arr


def test_bucket_transform():
arr = pa.array([1, 2])
result = ArrowArrayTransform.bucket(arr, 10)
expected = pa.array([6, 2], type=pa.int32())
assert result == expected


def test_bucket_transform_fails_for_list_type_input():
arr = pa.array([[1, 2], [3, 4]])
with pytest.raises(
ValueError,
match=r"FeatureUnsupported => Unsupported data type for bucket transform",
):
ArrowArrayTransform.bucket(arr, 10)


def test_bucket_chunked_array():
chunked = pa.chunked_array([pa.array([1, 2]), pa.array([3, 4])])
result_chunks = []
for arr in chunked.iterchunks():
result_chunks.append(ArrowArrayTransform.bucket(arr, 10))

expected = pa.chunked_array(
[pa.array([6, 2], type=pa.int32()), pa.array([5, 0], type=pa.int32())]
)
assert pa.chunked_array(result_chunks).equals(expected)


def test_year_transform():
arr = pa.array([date(1970, 1, 1), date(2000, 1, 1)])
result = ArrowArrayTransform.year(arr)
expected = pa.array([0, 30], type=pa.int32())
assert result == expected


def test_month_transform():
arr = pa.array([date(1970, 1, 1), date(2000, 4, 1)])
result = ArrowArrayTransform.month(arr)
expected = pa.array([0, 30 * 12 + 3], type=pa.int32())
assert result == expected


def test_day_transform():
arr = pa.array([date(1970, 1, 1), date(2000, 4, 1)])
result = ArrowArrayTransform.day(arr)
expected = pa.array([0, 11048], type=pa.int32())
assert result == expected


def test_hour_transform():
arr = pa.array([datetime(1970, 1, 1, 19, 1, 23), datetime(2000, 3, 1, 12, 1, 23)])
result = ArrowArrayTransform.hour(arr)
expected = pa.array([19, 264420], type=pa.int32())
assert result == expected


def test_truncate_transform():
arr = pa.array(["this is a long string", "hi my name is sung"])
result = ArrowArrayTransform.truncate(arr, 5)
expected = pa.array(["this ", "hi my"])
assert result == expected

0 comments on commit ecbb4c3

Please sign in to comment.