Skip to content

Commit

Permalink
Pyo3 bindings for datafusion
Browse files Browse the repository at this point in the history
  • Loading branch information
twitu committed Feb 18, 2023
1 parent c4a9f0b commit 8f85d21
Show file tree
Hide file tree
Showing 6 changed files with 239 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@
strategy = EMACrossBracket(config=config)
engine.add_strategy(strategy=strategy)

time.sleep(0.1)
input("Press Enter to continue...") # noqa (always Python 3)
# time.sleep(0.1)
# input("Press Enter to continue...") # noqa (always Python 3)

# Run the engine (from start to end of data)
engine.run()
Expand Down
17 changes: 11 additions & 6 deletions nautilus_core/persistence/src/bin/fusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,23 @@ use datafusion::error::Result;
use datafusion::prelude::*;

use nautilus_model::data::tick::QuoteTick;
use nautilus_persistence::datafusion::PersistenceSession;
use nautilus_persistence::parquet::DecodeFromRecordBatch;
use nautilus_persistence::session::PersistenceSession;
use std::collections::HashMap;

#[tokio::main]
async fn main() -> Result<()> {
let reader = PersistenceSession::new().await?;
let reader = PersistenceSession::new();
let mut parquet_options = ParquetReadOptions::default();
parquet_options.skip_metadata = Some(false);
reader
.register_parquet(
"quote_tick",
"../tests/test_data/quote_tick_data.parquet",
ParquetReadOptions::default(),
"../../tests/test_data/quote_tick_data.parquet",
parquet_options,
)
.await?;
let stream = reader.query("SELECT * FROM quote_tick").await?;
let stream = reader.query("SELECT * FROM quote_tick SORT BY ts_init").await?;

let metadata: HashMap<String, String> = HashMap::from([
("instrument_id".to_string(), "EUR/USD.SIM".to_string()),
Expand All @@ -27,7 +29,10 @@ async fn main() -> Result<()> {
// extract row batches from stream and decode them to vec of ticks
let ticks: Vec<QuoteTick> = stream
.into_iter()
.flat_map(|batch| QuoteTick::decode_batch(&metadata, batch))
.flat_map(|batch| {
dbg!(batch.schema().metadata());
QuoteTick::decode_batch(&metadata, batch)
})
.collect();

let is_ascending_by_init = |ticks: &Vec<QuoteTick>| {
Expand Down
65 changes: 0 additions & 65 deletions nautilus_core/persistence/src/datafusion.rs

This file was deleted.

5 changes: 4 additions & 1 deletion nautilus_core/persistence/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.
// -------------------------------------------------------------------------------------------------

pub mod datafusion;
pub mod session;
pub mod parquet;

use std::{collections::BTreeMap, ffi::c_void, fs::File, io::Cursor, ptr::null_mut, slice};
Expand All @@ -25,6 +25,7 @@ use parquet::{
};
use pyo3::types::PyBytes;
use pyo3::{prelude::*, types::PyCapsule};
use session::{PersistenceQuery, PersistenceSession};

#[pyclass(name = "ParquetReader")]
struct PythonParquetReader {
Expand Down Expand Up @@ -317,5 +318,7 @@ pub fn persistence(_: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<PythonParquetWriter>()?;
m.add_class::<ParquetType>()?;
m.add_class::<ParquetReaderType>()?;
m.add_class::<PersistenceQuery>()?;
m.add_class::<PersistenceSession>()?;
Ok(())
}
199 changes: 199 additions & 0 deletions nautilus_core/persistence/src/session.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
use std::collections::HashMap;
use std::ffi::c_void;
use std::ops::Deref;
use std::ptr::null_mut;

use datafusion::physical_plan::SendableRecordBatchStream;
use datafusion::prelude::*;
use datafusion::{arrow::record_batch::RecordBatch, error::Result};
use futures::executor::{block_on, block_on_stream, BlockingStream};
use nautilus_core::cvec::CVec;
use nautilus_model::data::tick::{QuoteTick, TradeTick};
use pyo3::prelude::*;
use pyo3::types::PyCapsule;

use crate::parquet::{DecodeFromRecordBatch, ParquetType};

/// Store the data fusion session context
#[pyclass]
pub struct PersistenceSession {
session_ctx: SessionContext,
}

impl Deref for PersistenceSession {
type Target = SessionContext;

fn deref(&self) -> &Self::Target {
&self.session_ctx
}
}

/// Store the result stream created by executing a query
///
/// The async stream has been wrapped into a blocking stream. The nautilus
/// engine is a CPU intensive process so it will process the events in one
/// batch and then request more. We want to block the thread until it
/// receives more events to consume.
#[pyclass]
pub struct PersistenceQueryResult(pub BlockingStream<SendableRecordBatchStream>);

impl Iterator for PersistenceQueryResult {
type Item = RecordBatch;

fn next(&mut self) -> Option<Self::Item> {
if let Some(result) = self.0.next() {
match result {
Ok(batch) => Some(batch),
// TODO log or handle error here
Err(_) => None,
}
} else {
None
}
}
}

impl PersistenceSession {
/// Create a new data fusion session
///
/// This can register new files and data sources
pub fn new() -> Self {
PersistenceSession {
session_ctx: SessionContext::new(),
}
}

/// Takes an sql query and creates a data frame
///
/// The data frame is the logical plan that can be executed on the
/// data sources registered with the context. The async stream
/// is wrapped into a blocking stream.
pub async fn query(&self, sql: &str) -> Result<PersistenceQueryResult> {
let df = self.sql(sql).await?;
let stream = df.execute_stream().await?;
Ok(PersistenceQueryResult(block_on_stream(stream)))
}
}

/// Persistence session methods exposed to Python
///
/// session_ctx has all the methods needed to manipulate the session
/// context. However we expose only limited or relevant methods
/// through python.
#[pymethods]
impl PersistenceSession {
#[new]
pub fn new_session() -> Self {
Self::new()
}

pub fn new_query(
slf: PyRef<'_, Self>,
sql: String,
metadata: HashMap<String, String>,
parquet_type: ParquetType,
) -> PersistenceQuery {
match block_on(slf.query(&sql)) {
Ok(query_result) => {
let boxed =
Box::leak(Box::new(query_result)) as *mut PersistenceQueryResult as *mut c_void;
PersistenceQuery {
query_result: boxed,
metadata,
parquet_type,
current_chunk: None,
}
}
Err(err) => panic!("failed new_query with error {}", err),
}
}

pub fn register_parquet_file(slf: PyRef<'_, Self>, table_name: String, path: String) {
match block_on(slf.register_parquet(&table_name, &path, ParquetReadOptions::default())) {
Ok(_) => (),
Err(err) => panic!("failed register_parquet_file with error {}", err),
}
}
}

#[pyclass]
pub struct PersistenceQuery {
query_result: *mut c_void,
metadata: HashMap<String, String>,
parquet_type: ParquetType,
current_chunk: Option<CVec>,
}

/// Empty derivation for Send to satisfy `pyclass` requirements,
/// however this is only designed for single threaded use for now.
unsafe impl Send for PersistenceQuery {}

#[pymethods]
impl PersistenceQuery {
/// The reader implements an iterator.
fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
slf
}

/// Each iteration returns a chunk of values read from the parquet file.
fn __next__(mut slf: PyRefMut<'_, Self>) -> Option<PyObject> {
slf.drop_chunk();
let mut query_result =
unsafe { Box::from_raw(slf.query_result as *mut PersistenceQueryResult) };

let chunk: Option<CVec> = match slf.parquet_type {
ParquetType::QuoteTick => {
if let Some(batch) = query_result.next() {
Some(QuoteTick::decode_batch(&slf.metadata, batch).into())
} else {
None
}
}
// TODO implement decode batch for trade tick
ParquetType::TradeTick => None,
};

// Leak reader value back otherwise it will be dropped after this function
Box::into_raw(query_result);
slf.current_chunk = chunk;
match chunk {
Some(cvec) => Python::with_gil(|py| {
Some(PyCapsule::new::<CVec>(py, cvec, None).unwrap().into_py(py))
}),
None => None,
}
}
}

impl PersistenceQuery {
/// Chunks generated by iteration must be dropped after use, otherwise
/// it will leak memory. Current chunk is held by the reader,
/// drop if exists and reset the field.
fn drop_chunk(&mut self) {
if let Some(CVec { ptr, len, cap }) = self.current_chunk {
match self.parquet_type {
ParquetType::QuoteTick => {
let data: Vec<QuoteTick> =
unsafe { Vec::from_raw_parts(ptr as *mut QuoteTick, len, cap) };
drop(data);
}
ParquetType::TradeTick => {
let data: Vec<TradeTick> =
unsafe { Vec::from_raw_parts(ptr as *mut TradeTick, len, cap) };
drop(data);
}
}

// reset current chunk field
self.current_chunk = None;
};
}
}

impl Drop for PersistenceQuery {
fn drop(&mut self) {
self.drop_chunk();
let query_result = unsafe { Box::from_raw(self.query_result as *mut PersistenceQuery) };
self.query_result = null_mut();
}
}
23 changes: 23 additions & 0 deletions nautilus_core/persistence/tests/test_persistence_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,32 @@
from nautilus_trader.core.nautilus_pyo3.persistence import ParquetReaderType
from nautilus_trader.core.nautilus_pyo3.persistence import ParquetType
from nautilus_trader.core.nautilus_pyo3.persistence import ParquetWriter
from nautilus_trader.core.nautilus_pyo3.persistence import PersistenceSession
from nautilus_trader.core.nautilus_pyo3.persistence import PersistenceQuery
from nautilus_trader.model.data.tick import QuoteTick
from nautilus_trader.model.data.tick import TradeTick

def test_python_persistence_reader():
parquet_data_path = os.path.join(PACKAGE_ROOT, "tests/test_data/quote_tick_data.parquet")
session = PersistenceSession()
session.register_parquet_file("quote_ticks", parquet_data_path)

metadata = {
"instrument_id": "EUR/USD.SIM",
"price_precision": "5",
"size_precision": "0",
}
query_result = session.new_query("SELECT * FROM quote_ticks SORT BY ts_init", metadata, ParquetType.QuoteTick)
total_count = 0
print("query result")
for chunk in query_result:
tick_list = QuoteTick.list_from_capsule(chunk)
total_count += len(tick_list)

assert total_count == 9500
# test on last chunk tick i.e. 9500th record
assert str(tick_list[-1]) == "EUR/USD.SIM,1.12130,1.12132,0,0,1577919652000000125"


def test_python_parquet_reader():
parquet_data_path = os.path.join(PACKAGE_ROOT, "tests/test_data/quote_tick_data.parquet")
Expand Down

0 comments on commit 8f85d21

Please sign in to comment.