Skip to content

Commit

Permalink
Use arrow crate to send events from Rust to Python without copying
Browse files Browse the repository at this point in the history
See #224 for details.
  • Loading branch information
phil-opp committed Mar 21, 2023
1 parent 214f2af commit c948639
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 75 deletions.
2 changes: 1 addition & 1 deletion apis/python/node/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ eyre = "0.6"
serde_yaml = "0.8.23"
flume = "0.10.14"
dora-runtime = { workspace = true }
arrow2 = "0.16"
arrow = { version = "35.0.0", features = ["pyarrow"] }

[lib]
name = "dora"
Expand Down
94 changes: 25 additions & 69 deletions apis/python/node/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
#![allow(clippy::borrow_deref_ref)] // clippy warns about code generated by #[pymethods]

use arrow2::array::PrimitiveArray;
use arrow2::{array::Array, datatypes::Field, ffi};
use arrow::pyarrow::PyArrowConvert;
use dora_node_api::{DoraNode, Event, EventStream};
use dora_operator_api_python::{metadata_to_pydict, pydict_to_metadata};
use eyre::{Context, ContextCompat, Result};
use pyo3::ffi::Py_uintptr_t;
use eyre::{Context, Result};
use pyo3::prelude::*;
use pyo3::types::PyDict;
#[pyclass]
Expand All @@ -26,12 +24,18 @@ impl IntoPy<PyObject> for PyInput {
dict.set_item("id", id.to_string())
.wrap_err("failed to add input ID")
.unwrap();
let array =
unsafe { arrow2::ffi::mmap::slice(&data.as_deref().unwrap_or_default()) };
let array_data = to_py_array(Box::new(array), py).unwrap();
dict.set_item("data", array_data)
.wrap_err("failed to add input data")
.unwrap();
if let Some(data) = data {
let array = data.into_arrow_array();
// TODO: Does this call leak data?
let array_data = array
.to_pyarrow(py)
.wrap_err("failed to convert arrow data to Python")
.unwrap();
dict.set_item("data", array_data)
.wrap_err("failed to add input data")
.unwrap();
}

dict.set_item("metadata", metadata_to_pydict(&metadata, py))
.wrap_err("failed to add input metadata")
.unwrap();
Expand Down Expand Up @@ -89,17 +93,19 @@ impl Node {
metadata: Option<&PyDict>,
py: Python,
) -> Result<()> {
let buffer = to_rust_array(data, py).unwrap();
let data = buffer
.as_any()
.downcast_ref::<PrimitiveArray<u8>>()
.wrap_err("Could not cast sent output to arrow uint8 array")
.unwrap()
.values();
let data = arrow::array::ArrayData::from_pyarrow(data.as_ref(py))
.wrap_err("failed to read data as Arrow array")?;
if data.buffers().len() != 1 {
eyre::bail!("output arrow array must contain a single buffer");
}

let len = data.len();
let slice = &data.buffer(0)[..len];

let metadata = pydict_to_metadata(metadata)?;
self.node
.send_output(output_id.into(), metadata, data.len(), |out| {
out.copy_from_slice(&data);
.send_output(output_id.into(), metadata, len, |out| {
out.copy_from_slice(slice);
})
.wrap_err("Could not send output")
}
Expand All @@ -108,56 +114,6 @@ impl Node {
self.node.id().to_string()
}
}
// Taken from arrow2/example: https://github.com/jorgecarleitao/arrow2/blob/main/arrow-pyarrow-integration-testing/src/lib.rs
fn to_py_array(array: Box<dyn Array>, py: Python) -> PyResult<PyObject> {
let schema = Box::new(ffi::export_field_to_c(&Field::new(
"",
array.data_type().clone(),
true,
)));
let array = Box::new(ffi::export_array_to_c(array));

let schema_ptr: *const arrow2::ffi::ArrowSchema = &*schema;
let array_ptr: *const arrow2::ffi::ArrowArray = &*array;

let pa = py.import("pyarrow")?;

let array = pa.getattr("Array")?.call_method1(
"_import_from_c",
(array_ptr as Py_uintptr_t, schema_ptr as Py_uintptr_t),
)?;

Ok(array.to_object(py))
}

// Taken from arrow2/example: https://github.com/jorgecarleitao/arrow2/blob/main/arrow-pyarrow-integration-testing/src/lib.rs
fn to_rust_array(ob: PyObject, py: Python) -> PyResult<Box<dyn Array>> {
// prepare a pointer to receive the Array struct
let array = Box::new(ffi::ArrowArray::empty());
let schema = Box::new(ffi::ArrowSchema::empty());

let array_ptr = &*array as *const ffi::ArrowArray;
let schema_ptr = &*schema as *const ffi::ArrowSchema;

// make the conversion through PyArrow's private API
// this changes the pointer's memory and is thus unsafe. In particular, `_export_to_c` can go out of bounds
ob.call_method1(
py,
"_export_to_c",
(array_ptr as Py_uintptr_t, schema_ptr as Py_uintptr_t),
)?;

let field = unsafe {
ffi::import_field_from_c(schema.as_ref()).wrap_err("Could not parse output array")?
};

let array = unsafe {
ffi::import_array_from_c(*array, field.data_type)
.wrap_err("Could not parse output array")?
};

Ok(array)
}

#[pyfunction]
fn start_runtime() -> Result<()> {
Expand Down
1 change: 1 addition & 0 deletions apis/rust/node/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ capnp = "0.14.11"
bincode = "1.3.3"
shared_memory = "0.12.0"
dora-tracing = { workspace = true, optional = true }
arrow = "35.0.0"

[dev-dependencies]
tokio = { version = "1.24.2", features = ["rt"] }
8 changes: 4 additions & 4 deletions apis/rust/node/src/daemon_connection/event_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ impl EventStream {
};

if let Some(tx) = tx.as_ref() {
let (drop_tx, drop_rx) = std::sync::mpsc::channel();
let (drop_tx, drop_rx) = flume::bounded(0);
match tx.send(EventItem::NodeEvent {
event,
ack_channel: drop_tx,
Expand All @@ -104,13 +104,13 @@ impl EventStream {
"Node API should not send anything on ACK channel"
))
}
Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {
Err(flume::RecvTimeoutError::Timeout) => {
tracing::warn!("timeout: event was not dropped after {timeout:?}");
if let Some(drop_token) = drop_token {
tracing::warn!("leaking drop token {drop_token:?}");
}
}
Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => {
Err(flume::RecvTimeoutError::Disconnected) => {
// the event was dropped -> add the drop token to the list
if let Some(token) = drop_token {
drop_tokens.push(token);
Expand Down Expand Up @@ -220,7 +220,7 @@ impl EventStream {
enum EventItem {
NodeEvent {
event: NodeEvent,
ack_channel: std::sync::mpsc::Sender<()>,
ack_channel: flume::Sender<()>,
},
FatalError(eyre::Report),
}
25 changes: 24 additions & 1 deletion apis/rust/node/src/event.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::marker::PhantomData;
use std::{ptr::NonNull, sync::Arc};

use dora_core::{config::DataId, message::Metadata};
use eyre::Context;
Expand Down Expand Up @@ -26,6 +26,26 @@ pub enum Data {
_drop: flume::Sender<()>,
},
}
impl Data {
pub fn into_arrow_array(self) -> arrow::array::ArrayData {
let ptr = NonNull::new(self.as_ptr() as *mut _).unwrap();
let len = self.len();
let owner = Arc::new(self);

let buffer = unsafe { arrow::buffer::Buffer::from_custom_allocation(ptr, len, owner) };
unsafe {
arrow::array::ArrayData::new_unchecked(
arrow::datatypes::DataType::UInt8,
len,
Some(0),
None,
0,
vec![buffer],
vec![],
)
}
}
}

impl std::ops::Deref for Data {
type Target = [u8];
Expand Down Expand Up @@ -66,3 +86,6 @@ impl std::ops::Deref for MappedInputData {
unsafe { &self.memory.as_slice()[..self.len] }
}
}

unsafe impl Send for MappedInputData {}
unsafe impl Sync for MappedInputData {}

0 comments on commit c948639

Please sign in to comment.