diff --git a/apis/python/node/src/lib.rs b/apis/python/node/src/lib.rs index cdab15435..c3957b422 100644 --- a/apis/python/node/src/lib.rs +++ b/apis/python/node/src/lib.rs @@ -1,11 +1,10 @@ #![allow(clippy::borrow_deref_ref)] // clippy warns about code generated by #[pymethods] -use arrow::pyarrow::PyArrowConvert; use dora_node_api::{DoraNode, EventStream}; -use dora_operator_api_python::{pydict_to_metadata, PyEvent}; +use dora_operator_api_python::{process_python_output, pydict_to_metadata, PyEvent}; use eyre::Context; use pyo3::prelude::*; -use pyo3::types::{PyBytes, PyDict}; +use pyo3::types::PyDict; #[pyclass] pub struct Node { @@ -43,24 +42,9 @@ impl Node { metadata: Option<&PyDict>, py: Python, ) -> eyre::Result<()> { - if let Ok(py_bytes) = data.downcast::(py) { - let data = py_bytes.as_bytes(); + process_python_output(&data, py, |data| { self.send_output_slice(output_id, data.len(), data, metadata) - } else if let Ok(arrow_array) = arrow::array::ArrayData::from_pyarrow(data.as_ref(py)) { - if arrow_array.data_type() != &arrow::datatypes::DataType::UInt8 { - eyre::bail!("only arrow arrays with data type `UInt8` are supported"); - } - if arrow_array.buffers().len() != 1 { - eyre::bail!("output arrow array must contain a single buffer"); - } - - let len = arrow_array.len(); - let slice = &arrow_array.buffer(0)[..len]; - - self.send_output_slice(output_id, len, slice, metadata) - } else { - eyre::bail!("invalid `data` type, must by `PyBytes` or arrow array") - } + }) } } diff --git a/apis/python/operator/src/lib.rs b/apis/python/operator/src/lib.rs index 6db72616f..5c6fd8e65 100644 --- a/apis/python/operator/src/lib.rs +++ b/apis/python/operator/src/lib.rs @@ -127,3 +127,28 @@ pub fn metadata_to_pydict<'a>(metadata: &'a Metadata, py: Python<'a>) -> &'a PyD .unwrap(); dict } + +pub fn process_python_output( + data: &PyObject, + py: Python, + callback: impl FnOnce(&[u8]) -> eyre::Result, +) -> eyre::Result { + if let Ok(py_bytes) = data.downcast::(py) { + let data = py_bytes.as_bytes(); + callback(data) + } else if let Ok(arrow_array) = arrow::array::ArrayData::from_pyarrow(data.as_ref(py)) { + if arrow_array.data_type() != &arrow::datatypes::DataType::UInt8 { + eyre::bail!("only arrow arrays with data type `UInt8` are supported"); + } + if arrow_array.buffers().len() != 1 { + eyre::bail!("output arrow array must contain a single buffer"); + } + + let len = arrow_array.len(); + let slice = &arrow_array.buffer(0)[..len]; + + callback(slice) + } else { + eyre::bail!("invalid `data` type, must by `PyBytes` or arrow array") + } +} diff --git a/binaries/runtime/src/operator/python.rs b/binaries/runtime/src/operator/python.rs index d5c837e10..d5d973f99 100644 --- a/binaries/runtime/src/operator/python.rs +++ b/binaries/runtime/src/operator/python.rs @@ -264,30 +264,29 @@ mod callback_impl { use crate::operator::OperatorEvent; use super::SendOutputCallback; - use dora_operator_api_python::pydict_to_metadata; + use dora_operator_api_python::{process_python_output, pydict_to_metadata}; use eyre::{eyre, Context, Result}; - use pyo3::{ - pymethods, - types::{PyBytes, PyDict}, - }; + use pyo3::{pymethods, types::PyDict, PyObject, Python}; #[pymethods] impl SendOutputCallback { fn __call__( &mut self, output: &str, - data: &PyBytes, + data: PyObject, metadata: Option<&PyDict>, + py: Python, ) -> Result<()> { - let data = data.as_bytes(); + let data = process_python_output(&data, py, |data| Ok(data.to_owned()))?; + let metadata = pydict_to_metadata(metadata) - .wrap_err("Could not parse metadata.")? + .wrap_err("failed to parse metadata")? .into_owned(); let event = OperatorEvent::Output { output_id: output.to_owned().into(), metadata, - data: data.to_owned(), + data, }; self.events_tx