Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial Function as Argument support for code2amr #723

Merged
merged 13 commits into from
Dec 7, 2023
28 changes: 24 additions & 4 deletions skema/skema-rs/skema/src/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3745,9 +3745,7 @@ pub fn create_att_expression(
for att_sub_box in att_box.bf.as_ref().unwrap().iter() {
new_c_args.box_counter = box_counter;
new_c_args.cur_box = att_sub_box.clone();
if att_sub_box.contents.is_some() {
new_c_args.att_idx = att_sub_box.contents.unwrap() as usize;
}
new_c_args.att_idx = c_args.att_idx;
match att_sub_box.function_type {
FunctionType::Literal => {
create_att_literal(
Expand All @@ -3769,6 +3767,17 @@ pub fn create_att_expression(
new_c_args.clone(),
);
}
FunctionType::Expression => {
new_c_args.att_idx = att_sub_box.contents.unwrap() as usize;
create_att_expression(
gromet, // gromet for metadata
nodes, // nodes
edges,
meta_nodes,
start,
new_c_args.clone(),
);
}
_ => {}
}
box_counter += 1;
Expand All @@ -3784,6 +3793,14 @@ pub fn create_att_expression(
c_args.bf_counter,
);

cross_att_wiring(
att_box.clone(),
nodes,
edges,
c_args.att_idx,
c_args.bf_counter,
);

// Now we also perform wopio wiring in case there is an empty expression
if att_box.wopio.is_some() {
wopio_wiring(att_box, nodes, edges, c_args.att_idx - 1, c_args.bf_counter);
Expand Down Expand Up @@ -5195,6 +5212,7 @@ pub fn wff_cross_att_wiring(
bf_counter: u8, // this is the current box
) {
for wire in eboxf.wff.as_ref().unwrap().iter() {
let mut prop = None;
// collect info to identify the opi src node
let src_idx = wire.src; // port index
let src_pif = eboxf.pif.as_ref().unwrap()[(src_idx - 1) as usize].clone(); // src port
Expand Down Expand Up @@ -5390,10 +5408,11 @@ pub fn wff_cross_att_wiring(
// only opo's
if node.n_type == "Primitive" || node.n_type == "Abstract" {
// iterate through port to check for tgt
for p in node.in_indx.as_ref().unwrap().iter() {
for (i, p) in node.in_indx.as_ref().unwrap().iter().enumerate() {
// push the src first, being pif
if (src_idx as u32) == *p {
wff_src_tgt.push(node.node_id.clone());
prop = Some(i);
}
}
}
Expand Down Expand Up @@ -5425,6 +5444,7 @@ pub fn wff_cross_att_wiring(
src: wff_src_tgt[0].clone(),
tgt: wff_src_tgt[1].clone(),
e_type: String::from("Wire"),
prop: Some(prop.unwrap()),
..Default::default()
};
edges.push(e8);
Expand Down
206 changes: 165 additions & 41 deletions skema/skema-rs/skema/src/model_extraction.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use crate::config::Config;

use mathml::ast::operator::Operator;
pub use mathml::mml2pn::{ACSet, Term};

use petgraph::prelude::*;
use petgraph::visit::IntoNeighborsDirected;

use std::string::ToString;

Expand Down Expand Up @@ -38,22 +41,22 @@ pub struct ModelEdge {

/**
* This is the main function call for model extraction.
*
*
* Parameters:
* - module_id: i64 -> This is the top level id of the gromet module in memgraph.
* - module_id: i64 -> This is the top level id of the gromet module in memgraph.
* - config: Config -> This is a config struct for connecting to memgraph
*
*
* Returns:
* - Vector of FirstOrderODE -> This vector of structs is used to construct a PetriNet or RegNet further down the pipeline
*
*
* Assumptions:
* - As of right now, we can always assume the code has been sliced to only one relevant function which contains the
* - As of right now, we can always assume the code has been sliced to only one relevant function which contains the
* core dynamics in it somewhere
*
* Notes:
*
* Notes:
* - FirstOrderODE is primarily composed of a LHS and a RHS,
* - LHS is just a Mi object of the state being differentiated. There are additional fields for the LHS but only the
* content field is used in downstream inference for now.
* content field is used in downstream inference for now.
* - RHS is where the bulk of the inference happens, it produces an expression tree, hence the MET -> Math Expression Tree.
* Every operator has a vector of arguments. (order matters)
*/
Expand Down Expand Up @@ -88,14 +91,14 @@ pub async fn module_id2mathml_MET_ast(module_id: i64, config: Config) -> Vec<Fir
core_dynamics_ast
}

/**
/**
* This function finds the core dynamics and returns a vector of node id's that meet the criteria for identification
*
* Based on the fact we are getting in only the function we expect to have dynamics, this should just be depricated in the future
* and replaced with a inference to move from the module_id to the top level function node id, however for now we should keep it
* as a simple heuristic because it is used in the original code2amr (zip repo) endpoint which would need to be updated first.
*
* Plus the case when it fails defaults to a emptry AMR which is preferable to crashing.
*
* Based on the fact we are getting in only the function we expect to have dynamics, this should just be depricated in the future
* and replaced with a inference to move from the module_id to the top level function node id, however for now we should keep it
* as a simple heuristic because it is used in the original code2amr (zip repo) endpoint which would need to be updated first.
*
* Plus the case when it fails defaults to a emptry AMR which is preferable to crashing.
*/
#[allow(clippy::if_same_then_else)]
pub async fn find_pn_dynamics(module_id: i64, config: Config) -> Vec<i64> {
Expand Down Expand Up @@ -154,11 +157,10 @@ pub async fn find_pn_dynamics(module_id: i64, config: Config) -> Vec<i64> {
core_id
}


/**
* Once the function node has been identified, this function takes it from there to extract the vector of FirstOrderODE's
*
* This is based heavily on the assumption that each equation is in a seperate expression which breaks for the vector case.
*
* This is based heavily on the assumption that each equation is in a seperate expression which breaks for the vector case.
*/
#[allow(non_snake_case)]
pub async fn subgrapg2_core_dyn_MET_ast(
Expand All @@ -184,16 +186,31 @@ pub async fn subgrapg2_core_dyn_MET_ast(
let mut sub_w = subgraph_wiring(graph[expression_nodes[i]].id, config.clone())
.await
.unwrap();
if sub_w.node_count() > 3 {
let expr = trim_un_named(&mut sub_w).await;
let mut prim_counter = 0;
let mut has_call = false;
for node_index in sub_w.node_indices() {
if sub_w[node_index].label == *"Primitive" {
prim_counter += 1;
if *sub_w[node_index].name.as_ref().unwrap() == "_call" {
has_call = true;
}
}
}
if sub_w.node_count() > 3 && !(prim_counter == 1 && has_call) && prim_counter != 0 {
println!("expression: {}", graph[expression_nodes[i]].id);
// the call expressions get referenced by multiple top level expressions, so deleting the nodes in it breaks the other graphs. Need to pass clone of expression subgraph so references to original has all the nodes.
if has_call {
sub_w = trim_calls(sub_w.clone())
}
let expr = trim_un_named(&mut sub_w);
let mut root_node = Vec::<NodeIndex>::new();
for node_index in expr.node_indices() {
if expr[node_index].label.clone() == *"Opo" {
root_node.push(node_index);
}
}
if root_node.len() >= 2 {
// println!("More than one Opo! Skipping Expression!");
println!("More than one Opo! Skipping Expression!");
} else {
core_dynamics.push(tree_2_MET_ast(expr, root_node[0]).unwrap());
}
Expand All @@ -203,9 +220,8 @@ pub async fn subgrapg2_core_dyn_MET_ast(
Ok(core_dynamics)
}


/**
* This function is designed to take in a petgraph instance of a wires only expression subgraph and output a FirstOrderODE equations representing it.
* This function is designed to take in a petgraph instance of a wires only expression subgraph and output a FirstOrderODE equations representing it.
*/
#[allow(non_snake_case)]
fn tree_2_MET_ast(
Expand Down Expand Up @@ -336,16 +352,16 @@ pub fn get_operator_MET(
}

/**
* This function takes in a wiring only petgraph of an expression and trims off the un-named nodes and unpack nodes.
*
* This is done by creating new edges that bypass the un-named nodes and then deleting them from the graph.
* For deleting the unpacks, the assumption is they are always terminal in the subgraph and can be deleted freely.
*
* Concerns:
* - I don't think this will work if there are multiple un-named nodes changed together. I haven't seen this in practice,
* but I think it's possible. So something to keep in mind.
* This function takes in a wiring only petgraph of an expression and trims off the un-named nodes and unpack nodes.
*
* This is done by creating new edges that bypass the un-named nodes and then deleting them from the graph.
* For deleting the unpacks, the assumption is they are always terminal in the subgraph and can be deleted freely.
*
* Concerns:
* - I don't think this will work if there are multiple un-named nodes changed together. I haven't seen this in practice,
* but I think it's possible. So something to keep in mind.
*/
async fn trim_un_named(
fn trim_un_named(
graph: &mut petgraph::Graph<ModelNode, ModelEdge>,
) -> &mut petgraph::Graph<ModelNode, ModelEdge> {
// first create a cloned version of the graph we can modify while iterating over it.
Expand All @@ -354,34 +370,34 @@ async fn trim_un_named(
for node_index in graph.node_indices() {
if graph[node_index].clone().name.unwrap().clone() == *"un-named" {
let mut bypass = Vec::<NodeIndex>::new();
let mut outgoing_bypass = Vec::<NodeIndex>::new();
for node1 in graph.neighbors_directed(node_index, Incoming) {
bypass.push(node1);
}
for node2 in graph.neighbors_directed(node_index, Outgoing) {
bypass.push(node2);
outgoing_bypass.push(node2);
}
// one incoming one outgoing
if bypass.len() == 2 {
if bypass.len() == 1 && outgoing_bypass.len() == 1 {
// annoyingly have to pull the edge/Relation to insert into graph
graph.add_edge(
bypass[0],
bypass[1],
outgoing_bypass[0],
graph
.edge_weight(graph.find_edge(bypass[0], node_index).unwrap())
.unwrap()
.clone(),
);
} else if bypass.len() > 2 {
} else if bypass.len() >= 2 && outgoing_bypass.len() == 1 {
// this operates on the assumption that there maybe multiple references to the port
// (incoming arrows) but only one outgoing arrow, this seems to be the case based on
// data too.

let end_node_idx = bypass.len() - 1;
for (i, _ent) in bypass[0..end_node_idx].iter().enumerate() {
for (i, _ent) in bypass.iter().enumerate() {
// this iterates over all but the last entry in the bypass vec
graph.add_edge(
bypass[i],
bypass[end_node_idx],
outgoing_bypass[0],
graph
.edge_weight(graph.find_edge(bypass[i], node_index).unwrap())
.unwrap()
Expand All @@ -405,7 +421,7 @@ async fn trim_un_named(
graph
}

/// This function takes in a node id (typically that of an expression subgraph) and returns a
/// This function takes in a node id (typically that of an expression subgraph) and returns a
/// petgraph subgraph of only the wire type edges
async fn subgraph_wiring(
module_id: i64,
Expand Down Expand Up @@ -535,7 +551,6 @@ pub async fn get_subgraph(
module_id: i64,
config: Config,
) -> Result<(Vec<ModelNode>, Vec<ModelEdge>), Error> {

let mut node_list = Vec::<ModelNode>::new();
let mut edge_list = Vec::<ModelEdge>::new();

Expand Down Expand Up @@ -591,3 +606,112 @@ pub async fn get_subgraph(

Ok((node_list, edge_list))
}

// this does special trimming to handle function calls
pub fn trim_calls(
graph: petgraph::Graph<ModelNode, ModelEdge>,
) -> petgraph::Graph<ModelNode, ModelEdge> {
let mut graph_clone = graph.clone();

// This will be all the nodes to be deleted
let mut inner_nodes = Vec::<NodeIndex>::new();
// find the call nodes
for node_index in graph.node_indices() {
if graph[node_index].clone().name.unwrap().clone() == *"_call" {
// we now trace up the incoming path until we hit a primitive,
// this will be the start node for the new edge.

// initialize trackers
let mut node_start = node_index;
let mut node_end = node_index;

// find end node and track path
for node in graph.neighbors_directed(node_index, Outgoing) {
if graph
.edge_weight(graph.find_edge(node_index, node).unwrap())
.unwrap()
.index
.unwrap()
== 0
{
let mut temp = to_terminal(graph.clone(), node);
node_end = temp.0;
inner_nodes.append(&mut temp.1);
}
}

// find start primtive node and track path
for node in graph.neighbors_directed(node_index, Incoming) {
let mut temp = to_primitive(graph.clone(), node);
node_start = temp.0;
inner_nodes.append(&mut temp.1);
}

// add edge from start to end node, with weight from start node a matching outgoing node form it
for node in graph.clone().neighbors_directed(node_start, Outgoing) {
for node_p in inner_nodes.iter() {
if node == *node_p {
graph_clone.add_edge(
node_start,
node_end,
graph
.clone()
.edge_weight(graph.clone().find_edge(node_start, node).unwrap())
.unwrap()
.clone(),
);
}
}
}
// we keep track all the node indexes we found while tracing the path and delete all
// intermediate nodes.
inner_nodes.push(node_index);
}
}
inner_nodes.sort();
for node in inner_nodes.iter().rev() {
graph_clone.remove_node(*node);
}

graph_clone
}

pub fn to_terminal(
graph: petgraph::Graph<ModelNode, ModelEdge>,
node_index: NodeIndex,
) -> (NodeIndex, Vec<NodeIndex>) {
let mut node_vec = Vec::<NodeIndex>::new();
let mut end_node = node_index;
// if there another node deeper
// else pass original input node out and an empty path vector
if graph.neighbors_directed(node_index, Outgoing).count() != 0 {
node_vec.push(node_index); // add current node to path list
for node in graph.neighbors_directed(node_index, Outgoing) {
// pass next node forward
let mut temp = to_terminal(graph.clone(), node);
end_node = temp.0; // make end_node
node_vec.append(&mut temp.1); // append previous path nodes
}
}
(end_node, node_vec)
}

// incoming walker to first primitive (NOTE: assumes input is not a primitive)
pub fn to_primitive(
graph: petgraph::Graph<ModelNode, ModelEdge>,
node_index: NodeIndex,
) -> (NodeIndex, Vec<NodeIndex>) {
let mut node_vec = Vec::<NodeIndex>::new();
let mut end_node = node_index;
node_vec.push(node_index);
for node in graph.neighbors_directed(node_index, Incoming) {
if graph[node].label.clone() != *"Primitive" {
let mut temp = to_primitive(graph.clone(), node);
end_node = temp.0;
node_vec.append(&mut temp.1);
} else {
end_node = node;
}
}
(end_node, node_vec)
}
Loading