Skip to content

Commit

Permalink
feat: schema validation
Browse files Browse the repository at this point in the history
  • Loading branch information
stefan-gorules committed Oct 24, 2024
1 parent 11666cb commit 9f46fc1
Show file tree
Hide file tree
Showing 10 changed files with 293 additions and 107 deletions.
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ members = [

[workspace.dependencies]
ahash = "0.8.11"
bincode = "2.0.0-rc.3"
bumpalo = "3.16.0"
chrono = "0.4.38"
criterion = "0.5.1"
Expand Down
19 changes: 9 additions & 10 deletions actions/cargo-version-action/src/index.spec.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import { describe } from 'node:test';
import { expect, test } from '@jest/globals';
import { getCargoVersion, updateCargoContents } from './cargo';
import { inc, ReleaseType } from 'semver';
import {describe} from 'node:test';
import {expect, test} from '@jest/globals';
import {getCargoVersion, updateCargoContents} from './cargo';
import {inc, ReleaseType} from 'semver';
import * as fs from 'fs/promises';
import * as path from 'path';

// language=Toml
const makeToml = ({ version }): string => `
const makeToml = ({version}): string => `
[package]
authors = ["GoRules Team <[email protected]>"]
description = "Business rules engine"
Expand All @@ -21,7 +21,6 @@ const makeToml = ({ version }): string => `
anyhow = { workspace = true }
thiserror = { workspace = true }
async-trait = { workspace = true }
bincode = { workspace = true, optional = true }
serde_json = { workspace = true, features = ["arbitrary_precision"] }
serde = { version = "1", features = ["derive"] }
serde_v8 = { version = "0.88.0" }
Expand All @@ -35,21 +34,21 @@ const makeToml = ({ version }): string => `
describe('GitHub Action', () => {
test('Bumps package', () => {
const version = '0.2.0';
const initialToml = makeToml({ version });
const initialToml = makeToml({version});

const releases: ReleaseType[] = ['major', 'minor', 'patch'];
for (const release of releases) {
const newVersion = inc(version, release);
const expectedToml = makeToml({ version: newVersion });
const expectedToml = makeToml({version: newVersion});

expect(updateCargoContents(initialToml, { version: newVersion })).toEqual(expectedToml);
expect(updateCargoContents(initialToml, {version: newVersion})).toEqual(expectedToml);
}
});

test('Extracts package version', () => {
const versions = ['0.1.0', '0.2.0', '0.3.0'];
for (const version of versions) {
const versionedToml = makeToml({ version });
const versionedToml = makeToml({version});
expect(getCargoVersion(versionedToml)).toEqual(version);
}
});
Expand Down
5 changes: 1 addition & 4 deletions core/engine/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ doctest = false
ahash = { workspace = true }
anyhow = { workspace = true }
thiserror = { workspace = true }
bincode = { workspace = true, optional = true }
petgraph = { workspace = true }
serde_json = { workspace = true, features = ["arbitrary_precision"] }
serde = { workspace = true, features = ["derive", "rc"] }
Expand All @@ -28,14 +27,12 @@ rquickjs = { version = "0.6.2", features = ["macro", "loader", "rust-alloc", "fu
itertools = { workspace = true }
zen-expression = { path = "../expression", version = "0.33.0" }
zen-tmpl = { path = "../template", version = "0.33.0" }
jsonschema = { version = "0.25.0" }

[dev-dependencies]
tokio = { workspace = true, features = ["rt-multi-thread", "macros"] }
criterion = { workspace = true, features = ["async_tokio"] }

[features]
bincode = ["dep:bincode"]

[[bench]]
harness = false
name = "engine"
58 changes: 55 additions & 3 deletions core/engine/src/decision.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
use std::sync::Arc;

use crate::engine::EvaluationOptions;
use crate::handler::custom_node_adapter::{CustomNodeAdapter, NoopCustomNode};
use crate::handler::graph::{DecisionGraph, DecisionGraphConfig, DecisionGraphResponse};
use crate::loader::{CachedLoader, DecisionLoader, NoopLoader};
use crate::model::DecisionContent;
use crate::{DecisionGraphValidationError, EvaluationError};
use jsonschema::Validator;
use serde_json::Value;
use std::ops::Deref;
use std::sync::Arc;
use tokio::sync::RwLock;
use zen_expression::variable::Variable;

type SharedValidator = Arc<RwLock<Option<Arc<Validator>>>>;

/// Represents a JDM decision which can be evaluated
#[derive(Debug, Clone)]
pub struct Decision<Loader, CustomNode>
Expand All @@ -18,6 +23,9 @@ where
content: Arc<DecisionContent>,
loader: Arc<Loader>,
adapter: Arc<CustomNode>,

input_validator: SharedValidator,
output_validator: SharedValidator,
}

impl From<DecisionContent> for Decision<NoopLoader, NoopCustomNode> {
Expand All @@ -26,6 +34,9 @@ impl From<DecisionContent> for Decision<NoopLoader, NoopCustomNode> {
content: value.into(),
loader: NoopLoader::default().into(),
adapter: NoopCustomNode::default().into(),

input_validator: Arc::new(RwLock::new(None)),
output_validator: Arc::new(RwLock::new(None)),
}
}
}
Expand All @@ -36,6 +47,9 @@ impl From<Arc<DecisionContent>> for Decision<NoopLoader, NoopCustomNode> {
content: value,
loader: NoopLoader::default().into(),
adapter: NoopCustomNode::default().into(),

input_validator: Arc::new(RwLock::new(None)),
output_validator: Arc::new(RwLock::new(None)),
}
}
}
Expand All @@ -53,6 +67,9 @@ where
loader,
adapter: self.adapter,
content: self.content,

input_validator: self.input_validator,
output_validator: self.output_validator,
}
}

