Skip to content

Commit

Permalink
feat: passthrough nodes (#261)
Browse files Browse the repository at this point in the history
* feat: passthrough nodes

* update bindings

* update model

* update exclusion

* fix merge strategy

* fix

* fix merge

* add tests

* fix walk

* fix merge

* fix

* fix: add tests

* fix types

---------

Co-authored-by: Ivan Miletic <[email protected]>
  • Loading branch information
stefan-gorules and ivanmiletic authored Oct 23, 2024
1 parent 8a7e72b commit 4781214
Show file tree
Hide file tree
Showing 38 changed files with 1,438 additions and 251 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/rust.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ jobs:
- uses: actions/checkout@v3
- name: Install Rust
run: rustup install 1.80
- run: cargo test --workspace --all-features --exclude zen-ffi --exclude zen-nodejs
- run: cargo test --workspace --all-features --exclude zen-ffi --exclude zen-nodejs --release
- run: cargo test --workspace --all-features --exclude zen-ffi --exclude zen-nodejs --exclude zen-python
- run: cargo test --workspace --all-features --exclude zen-ffi --exclude zen-nodejs --exclude zen-python --release

build:
name: cargo +${{ matrix.rust }} build
Expand Down
2 changes: 1 addition & 1 deletion bindings/c/src/custom_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ impl Default for DynamicCustomNode {
}

impl CustomNodeAdapter for DynamicCustomNode {
async fn handle(&self, request: CustomNodeRequest<'_>) -> NodeResult {
async fn handle(&self, request: CustomNodeRequest) -> NodeResult {
match self {
DynamicCustomNode::Noop(cn) => cn.handle(request).await,
DynamicCustomNode::Native(cn) => cn.handle(request).await,
Expand Down
2 changes: 1 addition & 1 deletion bindings/c/src/languages/go.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ impl GoCustomNode {
}

impl CustomNodeAdapter for GoCustomNode {
async fn handle(&self, request: CustomNodeRequest<'_>) -> NodeResult {
async fn handle(&self, request: CustomNodeRequest) -> NodeResult {
let Some(handler) = self.handler else {
return Err(anyhow!("go handler not found"));
};
Expand Down
2 changes: 1 addition & 1 deletion bindings/c/src/languages/native.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ impl NativeCustomNode {
}

impl CustomNodeAdapter for NativeCustomNode {
async fn handle(&self, request: CustomNodeRequest<'_>) -> NodeResult {
async fn handle(&self, request: CustomNodeRequest) -> NodeResult {
let Ok(request_value) = serde_json::to_string(&request) else {
return Err(anyhow!("failed to serialize request json"));
};
Expand Down
3 changes: 2 additions & 1 deletion bindings/nodejs/.gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
*.node
*.internal.js
*.internal.js
temp.d.ts
2 changes: 1 addition & 1 deletion bindings/nodejs/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
"url": "git+https://github.com/gorules/zen.git"
},
"scripts": {
"build": "napi build --dts false --platform --release",
"build": "napi build --dts temp.d.ts --platform --release",
"build:debug": "napi build --platform --js index.js --dts index.d.ts",
"watch": "cargo watch --ignore '{index.js,index.d.ts}' -- npm run build:debug",
"test": "jest",
Expand Down
2 changes: 1 addition & 1 deletion bindings/nodejs/src/custom_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ impl CustomNode {
}

impl CustomNodeAdapter for CustomNode {
async fn handle(&self, request: CustomNodeRequest<'_>) -> NodeResult {
async fn handle(&self, request: CustomNodeRequest) -> NodeResult {
let Some(function) = &self.function else {
return Err(anyhow!("Custom function is undefined"));
};
Expand Down
18 changes: 9 additions & 9 deletions bindings/nodejs/src/types.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::collections::HashMap;

use json_dotpath::DotPaths;
use napi::anyhow::{anyhow, Context};
use napi_derive::napi;
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;

use zen_engine::handler::custom_node_adapter::CustomDecisionNode;
use zen_engine::{DecisionGraphResponse, DecisionGraphTrace};
Expand Down Expand Up @@ -65,16 +65,16 @@ pub struct DecisionNode {
pub id: String,
pub name: String,
pub kind: String,
pub config: Value,
pub config: Arc<Value>,
}

impl From<CustomDecisionNode<'_>> for DecisionNode {
fn from(value: CustomDecisionNode<'_>) -> Self {
impl From<CustomDecisionNode> for DecisionNode {
fn from(value: CustomDecisionNode) -> Self {
Self {
id: value.id.to_string(),
name: value.name.to_string(),
kind: value.kind.to_string(),
config: value.config.clone(),
id: value.id,
name: value.name,
kind: value.kind,
config: value.config,
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion bindings/python/src/custom_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ fn extract_custom_node_response(py: Python<'_>, result: PyObject) -> NodeResult
}

impl CustomNodeAdapter for PyCustomNode {
async fn handle(&self, request: CustomNodeRequest<'_>) -> NodeResult {
async fn handle(&self, request: CustomNodeRequest) -> NodeResult {
let Some(callable) = &self.0 else {
return Err(anyhow!("Custom node handler not provided"));
};
Expand Down
16 changes: 7 additions & 9 deletions bindings/python/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use json_dotpath::DotPaths;
use pyo3::{pyclass, pymethods, PyObject, PyResult, Python, ToPyObject};
use serde::Serialize;
use serde_json::Value;
use std::sync::Arc;

use crate::value::{value_to_object, PyValue};
use zen_engine::handler::custom_node_adapter::{
Expand All @@ -16,15 +17,15 @@ struct CustomDecisionNode {
pub id: String,
pub name: String,
pub kind: String,
pub config: Value,
pub config: Arc<Value>,
}

impl From<BaseCustomDecisionNode<'_>> for CustomDecisionNode {
impl From<BaseCustomDecisionNode> for CustomDecisionNode {
fn from(value: BaseCustomDecisionNode) -> Self {
Self {
id: value.id.to_string(),
name: value.name.to_string(),
kind: value.kind.to_string(),
id: value.id,
name: value.name,
kind: value.kind,
config: value.config.clone(),
}
}
Expand All @@ -42,10 +43,7 @@ pub struct PyNodeRequest {
}

impl PyNodeRequest {
pub fn from_request(
py: Python,
value: CustomNodeRequest<'_>,
) -> pythonize::Result<PyNodeRequest> {
pub fn from_request(py: Python, value: CustomNodeRequest) -> pythonize::Result<PyNodeRequest> {
let inner_node = value.node.into();
let node_val = serde_json::to_value(&inner_node).unwrap();

Expand Down
2 changes: 1 addition & 1 deletion core/engine/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ thiserror = { workspace = true }
bincode = { workspace = true, optional = true }
petgraph = { workspace = true }
serde_json = { workspace = true, features = ["arbitrary_precision"] }
serde = { workspace = true, features = ["derive"] }
serde = { workspace = true, features = ["derive", "rc"] }
once_cell = { workspace = true }
json_dotpath = { workspace = true }
rust_decimal = { workspace = true, features = ["maths-nopanic"] }
Expand Down
4 changes: 2 additions & 2 deletions core/engine/src/decision.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,25 +82,25 @@ where
options: EvaluationOptions,
) -> Result<DecisionGraphResponse, Box<EvaluationError>> {
let mut decision_graph = DecisionGraph::try_new(DecisionGraphConfig {
content: self.content.clone(),
max_depth: options.max_depth.unwrap_or(5),
trace: options.trace.unwrap_or_default(),
loader: Arc::new(CachedLoader::from(self.loader.clone())),
adapter: self.adapter.clone(),
iteration: 0,
content: &self.content,
})?;

Ok(decision_graph.evaluate(context).await?)
}

pub fn validate(&self) -> Result<(), DecisionGraphValidationError> {
let decision_graph = DecisionGraph::try_new(DecisionGraphConfig {
content: self.content.clone(),
max_depth: 1,
trace: false,
loader: Arc::new(CachedLoader::from(self.loader.clone())),
adapter: self.adapter.clone(),
iteration: 0,
content: &self.content,
})?;

decision_graph.validate()
Expand Down
43 changes: 21 additions & 22 deletions core/engine/src/handler/custom_node_adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,44 +4,43 @@ use anyhow::anyhow;
use json_dotpath::DotPaths;
use serde::Serialize;
use serde_json::Value;
use std::ops::Deref;
use std::sync::Arc;
use zen_expression::variable::Variable;
use zen_tmpl::TemplateRenderError;

pub trait CustomNodeAdapter {
fn handle(
&self,
request: CustomNodeRequest<'_>,
) -> impl std::future::Future<Output = NodeResult>;
fn handle(&self, request: CustomNodeRequest) -> impl std::future::Future<Output = NodeResult>;
}

#[derive(Default, Debug)]
pub struct NoopCustomNode;

impl CustomNodeAdapter for NoopCustomNode {
async fn handle(&self, _: CustomNodeRequest<'_>) -> NodeResult {
async fn handle(&self, _: CustomNodeRequest) -> NodeResult {
Err(anyhow!("Custom node handler not provided"))
}
}

#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
pub struct CustomNodeRequest<'a> {
pub struct CustomNodeRequest {
pub input: Variable,
pub node: CustomDecisionNode<'a>,
pub node: CustomDecisionNode,
}

impl<'a> TryFrom<&'a NodeRequest<'a>> for CustomNodeRequest<'a> {
impl TryFrom<NodeRequest> for CustomNodeRequest {
type Error = ();

fn try_from(value: &'a NodeRequest<'a>) -> Result<Self, Self::Error> {
fn try_from(value: NodeRequest) -> Result<Self, Self::Error> {
Ok(Self {
input: value.input.clone(),
node: value.node.try_into()?,
node: value.node.deref().try_into()?,
})
}
}

impl<'a> CustomNodeRequest<'a> {
impl CustomNodeRequest {
pub fn get_field(&self, path: &str) -> Result<Option<Variable>, TemplateRenderError> {
let Some(selected_value) = self.get_field_raw(path) else {
return Ok(None);
Expand All @@ -62,26 +61,26 @@ impl<'a> CustomNodeRequest<'a> {

#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
pub struct CustomDecisionNode<'a> {
pub id: &'a str,
pub name: &'a str,
pub kind: &'a str,
pub config: &'a Value,
pub struct CustomDecisionNode {
pub id: String,
pub name: String,
pub kind: String,
pub config: Arc<Value>,
}

impl<'a> TryFrom<&'a DecisionNode> for CustomDecisionNode<'a> {
impl TryFrom<&DecisionNode> for CustomDecisionNode {
type Error = ();

fn try_from(value: &'a DecisionNode) -> Result<Self, Self::Error> {
fn try_from(value: &DecisionNode) -> Result<Self, Self::Error> {
let DecisionNodeKind::CustomNode { content } = &value.kind else {
return Err(());
};

Ok(Self {
id: &value.id,
name: &value.name,
kind: &content.kind,
config: &content.config,
id: value.id.clone(),
name: value.name.clone(),
kind: content.kind.clone(),
config: content.config.clone(),
})
}
}
Loading

0 comments on commit 4781214

Please sign in to comment.