Skip to content

Commit

Permalink
Rust neo4rs Driver Switch (#704)
Browse files Browse the repository at this point in the history
## Summary of Changes
This should update the Rust side to using neo4rs neo4j drivers from the
rsmgclient client drivers for interacting with memgraph.

Needed a fair bit of refactoring in the model_extraction.rs file in
order to accommodate this change, so I suggest most review effort should
be there.

---------

Co-authored-by: Gus Hahn-Powell <[email protected]>
Co-authored-by: Justin <[email protected]>
  • Loading branch information
3 people authored Nov 30, 2023
1 parent 4dee88c commit 7fdf749
Show file tree
Hide file tree
Showing 16 changed files with 609 additions and 1,490 deletions.
1 change: 1 addition & 0 deletions .github/workflows/tests-and-docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ env:
# used by skema-rs services
SKEMA_GRAPH_DB_HOST: "127.0.0.1"
SKEMA_GRAPH_DB_PORT: "7687"
SKEMA_GRAPH_DB_PROTO: "bolt://"
SKEMA_RS_HOST: "127.0.0.1"
SKEMA_RS_PORT: "8001"
# used by Python services.
Expand Down
12 changes: 9 additions & 3 deletions skema/rest/morae_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,27 @@


# FIXME: make GrometFunctionModuleCollection a pydantic model via code gen
@router.post("/model", summary="Pushes gromet (function network) to the graph database")
@router.post("/model", summary="Pushes gromet (function network) to the graph database", include_in_schema=False)
async def post_model(gromet: Dict[Text, Any]):
return requests.post(f"{SKEMA_RS_ADDESS}/models", json=gromet).json()


@router.get("/models", summary="Gets function network IDs from the graph database")
async def get_models() -> List[str]:
return requests.get(f"{SKEMA_RS_ADDESS}/models").json()
async def get_models() -> List[int]:
request = requests.get(f"{SKEMA_RS_ADDESS}/models")
print(f"request: {request}")
return request.json()


@router.get("/ping", summary="Status of MORAE service")
async def healthcheck() -> int:
return requests.get(f"{SKEMA_RS_ADDESS}/ping").status_code


@router.get("/version", summary="Status of MORAE service")
async def versioncheck() -> str:
return requests.get(f"{SKEMA_RS_ADDESS}/version").text

@router.post("/mathml/decapodes", summary="Gets Decapodes from a list of MathML strings")
async def get_decapodes(mathml: List[str]) -> Dict[Text, Any]:
return requests.put(f"{SKEMA_RS_ADDESS}/mathml/decapodes", json=mathml).json()
196 changes: 3 additions & 193 deletions skema/skema-rs/mathml/src/acset.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
//! Structs to represent elements of ACSets (Annotated C-Sets, a concept from category theory).
//! JSON-serialized ACSets are the form of model exchange between TA1 and TA2.
use crate::parsers::first_order_ode::{get_terms, FirstOrderODE, PnTerm};
use crate::{
ast::{Math, MathExpression, Mi},
mml2pn::{group_by_operators, Term},
petri_net::{Polarity, Var},
};

use serde::{Deserialize, Serialize};
use std::collections::{BTreeSet, HashMap, HashSet};
use std::collections::BTreeSet;
use utoipa;
use utoipa::ToSchema;

Expand Down Expand Up @@ -367,7 +363,7 @@ impl From<Vec<FirstOrderODE>> for PetriNet {

// now to trim off terms that are for euler methods, dyn_state != exp_state && parameters.len() != 0
for term in dirty_terms.iter() {
if term.dyn_state != term.exp_states[0] || term.parameters.len() != 0 {
if term.dyn_state != term.exp_states[0] || !term.parameters.is_empty() {
terms.push(term.clone());
}
}
Expand Down Expand Up @@ -862,189 +858,3 @@ impl From<Vec<FirstOrderODE>> for RegNet {
}
}
}

// This function takes in a mathml string and returns a Regnet
impl From<Vec<Math>> for RegNet {
fn from(mathml_asts: Vec<Math>) -> RegNet {
// this algorithm to follow should be refactored into a seperate function once it is functional

let mut specie_vars = HashSet::<Var>::new();
let mut vars = HashSet::<Var>::new();
let mut eqns = HashMap::<Var, Vec<Term>>::new();

for ast in mathml_asts.into_iter() {
group_by_operators(ast, &mut specie_vars, &mut vars, &mut eqns);
}

// Get the rate variables
let _rate_vars: HashSet<&Var> = vars.difference(&specie_vars).collect();

// -----------------------------------------------------------
// -----------------------------------------------------------

let mut states_vec = BTreeSet::<RegState>::new();
let mut transitions_vec = BTreeSet::<RegTransition>::new();

for state in specie_vars.clone().into_iter() {
// state bits
let mut rate_const = "temp".to_string();
let mut state_name = "temp".to_string();
let mut term_idx = 0;
let mut rate_sign = false;

//transition bits
let mut trans_name = "temp".to_string();
let mut trans_sign = false;
let _trans_tgt = "temp".to_string();
let mut trans_src = "temp".to_string();

for (i, term) in eqns[&state].iter().enumerate() {
for variable in term.vars.clone().iter() {
if state == variable.clone() && term.vars.len() == 2 {
term_idx = i;
}
}
}

// Positive rate sign: source, negative => sink.
if eqns[&state.clone()][term_idx].polarity == Polarity::Positive {
rate_sign = true;
}

for variable in eqns[&state][term_idx].vars.iter() {
if state.clone() != variable.clone() {
match variable.clone() {
Var(MathExpression::Mi(Mi(x))) => {
rate_const = x.clone();
}
_ => {
println!("Error in rate extraction");
}
};
} else {
match variable.clone() {
Var(MathExpression::Mi(Mi(x))) => {
state_name = x.clone();
}
_ => {
println!("Error in rate extraction");
}
};
}
}

let states = RegState {
id: state_name.clone(),
name: state_name.clone(),
sign: Some(rate_sign),
rate_constant: Some(rate_const.clone()),
..Default::default()
};
states_vec.insert(states.clone());

// now to make the transition part ----------------------------------

for (i, term) in eqns[&state].iter().enumerate() {
if i != term_idx {
if term.polarity == Polarity::Positive {
trans_sign = true;
}
let mut state_indx = 0;
let mut other_state_indx = 0;
for (j, var) in term.vars.iter().enumerate() {
if state.clone() == var.clone() {
state_indx = j;
}
for other_states in specie_vars.clone().into_iter() {
if *var != state && *var == other_states {
// this means it is not the state, but is another state
other_state_indx = j;
}
}
}
for (j, var) in term.vars.iter().enumerate() {
if j == other_state_indx {
match var.clone() {
Var(MathExpression::Mi(Mi(x))) => {
trans_src = x.clone();
}
_ => {
println!("error in trans src extraction");
}
};
} else if j != other_state_indx && j != state_indx {
match var.clone() {
Var(MathExpression::Mi(Mi(x))) => {
trans_name = x.clone();
}
_ => {
println!("error in trans name extraction");
}
};
}
}
}
}

let prop = Properties {
name: trans_name.clone(),
rate_constant: Some(trans_name.clone()),
};

let transitions = RegTransition {
id: trans_name.clone(),
target: Some([state_name.clone()].to_vec()), // tgt
source: Some([trans_src.clone()].to_vec()), // src
sign: Some(trans_sign),
properties: Some(prop.clone()),
..Default::default()
};

transitions_vec.insert(transitions.clone());
}

// -----------------------------------------------------------

let model = ModelRegNet {
vertices: states_vec,
edges: transitions_vec,
parameters: None,
};

let header = Header {
name: "Regnet mathml model".to_string(),
schema: "https://raw.githubusercontent.com/DARPA-ASKEM/Model-Representations/regnet_v0.1/regnet/regnet_schema.json".to_string(),
schema_name: "regnet".to_string(),
description: "This is a Regnet model from mathml equations".to_string(),
model_version: "0.1".to_string(),
};

RegNet {
header,
model,
metadata: None,
}
}
}

/*#[test]
fn test_lotka_volterra_mml_to_regnet() {
let input: serde_json::Value =
serde_json::from_str(&std::fs::read_to_string("tests/mml2amr_input_1.json").unwrap())
.unwrap();
let elements: Vec<Math> = input["mathml"]
.as_array()
.unwrap()
.iter()
.map(|x| x.as_str().unwrap().parse::<Math>().unwrap())
.collect();
let regnet = RegNet::from(elements);
let desired_output: RegNet =
serde_json::from_str(&std::fs::read_to_string("tests/mml2amr_output_1.json").unwrap())
.unwrap();
assert_eq!(regnet, desired_output);
}*/
2 changes: 2 additions & 0 deletions skema/skema-rs/mathml/src/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ impl Expr {

/// 1) distribute variables and terms over multiplications, e.g., a*(b+c) => a*b+a*c
/// 2) distribute variables and terms over divisions, e.g., a/(b+c)/(e+f) => a/(be+bf+ce+cf)
#[allow(dead_code)] // used in tests I believe
fn distribute_expr(&mut self) {
if let Expr::Expression { ops, args, .. } = self {
let mut ops_copy = ops.clone();
Expand Down Expand Up @@ -1024,6 +1025,7 @@ impl Expression {
}
}

#[allow(dead_code)] // used in tests I believe
fn distribute_expr(&mut self) {
for arg in &mut self.args {
if let Expr::Expression { .. } = arg {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ fn test_serialize_from_image_3_2() {
</math>
";
let expression = input.parse::<MathExpressionTree>().unwrap();
let s_exp = expression.to_string();
let _s_exp = expression.to_string();
let wiring_diagram = to_wiring_diagram(&expression);
let json = to_decapodes_json(wiring_diagram);
assert_eq!(json, "{\"Var\":[{\"type\":\"infer\",\"name\":\"mult_1\"},{\"type\":\"infer\",\"name\":\"mult_2\"},{\"type\":\"infer\",\"name\":\"•1\"},{\"type\":\"Literal\",\"name\":\"2\"},{\"type\":\"infer\",\"name\":\"sum_1\"},{\"type\":\"infer\",\"name\":\"n\"},{\"type\":\"infer\",\"name\":\"A\"},{\"type\":\"infer\",\"name\":\"•2\"},{\"type\":\"infer\",\"name\":\"mult_3\"},{\"type\":\"infer\",\"name\":\"ρ\"},{\"type\":\"infer\",\"name\":\"g\"},{\"type\":\"infer\",\"name\":\"Γ\"}],\"Op1\":[],\"Op2\":[{\"proj1\":4,\"proj2\":5,\"res\":3,\"op2\":\"/\"},{\"proj1\":3,\"proj2\":7,\"res\":2,\"op2\":\"*\"},{\"proj1\":10,\"proj2\":11,\"res\":9,\"op2\":\"*\"},{\"proj1\":9,\"proj2\":6,\"res\":8,\"op2\":\"^\"},{\"proj1\":2,\"proj2\":8,\"res\":1,\"op2\":\"*\"}],\"Σ\":[{\"sum\":5}],\"Summand\":[{\"summand\":6,\"summation\":1},{\"summand\":4,\"summation\":1}]}");
Expand Down
4 changes: 2 additions & 2 deletions skema/skema-rs/mathml/src/parsers/interpreted_mathml.rs
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@ pub fn absolute_with_msup(input: Span) -> IResult<MathExpression> {
//ws(tag("<mo>|</mo>")),
ws(tuple((
//math_expression,
map(ws(many0(math_expression)), |z| Mrow(z)),
map(ws(many0(math_expression)), Mrow),
preceded(ws(tag("<msup><mo>|</mo>")), ws(math_expression)),
))),
ws(tag("</msup>")),
Expand All @@ -568,7 +568,7 @@ pub fn paren_as_msup(input: Span) -> IResult<MathExpression> {
ws(delimited(
tag("<mo>(</mo>"),
tuple((
map(many0(math_expression), |z| Mrow(z)),
map(many0(math_expression), Mrow),
preceded(tag("<msup><mo>)</mo>"), math_expression),
)),
tag("</msup>"),
Expand Down
Loading

0 comments on commit 7fdf749

Please sign in to comment.