Expand All @@ -64,6 +81,9 @@ where
loader: self.loader,
adapter,
content: self.content,

input_validator: self.input_validator,
output_validator: self.output_validator,
}
}

Expand All @@ -81,6 +101,14 @@ where
context: Variable,
options: EvaluationOptions,
) -> Result<DecisionGraphResponse, Box<EvaluationError>> {
if let Some(input_schema) = &self.content.settings.validation.input_schema {
let input_validator =
get_validator(self.input_validator.clone(), &input_schema).await?;

let context_json = context.to_value();
input_validator.validate(&context_json)?;
}

let mut decision_graph = DecisionGraph::try_new(DecisionGraphConfig {
content: self.content.clone(),
max_depth: options.max_depth.unwrap_or(5),
Expand All @@ -90,7 +118,16 @@ where
iteration: 0,
})?;

Ok(decision_graph.evaluate(context).await?)
let response = decision_graph.evaluate(context).await?;
if let Some(output_schema) = &self.content.settings.validation.output_schema {
let output_validator =
get_validator(self.output_validator.clone(), &output_schema).await?;

let output_json = response.result.to_value();
output_validator.validate(&output_json)?;
}

Ok(response)
}

pub fn validate(&self) -> Result<(), DecisionGraphValidationError> {
Expand All @@ -106,3 +143,18 @@ where
decision_graph.validate()
}
}

async fn get_validator(
shared: SharedValidator,
schema: &Value,
) -> Result<Arc<Validator>, Box<EvaluationError>> {
if let Some(validator) = shared.read().await.deref() {
return Ok(validator.clone());
}

let mut w_shared = shared.write().await;
let validator = Arc::new(jsonschema::draft7::new(&schema)?);
w_shared.replace(validator.clone());

Ok(validator)
}
49 changes: 49 additions & 0 deletions core/engine/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use crate::handler::graph::DecisionGraphValidationError;
use crate::handler::node::NodeError;
use crate::loader::LoaderError;
use jsonschema::{ErrorIterator, ValidationError};
use serde::ser::SerializeMap;
use serde::{Serialize, Serializer};
use serde_json::{Map, Value};
use std::iter::once;
use thiserror::Error;

#[derive(Debug, Error)]
Expand All @@ -18,6 +21,9 @@ pub enum EvaluationError {

#[error("Invalid graph")]
InvalidGraph(Box<DecisionGraphValidationError>),

#[error("Validation failed")]
Validation(Box<Value>),
}

impl Serialize for EvaluationError {
Expand Down Expand Up @@ -51,6 +57,10 @@ impl Serialize for EvaluationError {
map.serialize_entry("type", "InvalidGraph")?;
map.serialize_entry("source", err)?;
}
EvaluationError::Validation(err) => {
map.serialize_entry("type", "Validation")?;
map.serialize_entry("source", err)?;
}
}

map.end()
Expand Down Expand Up @@ -86,3 +96,42 @@ impl From<DecisionGraphValidationError> for Box<EvaluationError> {
Box::new(EvaluationError::InvalidGraph(error.into()))
}
}

#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct ValidationErrorJson {
path: String,
message: String,
}

impl<'a> From<ValidationError<'a>> for ValidationErrorJson {
fn from(value: ValidationError<'a>) -> Self {
ValidationErrorJson {
path: value.instance_path.to_string(),
message: format!("{}", value),
}
}
}

impl<'a> From<ErrorIterator<'a>> for Box<EvaluationError> {
fn from(error_iter: ErrorIterator<'a>) -> Self {
let errors: Vec<ValidationErrorJson> = error_iter.into_iter().map(From::from).collect();

let mut json_map = Map::new();
json_map.insert(
"errors".to_string(),
serde_json::to_value(errors).unwrap_or_default(),
);

Box::new(EvaluationError::Validation(Box::new(Value::Object(
json_map,
))))
}
}

impl<'a> From<ValidationError<'a>> for Box<EvaluationError> {
fn from(value: ValidationError<'a>) -> Self {
let iterator: ErrorIterator<'a> = Box::new(once(value));
Box::<EvaluationError>::from(iterator)
}
}
Loading

0 comments on commit 9f46fc1

Please sign in to comment.