Skip to content

Commit

Permalink
refactor(python): Expose transform as a submodule for pyiceberg_core (a…
Browse files Browse the repository at this point in the history
  • Loading branch information
Xuanwo authored Sep 11, 2024
1 parent 8a3de4e commit eae9464
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 96 deletions.
4 changes: 2 additions & 2 deletions bindings/python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,5 @@ crate-type = ["cdylib"]

[dependencies]
iceberg = { path = "../../crates/iceberg" }
pyo3 = { version = "0.21.1", features = ["extension-module"] }
arrow = { version = "52.2.0", features = ["pyarrow"] }
pyo3 = { version = "0.21", features = ["extension-module"] }
arrow = { version = "52", features = ["pyarrow"] }
24 changes: 24 additions & 0 deletions bindings/python/src/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use pyo3::exceptions::PyValueError;
use pyo3::PyErr;

/// Convert an iceberg error to a python error
pub fn to_py_err(err: iceberg::Error) -> PyErr {
PyValueError::new_err(err.to_string())
}
16 changes: 3 additions & 13 deletions bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,13 @@
// specific language governing permissions and limitations
// under the License.

use iceberg::io::FileIOBuilder;
use pyo3::prelude::*;
use pyo3::wrap_pyfunction;

mod error;
mod transform;

#[pyfunction]
fn hello_world() -> PyResult<String> {
let _ = FileIOBuilder::new_fs_io().build().unwrap();
Ok("Hello, world!".to_string())
}


#[pymodule]
fn pyiceberg_core_rust(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(hello_world, m)?)?;

m.add_class::<transform::ArrowArrayTransform>()?;
fn pyiceberg_core_rust(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
transform::register_module(py, m)?;
Ok(())
}
104 changes: 55 additions & 49 deletions bindings/python/src/transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,55 @@
// specific language governing permissions and limitations
// under the License.

use arrow::array::{make_array, Array, ArrayData};
use arrow::pyarrow::{FromPyArrow, ToPyArrow};
use iceberg::spec::Transform;
use iceberg::transform::create_transform_function;
use pyo3::prelude::*;

use arrow::{
array::{make_array, Array, ArrayData},
};
use arrow::pyarrow::{FromPyArrow, ToPyArrow};
use pyo3::{exceptions::PyValueError, prelude::*};
use crate::error::to_py_err;

#[pyfunction]
pub fn identity(py: Python, array: PyObject) -> PyResult<PyObject> {
apply(py, array, Transform::Identity)
}

#[pyfunction]
pub fn void(py: Python, array: PyObject) -> PyResult<PyObject> {
apply(py, array, Transform::Void)
}

#[pyfunction]
pub fn year(py: Python, array: PyObject) -> PyResult<PyObject> {
apply(py, array, Transform::Year)
}

#[pyfunction]
pub fn month(py: Python, array: PyObject) -> PyResult<PyObject> {
apply(py, array, Transform::Month)
}

fn to_py_err(err: iceberg::Error) -> PyErr {
PyValueError::new_err(err.to_string())
#[pyfunction]
pub fn day(py: Python, array: PyObject) -> PyResult<PyObject> {
apply(py, array, Transform::Day)
}

#[pyclass]
pub struct ArrowArrayTransform {
#[pyfunction]
pub fn hour(py: Python, array: PyObject) -> PyResult<PyObject> {
apply(py, array, Transform::Hour)
}

fn apply(array: PyObject, transform: Transform, py: Python) -> PyResult<PyObject> {
#[pyfunction]
pub fn bucket(py: Python, array: PyObject, num_buckets: u32) -> PyResult<PyObject> {
apply(py, array, Transform::Bucket(num_buckets))
}

#[pyfunction]
pub fn truncate(py: Python, array: PyObject, width: u32) -> PyResult<PyObject> {
apply(py, array, Transform::Truncate(width))
}

fn apply(py: Python, array: PyObject, transform: Transform) -> PyResult<PyObject> {
// import
let array = ArrayData::from_pyarrow_bound(array.bind(py))?;
let array = make_array(array);
Expand All @@ -43,45 +74,20 @@ fn apply(array: PyObject, transform: Transform, py: Python) -> PyResult<PyObject
array.to_pyarrow(py)
}

#[pymethods]
impl ArrowArrayTransform {
#[staticmethod]
pub fn identity(array: PyObject, py: Python) -> PyResult<PyObject> {
apply(array, Transform::Identity, py)
}

#[staticmethod]
pub fn void(array: PyObject, py: Python) -> PyResult<PyObject> {
apply(array, Transform::Void, py)
}

#[staticmethod]
pub fn year(array: PyObject, py: Python) -> PyResult<PyObject> {
apply(array, Transform::Year, py)
}

#[staticmethod]
pub fn month(array: PyObject, py: Python) -> PyResult<PyObject> {
apply(array, Transform::Month, py)
}

#[staticmethod]
pub fn day(array: PyObject, py: Python) -> PyResult<PyObject> {
apply(array, Transform::Day, py)
}

#[staticmethod]
pub fn hour(array: PyObject, py: Python) -> PyResult<PyObject> {
apply(array, Transform::Hour, py)
}
pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
let this = PyModule::new_bound(py, "transform")?;

#[staticmethod]
pub fn bucket(array: PyObject, num_buckets: u32, py: Python) -> PyResult<PyObject> {
apply(array, Transform::Bucket(num_buckets), py)
}
this.add_function(wrap_pyfunction!(identity, &this)?)?;
this.add_function(wrap_pyfunction!(void, &this)?)?;
this.add_function(wrap_pyfunction!(year, &this)?)?;
this.add_function(wrap_pyfunction!(month, &this)?)?;
this.add_function(wrap_pyfunction!(day, &this)?)?;
this.add_function(wrap_pyfunction!(hour, &this)?)?;
this.add_function(wrap_pyfunction!(bucket, &this)?)?;
this.add_function(wrap_pyfunction!(truncate, &this)?)?;

#[staticmethod]
pub fn truncate(array: PyObject, width: u32, py: Python) -> PyResult<PyObject> {
apply(array, Transform::Truncate(width), py)
}
m.add_submodule(&this)?;
py.import_bound("sys")?
.getattr("modules")?
.set_item("pyiceberg_core.transform", this)
}
22 changes: 0 additions & 22 deletions bindings/python/tests/test_basic.py

This file was deleted.

28 changes: 18 additions & 10 deletions bindings/python/tests/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,18 @@

import pyarrow as pa
import pytest
from pyiceberg_core import ArrowArrayTransform
from pyiceberg_core import transform


def test_identity_transform():
arr = pa.array([1, 2])
result = ArrowArrayTransform.identity(arr)
result = transform.identity(arr)
assert result == arr


def test_bucket_transform():
arr = pa.array([1, 2])
result = ArrowArrayTransform.bucket(arr, 10)
result = transform.bucket(arr, 10)
expected = pa.array([6, 2], type=pa.int32())
assert result == expected

Expand All @@ -41,14 +41,14 @@ def test_bucket_transform_fails_for_list_type_input():
ValueError,
match=r"FeatureUnsupported => Unsupported data type for bucket transform",
):
ArrowArrayTransform.bucket(arr, 10)
transform.bucket(arr, 10)


def test_bucket_chunked_array():
chunked = pa.chunked_array([pa.array([1, 2]), pa.array([3, 4])])
result_chunks = []
for arr in chunked.iterchunks():
result_chunks.append(ArrowArrayTransform.bucket(arr, 10))
result_chunks.append(transform.bucket(arr, 10))

expected = pa.chunked_array(
[pa.array([6, 2], type=pa.int32()), pa.array([5, 0], type=pa.int32())]
Expand All @@ -58,34 +58,42 @@ def test_bucket_chunked_array():

def test_year_transform():
arr = pa.array([date(1970, 1, 1), date(2000, 1, 1)])
result = ArrowArrayTransform.year(arr)
result = transform.year(arr)
expected = pa.array([0, 30], type=pa.int32())
assert result == expected


def test_month_transform():
arr = pa.array([date(1970, 1, 1), date(2000, 4, 1)])
result = ArrowArrayTransform.month(arr)
result = transform.month(arr)
expected = pa.array([0, 30 * 12 + 3], type=pa.int32())
assert result == expected


def test_day_transform():
arr = pa.array([date(1970, 1, 1), date(2000, 4, 1)])
result = ArrowArrayTransform.day(arr)
result = transform.day(arr)
expected = pa.array([0, 11048], type=pa.int32())
assert result == expected


def test_hour_transform():
arr = pa.array([datetime(1970, 1, 1, 19, 1, 23), datetime(2000, 3, 1, 12, 1, 23)])
result = ArrowArrayTransform.hour(arr)
result = transform.hour(arr)
expected = pa.array([19, 264420], type=pa.int32())
assert result == expected


def test_truncate_transform():
arr = pa.array(["this is a long string", "hi my name is sung"])
result = ArrowArrayTransform.truncate(arr, 5)
result = transform.truncate(arr, 5)
expected = pa.array(["this ", "hi my"])
assert result == expected


def test_identity_transform_with_direct_import():
from pyiceberg_core.transform import identity

arr = pa.array([1, 2])
result = identity(arr)
assert result == arr

0 comments on commit eae9464

Please sign in to comment.