diff --git a/ascent/examples/ascent_agg_clause.rs b/ascent/examples/ascent_agg_clause.rs index 6432fdf..c3b456e 100644 --- a/ascent/examples/ascent_agg_clause.rs +++ b/ascent/examples/ascent_agg_clause.rs @@ -1,14 +1,15 @@ //! Aggregate `agg` clause -use ascent::{aggregators::{count, max, mean, min, sum}, ascent}; +use ascent::aggregators::{count, max, mean, min, sum}; +use ascent::ascent; ascent! { // Facts: relation number(i32); - + // Rules: - + relation lowest(i32); lowest(y) <-- agg y = min(x) in number(x); @@ -31,17 +32,17 @@ ascent! { } fn main() { - let mut prog = AscentProgram::default(); - - prog.number = (1..=5).map(|n| (n,)).collect(); + let mut prog = AscentProgram::default(); + + prog.number = (1..=5).map(|n| (n,)).collect(); - prog.run(); + prog.run(); - let AscentProgram { lowest, greatest, average, total, cardinality, ..} = prog; + let AscentProgram { lowest, greatest, average, total, cardinality, .. } = prog; - assert_eq!(lowest, vec![(1,)]); - assert_eq!(greatest, vec![(5,)]); - assert_eq!(average, vec![(3,)]); - assert_eq!(total, vec![(15,)]); - assert_eq!(cardinality, vec![(5,)]); + assert_eq!(lowest, vec![(1,)]); + assert_eq!(greatest, vec![(5,)]); + assert_eq!(average, vec![(3,)]); + assert_eq!(total, vec![(15,)]); + assert_eq!(cardinality, vec![(5,)]); } diff --git a/ascent/examples/ascent_disjunction_clause.rs b/ascent/examples/ascent_disjunction_clause.rs index e6bdc8a..f71affd 100644 --- a/ascent/examples/ascent_disjunction_clause.rs +++ b/ascent/examples/ascent_disjunction_clause.rs @@ -6,9 +6,9 @@ ascent! { // Facts: relation number(i32); - + // Rules: - + relation square(i32); square(y * y) <-- number(y), number(y * y); @@ -23,23 +23,15 @@ ascent! { } fn main() { - let mut prog = AscentProgram::default(); - - prog.number = (1..=10).map(|n| (n,)).collect(); + let mut prog = AscentProgram::default(); + + prog.number = (1..=10).map(|n| (n,)).collect(); - prog.run(); + prog.run(); - let AscentProgram { mut even_or_square, ..} = prog; + let AscentProgram { mut even_or_square, .. } = prog; - even_or_square.sort_by_key(|(key,)| *key); + even_or_square.sort_by_key(|(key,)| *key); - assert_eq!(even_or_square, vec![ - (1,), - (2,), - (4,), - (6,), - (8,), - (9,), - (10,), - ]); + assert_eq!(even_or_square, vec![(1,), (2,), (4,), (6,), (8,), (9,), (10,),]); } diff --git a/ascent/examples/ascent_for_in_clause.rs b/ascent/examples/ascent_for_in_clause.rs index 6089264..ebe33bb 100644 --- a/ascent/examples/ascent_for_in_clause.rs +++ b/ascent/examples/ascent_for_in_clause.rs @@ -8,32 +8,22 @@ ascent! { relation seed(i32); // Rules: - + relation number(i32); - + number(x + y) <-- seed(x), for y in 0..3; } fn main() { - let mut prog = AscentProgram::default(); - - prog.seed = vec![ - (0,), - (10,), - ]; - - prog.run(); - - let AscentProgram { mut number, ..} = prog; - - number.sort_by_key(|(key,)| *key); - - assert_eq!(number, vec![ - (0,), - (1,), - (2,), - (10,), - (11,), - (12,), - ]); + let mut prog = AscentProgram::default(); + + prog.seed = vec![(0,), (10,)]; + + prog.run(); + + let AscentProgram { mut number, .. } = prog; + + number.sort_by_key(|(key,)| *key); + + assert_eq!(number, vec![(0,), (1,), (2,), (10,), (11,), (12,),]); } diff --git a/ascent/examples/ascent_generic_program.rs b/ascent/examples/ascent_generic_program.rs index 839550a..0c3676e 100644 --- a/ascent/examples/ascent_generic_program.rs +++ b/ascent/examples/ascent_generic_program.rs @@ -11,7 +11,7 @@ pub struct Node(&'static str); ascent! { struct AscentProgram where N: Clone + Eq + Hash; - + // Facts: relation node(N); @@ -26,29 +26,18 @@ ascent! { } fn main() { - let mut prog: AscentProgram = AscentProgram::default(); - - prog.node = vec![ - (Node("A"),), - (Node("B"),), - (Node("C"),), - ]; - - prog.edge = vec![ - (Node("A"), Node("B")), - (Node("B"), Node("C")), - ]; - - prog.run(); - - let AscentProgram { mut reachable, ..} = prog; - - reachable.sort_by_key(|(_, key)| key.0); - reachable.sort_by_key(|(key, _)| key.0); - - assert_eq!(reachable, vec![ - (Node("A"), Node("B")), - (Node("A"), Node("C")), - (Node("B"), Node("C")), - ]); + let mut prog: AscentProgram = AscentProgram::default(); + + prog.node = vec![(Node("A"),), (Node("B"),), (Node("C"),)]; + + prog.edge = vec![(Node("A"), Node("B")), (Node("B"), Node("C"))]; + + prog.run(); + + let AscentProgram { mut reachable, .. } = prog; + + reachable.sort_by_key(|(_, key)| key.0); + reachable.sort_by_key(|(key, _)| key.0); + + assert_eq!(reachable, vec![(Node("A"), Node("B")), (Node("A"), Node("C")), (Node("B"), Node("C")),]); } diff --git a/ascent/examples/ascent_if_clause.rs b/ascent/examples/ascent_if_clause.rs index 9edde71..2be9e4e 100644 --- a/ascent/examples/ascent_if_clause.rs +++ b/ascent/examples/ascent_if_clause.rs @@ -6,39 +6,32 @@ ascent! { // Facts: relation number(isize); - + // Rules: - + relation even(isize); even(x) <-- number(x), if x % 2 == 0; - + relation odd(isize); odd(x) <-- number(x), if x % 2 != 0; } fn main() { - let mut prog = AscentProgram::default(); - - prog.number = (1..=5).map(|n| (n,)).collect(); + let mut prog = AscentProgram::default(); + + prog.number = (1..=5).map(|n| (n,)).collect(); - prog.run(); + prog.run(); - let AscentProgram { mut even, mut odd, ..} = prog; + let AscentProgram { mut even, mut odd, .. } = prog; - even.sort_by_key(|(key,)| *key); + even.sort_by_key(|(key,)| *key); - assert_eq!(even, vec![ - (2,), - (4,), - ]); + assert_eq!(even, vec![(2,), (4,),]); - odd.sort_by_key(|(key,)| *key); + odd.sort_by_key(|(key,)| *key); - assert_eq!(odd, vec![ - (1,), - (3,), - (5,), - ]); + assert_eq!(odd, vec![(1,), (3,), (5,),]); } diff --git a/ascent/examples/ascent_if_let_clause.rs b/ascent/examples/ascent_if_let_clause.rs index fbddf9f..7c6642f 100644 --- a/ascent/examples/ascent_if_let_clause.rs +++ b/ascent/examples/ascent_if_let_clause.rs @@ -6,9 +6,9 @@ ascent! { // Facts: relation option(Option); - + // Rules: - + relation some(isize); some(y) <-- option(x), if let Some(y) = x; @@ -19,24 +19,15 @@ ascent! { } fn main() { - let mut prog = AscentProgram::default(); - - prog.option = vec![ - (None,), - (Some(1),), - (Some(2),), - (Some(3),), - ]; - - prog.run(); - - let AscentProgram { mut some, ..} = prog; - - some.sort_by_key(|(key,)| *key); - - assert_eq!(some, vec![ - (1,), - (2,), - (3,), - ]); + let mut prog = AscentProgram::default(); + + prog.option = vec![(None,), (Some(1),), (Some(2),), (Some(3),)]; + + prog.run(); + + let AscentProgram { mut some, .. } = prog; + + some.sort_by_key(|(key,)| *key); + + assert_eq!(some, vec![(1,), (2,), (3,),]); } diff --git a/ascent/examples/ascent_lattice.rs b/ascent/examples/ascent_lattice.rs index f690de4..566476a 100644 --- a/ascent/examples/ascent_lattice.rs +++ b/ascent/examples/ascent_lattice.rs @@ -1,12 +1,12 @@ //! Aggregate `agg` clause -use ascent::{ascent, Dual}; +use ascent::{Dual, ascent}; pub type Node = &'static str; ascent! { // Facts: - + relation edge(Node, Node, u32); // Rules: @@ -15,35 +15,29 @@ ascent! { shortest_path(x, y, Dual(*w)) <-- edge(x, y, w); - shortest_path(x, z, Dual(w + l)) <-- - edge(x, y, w), + shortest_path(x, z, Dual(w + l)) <-- + edge(x, y, w), shortest_path(y, z, ?Dual(l)); } fn main() { - let mut prog = AscentProgram::default(); - - prog.edge = vec![ - ("A", "B", 1), - ("A", "D", 4), - ("B", "C", 1), - ("B", "D", 1), - ("C", "D", 2), - ]; - - prog.run(); - - let AscentProgram { mut shortest_path, ..} = prog; - - shortest_path.sort_by_key(|(_, key, _)| *key); - shortest_path.sort_by_key(|(key, _, _)| *key); - - assert_eq!(shortest_path, vec![ - ("A", "B", Dual(1)), - ("A", "C", Dual(2)), - ("A", "D", Dual(2)), - ("B", "C", Dual(1)), - ("B", "D", Dual(1)), - ("C", "D", Dual(2)), - ]); + let mut prog = AscentProgram::default(); + + prog.edge = vec![("A", "B", 1), ("A", "D", 4), ("B", "C", 1), ("B", "D", 1), ("C", "D", 2)]; + + prog.run(); + + let AscentProgram { mut shortest_path, .. } = prog; + + shortest_path.sort_by_key(|(_, key, _)| *key); + shortest_path.sort_by_key(|(key, _, _)| *key); + + assert_eq!(shortest_path, vec![ + ("A", "B", Dual(1)), + ("A", "C", Dual(2)), + ("A", "D", Dual(2)), + ("B", "C", Dual(1)), + ("B", "D", Dual(1)), + ("C", "D", Dual(2)), + ]); } diff --git a/ascent/examples/ascent_let_clause.rs b/ascent/examples/ascent_let_clause.rs index 9f7d342..9acd138 100644 --- a/ascent/examples/ascent_let_clause.rs +++ b/ascent/examples/ascent_let_clause.rs @@ -4,22 +4,22 @@ use ascent::ascent; #[derive(Clone, Eq, PartialEq, Hash, Debug)] pub enum List { - Nil, - Cons(usize, Box), + Nil, + Cons(usize, Box), } impl List { - fn as_vec(&self) -> Vec { - let mut items = vec![]; + fn as_vec(&self) -> Vec { + let mut items = vec![]; - let mut list = self; - while let Self::Cons(head, tail) = list { - items.push(*head); - list = tail; - } + let mut list = self; + while let Self::Cons(head, tail) = list { + items.push(*head); + list = tail; + } - items - } + items + } } ascent! { @@ -32,24 +32,22 @@ ascent! { } fn main() { - let mut prog = AscentProgram::default(); + let mut prog = AscentProgram::default(); - prog.run(); + prog.run(); - let AscentProgram { mut list, ..} = prog; + let AscentProgram { mut list, .. } = prog; - list.sort_by_key(|(_, key)| *key); + list.sort_by_key(|(_, key)| *key); - let lists: Vec<_> = list.into_iter().map(|(list, len)| { - (list.as_vec(), len) - }).collect(); + let lists: Vec<_> = list.into_iter().map(|(list, len)| (list.as_vec(), len)).collect(); - assert_eq!(lists, vec![ - (vec![], 0), - (vec![0], 1), - (vec![1, 0], 2), - (vec![2, 1, 0], 3), - (vec![3, 2, 1, 0], 4), - (vec![4, 3, 2, 1, 0], 5), - ]); + assert_eq!(lists, vec![ + (vec![], 0), + (vec![0], 1), + (vec![1, 0], 2), + (vec![2, 1, 0], 3), + (vec![3, 2, 1, 0], 4), + (vec![4, 3, 2, 1, 0], 5), + ]); } diff --git a/ascent/examples/ascent_macros_rule.rs b/ascent/examples/ascent_macros_rule.rs index 0645f1e..d133bca 100644 --- a/ascent/examples/ascent_macros_rule.rs +++ b/ascent/examples/ascent_macros_rule.rs @@ -8,7 +8,7 @@ ascent! { // Facts: relation unique(isize); - + // Macros: macro shared($x: expr) { @@ -16,34 +16,22 @@ ascent! { } // Rules: - + relation shared(Rc); shared!(*x) <-- unique(x); } fn main() { - let mut prog = AscentProgram::default(); - - prog.unique = vec![ - (1,), - (2,), - (3,), - (4,), - (5,), - ]; - - prog.run(); - - let AscentProgram { mut shared, ..} = prog; - - shared.sort_by_key(|(key,)| Rc::clone(key)); - - assert_eq!(shared, vec![ - (Rc::new(1),), - (Rc::new(2),), - (Rc::new(3),), - (Rc::new(4),), - (Rc::new(5),), - ]); + let mut prog = AscentProgram::default(); + + prog.unique = vec![(1,), (2,), (3,), (4,), (5,)]; + + prog.run(); + + let AscentProgram { mut shared, .. } = prog; + + shared.sort_by_key(|(key,)| Rc::clone(key)); + + assert_eq!(shared, vec![(Rc::new(1),), (Rc::new(2),), (Rc::new(3),), (Rc::new(4),), (Rc::new(5),),]); } diff --git a/ascent/examples/ascent_negation_clause.rs b/ascent/examples/ascent_negation_clause.rs index 0c4d5a2..fdea8b7 100644 --- a/ascent/examples/ascent_negation_clause.rs +++ b/ascent/examples/ascent_negation_clause.rs @@ -6,9 +6,9 @@ ascent! { // Facts: relation number(i32); - + // Rules: - + relation even(i32); even(x) <-- number(x), if x % 2 == 0; @@ -19,26 +19,19 @@ ascent! { } fn main() { - let mut prog = AscentProgram::default(); - - prog.number = (1..=5).map(|n| (n,)).collect(); + let mut prog = AscentProgram::default(); + + prog.number = (1..=5).map(|n| (n,)).collect(); - prog.run(); + prog.run(); - let AscentProgram { mut even, mut odd, ..} = prog; + let AscentProgram { mut even, mut odd, .. } = prog; - even.sort_by_key(|(key,)| *key); + even.sort_by_key(|(key,)| *key); - assert_eq!(even, vec![ - (2,), - (4,), - ]); + assert_eq!(even, vec![(2,), (4,),]); - odd.sort_by_key(|(key,)| *key); + odd.sort_by_key(|(key,)| *key); - assert_eq!(odd, vec![ - (1,), - (3,), - (5,), - ]); + assert_eq!(odd, vec![(1,), (3,), (5,),]); } diff --git a/ascent/examples/context_sensitive_flow_graph.rs b/ascent/examples/context_sensitive_flow_graph.rs index 216f6c0..457f7de 100644 --- a/ascent/examples/context_sensitive_flow_graph.rs +++ b/ascent/examples/context_sensitive_flow_graph.rs @@ -23,8 +23,8 @@ pub struct Context(&'static str); #[derive(Clone, Eq, PartialEq, Hash, Debug)] pub enum Res { - Ok, - Err, + Ok, + Err, } ascent! { @@ -33,36 +33,33 @@ ascent! { relation succ(Instr, Context, Instr, Context); // Rules: - + relation flow(Instr, Context, Instr, Context); - + flow(i1, c1, i2, c2) <-- succ(i1, c1, i2, c2); flow(i1, c1, i3, c3) <-- flow(i1, c1, i2, c2), flow(i2, c2, i3, c3); relation res(Res); - + res(Res::Ok) <-- flow(Instr("w1"), Context("c1"), Instr("r2"), Context("c1")); res(Res::Err) <-- flow(Instr("w1"), Context("c1"), Instr("r2"), Context("c2")); } fn main() { - let mut prog = AscentProgram::default(); - - prog.succ = vec![ - (Instr("w1"), Context("c1"), Instr("w2"), Context("c1")), - (Instr("w2"), Context("c1"), Instr("r1"), Context("c1")), - (Instr("r1"), Context("c1"), Instr("r2"), Context("c1")), + let mut prog = AscentProgram::default(); - (Instr("w1"), Context("c2"), Instr("w2"), Context("c2")), - (Instr("w2"), Context("c2"), Instr("r1"), Context("c2")), - (Instr("r1"), Context("c2"), Instr("r2"), Context("c2")), - ]; + prog.succ = vec![ + (Instr("w1"), Context("c1"), Instr("w2"), Context("c1")), + (Instr("w2"), Context("c1"), Instr("r1"), Context("c1")), + (Instr("r1"), Context("c1"), Instr("r2"), Context("c1")), + (Instr("w1"), Context("c2"), Instr("w2"), Context("c2")), + (Instr("w2"), Context("c2"), Instr("r1"), Context("c2")), + (Instr("r1"), Context("c2"), Instr("r2"), Context("c2")), + ]; - prog.run(); + prog.run(); - let AscentProgram { res, ..} = prog; + let AscentProgram { res, .. } = prog; - assert_eq!(res, vec![ - (Res::Ok,), - ]); + assert_eq!(res, vec![(Res::Ok,),]); } diff --git a/ascent/examples/context_sensitive_flow_graph_with_records.rs b/ascent/examples/context_sensitive_flow_graph_with_records.rs index bbe834e..9f43625 100644 --- a/ascent/examples/context_sensitive_flow_graph_with_records.rs +++ b/ascent/examples/context_sensitive_flow_graph_with_records.rs @@ -24,8 +24,8 @@ pub struct ProgPoint(Instr, Context); #[derive(Clone, Eq, PartialEq, Hash, Debug)] pub enum Res { - Ok, - Err, + Ok, + Err, } ascent! { @@ -34,36 +34,33 @@ ascent! { relation succ(ProgPoint, ProgPoint); // Rules: - + relation flow(ProgPoint, ProgPoint); - + flow(p1, p2) <-- succ(p1, p2); flow(p1, p3) <-- flow(p1, p2), flow(p2, p3); relation res(Res); - + res(Res::Ok) <-- flow(ProgPoint(Instr("w1"), Context("c1")), ProgPoint(Instr("r2"), Context("c1"))); res(Res::Err) <-- flow(ProgPoint(Instr("w1"), Context("c1")), ProgPoint(Instr("r2"), Context("c2"))); } fn main() { - let mut prog = AscentProgram::default(); - - prog.succ = vec![ - (ProgPoint(Instr("w1"), Context("c1")), ProgPoint(Instr("w2"), Context("c1"))), - (ProgPoint(Instr("w2"), Context("c1")), ProgPoint(Instr("r1"), Context("c1"))), - (ProgPoint(Instr("r1"), Context("c1")), ProgPoint(Instr("r2"), Context("c1"))), + let mut prog = AscentProgram::default(); - (ProgPoint(Instr("w1"), Context("c2")), ProgPoint(Instr("w2"), Context("c2"))), - (ProgPoint(Instr("w2"), Context("c2")), ProgPoint(Instr("r1"), Context("c2"))), - (ProgPoint(Instr("r1"), Context("c2")), ProgPoint(Instr("r2"), Context("c2"))), - ]; + prog.succ = vec![ + (ProgPoint(Instr("w1"), Context("c1")), ProgPoint(Instr("w2"), Context("c1"))), + (ProgPoint(Instr("w2"), Context("c1")), ProgPoint(Instr("r1"), Context("c1"))), + (ProgPoint(Instr("r1"), Context("c1")), ProgPoint(Instr("r2"), Context("c1"))), + (ProgPoint(Instr("w1"), Context("c2")), ProgPoint(Instr("w2"), Context("c2"))), + (ProgPoint(Instr("w2"), Context("c2")), ProgPoint(Instr("r1"), Context("c2"))), + (ProgPoint(Instr("r1"), Context("c2")), ProgPoint(Instr("r2"), Context("c2"))), + ]; - prog.run(); + prog.run(); - let AscentProgram { res, ..} = prog; + let AscentProgram { res, .. } = prog; - assert_eq!(res, vec![ - (Res::Ok,), - ]); + assert_eq!(res, vec![(Res::Ok,),]); } diff --git a/ascent/examples/def_use_chains.rs b/ascent/examples/def_use_chains.rs index 058d2c7..c5d320e 100644 --- a/ascent/examples/def_use_chains.rs +++ b/ascent/examples/def_use_chains.rs @@ -27,9 +27,9 @@ pub struct Jump(&'static str); #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)] pub enum Instr { - Read(Read), - Write(Write), - Jump(Jump), + Read(Read), + Write(Write), + Jump(Jump), } ascent! { @@ -52,37 +52,26 @@ ascent! { } fn main() { - let mut prog = AscentProgram::default(); - - prog.read = vec![ - (Read("r1"), Var("v1")), - (Read("r2"), Var("v1")), - (Read("r3"), Var("v2")), - ]; - - prog.write = vec![ - (Write("w1"), Var("v1")), - (Write("w2"), Var("v2")), - (Write("w3"), Var("v2")), - ]; - - prog.succ = vec![ - (Instr::Write(Write("w1")), Instr::Jump(Jump("o1"))), - (Instr::Jump(Jump("o1")), Instr::Read(Read("r1"))), - (Instr::Jump(Jump("o1")), Instr::Read(Read("r2"))), - (Instr::Read(Read("r2")), Instr::Read(Read("r3"))), - (Instr::Read(Read("r3")), Instr::Write(Write("w2"))), - ]; - - prog.run(); - - let AscentProgram { mut def_use, ..} = prog; - - def_use.sort_by_key(|(key, _)| key.0); - def_use.sort_by_key(|(_, key)| key.0); - - assert_eq!(def_use, vec![ - (Write("w1"), Read("r1")), - (Write("w1"), Read("r2")), - ]); + let mut prog = AscentProgram::default(); + + prog.read = vec![(Read("r1"), Var("v1")), (Read("r2"), Var("v1")), (Read("r3"), Var("v2"))]; + + prog.write = vec![(Write("w1"), Var("v1")), (Write("w2"), Var("v2")), (Write("w3"), Var("v2"))]; + + prog.succ = vec![ + (Instr::Write(Write("w1")), Instr::Jump(Jump("o1"))), + (Instr::Jump(Jump("o1")), Instr::Read(Read("r1"))), + (Instr::Jump(Jump("o1")), Instr::Read(Read("r2"))), + (Instr::Read(Read("r2")), Instr::Read(Read("r3"))), + (Instr::Read(Read("r3")), Instr::Write(Write("w2"))), + ]; + + prog.run(); + + let AscentProgram { mut def_use, .. } = prog; + + def_use.sort_by_key(|(key, _)| key.0); + def_use.sort_by_key(|(_, key)| key.0); + + assert_eq!(def_use, vec![(Write("w1"), Read("r1")), (Write("w1"), Read("r2")),]); } diff --git a/ascent/examples/fibonacci.rs b/ascent/examples/fibonacci.rs index 05097b8..f710199 100644 --- a/ascent/examples/fibonacci.rs +++ b/ascent/examples/fibonacci.rs @@ -6,33 +6,26 @@ ascent! { // Facts: relation number(isize); - + // Rules: relation fib(isize, isize); - + fib(0, 1) <-- number(0); fib(1, 1) <-- number(1); fib(x, y + z) <-- number(x), if *x >= 2, fib(x - 1, y), fib(x - 2, z); } fn main() { - let mut prog = AscentProgram::default(); - - prog.number = (0..6).map(|n| (n,)).collect(); + let mut prog = AscentProgram::default(); + + prog.number = (0..6).map(|n| (n,)).collect(); - prog.run(); + prog.run(); - let AscentProgram { mut fib, ..} = prog; + let AscentProgram { mut fib, .. } = prog; - fib.sort_by_key(|(key, _)| *key); + fib.sort_by_key(|(key, _)| *key); - assert_eq!(fib, vec![ - (0, 1), - (1, 1), - (2, 2), - (3, 3), - (4, 5), - (5, 8), - ]); + assert_eq!(fib, vec![(0, 1), (1, 1), (2, 2), (3, 3), (4, 5), (5, 8),]); } diff --git a/ascent/examples/fizz_buzz.rs b/ascent/examples/fizz_buzz.rs index fb9e2c2..cb6b21b 100644 --- a/ascent/examples/fizz_buzz.rs +++ b/ascent/examples/fizz_buzz.rs @@ -6,9 +6,9 @@ ascent! { // Facts: relation number(isize); - + // Rules: - + relation divisible(isize, isize); divisible(x, 3) <-- number(x), if x % 3 == 0; @@ -32,46 +32,27 @@ ascent! { } fn main() { - let mut prog = AscentProgram::default(); - - prog.number = (1..=15).map(|n| (n,)).collect(); + let mut prog = AscentProgram::default(); + + prog.number = (1..=15).map(|n| (n,)).collect(); - prog.run(); + prog.run(); - let AscentProgram { mut fizz, mut buzz, mut fizz_buzz, mut other, ..} = prog; + let AscentProgram { mut fizz, mut buzz, mut fizz_buzz, mut other, .. } = prog; - fizz.sort_by_key(|(key,)| *key); + fizz.sort_by_key(|(key,)| *key); - assert_eq!(fizz, vec![ - (3,), - (6,), - (9,), - (12,), - ]); + assert_eq!(fizz, vec![(3,), (6,), (9,), (12,),]); - buzz.sort_by_key(|(key,)| *key); + buzz.sort_by_key(|(key,)| *key); - assert_eq!(buzz, vec![ - (5,), - (10,), - ]); + assert_eq!(buzz, vec![(5,), (10,),]); - fizz_buzz.sort_by_key(|(key,)| *key); + fizz_buzz.sort_by_key(|(key,)| *key); - assert_eq!(fizz_buzz, vec![ - (15,), - ]); + assert_eq!(fizz_buzz, vec![(15,),]); - other.sort_by_key(|(key,)| *key); + other.sort_by_key(|(key,)| *key); - assert_eq!(other, vec![ - (1,), - (2,), - (4,), - (7,), - (8,), - (11,), - (13,), - (14,) - ]); + assert_eq!(other, vec![(1,), (2,), (4,), (7,), (8,), (11,), (13,), (14,)]); } diff --git a/ascent/examples/lists_using_recursive_enums.rs b/ascent/examples/lists_using_recursive_enums.rs index 93446b7..4143486 100644 --- a/ascent/examples/lists_using_recursive_enums.rs +++ b/ascent/examples/lists_using_recursive_enums.rs @@ -20,20 +20,20 @@ use ascent::ascent; #[derive(Clone, Eq, PartialEq, Hash, Debug)] pub enum List { - Cons(T, Rc>), - Nil + Cons(T, Rc>), + Nil, } macro_rules! cons { - ($h: expr, $t: expr) => { + ($h: expr, $t: expr) => { Rc::new(List::Cons($h, $t)) - } + }; } macro_rules! nil { - () => { + () => { Rc::new(List::Nil) - } + }; } #[derive(Clone, Eq, PartialEq, Hash, Debug)] @@ -43,19 +43,19 @@ ascent! { // Facts: relation char(char); - + // Rules: - + relation list(Rc>); - + list(nil!()); list(cons!(c.clone(), l.clone())) <-- char(c), list(l), len(l, n), if *n < 5; relation len(Rc>, usize); - + len(nil!(), 0); len(l.clone(), n + 1) <-- char(c), len(r, n), let l = cons!(c.clone(), r.clone()), list(&l); - + relation res(Res); res(Res("-")) <-- list(nil!()); @@ -68,24 +68,15 @@ ascent! { } fn main() { - let mut prog = AscentProgram::default(); - - prog.char = vec![ - ('a',), - ('b',), - ]; - - prog.run(); - - let AscentProgram { mut res, ..} = prog; - - res.sort_by_key(|(key,)| key.0); - - assert_eq!(res, vec![ - (Res("-"),), - (Res("a"),), - (Res("ab"),), - (Res("aba"),), - (Res("b"),), - ]); + let mut prog = AscentProgram::default(); + + prog.char = vec![('a',), ('b',)]; + + prog.run(); + + let AscentProgram { mut res, .. } = prog; + + res.sort_by_key(|(key,)| key.0); + + assert_eq!(res, vec![(Res("-"),), (Res("a"),), (Res("ab"),), (Res("aba"),), (Res("b"),),]); } diff --git a/ascent/examples/transitive_graph_closure.rs b/ascent/examples/transitive_graph_closure.rs index dc0461e..d64bf98 100644 --- a/ascent/examples/transitive_graph_closure.rs +++ b/ascent/examples/transitive_graph_closure.rs @@ -24,7 +24,7 @@ ascent! { // or a non-linear rule: // reachable(x, y) <-- edge(x, y); // reachable(x, z) <-- reachable(x, y), reachable(y, z); - + // While both variants are semantically equivalent the linear rule // tends to be more performant since in the non-linear variant the fact `reachable(x, y)` // is redundantly discovered again at every iteration (and thus n − 1 times). @@ -35,36 +35,22 @@ ascent! { } fn main() { - let mut prog = AscentProgram::default(); - - prog.node = vec![ - (Node("A"),), - (Node("B"),), - (Node("C"),), - ]; - - prog.edge = vec![ - (Node("A"), Node("B")), - (Node("B"), Node("C")), - ]; - - prog.run(); - - let AscentProgram { mut reachable, mut closure_of_a, .. } = prog; - - reachable.sort_by_key(|(key, _)| key.0); - reachable.sort_by_key(|(_, key)| key.0); - - assert_eq!(reachable, vec![ - (Node("A"), Node("B")), - (Node("A"), Node("C")), - (Node("B"), Node("C")), - ]); - - closure_of_a.sort_by_key(|(key,)| key.0); - - assert_eq!(closure_of_a, vec![ - (Node("B"),), - (Node("C"),), - ]); + let mut prog = AscentProgram::default(); + + prog.node = vec![(Node("A"),), (Node("B"),), (Node("C"),)]; + + prog.edge = vec![(Node("A"), Node("B")), (Node("B"), Node("C"))]; + + prog.run(); + + let AscentProgram { mut reachable, mut closure_of_a, .. } = prog; + + reachable.sort_by_key(|(key, _)| key.0); + reachable.sort_by_key(|(_, key)| key.0); + + assert_eq!(reachable, vec![(Node("A"), Node("B")), (Node("A"), Node("C")), (Node("B"), Node("C")),]); + + closure_of_a.sort_by_key(|(key,)| key.0); + + assert_eq!(closure_of_a, vec![(Node("B"),), (Node("C"),),]); } diff --git a/ascent/examples/var_points_to.rs b/ascent/examples/var_points_to.rs index 202382a..2f07b01 100644 --- a/ascent/examples/var_points_to.rs +++ b/ascent/examples/var_points_to.rs @@ -65,48 +65,38 @@ ascent! { } fn main() { - let mut prog = AscentProgram::default(); - - prog.assign = vec![ - (Var("v1"), Var("v2")) - ]; - - prog.new = vec![ - (Var("v1"), Obj("h1")), - (Var("v2"), Obj("h2")), - (Var("v3"), Obj("h3")), - ]; - - prog.st = vec![ - (Var("v1"), Field("f"), Var("v3")), - ]; - - prog.ld = vec![ - (Var("v4"), Var("v1"), Field("f")), - ]; - - prog.run(); - - let AscentProgram { mut alias, mut points_to, .. } = prog; - - alias.sort_by_key(|(_, key)| key.0); - alias.sort_by_key(|(key, _)| key.0); - - assert_eq!(alias, vec![ - (Var("v1"), Var("v1")), - (Var("v1"), Var("v2")), - (Var("v2"), Var("v2")), - (Var("v4"), Var("v3")), - ]); - - points_to.sort_by_key(|(_, key)| key.0); - points_to.sort_by_key(|(key, _)| key.0); - - assert_eq!(points_to, vec![ - (Var("v1"), Obj("h1")), - (Var("v1"), Obj("h2")), - (Var("v2"), Obj("h2")), - (Var("v3"), Obj("h3")), - (Var("v4"), Obj("h3")), - ]); + let mut prog = AscentProgram::default(); + + prog.assign = vec![(Var("v1"), Var("v2"))]; + + prog.new = vec![(Var("v1"), Obj("h1")), (Var("v2"), Obj("h2")), (Var("v3"), Obj("h3"))]; + + prog.st = vec![(Var("v1"), Field("f"), Var("v3"))]; + + prog.ld = vec![(Var("v4"), Var("v1"), Field("f"))]; + + prog.run(); + + let AscentProgram { mut alias, mut points_to, .. } = prog; + + alias.sort_by_key(|(_, key)| key.0); + alias.sort_by_key(|(key, _)| key.0); + + assert_eq!(alias, vec![ + (Var("v1"), Var("v1")), + (Var("v1"), Var("v2")), + (Var("v2"), Var("v2")), + (Var("v4"), Var("v3")), + ]); + + points_to.sort_by_key(|(_, key)| key.0); + points_to.sort_by_key(|(key, _)| key.0); + + assert_eq!(points_to, vec![ + (Var("v1"), Obj("h1")), + (Var("v1"), Obj("h2")), + (Var("v2"), Obj("h2")), + (Var("v3"), Obj("h3")), + (Var("v4"), Obj("h3")), + ]); } diff --git a/ascent/src/aggregators.rs b/ascent/src/aggregators.rs index e81652f..2259965 100644 --- a/ascent/src/aggregators.rs +++ b/ascent/src/aggregators.rs @@ -1,108 +1,93 @@ -//! This module provides aggregators that can be used in Ascent rules. -//! -//! eg: Writing the average of `foo` in `bar` -//! ``` -//! use ascent::ascent; -//! use ascent::aggregators::*; -//! ascent!{ -//! relation foo(i32); -//! relation bar(i32); -//! // ... -//! bar(m as i32) <-- agg m = mean(x) in foo(x); -//! } -//! ``` - -use std::ops::Add; -use std::iter::Sum; - -/// computes the minimum of the input column -pub fn min<'a, N: 'a>(inp: impl Iterator) -> impl Iterator -where N: Ord + Clone -{ - inp.map(|tuple| tuple.0).min().cloned().into_iter() -} - -/// computes the maximum of the input column -pub fn max<'a, N: 'a>(inp: impl Iterator) -> impl Iterator -where N: Ord + Clone -{ - inp.map(|tuple| tuple.0).max().cloned().into_iter() -} - -/// computes the sum of the input column -pub fn sum<'a, N: 'a>(inp: impl Iterator) -> impl Iterator -where N: Ord + Add + Clone + Sum -{ - let sum = inp.map(|tuple| tuple.0).cloned().sum::(); - std::iter::once(sum) -} - -/// returns the number of tuples -/// -/// # Examples -/// -/// ``` -/// # use ascent::ascent_run; -/// # use ascent::aggregators::count; -/// let res = ascent_run!{ -/// relation edge(u32, u32); -/// relation path(u32, u32); -/// relation num_paths(usize); -/// path(a, b) <-- edge(a, b); -/// path(a, c) <-- path(a, b), edge(b, c); -/// -/// edge(1, 2); -/// edge(2, 3); -/// edge(3, 4); -/// -/// num_paths(n) <-- agg n = count() in path(_, _); -/// }; -/// // This example program is expected to produce 6 paths. -/// assert_eq!(res.num_paths[0].0, 6); -///``` -pub fn count(inp: impl Iterator) -> impl Iterator -{ - let (size_floor, size_ceiling)= inp.size_hint(); - let size_ceiling = size_ceiling.unwrap_or(usize::MAX); - let count = if size_floor == size_ceiling { - size_floor - } else { - inp.count() - }; - std::iter::once(count) -} - -/// computes the average of the input column, returning an `f64` -pub fn mean<'a, N: 'a>(inp: impl Iterator) -> impl Iterator -where N: Clone + Into -{ - let (sum, count) = inp.fold((0.0, 0usize), |(sum, count), tuple| (tuple.0.clone().into() + sum, count + 1)); - let res = if count == 0 {None} else { - Some(sum / count as f64) - }; - res.into_iter() -} - -/// computes the value at the given percentile of the input column -pub fn percentile<'a, TItem: 'a, TInputIter>(p: f64) -> impl Fn(TInputIter) -> std::option::IntoIter -where - TInputIter: Iterator, TItem: Ord + Clone, -{ - move |inp| { - let mut sorted: Vec<_> = inp.map(|tuple| tuple.0.clone()).collect(); - sorted.sort(); - let p_index = (sorted.len() as f64 * p / 100.0) as usize; - if !sorted.is_empty() { - Some(sorted.swap_remove(p_index)) - } else { - None - }.into_iter() - } -} - -/// backs negations (eg `!foo(x)`) in `ascent` -pub fn not(mut inp: impl Iterator) -> impl Iterator -{ - let any = inp.next().is_some(); - if any {None} else {Some(())}.into_iter() -} \ No newline at end of file +//! This module provides aggregators that can be used in Ascent rules. +//! +//! eg: Writing the average of `foo` in `bar` +//! ``` +//! use ascent::ascent; +//! use ascent::aggregators::*; +//! ascent!{ +//! relation foo(i32); +//! relation bar(i32); +//! // ... +//! bar(m as i32) <-- agg m = mean(x) in foo(x); +//! } +//! ``` + +use std::iter::Sum; +use std::ops::Add; + +/// computes the minimum of the input column +pub fn min<'a, N: 'a>(inp: impl Iterator) -> impl Iterator +where N: Ord + Clone { + inp.map(|tuple| tuple.0).min().cloned().into_iter() +} + +/// computes the maximum of the input column +pub fn max<'a, N: 'a>(inp: impl Iterator) -> impl Iterator +where N: Ord + Clone { + inp.map(|tuple| tuple.0).max().cloned().into_iter() +} + +/// computes the sum of the input column +pub fn sum<'a, N: 'a>(inp: impl Iterator) -> impl Iterator +where N: Ord + Add + Clone + Sum { + let sum = inp.map(|tuple| tuple.0).cloned().sum::(); + std::iter::once(sum) +} + +/// returns the number of tuples +/// +/// # Examples +/// +/// ``` +/// # use ascent::ascent_run; +/// # use ascent::aggregators::count; +/// let res = ascent_run!{ +/// relation edge(u32, u32); +/// relation path(u32, u32); +/// relation num_paths(usize); +/// path(a, b) <-- edge(a, b); +/// path(a, c) <-- path(a, b), edge(b, c); +/// +/// edge(1, 2); +/// edge(2, 3); +/// edge(3, 4); +/// +/// num_paths(n) <-- agg n = count() in path(_, _); +/// }; +/// // This example program is expected to produce 6 paths. +/// assert_eq!(res.num_paths[0].0, 6); +///``` +pub fn count(inp: impl Iterator) -> impl Iterator { + let (size_floor, size_ceiling) = inp.size_hint(); + let size_ceiling = size_ceiling.unwrap_or(usize::MAX); + let count = if size_floor == size_ceiling { size_floor } else { inp.count() }; + std::iter::once(count) +} + +/// computes the average of the input column, returning an `f64` +pub fn mean<'a, N: 'a>(inp: impl Iterator) -> impl Iterator +where N: Clone + Into { + let (sum, count) = inp.fold((0.0, 0usize), |(sum, count), tuple| (tuple.0.clone().into() + sum, count + 1)); + let res = if count == 0 { None } else { Some(sum / count as f64) }; + res.into_iter() +} + +/// computes the value at the given percentile of the input column +pub fn percentile<'a, TItem: 'a, TInputIter>(p: f64) -> impl Fn(TInputIter) -> std::option::IntoIter +where + TInputIter: Iterator, + TItem: Ord + Clone, +{ + move |inp| { + let mut sorted: Vec<_> = inp.map(|tuple| tuple.0.clone()).collect(); + sorted.sort(); + let p_index = (sorted.len() as f64 * p / 100.0) as usize; + if !sorted.is_empty() { Some(sorted.swap_remove(p_index)) } else { None }.into_iter() + } +} + +/// backs negations (eg `!foo(x)`) in `ascent` +pub fn not(mut inp: impl Iterator) -> impl Iterator { + let any = inp.next().is_some(); + if any { None } else { Some(()) }.into_iter() +} diff --git a/ascent/src/c_lat_index.rs b/ascent/src/c_lat_index.rs index 14d719c..87cfbef 100644 --- a/ascent/src/c_lat_index.rs +++ b/ascent/src/c_lat_index.rs @@ -1,226 +1,234 @@ -use ascent_base::util::update; -use dashmap::{DashMap, SharedValue}; -use instant::Instant; -use rustc_hash::FxHasher; -use std::collections::HashSet; -use std::hash::{Hash, BuildHasherDefault}; - -use crate::c_rel_index::{shards_count, DashMapViewParIter}; -use crate::internal::{RelIndexWrite, CRelIndexWrite, RelIndexMerge, Freezable}; -use crate::internal::{RelIndexRead, RelIndexReadAll, CRelIndexRead, CRelIndexReadAll}; - -type SetType = HashSet; -pub enum CLatIndex { - Unfrozen(DashMap, BuildHasherDefault>), - Frozen(dashmap::ReadOnlyView, BuildHasherDefault>) -} - -impl Freezable for CLatIndex { - fn freeze(&mut self) { - update(self, |_self| match _self { - CLatIndex::Unfrozen(dm) => Self::Frozen(dm.into_read_only()), - CLatIndex::Frozen(_) => _self, - }) - } - - fn unfreeze(&mut self) { - update(self, |_self| match _self { - CLatIndex::Frozen(v) => Self::Unfrozen(v.into_inner()), - CLatIndex::Unfrozen(_) => _self, - }) - } -} - -impl CLatIndex { - - #[inline] - pub fn unwrap_frozen(&self) -> &dashmap::ReadOnlyView, BuildHasherDefault> { - match self { - CLatIndex::Frozen(v) => v, - CLatIndex::Unfrozen(_) => panic!("CRelIndex::unwrap_frozen(): object is Unfrozen"), - } - } - - #[inline] - pub fn unwrap_unfrozen(&self) -> &DashMap, BuildHasherDefault> { - match self { - CLatIndex::Unfrozen(dm) => dm, - CLatIndex::Frozen(_) => panic!("CRelIndex::unwrap_unfrozen(): object is Frozen"), - } - } - - #[inline] - pub fn unwrap_mut_unfrozen(&mut self) -> &mut DashMap, BuildHasherDefault> { - match self { - CLatIndex::Unfrozen(dm) => dm, - CLatIndex::Frozen(_) => panic!("CRelIndex::unwrap_unfrozen(): object is Frozen"), - } - } - - pub fn into_read_only(self) -> dashmap::ReadOnlyView, BuildHasherDefault> { - match self { - CLatIndex::Unfrozen(dm) => dm.into_read_only(), - CLatIndex::Frozen(f) => f, - } - } - - #[inline] - fn insert(&self, key: K, value: V) { - match self.unwrap_unfrozen().entry(key) { - dashmap::mapref::entry::Entry::Occupied(mut occ) => {occ.get_mut().insert(value);}, - dashmap::mapref::entry::Entry::Vacant(vac) => { - let mut set = SetType::default(); - set.insert(value); - vac.insert(set); - }, - } - } - - #[inline] - pub fn hash_usize(&self, k: &K) -> usize { - self.unwrap_unfrozen().hash_usize(k) - } -} - -impl Default for CLatIndex { - fn default() -> Self { - // Self::Unfrozen(Default::default()) - Self::Unfrozen(DashMap::with_hasher_and_shard_amount(Default::default(), shards_count())) - } -} - -impl<'a, K: 'a + Clone + Hash + Eq, V: 'a + Clone + Hash + Eq> RelIndexRead<'a> for CLatIndex { - type Key = K; - type Value = &'a V; - - type IteratorType = std::collections::hash_set::Iter<'a, V>; - - fn index_get(&'a self, key: &Self::Key) -> Option { - let vals = &self.unwrap_frozen().get(key)?; - let res = vals.iter(); - Some(res) - } - - fn len(&self) -> usize { - self.unwrap_frozen().len() - } -} - -impl<'a, K: 'a + Clone + Hash + Eq, V: 'a + Clone + Hash + Eq + Sync> CRelIndexRead<'a> for CLatIndex { - type Key = K; - type Value = &'a V; - - type IteratorType = rayon::collections::hash_set::Iter<'a, V>; - - fn c_index_get(&'a self, key: &Self::Key) -> Option { - use rayon::prelude::*; - let vals = &self.unwrap_frozen().get(key)?; - let res = vals.par_iter(); - Some(res) - } - -} - -impl<'a, K: 'a + Clone + Hash + Eq + Send + Sync, V: 'a + Clone + Hash + Eq + Send + Sync> RelIndexWrite for CLatIndex { - type Key = K; - type Value = V; - - fn index_insert(&mut self, key: Self::Key, value: Self::Value) { - let dm = self.unwrap_mut_unfrozen(); - // let shard = dm.determine_map(&key); - // let entry = dm.shards_mut()[shard].get_mut().entry(key).or_insert(SharedValue::new(Default::default())); - // entry.get_mut().push(value); - - let hash = dm.hash_usize(&key); - let shard = dm.determine_shard(hash); - let entry = dm.shards_mut()[shard].get_mut().raw_entry_mut() - .from_key_hashed_nocheck(hash as u64, &key) - .or_insert(key, SharedValue::new(Default::default())); - entry.1.get_mut().insert(value); - } -} - -impl<'a, K: 'a + Clone + Hash + Eq + Send + Sync, V: 'a + Clone + Hash + Eq + Send + Sync> RelIndexMerge for CLatIndex { - fn move_index_contents(from: &mut Self, to: &mut Self) { - let before = Instant::now(); - let from = from.unwrap_mut_unfrozen(); - let to = to.unwrap_mut_unfrozen(); - - use rayon::prelude::*; - assert_eq!(from.shards().len(), to.shards().len()); - to.shards_mut().par_iter_mut().zip(from.shards_mut().par_iter_mut()).for_each(|(to, from)| { - let from = from.get_mut(); - let to = to.get_mut(); - - if from.len() > to.len() { - std::mem::swap(from, to); - } - - for (k, mut v) in from.drain() { - match to.entry(k) { - hashbrown::hash_map::Entry::Occupied(mut occ) => { - let occ = occ.get_mut().get_mut(); - let v_mut = v.get_mut(); - if v_mut.len() > occ.len() { - std::mem::swap(occ, v_mut); - } - let v = v.into_inner(); - occ.reserve(v.len()); - occ.extend(&mut v.into_iter()); - }, - hashbrown::hash_map::Entry::Vacant(vac) => {vac.insert(v);}, - } - } - - }); - unsafe { - crate::internal::MOVE_REL_INDEX_CONTENTS_TOTAL_TIME += before.elapsed(); - } - - } -} - -impl<'a, K: 'a + Clone + Hash + Eq, V: 'a + Clone + Hash + Eq> RelIndexReadAll<'a> for CLatIndex { - type Key = &'a K; - type Value = V; - - type ValueIteratorType = std::iter::Cloned>; - type AllIteratorType = Box + 'a>; - - fn iter_all(&'a self) -> Self::AllIteratorType { - // let res = DashMapViewParIter::new(self.unwrap_frozen()).map(|(k, v)| (k, v.iter().cloned())); - let res = self.unwrap_frozen().iter().map(|(k, v)| (k, v.iter().cloned())); - Box::new(res) as _ - } -} - -impl<'a, K: 'a + Clone + Hash + Eq + Sync + Send, V: 'a + Clone + Hash + Eq + Sync + Send> CRelIndexReadAll<'a> for CLatIndex { - type Key = &'a K; - type Value = &'a V; - - type ValueIteratorType = rayon::collections::hash_set::Iter<'a, V>; - - type AllIteratorType = - rayon::iter::Map, BuildHasherDefault>, for<'aa, 'bb> fn((&'aa K, &'bb SetType)) -> (&'aa K, rayon::collections::hash_set::Iter<'bb, V>)>; - - fn c_iter_all(&'a self) -> Self::AllIteratorType { - use rayon::prelude::*; - let res: Self::AllIteratorType = DashMapViewParIter::new(self.unwrap_frozen()).map(|(k, v)| (k, v.par_iter())); - res - } -} - - -impl<'a, K: 'a + Clone + Hash + Eq, V: 'a + Clone + Hash + Eq> CRelIndexWrite for CLatIndex { - type Key = K; - type Value = V; - - #[inline(always)] - fn index_insert(&self, key: Self::Key, value: Self::Value) { - // let before = Instant::now(); - self.insert(key, value); - // unsafe { - // crate::internal::INDEX_INSERT_TOTAL_TIME += before.elapsed(); - // } - } -} +use std::collections::HashSet; +use std::hash::{BuildHasherDefault, Hash}; + +use ascent_base::util::update; +use dashmap::{DashMap, SharedValue}; +use instant::Instant; +use rustc_hash::FxHasher; + +use crate::c_rel_index::{DashMapViewParIter, shards_count}; +use crate::internal::{ + CRelIndexRead, CRelIndexReadAll, CRelIndexWrite, Freezable, RelIndexMerge, RelIndexRead, RelIndexReadAll, + RelIndexWrite, +}; + +type SetType = HashSet; +pub enum CLatIndex { + Unfrozen(DashMap, BuildHasherDefault>), + Frozen(dashmap::ReadOnlyView, BuildHasherDefault>), +} + +impl Freezable for CLatIndex { + fn freeze(&mut self) { + update(self, |_self| match _self { + CLatIndex::Unfrozen(dm) => Self::Frozen(dm.into_read_only()), + CLatIndex::Frozen(_) => _self, + }) + } + + fn unfreeze(&mut self) { + update(self, |_self| match _self { + CLatIndex::Frozen(v) => Self::Unfrozen(v.into_inner()), + CLatIndex::Unfrozen(_) => _self, + }) + } +} + +impl CLatIndex { + #[inline] + pub fn unwrap_frozen(&self) -> &dashmap::ReadOnlyView, BuildHasherDefault> { + match self { + CLatIndex::Frozen(v) => v, + CLatIndex::Unfrozen(_) => panic!("CRelIndex::unwrap_frozen(): object is Unfrozen"), + } + } + + #[inline] + pub fn unwrap_unfrozen(&self) -> &DashMap, BuildHasherDefault> { + match self { + CLatIndex::Unfrozen(dm) => dm, + CLatIndex::Frozen(_) => panic!("CRelIndex::unwrap_unfrozen(): object is Frozen"), + } + } + + #[inline] + pub fn unwrap_mut_unfrozen(&mut self) -> &mut DashMap, BuildHasherDefault> { + match self { + CLatIndex::Unfrozen(dm) => dm, + CLatIndex::Frozen(_) => panic!("CRelIndex::unwrap_unfrozen(): object is Frozen"), + } + } + + pub fn into_read_only(self) -> dashmap::ReadOnlyView, BuildHasherDefault> { + match self { + CLatIndex::Unfrozen(dm) => dm.into_read_only(), + CLatIndex::Frozen(f) => f, + } + } + + #[inline] + fn insert(&self, key: K, value: V) { + match self.unwrap_unfrozen().entry(key) { + dashmap::mapref::entry::Entry::Occupied(mut occ) => { + occ.get_mut().insert(value); + }, + dashmap::mapref::entry::Entry::Vacant(vac) => { + let mut set = SetType::default(); + set.insert(value); + vac.insert(set); + }, + } + } + + #[inline] + pub fn hash_usize(&self, k: &K) -> usize { self.unwrap_unfrozen().hash_usize(k) } +} + +impl Default for CLatIndex { + fn default() -> Self { + // Self::Unfrozen(Default::default()) + Self::Unfrozen(DashMap::with_hasher_and_shard_amount(Default::default(), shards_count())) + } +} + +impl<'a, K: 'a + Clone + Hash + Eq, V: 'a + Clone + Hash + Eq> RelIndexRead<'a> for CLatIndex { + type Key = K; + type Value = &'a V; + + type IteratorType = std::collections::hash_set::Iter<'a, V>; + + fn index_get(&'a self, key: &Self::Key) -> Option { + let vals = &self.unwrap_frozen().get(key)?; + let res = vals.iter(); + Some(res) + } + + fn len(&self) -> usize { self.unwrap_frozen().len() } +} + +impl<'a, K: 'a + Clone + Hash + Eq, V: 'a + Clone + Hash + Eq + Sync> CRelIndexRead<'a> for CLatIndex { + type Key = K; + type Value = &'a V; + + type IteratorType = rayon::collections::hash_set::Iter<'a, V>; + + fn c_index_get(&'a self, key: &Self::Key) -> Option { + use rayon::prelude::*; + let vals = &self.unwrap_frozen().get(key)?; + let res = vals.par_iter(); + Some(res) + } +} + +impl<'a, K: 'a + Clone + Hash + Eq + Send + Sync, V: 'a + Clone + Hash + Eq + Send + Sync> RelIndexWrite + for CLatIndex +{ + type Key = K; + type Value = V; + + fn index_insert(&mut self, key: Self::Key, value: Self::Value) { + let dm = self.unwrap_mut_unfrozen(); + // let shard = dm.determine_map(&key); + // let entry = dm.shards_mut()[shard].get_mut().entry(key).or_insert(SharedValue::new(Default::default())); + // entry.get_mut().push(value); + + let hash = dm.hash_usize(&key); + let shard = dm.determine_shard(hash); + let entry = dm.shards_mut()[shard] + .get_mut() + .raw_entry_mut() + .from_key_hashed_nocheck(hash as u64, &key) + .or_insert(key, SharedValue::new(Default::default())); + entry.1.get_mut().insert(value); + } +} + +impl<'a, K: 'a + Clone + Hash + Eq + Send + Sync, V: 'a + Clone + Hash + Eq + Send + Sync> RelIndexMerge + for CLatIndex +{ + fn move_index_contents(from: &mut Self, to: &mut Self) { + let before = Instant::now(); + let from = from.unwrap_mut_unfrozen(); + let to = to.unwrap_mut_unfrozen(); + + use rayon::prelude::*; + assert_eq!(from.shards().len(), to.shards().len()); + to.shards_mut().par_iter_mut().zip(from.shards_mut().par_iter_mut()).for_each(|(to, from)| { + let from = from.get_mut(); + let to = to.get_mut(); + + if from.len() > to.len() { + std::mem::swap(from, to); + } + + for (k, mut v) in from.drain() { + match to.entry(k) { + hashbrown::hash_map::Entry::Occupied(mut occ) => { + let occ = occ.get_mut().get_mut(); + let v_mut = v.get_mut(); + if v_mut.len() > occ.len() { + std::mem::swap(occ, v_mut); + } + let v = v.into_inner(); + occ.reserve(v.len()); + occ.extend(&mut v.into_iter()); + }, + hashbrown::hash_map::Entry::Vacant(vac) => { + vac.insert(v); + }, + } + } + }); + unsafe { + crate::internal::MOVE_REL_INDEX_CONTENTS_TOTAL_TIME += before.elapsed(); + } + } +} + +impl<'a, K: 'a + Clone + Hash + Eq, V: 'a + Clone + Hash + Eq> RelIndexReadAll<'a> for CLatIndex { + type Key = &'a K; + type Value = V; + + type ValueIteratorType = std::iter::Cloned>; + type AllIteratorType = Box + 'a>; + + fn iter_all(&'a self) -> Self::AllIteratorType { + // let res = DashMapViewParIter::new(self.unwrap_frozen()).map(|(k, v)| (k, v.iter().cloned())); + let res = self.unwrap_frozen().iter().map(|(k, v)| (k, v.iter().cloned())); + Box::new(res) as _ + } +} + +impl<'a, K: 'a + Clone + Hash + Eq + Sync + Send, V: 'a + Clone + Hash + Eq + Sync + Send> CRelIndexReadAll<'a> + for CLatIndex +{ + type Key = &'a K; + type Value = &'a V; + + type ValueIteratorType = rayon::collections::hash_set::Iter<'a, V>; + + type AllIteratorType = rayon::iter::Map< + DashMapViewParIter<'a, K, SetType, BuildHasherDefault>, + for<'aa, 'bb> fn((&'aa K, &'bb SetType)) -> (&'aa K, rayon::collections::hash_set::Iter<'bb, V>), + >; + + fn c_iter_all(&'a self) -> Self::AllIteratorType { + use rayon::prelude::*; + let res: Self::AllIteratorType = DashMapViewParIter::new(self.unwrap_frozen()).map(|(k, v)| (k, v.par_iter())); + res + } +} + +impl<'a, K: 'a + Clone + Hash + Eq, V: 'a + Clone + Hash + Eq> CRelIndexWrite for CLatIndex { + type Key = K; + type Value = V; + + #[inline(always)] + fn index_insert(&self, key: Self::Key, value: Self::Value) { + // let before = Instant::now(); + self.insert(key, value); + // unsafe { + // crate::internal::INDEX_INSERT_TOTAL_TIME += before.elapsed(); + // } + } +} diff --git a/ascent/src/c_rel_full_index.rs b/ascent/src/c_rel_full_index.rs index 544c3ac..0aefd7c 100644 --- a/ascent/src/c_rel_full_index.rs +++ b/ascent/src/c_rel_full_index.rs @@ -1,290 +1,285 @@ -use ascent_base::util::update; -use dashmap::{DashMap, SharedValue}; -use instant::Instant; -use rustc_hash::FxHasher; -use std::hash::{Hash, BuildHasherDefault}; - -use crate::c_rel_index::{DashMapViewParIter, shards_count}; -use crate::internal::{RelIndexWrite, CRelIndexWrite, RelFullIndexRead, RelFullIndexWrite, CRelFullIndexWrite, RelIndexMerge, Freezable}; -use crate::internal::{RelIndexRead, RelIndexReadAll, CRelIndexRead, CRelIndexReadAll}; - - -pub enum CRelFullIndex { - Unfrozen(DashMap>), - Frozen(dashmap::ReadOnlyView>) -} - -impl Freezable for CRelFullIndex { - fn freeze(&mut self) { - update(self, |_self| match _self { - CRelFullIndex::Unfrozen(dm) => Self::Frozen(dm.into_read_only()), - CRelFullIndex::Frozen(_) => _self, - }) - } - - fn unfreeze(&mut self) { - update(self, |_self| match _self { - CRelFullIndex::Frozen(v) => Self::Unfrozen(v.into_inner()), - CRelFullIndex::Unfrozen(dm) => CRelFullIndex::Unfrozen(dm), - }) - } -} - -impl CRelFullIndex { - - - pub fn exact_len(&self) -> usize { - match self { - CRelFullIndex::Unfrozen(uf) => uf.len(), - CRelFullIndex::Frozen(f) => f.len(), - } - } - - #[inline] - pub fn unwrap_frozen(&self) -> &dashmap::ReadOnlyView> { - match self { - CRelFullIndex::Frozen(v) => v, - CRelFullIndex::Unfrozen(_) => panic!("CRelFullIndex::unwrap_frozen(): object is Unfrozen"), - } - } - - #[inline] - pub fn unwrap_unfrozen(&self) -> &DashMap> { - match self { - CRelFullIndex::Unfrozen(dm) => dm, - CRelFullIndex::Frozen(_) => panic!("CRelFullIndex::unwrap_unfrozen(): object is Frozen"), - } - } - - #[inline] - pub fn unwrap_mut_unfrozen(&mut self) -> &mut DashMap> { - match self { - CRelFullIndex::Frozen(_) => panic!("CRelFullIndex::unwrap_mut_unfrozen(): object is Frozen"), - CRelFullIndex::Unfrozen(dm) => dm, - } - } - - pub fn into_read_only(self) -> dashmap::ReadOnlyView> { - match self { - CRelFullIndex::Unfrozen(dm) => dm.into_read_only(), - CRelFullIndex::Frozen(f) => f, - } - } - - #[inline] - fn insert(&self, key: K, value: V) { - self.unwrap_unfrozen().insert(key, value); - } - - pub fn hash_usize(&self, k: &K) -> usize { - self.unwrap_unfrozen().hash_usize(k) - } - - pub fn get_cloned(&self, key: &K) -> Option where V: Clone { - match self { - CRelFullIndex::Unfrozen(uf) => uf.get(key).map(|x| x.value().clone()), - CRelFullIndex::Frozen(f) => f.get(key).cloned(), - } - } - - - pub fn insert_if_not_present2(&self, key: &K, value: V) -> bool { - let dm = self.unwrap_unfrozen(); - - let hash = dm.hash_usize(&key); - - let idx = dm.determine_shard(hash); - use dashmap::Map; - let mut shard = unsafe { dm._yield_write_shard(idx) }; - - match shard.raw_entry_mut().from_key_hashed_nocheck(hash as u64, key) { - hashbrown::hash_map::RawEntryMut::Occupied(_) => false, - hashbrown::hash_map::RawEntryMut::Vacant(vac) => { - vac.insert(key.clone(), SharedValue::new(value)); - true - }, - } - } -} - -impl Default for CRelFullIndex { - fn default() -> Self { - // Self::Unfrozen(Default::default()) - Self::Unfrozen(DashMap::with_hasher_and_shard_amount(Default::default(), shards_count())) - } -} - -impl<'a, K: 'a + Clone + Hash + Eq, V: 'a> RelIndexRead<'a> for CRelFullIndex { - type Key = K; - type Value = &'a V; - - type IteratorType = std::iter::Once<&'a V>; - - fn index_get(&'a self, key: &Self::Key) -> Option { - let val = self.unwrap_frozen().get(key)?; - let res = std::iter::once(val); - Some(res) - } - - fn len(&self) -> usize { - self.unwrap_frozen().len() - } -} - -impl<'a, K: 'a + Clone + Hash + Eq, V: 'a + Sync> CRelIndexRead<'a> for CRelFullIndex { - type Key = K; - - type Value = &'a V; - - type IteratorType = rayon::iter::Once<&'a V>; - - fn c_index_get(&'a self, key: &Self::Key) -> Option { - let val = self.unwrap_frozen().get(key)?; - let res = rayon::iter::once(val); - Some(res) - } - -} - -impl<'a, K: 'a + Clone + Hash + Eq, V: 'a> RelFullIndexRead<'a> for CRelFullIndex { - type Key = K; - - #[inline(always)] - fn contains_key(&self, key: &Self::Key) -> bool { - self.unwrap_frozen().contains_key(key) - } -} - -impl<'a, K: 'a + Clone + Hash + Eq, V: 'a> RelFullIndexWrite for CRelFullIndex { - type Key = K; - type Value = V; - - fn insert_if_not_present(&mut self, key: &Self::Key, v: Self::Value) -> bool { - self.unfreeze(); - match self.unwrap_mut_unfrozen().entry(key.clone()) { - dashmap::mapref::entry::Entry::Occupied(_) => false, - dashmap::mapref::entry::Entry::Vacant(vac) => { - vac.insert(v); - true - }, - } - } -} - -impl<'a, K: 'a + Clone + Hash + Eq, V: 'a> CRelFullIndexWrite for CRelFullIndex { - type Key = K; - type Value = V; - - fn insert_if_not_present(&self, key: &Self::Key, v: Self::Value) -> bool { - - // TODO decide what to do here - // let before = Instant::now(); - - // let res = match self.unwrap_unfrozen().entry(key.clone()) { - // dashmap::mapref::entry::Entry::Occupied(_) => false, - // dashmap::mapref::entry::Entry::Vacant(vac) => { - // vac.insert_quick(v); - // true - // }, - // }; - let res = self.insert_if_not_present2(key, v); - // unsafe { - // crate::internal::INDEX_INSERT_TOTAL_TIME += before.elapsed(); - // } - - #[allow(clippy::let_and_return)] - res - } -} - -impl<'a, K: 'a + Clone + Hash + Eq, V: 'a + Clone> RelIndexReadAll<'a> for CRelFullIndex { - type Key = &'a K; - type Value = V; - - type ValueIteratorType = std::iter::Once; - type AllIteratorType = Box + 'a>; - - fn iter_all(&'a self) -> Self::AllIteratorType { - let res = self.unwrap_frozen().iter().map(|(k, v)| (k, std::iter::once(v.clone()))); - Box::new(res) as _ - } -} - -impl<'a, K: 'a + Clone + Hash + Eq + Send + Sync, V: 'a + Clone + Send + Sync> CRelIndexReadAll<'a> for CRelFullIndex { - type Key = &'a K; - type Value = &'a V; - - type ValueIteratorType = rayon::iter::Once<&'a V>; - // type AllIteratorType = Box + 'a>; - - type AllIteratorType = - rayon::iter::Map>, for<'aa, 'bb> fn((&'aa K, &'bb V)) -> (&'aa K, rayon::iter::Once<&'bb V>)>; - - - fn c_iter_all(&'a self) -> Self::AllIteratorType { - use rayon::prelude::*; - let res: Self::AllIteratorType = DashMapViewParIter::new(self.unwrap_frozen()).map(|(k, v)| (k, rayon::iter::once(v))); - res - } -} - -impl<'a, K: 'a + Clone + Hash + Eq + Send + Sync, V: 'a + Send + Sync> RelIndexWrite for CRelFullIndex { - type Key = K; - type Value = V; - - fn index_insert(&mut self, key: Self::Key, value: Self::Value) { - let dm = self.unwrap_mut_unfrozen(); - - // let shard = dm.determine_map(&key); - // dm.shards_mut()[shard].get_mut().insert(key, SharedValue::new(value)); - - let hash = dm.hash_usize(&key); - let shard = dm.determine_shard(hash); - dm.shards_mut()[shard].get_mut().raw_entry_mut() - .from_key_hashed_nocheck(hash as u64, &key) - .insert(key, SharedValue::new(value)); - } -} - -impl<'a, K: 'a + Clone + Hash + Eq + Send + Sync, V: 'a + Send + Sync> RelIndexMerge for CRelFullIndex { - fn move_index_contents(from: &mut Self, to: &mut Self) { - let before = Instant::now(); - - let from = from.unwrap_mut_unfrozen(); - let to = to.unwrap_mut_unfrozen(); - - use rayon::prelude::*; - assert_eq!(from.shards().len(), to.shards().len()); - to.shards_mut().par_iter_mut().zip(from.shards_mut().par_iter_mut()).for_each(|(to, from)| { - let from = from.get_mut(); - let to = to.get_mut(); - - if from.len() > to.len() { - std::mem::swap(from, to); - } - to.reserve(from.len()); - for (k, v) in from.drain() { - to.insert(k, v); - } - }); - - unsafe { - crate::internal::MOVE_FULL_INDEX_CONTENTS_TOTAL_TIME += before.elapsed(); - } - - } -} - -impl<'a, K: 'a + Clone + Hash + Eq, V: 'a> CRelIndexWrite for CRelFullIndex { - type Key = K; - type Value = V; - - #[inline(always)] - fn index_insert(&self, key: Self::Key, value: Self::Value) { - // let before = Instant::now(); - self.insert(key, value); - // unsafe { - // crate::internal::INDEX_INSERT_TOTAL_TIME += before.elapsed(); - // } - } -} \ No newline at end of file +use std::hash::{BuildHasherDefault, Hash}; + +use ascent_base::util::update; +use dashmap::{DashMap, SharedValue}; +use instant::Instant; +use rustc_hash::FxHasher; + +use crate::c_rel_index::{DashMapViewParIter, shards_count}; +use crate::internal::{ + CRelFullIndexWrite, CRelIndexRead, CRelIndexReadAll, CRelIndexWrite, Freezable, RelFullIndexRead, RelFullIndexWrite, + RelIndexMerge, RelIndexRead, RelIndexReadAll, RelIndexWrite, +}; + +pub enum CRelFullIndex { + Unfrozen(DashMap>), + Frozen(dashmap::ReadOnlyView>), +} + +impl Freezable for CRelFullIndex { + fn freeze(&mut self) { + update(self, |_self| match _self { + CRelFullIndex::Unfrozen(dm) => Self::Frozen(dm.into_read_only()), + CRelFullIndex::Frozen(_) => _self, + }) + } + + fn unfreeze(&mut self) { + update(self, |_self| match _self { + CRelFullIndex::Frozen(v) => Self::Unfrozen(v.into_inner()), + CRelFullIndex::Unfrozen(dm) => CRelFullIndex::Unfrozen(dm), + }) + } +} + +impl CRelFullIndex { + pub fn exact_len(&self) -> usize { + match self { + CRelFullIndex::Unfrozen(uf) => uf.len(), + CRelFullIndex::Frozen(f) => f.len(), + } + } + + #[inline] + pub fn unwrap_frozen(&self) -> &dashmap::ReadOnlyView> { + match self { + CRelFullIndex::Frozen(v) => v, + CRelFullIndex::Unfrozen(_) => panic!("CRelFullIndex::unwrap_frozen(): object is Unfrozen"), + } + } + + #[inline] + pub fn unwrap_unfrozen(&self) -> &DashMap> { + match self { + CRelFullIndex::Unfrozen(dm) => dm, + CRelFullIndex::Frozen(_) => panic!("CRelFullIndex::unwrap_unfrozen(): object is Frozen"), + } + } + + #[inline] + pub fn unwrap_mut_unfrozen(&mut self) -> &mut DashMap> { + match self { + CRelFullIndex::Frozen(_) => panic!("CRelFullIndex::unwrap_mut_unfrozen(): object is Frozen"), + CRelFullIndex::Unfrozen(dm) => dm, + } + } + + pub fn into_read_only(self) -> dashmap::ReadOnlyView> { + match self { + CRelFullIndex::Unfrozen(dm) => dm.into_read_only(), + CRelFullIndex::Frozen(f) => f, + } + } + + #[inline] + fn insert(&self, key: K, value: V) { self.unwrap_unfrozen().insert(key, value); } + + pub fn hash_usize(&self, k: &K) -> usize { self.unwrap_unfrozen().hash_usize(k) } + + pub fn get_cloned(&self, key: &K) -> Option + where V: Clone { + match self { + CRelFullIndex::Unfrozen(uf) => uf.get(key).map(|x| x.value().clone()), + CRelFullIndex::Frozen(f) => f.get(key).cloned(), + } + } + + pub fn insert_if_not_present2(&self, key: &K, value: V) -> bool { + let dm = self.unwrap_unfrozen(); + + let hash = dm.hash_usize(&key); + + let idx = dm.determine_shard(hash); + use dashmap::Map; + let mut shard = unsafe { dm._yield_write_shard(idx) }; + + match shard.raw_entry_mut().from_key_hashed_nocheck(hash as u64, key) { + hashbrown::hash_map::RawEntryMut::Occupied(_) => false, + hashbrown::hash_map::RawEntryMut::Vacant(vac) => { + vac.insert(key.clone(), SharedValue::new(value)); + true + }, + } + } +} + +impl Default for CRelFullIndex { + fn default() -> Self { + // Self::Unfrozen(Default::default()) + Self::Unfrozen(DashMap::with_hasher_and_shard_amount(Default::default(), shards_count())) + } +} + +impl<'a, K: 'a + Clone + Hash + Eq, V: 'a> RelIndexRead<'a> for CRelFullIndex { + type Key = K; + type Value = &'a V; + + type IteratorType = std::iter::Once<&'a V>; + + fn index_get(&'a self, key: &Self::Key) -> Option { + let val = self.unwrap_frozen().get(key)?; + let res = std::iter::once(val); + Some(res) + } + + fn len(&self) -> usize { self.unwrap_frozen().len() } +} + +impl<'a, K: 'a + Clone + Hash + Eq, V: 'a + Sync> CRelIndexRead<'a> for CRelFullIndex { + type Key = K; + + type Value = &'a V; + + type IteratorType = rayon::iter::Once<&'a V>; + + fn c_index_get(&'a self, key: &Self::Key) -> Option { + let val = self.unwrap_frozen().get(key)?; + let res = rayon::iter::once(val); + Some(res) + } +} + +impl<'a, K: 'a + Clone + Hash + Eq, V: 'a> RelFullIndexRead<'a> for CRelFullIndex { + type Key = K; + + #[inline(always)] + fn contains_key(&self, key: &Self::Key) -> bool { self.unwrap_frozen().contains_key(key) } +} + +impl<'a, K: 'a + Clone + Hash + Eq, V: 'a> RelFullIndexWrite for CRelFullIndex { + type Key = K; + type Value = V; + + fn insert_if_not_present(&mut self, key: &Self::Key, v: Self::Value) -> bool { + self.unfreeze(); + match self.unwrap_mut_unfrozen().entry(key.clone()) { + dashmap::mapref::entry::Entry::Occupied(_) => false, + dashmap::mapref::entry::Entry::Vacant(vac) => { + vac.insert(v); + true + }, + } + } +} + +impl<'a, K: 'a + Clone + Hash + Eq, V: 'a> CRelFullIndexWrite for CRelFullIndex { + type Key = K; + type Value = V; + + fn insert_if_not_present(&self, key: &Self::Key, v: Self::Value) -> bool { + // TODO decide what to do here + // let before = Instant::now(); + + // let res = match self.unwrap_unfrozen().entry(key.clone()) { + // dashmap::mapref::entry::Entry::Occupied(_) => false, + // dashmap::mapref::entry::Entry::Vacant(vac) => { + // vac.insert_quick(v); + // true + // }, + // }; + let res = self.insert_if_not_present2(key, v); + // unsafe { + // crate::internal::INDEX_INSERT_TOTAL_TIME += before.elapsed(); + // } + + #[allow(clippy::let_and_return)] + res + } +} + +impl<'a, K: 'a + Clone + Hash + Eq, V: 'a + Clone> RelIndexReadAll<'a> for CRelFullIndex { + type Key = &'a K; + type Value = V; + + type ValueIteratorType = std::iter::Once; + type AllIteratorType = Box + 'a>; + + fn iter_all(&'a self) -> Self::AllIteratorType { + let res = self.unwrap_frozen().iter().map(|(k, v)| (k, std::iter::once(v.clone()))); + Box::new(res) as _ + } +} + +impl<'a, K: 'a + Clone + Hash + Eq + Send + Sync, V: 'a + Clone + Send + Sync> CRelIndexReadAll<'a> + for CRelFullIndex +{ + type Key = &'a K; + type Value = &'a V; + + type ValueIteratorType = rayon::iter::Once<&'a V>; + // type AllIteratorType = Box + 'a>; + + type AllIteratorType = rayon::iter::Map< + DashMapViewParIter<'a, K, V, BuildHasherDefault>, + for<'aa, 'bb> fn((&'aa K, &'bb V)) -> (&'aa K, rayon::iter::Once<&'bb V>), + >; + + fn c_iter_all(&'a self) -> Self::AllIteratorType { + use rayon::prelude::*; + let res: Self::AllIteratorType = + DashMapViewParIter::new(self.unwrap_frozen()).map(|(k, v)| (k, rayon::iter::once(v))); + res + } +} + +impl<'a, K: 'a + Clone + Hash + Eq + Send + Sync, V: 'a + Send + Sync> RelIndexWrite for CRelFullIndex { + type Key = K; + type Value = V; + + fn index_insert(&mut self, key: Self::Key, value: Self::Value) { + let dm = self.unwrap_mut_unfrozen(); + + // let shard = dm.determine_map(&key); + // dm.shards_mut()[shard].get_mut().insert(key, SharedValue::new(value)); + + let hash = dm.hash_usize(&key); + let shard = dm.determine_shard(hash); + dm.shards_mut()[shard] + .get_mut() + .raw_entry_mut() + .from_key_hashed_nocheck(hash as u64, &key) + .insert(key, SharedValue::new(value)); + } +} + +impl<'a, K: 'a + Clone + Hash + Eq + Send + Sync, V: 'a + Send + Sync> RelIndexMerge for CRelFullIndex { + fn move_index_contents(from: &mut Self, to: &mut Self) { + let before = Instant::now(); + + let from = from.unwrap_mut_unfrozen(); + let to = to.unwrap_mut_unfrozen(); + + use rayon::prelude::*; + assert_eq!(from.shards().len(), to.shards().len()); + to.shards_mut().par_iter_mut().zip(from.shards_mut().par_iter_mut()).for_each(|(to, from)| { + let from = from.get_mut(); + let to = to.get_mut(); + + if from.len() > to.len() { + std::mem::swap(from, to); + } + to.reserve(from.len()); + for (k, v) in from.drain() { + to.insert(k, v); + } + }); + + unsafe { + crate::internal::MOVE_FULL_INDEX_CONTENTS_TOTAL_TIME += before.elapsed(); + } + } +} + +impl<'a, K: 'a + Clone + Hash + Eq, V: 'a> CRelIndexWrite for CRelFullIndex { + type Key = K; + type Value = V; + + #[inline(always)] + fn index_insert(&self, key: Self::Key, value: Self::Value) { + // let before = Instant::now(); + self.insert(key, value); + // unsafe { + // crate::internal::INDEX_INSERT_TOTAL_TIME += before.elapsed(); + // } + } +} diff --git a/ascent/src/c_rel_index.rs b/ascent/src/c_rel_index.rs index 6819a7f..a633397 100644 --- a/ascent/src/c_rel_index.rs +++ b/ascent/src/c_rel_index.rs @@ -1,329 +1,326 @@ -use ascent_base::util::update; -use dashmap::{DashMap, RwLock, SharedValue, ReadOnlyView}; -use instant::Instant; -use rayon::iter::plumbing::UnindexedConsumer; -use rayon::prelude::{ParallelIterator, IntoParallelRefIterator}; -use rustc_hash::FxHasher; -use std::hash::{Hash, BuildHasherDefault, BuildHasher}; - -use crate::internal::{RelIndexWrite, CRelIndexWrite, RelIndexMerge, Freezable}; -use crate::internal::{RelIndexRead, RelIndexReadAll, CRelIndexRead, CRelIndexReadAll}; - -use rayon::iter::IntoParallelIterator; - - -type VecType = Vec; -pub enum CRelIndex { - Unfrozen(DashMap, BuildHasherDefault>), - Frozen(dashmap::ReadOnlyView, BuildHasherDefault>) -} - -impl Freezable for CRelIndex { - fn freeze(&mut self) { - update(self, |_self| match _self { - CRelIndex::Unfrozen(dm) => Self::Frozen(dm.into_read_only()), - CRelIndex::Frozen(_) => _self, - }) - } - - fn unfreeze(&mut self) { - update(self, |_self| match _self { - CRelIndex::Frozen(v) => Self::Unfrozen(v.into_inner()), - CRelIndex::Unfrozen(_) => _self, - }) - } -} - -impl CRelIndex { - - #[inline] - pub fn unwrap_frozen(&self) -> &dashmap::ReadOnlyView, BuildHasherDefault> { - match self { - CRelIndex::Frozen(v) => v, - CRelIndex::Unfrozen(_) => panic!("CRelIndex::unwrap_frozen(): object is Unfrozen"), - } - } - - #[inline] - pub fn unwrap_unfrozen(&self) -> &DashMap, BuildHasherDefault> { - match self { - CRelIndex::Unfrozen(dm) => dm, - CRelIndex::Frozen(_) => panic!("CRelIndex::unwrap_unfrozen(): object is Frozen"), - } - } - - #[inline] - pub fn unwrap_mut_unfrozen(&mut self) -> &mut DashMap, BuildHasherDefault> { - match self { - CRelIndex::Unfrozen(dm) => dm, - CRelIndex::Frozen(_) => panic!("CRelIndex::unwrap_unfrozen(): object is Frozen"), - } - } - - pub fn into_read_only(self) -> dashmap::ReadOnlyView, BuildHasherDefault> { - match self { - CRelIndex::Unfrozen(dm) => dm.into_read_only(), - CRelIndex::Frozen(f) => f, - } - } - - #[inline] - #[allow(dead_code)] - // TODO remove if not used - fn insert(&self, key: K, value: V) { - match self.unwrap_unfrozen().entry(key) { - dashmap::mapref::entry::Entry::Occupied(mut occ) => {occ.get_mut().push(value);}, - dashmap::mapref::entry::Entry::Vacant(vac) => {vac.insert(vec![value]);}, - } - } - - #[inline] - #[allow(dead_code)] - // TODO remove if not used - fn insert2(&self, key: K, value: V) { - use std::hash::Hasher; - use dashmap::Map; - - let dm = self.unwrap_unfrozen(); - let mut hasher = dm.hasher().build_hasher(); - key.hash(&mut hasher); - let hash = hasher.finish(); - - let idx = dm.determine_shard(hash as usize); - let mut shard = unsafe { dm._yield_write_shard(idx) }; - - match shard.raw_entry_mut().from_key_hashed_nocheck(hash, &key) { - hashbrown::hash_map::RawEntryMut::Occupied(mut occ) => { - occ.get_mut().get_mut().push(value); - }, - hashbrown::hash_map::RawEntryMut::Vacant(vac) => { - vac.insert(key, SharedValue::new(vec![value])); - }, - } - } - - #[inline] - pub fn hash_usize(&self, k: &K) -> usize { - self.unwrap_unfrozen().hash_usize(k) - } -} - -impl Default for CRelIndex { - fn default() -> Self { - // Self::Unfrozen(Default::default()) - Self::Unfrozen(DashMap::with_hasher_and_shard_amount(Default::default(), shards_count())) - } -} - -impl<'a, K: 'a + Clone + Hash + Eq, V: 'a> RelIndexRead<'a> for CRelIndex { - type Key = K; - type Value = &'a V; - - type IteratorType = std::slice::Iter<'a, V>; - - fn index_get(&'a self, key: &Self::Key) -> Option { - let vals = &self.unwrap_frozen().get(key)?; - let res = vals.iter(); - Some(res) - } - - fn len(&self) -> usize { - // approximate len - let sample_size = 4; - let shards = self.unwrap_frozen().shards(); - let (count, sum) = shards.iter().take(sample_size).fold((0, 0), |(c, s), shard| (c + 1, s + shard.read().len())); - sum * shards.len() / count - } -} - -impl<'a, K: 'a + Clone + Hash + Eq, V: 'a + Sync> CRelIndexRead<'a> for CRelIndex { - type Key = K; - type Value = &'a V; - - type IteratorType = rayon::slice::Iter<'a, V>; - - fn c_index_get(&'a self, key: &Self::Key) -> Option { - use rayon::prelude::*; - let vals = &self.unwrap_frozen().get(key)?; - let res = vals.as_slice().par_iter(); - Some(res) - } - -} - -impl<'a, K: 'a + Clone + Hash + Eq + Send + Sync, V: 'a + Send + Sync> RelIndexWrite for CRelIndex { - type Key = K; - type Value = V; - - fn index_insert(&mut self, key: Self::Key, value: Self::Value) { - let dm = self.unwrap_mut_unfrozen(); - // let shard = dm.determine_map(&key); - // let entry = dm.shards_mut()[shard].get_mut().entry(key).or_insert(SharedValue::new(Default::default())); - // entry.get_mut().push(value); - - let hash = dm.hash_usize(&key); - let shard = dm.determine_shard(hash); - let entry = dm.shards_mut()[shard].get_mut().raw_entry_mut() - .from_key_hashed_nocheck(hash as u64, &key) - .or_insert(key, SharedValue::new(Default::default())); - entry.1.get_mut().push(value); - } -} - -impl<'a, K: 'a + Clone + Hash + Eq + Send + Sync, V: 'a + Send + Sync> RelIndexMerge for CRelIndex { - fn move_index_contents(from: &mut Self, to: &mut Self) { - let before = Instant::now(); - let from = from.unwrap_mut_unfrozen(); - let to = to.unwrap_mut_unfrozen(); - - use rayon::prelude::*; - assert_eq!(from.shards().len(), to.shards().len()); - to.shards_mut().par_iter_mut().zip(from.shards_mut().par_iter_mut()).for_each(|(to, from)| { - let from = from.get_mut(); - let to = to.get_mut(); - - if from.len() > to.len() { - std::mem::swap(from, to); - } - - for (k, mut v) in from.drain() { - match to.entry(k) { - hashbrown::hash_map::Entry::Occupied(mut occ) => { - let occ = occ.get_mut().get_mut(); - let v_mut = v.get_mut(); - if v_mut.len() > occ.len() { - std::mem::swap(occ, v_mut); - } - occ.append(&mut v.into_inner()); - }, - hashbrown::hash_map::Entry::Vacant(vac) => {vac.insert(v);}, - } - } - - }); - unsafe { - crate::internal::MOVE_REL_INDEX_CONTENTS_TOTAL_TIME += before.elapsed(); - } - - } -} - -impl<'a, K: 'a + Clone + Hash + Eq, V: Clone + 'a> RelIndexReadAll<'a> for CRelIndex { - type Key = &'a K; - type Value = &'a V; - - type ValueIteratorType = std::slice::Iter<'a, V>; - - type AllIteratorType = Box + 'a>; - - fn iter_all(&'a self) -> Self::AllIteratorType { - let res = self.unwrap_frozen().iter().map(|(k, v)| (k, v.iter())); - Box::new(res) as _ - } -} -pub struct DashMapViewParIter<'a, K, V, S> { - shards: &'a [RwLock, S>>] -} - -impl<'a, K, V, S> Clone for DashMapViewParIter<'a, K, V, S> { - fn clone(&self) -> Self { - Self { shards: self.shards } - } -} - -impl<'a, K: Eq + Hash, V, S: BuildHasher + Clone> DashMapViewParIter<'a, K, V, S> { - pub fn new(v: &'a ReadOnlyView) -> Self { - Self { - shards: v.shards() - } - } -} - -// taken from DashMap rayon::map::Iter ParallelIterator impl -impl<'a, K, V, S> ParallelIterator for DashMapViewParIter<'a, K, V, S> -where - K: Send + Sync + Eq + Hash, - V: Send + Sync, - S: Send + Sync + Clone + BuildHasher, -{ - type Item = (&'a K, &'a V); - - fn drive_unindexed(self, consumer: C) -> C::Result - where - C: UnindexedConsumer, - { - self.shards - .into_par_iter() - .flat_map(|shard| { - let sref = unsafe { shard.data_ptr().as_ref().unwrap() }; - sref.par_iter().map(move |(k, v)| { - (k, v.get()) - }) - }) - .drive_unindexed(consumer) - } -} - -type CRelIndexReadAllParIterShard = hashbrown::HashMap>, S>; - -pub struct CRelIndexReadAllParIter<'a, K, V, S> { - shards: &'a [RwLock>] -} - -impl<'a, K, V, S> ParallelIterator for CRelIndexReadAllParIter<'a, K, V, S> -where - K: Send + Sync + Eq + Hash, - V: Send + Sync, - S: Send + Sync + Clone + BuildHasher, -{ - type Item = (&'a K, rayon::slice::Iter<'a, V>); - - fn drive_unindexed(self, consumer: C) -> C::Result - where - C: UnindexedConsumer, - { - self.shards.into_par_iter() - .flat_map(|shard| { - let sref = unsafe { shard.data_ptr().as_ref().unwrap() }; - sref.par_iter().map(|(k, v)| (k, v.get().par_iter())) - }).drive_unindexed(consumer) - } -} - -impl<'a, K: 'a + Clone + Hash + Eq + Sync + Send, V: Clone + 'a + Sync + Send> CRelIndexReadAll<'a> for CRelIndex { - type Key = &'a K; - type Value = &'a V; - - type ValueIteratorType = rayon::slice::Iter<'a, V>; - - type AllIteratorType = CRelIndexReadAllParIter<'a, K, V, BuildHasherDefault>; - - #[inline] - fn c_iter_all(&'a self) -> Self::AllIteratorType { - CRelIndexReadAllParIter{shards: self.unwrap_frozen().shards()} - } -} - - -impl<'a, K: 'a + Clone + Hash + Eq, V: 'a> CRelIndexWrite for CRelIndex { - type Key = K; - type Value = V; - - #[inline(always)] - fn index_insert(&self, key: Self::Key, value: Self::Value) { - // let before = Instant::now(); - self.insert(key, value); - // ind.insert2(key, value); - // unsafe { - // crate::internal::INDEX_INSERT_TOTAL_TIME += before.elapsed(); - // } - } -} - -pub fn shards_count() -> usize { - static RES: once_cell::sync::Lazy = once_cell::sync::Lazy::new(|| { - (rayon::current_num_threads() * 4).next_power_of_two() - // (std::thread::available_parallelism().map_or(1, usize::from) * 4).next_power_of_two() - }); - *RES -} \ No newline at end of file +use std::hash::{BuildHasher, BuildHasherDefault, Hash}; + +use ascent_base::util::update; +use dashmap::{DashMap, ReadOnlyView, RwLock, SharedValue}; +use instant::Instant; +use rayon::iter::IntoParallelIterator; +use rayon::iter::plumbing::UnindexedConsumer; +use rayon::prelude::{IntoParallelRefIterator, ParallelIterator}; +use rustc_hash::FxHasher; + +use crate::internal::{ + CRelIndexRead, CRelIndexReadAll, CRelIndexWrite, Freezable, RelIndexMerge, RelIndexRead, RelIndexReadAll, + RelIndexWrite, +}; + +type VecType = Vec; +pub enum CRelIndex { + Unfrozen(DashMap, BuildHasherDefault>), + Frozen(dashmap::ReadOnlyView, BuildHasherDefault>), +} + +impl Freezable for CRelIndex { + fn freeze(&mut self) { + update(self, |_self| match _self { + CRelIndex::Unfrozen(dm) => Self::Frozen(dm.into_read_only()), + CRelIndex::Frozen(_) => _self, + }) + } + + fn unfreeze(&mut self) { + update(self, |_self| match _self { + CRelIndex::Frozen(v) => Self::Unfrozen(v.into_inner()), + CRelIndex::Unfrozen(_) => _self, + }) + } +} + +impl CRelIndex { + #[inline] + pub fn unwrap_frozen(&self) -> &dashmap::ReadOnlyView, BuildHasherDefault> { + match self { + CRelIndex::Frozen(v) => v, + CRelIndex::Unfrozen(_) => panic!("CRelIndex::unwrap_frozen(): object is Unfrozen"), + } + } + + #[inline] + pub fn unwrap_unfrozen(&self) -> &DashMap, BuildHasherDefault> { + match self { + CRelIndex::Unfrozen(dm) => dm, + CRelIndex::Frozen(_) => panic!("CRelIndex::unwrap_unfrozen(): object is Frozen"), + } + } + + #[inline] + pub fn unwrap_mut_unfrozen(&mut self) -> &mut DashMap, BuildHasherDefault> { + match self { + CRelIndex::Unfrozen(dm) => dm, + CRelIndex::Frozen(_) => panic!("CRelIndex::unwrap_unfrozen(): object is Frozen"), + } + } + + pub fn into_read_only(self) -> dashmap::ReadOnlyView, BuildHasherDefault> { + match self { + CRelIndex::Unfrozen(dm) => dm.into_read_only(), + CRelIndex::Frozen(f) => f, + } + } + + #[inline] + #[allow(dead_code)] + // TODO remove if not used + fn insert(&self, key: K, value: V) { + match self.unwrap_unfrozen().entry(key) { + dashmap::mapref::entry::Entry::Occupied(mut occ) => { + occ.get_mut().push(value); + }, + dashmap::mapref::entry::Entry::Vacant(vac) => { + vac.insert(vec![value]); + }, + } + } + + #[inline] + #[allow(dead_code)] + // TODO remove if not used + fn insert2(&self, key: K, value: V) { + use std::hash::Hasher; + + use dashmap::Map; + + let dm = self.unwrap_unfrozen(); + let mut hasher = dm.hasher().build_hasher(); + key.hash(&mut hasher); + let hash = hasher.finish(); + + let idx = dm.determine_shard(hash as usize); + let mut shard = unsafe { dm._yield_write_shard(idx) }; + + match shard.raw_entry_mut().from_key_hashed_nocheck(hash, &key) { + hashbrown::hash_map::RawEntryMut::Occupied(mut occ) => { + occ.get_mut().get_mut().push(value); + }, + hashbrown::hash_map::RawEntryMut::Vacant(vac) => { + vac.insert(key, SharedValue::new(vec![value])); + }, + } + } + + #[inline] + pub fn hash_usize(&self, k: &K) -> usize { self.unwrap_unfrozen().hash_usize(k) } +} + +impl Default for CRelIndex { + fn default() -> Self { + // Self::Unfrozen(Default::default()) + Self::Unfrozen(DashMap::with_hasher_and_shard_amount(Default::default(), shards_count())) + } +} + +impl<'a, K: 'a + Clone + Hash + Eq, V: 'a> RelIndexRead<'a> for CRelIndex { + type Key = K; + type Value = &'a V; + + type IteratorType = std::slice::Iter<'a, V>; + + fn index_get(&'a self, key: &Self::Key) -> Option { + let vals = &self.unwrap_frozen().get(key)?; + let res = vals.iter(); + Some(res) + } + + fn len(&self) -> usize { + // approximate len + let sample_size = 4; + let shards = self.unwrap_frozen().shards(); + let (count, sum) = shards.iter().take(sample_size).fold((0, 0), |(c, s), shard| (c + 1, s + shard.read().len())); + sum * shards.len() / count + } +} + +impl<'a, K: 'a + Clone + Hash + Eq, V: 'a + Sync> CRelIndexRead<'a> for CRelIndex { + type Key = K; + type Value = &'a V; + + type IteratorType = rayon::slice::Iter<'a, V>; + + fn c_index_get(&'a self, key: &Self::Key) -> Option { + use rayon::prelude::*; + let vals = &self.unwrap_frozen().get(key)?; + let res = vals.as_slice().par_iter(); + Some(res) + } +} + +impl<'a, K: 'a + Clone + Hash + Eq + Send + Sync, V: 'a + Send + Sync> RelIndexWrite for CRelIndex { + type Key = K; + type Value = V; + + fn index_insert(&mut self, key: Self::Key, value: Self::Value) { + let dm = self.unwrap_mut_unfrozen(); + // let shard = dm.determine_map(&key); + // let entry = dm.shards_mut()[shard].get_mut().entry(key).or_insert(SharedValue::new(Default::default())); + // entry.get_mut().push(value); + + let hash = dm.hash_usize(&key); + let shard = dm.determine_shard(hash); + let entry = dm.shards_mut()[shard] + .get_mut() + .raw_entry_mut() + .from_key_hashed_nocheck(hash as u64, &key) + .or_insert(key, SharedValue::new(Default::default())); + entry.1.get_mut().push(value); + } +} + +impl<'a, K: 'a + Clone + Hash + Eq + Send + Sync, V: 'a + Send + Sync> RelIndexMerge for CRelIndex { + fn move_index_contents(from: &mut Self, to: &mut Self) { + let before = Instant::now(); + let from = from.unwrap_mut_unfrozen(); + let to = to.unwrap_mut_unfrozen(); + + use rayon::prelude::*; + assert_eq!(from.shards().len(), to.shards().len()); + to.shards_mut().par_iter_mut().zip(from.shards_mut().par_iter_mut()).for_each(|(to, from)| { + let from = from.get_mut(); + let to = to.get_mut(); + + if from.len() > to.len() { + std::mem::swap(from, to); + } + + for (k, mut v) in from.drain() { + match to.entry(k) { + hashbrown::hash_map::Entry::Occupied(mut occ) => { + let occ = occ.get_mut().get_mut(); + let v_mut = v.get_mut(); + if v_mut.len() > occ.len() { + std::mem::swap(occ, v_mut); + } + occ.append(&mut v.into_inner()); + }, + hashbrown::hash_map::Entry::Vacant(vac) => { + vac.insert(v); + }, + } + } + }); + unsafe { + crate::internal::MOVE_REL_INDEX_CONTENTS_TOTAL_TIME += before.elapsed(); + } + } +} + +impl<'a, K: 'a + Clone + Hash + Eq, V: Clone + 'a> RelIndexReadAll<'a> for CRelIndex { + type Key = &'a K; + type Value = &'a V; + + type ValueIteratorType = std::slice::Iter<'a, V>; + + type AllIteratorType = Box + 'a>; + + fn iter_all(&'a self) -> Self::AllIteratorType { + let res = self.unwrap_frozen().iter().map(|(k, v)| (k, v.iter())); + Box::new(res) as _ + } +} +pub struct DashMapViewParIter<'a, K, V, S> { + shards: &'a [RwLock, S>>], +} + +impl<'a, K, V, S> Clone for DashMapViewParIter<'a, K, V, S> { + fn clone(&self) -> Self { Self { shards: self.shards } } +} + +impl<'a, K: Eq + Hash, V, S: BuildHasher + Clone> DashMapViewParIter<'a, K, V, S> { + pub fn new(v: &'a ReadOnlyView) -> Self { Self { shards: v.shards() } } +} + +// taken from DashMap rayon::map::Iter ParallelIterator impl +impl<'a, K, V, S> ParallelIterator for DashMapViewParIter<'a, K, V, S> +where + K: Send + Sync + Eq + Hash, + V: Send + Sync, + S: Send + Sync + Clone + BuildHasher, +{ + type Item = (&'a K, &'a V); + + fn drive_unindexed(self, consumer: C) -> C::Result + where C: UnindexedConsumer { + self + .shards + .into_par_iter() + .flat_map(|shard| { + let sref = unsafe { shard.data_ptr().as_ref().unwrap() }; + sref.par_iter().map(move |(k, v)| (k, v.get())) + }) + .drive_unindexed(consumer) + } +} + +type CRelIndexReadAllParIterShard = hashbrown::HashMap>, S>; + +pub struct CRelIndexReadAllParIter<'a, K, V, S> { + shards: &'a [RwLock>], +} + +impl<'a, K, V, S> ParallelIterator for CRelIndexReadAllParIter<'a, K, V, S> +where + K: Send + Sync + Eq + Hash, + V: Send + Sync, + S: Send + Sync + Clone + BuildHasher, +{ + type Item = (&'a K, rayon::slice::Iter<'a, V>); + + fn drive_unindexed(self, consumer: C) -> C::Result + where C: UnindexedConsumer { + self + .shards + .into_par_iter() + .flat_map(|shard| { + let sref = unsafe { shard.data_ptr().as_ref().unwrap() }; + sref.par_iter().map(|(k, v)| (k, v.get().par_iter())) + }) + .drive_unindexed(consumer) + } +} + +impl<'a, K: 'a + Clone + Hash + Eq + Sync + Send, V: Clone + 'a + Sync + Send> CRelIndexReadAll<'a> + for CRelIndex +{ + type Key = &'a K; + type Value = &'a V; + + type ValueIteratorType = rayon::slice::Iter<'a, V>; + + type AllIteratorType = CRelIndexReadAllParIter<'a, K, V, BuildHasherDefault>; + + #[inline] + fn c_iter_all(&'a self) -> Self::AllIteratorType { + CRelIndexReadAllParIter { shards: self.unwrap_frozen().shards() } + } +} + +impl<'a, K: 'a + Clone + Hash + Eq, V: 'a> CRelIndexWrite for CRelIndex { + type Key = K; + type Value = V; + + #[inline(always)] + fn index_insert(&self, key: Self::Key, value: Self::Value) { + // let before = Instant::now(); + self.insert(key, value); + // ind.insert2(key, value); + // unsafe { + // crate::internal::INDEX_INSERT_TOTAL_TIME += before.elapsed(); + // } + } +} + +pub fn shards_count() -> usize { + static RES: once_cell::sync::Lazy = once_cell::sync::Lazy::new(|| { + (rayon::current_num_threads() * 4).next_power_of_two() + // (std::thread::available_parallelism().map_or(1, usize::from) * 4).next_power_of_two() + }); + *RES +} diff --git a/ascent/src/c_rel_index_combined.rs b/ascent/src/c_rel_index_combined.rs index 12b92dd..d94905c 100644 --- a/ascent/src/c_rel_index_combined.rs +++ b/ascent/src/c_rel_index_combined.rs @@ -1,38 +1,44 @@ -use crate::rel_index_read::RelIndexCombined; -use crate::internal::{CRelIndexRead, CRelIndexReadAll}; - -use rayon::prelude::*; - -impl <'a, Ind1, Ind2, K, V> CRelIndexRead<'a> for RelIndexCombined<'a, Ind1, Ind2> -where Ind1: CRelIndexRead<'a, Key = K, Value = V>, Ind2: CRelIndexRead<'a, Key = K, Value = V>, { - type Key = K; - - type Value = V; - - type IteratorType = rayon::iter::Chain>::IteratorType>>, - rayon::iter::Flatten>::IteratorType>>>; - - fn c_index_get(&'a self, key: &Self::Key) -> Option { - match (self.ind1.c_index_get(key), self.ind2.c_index_get(key)) { - (None, None) => None, - (iter1, iter2) => { - let res = iter1.into_par_iter().flatten().chain(iter2.into_par_iter().flatten()); - Some(res) - } - } - } -} - -impl <'a, Ind1, Ind2, K: 'a, V: 'a, VTI: ParallelIterator + 'a> CRelIndexReadAll<'a> for RelIndexCombined<'a, Ind1, Ind2> -where Ind1: CRelIndexReadAll<'a, Key = K, ValueIteratorType = VTI>, Ind2: CRelIndexReadAll<'a, Key = K, ValueIteratorType = VTI> -{ - type Key = K; - type Value = V; - - type ValueIteratorType = VTI; - type AllIteratorType = rayon::iter::Chain; - - fn c_iter_all(&'a self) -> Self::AllIteratorType { - self.ind1.c_iter_all().chain(self.ind2.c_iter_all()) - } -} \ No newline at end of file +use rayon::prelude::*; + +use crate::internal::{CRelIndexRead, CRelIndexReadAll}; +use crate::rel_index_read::RelIndexCombined; + +impl<'a, Ind1, Ind2, K, V> CRelIndexRead<'a> for RelIndexCombined<'a, Ind1, Ind2> +where + Ind1: CRelIndexRead<'a, Key = K, Value = V>, + Ind2: CRelIndexRead<'a, Key = K, Value = V>, +{ + type Key = K; + + type Value = V; + + type IteratorType = rayon::iter::Chain< + rayon::iter::Flatten>::IteratorType>>, + rayon::iter::Flatten>::IteratorType>>, + >; + + fn c_index_get(&'a self, key: &Self::Key) -> Option { + match (self.ind1.c_index_get(key), self.ind2.c_index_get(key)) { + (None, None) => None, + (iter1, iter2) => { + let res = iter1.into_par_iter().flatten().chain(iter2.into_par_iter().flatten()); + Some(res) + }, + } + } +} + +impl<'a, Ind1, Ind2, K: 'a, V: 'a, VTI: ParallelIterator + 'a> CRelIndexReadAll<'a> + for RelIndexCombined<'a, Ind1, Ind2> +where + Ind1: CRelIndexReadAll<'a, Key = K, ValueIteratorType = VTI>, + Ind2: CRelIndexReadAll<'a, Key = K, ValueIteratorType = VTI>, +{ + type Key = K; + type Value = V; + + type ValueIteratorType = VTI; + type AllIteratorType = rayon::iter::Chain; + + fn c_iter_all(&'a self) -> Self::AllIteratorType { self.ind1.c_iter_all().chain(self.ind2.c_iter_all()) } +} diff --git a/ascent/src/c_rel_index_read.rs b/ascent/src/c_rel_index_read.rs index 96ade4e..26028f3 100644 --- a/ascent/src/c_rel_index_read.rs +++ b/ascent/src/c_rel_index_read.rs @@ -1,17 +1,16 @@ -use rayon::iter::ParallelIterator; - - -pub trait CRelIndexRead<'a>{ - type Key; - type Value; - type IteratorType: ParallelIterator + Clone + 'a; - fn c_index_get(&'a self, key: &Self::Key) -> Option; -} - -pub trait CRelIndexReadAll<'a>{ - type Key: 'a; - type Value; - type ValueIteratorType: ParallelIterator + 'a; - type AllIteratorType: ParallelIterator + 'a; - fn c_iter_all(&'a self) -> Self::AllIteratorType; -} \ No newline at end of file +use rayon::iter::ParallelIterator; + +pub trait CRelIndexRead<'a> { + type Key; + type Value; + type IteratorType: ParallelIterator + Clone + 'a; + fn c_index_get(&'a self, key: &Self::Key) -> Option; +} + +pub trait CRelIndexReadAll<'a> { + type Key: 'a; + type Value; + type ValueIteratorType: ParallelIterator + 'a; + type AllIteratorType: ParallelIterator + 'a; + fn c_iter_all(&'a self) -> Self::AllIteratorType; +} diff --git a/ascent/src/c_rel_no_index.rs b/ascent/src/c_rel_no_index.rs index f367bd6..d6186e5 100644 --- a/ascent/src/c_rel_no_index.rs +++ b/ascent/src/c_rel_no_index.rs @@ -1,144 +1,146 @@ -use instant::Instant; -use rayon::prelude::{IntoParallelRefIterator, ParallelIterator}; - -use crate::internal::{RelIndexWrite, CRelIndexWrite, RelIndexMerge, Freezable}; -use crate::internal::{RelIndexRead, RelIndexReadAll, CRelIndexRead, CRelIndexReadAll}; -use dashmap::RwLock; - -pub struct CRelNoIndex { - // TODO remove pub - pub vec: Vec>>, - // vec: [RwLock>; 32], - frozen: bool, -} - -impl Default for CRelNoIndex { - #[inline] - fn default() -> Self { - let threads = rayon::current_num_threads().max(1); - let mut vec = Vec::with_capacity(threads); - for _ in 0..threads { - vec.push(RwLock::new(vec![])); - } - Self { vec, frozen: false } - - // Self { vec: array_init::array_init(|_| RwLock::new(vec![])), frozen: false } - } -} - -impl CRelNoIndex { - - pub fn hash_usize(&self, _key: &()) -> usize { 0 } -} - -impl Freezable for CRelNoIndex { - fn freeze(&mut self) { self.frozen = true; } - fn unfreeze(&mut self) { self.frozen = false; } -} - -impl<'a, V: 'a> RelIndexRead<'a> for CRelNoIndex { - type Key = (); - type Value = &'a V; - - type IteratorType = std::iter::FlatMap>>, std::slice::Iter<'a, V>, fn(&RwLock>) -> std::slice::Iter>; - - fn index_get(&'a self, _key: &Self::Key) -> Option { - assert!(self.frozen); - let res: Self::IteratorType = self.vec.iter().flat_map(|v| { - let data = unsafe { &*v.data_ptr()}; - data.iter() - }); - Some(res) - } - - #[inline(always)] - fn len(&self) -> usize { 1 } -} - -impl<'a, V: 'a + Sync + Send> CRelIndexRead<'a> for CRelNoIndex { - type Key = (); - type Value = &'a V; - - type IteratorType = rayon::iter::FlatMap>>, fn(&RwLock>) -> rayon::slice::Iter>; - - fn c_index_get(&'a self, _key: &Self::Key) -> Option { - assert!(self.frozen); - let res: Self::IteratorType = self.vec.par_iter().flat_map(|v| { - let data = unsafe {&* v.data_ptr()}; - data.par_iter() - }); - Some(res) - } -} - -impl<'a, V: 'a> RelIndexWrite for CRelNoIndex { - type Key = (); - type Value = V; - - fn index_insert(&mut self, _key: Self::Key, value: Self::Value) { - // not necessary because we have a mut reference - // assert!(!ind.frozen); - let shard_idx = rayon::current_thread_index().unwrap_or(0) % self.vec.len(); - self.vec[shard_idx].get_mut().push(value); - } -} - -impl<'a, V: 'a> RelIndexMerge for CRelNoIndex { - fn move_index_contents(from: &mut Self, to: &mut Self) { - let before = Instant::now(); - assert_eq!(from.len(), to.len()); - // not necessary because we have a mut reference - // assert!(!from.frozen); - // assert!(!to.frozen); - - from.vec.iter_mut().zip(to.vec.iter_mut()).for_each(|(from, to)| { - let from = from.get_mut(); - let to = to.get_mut(); - - if from.len() > to.len() { - std::mem::swap(from, to); - } - to.append(from); - }); - unsafe { - crate::internal::MOVE_NO_INDEX_CONTENTS_TOTAL_TIME += before.elapsed(); - } - } -} - -impl<'a, V: 'a> CRelIndexWrite for CRelNoIndex { - type Key = (); - type Value = V; - - fn index_insert(&self, _key: Self::Key, value: Self::Value) { - assert!(!self.frozen); - let shard_idx = rayon::current_thread_index().unwrap_or(0) % self.vec.len(); - self.vec[shard_idx].write().push(value); - } -} - -impl<'a, V: 'a> RelIndexReadAll<'a> for CRelNoIndex { - type Key = &'a (); - type Value = &'a V; - - type ValueIteratorType = >::IteratorType; - - type AllIteratorType = std::iter::Once<(&'a (), Self::ValueIteratorType)>; - - fn iter_all(&'a self) -> Self::AllIteratorType { - std::iter::once((&(), self.index_get(&()).unwrap())) - } -} - -impl<'a, V: 'a + Sync + Send> CRelIndexReadAll<'a> for CRelNoIndex { - type Key = &'a (); - type Value = &'a V; - - type ValueIteratorType = >::IteratorType; - - type AllIteratorType = rayon::iter::Once<(&'a (), Self::ValueIteratorType)>; - - fn c_iter_all(&'a self) -> Self::AllIteratorType { - rayon::iter::once((&(), self.c_index_get(&()).unwrap())) - } -} \ No newline at end of file +use dashmap::RwLock; +use instant::Instant; +use rayon::prelude::{IntoParallelRefIterator, ParallelIterator}; + +use crate::internal::{ + CRelIndexRead, CRelIndexReadAll, CRelIndexWrite, Freezable, RelIndexMerge, RelIndexRead, RelIndexReadAll, + RelIndexWrite, +}; + +pub struct CRelNoIndex { + // TODO remove pub + pub vec: Vec>>, + // vec: [RwLock>; 32], + frozen: bool, +} + +impl Default for CRelNoIndex { + #[inline] + fn default() -> Self { + let threads = rayon::current_num_threads().max(1); + let mut vec = Vec::with_capacity(threads); + for _ in 0..threads { + vec.push(RwLock::new(vec![])); + } + Self { vec, frozen: false } + + // Self { vec: array_init::array_init(|_| RwLock::new(vec![])), frozen: false } + } +} + +impl CRelNoIndex { + pub fn hash_usize(&self, _key: &()) -> usize { 0 } +} + +impl Freezable for CRelNoIndex { + fn freeze(&mut self) { self.frozen = true; } + fn unfreeze(&mut self) { self.frozen = false; } +} + +impl<'a, V: 'a> RelIndexRead<'a> for CRelNoIndex { + type Key = (); + type Value = &'a V; + + type IteratorType = std::iter::FlatMap< + std::slice::Iter<'a, RwLock>>, + std::slice::Iter<'a, V>, + fn(&RwLock>) -> std::slice::Iter, + >; + + fn index_get(&'a self, _key: &Self::Key) -> Option { + assert!(self.frozen); + let res: Self::IteratorType = self.vec.iter().flat_map(|v| { + let data = unsafe { &*v.data_ptr() }; + data.iter() + }); + Some(res) + } + + #[inline(always)] + fn len(&self) -> usize { 1 } +} + +impl<'a, V: 'a + Sync + Send> CRelIndexRead<'a> for CRelNoIndex { + type Key = (); + type Value = &'a V; + + type IteratorType = + rayon::iter::FlatMap>>, fn(&RwLock>) -> rayon::slice::Iter>; + + fn c_index_get(&'a self, _key: &Self::Key) -> Option { + assert!(self.frozen); + let res: Self::IteratorType = self.vec.par_iter().flat_map(|v| { + let data = unsafe { &*v.data_ptr() }; + data.par_iter() + }); + Some(res) + } +} + +impl<'a, V: 'a> RelIndexWrite for CRelNoIndex { + type Key = (); + type Value = V; + + fn index_insert(&mut self, _key: Self::Key, value: Self::Value) { + // not necessary because we have a mut reference + // assert!(!ind.frozen); + let shard_idx = rayon::current_thread_index().unwrap_or(0) % self.vec.len(); + self.vec[shard_idx].get_mut().push(value); + } +} + +impl<'a, V: 'a> RelIndexMerge for CRelNoIndex { + fn move_index_contents(from: &mut Self, to: &mut Self) { + let before = Instant::now(); + assert_eq!(from.len(), to.len()); + // not necessary because we have a mut reference + // assert!(!from.frozen); + // assert!(!to.frozen); + + from.vec.iter_mut().zip(to.vec.iter_mut()).for_each(|(from, to)| { + let from = from.get_mut(); + let to = to.get_mut(); + + if from.len() > to.len() { + std::mem::swap(from, to); + } + to.append(from); + }); + unsafe { + crate::internal::MOVE_NO_INDEX_CONTENTS_TOTAL_TIME += before.elapsed(); + } + } +} + +impl<'a, V: 'a> CRelIndexWrite for CRelNoIndex { + type Key = (); + type Value = V; + + fn index_insert(&self, _key: Self::Key, value: Self::Value) { + assert!(!self.frozen); + let shard_idx = rayon::current_thread_index().unwrap_or(0) % self.vec.len(); + self.vec[shard_idx].write().push(value); + } +} + +impl<'a, V: 'a> RelIndexReadAll<'a> for CRelNoIndex { + type Key = &'a (); + type Value = &'a V; + + type ValueIteratorType = >::IteratorType; + + type AllIteratorType = std::iter::Once<(&'a (), Self::ValueIteratorType)>; + + fn iter_all(&'a self) -> Self::AllIteratorType { std::iter::once((&(), self.index_get(&()).unwrap())) } +} + +impl<'a, V: 'a + Sync + Send> CRelIndexReadAll<'a> for CRelNoIndex { + type Key = &'a (); + type Value = &'a V; + + type ValueIteratorType = >::IteratorType; + + type AllIteratorType = rayon::iter::Once<(&'a (), Self::ValueIteratorType)>; + + fn c_iter_all(&'a self) -> Self::AllIteratorType { rayon::iter::once((&(), self.c_index_get(&()).unwrap())) } +} diff --git a/ascent/src/convert.rs b/ascent/src/convert.rs index 8e72c2a..29bb493 100644 --- a/ascent/src/convert.rs +++ b/ascent/src/convert.rs @@ -1,35 +1,30 @@ - -use std::rc::Rc; -use std::sync::Arc; - -pub trait Convert { - fn convert(source: TSource) -> Self; -} - -impl Convert for T { - #[inline(always)] - fn convert(source: T) -> T {source} -} - -impl Convert<&T> for T where T: Clone { - #[inline(always)] - fn convert(source: &T) -> T {source.clone()} -} - -impl Convert<&str> for String { - fn convert(source: &str) -> Self { - source.to_string() - } -} - -impl Convert<&Rc> for T { - fn convert(source: &Rc) -> Self { - source.as_ref().clone() - } -} - -impl Convert<&Arc> for T { - fn convert(source: &Arc) -> Self { - source.as_ref().clone() - } -} +use std::rc::Rc; +use std::sync::Arc; + +pub trait Convert { + fn convert(source: TSource) -> Self; +} + +impl Convert for T { + #[inline(always)] + fn convert(source: T) -> T { source } +} + +impl Convert<&T> for T +where T: Clone +{ + #[inline(always)] + fn convert(source: &T) -> T { source.clone() } +} + +impl Convert<&str> for String { + fn convert(source: &str) -> Self { source.to_string() } +} + +impl Convert<&Rc> for T { + fn convert(source: &Rc) -> Self { source.as_ref().clone() } +} + +impl Convert<&Arc> for T { + fn convert(source: &Arc) -> Self { source.as_ref().clone() } +} diff --git a/ascent/src/exps.rs b/ascent/src/exps.rs index 712e807..d4973d9 100644 --- a/ascent/src/exps.rs +++ b/ascent/src/exps.rs @@ -1,197 +1,186 @@ -#![cfg(all(test, feature = "par"))] -#![allow(dead_code)] - -use std::sync::Mutex; -use std::sync::atomic::AtomicBool; -use std::time::Instant; - -use rayon::prelude::*; - -use crate::c_rel_index::CRelIndex; -use crate::internal::{RelIndexWrite, Freezable}; -use crate::rel_index_read::RelIndexRead; -use std::sync::atomic::Ordering::Relaxed; - -// #[test] -fn bench_aovec() { - type AOVec = boxcar::Vec; - let size = 125_000_000; - - println!("pushing ..."); - let before = Instant::now(); - let mut vec = vec![]; - for i in 0..size { - vec.push(i); - } - let elapsed = before.elapsed(); - println!("vec time: {:?}", elapsed); - - let before = Instant::now(); - let vec = AOVec::new(); - for i in 0..size { - vec.push(i); - } - let elapsed = before.elapsed(); - println!("ao vec time: {:?}", elapsed); - - ///////////////////////////////// - - println!("\nparallel pushing ..."); - - let before = Instant::now(); - let vec = Mutex::new(vec![]); - (0..size).into_par_iter().for_each(|i| { - vec.lock().unwrap().push(i); - }); - let elapsed = before.elapsed(); - assert_eq!(vec.lock().unwrap().len(), size); - println!("parallel Mutex time: {:?}", elapsed); - - - let before = Instant::now(); - let vec = AOVec::new(); - (0..size).into_par_iter().for_each(|i| { - vec.push(i); - }); - let elapsed = before.elapsed(); - assert_eq!(vec.len(), size); - println!("parallel ao vec time: {:?}", elapsed); -} - -// #[test] -fn bench_atomic_changed() { - type AOVec = boxcar::Vec; - let size = 125_000_000; - - { - - let before = Instant::now(); - let vec = AOVec::new(); - let changed = AtomicBool::new(false); - (0..size).into_par_iter().for_each(|i| { - vec.push(i); - changed.store(true, Relaxed); - }); - let elapsed = before.elapsed(); - println!("changed: {}", changed.load(Relaxed)); - assert_eq!(vec.len(), size); - println!("atomic changed ao vec time: {:?}", elapsed); - } - - { - let before = Instant::now(); - let vec = AOVec::new(); - let changed = (0..size).into_par_iter().fold_with(false, |_changed, i| { - vec.push(i); - true - }); - // let changed = changed.reduce(|| false, |x, y| x | y); - println!("changed count: {}", changed.count()); - let elapsed = before.elapsed(); - // println!("changed: {}", changed); - assert_eq!(vec.len(), size); - println!("therad-local changed ao vec time: {:?}", elapsed); - } -} - - -// #[test] -fn bench_crel_index() { - let mut rel_index = CRelIndex::default(); - - let before = Instant::now(); - for i in 0..1_000_000 { - RelIndexWrite::index_insert(&mut rel_index, i, i); - } - let elapsed = before.elapsed(); - println!("insert time: {:?}", elapsed); - - let iters = 1_000_000; - - let before = Instant::now(); - let mut _sum = 0; - for _ in 0..iters { - crate::internal::Freezable::freeze(&mut rel_index as _); - _sum += rel_index.index_get(&42).unwrap().next().unwrap(); - rel_index.unfreeze(); - } - - let elapsed = before.elapsed(); - - println!("freeze_unfreeze for {} iterations time: {:?}", iters, elapsed); -} - -// #[test] -fn bench_par_iter() { - - let arr = (1..1_000_000).collect::>(); - - let before = Instant::now(); - arr.par_iter().for_each(|x| { - if *x == 42 { - println!("x is 42"); - } - }); - println!("par_iter took {:?}", before.elapsed()); - - let before = Instant::now(); - arr.iter().par_bridge().for_each(|x| { - if *x == 42 { - println!("x is 42"); - } - }); - println!("par_bridge took {:?}", before.elapsed()); -} - -#[test] -fn bench_par_flat_map() { - - fn calc_sum(x: usize) -> usize { - let mut res = 0; - for i in 0..=(x >> 5) { - assert!(res < usize::MAX); - res += i; - } - res - } - - let len = 40; - let mut arr = Vec::with_capacity(len); - - for _ in 0..len { - let vec = (0..1000).collect::>(); - arr.push(vec); - } - - - let before = Instant::now(); - arr.iter().flat_map(|v| v.iter()).for_each(|x| { - calc_sum(*x); - }); - println!("ser flat_map took {:?}", before.elapsed()); - - let before = Instant::now(); - arr.par_iter().flat_map(|v| v.par_iter()).for_each(|x| { - calc_sum(*x); - }); - println!("par flat_map took {:?}", before.elapsed()); - - let before = Instant::now(); - arr.par_iter().flat_map_iter(|v| v.iter()).for_each(|x| { - calc_sum(*x); - }); - println!("par flat_map_iter took {:?}", before.elapsed()); - -} - -#[test] -fn exp_rayon_scope() { - rayon::scope(|__scope| { - __scope.spawn(|_| { - - }); - __scope.spawn(|_| { - - }); - }); -} \ No newline at end of file +#![cfg(all(test, feature = "par"))] +#![allow(dead_code)] + +use std::sync::Mutex; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering::Relaxed; +use std::time::Instant; + +use rayon::prelude::*; + +use crate::c_rel_index::CRelIndex; +use crate::internal::{Freezable, RelIndexWrite}; +use crate::rel_index_read::RelIndexRead; + +// #[test] +fn bench_aovec() { + type AOVec = boxcar::Vec; + let size = 125_000_000; + + println!("pushing ..."); + let before = Instant::now(); + let mut vec = vec![]; + for i in 0..size { + vec.push(i); + } + let elapsed = before.elapsed(); + println!("vec time: {:?}", elapsed); + + let before = Instant::now(); + let vec = AOVec::new(); + for i in 0..size { + vec.push(i); + } + let elapsed = before.elapsed(); + println!("ao vec time: {:?}", elapsed); + + ///////////////////////////////// + + println!("\nparallel pushing ..."); + + let before = Instant::now(); + let vec = Mutex::new(vec![]); + (0..size).into_par_iter().for_each(|i| { + vec.lock().unwrap().push(i); + }); + let elapsed = before.elapsed(); + assert_eq!(vec.lock().unwrap().len(), size); + println!("parallel Mutex time: {:?}", elapsed); + + let before = Instant::now(); + let vec = AOVec::new(); + (0..size).into_par_iter().for_each(|i| { + vec.push(i); + }); + let elapsed = before.elapsed(); + assert_eq!(vec.len(), size); + println!("parallel ao vec time: {:?}", elapsed); +} + +// #[test] +fn bench_atomic_changed() { + type AOVec = boxcar::Vec; + let size = 125_000_000; + + { + let before = Instant::now(); + let vec = AOVec::new(); + let changed = AtomicBool::new(false); + (0..size).into_par_iter().for_each(|i| { + vec.push(i); + changed.store(true, Relaxed); + }); + let elapsed = before.elapsed(); + println!("changed: {}", changed.load(Relaxed)); + assert_eq!(vec.len(), size); + println!("atomic changed ao vec time: {:?}", elapsed); + } + + { + let before = Instant::now(); + let vec = AOVec::new(); + let changed = (0..size).into_par_iter().fold_with(false, |_changed, i| { + vec.push(i); + true + }); + // let changed = changed.reduce(|| false, |x, y| x | y); + println!("changed count: {}", changed.count()); + let elapsed = before.elapsed(); + // println!("changed: {}", changed); + assert_eq!(vec.len(), size); + println!("therad-local changed ao vec time: {:?}", elapsed); + } +} + +// #[test] +fn bench_crel_index() { + let mut rel_index = CRelIndex::default(); + + let before = Instant::now(); + for i in 0..1_000_000 { + RelIndexWrite::index_insert(&mut rel_index, i, i); + } + let elapsed = before.elapsed(); + println!("insert time: {:?}", elapsed); + + let iters = 1_000_000; + + let before = Instant::now(); + let mut _sum = 0; + for _ in 0..iters { + crate::internal::Freezable::freeze(&mut rel_index as _); + _sum += rel_index.index_get(&42).unwrap().next().unwrap(); + rel_index.unfreeze(); + } + + let elapsed = before.elapsed(); + + println!("freeze_unfreeze for {} iterations time: {:?}", iters, elapsed); +} + +// #[test] +fn bench_par_iter() { + let arr = (1..1_000_000).collect::>(); + + let before = Instant::now(); + arr.par_iter().for_each(|x| { + if *x == 42 { + println!("x is 42"); + } + }); + println!("par_iter took {:?}", before.elapsed()); + + let before = Instant::now(); + arr.iter().par_bridge().for_each(|x| { + if *x == 42 { + println!("x is 42"); + } + }); + println!("par_bridge took {:?}", before.elapsed()); +} + +#[test] +fn bench_par_flat_map() { + fn calc_sum(x: usize) -> usize { + let mut res = 0; + for i in 0..=(x >> 5) { + assert!(res < usize::MAX); + res += i; + } + res + } + + let len = 40; + let mut arr = Vec::with_capacity(len); + + for _ in 0..len { + let vec = (0..1000).collect::>(); + arr.push(vec); + } + + let before = Instant::now(); + arr.iter().flat_map(|v| v.iter()).for_each(|x| { + calc_sum(*x); + }); + println!("ser flat_map took {:?}", before.elapsed()); + + let before = Instant::now(); + arr.par_iter().flat_map(|v| v.par_iter()).for_each(|x| { + calc_sum(*x); + }); + println!("par flat_map took {:?}", before.elapsed()); + + let before = Instant::now(); + arr.par_iter().flat_map_iter(|v| v.iter()).for_each(|x| { + calc_sum(*x); + }); + println!("par flat_map_iter took {:?}", before.elapsed()); +} + +#[test] +fn exp_rayon_scope() { + rayon::scope(|__scope| { + __scope.spawn(|_| {}); + __scope.spawn(|_| {}); + }); +} diff --git a/ascent/src/internal.rs b/ascent/src/internal.rs index e2e41ef..a0427c0 100644 --- a/ascent/src/internal.rs +++ b/ascent/src/internal.rs @@ -1,245 +1,241 @@ -//! Provides definitions required for the `ascent` macro(s), plus traits that custom relations need to implement. - -pub use crate::convert::*; - -use std::time::Duration; -use std::hash::{BuildHasherDefault, Hash}; -use std::collections::{HashMap, HashSet}; - -use cfg_if::cfg_if; -pub use instant::Instant; - -use ascent_base::Lattice; -use rustc_hash::FxHasher; - -pub use crate::rel_index_read::RelIndexCombined; -pub use crate::rel_index_read::RelIndexRead; -pub use crate::rel_index_read::RelIndexReadAll; - -pub type RelIndexType = RelIndexType1; - -pub type LatticeIndexType = HashMap>, BuildHasherDefault>; - -pub(crate) type HashBrownRelFullIndexType = hashbrown::HashMap>; -pub type RelFullIndexType = HashBrownRelFullIndexType; - -pub type RelNoIndexType = Vec; - -cfg_if! { - if #[cfg(feature = "par")] { - pub use crate::c_rel_index_read::CRelIndexRead; - pub use crate::c_rel_index_read::CRelIndexReadAll; - - pub use crate::c_rel_index::shards_count; - - pub use crate::c_rel_index::CRelIndex; - pub use crate::c_rel_full_index::CRelFullIndex; - pub use crate::c_lat_index::CLatIndex; - pub use crate::c_rel_no_index::CRelNoIndex; - pub use crate::c_rel_index::DashMapViewParIter; - } -} - -pub use crate::to_rel_index::{ToRelIndex0, ToRelIndex}; -pub use crate::tuple_of_borrowed::TupleOfBorrowed; - - -pub trait Freezable { - fn freeze(&mut self) { } - fn unfreeze(&mut self) { } -} - -pub trait RelIndexWrite: Sized { - type Key; - type Value; - fn index_insert(&mut self, key: Self::Key, value: Self::Value); -} - -pub trait RelIndexMerge: Sized { - fn move_index_contents(from: &mut Self, to: &mut Self); - fn merge_delta_to_total_new_to_delta(new: &mut Self, delta: &mut Self, total: &mut Self) { - Self::move_index_contents(delta, total); - std::mem::swap(new, delta); - } - - /// Called once at the start of the SCC - #[allow(unused_variables)] - fn init(new: &mut Self, delta: &mut Self, total: &mut Self) { } -} - -pub trait CRelIndexWrite{ - type Key; - type Value; - fn index_insert(&self, key: Self::Key, value: Self::Value); -} - -pub trait RelFullIndexRead<'a> { - type Key; - fn contains_key(&'a self, key: &Self::Key) -> bool; -} - - -pub trait RelFullIndexWrite { - type Key: Clone; - type Value; - /// if an entry for `key` does not exist, inserts `v` for it and returns true. - fn insert_if_not_present(&mut self, key: &Self::Key, v: Self::Value) -> bool; -} - -pub trait CRelFullIndexWrite { - type Key: Clone; - type Value; - /// if an entry for `key` does not exist, inserts `v` for it and returns true. - fn insert_if_not_present(&self, key: &Self::Key, v: Self::Value) -> bool; -} - - -pub type RelIndexType1 = HashMap, BuildHasherDefault>; - -pub static mut MOVE_REL_INDEX_CONTENTS_TOTAL_TIME : Duration = Duration::ZERO; -pub static mut INDEX_INSERT_TOTAL_TIME : Duration = Duration::ZERO; - -impl RelIndexWrite for RelIndexType1{ - type Key = K; - type Value = V; - - fn index_insert(&mut self, key: K, value: V) { - // let before = Instant::now(); - use std::collections::hash_map::Entry::*; - match self.entry(key){ - Occupied(mut vec) => vec.get_mut().push(value), - Vacant(vacant) => { - let mut vec = Vec::with_capacity(4); - vec.push(value); - vacant.insert(vec); - }, - } - // unsafe { - // INDEX_INSERT_TOTAL_TIME += before.elapsed(); - // } - } -} - -impl RelIndexMerge for RelIndexType1 { - fn move_index_contents(from: &mut RelIndexType1, to: &mut RelIndexType1) { - let before = Instant::now(); - if from.len() > to.len() { - std::mem::swap(from, to); - } - use std::collections::hash_map::Entry::*; - for (k, mut v) in from.drain() { - match to.entry(k) { - Occupied(existing) => { - let existing = existing.into_mut(); - if v.len() > existing.len() { - std::mem::swap(&mut v, existing); - } - existing.append(&mut v); - }, - Vacant(vacant) => { - vacant.insert(v); - }, - } - } - unsafe { - MOVE_REL_INDEX_CONTENTS_TOTAL_TIME += before.elapsed(); - } - } -} - -impl RelIndexWrite for RelNoIndexType { - type Key = (); - type Value = usize; - - fn index_insert(&mut self, _key: Self::Key, tuple_index: usize) { - self.push(tuple_index); - } -} - -impl RelIndexMerge for RelNoIndexType { - fn move_index_contents(ind1: &mut Self, ind2: &mut Self) { - ind2.append(ind1); - } -} - -impl RelIndexWrite for LatticeIndexType{ - type Key = K; - type Value = V; - - #[inline(always)] - fn index_insert(&mut self, key: Self::Key, tuple_index: V) { - self.entry(key).or_default().insert(tuple_index); - } -} - -impl RelIndexMerge for LatticeIndexType{ - #[inline(always)] - fn move_index_contents(hm1: &mut LatticeIndexType, hm2: &mut LatticeIndexType) { - for (k,v) in hm1.drain(){ - let set = hm2.entry(k).or_default(); - set.extend(v); - } - } -} - - -pub static mut MOVE_FULL_INDEX_CONTENTS_TOTAL_TIME : Duration = Duration::ZERO; -pub static mut MOVE_NO_INDEX_CONTENTS_TOTAL_TIME : Duration = Duration::ZERO; - -impl RelIndexWrite for HashBrownRelFullIndexType{ - type Key = K; - type Value = V; - - #[inline(always)] - fn index_insert(&mut self, key: Self::Key, value: V) { - self.insert(key, value); - } -} - -impl RelIndexMerge for HashBrownRelFullIndexType { - fn move_index_contents(from: &mut Self, to: &mut Self) { - let before = Instant::now(); - if from.len() > to.len() { - std::mem::swap(from, to); - } - to.reserve(from.len()); - for (k, v) in from.drain() { - to.insert(k, v); // TODO could be improved - } - unsafe { - MOVE_FULL_INDEX_CONTENTS_TOTAL_TIME += before.elapsed(); - } - } -} - -impl RelFullIndexWrite for HashBrownRelFullIndexType { - type Key = K; - type Value = V; - #[inline] - fn insert_if_not_present(&mut self, key: &K, v: V) -> bool { - match self.raw_entry_mut().from_key(key) { - hashbrown::hash_map::RawEntryMut::Occupied(_) => false, - hashbrown::hash_map::RawEntryMut::Vacant(vacant) => {vacant.insert(key.clone(), v); true}, - } - } -} - -impl<'a, K: Hash + Eq, V> RelFullIndexRead<'a> for HashBrownRelFullIndexType { - type Key = K; - - fn contains_key(&self, key: &Self::Key) -> bool { - self.contains_key(key) - } -} - - -/// type constraints for relation columns -pub struct TypeConstraints where T : Clone + Eq + Hash{_t: ::core::marker::PhantomData} -/// type constraints for a lattice -pub struct LatTypeConstraints where T : Clone + Eq + Hash + Lattice{_t: ::core::marker::PhantomData} - -/// type constraints for parallel Ascent -pub struct ParTypeConstraints where T: Send + Sync {_t: ::core::marker::PhantomData} - -#[inline(always)] -pub fn comment(_: &str){} +//! Provides definitions required for the `ascent` macro(s), plus traits that custom relations need to implement. + +use std::collections::{HashMap, HashSet}; +use std::hash::{BuildHasherDefault, Hash}; +use std::time::Duration; + +use ascent_base::Lattice; +use cfg_if::cfg_if; +pub use instant::Instant; +use rustc_hash::FxHasher; + +pub use crate::convert::*; +pub use crate::rel_index_read::{RelIndexCombined, RelIndexRead, RelIndexReadAll}; + +pub type RelIndexType = RelIndexType1; + +pub type LatticeIndexType = HashMap>, BuildHasherDefault>; + +pub(crate) type HashBrownRelFullIndexType = hashbrown::HashMap>; +pub type RelFullIndexType = HashBrownRelFullIndexType; + +pub type RelNoIndexType = Vec; + +cfg_if! { + if #[cfg(feature = "par")] { + pub use crate::c_rel_index_read::CRelIndexRead; + pub use crate::c_rel_index_read::CRelIndexReadAll; + + pub use crate::c_rel_index::shards_count; + + pub use crate::c_rel_index::CRelIndex; + pub use crate::c_rel_full_index::CRelFullIndex; + pub use crate::c_lat_index::CLatIndex; + pub use crate::c_rel_no_index::CRelNoIndex; + pub use crate::c_rel_index::DashMapViewParIter; + } +} + +pub use crate::to_rel_index::{ToRelIndex, ToRelIndex0}; +pub use crate::tuple_of_borrowed::TupleOfBorrowed; + +pub trait Freezable { + fn freeze(&mut self) {} + fn unfreeze(&mut self) {} +} + +pub trait RelIndexWrite: Sized { + type Key; + type Value; + fn index_insert(&mut self, key: Self::Key, value: Self::Value); +} + +pub trait RelIndexMerge: Sized { + fn move_index_contents(from: &mut Self, to: &mut Self); + fn merge_delta_to_total_new_to_delta(new: &mut Self, delta: &mut Self, total: &mut Self) { + Self::move_index_contents(delta, total); + std::mem::swap(new, delta); + } + + /// Called once at the start of the SCC + #[allow(unused_variables)] + fn init(new: &mut Self, delta: &mut Self, total: &mut Self) {} +} + +pub trait CRelIndexWrite { + type Key; + type Value; + fn index_insert(&self, key: Self::Key, value: Self::Value); +} + +pub trait RelFullIndexRead<'a> { + type Key; + fn contains_key(&'a self, key: &Self::Key) -> bool; +} + +pub trait RelFullIndexWrite { + type Key: Clone; + type Value; + /// if an entry for `key` does not exist, inserts `v` for it and returns true. + fn insert_if_not_present(&mut self, key: &Self::Key, v: Self::Value) -> bool; +} + +pub trait CRelFullIndexWrite { + type Key: Clone; + type Value; + /// if an entry for `key` does not exist, inserts `v` for it and returns true. + fn insert_if_not_present(&self, key: &Self::Key, v: Self::Value) -> bool; +} + +pub type RelIndexType1 = HashMap, BuildHasherDefault>; + +pub static mut MOVE_REL_INDEX_CONTENTS_TOTAL_TIME: Duration = Duration::ZERO; +pub static mut INDEX_INSERT_TOTAL_TIME: Duration = Duration::ZERO; + +impl RelIndexWrite for RelIndexType1 { + type Key = K; + type Value = V; + + fn index_insert(&mut self, key: K, value: V) { + // let before = Instant::now(); + use std::collections::hash_map::Entry::*; + match self.entry(key) { + Occupied(mut vec) => vec.get_mut().push(value), + Vacant(vacant) => { + let mut vec = Vec::with_capacity(4); + vec.push(value); + vacant.insert(vec); + }, + } + // unsafe { + // INDEX_INSERT_TOTAL_TIME += before.elapsed(); + // } + } +} + +impl RelIndexMerge for RelIndexType1 { + fn move_index_contents(from: &mut RelIndexType1, to: &mut RelIndexType1) { + let before = Instant::now(); + if from.len() > to.len() { + std::mem::swap(from, to); + } + use std::collections::hash_map::Entry::*; + for (k, mut v) in from.drain() { + match to.entry(k) { + Occupied(existing) => { + let existing = existing.into_mut(); + if v.len() > existing.len() { + std::mem::swap(&mut v, existing); + } + existing.append(&mut v); + }, + Vacant(vacant) => { + vacant.insert(v); + }, + } + } + unsafe { + MOVE_REL_INDEX_CONTENTS_TOTAL_TIME += before.elapsed(); + } + } +} + +impl RelIndexWrite for RelNoIndexType { + type Key = (); + type Value = usize; + + fn index_insert(&mut self, _key: Self::Key, tuple_index: usize) { self.push(tuple_index); } +} + +impl RelIndexMerge for RelNoIndexType { + fn move_index_contents(ind1: &mut Self, ind2: &mut Self) { ind2.append(ind1); } +} + +impl RelIndexWrite for LatticeIndexType { + type Key = K; + type Value = V; + + #[inline(always)] + fn index_insert(&mut self, key: Self::Key, tuple_index: V) { self.entry(key).or_default().insert(tuple_index); } +} + +impl RelIndexMerge for LatticeIndexType { + #[inline(always)] + fn move_index_contents(hm1: &mut LatticeIndexType, hm2: &mut LatticeIndexType) { + for (k, v) in hm1.drain() { + let set = hm2.entry(k).or_default(); + set.extend(v); + } + } +} + +pub static mut MOVE_FULL_INDEX_CONTENTS_TOTAL_TIME: Duration = Duration::ZERO; +pub static mut MOVE_NO_INDEX_CONTENTS_TOTAL_TIME: Duration = Duration::ZERO; + +impl RelIndexWrite for HashBrownRelFullIndexType { + type Key = K; + type Value = V; + + #[inline(always)] + fn index_insert(&mut self, key: Self::Key, value: V) { self.insert(key, value); } +} + +impl RelIndexMerge for HashBrownRelFullIndexType { + fn move_index_contents(from: &mut Self, to: &mut Self) { + let before = Instant::now(); + if from.len() > to.len() { + std::mem::swap(from, to); + } + to.reserve(from.len()); + for (k, v) in from.drain() { + to.insert(k, v); // TODO could be improved + } + unsafe { + MOVE_FULL_INDEX_CONTENTS_TOTAL_TIME += before.elapsed(); + } + } +} + +impl RelFullIndexWrite for HashBrownRelFullIndexType { + type Key = K; + type Value = V; + #[inline] + fn insert_if_not_present(&mut self, key: &K, v: V) -> bool { + match self.raw_entry_mut().from_key(key) { + hashbrown::hash_map::RawEntryMut::Occupied(_) => false, + hashbrown::hash_map::RawEntryMut::Vacant(vacant) => { + vacant.insert(key.clone(), v); + true + }, + } + } +} + +impl<'a, K: Hash + Eq, V> RelFullIndexRead<'a> for HashBrownRelFullIndexType { + type Key = K; + + fn contains_key(&self, key: &Self::Key) -> bool { self.contains_key(key) } +} + +/// type constraints for relation columns +pub struct TypeConstraints +where T: Clone + Eq + Hash +{ + _t: ::core::marker::PhantomData, +} +/// type constraints for a lattice +pub struct LatTypeConstraints +where T: Clone + Eq + Hash + Lattice +{ + _t: ::core::marker::PhantomData, +} + +/// type constraints for parallel Ascent +pub struct ParTypeConstraints +where T: Send + Sync +{ + _t: ::core::marker::PhantomData, +} + +#[inline(always)] +pub fn comment(_: &str) {} diff --git a/ascent/src/lib.rs b/ascent/src/lib.rs index 9f1fa28..be4091a 100644 --- a/ascent/src/lib.rs +++ b/ascent/src/lib.rs @@ -1,5 +1,5 @@ //! Ascent enables writing logic programs in the style of Datalog in Rust. -//! +//! //! See the documentation for [`ascent`], one of the main macros of this crate, for more information. #![deny(unused_crate_dependencies)] @@ -26,15 +26,12 @@ mod to_rel_index; mod tuple_of_borrowed; mod rel_index_boilerplate; +pub use ascent_base::*; pub use ascent_macro::{ascent, ascent_run}; #[cfg(feature = "par")] pub use ascent_macro::{ascent_par, ascent_run_par}; - -pub use ascent_base::*; - -pub use hashbrown; #[cfg(feature = "par")] pub use dashmap; -pub use boxcar; #[cfg(feature = "par")] pub use rayon; +pub use {boxcar, hashbrown}; diff --git a/ascent/src/rel.rs b/ascent/src/rel.rs index 9382ff0..5e60902 100644 --- a/ascent/src/rel.rs +++ b/ascent/src/rel.rs @@ -1,169 +1,209 @@ -//! The default data structure provider for Ascent relations - -macro_rules! _rel_type_template { - ($field_types: ty, $indices: expr, $par: ident) => {}; -} - -macro_rules! _rel_ind_template { - ($field_types: ty, $indices: expr, $par: ident, $ind: expr) => {}; -} - -#[doc(hidden)] -#[macro_export] -macro_rules! rel_codegen { - ( $($tt: tt)* ) => { }; -} -pub use rel_codegen; - -#[doc(hidden)] -#[macro_export] -macro_rules! rel { - ($name: ident, $field_types: ty, $indices: expr, ser, ()) => { - ::std::vec::Vec<$field_types> - }; - ($name: ident, $field_types: ty, $indices: expr, par, ()) => { - ::ascent::boxcar::Vec<$field_types> - }; -} -pub use rel; - -#[doc(hidden)] -#[macro_export] -macro_rules! rel_ind_common { - ($name: ident, $field_types: ty, $indices: expr, ser, ()) => { - () - }; - ($name: ident, $field_types: ty, $indices: expr, par, ()) => { - () - }; -} -pub use rel_ind_common; - -#[doc(hidden)] -#[macro_export] -macro_rules! rel_full_ind { - ($name: ident, $field_types: ty, $indices: expr, ser, (), $key: ty, $val: ty) => { - ascent::internal::RelFullIndexType<$key, $val> - }; - ($name: ident, $field_types: ty, $indices: expr, par, (), $key: ty, $val: ty) => { - ascent::internal::CRelFullIndex<$key, $val> - }; -} -pub use rel_full_ind; - -#[doc(hidden)] -#[macro_export] -macro_rules! rel_ind { - ($name: ident, $field_types: ty, $indices: expr, ser, (), $ind: expr, $key: ty, $val: ty) => { - ascent::rel::ToRelIndexType<$key, $val> - }; - ($name: ident, $field_types: ty, $indices: expr, par, (), [], $key: ty, $val: ty) => { - ascent::internal::CRelNoIndex<$val> - }; - ($name: ident, $field_types: ty, $indices: expr, par, (), $ind: expr, $key: ty, $val: ty) => { - ascent::internal::CRelIndex<$key, $val> - }; -} -pub use rel_ind; - -#[derive(Clone)] -pub struct ToRelIndexType(pub RelIndexType1); - -impl Default for ToRelIndexType { - #[inline(always)] - fn default() -> Self { Self(Default::default()) } -} - -impl ToRelIndex for ToRelIndexType { - type RelIndex<'a> = &'a RelIndexType1 where Self: 'a, R: 'a; - - #[inline(always)] - fn to_rel_index<'a>(&'a self, _rel: &'a R) -> Self::RelIndex<'a> { - &self.0 - } - - type RelIndexWrite<'a> = &'a mut RelIndexType1 where Self: 'a, R: 'a; - - #[inline(always)] - fn to_rel_index_write<'a>(&'a mut self, _rel: &'a mut R) -> Self::RelIndexWrite<'a> { - &mut self.0 - } -} - -use crate::internal::{Freezable, RelFullIndexType, RelIndexMerge, RelIndexType1}; -use crate::to_rel_index::ToRelIndex; - -impl ToRelIndex for RelIndexType1 { - type RelIndex<'a> = &'a Self where Self: 'a, Rel: 'a; - #[inline(always)] - fn to_rel_index<'a>(&'a self, _rel: &'a Rel) -> Self::RelIndex<'a> { self } - - type RelIndexWrite<'a> = &'a mut Self where Self: 'a, Rel: 'a; - #[inline(always)] - fn to_rel_index_write<'a>(&'a mut self, _rel: &'a mut Rel) -> Self::RelIndexWrite<'a> { self } -} - -impl ToRelIndex for RelFullIndexType { - - type RelIndex<'a> = &'a Self where Self: 'a, Rel: 'a; - #[inline(always)] - fn to_rel_index<'a>(&'a self, _rel: &'a Rel) -> Self::RelIndex<'a> { self } - - type RelIndexWrite<'a> = &'a mut Self where Self: 'a, Rel: 'a; - #[inline(always)] - fn to_rel_index_write<'a>(&'a mut self, _rel: &'a mut Rel) -> Self::RelIndexWrite<'a> { self } -} - -#[cfg(feature = "par")] -mod par { - use crate::c_rel_full_index::CRelFullIndex; - use crate::c_rel_index::CRelIndex; - use crate::c_rel_no_index::CRelNoIndex; - use crate::to_rel_index::ToRelIndex; - - impl ToRelIndex for CRelIndex { - - type RelIndex<'a> = &'a Self where Self: 'a, Rel: 'a; - #[inline(always)] - fn to_rel_index<'a>(&'a self, _rel: &'a Rel) -> Self::RelIndex<'a> { self } - - type RelIndexWrite<'a> = &'a mut Self where Self: 'a, Rel: 'a; - #[inline(always)] - fn to_rel_index_write<'a>(&'a mut self, _rel: &'a mut Rel) -> Self::RelIndexWrite<'a> { self } - } - - impl ToRelIndex for CRelNoIndex { - type RelIndex<'a> = &'a Self where Self: 'a, Rel: 'a; - #[inline(always)] - fn to_rel_index<'a>(&'a self, _rel: &'a Rel) -> Self::RelIndex<'a> { self } - - type RelIndexWrite<'a> = &'a mut Self where Self: 'a, Rel: 'a; - #[inline(always)] - fn to_rel_index_write<'a>(&'a mut self, _rel: &'a mut Rel) -> Self::RelIndexWrite<'a> { self } - } - - impl ToRelIndex for CRelFullIndex { - - type RelIndex<'a> = &'a Self where Self: 'a, Rel: 'a; - #[inline(always)] - fn to_rel_index<'a>(&'a self, _rel: &'a Rel) -> Self::RelIndex<'a> { self } - - type RelIndexWrite<'a> = &'a mut Self where Self: 'a, Rel: 'a; - #[inline(always)] - fn to_rel_index_write<'a>(&'a mut self, _rel: &'a mut Rel) -> Self::RelIndexWrite<'a> { self } - } -} - - -impl RelIndexMerge for () { - #[inline(always)] - fn move_index_contents(_from: &mut Self, _to: &mut Self) { } - - #[inline(always)] - fn merge_delta_to_total_new_to_delta(_new: &mut Self, _delta: &mut Self, _total: &mut Self) { } -} - -impl Freezable for () { - fn freeze(&mut self) { } - fn unfreeze(&mut self) { } -} \ No newline at end of file +//! The default data structure provider for Ascent relations + +macro_rules! _rel_type_template { + ($field_types: ty, $indices: expr, $par: ident) => {}; +} + +macro_rules! _rel_ind_template { + ($field_types: ty, $indices: expr, $par: ident, $ind: expr) => {}; +} + +#[doc(hidden)] +#[macro_export] +macro_rules! rel_codegen { + ( $($tt: tt)* ) => { }; +} +pub use rel_codegen; + +#[doc(hidden)] +#[macro_export] +macro_rules! rel { + ($name: ident, $field_types: ty, $indices: expr, ser, ()) => { + ::std::vec::Vec<$field_types> + }; + ($name: ident, $field_types: ty, $indices: expr, par, ()) => { + ::ascent::boxcar::Vec<$field_types> + }; +} +pub use rel; + +#[doc(hidden)] +#[macro_export] +macro_rules! rel_ind_common { + ($name: ident, $field_types: ty, $indices: expr, ser, ()) => { + () + }; + ($name: ident, $field_types: ty, $indices: expr, par, ()) => { + () + }; +} +pub use rel_ind_common; + +#[doc(hidden)] +#[macro_export] +macro_rules! rel_full_ind { + ($name: ident, $field_types: ty, $indices: expr, ser, (), $key: ty, $val: ty) => { + ascent::internal::RelFullIndexType<$key, $val> + }; + ($name: ident, $field_types: ty, $indices: expr, par, (), $key: ty, $val: ty) => { + ascent::internal::CRelFullIndex<$key, $val> + }; +} +pub use rel_full_ind; + +#[doc(hidden)] +#[macro_export] +macro_rules! rel_ind { + ($name: ident, $field_types: ty, $indices: expr, ser, (), $ind: expr, $key: ty, $val: ty) => { + ascent::rel::ToRelIndexType<$key, $val> + }; + ($name: ident, $field_types: ty, $indices: expr, par, (), [], $key: ty, $val: ty) => { + ascent::internal::CRelNoIndex<$val> + }; + ($name: ident, $field_types: ty, $indices: expr, par, (), $ind: expr, $key: ty, $val: ty) => { + ascent::internal::CRelIndex<$key, $val> + }; +} +pub use rel_ind; + +#[derive(Clone)] +pub struct ToRelIndexType(pub RelIndexType1); + +impl Default for ToRelIndexType { + #[inline(always)] + fn default() -> Self { Self(Default::default()) } +} + +impl ToRelIndex for ToRelIndexType { + type RelIndex<'a> + = &'a RelIndexType1 + where + Self: 'a, + R: 'a; + + #[inline(always)] + fn to_rel_index<'a>(&'a self, _rel: &'a R) -> Self::RelIndex<'a> { &self.0 } + + type RelIndexWrite<'a> + = &'a mut RelIndexType1 + where + Self: 'a, + R: 'a; + + #[inline(always)] + fn to_rel_index_write<'a>(&'a mut self, _rel: &'a mut R) -> Self::RelIndexWrite<'a> { &mut self.0 } +} + +use crate::internal::{Freezable, RelFullIndexType, RelIndexMerge, RelIndexType1}; +use crate::to_rel_index::ToRelIndex; + +impl ToRelIndex for RelIndexType1 { + type RelIndex<'a> + = &'a Self + where + Self: 'a, + Rel: 'a; + #[inline(always)] + fn to_rel_index<'a>(&'a self, _rel: &'a Rel) -> Self::RelIndex<'a> { self } + + type RelIndexWrite<'a> + = &'a mut Self + where + Self: 'a, + Rel: 'a; + #[inline(always)] + fn to_rel_index_write<'a>(&'a mut self, _rel: &'a mut Rel) -> Self::RelIndexWrite<'a> { self } +} + +impl ToRelIndex for RelFullIndexType { + type RelIndex<'a> + = &'a Self + where + Self: 'a, + Rel: 'a; + #[inline(always)] + fn to_rel_index<'a>(&'a self, _rel: &'a Rel) -> Self::RelIndex<'a> { self } + + type RelIndexWrite<'a> + = &'a mut Self + where + Self: 'a, + Rel: 'a; + #[inline(always)] + fn to_rel_index_write<'a>(&'a mut self, _rel: &'a mut Rel) -> Self::RelIndexWrite<'a> { self } +} + +#[cfg(feature = "par")] +mod par { + use crate::c_rel_full_index::CRelFullIndex; + use crate::c_rel_index::CRelIndex; + use crate::c_rel_no_index::CRelNoIndex; + use crate::to_rel_index::ToRelIndex; + + impl ToRelIndex for CRelIndex { + type RelIndex<'a> + = &'a Self + where + Self: 'a, + Rel: 'a; + #[inline(always)] + fn to_rel_index<'a>(&'a self, _rel: &'a Rel) -> Self::RelIndex<'a> { self } + + type RelIndexWrite<'a> + = &'a mut Self + where + Self: 'a, + Rel: 'a; + #[inline(always)] + fn to_rel_index_write<'a>(&'a mut self, _rel: &'a mut Rel) -> Self::RelIndexWrite<'a> { self } + } + + impl ToRelIndex for CRelNoIndex { + type RelIndex<'a> + = &'a Self + where + Self: 'a, + Rel: 'a; + #[inline(always)] + fn to_rel_index<'a>(&'a self, _rel: &'a Rel) -> Self::RelIndex<'a> { self } + + type RelIndexWrite<'a> + = &'a mut Self + where + Self: 'a, + Rel: 'a; + #[inline(always)] + fn to_rel_index_write<'a>(&'a mut self, _rel: &'a mut Rel) -> Self::RelIndexWrite<'a> { self } + } + + impl ToRelIndex for CRelFullIndex { + type RelIndex<'a> + = &'a Self + where + Self: 'a, + Rel: 'a; + #[inline(always)] + fn to_rel_index<'a>(&'a self, _rel: &'a Rel) -> Self::RelIndex<'a> { self } + + type RelIndexWrite<'a> + = &'a mut Self + where + Self: 'a, + Rel: 'a; + #[inline(always)] + fn to_rel_index_write<'a>(&'a mut self, _rel: &'a mut Rel) -> Self::RelIndexWrite<'a> { self } + } +} + +impl RelIndexMerge for () { + #[inline(always)] + fn move_index_contents(_from: &mut Self, _to: &mut Self) {} + + #[inline(always)] + fn merge_delta_to_total_new_to_delta(_new: &mut Self, _delta: &mut Self, _total: &mut Self) {} +} + +impl Freezable for () { + fn freeze(&mut self) {} + fn unfreeze(&mut self) {} +} diff --git a/ascent/src/rel_index_boilerplate.rs b/ascent/src/rel_index_boilerplate.rs index 1466f99..5f92f6d 100644 --- a/ascent/src/rel_index_boilerplate.rs +++ b/ascent/src/rel_index_boilerplate.rs @@ -1,117 +1,120 @@ -use crate::internal::{RelFullIndexWrite, RelIndexWrite, RelFullIndexRead, RelIndexMerge, CRelIndexWrite, CRelFullIndexWrite}; -use crate::rel_index_read::{RelIndexRead, RelIndexReadAll}; - -impl<'a, T> RelIndexWrite for &'a mut T where T: RelIndexWrite { - type Key = T::Key; - type Value = T::Value; - - #[inline(always)] - fn index_insert(&mut self, key: Self::Key, value: Self::Value) { (**self).index_insert(key, value) } -} - -impl<'a, T> RelFullIndexWrite for &'a mut T where T:RelFullIndexWrite { - type Key = T::Key; - type Value = T::Value; - - #[inline(always)] - fn insert_if_not_present(&mut self, key: &Self::Key, v: Self::Value) -> bool { - (**self).insert_if_not_present(key, v) - } -} - -impl<'a, T> CRelIndexWrite for &'a T where T: CRelIndexWrite { - type Key = T::Key; - type Value = T::Value; - - #[inline(always)] - fn index_insert(&self, key: Self::Key, value: Self::Value) { - (**self).index_insert(key, value) - } -} - -impl<'a, T> CRelFullIndexWrite for &'a T where T: CRelFullIndexWrite { - type Key = T::Key; - type Value = T::Value; - - #[inline(always)] - fn insert_if_not_present(&self, key: &Self::Key, v: Self::Value) -> bool { - (**self).insert_if_not_present(key, v) - } -} - -impl<'a, T> RelIndexMerge for &'a mut T where T:RelIndexMerge { - fn move_index_contents(from: &mut Self, to: &mut Self) { - T::move_index_contents(*from, *to) - } - - fn merge_delta_to_total_new_to_delta(new: &mut Self, delta: &mut Self, total: &mut Self) { - T::merge_delta_to_total_new_to_delta(*new, *delta, *total) - } - - fn init(new: &mut Self, delta: &mut Self, total: &mut Self) { - T::init(new, delta, total) - } -} - -impl<'a, T> RelIndexRead<'a> for &'a T where T: RelIndexRead<'a> { - type Key = T::Key; - type Value = T::Value; - type IteratorType = T::IteratorType; - - #[inline(always)] - fn index_get(&'a self, key: &Self::Key) -> Option { - (**self).index_get(key) - } - - #[inline(always)] - fn len(&self) -> usize { - (**self).len() - } -} - -impl<'a, T> RelIndexReadAll<'a> for &'a T where T: RelIndexReadAll<'a> { - type Key = T::Key; - type Value = T::Value; - type ValueIteratorType = T::ValueIteratorType; - type AllIteratorType = T::AllIteratorType; - - #[inline(always)] - fn iter_all(&'a self) -> Self::AllIteratorType { - (**self).iter_all() - } -} - -impl<'a, T> RelFullIndexRead<'a> for &'a T where T:RelFullIndexRead<'a> { - type Key = T::Key; - #[inline(always)] - fn contains_key(&self, key: &Self::Key) -> bool { (**self).contains_key(key) } -} - -#[cfg(feature = "par")] -mod par { - use crate::internal::{CRelIndexRead, CRelIndexReadAll}; - - impl<'a, T> CRelIndexRead<'a> for &'a T where T: CRelIndexRead<'a> { - type Key = T::Key; - type Value = T::Value; - type IteratorType = T::IteratorType; - - #[inline(always)] - fn c_index_get(&'a self, key: &Self::Key) -> Option { - (**self).c_index_get(key) - } - - } - - impl<'a, T> CRelIndexReadAll<'a> for &'a T where T: CRelIndexReadAll<'a> { - type Key = T::Key; - type Value = T::Value; - type ValueIteratorType = T::ValueIteratorType; - type AllIteratorType = T::AllIteratorType; - - #[inline(always)] - fn c_iter_all(&'a self) -> Self::AllIteratorType { - (**self).c_iter_all() - } - } -} +use crate::internal::{ + CRelFullIndexWrite, CRelIndexWrite, RelFullIndexRead, RelFullIndexWrite, RelIndexMerge, RelIndexWrite, +}; +use crate::rel_index_read::{RelIndexRead, RelIndexReadAll}; + +impl<'a, T> RelIndexWrite for &'a mut T +where T: RelIndexWrite +{ + type Key = T::Key; + type Value = T::Value; + + #[inline(always)] + fn index_insert(&mut self, key: Self::Key, value: Self::Value) { (**self).index_insert(key, value) } +} + +impl<'a, T> RelFullIndexWrite for &'a mut T +where T: RelFullIndexWrite +{ + type Key = T::Key; + type Value = T::Value; + + #[inline(always)] + fn insert_if_not_present(&mut self, key: &Self::Key, v: Self::Value) -> bool { + (**self).insert_if_not_present(key, v) + } +} + +impl<'a, T> CRelIndexWrite for &'a T +where T: CRelIndexWrite +{ + type Key = T::Key; + type Value = T::Value; + + #[inline(always)] + fn index_insert(&self, key: Self::Key, value: Self::Value) { (**self).index_insert(key, value) } +} + +impl<'a, T> CRelFullIndexWrite for &'a T +where T: CRelFullIndexWrite +{ + type Key = T::Key; + type Value = T::Value; + + #[inline(always)] + fn insert_if_not_present(&self, key: &Self::Key, v: Self::Value) -> bool { (**self).insert_if_not_present(key, v) } +} + +impl<'a, T> RelIndexMerge for &'a mut T +where T: RelIndexMerge +{ + fn move_index_contents(from: &mut Self, to: &mut Self) { T::move_index_contents(*from, *to) } + + fn merge_delta_to_total_new_to_delta(new: &mut Self, delta: &mut Self, total: &mut Self) { + T::merge_delta_to_total_new_to_delta(*new, *delta, *total) + } + + fn init(new: &mut Self, delta: &mut Self, total: &mut Self) { T::init(new, delta, total) } +} + +impl<'a, T> RelIndexRead<'a> for &'a T +where T: RelIndexRead<'a> +{ + type Key = T::Key; + type Value = T::Value; + type IteratorType = T::IteratorType; + + #[inline(always)] + fn index_get(&'a self, key: &Self::Key) -> Option { (**self).index_get(key) } + + #[inline(always)] + fn len(&self) -> usize { (**self).len() } +} + +impl<'a, T> RelIndexReadAll<'a> for &'a T +where T: RelIndexReadAll<'a> +{ + type Key = T::Key; + type Value = T::Value; + type ValueIteratorType = T::ValueIteratorType; + type AllIteratorType = T::AllIteratorType; + + #[inline(always)] + fn iter_all(&'a self) -> Self::AllIteratorType { (**self).iter_all() } +} + +impl<'a, T> RelFullIndexRead<'a> for &'a T +where T: RelFullIndexRead<'a> +{ + type Key = T::Key; + #[inline(always)] + fn contains_key(&self, key: &Self::Key) -> bool { (**self).contains_key(key) } +} + +#[cfg(feature = "par")] +mod par { + use crate::internal::{CRelIndexRead, CRelIndexReadAll}; + + impl<'a, T> CRelIndexRead<'a> for &'a T + where T: CRelIndexRead<'a> + { + type Key = T::Key; + type Value = T::Value; + type IteratorType = T::IteratorType; + + #[inline(always)] + fn c_index_get(&'a self, key: &Self::Key) -> Option { (**self).c_index_get(key) } + } + + impl<'a, T> CRelIndexReadAll<'a> for &'a T + where T: CRelIndexReadAll<'a> + { + type Key = T::Key; + type Value = T::Value; + type ValueIteratorType = T::ValueIteratorType; + type AllIteratorType = T::AllIteratorType; + + #[inline(always)] + fn c_iter_all(&'a self) -> Self::AllIteratorType { (**self).c_iter_all() } + } +} diff --git a/ascent/src/rel_index_read.rs b/ascent/src/rel_index_read.rs index 6b73070..cb6c2bb 100644 --- a/ascent/src/rel_index_read.rs +++ b/ascent/src/rel_index_read.rs @@ -1,211 +1,214 @@ - -use rustc_hash::FxHasher; - -use crate::internal::*; - -use std::collections::HashSet; -use core::slice::Iter; -use std::hash::BuildHasherDefault; -use std::iter::Chain; - -#[allow(clippy::len_without_is_empty)] -pub trait RelIndexRead<'a>{ - type Key; - type Value; - type IteratorType: Iterator + Clone + 'a; - fn index_get(&'a self, key: &Self::Key) -> Option; - fn len(&'a self) -> usize; -} - -pub trait RelIndexReadAll<'a>{ - type Key: 'a; - type Value; - type ValueIteratorType: Iterator + 'a; - type AllIteratorType: Iterator + 'a; - fn iter_all(&'a self) -> Self::AllIteratorType; -} - -impl<'a, K: Eq + std::hash::Hash + 'a, V: Clone + 'a> RelIndexRead<'a> for RelIndexType1 { - type IteratorType = core::slice::Iter<'a, V>; - type Key = K; - type Value = &'a V; - - #[inline] - fn index_get(&'a self, key: &K) -> Option { - let v = self.get(key)?; - Some(v.iter()) - } - - #[inline(always)] - fn len(&self) -> usize { - Self::len(self) - } -} - -impl<'a, K: Eq + std::hash::Hash + 'a, V: 'a + Clone> RelIndexReadAll<'a> for RelIndexType1 { - - type Key = &'a K; - type Value = &'a V; - type ValueIteratorType = core::slice::Iter<'a, V>; - - type AllIteratorType = std::iter::Map>, for <'aa, 'bb> fn ((&'aa K, &'bb Vec)) -> (&'aa K, Iter<'bb, V>)>; - - fn iter_all(&'a self) -> Self::AllIteratorType { - let res: Self::AllIteratorType = self.iter().map(|(k, v)| (k, v.iter())); - res - } -} - - -impl<'a, K: Eq + std::hash::Hash, V: 'a + Clone> RelIndexRead<'a> for HashBrownRelFullIndexType { - type IteratorType = std::iter::Once<&'a V>; - type Key = K; - type Value = &'a V; - - - #[inline] - fn index_get(&'a self, key: &K) -> Option { - let res = self.get(key)?; - Some(std::iter::once(res)) - } - - #[inline(always)] - fn len(&self) -> usize { - Self::len(self) - } -} - -impl<'a, K: Eq + std::hash::Hash + 'a, V: 'a + Clone> RelIndexReadAll<'a> for HashBrownRelFullIndexType { - - type Key = &'a K; - type Value = &'a V; - type ValueIteratorType = std::iter::Once<&'a V>; - - type AllIteratorType = std::iter::Map, for <'aa, 'bb> fn ((&'aa K, &'bb V)) -> (&'aa K, std::iter::Once<&'bb V>)>; - - fn iter_all(&'a self) -> Self::AllIteratorType { - let res: Self::AllIteratorType = self.iter().map(|(k, v)| (k, std::iter::once(v))); - res - } -} - - -impl<'a, K: Eq + std::hash::Hash, V: 'a + Clone> RelIndexRead<'a> for LatticeIndexType { - type IteratorType = std::collections::hash_set::Iter<'a, V>; - type Key = K; - type Value = &'a V; - - - #[inline] - fn index_get(&'a self, key: &K) -> Option { - let res: Option> = - self.get(key).map(HashSet::iter); - res - } - - #[inline(always)] - fn len(&self) -> usize { - Self::len(self) - } -} - -impl<'a, K: Eq + std::hash::Hash + 'a, V: 'a + Clone> RelIndexReadAll<'a> for LatticeIndexType { - - type Key = &'a K; - type Value = &'a V; - type ValueIteratorType = std::collections::hash_set::Iter<'a, V>; - - type AllIteratorType = std::iter::Map>>, for <'aa, 'bb> fn ((&'aa K, &'bb std::collections::HashSet>)) -> (&'aa K, std::collections::hash_set::Iter<'bb, V>)>; - - #[inline] - fn iter_all(&'a self) -> Self::AllIteratorType { - let res: Self::AllIteratorType = self.iter().map(|(k, v)| (k, v.iter())); - res - } -} - - -pub struct RelIndexCombined<'a, Ind1, Ind2> { - pub ind1: &'a Ind1, - pub ind2: &'a Ind2, -} - -impl <'a, Ind1, Ind2> RelIndexCombined<'a, Ind1, Ind2> { - #[inline] - pub fn new(ind1: &'a Ind1, ind2: &'a Ind2) -> Self { Self { ind1, ind2 } } -} - -impl <'a, Ind1, Ind2, K, V> RelIndexRead<'a> for RelIndexCombined<'a, Ind1, Ind2> -where Ind1: RelIndexRead<'a, Key = K, Value = V>, Ind2: RelIndexRead<'a, Key = K, Value = V>, { - type Key = K; - type Value = V; - - type IteratorType = Chain>, - std::iter::Flatten>>; - - fn index_get(&'a self, key: &Self::Key) -> Option { - match (self.ind1.index_get(key), self.ind2.index_get(key)) { - (None, None) => None, - (iter1, iter2) => { - let res = iter1.into_iter().flatten().chain(iter2.into_iter().flatten()); - Some(res) - } - } - } - - #[inline(always)] - fn len(&self) -> usize { self.ind1.len() + self.ind2.len() } -} - -// impl <'a, Ind> RelIndexRead<'a> for RelIndexCombined<'a, Ind, Ind> -// where Ind: RelIndexTrait2<'a> { -// type Key = Ind::Key; - -// type IteratorType = EitherIter>; - -// fn index_get(&'a self, key: &Self::Key) -> Option { -// let res: Option::IteratorType, Chain<::IteratorType, ::IteratorType>>> = -// match (self.ind1.index_get(key), self.ind2.index_get(key)) { -// (None, None) => None, -// (Some(it), None) | (None, Some(it)) => Some(EitherIter::Left(it)), -// (Some(it1), Some(it2)) => Some(EitherIter::Right(it1.chain(it2))) -// }; -// res -// } - -// #[inline(always)] -// fn len(&self) -> usize { self.ind1.len() + self.ind2.len() } -// } - -// #[derive(Clone)] -// pub enum EitherIter { -// Left(L), -// Right(R) -// } - -// impl , R: Iterator, T> Iterator for EitherIter { -// type Item = T; - -// fn next(&mut self) -> Option { -// match self { -// EitherIter::Left(l) => l.next(), -// EitherIter::Right(r) => r.next(), -// } -// } -// } - -impl <'a, Ind1, Ind2, K: 'a, V: 'a, VTI: Iterator + 'a> RelIndexReadAll<'a> for RelIndexCombined<'a, Ind1, Ind2> -where Ind1: RelIndexReadAll<'a, Key = K, ValueIteratorType = VTI>, Ind2: RelIndexReadAll<'a, Key = K, ValueIteratorType = VTI> -{ - type Key = K; - type Value = V; - - type ValueIteratorType = VTI; - - type AllIteratorType = Chain; - - fn iter_all(&'a self) -> Self::AllIteratorType { - let res = self.ind1.iter_all().chain(self.ind2.iter_all()); - res - } -} \ No newline at end of file +use core::slice::Iter; +use std::collections::HashSet; +use std::hash::BuildHasherDefault; +use std::iter::Chain; + +use rustc_hash::FxHasher; + +use crate::internal::*; + +#[allow(clippy::len_without_is_empty)] +pub trait RelIndexRead<'a> { + type Key; + type Value; + type IteratorType: Iterator + Clone + 'a; + fn index_get(&'a self, key: &Self::Key) -> Option; + fn len(&'a self) -> usize; +} + +pub trait RelIndexReadAll<'a> { + type Key: 'a; + type Value; + type ValueIteratorType: Iterator + 'a; + type AllIteratorType: Iterator + 'a; + fn iter_all(&'a self) -> Self::AllIteratorType; +} + +impl<'a, K: Eq + std::hash::Hash + 'a, V: Clone + 'a> RelIndexRead<'a> for RelIndexType1 { + type IteratorType = core::slice::Iter<'a, V>; + type Key = K; + type Value = &'a V; + + #[inline] + fn index_get(&'a self, key: &K) -> Option { + let v = self.get(key)?; + Some(v.iter()) + } + + #[inline(always)] + fn len(&self) -> usize { Self::len(self) } +} + +impl<'a, K: Eq + std::hash::Hash + 'a, V: 'a + Clone> RelIndexReadAll<'a> for RelIndexType1 { + type Key = &'a K; + type Value = &'a V; + type ValueIteratorType = core::slice::Iter<'a, V>; + + type AllIteratorType = std::iter::Map< + std::collections::hash_map::Iter<'a, K, Vec>, + for<'aa, 'bb> fn((&'aa K, &'bb Vec)) -> (&'aa K, Iter<'bb, V>), + >; + + fn iter_all(&'a self) -> Self::AllIteratorType { + let res: Self::AllIteratorType = self.iter().map(|(k, v)| (k, v.iter())); + res + } +} + +impl<'a, K: Eq + std::hash::Hash, V: 'a + Clone> RelIndexRead<'a> for HashBrownRelFullIndexType { + type IteratorType = std::iter::Once<&'a V>; + type Key = K; + type Value = &'a V; + + #[inline] + fn index_get(&'a self, key: &K) -> Option { + let res = self.get(key)?; + Some(std::iter::once(res)) + } + + #[inline(always)] + fn len(&self) -> usize { Self::len(self) } +} + +impl<'a, K: Eq + std::hash::Hash + 'a, V: 'a + Clone> RelIndexReadAll<'a> for HashBrownRelFullIndexType { + type Key = &'a K; + type Value = &'a V; + type ValueIteratorType = std::iter::Once<&'a V>; + + type AllIteratorType = std::iter::Map< + hashbrown::hash_map::Iter<'a, K, V>, + for<'aa, 'bb> fn((&'aa K, &'bb V)) -> (&'aa K, std::iter::Once<&'bb V>), + >; + + fn iter_all(&'a self) -> Self::AllIteratorType { + let res: Self::AllIteratorType = self.iter().map(|(k, v)| (k, std::iter::once(v))); + res + } +} + +impl<'a, K: Eq + std::hash::Hash, V: 'a + Clone> RelIndexRead<'a> for LatticeIndexType { + type IteratorType = std::collections::hash_set::Iter<'a, V>; + type Key = K; + type Value = &'a V; + + #[inline] + fn index_get(&'a self, key: &K) -> Option { + let res: Option> = self.get(key).map(HashSet::iter); + res + } + + #[inline(always)] + fn len(&self) -> usize { Self::len(self) } +} + +impl<'a, K: Eq + std::hash::Hash + 'a, V: 'a + Clone> RelIndexReadAll<'a> for LatticeIndexType { + type Key = &'a K; + type Value = &'a V; + type ValueIteratorType = std::collections::hash_set::Iter<'a, V>; + + type AllIteratorType = std::iter::Map< + std::collections::hash_map::Iter<'a, K, HashSet>>, + for<'aa, 'bb> fn( + (&'aa K, &'bb std::collections::HashSet>), + ) -> (&'aa K, std::collections::hash_set::Iter<'bb, V>), + >; + + #[inline] + fn iter_all(&'a self) -> Self::AllIteratorType { + let res: Self::AllIteratorType = self.iter().map(|(k, v)| (k, v.iter())); + res + } +} + +pub struct RelIndexCombined<'a, Ind1, Ind2> { + pub ind1: &'a Ind1, + pub ind2: &'a Ind2, +} + +impl<'a, Ind1, Ind2> RelIndexCombined<'a, Ind1, Ind2> { + #[inline] + pub fn new(ind1: &'a Ind1, ind2: &'a Ind2) -> Self { Self { ind1, ind2 } } +} + +impl<'a, Ind1, Ind2, K, V> RelIndexRead<'a> for RelIndexCombined<'a, Ind1, Ind2> +where + Ind1: RelIndexRead<'a, Key = K, Value = V>, + Ind2: RelIndexRead<'a, Key = K, Value = V>, +{ + type Key = K; + type Value = V; + + type IteratorType = Chain< + std::iter::Flatten>, + std::iter::Flatten>, + >; + + fn index_get(&'a self, key: &Self::Key) -> Option { + match (self.ind1.index_get(key), self.ind2.index_get(key)) { + (None, None) => None, + (iter1, iter2) => { + let res = iter1.into_iter().flatten().chain(iter2.into_iter().flatten()); + Some(res) + }, + } + } + + #[inline(always)] + fn len(&self) -> usize { self.ind1.len() + self.ind2.len() } +} + +// impl <'a, Ind> RelIndexRead<'a> for RelIndexCombined<'a, Ind, Ind> +// where Ind: RelIndexTrait2<'a> { +// type Key = Ind::Key; + +// type IteratorType = EitherIter>; + +// fn index_get(&'a self, key: &Self::Key) -> Option { +// let res: Option::IteratorType, Chain<::IteratorType, ::IteratorType>>> = +// match (self.ind1.index_get(key), self.ind2.index_get(key)) { +// (None, None) => None, +// (Some(it), None) | (None, Some(it)) => Some(EitherIter::Left(it)), +// (Some(it1), Some(it2)) => Some(EitherIter::Right(it1.chain(it2))) +// }; +// res +// } + +// #[inline(always)] +// fn len(&self) -> usize { self.ind1.len() + self.ind2.len() } +// } + +// #[derive(Clone)] +// pub enum EitherIter { +// Left(L), +// Right(R) +// } + +// impl , R: Iterator, T> Iterator for EitherIter { +// type Item = T; + +// fn next(&mut self) -> Option { +// match self { +// EitherIter::Left(l) => l.next(), +// EitherIter::Right(r) => r.next(), +// } +// } +// } + +impl<'a, Ind1, Ind2, K: 'a, V: 'a, VTI: Iterator + 'a> RelIndexReadAll<'a> + for RelIndexCombined<'a, Ind1, Ind2> +where + Ind1: RelIndexReadAll<'a, Key = K, ValueIteratorType = VTI>, + Ind2: RelIndexReadAll<'a, Key = K, ValueIteratorType = VTI>, +{ + type Key = K; + type Value = V; + + type ValueIteratorType = VTI; + + type AllIteratorType = Chain; + + fn iter_all(&'a self) -> Self::AllIteratorType { + let res = self.ind1.iter_all().chain(self.ind2.iter_all()); + res + } +} diff --git a/ascent/src/to_rel_index.rs b/ascent/src/to_rel_index.rs index f8fdadb..97e0453 100644 --- a/ascent/src/to_rel_index.rs +++ b/ascent/src/to_rel_index.rs @@ -1,42 +1,62 @@ - -pub trait ToRelIndex0 { - - type RelIndex<'a> where Self: 'a, Rel: 'a; - fn to_rel_index<'a>(&'a self, rel: &'a Rel) -> Self::RelIndex<'a>; - - type RelIndexWrite<'a> where Self: 'a, Rel: 'a; - fn to_rel_index_write<'a>(&'a mut self, rel: &'a mut Rel) -> Self::RelIndexWrite<'a>; - - type CRelIndexWrite<'a> where Self: 'a, Rel: 'a; - fn to_c_rel_index_write<'a>(&'a self, rel: &'a Rel) -> Self::CRelIndexWrite<'a>; -} - -pub trait ToRelIndex { - - type RelIndex<'a> where Self: 'a, Rel: 'a; - fn to_rel_index<'a>(&'a self, rel: &'a Rel) -> Self::RelIndex<'a>; - - type RelIndexWrite<'a> where Self: 'a, Rel: 'a; - fn to_rel_index_write<'a>(&'a mut self, rel: &'a mut Rel) -> Self::RelIndexWrite<'a>; -} - -impl ToRelIndex0 for T where T: ToRelIndex { - type RelIndex<'a> = T::RelIndex<'a> where Self: 'a, Rel: 'a; - - #[inline(always)] - fn to_rel_index<'a>(&'a self, rel: &'a Rel) -> Self::RelIndex<'a> { - self.to_rel_index(rel) - } - - type RelIndexWrite<'a> = T::RelIndexWrite<'a> where Self: 'a, Rel: 'a; - #[inline(always)] - fn to_rel_index_write<'a>(&'a mut self, rel: &'a mut Rel) -> Self::RelIndexWrite<'a> { - self.to_rel_index_write(rel) - } - - type CRelIndexWrite<'a> = T::RelIndex<'a> where Self: 'a, Rel: 'a; - #[inline(always)] - fn to_c_rel_index_write<'a>(&'a self, rel: &'a Rel) -> Self::CRelIndexWrite<'a> { - self.to_rel_index(rel) - } -} +pub trait ToRelIndex0 { + type RelIndex<'a> + where + Self: 'a, + Rel: 'a; + fn to_rel_index<'a>(&'a self, rel: &'a Rel) -> Self::RelIndex<'a>; + + type RelIndexWrite<'a> + where + Self: 'a, + Rel: 'a; + fn to_rel_index_write<'a>(&'a mut self, rel: &'a mut Rel) -> Self::RelIndexWrite<'a>; + + type CRelIndexWrite<'a> + where + Self: 'a, + Rel: 'a; + fn to_c_rel_index_write<'a>(&'a self, rel: &'a Rel) -> Self::CRelIndexWrite<'a>; +} + +pub trait ToRelIndex { + type RelIndex<'a> + where + Self: 'a, + Rel: 'a; + fn to_rel_index<'a>(&'a self, rel: &'a Rel) -> Self::RelIndex<'a>; + + type RelIndexWrite<'a> + where + Self: 'a, + Rel: 'a; + fn to_rel_index_write<'a>(&'a mut self, rel: &'a mut Rel) -> Self::RelIndexWrite<'a>; +} + +impl ToRelIndex0 for T +where T: ToRelIndex +{ + type RelIndex<'a> + = T::RelIndex<'a> + where + Self: 'a, + Rel: 'a; + + #[inline(always)] + fn to_rel_index<'a>(&'a self, rel: &'a Rel) -> Self::RelIndex<'a> { self.to_rel_index(rel) } + + type RelIndexWrite<'a> + = T::RelIndexWrite<'a> + where + Self: 'a, + Rel: 'a; + #[inline(always)] + fn to_rel_index_write<'a>(&'a mut self, rel: &'a mut Rel) -> Self::RelIndexWrite<'a> { self.to_rel_index_write(rel) } + + type CRelIndexWrite<'a> + = T::RelIndex<'a> + where + Self: 'a, + Rel: 'a; + #[inline(always)] + fn to_c_rel_index_write<'a>(&'a self, rel: &'a Rel) -> Self::CRelIndexWrite<'a> { self.to_rel_index(rel) } +} diff --git a/ascent/src/tuple_of_borrowed.rs b/ascent/src/tuple_of_borrowed.rs index 45e4625..193261b 100644 --- a/ascent/src/tuple_of_borrowed.rs +++ b/ascent/src/tuple_of_borrowed.rs @@ -1,59 +1,55 @@ -use paste::paste; - -pub trait TupleOfBorrowed { - type Tuple; - fn tuple_of_borrowed(self) -> Self::Tuple; -} - -impl<'a, T1> TupleOfBorrowed for &'a (T1,) { - type Tuple = (&'a T1,); - - #[inline(always)] - fn tuple_of_borrowed(self) -> Self::Tuple { - (&self.0,) - } -} - -impl<'a, T1> TupleOfBorrowed for (&'a T1,) { - type Tuple = Self; - - #[inline(always)] - fn tuple_of_borrowed(self) -> Self::Tuple { - self - } -} - -macro_rules! impl_tuple_of_borrowed { - ($($i: literal),*) => { paste!{ - impl<'a, $([]),*> TupleOfBorrowed for &'a ($([]),*) { - type Tuple = ($(&'a []),*); - #[allow(clippy::unused_unit)] - #[inline(always)] - fn tuple_of_borrowed(self) -> Self::Tuple { - ($(&self.$i),*) - } - } - - impl<'a, $([]),*> TupleOfBorrowed for ($(&'a []),*) { - type Tuple = Self; - #[inline(always)] - fn tuple_of_borrowed(self) -> Self::Tuple { - self - } - } - }}; -} - -impl_tuple_of_borrowed!(); -// impl_tuple_of_borrowed!(0); -impl_tuple_of_borrowed!(0, 1); -impl_tuple_of_borrowed!(0, 1, 2); -impl_tuple_of_borrowed!(0, 1, 2, 3); -impl_tuple_of_borrowed!(0, 1, 2, 3, 4); -impl_tuple_of_borrowed!(0, 1, 2, 3, 4, 5); -impl_tuple_of_borrowed!(0, 1, 2, 3, 4, 5, 6); -impl_tuple_of_borrowed!(0, 1, 2, 3, 4, 5, 6, 7); -impl_tuple_of_borrowed!(0, 1, 2, 3, 4, 5, 6, 7, 8); -impl_tuple_of_borrowed!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9); -impl_tuple_of_borrowed!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10); -impl_tuple_of_borrowed!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11); +use paste::paste; + +pub trait TupleOfBorrowed { + type Tuple; + fn tuple_of_borrowed(self) -> Self::Tuple; +} + +impl<'a, T1> TupleOfBorrowed for &'a (T1,) { + type Tuple = (&'a T1,); + + #[inline(always)] + fn tuple_of_borrowed(self) -> Self::Tuple { (&self.0,) } +} + +impl<'a, T1> TupleOfBorrowed for (&'a T1,) { + type Tuple = Self; + + #[inline(always)] + fn tuple_of_borrowed(self) -> Self::Tuple { self } +} + +macro_rules! impl_tuple_of_borrowed { + ($($i: literal),*) => { paste!{ + impl<'a, $([]),*> TupleOfBorrowed for &'a ($([]),*) { + type Tuple = ($(&'a []),*); + #[allow(clippy::unused_unit)] + #[inline(always)] + fn tuple_of_borrowed(self) -> Self::Tuple { + ($(&self.$i),*) + } + } + + impl<'a, $([]),*> TupleOfBorrowed for ($(&'a []),*) { + type Tuple = Self; + #[inline(always)] + fn tuple_of_borrowed(self) -> Self::Tuple { + self + } + } + }}; +} + +impl_tuple_of_borrowed!(); +// impl_tuple_of_borrowed!(0); +impl_tuple_of_borrowed!(0, 1); +impl_tuple_of_borrowed!(0, 1, 2); +impl_tuple_of_borrowed!(0, 1, 2, 3); +impl_tuple_of_borrowed!(0, 1, 2, 3, 4); +impl_tuple_of_borrowed!(0, 1, 2, 3, 4, 5); +impl_tuple_of_borrowed!(0, 1, 2, 3, 4, 5, 6); +impl_tuple_of_borrowed!(0, 1, 2, 3, 4, 5, 6, 7); +impl_tuple_of_borrowed!(0, 1, 2, 3, 4, 5, 6, 7, 8); +impl_tuple_of_borrowed!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9); +impl_tuple_of_borrowed!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10); +impl_tuple_of_borrowed!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11); diff --git a/ascent_base/src/lattice.rs b/ascent_base/src/lattice.rs index f1ee2be..1219229 100644 --- a/ascent_base/src/lattice.rs +++ b/ascent_base/src/lattice.rs @@ -1,247 +1,238 @@ -//! Defines the `Lattice` trait and provides implementations for standard types - -pub mod constant_propagation; -pub mod set; -pub mod product; -pub mod ord_lattice; -pub mod bounded_set; -pub use product::Product; -pub mod tuple; -use std::cmp::{Ordering, Reverse}; -use std::sync::Arc; -use std::rc::Rc; -mod dual; -pub use dual::Dual; - -/// A `Lattice` is a `PartialOrd` where each pair of elements has a least upper bound (`join`) and a greatest lower bound (`meet`) -pub trait Lattice: PartialOrd + Sized { - /// ensures `self` is the join of `self` and `other` - /// - /// Returns true if `self` was changed. - fn meet_mut(&mut self, other: Self) -> bool; - - /// ensures `self` is the meet of `self` and `other`. - /// - /// Returns true if `self` was changed. - fn join_mut(&mut self, other: Self) -> bool; - - /// The greatest lower bound of two elements. `meet(x, y)` is the biggest value `z` - /// s.t. `z <= x` and `z <= y` - fn meet(mut self, other: Self) -> Self { - self.meet_mut(other); - self - } - - /// The least upper bound of two elements. `join(x, y)` is the smallest value `z` - /// s.t. `z >= x` and `z >= y`. - fn join(mut self, other: Self) -> Self { - self.join_mut(other); - self - } -} - -pub trait BoundedLattice: Lattice { - fn bottom() -> Self; - fn top() -> Self; -} - -macro_rules! ord_lattice_impl { - ($t: ty) => { - impl Lattice for $t { - fn meet_mut(&mut self, other: Self) -> bool { - #[allow(clippy::neg_cmp_op_on_partial_ord)] - let changed = !(*self <= other); - if changed { - *self = other; - } - changed - } - - fn join_mut(&mut self, other: Self) -> bool { - #[allow(clippy::neg_cmp_op_on_partial_ord)] - let changed = !(*self >= other); - if changed { - *self = other; - } - changed - } - } - }; -} - -ord_lattice_impl!(bool); - -impl BoundedLattice for bool { - #[inline] - fn bottom() -> Self { false } - #[inline] - fn top() -> Self { true } -} - -macro_rules! num_lattice_impl { - ($int:ty) => { - ord_lattice_impl!($int); - impl BoundedLattice for $int { - fn bottom() -> Self { Self::MIN } - fn top() -> Self { Self::MAX } - } - }; -} - -num_lattice_impl!(i8); -num_lattice_impl!(u8); -num_lattice_impl!(i16); -num_lattice_impl!(u16); -num_lattice_impl!(i32); -num_lattice_impl!(u32); -num_lattice_impl!(i64); -num_lattice_impl!(u64); -num_lattice_impl!(i128); -num_lattice_impl!(u128); - -num_lattice_impl!(isize); -num_lattice_impl!(usize); - -impl Lattice for Option { - fn meet_mut(&mut self, other: Self) -> bool { - match (self, other) { - (Some(x), Some(y)) => { - x.meet_mut(y) - }, - (this @ Some(_), None) => { - *this = None; - true - }, - (None, _) => false - } - } - - fn join_mut(&mut self, other: Self) -> bool { - match (self, other) { - (Some(x), Some(y)) => { - x.join_mut(y) - }, - (this @ None, Some(y)) => { - *this = Some(y); - true - } - (_, None) => false - } - } -} - -impl BoundedLattice for Option { - #[inline] - fn bottom() -> Self { None } - #[inline] - fn top() -> Self { Some(T::top()) } -} - - -impl Lattice for Rc { - fn meet_mut(&mut self, other: Self) -> bool { - match self.as_ref().partial_cmp(&other) { - Some(Ordering::Less | Ordering::Equal) => false, - Some(Ordering::Greater) => { - *self = other; - true - } - // Stable in 1.76: - // None => Rc::make_mut(self).meet_mut(Rc::unwrap_or_clone(other)) - None => Rc::make_mut(self).meet_mut(Rc::try_unwrap(other).unwrap_or_else(|rc| (*rc).clone())) - } - } - - fn join_mut(&mut self, other: Self) -> bool { - match self.as_ref().partial_cmp(&other) { - Some(Ordering::Greater | Ordering::Equal) => false, - Some(Ordering::Less) => { - *self = other; - true - }, - // Stable in 1.76: - // None => Rc::make_mut(self).join_mut(Rc::unwrap_or_clone(other)) - None => Rc::make_mut(self).join_mut(Rc::try_unwrap(other).unwrap_or_else(|rc| (*rc).clone())) - } - } -} - -impl Lattice for Arc { - fn meet_mut(&mut self, other: Self) -> bool { - match self.as_ref().partial_cmp(&other) { - Some(Ordering::Less | Ordering::Equal) => false, - Some(Ordering::Greater) => { - *self = other; - true - } - // Stable in 1.76: - // None => Arc::make_mut(self).meet_mut(Arc::unwrap_or_clone(other)) - None => Arc::make_mut(self).meet_mut(Arc::try_unwrap(other).unwrap_or_else(|rc| (*rc).clone())) - } - } - - fn join_mut(&mut self, other: Self) -> bool { - match self.as_ref().partial_cmp(&other) { - Some(Ordering::Greater | Ordering::Equal) => false, - Some(Ordering::Less) => { - *self = other; - true - }, - // Stable in 1.76: - // None => Arc::make_mut(self).join_mut(Arc::unwrap_or_clone(other)) - None => Arc::make_mut(self).join_mut(Arc::try_unwrap(other).unwrap_or_else(|rc| (*rc).clone())) - } - } -} - -impl Lattice for Box { - fn meet_mut(&mut self, other: Self) -> bool { - self.as_mut().meet_mut(*other) - } - - fn join_mut(&mut self, other: Self) -> bool { - self.as_mut().join_mut(*other) - } -} - -impl Lattice for Reverse { - #[inline] - fn meet(self, other: Self) -> Self { Reverse(self.0.join(other.0)) } - - #[inline] - fn join(self, other: Self) -> Self { Reverse(self.0.meet(other.0)) } - - #[inline] - fn meet_mut(&mut self, other: Self) -> bool { self.0.join_mut(other.0) } - - #[inline] - fn join_mut(&mut self, other: Self) -> bool { self.0.meet_mut(other.0) } -} - -impl BoundedLattice for Reverse { - #[inline] - fn bottom() -> Self { Reverse(T::top()) } - - #[inline] - fn top() -> Self { Reverse(T::bottom()) } -} - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use crate::Lattice; - - #[test] - fn test_arc_lattice() { - let x = Arc::new(42); - let y = Arc::new(17); - assert_eq!(*x.clone().meet(y.clone()), 17); - assert_eq!(*x.meet(y), 17); - - let x = Arc::new(42); - let y = Arc::new(17); - assert_eq!(*x.clone().join(y.clone()), 42); - assert_eq!(*x.join(y), 42); - } -} +//! Defines the `Lattice` trait and provides implementations for standard types + +pub mod constant_propagation; +pub mod set; +pub mod product; +pub mod ord_lattice; +pub mod bounded_set; +pub use product::Product; +pub mod tuple; +use std::cmp::{Ordering, Reverse}; +use std::rc::Rc; +use std::sync::Arc; +mod dual; +pub use dual::Dual; + +/// A `Lattice` is a `PartialOrd` where each pair of elements has a least upper bound (`join`) and a greatest lower bound (`meet`) +pub trait Lattice: PartialOrd + Sized { + /// ensures `self` is the join of `self` and `other` + /// + /// Returns true if `self` was changed. + fn meet_mut(&mut self, other: Self) -> bool; + + /// ensures `self` is the meet of `self` and `other`. + /// + /// Returns true if `self` was changed. + fn join_mut(&mut self, other: Self) -> bool; + + /// The greatest lower bound of two elements. `meet(x, y)` is the biggest value `z` + /// s.t. `z <= x` and `z <= y` + fn meet(mut self, other: Self) -> Self { + self.meet_mut(other); + self + } + + /// The least upper bound of two elements. `join(x, y)` is the smallest value `z` + /// s.t. `z >= x` and `z >= y`. + fn join(mut self, other: Self) -> Self { + self.join_mut(other); + self + } +} + +pub trait BoundedLattice: Lattice { + fn bottom() -> Self; + fn top() -> Self; +} + +macro_rules! ord_lattice_impl { + ($t: ty) => { + impl Lattice for $t { + fn meet_mut(&mut self, other: Self) -> bool { + #[allow(clippy::neg_cmp_op_on_partial_ord)] + let changed = !(*self <= other); + if changed { + *self = other; + } + changed + } + + fn join_mut(&mut self, other: Self) -> bool { + #[allow(clippy::neg_cmp_op_on_partial_ord)] + let changed = !(*self >= other); + if changed { + *self = other; + } + changed + } + } + }; +} + +ord_lattice_impl!(bool); + +impl BoundedLattice for bool { + #[inline] + fn bottom() -> Self { false } + #[inline] + fn top() -> Self { true } +} + +macro_rules! num_lattice_impl { + ($int:ty) => { + ord_lattice_impl!($int); + impl BoundedLattice for $int { + fn bottom() -> Self { Self::MIN } + fn top() -> Self { Self::MAX } + } + }; +} + +num_lattice_impl!(i8); +num_lattice_impl!(u8); +num_lattice_impl!(i16); +num_lattice_impl!(u16); +num_lattice_impl!(i32); +num_lattice_impl!(u32); +num_lattice_impl!(i64); +num_lattice_impl!(u64); +num_lattice_impl!(i128); +num_lattice_impl!(u128); + +num_lattice_impl!(isize); +num_lattice_impl!(usize); + +impl Lattice for Option { + fn meet_mut(&mut self, other: Self) -> bool { + match (self, other) { + (Some(x), Some(y)) => x.meet_mut(y), + (this @ Some(_), None) => { + *this = None; + true + }, + (None, _) => false, + } + } + + fn join_mut(&mut self, other: Self) -> bool { + match (self, other) { + (Some(x), Some(y)) => x.join_mut(y), + (this @ None, Some(y)) => { + *this = Some(y); + true + }, + (_, None) => false, + } + } +} + +impl BoundedLattice for Option { + #[inline] + fn bottom() -> Self { None } + #[inline] + fn top() -> Self { Some(T::top()) } +} + +impl Lattice for Rc { + fn meet_mut(&mut self, other: Self) -> bool { + match self.as_ref().partial_cmp(&other) { + Some(Ordering::Less | Ordering::Equal) => false, + Some(Ordering::Greater) => { + *self = other; + true + }, + // Stable in 1.76: + // None => Rc::make_mut(self).meet_mut(Rc::unwrap_or_clone(other)) + None => Rc::make_mut(self).meet_mut(Rc::try_unwrap(other).unwrap_or_else(|rc| (*rc).clone())), + } + } + + fn join_mut(&mut self, other: Self) -> bool { + match self.as_ref().partial_cmp(&other) { + Some(Ordering::Greater | Ordering::Equal) => false, + Some(Ordering::Less) => { + *self = other; + true + }, + // Stable in 1.76: + // None => Rc::make_mut(self).join_mut(Rc::unwrap_or_clone(other)) + None => Rc::make_mut(self).join_mut(Rc::try_unwrap(other).unwrap_or_else(|rc| (*rc).clone())), + } + } +} + +impl Lattice for Arc { + fn meet_mut(&mut self, other: Self) -> bool { + match self.as_ref().partial_cmp(&other) { + Some(Ordering::Less | Ordering::Equal) => false, + Some(Ordering::Greater) => { + *self = other; + true + }, + // Stable in 1.76: + // None => Arc::make_mut(self).meet_mut(Arc::unwrap_or_clone(other)) + None => Arc::make_mut(self).meet_mut(Arc::try_unwrap(other).unwrap_or_else(|rc| (*rc).clone())), + } + } + + fn join_mut(&mut self, other: Self) -> bool { + match self.as_ref().partial_cmp(&other) { + Some(Ordering::Greater | Ordering::Equal) => false, + Some(Ordering::Less) => { + *self = other; + true + }, + // Stable in 1.76: + // None => Arc::make_mut(self).join_mut(Arc::unwrap_or_clone(other)) + None => Arc::make_mut(self).join_mut(Arc::try_unwrap(other).unwrap_or_else(|rc| (*rc).clone())), + } + } +} + +impl Lattice for Box { + fn meet_mut(&mut self, other: Self) -> bool { self.as_mut().meet_mut(*other) } + + fn join_mut(&mut self, other: Self) -> bool { self.as_mut().join_mut(*other) } +} + +impl Lattice for Reverse { + #[inline] + fn meet(self, other: Self) -> Self { Reverse(self.0.join(other.0)) } + + #[inline] + fn join(self, other: Self) -> Self { Reverse(self.0.meet(other.0)) } + + #[inline] + fn meet_mut(&mut self, other: Self) -> bool { self.0.join_mut(other.0) } + + #[inline] + fn join_mut(&mut self, other: Self) -> bool { self.0.meet_mut(other.0) } +} + +impl BoundedLattice for Reverse { + #[inline] + fn bottom() -> Self { Reverse(T::top()) } + + #[inline] + fn top() -> Self { Reverse(T::bottom()) } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use crate::Lattice; + + #[test] + fn test_arc_lattice() { + let x = Arc::new(42); + let y = Arc::new(17); + assert_eq!(*x.clone().meet(y.clone()), 17); + assert_eq!(*x.meet(y), 17); + + let x = Arc::new(42); + let y = Arc::new(17); + assert_eq!(*x.clone().join(y.clone()), 42); + assert_eq!(*x.join(y), 42); + } +} diff --git a/ascent_base/src/lattice/bounded_set.rs b/ascent_base/src/lattice/bounded_set.rs index 65d3bc4..37d45cc 100644 --- a/ascent_base/src/lattice/bounded_set.rs +++ b/ascent_base/src/lattice/bounded_set.rs @@ -1,165 +1,136 @@ -use std::hash::Hash; - -use crate::Lattice; - -use super::{set::Set, BoundedLattice}; - -/// `BoundedSet` is a generalization of the flat lattice. -/// -/// A `BoundedSet` stores at most `BOUND` items, and if asked to store more, will go to `TOP`. -#[derive(Clone, PartialEq, Eq, Hash)] -pub struct BoundedSet(Option>); - - -impl Default for BoundedSet { - fn default() -> Self {Self::new()} -} - -impl BoundedSet { - - /// A set containing everything - pub const TOP: Self = BoundedSet(None); - - /// Creates an empty `BoundedSet` - pub fn new() -> Self { - BoundedSet(Some(Set::default())) - } - - /// Creates a `BoundedSet` containing only `item` - pub fn singleton(item: T) -> Self { - Self::from_set(Set::singleton(item)) - } - - /// Creates a `BoundedSet` from a `Set`, ensuring the `BOUND` is not exceeded - pub fn from_set(set: Set) -> Self { - if set.len() <= BOUND { - BoundedSet(Some(set)) - } else { - BoundedSet(None) - } - } - - /// Returns the size of the set. In case of the set being `TOP`, returns `None` - pub fn count(&self) -> Option { - self.0.as_ref().map(|s| s.len()) - } - - /// Returns `true` if the set contains the `item`. For a set that `is_top()`, always returns `true`. - pub fn contains(&self, item: &T) -> bool { - match &self.0 { - Some(set) => set.0.contains(item), - None => true, - } - } - - pub fn is_top(&self) -> bool { - self.0.is_none() - } -} - -impl PartialOrd for BoundedSet { - fn partial_cmp(&self, other: &Self) -> Option { - use std::cmp::Ordering; - match (&self.0, &other.0) { - (None, None) => Some(Ordering::Equal), - (None, _) => Some(Ordering::Greater), - (_, None) => Some(Ordering::Less), - (Some(set1), Some(set2)) => { - set1.partial_cmp(set2) - } - } - } -} - -impl Lattice for BoundedSet { - fn meet_mut(&mut self, other: Self) -> bool { - match (&mut self.0, other.0){ - (None, None) => false, - (this @ None, Some(set2)) => { - *this = Some(set2); - true - }, - (Some(_), None) => false, - (Some(set1), Some(set2)) => { - set1.meet_mut(set2) - }, - } - } - - fn join_mut(&mut self, other: Self) -> bool { - match (&mut self.0, other.0){ - (None, _) => false, - (this @ Some(_), None) => { - *this = None; - true - }, - (Some(set1), Some(set2)) => { - let changed = set1.join_mut(set2); - if set1.len() > BOUND { - self.0 = None; - true - } else { - changed - } - }, - } - } - fn meet(self, other: Self) -> Self { - match (self.0, other.0){ - (None, None) => BoundedSet(None), - (None, set2@Some(_)) => BoundedSet(set2), - (set1@Some(_), None) => BoundedSet(set1), - (Some(set1), Some(set2)) => { - let res = set1.meet(set2); - BoundedSet(Some(res)) - }, - } - } - - fn join(self, other: Self) -> Self { - match (self.0, other.0){ - (None, _) => BoundedSet(None), - (_, None) => BoundedSet(None), - (Some(set1), Some(set2)) => { - let res = set1.join(set2); - if res.len() > BOUND { - BoundedSet(None) - } else { - BoundedSet(Some(res)) - } - }, - } - } -} - -impl BoundedLattice for BoundedSet { - fn bottom() -> Self { - Self::new() - } - - /// top is meant to represent a set containing everything - fn top() -> Self { - BoundedSet(None) - } -} - - -#[test] -fn test_bounded_set() { - let set1 = BoundedSet::<2, i32>::singleton(10); - - let mut set2_by_mut = set1.clone(); - assert!(set2_by_mut.join_mut(BoundedSet::singleton(11))); - - let set2 = set1.join(BoundedSet::singleton(11)); - - assert!(set2_by_mut == set2); - assert_eq!(set2.count(), Some(2)); - assert!(set2.contains(&10)); - assert!(!set2.contains(&20)); - - let set3 = set2.join(BoundedSet::singleton(12)); - assert!(set3.is_top()); - assert!(set3 == BoundedSet::TOP); - assert!(set3.contains(&15)); -} \ No newline at end of file +use std::hash::Hash; + +use super::BoundedLattice; +use super::set::Set; +use crate::Lattice; + +/// `BoundedSet` is a generalization of the flat lattice. +/// +/// A `BoundedSet` stores at most `BOUND` items, and if asked to store more, will go to `TOP`. +#[derive(Clone, PartialEq, Eq, Hash)] +pub struct BoundedSet(Option>); + +impl Default for BoundedSet { + fn default() -> Self { Self::new() } +} + +impl BoundedSet { + /// A set containing everything + pub const TOP: Self = BoundedSet(None); + + /// Creates an empty `BoundedSet` + pub fn new() -> Self { BoundedSet(Some(Set::default())) } + + /// Creates a `BoundedSet` containing only `item` + pub fn singleton(item: T) -> Self { Self::from_set(Set::singleton(item)) } + + /// Creates a `BoundedSet` from a `Set`, ensuring the `BOUND` is not exceeded + pub fn from_set(set: Set) -> Self { if set.len() <= BOUND { BoundedSet(Some(set)) } else { BoundedSet(None) } } + + /// Returns the size of the set. In case of the set being `TOP`, returns `None` + pub fn count(&self) -> Option { self.0.as_ref().map(|s| s.len()) } + + /// Returns `true` if the set contains the `item`. For a set that `is_top()`, always returns `true`. + pub fn contains(&self, item: &T) -> bool { + match &self.0 { + Some(set) => set.0.contains(item), + None => true, + } + } + + pub fn is_top(&self) -> bool { self.0.is_none() } +} + +impl PartialOrd for BoundedSet { + fn partial_cmp(&self, other: &Self) -> Option { + use std::cmp::Ordering; + match (&self.0, &other.0) { + (None, None) => Some(Ordering::Equal), + (None, _) => Some(Ordering::Greater), + (_, None) => Some(Ordering::Less), + (Some(set1), Some(set2)) => set1.partial_cmp(set2), + } + } +} + +impl Lattice for BoundedSet { + fn meet_mut(&mut self, other: Self) -> bool { + match (&mut self.0, other.0) { + (None, None) => false, + (this @ None, Some(set2)) => { + *this = Some(set2); + true + }, + (Some(_), None) => false, + (Some(set1), Some(set2)) => set1.meet_mut(set2), + } + } + + fn join_mut(&mut self, other: Self) -> bool { + match (&mut self.0, other.0) { + (None, _) => false, + (this @ Some(_), None) => { + *this = None; + true + }, + (Some(set1), Some(set2)) => { + let changed = set1.join_mut(set2); + if set1.len() > BOUND { + self.0 = None; + true + } else { + changed + } + }, + } + } + fn meet(self, other: Self) -> Self { + match (self.0, other.0) { + (None, None) => BoundedSet(None), + (None, set2 @ Some(_)) => BoundedSet(set2), + (set1 @ Some(_), None) => BoundedSet(set1), + (Some(set1), Some(set2)) => { + let res = set1.meet(set2); + BoundedSet(Some(res)) + }, + } + } + + fn join(self, other: Self) -> Self { + match (self.0, other.0) { + (None, _) => BoundedSet(None), + (_, None) => BoundedSet(None), + (Some(set1), Some(set2)) => { + let res = set1.join(set2); + if res.len() > BOUND { BoundedSet(None) } else { BoundedSet(Some(res)) } + }, + } + } +} + +impl BoundedLattice for BoundedSet { + fn bottom() -> Self { Self::new() } + + /// top is meant to represent a set containing everything + fn top() -> Self { BoundedSet(None) } +} + +#[test] +fn test_bounded_set() { + let set1 = BoundedSet::<2, i32>::singleton(10); + + let mut set2_by_mut = set1.clone(); + assert!(set2_by_mut.join_mut(BoundedSet::singleton(11))); + + let set2 = set1.join(BoundedSet::singleton(11)); + + assert!(set2_by_mut == set2); + assert_eq!(set2.count(), Some(2)); + assert!(set2.contains(&10)); + assert!(!set2.contains(&20)); + + let set3 = set2.join(BoundedSet::singleton(12)); + assert!(set3.is_top()); + assert!(set3 == BoundedSet::TOP); + assert!(set3.contains(&15)); +} diff --git a/ascent_base/src/lattice/constant_propagation.rs b/ascent_base/src/lattice/constant_propagation.rs index 60039af..195858f 100644 --- a/ascent_base/src/lattice/constant_propagation.rs +++ b/ascent_base/src/lattice/constant_propagation.rs @@ -1,6 +1,6 @@ use std::cmp::Ordering; -use super::{BoundedLattice, Lattice}; +use super::{BoundedLattice, Lattice}; /// A flat `Lattice`: `Bottom` <= everything <= `Top`, and `Constant(x) == Constant(y)` iff `x == y` #[derive(PartialEq, Eq, Clone, Copy, Debug, Hash)] @@ -18,7 +18,12 @@ impl PartialOrd for ConstPropagation { (Bottom, Bottom) => Some(Ordering::Equal), (Bottom, _) => Some(Ordering::Less), (Constant(_x), Bottom) => Some(Ordering::Greater), - (Constant(x), Constant(y)) => if x == y {Some(Ordering::Equal)} else {None}, + (Constant(x), Constant(y)) => + if x == y { + Some(Ordering::Equal) + } else { + None + }, (Constant(_), Top) => Some(Ordering::Less), (Top, Top) => Some(Ordering::Equal), (Top, _) => Some(Ordering::Greater), @@ -32,7 +37,12 @@ impl Lattice for ConstPropagation { match (self, other) { (Bottom, _) => Self::Bottom, (Constant(_x), Bottom) => Self::Bottom, - (Constant(x), Constant(y)) => if x == y {Constant(x)} else {Self::Bottom}, + (Constant(x), Constant(y)) => + if x == y { + Constant(x) + } else { + Self::Bottom + }, (Constant(x), Top) => Constant(x), (Top, other) => other, } @@ -43,12 +53,17 @@ impl Lattice for ConstPropagation { match (self, other) { (Bottom, other) => other, (Constant(x), Bottom) => Constant(x), - (Constant(x), Constant(y)) => if x == y {Constant(x)} else {Self::Top}, + (Constant(x), Constant(y)) => + if x == y { + Constant(x) + } else { + Self::Top + }, (Constant(_x), Top) => Top, (Top, _) => Top, } } - + fn meet_mut(&mut self, other: Self) -> bool { use ConstPropagation::*; match (self, other) { @@ -61,11 +76,11 @@ impl Lattice for ConstPropagation { (_, Top) => false, (this @ Top, other) => { *this = other; - true + true }, } } - + fn join_mut(&mut self, other: Self) -> bool { use ConstPropagation::*; match (self, other) { @@ -84,13 +99,15 @@ impl Lattice for ConstPropagation { } } -impl BoundedLattice for ConstPropagation where ConstPropagation: Lattice { +impl BoundedLattice for ConstPropagation +where ConstPropagation: Lattice +{ fn top() -> Self { Self::Top } fn bottom() -> Self { Self::Bottom } } #[test] -fn test_constant_propagation(){ +fn test_constant_propagation() { let const_1 = ConstPropagation::Constant(1); assert!(const_1 > ConstPropagation::Bottom); assert!(const_1 < ConstPropagation::Top); @@ -98,7 +115,7 @@ fn test_constant_propagation(){ } #[test] -fn test_constant_propagation_lattice(){ +fn test_constant_propagation_lattice() { let const_1 = ConstPropagation::Constant(1); let mut x = const_1.clone(); @@ -114,4 +131,4 @@ fn test_constant_propagation_lattice(){ assert_eq!(x, ConstPropagation::Top); assert!(!x.join_mut(ConstPropagation::Constant(2))); -} \ No newline at end of file +} diff --git a/ascent_base/src/lattice/dual.rs b/ascent_base/src/lattice/dual.rs index 6f7f23d..f9cb155 100644 --- a/ascent_base/src/lattice/dual.rs +++ b/ascent_base/src/lattice/dual.rs @@ -1,66 +1,66 @@ -use std::{cmp::Ordering, fmt::Debug, fmt::Display, fmt::Formatter, ops::Deref}; - -use crate::Lattice; - -use super::BoundedLattice; - -#[derive(PartialEq, Eq, Clone, Copy, Hash)] -// TODO uncomment for a major release -// #[repr(transparent)] -/// A wrapper type that swaps (`<=` and `>=`) for `PartialOrd`s, (`meet` and `join`) for `Lattice`s, -/// and (`top` and `bottom`) for `BoundedLattice`s. -/// -/// # Example -/// ``` -/// # use ascent_base::lattice::Dual; -/// assert!(Dual(2) < Dual(1)); -/// ``` -pub struct Dual(pub T); - -impl Deref for Dual{ - type Target = T; - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl Debug for Dual { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { self.0.fmt(f) } -} - -impl Display for Dual { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { self.0.fmt(f) } -} - - -impl PartialOrd for Dual where T: PartialOrd { - fn partial_cmp(&self, other: &Self) -> Option { - other.0.partial_cmp(&self.0) - } -} - -impl Ord for Dual where T: Ord { - fn cmp(&self, other: &Self) -> Ordering { other.0.cmp(&self.0) } -} - -impl Lattice for Dual { - #[inline] - fn meet(self, other: Self) -> Self { Dual(self.0.join(other.0)) } - - #[inline] - fn join(self, other: Self) -> Self { Dual(self.0.meet(other.0)) } - - #[inline] - fn meet_mut(&mut self, other: Self) -> bool { self.0.join_mut(other.0) } - - #[inline] - fn join_mut(&mut self, other: Self) -> bool { self.0.meet_mut(other.0) } -} - -impl BoundedLattice for Dual { - #[inline] - fn top() -> Self { Dual(T::bottom()) } - - #[inline] - fn bottom() -> Self { Dual(T::top()) } -} +use std::cmp::Ordering; +use std::fmt::{Debug, Display, Formatter}; +use std::ops::Deref; + +use super::BoundedLattice; +use crate::Lattice; + +#[derive(PartialEq, Eq, Clone, Copy, Hash)] +// TODO uncomment for a major release +// #[repr(transparent)] +/// A wrapper type that swaps (`<=` and `>=`) for `PartialOrd`s, (`meet` and `join`) for `Lattice`s, +/// and (`top` and `bottom`) for `BoundedLattice`s. +/// +/// # Example +/// ``` +/// # use ascent_base::lattice::Dual; +/// assert!(Dual(2) < Dual(1)); +/// ``` +pub struct Dual(pub T); + +impl Deref for Dual { + type Target = T; + fn deref(&self) -> &Self::Target { &self.0 } +} + +impl Debug for Dual { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { self.0.fmt(f) } +} + +impl Display for Dual { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { self.0.fmt(f) } +} + +impl PartialOrd for Dual +where T: PartialOrd +{ + fn partial_cmp(&self, other: &Self) -> Option { other.0.partial_cmp(&self.0) } +} + +impl Ord for Dual +where T: Ord +{ + fn cmp(&self, other: &Self) -> Ordering { other.0.cmp(&self.0) } +} + +impl Lattice for Dual { + #[inline] + fn meet(self, other: Self) -> Self { Dual(self.0.join(other.0)) } + + #[inline] + fn join(self, other: Self) -> Self { Dual(self.0.meet(other.0)) } + + #[inline] + fn meet_mut(&mut self, other: Self) -> bool { self.0.join_mut(other.0) } + + #[inline] + fn join_mut(&mut self, other: Self) -> bool { self.0.meet_mut(other.0) } +} + +impl BoundedLattice for Dual { + #[inline] + fn top() -> Self { Dual(T::bottom()) } + + #[inline] + fn bottom() -> Self { Dual(T::top()) } +} diff --git a/ascent_base/src/lattice/ord_lattice.rs b/ascent_base/src/lattice/ord_lattice.rs index b1d1e5c..ac97e41 100644 --- a/ascent_base/src/lattice/ord_lattice.rs +++ b/ascent_base/src/lattice/ord_lattice.rs @@ -1,45 +1,40 @@ -use crate::Lattice; - -#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] -pub struct OrdLattice(pub T); - - -impl Lattice for OrdLattice { - #[inline(always)] - fn meet(self, other: Self) -> Self { - self.min(other) - } - - #[inline(always)] - fn join(self, other: Self) -> Self { - self.max(other) - } - - fn meet_mut(&mut self, other: Self) -> bool { - if self.0 > other.0 { - self.0 = other.0; - true - } else { - false - } - } - fn join_mut(&mut self, other: Self) -> bool { - if self.0 < other.0 { - self.0 = other.0; - true - } else { - false - } - } -} - -#[test] -fn test_ord_lattice(){ - assert_eq!(OrdLattice(42).meet(OrdLattice(22)), OrdLattice(22)); - - let mut x = OrdLattice(42); - assert!(!x.join_mut(OrdLattice(42))); - assert_eq!(x.0, 42); - assert!(!x.meet_mut(OrdLattice(42))); - assert_eq!(x.0, 42); -} \ No newline at end of file +use crate::Lattice; + +#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] +pub struct OrdLattice(pub T); + +impl Lattice for OrdLattice { + #[inline(always)] + fn meet(self, other: Self) -> Self { self.min(other) } + + #[inline(always)] + fn join(self, other: Self) -> Self { self.max(other) } + + fn meet_mut(&mut self, other: Self) -> bool { + if self.0 > other.0 { + self.0 = other.0; + true + } else { + false + } + } + fn join_mut(&mut self, other: Self) -> bool { + if self.0 < other.0 { + self.0 = other.0; + true + } else { + false + } + } +} + +#[test] +fn test_ord_lattice() { + assert_eq!(OrdLattice(42).meet(OrdLattice(22)), OrdLattice(22)); + + let mut x = OrdLattice(42); + assert!(!x.join_mut(OrdLattice(42))); + assert_eq!(x.0, 42); + assert!(!x.meet_mut(OrdLattice(42))); + assert_eq!(x.0, 42); +} diff --git a/ascent_base/src/lattice/product.rs b/ascent_base/src/lattice/product.rs index 5481c4d..29b29d4 100644 --- a/ascent_base/src/lattice/product.rs +++ b/ascent_base/src/lattice/product.rs @@ -1,187 +1,183 @@ -use paste::paste; - -use super::{BoundedLattice, Lattice}; -use std::cmp::Ordering; - -#[derive(Clone, Copy, PartialEq, Eq, Debug)] -/// A wrapper for tuple types and arrays that implements `PartialOrd` using -/// [product-order](https://en.wikipedia.org/wiki/Product_order) semantics. -/// -/// `Lattice` and `BoundedLattice` traits are also implemented. -/// -/// Difference from lexicographical ordering (the `PartialOrd` implementation for tuple types): -/// ``` -/// # use ascent_base::lattice::Product; -/// assert!(!{Product((1,4)) < Product((2,3))}); -/// assert!((1,4) < (2,3)); -/// -/// ``` -pub struct Product(pub T); - -#[inline] -fn combine_orderings(ord1: Ordering, ord2: Ordering) -> Option{ - use Ordering::*; - match (ord1, ord2) { - (Equal, _) => Some(ord2), - (_, Equal) => Some(ord1), - (Less, Less) => Some(Less), - (Greater, Greater) => Some(Greater), - _ => None - } -} - -macro_rules! tuple_lattice_impl{ - ($($i:tt),*) => { paste!( - impl< $([]: PartialOrd),* > PartialOrd for Product<($([]),*,)> { - fn partial_cmp(&self, other: &Self) -> Option { - let mut res = Ordering::Equal; - $( - match self.0.$i.partial_cmp(&other.0.$i) { - None => return None, - Some(ord) => { - match combine_orderings(ord, res) { - None => return None, - Some(new_res) => res = new_res, - } - } - }; - )* - Some(res) - } - } - impl< $([]: Lattice),* > Lattice for Product<($([]),*,)> { - fn meet_mut(&mut self, other: Self) -> bool { - let mut changed = false; - $(changed |= self.0.$i.meet_mut(other.0.$i);)* - changed - } - - fn join_mut(&mut self, other: Self) -> bool { - let mut changed = false; - $(changed |= self.0.$i.join_mut(other.0.$i);)* - changed - } - - fn meet(self, other: Self) -> Self { - Product(($(self.0.$i.meet(other.0.$i)),*,)) - } - - fn join(self, other: Self) -> Self { - Product(($(self.0.$i.join(other.0.$i)),*,)) - } - } - - impl< $([]: BoundedLattice),* > BoundedLattice for Product<($([]),*,)> where Product<($([]),*,)>: Lattice { - fn bottom() -> Self { - Product(($([]::bottom()),*,)) - } - - fn top() -> Self { - Product(($([]::top()),*,)) - } - } - );}; -} -tuple_lattice_impl!(0); -tuple_lattice_impl!(0, 1); -tuple_lattice_impl!(0, 1, 2); -tuple_lattice_impl!(0, 1, 2, 3); -tuple_lattice_impl!(0, 1, 2, 3, 4); -tuple_lattice_impl!(0, 1, 2, 3, 4, 5); -tuple_lattice_impl!(0, 1, 2, 3, 4, 5, 6); -tuple_lattice_impl!(0, 1, 2, 3, 4, 5, 6, 7); -tuple_lattice_impl!(0, 1, 2, 3, 4, 5, 6, 7, 8); -tuple_lattice_impl!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9); -tuple_lattice_impl!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10); - - -#[test] -fn test_product_lattice(){ - let t1 = Product((1, 3)); - let t2 = Product((0, 10)); - - assert_eq!(Lattice::meet(t1, t2), Product((0, 3))); - assert_eq!(Lattice::join(t1, t2), Product((1, 10))); - assert_eq!(Product::<(u32, u32)>::bottom(), Product((0,0))); - - - assert!(Product((1,3)) < Product((2,3))); - assert!(!{Product((1,4)) < Product((2,3))}); - assert!(Product((1,4)).partial_cmp(&Product((2,3))).is_none()); -} - - -impl PartialOrd for Product<[T; N]> { - fn partial_cmp(&self, other: &Self) -> Option { - let mut ord = Ordering::Equal; - for i in 0..N { - let ith_ord = self.0[i].partial_cmp(&other.0[i]); - match ith_ord { - None => return None, - Some(ith_ord) => match combine_orderings(ith_ord, ord) { - Some(new_ord) => ord = new_ord, - None => return None, - }, - } - } - Some(ord) - } -} - -#[test] -fn test_product_of_array_partial_ord() { - let a1 = Product([1, 2, 3]); - let a2 = Product([1, 3, 4]); - assert_eq!(a1.partial_cmp(&a2), Some(Ordering::Less)); - assert_eq!(a2.partial_cmp(&a1), Some(Ordering::Greater)); - - let a3 = Product([0, 2, 4]); - assert_eq!(a1.partial_cmp(&a3), None); - assert_eq!(a3.partial_cmp(&a1), None); - - assert_eq!(a2.partial_cmp(&a3), Some(Ordering::Greater)); - assert_eq!(a3.partial_cmp(&a2), Some(Ordering::Less)); -} - -impl Lattice for Product<[T; N]> { - fn meet_mut(&mut self, other: Self) -> bool { - let mut changed = false; - for (l, r) in self.0.iter_mut().zip(other.0){ - changed |= l.meet_mut(r); - } - changed - } - - fn join_mut(&mut self, other: Self) -> bool { - let mut changed = false; - for (l, r) in self.0.iter_mut().zip(other.0){ - changed |= l.join_mut(r); - } - changed - } -} - -impl BoundedLattice for Product<[T; N]> { - fn bottom() -> Self { - // unstable: - // Product(std::array::from_fn(|_| T::bottom())) - Product([(); N].map(|_| T::bottom())) - } - - fn top() -> Self { - Product([(); N].map(|_| T::top())) - } -} - -#[test] -fn test_product_of_array_lattice() { - let a1 = Product([1, 5, 3]); - let a2 = Product([1, 3, 4]); - let a1_a2_meet = Product([1,3,3]); - let a1_a2_join = Product([1,5,4]); - assert_eq!(a1.meet(a2), a1_a2_meet); - assert_eq!(a1.join(a2), a1_a2_join); - - assert_eq!(Product([0; 3]), Product::<[u32; 3]>::bottom()); - assert_eq!(Product([true; 4]), Product::<[bool; 4]>::top()); -} \ No newline at end of file +use std::cmp::Ordering; + +use paste::paste; + +use super::{BoundedLattice, Lattice}; + +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +/// A wrapper for tuple types and arrays that implements `PartialOrd` using +/// [product-order](https://en.wikipedia.org/wiki/Product_order) semantics. +/// +/// `Lattice` and `BoundedLattice` traits are also implemented. +/// +/// Difference from lexicographical ordering (the `PartialOrd` implementation for tuple types): +/// ``` +/// # use ascent_base::lattice::Product; +/// assert!(!{Product((1,4)) < Product((2,3))}); +/// assert!((1,4) < (2,3)); +/// +/// ``` +pub struct Product(pub T); + +#[inline] +fn combine_orderings(ord1: Ordering, ord2: Ordering) -> Option { + use Ordering::*; + match (ord1, ord2) { + (Equal, _) => Some(ord2), + (_, Equal) => Some(ord1), + (Less, Less) => Some(Less), + (Greater, Greater) => Some(Greater), + _ => None, + } +} + +macro_rules! tuple_lattice_impl { + ($($i:tt),*) => { paste!( + impl< $([]: PartialOrd),* > PartialOrd for Product<($([]),*,)> { + fn partial_cmp(&self, other: &Self) -> Option { + let mut res = Ordering::Equal; + $( + match self.0.$i.partial_cmp(&other.0.$i) { + None => return None, + Some(ord) => { + match combine_orderings(ord, res) { + None => return None, + Some(new_res) => res = new_res, + } + } + }; + )* + Some(res) + } + } + impl< $([]: Lattice),* > Lattice for Product<($([]),*,)> { + fn meet_mut(&mut self, other: Self) -> bool { + let mut changed = false; + $(changed |= self.0.$i.meet_mut(other.0.$i);)* + changed + } + + fn join_mut(&mut self, other: Self) -> bool { + let mut changed = false; + $(changed |= self.0.$i.join_mut(other.0.$i);)* + changed + } + + fn meet(self, other: Self) -> Self { + Product(($(self.0.$i.meet(other.0.$i)),*,)) + } + + fn join(self, other: Self) -> Self { + Product(($(self.0.$i.join(other.0.$i)),*,)) + } + } + + impl< $([]: BoundedLattice),* > BoundedLattice for Product<($([]),*,)> where Product<($([]),*,)>: Lattice { + fn bottom() -> Self { + Product(($([]::bottom()),*,)) + } + + fn top() -> Self { + Product(($([]::top()),*,)) + } + } + );}; +} +tuple_lattice_impl!(0); +tuple_lattice_impl!(0, 1); +tuple_lattice_impl!(0, 1, 2); +tuple_lattice_impl!(0, 1, 2, 3); +tuple_lattice_impl!(0, 1, 2, 3, 4); +tuple_lattice_impl!(0, 1, 2, 3, 4, 5); +tuple_lattice_impl!(0, 1, 2, 3, 4, 5, 6); +tuple_lattice_impl!(0, 1, 2, 3, 4, 5, 6, 7); +tuple_lattice_impl!(0, 1, 2, 3, 4, 5, 6, 7, 8); +tuple_lattice_impl!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9); +tuple_lattice_impl!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10); + +#[test] +fn test_product_lattice() { + let t1 = Product((1, 3)); + let t2 = Product((0, 10)); + + assert_eq!(Lattice::meet(t1, t2), Product((0, 3))); + assert_eq!(Lattice::join(t1, t2), Product((1, 10))); + assert_eq!(Product::<(u32, u32)>::bottom(), Product((0, 0))); + + assert!(Product((1, 3)) < Product((2, 3))); + assert!(!{ Product((1, 4)) < Product((2, 3)) }); + assert!(Product((1, 4)).partial_cmp(&Product((2, 3))).is_none()); +} + +impl PartialOrd for Product<[T; N]> { + fn partial_cmp(&self, other: &Self) -> Option { + let mut ord = Ordering::Equal; + for i in 0..N { + let ith_ord = self.0[i].partial_cmp(&other.0[i]); + match ith_ord { + None => return None, + Some(ith_ord) => match combine_orderings(ith_ord, ord) { + Some(new_ord) => ord = new_ord, + None => return None, + }, + } + } + Some(ord) + } +} + +#[test] +fn test_product_of_array_partial_ord() { + let a1 = Product([1, 2, 3]); + let a2 = Product([1, 3, 4]); + assert_eq!(a1.partial_cmp(&a2), Some(Ordering::Less)); + assert_eq!(a2.partial_cmp(&a1), Some(Ordering::Greater)); + + let a3 = Product([0, 2, 4]); + assert_eq!(a1.partial_cmp(&a3), None); + assert_eq!(a3.partial_cmp(&a1), None); + + assert_eq!(a2.partial_cmp(&a3), Some(Ordering::Greater)); + assert_eq!(a3.partial_cmp(&a2), Some(Ordering::Less)); +} + +impl Lattice for Product<[T; N]> { + fn meet_mut(&mut self, other: Self) -> bool { + let mut changed = false; + for (l, r) in self.0.iter_mut().zip(other.0) { + changed |= l.meet_mut(r); + } + changed + } + + fn join_mut(&mut self, other: Self) -> bool { + let mut changed = false; + for (l, r) in self.0.iter_mut().zip(other.0) { + changed |= l.join_mut(r); + } + changed + } +} + +impl BoundedLattice for Product<[T; N]> { + fn bottom() -> Self { + // unstable: + // Product(std::array::from_fn(|_| T::bottom())) + Product([(); N].map(|_| T::bottom())) + } + + fn top() -> Self { Product([(); N].map(|_| T::top())) } +} + +#[test] +fn test_product_of_array_lattice() { + let a1 = Product([1, 5, 3]); + let a2 = Product([1, 3, 4]); + let a1_a2_meet = Product([1, 3, 3]); + let a1_a2_join = Product([1, 5, 4]); + assert_eq!(a1.meet(a2), a1_a2_meet); + assert_eq!(a1.join(a2), a1_a2_join); + + assert_eq!(Product([0; 3]), Product::<[u32; 3]>::bottom()); + assert_eq!(Product([true; 4]), Product::<[bool; 4]>::top()); +} diff --git a/ascent_base/src/lattice/set.rs b/ascent_base/src/lattice/set.rs index 85bcfeb..1903750 100644 --- a/ascent_base/src/lattice/set.rs +++ b/ascent_base/src/lattice/set.rs @@ -1,85 +1,82 @@ -use std::cmp::Ordering; -use std::collections::{BTreeSet}; -use std::hash::Hash; -use std::ops::Deref; - -use super::Lattice; - -/// A set type that implements the `Lattice` trait -#[derive(Clone, PartialEq, Eq, Hash)] -pub struct Set(pub BTreeSet); - -impl Set { - - /// Creates a `Set` containing only `item` - pub fn singleton(item: T) -> Self { - let mut set = BTreeSet::new(); - set.insert(item); - Set(set) - } -} - -impl Default for Set { - fn default() -> Self {Self(Default::default())} -} - -impl Deref for Set{ - type Target = BTreeSet; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl PartialOrd for Set { - fn partial_cmp(&self, other: &Self) -> Option { - if self.0 == other.0 { - Some(Ordering::Equal) - } else if self.0.is_subset(&other.0) { - Some(Ordering::Less) - } else if self.0.is_superset(&other.0) { - Some(Ordering::Greater) - } else { - None - } - } -} - -impl Lattice for Set { - fn meet_mut(&mut self, mut other: Self) -> bool { - let self_len = self.0.len(); - let mut old_self = BTreeSet::new(); - std::mem::swap(&mut self.0, &mut old_self); - if self.0.len() > other.0.len() { - std::mem::swap(self, &mut other); - } - for item in old_self.into_iter() { - if other.0.contains(&item) { - self.0.insert(item); - } - } - self_len != self.0.len() - } - - fn join_mut(&mut self, mut other: Self) -> bool { - let self_len = self.0.len(); - if self_len < other.0.len() { - std::mem::swap(self, &mut other); - } - for item in other.0.into_iter() { - self.0.insert(item); - } - - self_len != self.0.len() - } - - fn meet(mut self, other: Self) -> Self { - self.meet_mut(other); - self - } - - fn join(mut self, other: Self) -> Self { - self.join_mut(other); - self - } -} +use std::cmp::Ordering; +use std::collections::BTreeSet; +use std::hash::Hash; +use std::ops::Deref; + +use super::Lattice; + +/// A set type that implements the `Lattice` trait +#[derive(Clone, PartialEq, Eq, Hash)] +pub struct Set(pub BTreeSet); + +impl Set { + /// Creates a `Set` containing only `item` + pub fn singleton(item: T) -> Self { + let mut set = BTreeSet::new(); + set.insert(item); + Set(set) + } +} + +impl Default for Set { + fn default() -> Self { Self(Default::default()) } +} + +impl Deref for Set { + type Target = BTreeSet; + + fn deref(&self) -> &Self::Target { &self.0 } +} + +impl PartialOrd for Set { + fn partial_cmp(&self, other: &Self) -> Option { + if self.0 == other.0 { + Some(Ordering::Equal) + } else if self.0.is_subset(&other.0) { + Some(Ordering::Less) + } else if self.0.is_superset(&other.0) { + Some(Ordering::Greater) + } else { + None + } + } +} + +impl Lattice for Set { + fn meet_mut(&mut self, mut other: Self) -> bool { + let self_len = self.0.len(); + let mut old_self = BTreeSet::new(); + std::mem::swap(&mut self.0, &mut old_self); + if self.0.len() > other.0.len() { + std::mem::swap(self, &mut other); + } + for item in old_self.into_iter() { + if other.0.contains(&item) { + self.0.insert(item); + } + } + self_len != self.0.len() + } + + fn join_mut(&mut self, mut other: Self) -> bool { + let self_len = self.0.len(); + if self_len < other.0.len() { + std::mem::swap(self, &mut other); + } + for item in other.0.into_iter() { + self.0.insert(item); + } + + self_len != self.0.len() + } + + fn meet(mut self, other: Self) -> Self { + self.meet_mut(other); + self + } + + fn join(mut self, other: Self) -> Self { + self.join_mut(other); + self + } +} diff --git a/ascent_base/src/lattice/tuple.rs b/ascent_base/src/lattice/tuple.rs index d5e6887..7d35c70 100644 --- a/ascent_base/src/lattice/tuple.rs +++ b/ascent_base/src/lattice/tuple.rs @@ -1,75 +1,75 @@ -use paste::paste; - -use super::{BoundedLattice, Lattice}; - -macro_rules! tuple_lattice_impl{ - ($($i:tt),*) => { - paste!( - impl< $([]),* > Lattice for ($([]),*,) where ($([]),*,): Ord { - fn meet_mut(&mut self, other: Self) -> bool { - use std::cmp::Ordering::*; - match (&*self).cmp(&other) { - Less | Equal => false, - Greater => { - *self = other; - true - } - } - } - - fn join_mut(&mut self, other: Self) -> bool { - use std::cmp::Ordering::*; - match (&*self).cmp(&other) { - Greater | Equal => false, - Less => { - *self = other; - true - } - } - } - - fn meet(self, other: Self) -> Self { - self.min(other) - } - - fn join(self, other: Self) -> Self { - self.max(other) - } - } - - impl< $([]),* > BoundedLattice for ($([]),*,) where $([]: BoundedLattice + Ord),* { - fn bottom() -> Self { - ($([]::bottom(),)*) - } - - fn top() -> Self { - ($([]::top(),)*) - } - } - ); - }; -} - -tuple_lattice_impl!(0); -tuple_lattice_impl!(0, 1); -tuple_lattice_impl!(0, 1, 2); -tuple_lattice_impl!(0, 1, 2, 3); -tuple_lattice_impl!(0, 1, 2, 3, 4); -tuple_lattice_impl!(0, 1, 2, 3, 4, 5); -tuple_lattice_impl!(0, 1, 2, 3, 4, 5, 6); -tuple_lattice_impl!(0, 1, 2, 3, 4, 5, 6, 7); -tuple_lattice_impl!(0, 1, 2, 3, 4, 5, 6, 7, 8); -tuple_lattice_impl!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9); -tuple_lattice_impl!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10); - -impl super::Lattice for () { - fn meet_mut(&mut self, _other: Self) -> bool { false } - fn join_mut(&mut self, _other: Self) -> bool { false } - fn meet(self, _other: Self) -> Self {} - fn join(self, _other: Self) -> Self {} -} - -impl BoundedLattice for () { - fn bottom() -> Self {} - fn top() -> Self {} -} +use paste::paste; + +use super::{BoundedLattice, Lattice}; + +macro_rules! tuple_lattice_impl { + ($($i:tt),*) => { + paste!( + impl< $([]),* > Lattice for ($([]),*,) where ($([]),*,): Ord { + fn meet_mut(&mut self, other: Self) -> bool { + use std::cmp::Ordering::*; + match (&*self).cmp(&other) { + Less | Equal => false, + Greater => { + *self = other; + true + } + } + } + + fn join_mut(&mut self, other: Self) -> bool { + use std::cmp::Ordering::*; + match (&*self).cmp(&other) { + Greater | Equal => false, + Less => { + *self = other; + true + } + } + } + + fn meet(self, other: Self) -> Self { + self.min(other) + } + + fn join(self, other: Self) -> Self { + self.max(other) + } + } + + impl< $([]),* > BoundedLattice for ($([]),*,) where $([]: BoundedLattice + Ord),* { + fn bottom() -> Self { + ($([]::bottom(),)*) + } + + fn top() -> Self { + ($([]::top(),)*) + } + } + ); + }; +} + +tuple_lattice_impl!(0); +tuple_lattice_impl!(0, 1); +tuple_lattice_impl!(0, 1, 2); +tuple_lattice_impl!(0, 1, 2, 3); +tuple_lattice_impl!(0, 1, 2, 3, 4); +tuple_lattice_impl!(0, 1, 2, 3, 4, 5); +tuple_lattice_impl!(0, 1, 2, 3, 4, 5, 6); +tuple_lattice_impl!(0, 1, 2, 3, 4, 5, 6, 7); +tuple_lattice_impl!(0, 1, 2, 3, 4, 5, 6, 7, 8); +tuple_lattice_impl!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9); +tuple_lattice_impl!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10); + +impl super::Lattice for () { + fn meet_mut(&mut self, _other: Self) -> bool { false } + fn join_mut(&mut self, _other: Self) -> bool { false } + fn meet(self, _other: Self) -> Self {} + fn join(self, _other: Self) -> Self {} +} + +impl BoundedLattice for () { + fn bottom() -> Self {} + fn top() -> Self {} +} diff --git a/ascent_base/src/lib.rs b/ascent_base/src/lib.rs index cc12c23..487dd5f 100644 --- a/ascent_base/src/lib.rs +++ b/ascent_base/src/lib.rs @@ -1,5 +1,4 @@ -pub mod lattice; -#[doc(hidden)] -pub mod util; -pub use lattice::Lattice; -pub use lattice::Dual; \ No newline at end of file +pub mod lattice; +#[doc(hidden)] +pub mod util; +pub use lattice::{Dual, Lattice}; diff --git a/ascent_base/src/util.rs b/ascent_base/src/util.rs index 7c8c84b..01b83d7 100644 --- a/ascent_base/src/util.rs +++ b/ascent_base/src/util.rs @@ -1,17 +1,20 @@ -//! internal utility functions defined here -//! -//! CAUTION: anything defined here is subject to change in semver-compatible releases - -/// update `reference` in-place using the provided closure -pub fn update(reference: &mut T, f: impl FnOnce(T) -> T) { - let ref_taken = std::mem::take(reference); - let new_val = f(ref_taken); - *reference = new_val; -} - -#[test] -fn test_update(){ - let mut vec = vec![1, 2, 3]; - update(&mut vec, |mut v| {v.push(4); v}); - assert_eq!(vec, vec![1, 2, 3, 4]); -} \ No newline at end of file +//! internal utility functions defined here +//! +//! CAUTION: anything defined here is subject to change in semver-compatible releases + +/// update `reference` in-place using the provided closure +pub fn update(reference: &mut T, f: impl FnOnce(T) -> T) { + let ref_taken = std::mem::take(reference); + let new_val = f(ref_taken); + *reference = new_val; +} + +#[test] +fn test_update() { + let mut vec = vec![1, 2, 3]; + update(&mut vec, |mut v| { + v.push(4); + v + }); + assert_eq!(vec, vec![1, 2, 3, 4]); +} diff --git a/ascent_macro/src/ascent_codegen.rs b/ascent_macro/src/ascent_codegen.rs index 82549ec..8ddfeea 100644 --- a/ascent_macro/src/ascent_codegen.rs +++ b/ascent_macro/src/ascent_codegen.rs @@ -1,1278 +1,1393 @@ -#![deny(warnings)] -use std::collections::HashSet; - -use itertools::{Itertools, Either}; -use proc_macro2::{Ident, Span, TokenStream}; -use syn::{Expr, Type, parse2, spanned::Spanned, parse_quote, parse_quote_spanned}; - -use crate::{ascent_hir::{IrRelation, IndexValType}, ascent_syntax::{CondClause, RelationIdentity}, utils::TokenStreamExtensions}; -use crate::utils::{exp_cloned, expr_to_ident, tuple, tuple_spanned, tuple_type}; -use crate::ascent_mir::{AscentMir, MirBodyItem, MirRelation, MirRelationVersion, MirRule, MirScc, ir_relation_version_var_name, mir_rule_summary, mir_summary}; -use crate::ascent_mir::MirRelationVersion::*; - -pub(crate) fn compile_mir(mir: &AscentMir, is_ascent_run: bool) -> proc_macro2::TokenStream { - - let mut relation_fields = vec![]; - let mut field_defaults = vec![]; - - let sorted_relations_ir_relations = mir.relations_ir_relations.iter().sorted_by_key(|(rel, _)| &rel.name); - for (rel, rel_indices) in sorted_relations_ir_relations { - let name = &rel.name; - let rel_attrs = &mir.relations_metadata[rel].attributes; - let sorted_rel_index_names = rel_indices.iter().map(|ind| format!("{}", ind.ir_name())).sorted(); - let rel_indices_comment = format!("\nlogical indices: {}", sorted_rel_index_names.into_iter().join("; ")); - let rel_type = rel_type(rel, mir); - let rel_ind_common = rel_ind_common_var_name(rel); - let rel_ind_common_type = rel_ind_common_type(rel, mir); - relation_fields.push(quote! { - #(#rel_attrs)* - #[doc = #rel_indices_comment] - pub #name: #rel_type, - pub #rel_ind_common: #rel_ind_common_type, - }); - field_defaults.push(quote! {#name : Default::default(), #rel_ind_common: Default::default(),}); - if rel.is_lattice && mir.is_parallel { - let lattice_mutex_name = lattice_insertion_mutex_var_name(rel); - relation_fields.push(quote! { - pub #lattice_mutex_name: ::std::vec::Vec>, - }); - field_defaults.push(quote! {#lattice_mutex_name: { - let len = ::ascent::internal::shards_count(); - let mut v = Vec::with_capacity(len); for _ in 0..len {v.push(Default::default())}; - v }, - }) - } - let sorted_indices = rel_indices.iter().sorted_by_cached_key(|ind| ind.ir_name()); - for ind in sorted_indices { - let name = &ind.ir_name(); - let rel_index_type = rel_index_type(ind, mir); - relation_fields.push(quote!{ - pub #name: #rel_index_type, - }); - field_defaults.push(quote! {#name : Default::default(),}); - } - } - - let sccs_ordered = &mir.sccs; - let mut rule_time_fields = vec![]; - let mut rule_time_fields_defaults = vec![]; - for i in 0..mir.sccs.len(){ - for (rule_ind, _rule) in mir.sccs[i].rules.iter().enumerate() { - let name = rule_time_field_name(i, rule_ind); - rule_time_fields.push(quote!{ - pub #name: std::time::Duration, - }); - rule_time_fields_defaults.push(quote!{ - #name: std::time::Duration::ZERO, - }); - } - } - - - let mut sccs_compiled = vec![]; - for (i, _scc) in sccs_ordered.iter().enumerate() { - let msg = format!("scc {}", i); - let scc_compiled = compile_mir_scc(mir, i); - sccs_compiled.push(quote!{ - ascent::internal::comment(#msg); - { - let _scc_start_time = ::ascent::internal::Instant::now(); - #scc_compiled - _self.scc_times[#i] += _scc_start_time.elapsed(); - } - }); - } - - let update_indices_body = compile_update_indices_function_body(mir); - let relation_sizes_body = compile_relation_sizes_body(mir); - let scc_times_summary_body = compile_scc_times_summary_body(mir); - - let mut type_constraints = vec![]; - let mut field_type_names = HashSet::::new(); - let mut lat_field_type_names = HashSet::::new(); - - for relation in mir.relations_ir_relations.keys().sorted_by_key(|rel| &rel.name) { - use crate::quote::ToTokens; - for (i,field_type) in relation.field_types.iter().enumerate() { - let is_lat = relation.is_lattice && i == relation.field_types.len() - 1; - let add = if let Type::Path(path) = field_type { - let container = if is_lat {&mut lat_field_type_names} else {&mut field_type_names}; - container.insert(path.path.clone().into_token_stream().to_string()) - } else {true}; - if add { - let type_constraints_type = - if is_lat {quote_spanned!(field_type.span()=>LatTypeConstraints)} - else {quote_spanned!(field_type.span()=>TypeConstraints)}; - type_constraints.push(quote_spanned!{field_type.span()=> - let _type_constraints : ascent::internal::#type_constraints_type<#field_type>; - }); - if mir.is_parallel { - type_constraints.push(quote_spanned!{field_type.span()=> - let _par_constraints : ascent::internal::ParTypeConstraints<#field_type>; - }); - } - } - } - } - - let mut relation_initializations = vec![]; - for (rel, md) in mir.relations_metadata.iter().sorted_by_key(|(rel, _)| &rel.name) { - if let Some(ref init) = md.initialization { - let rel_name = &rel.name; - relation_initializations.push(quote! { - _self.#rel_name = #init; - }); - } - } - if !relation_initializations.is_empty() { - relation_initializations.push(quote! { - _self.update_indices_priv(); - }) - } - - let par_usings = if mir.is_parallel {quote! { - use ascent::rayon::iter::ParallelBridge; - use ascent::rayon::iter::ParallelIterator; - use ascent::internal::CRelIndexRead; - use ascent::internal::CRelIndexReadAll; - use ascent::internal::Freezable; - }} else {quote!{}}; - - let more_usings = if !mir.is_parallel {quote! { - use ascent::internal::RelIndexWrite; - }} else { quote! { - use ascent::internal::CRelIndexWrite; - }}; - - let run_usings = quote! { - use core::cmp::PartialEq; - use ascent::internal::{RelIndexRead, RelIndexReadAll, ToRelIndex0, TupleOfBorrowed}; - #more_usings - #par_usings - }; - - let generate_run_timeout = !is_ascent_run && mir.config.generate_run_partial; - let run_func = if is_ascent_run {quote!{}} - else if generate_run_timeout { - quote! { - #[doc = "Runs the Ascent program to a fixed point."] - pub fn run(&mut self) { - self.run_timeout(::std::time::Duration::MAX); - } - } - } else { - quote! { - #[allow(unused_imports, noop_method_call, suspicious_double_ref_op)] - #[doc = "Runs the Ascent program to a fixed point."] - pub fn run(&mut self) { - macro_rules! __check_return_conditions {() => {};} - #run_usings - self.update_indices_priv(); - let _self = self; - #(#sccs_compiled)* - } - } - }; - let run_timeout_func = if !generate_run_timeout {quote!{}} else { - quote! { - #[allow(unused_imports, noop_method_call, suspicious_double_ref_op)] - #[doc = "Runs the Ascent program to a fixed point or until the timeout is reached. In case of a timeout returns false"] - pub fn run_timeout(&mut self, timeout: ::std::time::Duration) -> bool { - let __start_time = ::ascent::internal::Instant::now(); - macro_rules! __check_return_conditions {() => { - if timeout < ::std::time::Duration::MAX && __start_time.elapsed() >= timeout {return false;} - };} - #run_usings - self.update_indices_priv(); - let _self = self; - #(#sccs_compiled)* - true - } - } - }; - let run_code = if !is_ascent_run {quote!{}} else { - quote! { - macro_rules! __check_return_conditions {() => {};} - #run_usings - let _self = &mut __run_res; - #(#relation_initializations)* - #(#sccs_compiled)* - } - }; - - let relation_initializations_for_default_impl = - if is_ascent_run {vec![]} else {relation_initializations}; - let summary = mir_summary(mir); - - let (ty_impl_generics, ty_ty_generics, ty_where_clause) = mir.signatures.split_ty_generics_for_impl(); - let (impl_impl_generics, impl_ty_generics, impl_where_clause) = mir.signatures.split_impl_generics_for_impl(); - - let ty_signature = &mir.signatures.declaration; - if let Some(impl_signature) = &mir.signatures.implementation { - assert_eq!(ty_signature.ident, impl_signature.ident, "The identifiers of struct and impl must match"); - } - - let ty_ty_generics_str = quote!(#ty_ty_generics).to_string(); - let impl_ty_generics_str = quote!(#impl_ty_generics).to_string(); - assert_eq!(ty_ty_generics_str, impl_ty_generics_str, "The generic parameters of struct ({ty_ty_generics_str}) and impl ({impl_ty_generics_str}) must match"); - - let vis = &ty_signature.visibility; - let struct_name = &ty_signature.ident; - let struct_attrs = &ty_signature.attrs; - let summary_fn = if is_ascent_run { quote! { - pub fn summary(&self) -> &'static str {#summary} - }} else { quote! { - pub fn summary() -> &'static str {#summary} - }}; - let rule_time_fields = if mir.config.include_rule_times {rule_time_fields} else {vec![]}; - let rule_time_fields_defaults = if mir.config.include_rule_times {rule_time_fields_defaults} else {vec![]}; - - let mut rel_codegens = vec![]; - for rel in mir.relations_ir_relations.keys() { - let macro_path = &mir.relations_metadata[rel].ds_macro_path; - let macro_input = rel_ds_macro_input(rel, mir); - rel_codegens.push(quote_spanned!{ macro_path.span()=> #macro_path::rel_codegen!{#macro_input} }); - } - - let sccs_count = sccs_ordered.len(); - let res = quote! { - #(#rel_codegens)* - - #(#struct_attrs)* - #vis struct #struct_name #ty_impl_generics #ty_where_clause { - #(#relation_fields)* - scc_times: [std::time::Duration; #sccs_count], - scc_iters: [usize; #sccs_count], - #(#rule_time_fields)* - pub update_time_nanos: std::sync::atomic::AtomicU64, - pub update_indices_duration: std::time::Duration, - } - impl #impl_impl_generics #struct_name #impl_ty_generics #impl_where_clause { - #run_func - - #run_timeout_func - // TODO remove pub update_indices at some point - #[allow(noop_method_call, suspicious_double_ref_op)] - fn update_indices_priv(&mut self) { - let before = ::ascent::internal::Instant::now(); - #update_indices_body - self.update_indices_duration += before.elapsed(); - } - - #[deprecated = "Explicit call to update_indices not required anymore."] - pub fn update_indices(&mut self) { - self.update_indices_priv(); - } - fn type_constraints() { - #(#type_constraints)* - } - #summary_fn - - pub fn relation_sizes_summary(&self) -> String { - #relation_sizes_body - } - pub fn scc_times_summary(&self) -> String { - #scc_times_summary_body - } - } - impl #impl_impl_generics Default for #struct_name #impl_ty_generics #impl_where_clause { - fn default() -> Self { - let mut _self = #struct_name { - #(#field_defaults)* - scc_times: [std::time::Duration::ZERO; #sccs_count], - scc_iters: [0; #sccs_count], - #(#rule_time_fields_defaults)* - update_time_nanos: Default::default(), - update_indices_duration: std::time::Duration::default() - }; - #(#relation_initializations_for_default_impl)* - _self - } - } - }; - if !is_ascent_run {res} else { - quote! { - { - #res - let mut __run_res: #struct_name #ty_ty_generics = #struct_name::default(); - #[allow(unused_imports, noop_method_call, suspicious_double_ref_op)] - { - ascent::internal::comment("running..."); - #run_code - } - __run_res - } - } - } -} - -fn rel_ind_common_type(rel: &RelationIdentity, mir: &AscentMir) -> Type { - if rel.is_lattice { - parse_quote!{ () } - } else { - let macro_path = &mir.relations_metadata[rel].ds_macro_path; - let macro_input = rel_ds_macro_input(rel, mir); - parse_quote_spanned!{ macro_path.span()=> #macro_path::rel_ind_common!(#macro_input) } - } -} - -fn rel_index_type(rel: &IrRelation, mir: &AscentMir) -> Type { - let span = rel.relation.name.span(); - let key_type = rel.key_type(); - let value_type = rel.value_type(); - - let is_lat_full_index = rel.relation.is_lattice && &mir.lattices_full_indices[&rel.relation] == rel; - - if rel.relation.is_lattice { - let res = if !mir.is_parallel { - if is_lat_full_index { - quote_spanned! { span=>ascent::internal::RelFullIndexType<#key_type, #value_type> } - } else { - quote_spanned!{ span=>ascent::internal::LatticeIndexType<#key_type, #value_type> } - } - } else { - // parallel - if is_lat_full_index { - quote_spanned! { span=>ascent::internal::CRelFullIndex<#key_type, #value_type> } - } else if rel.is_no_index() { - quote_spanned! { span=>ascent::internal::CRelNoIndex<#value_type> } - } else { - quote_spanned! { span=>ascent::internal::CRelIndex<#key_type, #value_type> } - } - }; - syn::parse2(res).unwrap() - } else { - let macro_path = &mir.relations_metadata[&rel.relation].ds_macro_path; - let span = macro_path.span(); - let macro_input = rel_ds_macro_input(&rel.relation, mir); - if rel.is_full_index() { - parse_quote_spanned! {span=> #macro_path::rel_full_ind!(#macro_input, #key_type, #value_type)} - } else { - let ind = rel_index_to_macro_input(&rel.indices); - parse_quote_spanned! {span=> #macro_path::rel_ind!(#macro_input, #ind, #key_type, #value_type)} - } - } -} - -fn rel_type(rel: &RelationIdentity, mir: &AscentMir) -> Type { - let field_types = tuple_type(&rel.field_types); - - if rel.is_lattice { - if mir.is_parallel { - parse_quote! {::ascent::boxcar::Vec<::std::sync::RwLock<#field_types>>} - } else { - parse_quote! {::std::vec::Vec<#field_types>} - } - } else { - let macro_path = &mir.relations_metadata[rel].ds_macro_path; - let macro_input = rel_ds_macro_input(rel, mir); - parse_quote_spanned! {macro_path.span()=> #macro_path::rel!(#macro_input) } - } -} - - -fn rel_index_to_macro_input(ind: &[usize]) -> TokenStream { - let indices =ind.iter().cloned().map(syn::Index::from); - quote!{ [#(#indices),*] } -} - -fn rel_ds_macro_input(rel: &RelationIdentity, mir: &AscentMir) -> TokenStream { - let span = rel.name.span(); - let field_types = tuple_type(&rel.field_types); - let indices = mir.relations_ir_relations[rel].iter() - .sorted_by_key(|r| &r.indices) - .map(|ir_rel| rel_index_to_macro_input(&ir_rel.indices)); - let args = &mir.relations_metadata[rel].ds_macro_args; - let par: Ident = if mir.is_parallel { parse_quote_spanned! {span=> par} } else { parse_quote_spanned! {span=> ser} }; - let name = Ident::new(&format!("{}_{}", mir.signatures.declaration.ident, rel.name), span); - quote! { - #name, - #field_types, - [#(#indices),*], - #par, - (#args) - } -} - -fn rule_time_field_name(scc_ind: usize, rule_ind: usize) ->Ident { - Ident::new(&format!("rule{}_{}_duration", scc_ind, rule_ind), Span::call_site()) -} - -fn compile_mir_scc(mir: &AscentMir, scc_ind: usize) -> proc_macro2::TokenStream { - - let scc = &mir.sccs[scc_ind]; - let mut move_total_to_delta = vec![]; - let mut shift_delta_to_total_new_to_delta = vec![]; - let mut move_total_to_field = vec![]; - let mut freeze_code = vec![]; - let mut unfreeze_code = vec![]; - - let _self = quote! { _self }; - - use std::iter::once; - let sorted_dynamic_relations = scc.dynamic_relations.iter().sorted_by_cached_key(|(rel, _)| rel.name.clone()); - for rel in sorted_dynamic_relations.flat_map(|(rel, indices)| { - once(Either::Left(rel)).chain(indices.iter().sorted_by_cached_key(|rel| rel.ir_name()).map(Either::Right)) - }) { - let (ir_name, ty) = match rel { - Either::Left(rel) => (rel_ind_common_var_name(rel), rel_ind_common_type(rel, mir)), - Either::Right(rel_ind) => (rel_ind.ir_name(), rel_index_type(rel_ind, mir)), - }; - let delta_var_name = ir_relation_version_var_name(&ir_name, MirRelationVersion::Delta); - let total_var_name = ir_relation_version_var_name(&ir_name, MirRelationVersion::Total); - let new_var_name = ir_relation_version_var_name(&ir_name, MirRelationVersion::New); - let total_field = &ir_name; - move_total_to_delta.push(quote_spanned! {ir_name.span()=> - let mut #delta_var_name: #ty = ::std::mem::take(&mut #_self.#total_field); - let mut #total_var_name : #ty = Default::default(); - let mut #new_var_name : #ty = Default::default(); - }); - - match rel { - Either::Left(_rel_ind_common) => { - shift_delta_to_total_new_to_delta.push(quote_spanned!{ir_name.span()=> - ::ascent::internal::RelIndexMerge::merge_delta_to_total_new_to_delta(&mut #new_var_name, &mut #delta_var_name, &mut #total_var_name); - }); - move_total_to_delta.push(quote_spanned! {ir_name.span()=> - ::ascent::internal::RelIndexMerge::init(&mut #new_var_name, &mut #delta_var_name, &mut #total_var_name); - }); - }, - Either::Right(ir_rel) => { - let delta_expr = expr_for_rel_write(&MirRelation::from(ir_rel.clone(), Delta), mir); - let total_expr = expr_for_rel_write(&MirRelation::from(ir_rel.clone(), Total), mir); - let new_expr = expr_for_rel_write(&MirRelation::from(ir_rel.clone(), New), mir); - - shift_delta_to_total_new_to_delta.push(quote_spanned!{ir_name.span()=> - ::ascent::internal::RelIndexMerge::merge_delta_to_total_new_to_delta(&mut #new_expr, &mut #delta_expr, &mut #total_expr); - }); - move_total_to_delta.push(quote_spanned! {ir_name.span()=> - ::ascent::internal::RelIndexMerge::init(&mut #new_expr, &mut #delta_expr, &mut #total_expr); - }); - }, - } - - move_total_to_field.push(quote_spanned!{ir_name.span()=> - #_self.#total_field = #total_var_name; - }); - - if mir.is_parallel { - freeze_code.push(quote_spanned!{ir_name.span()=> - #total_var_name.freeze(); - #delta_var_name.freeze(); - }); - - unfreeze_code.push(quote_spanned!{ir_name.span()=> - #total_var_name.unfreeze(); - #delta_var_name.unfreeze(); - }); - } - } - let sorted_body_only_relations = scc.body_only_relations.iter().sorted_by_cached_key(|(rel, _)| rel.name.clone()); - for rel in sorted_body_only_relations.flat_map(|(rel, indices)| { - let sorted_indices = indices.iter().sorted_by_cached_key(|rel| rel.ir_name()); - once(Either::Left(rel)).chain(sorted_indices.map(Either::Right)) - }) { - let (ir_name, ty) = match rel { - Either::Left(rel) => (rel_ind_common_var_name(rel), rel_ind_common_type(rel, mir)), - Either::Right(rel_ind) => (rel_ind.ir_name(), rel_index_type(rel_ind, mir)), - }; - let total_var_name = ir_relation_version_var_name(&ir_name, MirRelationVersion::Total); - let total_field = &ir_name; - - if mir.is_parallel { - move_total_to_delta.push(quote_spanned!{ir_name.span()=> - #_self.#total_field.freeze(); - }); - } - - move_total_to_delta.push(quote_spanned! {ir_name.span()=> - let #total_var_name: #ty = std::mem::take(&mut #_self.#total_field); - }); - - move_total_to_field.push(quote_spanned!{ir_name.span()=> - #_self.#total_field = #total_var_name; - }); - } - - let rule_parallelism = mir.config.inter_rule_parallelism && mir.is_parallel; - - let mut evaluate_rules = vec![]; - - for (i, rule) in scc.rules.iter().enumerate() { - let msg = mir_rule_summary(rule); - let rule_compiled = compile_mir_rule(rule, scc, mir); - let rule_time_field = rule_time_field_name(scc_ind, i); - let (before_rule_var, update_rule_time_field) = if mir.config.include_rule_times { - (quote! {let before_rule = ::ascent::internal::Instant::now();}, - quote!{_self.#rule_time_field += before_rule.elapsed();}) - } else {(quote!{}, quote!{})}; - evaluate_rules.push(if rule_parallelism { quote! { - ascent::internal::comment(#msg); - __scope.spawn(|_| { - #before_rule_var - #rule_compiled - #update_rule_time_field - }); - }} else { quote! { - #before_rule_var - ascent::internal::comment(#msg); - { - #rule_compiled - } - #update_rule_time_field - }}); - } - - let evaluate_rules = if rule_parallelism { - quote! { - ascent::rayon::scope(|__scope| { - #(#evaluate_rules)* - }); - } - } else { - quote! { #(#evaluate_rules)* } - }; - - let changed_var_def_code = if !mir.is_parallel { - quote! { let mut __changed = false; } - } else { - quote! { let __changed = std::sync::atomic::AtomicBool::new(false); } - }; - let check_changed_code = if !mir.is_parallel { - quote! {__changed} - } else { - quote! {__changed.load(std::sync::atomic::Ordering::Relaxed)} - }; - - let evaluate_rules_loop = if scc.is_looping { quote! { - #[allow(unused_assignments, unused_variables)] - loop { - #changed_var_def_code - - #(#freeze_code)* - // evaluate rules - #evaluate_rules - - #(#unfreeze_code)* - #(#shift_delta_to_total_new_to_delta)* - _self.scc_iters[#scc_ind] += 1; - if !#check_changed_code {break;} - __check_return_conditions!(); - } - - }} else {quote! { - #[allow(unused_assignments, unused_variables)] - { - // let mut __changed = false; - #changed_var_def_code - #(#freeze_code)* - - #evaluate_rules - - #(#unfreeze_code)* - - #(#shift_delta_to_total_new_to_delta)* - #(#shift_delta_to_total_new_to_delta)* - _self.scc_iters[#scc_ind] += 1; - __check_return_conditions!(); - } - }}; - quote! { - // define variables for delta and new versions of dynamic relations in the scc - // move total versions of dynamic indices to delta - #(#move_total_to_delta)* - - #evaluate_rules_loop - - #(#move_total_to_field)* - } -} - - -fn compile_relation_sizes_body(mir: &AscentMir) -> proc_macro2::TokenStream { - let mut write_sizes = vec![]; - for r in mir.relations_ir_relations.keys().sorted_by_key(|r| &r.name) { - let rel_name = &r.name; - let rel_name_str = r.name.to_string(); - write_sizes.push(quote! { - writeln!(&mut res, "{} size: {}", #rel_name_str, self.#rel_name.len()).unwrap(); - }); - } - quote! { - use std::fmt::Write; - let mut res = String::new(); - #(#write_sizes)* - res - } -} - -fn compile_scc_times_summary_body(mir: &AscentMir) -> proc_macro2::TokenStream { - let mut res = vec![]; - for i in 0..mir.sccs.len(){ - let i_str = format!("{}", i); - res.push(quote!{ - writeln!(&mut res, "scc {}: iterations: {}, time: {:?}", #i_str, self.scc_iters[#i], self.scc_times[#i]).unwrap(); - }); - if mir.config.include_rule_times { - let mut sum_of_rule_times = quote!{ std::time::Duration::ZERO }; - for (rule_ind, _rule) in mir.sccs[i].rules.iter().enumerate() { - let rule_time_field = rule_time_field_name(i, rule_ind); - sum_of_rule_times = quote!{ #sum_of_rule_times + self.#rule_time_field}; - } - res.push(quote! { - let sum_of_rule_times = #sum_of_rule_times; - writeln!(&mut res, " sum of rule times: {:?}", sum_of_rule_times).unwrap(); - }); - for (rule_ind, rule) in mir.sccs[i].rules.iter().enumerate() { - let rule_time_field = rule_time_field_name(i, rule_ind); - let rule_summary = mir_rule_summary(rule); - res.push(quote!{ - writeln!(&mut res, " rule {}\n time: {:?}", #rule_summary, self.#rule_time_field).unwrap(); - }); - } - res.push(quote!{ - writeln!(&mut res, "-----------------------------------------").unwrap(); - }); - } - } - let update_indices_time_code = quote! { - writeln!(&mut res, "update_indices time: {:?}", self.update_indices_duration).unwrap(); - }; - quote!{ - use std::fmt::Write; - let mut res = String::new(); - #update_indices_time_code - #(#res)* - res - } -} - -fn compile_update_indices_function_body(mir: &AscentMir) -> proc_macro2::TokenStream { - let par = mir.is_parallel; - let mut res = vec![]; - if par { - res.push(quote! { use ascent::rayon::iter::{IntoParallelIterator, ParallelIterator}; }) - } - let (rel_index_write_trait, index_insert_fn) = if !par { - (quote! {ascent::internal::RelIndexWrite}, quote! {index_insert}) - } else { - (quote! {ascent::internal::CRelIndexWrite}, quote!{index_insert}) - }; - let sorted_relations_ir_relations = mir.relations_ir_relations.iter().sorted_by_key(|(rel, _)| &rel.name); - for (r,indices_set) in sorted_relations_ir_relations { - - let _ref = if !par { quote!{&mut} } else { quote!{&} }.with_span(r.name.span()); - let ind_common = rel_ind_common_var_name(r); - let rel_index_write_trait = rel_index_write_trait.clone().with_span(r.name.span()); - let _self = quote_spanned!{r.name.span().resolved_at(Span::call_site())=> self }; - let to_rel_index_fn = if !par { quote!{to_rel_index_write} } else { quote!{to_c_rel_index_write} }; - let to_rel_index = if r.is_lattice { quote!{} } else { - quote! {.#to_rel_index_fn(#_ref #_self.#ind_common) } - }; - - let mut update_indices = vec![]; - for ind in indices_set.iter().sorted_by_cached_key(|rel| rel.ir_name()) { - let ind_name = &ind.ir_name(); - let selection_tuple : Vec = ind.indices.iter().map(|&i| { - let ind = syn::Index::from(i); - parse_quote_spanned! {r.name.span()=> tuple.#ind.clone()} - }).collect_vec(); - let selection_tuple = tuple_spanned(&selection_tuple, r.name.span()); - let entry_val = index_get_entry_val_for_insert( - ind, &parse_quote_spanned!{r.name.span()=> tuple}, &parse_quote_spanned!{r.name.span()=> _i}); - let _pre_ref = if r.is_lattice {quote!()} else {_ref.clone()}; - update_indices.push(quote_spanned! {r.name.span()=> - let selection_tuple = #selection_tuple; - let rel_ind = #_ref #_self.#ind_name; - #rel_index_write_trait::#index_insert_fn(#_pre_ref rel_ind #to_rel_index, selection_tuple, #entry_val); - }); - - } - let rel_name = &r.name; - let maybe_lock = if r.is_lattice && mir.is_parallel { - quote_spanned!{r.name.span()=> let tuple = tuple.read().unwrap(); } - } else { quote!{} }; - if !par { - res.push(quote_spanned! {r.name.span()=> - for (_i, tuple) in #_self.#rel_name.iter().enumerate() { - #maybe_lock - #(#update_indices)* - } - }); - } else { - res.push(quote_spanned! {r.name.span()=> - (0..#_self.#rel_name.len()).into_par_iter().for_each(|_i| { - let tuple = &#_self.#rel_name[_i]; - #maybe_lock - #(#update_indices)* - }); - }); - } - } - - quote! { - use ascent::internal::ToRelIndex0; - use #rel_index_write_trait; - #(#res)* - } -} - -fn compile_cond_clause(cond: &CondClause, body: proc_macro2::TokenStream) -> proc_macro2::TokenStream { - match cond { - CondClause::IfLet(if_let_clause) => { - let pat = &if_let_clause.pattern; - let expr = &if_let_clause.exp; - quote_spanned! {if_let_clause.if_keyword.span()=> - if let #pat = #expr { - #body - } - } - } - CondClause::Let(let_clause) => { - let pat = &let_clause.pattern; - let expr = &let_clause.exp; - quote_spanned! {let_clause.let_keyword.span()=> - let #pat = #expr; - #body - } - } - CondClause::If(if_clause) => { - let cond = &if_clause.cond; - quote_spanned! {if_clause.if_keyword.span()=> - if #cond { - #body - } - } - } - } -} - -fn compile_mir_rule(rule: &MirRule, scc: &MirScc, mir: &AscentMir) -> proc_macro2::TokenStream { - let (head_rels_structs_and_vars, head_update_code) = - head_clauses_structs_and_update_code(rule, scc, mir); - - const MAX_PAR_ITERS: usize = 2; - - // do parallel iteration up to this clause index (exclusive) - let par_iter_to_ind = if mir.is_parallel { - use itertools::FoldWhile::*; - rule.body_items.iter().fold_while((0, 0), |(count, i), bi| { - let new_count = count + matches!(bi, MirBodyItem::Clause(..)) as usize; - if new_count > MAX_PAR_ITERS { Done((new_count, i)) } else { Continue((new_count, i + 1)) } - }).into_inner().1 - } else { 0 }; - let rule_code = compile_mir_rule_inner(rule, scc, mir, par_iter_to_ind, head_update_code, 0); - quote!{ - #head_rels_structs_and_vars - #rule_code - } -} - -fn compile_mir_rule_inner(rule: &MirRule, _scc: &MirScc, mir: &AscentMir, par_iter_to_ind: usize, head_update_code: proc_macro2::TokenStream, clause_ind: usize) --> proc_macro2::TokenStream -{ - if Some(clause_ind) == rule.simple_join_start_index && rule.reorderable { - let mut rule_cp1 = rule.clone(); - let mut rule_cp2 = rule.clone(); - rule_cp1.reorderable = false; - rule_cp2.reorderable = false; - rule_cp2.body_items.swap(clause_ind, clause_ind + 1); - let rule_cp1_compiled = compile_mir_rule_inner(&rule_cp1, _scc, mir, par_iter_to_ind, head_update_code.clone(), clause_ind); - let rule_cp2_compiled = compile_mir_rule_inner(&rule_cp2, _scc, mir, par_iter_to_ind, head_update_code , clause_ind); - - if let [MirBodyItem::Clause(bcl1), MirBodyItem::Clause(bcl2)] = &rule.body_items[clause_ind..clause_ind+2]{ - let rel1_var_name = expr_for_rel(&bcl1.rel, mir); - let rel2_var_name = expr_for_rel(&bcl2.rel, mir); - - return quote_spanned!{bcl1.rel_args_span=> - if #rel1_var_name.len() <= #rel2_var_name.len() { - #rule_cp1_compiled - } else { - #rule_cp2_compiled - } - }; - } else { - panic!("unexpected body items in reorderable rule") - } - } - if clause_ind < rule.body_items.len(){ - let bitem = &rule.body_items[clause_ind]; - let doing_simple_join = rule.simple_join_start_index == Some(clause_ind); - let next_loop = if doing_simple_join { - compile_mir_rule_inner(rule, _scc, mir, par_iter_to_ind, head_update_code, clause_ind + 2) - } else { - compile_mir_rule_inner(rule, _scc, mir, par_iter_to_ind, head_update_code, clause_ind + 1) - }; - - match bitem { - MirBodyItem::Clause(bclause) => { - let (clause_ind, bclause) = - if doing_simple_join {(clause_ind + 1, rule.body_items[clause_ind + 1].unwrap_clause())} - else {(clause_ind, bclause)}; - - let bclause_rel_name = &bclause.rel.relation.name; - let selected_args = &bclause.selected_args(); - let pre_clause_vars = rule.body_items.iter().take(clause_ind) - .flat_map(MirBodyItem::bound_vars) - .collect::>(); - - let clause_vars = bclause.vars(); - let common_vars = clause_vars.iter().filter(|(_i,v)| pre_clause_vars.contains(v)).collect::>(); - let common_vars_no_indices = common_vars.iter().map(|(_i,v)| v.clone()).collect::>(); - - let cloning_needed = true; - - let matched_val_ident = Ident::new("__val", bclause.rel_args_span); - let new_vars_assignments = clause_var_assignments( - &bclause.rel, clause_vars.iter().filter(|(_i, var)| !common_vars_no_indices.contains(var)).cloned(), - &matched_val_ident, &parse_quote!{_self.#bclause_rel_name}, cloning_needed, mir - ); - - let selected_args_cloned = selected_args.iter().map(exp_cloned).collect_vec(); - let selected_args_tuple = tuple_spanned(&selected_args_cloned, bclause.args_span); - let rel_version_exp = expr_for_rel(&bclause.rel, mir); - - let mut conds_then_next_loop = next_loop; - for cond in bclause.cond_clauses.iter().rev() { - conds_then_next_loop = compile_cond_clause(cond, conds_then_next_loop); - } - - let span = bclause.rel_args_span; - - let matching_dot_iter = quote_spanned!{bclause.rel_args_span=> __matching}; - - let (index_get, iter_all) = if clause_ind < par_iter_to_ind { - (quote_spanned! {span=> c_index_get}, quote_spanned! {span=> c_iter_all}) - } else { - (quote_spanned! {span=> index_get}, quote_spanned! {span=> iter_all}) - }; - - // The special case where the first clause has indices, but there are no expressions - // in the args of the first clause - if doing_simple_join { - let cl1 = rule.body_items[rule.simple_join_start_index.unwrap()].unwrap_clause(); - let cl2 = bclause; - let cl1_var_name = expr_for_rel(&cl1.rel, mir); - let cl2_var_name = expr_for_rel(&cl2.rel, mir); - let cl1_vars = cl1.vars(); - - let cl1_rel_name = &cl1.rel.relation.name; - - let mut cl1_join_vars_assignments = vec![]; - for (tuple_ind, &i) in cl1.rel.indices.iter().enumerate() { - let var = expr_to_ident(&cl1.args[i]).unwrap(); - let tuple_ind = syn::Index{index: tuple_ind as u32, span: var.span()}; - cl1_join_vars_assignments.push(quote_spanned! {var.span()=> let #var = __cl1_joined_columns.#tuple_ind;}); - } - - let cl1_matched_val_ident = syn::Ident::new("cl1_val", cl1.rel_args_span); - let cl1_vars_assignments = clause_var_assignments( - &cl1.rel, cl1_vars.iter().filter(|(i, _var)| !cl1.rel.indices.contains(i)).cloned(), - &cl1_matched_val_ident, &parse_quote!{_self.#cl1_rel_name}, cloning_needed, mir - ); - let cl1_vars_assignments = vec![cl1_vars_assignments]; - - let joined_args_for_cl2_cloned = cl2.selected_args().iter().map(exp_cloned).collect_vec(); - let joined_args_tuple_for_cl2 = tuple_spanned(&joined_args_for_cl2_cloned, cl2.args_span); - - - let cl1_tuple_indices_iter = quote_spanned!(cl1.rel_args_span=> __cl1_tuple_indices); - - let mut cl1_conds_then_rest = quote_spanned! {bclause.rel_args_span=> - #matching_dot_iter.clone().for_each(|__val| { - // TODO we may be doing excessive cloning - #new_vars_assignments - #conds_then_next_loop - }); - }; - for cond in cl1.cond_clauses.iter().rev() { - cl1_conds_then_rest = compile_cond_clause(cond, cl1_conds_then_rest); - } - quote_spanned! {cl1.rel_args_span=> - #cl1_var_name.#iter_all().for_each(|(__cl1_joined_columns, __cl1_tuple_indices)| { - let __cl1_joined_columns = __cl1_joined_columns.tuple_of_borrowed(); - #(#cl1_join_vars_assignments)* - if let Some(__matching) = #cl2_var_name.#index_get(&#joined_args_tuple_for_cl2) { - #cl1_tuple_indices_iter.for_each(|cl1_val| { - #(#cl1_vars_assignments)* - #cl1_conds_then_rest - }); - } - }); - } - } else { - quote_spanned! {bclause.rel_args_span=> - if let Some(__matching) = #rel_version_exp.#index_get( &#selected_args_tuple) { - #matching_dot_iter.for_each(|__val| { - // TODO we may be doing excessive cloning - #new_vars_assignments - #conds_then_next_loop - }); - } - } - } - - }, - MirBodyItem::Generator(gen) => { - let pat = &gen.pattern; - let expr = &gen.expr; - quote_spanned! {gen.for_keyword.span()=> - for #pat in #expr { - #next_loop - } - } - } - MirBodyItem::Cond(cond) => compile_cond_clause(cond, next_loop), - MirBodyItem::Agg(agg) => { - - let pat = &agg.pat; - let rel_name = &agg.rel.relation.name; - let mir_relation = MirRelation::from(agg.rel.clone(), Total); - // let rel_version_var_name = mir_relation.var_name(); - let rel_expr = expr_for_rel(&mir_relation, mir); - let selected_args = mir_relation.indices.iter().map(|&i| &agg.rel_args[i]); - let selected_args_cloned = selected_args.map(exp_cloned).collect_vec(); - let selected_args_tuple = tuple_spanned(&selected_args_cloned, agg.span); - let agg_args_tuple_indices = - agg.bound_args.iter() - .map(|arg| (agg.rel_args.iter() - .find_position(|rel_arg| expr_to_ident(rel_arg) == Some(arg.clone())).unwrap().0, arg.clone())); - - let agg_args_tuple = tuple_spanned(&agg.bound_args.iter().map(|v| parse_quote!{#v}).collect_vec(), agg.span); - - let vars_assignments = clause_var_assignments( - &MirRelation::from(agg.rel.clone(), MirRelationVersion::Total), agg_args_tuple_indices, - &parse_quote_spanned!{agg.span=> __val}, &parse_quote!{_self.#rel_name}, - false, mir - ); - - let agg_func = &agg.aggregator; - let _self = quote!{ _self }; - quote_spanned! {agg.span=> - let __aggregated_rel = #rel_expr; - let __matching = __aggregated_rel.index_get( &#selected_args_tuple); - let __agg_args = __matching.into_iter().flatten().map(|__val| { - #vars_assignments - #agg_args_tuple - }); - for #pat in #agg_func(__agg_args) { - #next_loop - } - - } - } - } - } else { - quote! { - // let before_update = ::ascent::internal::Instant::now(); - #head_update_code - // let update_took = before_update.elapsed(); - // _self.update_time_nanos.fetch_add(update_took.as_nanos() as u64, std::sync::atomic::Ordering::Relaxed); - } - } -} - -fn head_clauses_structs_and_update_code(rule: &MirRule, scc: &MirScc, mir: &AscentMir) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) { - let mut add_rows = vec![]; - - let set_changed_true_code = if !mir.is_parallel { - quote! { __changed = true; } - } else { - quote! { __changed.store(true, std::sync::atomic::Ordering::Relaxed);} - }; - - for hcl in rule.head_clause.iter() { - - let head_rel_name = Ident::new(&hcl.rel.name.to_string(), hcl.span); - let hcl_args_converted = hcl.args.iter().cloned().map(convert_head_arg).collect_vec(); - let new_row_tuple = tuple_spanned(&hcl_args_converted, hcl.args_span); - - let head_relation = &hcl.rel; - let row_type = tuple_type(&head_relation.field_types); - - let mut update_indices = vec![]; - let rel_indices = scc.dynamic_relations.get(head_relation); - let (rel_index_write_trait, index_insert_fn) = if !mir.is_parallel { - (quote! { ::ascent::internal::RelIndexWrite }, quote! {index_insert}) - } else { - (quote! { ::ascent::internal::CRelIndexWrite }, quote! {index_insert}) - }; - let (rel_index_write_trait, index_insert_fn) = (rel_index_write_trait.with_span(hcl.span), index_insert_fn.with_span(hcl.span)); - let new_ref = if !mir.is_parallel { quote!{&mut} } else { quote!{&} }; - let mut used_fields = HashSet::new(); - if let Some(rel_indices) = rel_indices { - for rel_ind in rel_indices.iter().sorted_by_cached_key(|rel| rel.ir_name()) { - if rel_ind.is_full_index() {continue}; - let var_name = if !mir.is_parallel { - expr_for_rel_write(&MirRelation::from(rel_ind.clone(), New), mir) - } else { - expr_for_c_rel_write(&MirRelation::from(rel_ind.clone(), New), mir) - }; - let args_tuple : Vec = rel_ind.indices.iter().map(|&i| { - let i_ind = syn::Index::from(i); - syn::parse2(quote_spanned!{hcl.span=> __new_row.#i_ind.clone()}).unwrap() - }).collect(); - used_fields.extend(rel_ind.indices.iter().cloned()); - if let IndexValType::Direct(direct) = &rel_ind.val_type { - used_fields.extend(direct.iter().cloned()); - } - let args_tuple = tuple(&args_tuple); - let entry_val = index_get_entry_val_for_insert( - rel_ind, &parse_quote_spanned!{hcl.span=> __new_row}, &parse_quote_spanned!{hcl.span=> __new_row_ind}); - update_indices.push(quote_spanned! {hcl.span=> - #rel_index_write_trait::#index_insert_fn(#new_ref #var_name, #args_tuple, #entry_val); - }); - } - } - - let head_rel_full_index = &mir.relations_full_indices[head_relation]; - - let expr_for_rel_maybe_mut = if mir.is_parallel { expr_for_c_rel_write } else { expr_for_rel_write }; - let head_rel_full_index_expr_new = expr_for_rel_maybe_mut(&MirRelation::from(head_rel_full_index.clone(), New), mir); - let head_rel_full_index_expr_delta = expr_for_rel(&MirRelation::from(head_rel_full_index.clone(), Delta), mir); - let head_rel_full_index_expr_total = expr_for_rel(&MirRelation::from(head_rel_full_index.clone(), Total), mir); - - - let rel_full_index_write_trait = if !mir.is_parallel { - quote! { ::ascent::internal::RelFullIndexWrite } - } else { - quote! { ::ascent::internal::CRelFullIndexWrite } - }.with_span(hcl.span); - - let new_row_to_be_pushed = (0..hcl.rel.field_types.len()).map(|i| { - let ind = syn::Index::from(i); - let clone = if used_fields.contains(&i) { quote!{.clone()} } else { quote!{} }; - parse_quote_spanned!{hcl.span=> __new_row.#ind #clone } - }).collect_vec(); - let new_row_to_be_pushed = tuple_spanned(&new_row_to_be_pushed, hcl.span); - - let push_code = if !mir.is_parallel { quote! { - let __new_row_ind = _self.#head_rel_name.len(); - _self.#head_rel_name.push(#new_row_to_be_pushed); - }} else { quote! { - let __new_row_ind = _self.#head_rel_name.push(#new_row_to_be_pushed); - }}; - - if !hcl.rel.is_lattice { - let add_row = quote_spanned!{hcl.span=> - let __new_row: #row_type = #new_row_tuple; - - if !::ascent::internal::RelFullIndexRead::contains_key(&#head_rel_full_index_expr_total, &__new_row) && - !::ascent::internal::RelFullIndexRead::contains_key(&#head_rel_full_index_expr_delta, &__new_row) { - if #rel_full_index_write_trait::insert_if_not_present(#new_ref #head_rel_full_index_expr_new, - &__new_row, ()) - { - #push_code - #(#update_indices)* - #set_changed_true_code - } - } - }; - add_rows.push(add_row); - } else { // rel.is_lattice: - let lattice_insertion_mutex = lattice_insertion_mutex_var_name(head_relation); - let head_lat_full_index = &mir.lattices_full_indices[head_relation]; - let head_lat_full_index_var_name_new = ir_relation_version_var_name(&head_lat_full_index.ir_name(), New); - let head_lat_full_index_var_name_delta = ir_relation_version_var_name(&head_lat_full_index.ir_name(), Delta); - let head_lat_full_index_var_name_full = ir_relation_version_var_name(&head_lat_full_index.ir_name(), Total); - let tuple_lat_index = syn::Index::from(hcl.rel.field_types.len() - 1); - let lattice_key_args : Vec = (0..hcl.args.len() - 1).map(|i| { - let i_ind = syn::Index::from(i); - syn::parse2(quote_spanned!{hcl.span=> __new_row.#i_ind}).unwrap() - }).map(|e| exp_cloned(&e)).collect_vec(); - let lattice_key_tuple = tuple(&lattice_key_args); - - let _self = quote! { _self }; - let add_row = if !mir.is_parallel { quote_spanned! {hcl.span=> - let __new_row: #row_type = #new_row_tuple; - let __lattice_key = #lattice_key_tuple; - if let Some(mut __existing_ind) = #head_lat_full_index_var_name_new.index_get(&__lattice_key) - .or_else(|| #head_lat_full_index_var_name_delta.index_get(&__lattice_key)) - .or_else(|| #head_lat_full_index_var_name_full.index_get(&__lattice_key)) - { - let __existing_ind = *__existing_ind.next().unwrap(); - // TODO possible excessive cloning here? - let __lat_changed = ::ascent::Lattice::join_mut(&mut #_self.#head_rel_name[__existing_ind].#tuple_lat_index, __new_row.#tuple_lat_index.clone()); - if __lat_changed { - let __new_row_ind = __existing_ind; - #(#update_indices)* - #set_changed_true_code - } - } else { - let __new_row_ind = #_self.#head_rel_name.len(); - #(#update_indices)* - #_self.#head_rel_name.push(#new_row_to_be_pushed); - #set_changed_true_code - } - }} else { quote_spanned! {hcl.span=> // mir.is_parallel: - let __new_row: #row_type = #new_row_tuple; - let __lattice_key = #lattice_key_tuple; - let __existing_ind_in_new = #head_lat_full_index_var_name_new.get_cloned(&__lattice_key); - let __new_has_ind = __existing_ind_in_new.is_some(); - if let Some(__existing_ind) = __existing_ind_in_new - .or_else(|| #head_lat_full_index_var_name_delta.get_cloned(&__lattice_key)) - .or_else(|| #head_lat_full_index_var_name_full.get_cloned(&__lattice_key)) - { - let __lat_changed = ::ascent::Lattice::join_mut(&mut #_self.#head_rel_name[__existing_ind].write().unwrap().#tuple_lat_index, - __new_row.#tuple_lat_index.clone()); - if __lat_changed && !__new_has_ind{ - let __new_row_ind = __existing_ind; - #(#update_indices)* - #set_changed_true_code - } - } else { - let __hash = #head_lat_full_index_var_name_new.hash_usize(&__lattice_key); - let __lock = #_self.#lattice_insertion_mutex.get(__hash % #_self.#lattice_insertion_mutex.len()).expect("lattice_insertion_mutex index out of bounds").lock().unwrap(); - if let Some(__existing_ind) = #head_lat_full_index_var_name_new.get_cloned(&__lattice_key) { - ::ascent::Lattice::join_mut(&mut #_self.#head_rel_name[__existing_ind].write().unwrap().#tuple_lat_index, - __new_row.#tuple_lat_index.clone()); - } else { - let __new_row_ind = #_self.#head_rel_name.push(::std::sync::RwLock::new(#new_row_to_be_pushed)); - #(#update_indices)* - #set_changed_true_code - } - } - }}; - add_rows.push(add_row); - } - } - ( - quote!{}, - quote!{#(#add_rows)*} - ) -} - -fn lattice_insertion_mutex_var_name(head_relation: &RelationIdentity) -> Ident { - Ident::new(&format!("__{}_mutex", head_relation.name), head_relation.name.span()) -} - -fn rel_ind_common_var_name(relation: &RelationIdentity) -> Ident { - Ident::new(&format!("__{}_ind_common", relation.name), relation.name.span()) -} - -fn convert_head_arg(arg: Expr) -> Expr { - if let Some(var) = expr_to_ident(&arg){ - parse2(quote_spanned!{arg.span()=> ascent::internal::Convert::convert(#var)}).unwrap() - } else { - arg - } -} - -fn expr_for_rel(rel: &MirRelation, mir: &AscentMir) -> proc_macro2::TokenStream { - fn expr_for_rel_inner(ir_name: &Ident, version: MirRelationVersion, _mir: &AscentMir, mir_rel: &MirRelation) -> (TokenStream, bool) { - let var = ir_relation_version_var_name(ir_name, version); - if mir_rel.relation.is_lattice { - (quote! { & #var }, true) - } else { - let rel_ind_common = ir_relation_version_var_name(&rel_ind_common_var_name(&mir_rel.relation), version); - (quote! { #var.to_rel_index(& #rel_ind_common) }, false) - } - } - - if rel.version == MirRelationVersion::TotalDelta { - let total_expr = expr_for_rel_inner(&rel.ir_name, MirRelationVersion::Total, mir, rel).0; - let delta_expr = expr_for_rel_inner(&rel.ir_name, MirRelationVersion::Delta, mir, rel).0; - quote! { - ascent::internal::RelIndexCombined::new(& #total_expr, & #delta_expr) - } - } else { - let (res, borrowed) = expr_for_rel_inner(&rel.ir_name, rel.version, mir, rel); - if !borrowed {res} else {quote!{(#res)}} - } -} - -fn expr_for_rel_write(mir_rel: &MirRelation, _mir: &AscentMir) -> proc_macro2::TokenStream { - let var = mir_rel.var_name(); - if mir_rel.relation.is_lattice { - quote!{ #var } - } else { - let rel_ind_common = ir_relation_version_var_name(&rel_ind_common_var_name(&mir_rel.relation), mir_rel.version); - quote! { #var.to_rel_index_write(&mut #rel_ind_common) } - } -} - -fn expr_for_c_rel_write(mir_rel: &MirRelation, _mir: &AscentMir) -> proc_macro2::TokenStream { - let var = mir_rel.var_name(); - if mir_rel.relation.is_lattice { - quote!{ #var } - } else { - let rel_ind_common = ir_relation_version_var_name(&rel_ind_common_var_name(&mir_rel.relation), mir_rel.version); - quote! { #var.to_c_rel_index_write(&#rel_ind_common) } - } -} - -fn clause_var_assignments( - rel: &MirRelation, - vars: impl Iterator, - val_ident: &Ident, - relation_expr: &Expr, - cloning_needed: bool, - mir: &AscentMir -) -> proc_macro2::TokenStream { - let mut assignments = vec![]; - - let mut any_vars = false; - for (ind_in_tuple, var) in vars { - let var_type_ascription = { - let ty = &rel.relation.field_types[ind_in_tuple]; - quote! { : & #ty} - }; - any_vars = true; - match &rel.val_type { - IndexValType::Reference => { - let ind = syn::Index::from(ind_in_tuple); - assignments.push(quote! { - let #var #var_type_ascription = &__row.#ind; - }) - }, - IndexValType::Direct(inds) => { - let ind = inds.iter().enumerate().find(|(_i, ind)| **ind == ind_in_tuple).unwrap().0; - let ind = syn::Index::from(ind); - - assignments.push(quote! { - let #var #var_type_ascription = #val_ident.#ind; - }) - }, - } - } - - if any_vars { - match &rel.val_type { - IndexValType::Reference => { - let maybe_lock = if rel.relation.is_lattice && mir.is_parallel { - quote! {.read().unwrap()} - } else { quote! {} }; - let maybe_clone = if cloning_needed { - quote! {.clone()} - } else { quote! {} }; - assignments.insert(0, quote! { - let __row = &#relation_expr[*#val_ident]#maybe_lock #maybe_clone; - }); - }, - IndexValType::Direct(_) => { - assignments.insert(0, quote! { - let #val_ident = #val_ident.tuple_of_borrowed(); - }); - } - } - } - - quote! { - #(#assignments)* - } -} - -fn index_get_entry_val_for_insert(rel_ind: &IrRelation, tuple_expr: &Expr, ind_expr: &Expr) -> Expr { - match &rel_ind.val_type { - IndexValType::Reference => ind_expr.clone(), - IndexValType::Direct(inds) => { - let val = inds.iter().map(|&ind| { - let ind = syn::Index::from(ind); - parse_quote!{ - #tuple_expr.#ind.clone() - } - }).collect_vec(); - tuple(&val) - }, - } -} +#![deny(warnings)] +use std::collections::HashSet; + +use itertools::{Either, Itertools}; +use proc_macro2::{Ident, Span, TokenStream}; +use syn::spanned::Spanned; +use syn::{Expr, Type, parse_quote, parse_quote_spanned, parse2}; + +use crate::ascent_hir::{IndexValType, IrRelation}; +use crate::ascent_mir::MirRelationVersion::*; +use crate::ascent_mir::{ + AscentMir, MirBodyItem, MirRelation, MirRelationVersion, MirRule, MirScc, ir_relation_version_var_name, + mir_rule_summary, mir_summary, +}; +use crate::ascent_syntax::{CondClause, RelationIdentity}; +use crate::utils::{TokenStreamExtensions, exp_cloned, expr_to_ident, tuple, tuple_spanned, tuple_type}; + +pub(crate) fn compile_mir(mir: &AscentMir, is_ascent_run: bool) -> proc_macro2::TokenStream { + let mut relation_fields = vec![]; + let mut field_defaults = vec![]; + + let sorted_relations_ir_relations = mir.relations_ir_relations.iter().sorted_by_key(|(rel, _)| &rel.name); + for (rel, rel_indices) in sorted_relations_ir_relations { + let name = &rel.name; + let rel_attrs = &mir.relations_metadata[rel].attributes; + let sorted_rel_index_names = rel_indices.iter().map(|ind| format!("{}", ind.ir_name())).sorted(); + let rel_indices_comment = format!("\nlogical indices: {}", sorted_rel_index_names.into_iter().join("; ")); + let rel_type = rel_type(rel, mir); + let rel_ind_common = rel_ind_common_var_name(rel); + let rel_ind_common_type = rel_ind_common_type(rel, mir); + relation_fields.push(quote! { + #(#rel_attrs)* + #[doc = #rel_indices_comment] + pub #name: #rel_type, + pub #rel_ind_common: #rel_ind_common_type, + }); + field_defaults.push(quote! {#name : Default::default(), #rel_ind_common: Default::default(),}); + if rel.is_lattice && mir.is_parallel { + let lattice_mutex_name = lattice_insertion_mutex_var_name(rel); + relation_fields.push(quote! { + pub #lattice_mutex_name: ::std::vec::Vec>, + }); + field_defaults.push(quote! {#lattice_mutex_name: { + let len = ::ascent::internal::shards_count(); + let mut v = Vec::with_capacity(len); for _ in 0..len {v.push(Default::default())}; + v }, + }) + } + let sorted_indices = rel_indices.iter().sorted_by_cached_key(|ind| ind.ir_name()); + for ind in sorted_indices { + let name = &ind.ir_name(); + let rel_index_type = rel_index_type(ind, mir); + relation_fields.push(quote! { + pub #name: #rel_index_type, + }); + field_defaults.push(quote! {#name : Default::default(),}); + } + } + + let sccs_ordered = &mir.sccs; + let mut rule_time_fields = vec![]; + let mut rule_time_fields_defaults = vec![]; + for i in 0..mir.sccs.len() { + for (rule_ind, _rule) in mir.sccs[i].rules.iter().enumerate() { + let name = rule_time_field_name(i, rule_ind); + rule_time_fields.push(quote! { + pub #name: std::time::Duration, + }); + rule_time_fields_defaults.push(quote! { + #name: std::time::Duration::ZERO, + }); + } + } + + let mut sccs_compiled = vec![]; + for (i, _scc) in sccs_ordered.iter().enumerate() { + let msg = format!("scc {}", i); + let scc_compiled = compile_mir_scc(mir, i); + sccs_compiled.push(quote! { + ascent::internal::comment(#msg); + { + let _scc_start_time = ::ascent::internal::Instant::now(); + #scc_compiled + _self.scc_times[#i] += _scc_start_time.elapsed(); + } + }); + } + + let update_indices_body = compile_update_indices_function_body(mir); + let relation_sizes_body = compile_relation_sizes_body(mir); + let scc_times_summary_body = compile_scc_times_summary_body(mir); + + let mut type_constraints = vec![]; + let mut field_type_names = HashSet::::new(); + let mut lat_field_type_names = HashSet::::new(); + + for relation in mir.relations_ir_relations.keys().sorted_by_key(|rel| &rel.name) { + use crate::quote::ToTokens; + for (i, field_type) in relation.field_types.iter().enumerate() { + let is_lat = relation.is_lattice && i == relation.field_types.len() - 1; + let add = if let Type::Path(path) = field_type { + let container = if is_lat { &mut lat_field_type_names } else { &mut field_type_names }; + container.insert(path.path.clone().into_token_stream().to_string()) + } else { + true + }; + if add { + let type_constraints_type = if is_lat { + quote_spanned!(field_type.span()=>LatTypeConstraints) + } else { + quote_spanned!(field_type.span()=>TypeConstraints) + }; + type_constraints.push(quote_spanned! {field_type.span()=> + let _type_constraints : ascent::internal::#type_constraints_type<#field_type>; + }); + if mir.is_parallel { + type_constraints.push(quote_spanned! {field_type.span()=> + let _par_constraints : ascent::internal::ParTypeConstraints<#field_type>; + }); + } + } + } + } + + let mut relation_initializations = vec![]; + for (rel, md) in mir.relations_metadata.iter().sorted_by_key(|(rel, _)| &rel.name) { + if let Some(ref init) = md.initialization { + let rel_name = &rel.name; + relation_initializations.push(quote! { + _self.#rel_name = #init; + }); + } + } + if !relation_initializations.is_empty() { + relation_initializations.push(quote! { + _self.update_indices_priv(); + }) + } + + let par_usings = if mir.is_parallel { + quote! { + use ascent::rayon::iter::ParallelBridge; + use ascent::rayon::iter::ParallelIterator; + use ascent::internal::CRelIndexRead; + use ascent::internal::CRelIndexReadAll; + use ascent::internal::Freezable; + } + } else { + quote! {} + }; + + let more_usings = if !mir.is_parallel { + quote! { + use ascent::internal::RelIndexWrite; + } + } else { + quote! { + use ascent::internal::CRelIndexWrite; + } + }; + + let run_usings = quote! { + use core::cmp::PartialEq; + use ascent::internal::{RelIndexRead, RelIndexReadAll, ToRelIndex0, TupleOfBorrowed}; + #more_usings + #par_usings + }; + + let generate_run_timeout = !is_ascent_run && mir.config.generate_run_partial; + let run_func = if is_ascent_run { + quote! {} + } else if generate_run_timeout { + quote! { + #[doc = "Runs the Ascent program to a fixed point."] + pub fn run(&mut self) { + self.run_timeout(::std::time::Duration::MAX); + } + } + } else { + quote! { + #[allow(unused_imports, noop_method_call, suspicious_double_ref_op)] + #[doc = "Runs the Ascent program to a fixed point."] + pub fn run(&mut self) { + macro_rules! __check_return_conditions {() => {};} + #run_usings + self.update_indices_priv(); + let _self = self; + #(#sccs_compiled)* + } + } + }; + let run_timeout_func = if !generate_run_timeout { + quote! {} + } else { + quote! { + #[allow(unused_imports, noop_method_call, suspicious_double_ref_op)] + #[doc = "Runs the Ascent program to a fixed point or until the timeout is reached. In case of a timeout returns false"] + pub fn run_timeout(&mut self, timeout: ::std::time::Duration) -> bool { + let __start_time = ::ascent::internal::Instant::now(); + macro_rules! __check_return_conditions {() => { + if timeout < ::std::time::Duration::MAX && __start_time.elapsed() >= timeout {return false;} + };} + #run_usings + self.update_indices_priv(); + let _self = self; + #(#sccs_compiled)* + true + } + } + }; + let run_code = if !is_ascent_run { + quote! {} + } else { + quote! { + macro_rules! __check_return_conditions {() => {};} + #run_usings + let _self = &mut __run_res; + #(#relation_initializations)* + #(#sccs_compiled)* + } + }; + + let relation_initializations_for_default_impl = if is_ascent_run { vec![] } else { relation_initializations }; + let summary = mir_summary(mir); + + let (ty_impl_generics, ty_ty_generics, ty_where_clause) = mir.signatures.split_ty_generics_for_impl(); + let (impl_impl_generics, impl_ty_generics, impl_where_clause) = mir.signatures.split_impl_generics_for_impl(); + + let ty_signature = &mir.signatures.declaration; + if let Some(impl_signature) = &mir.signatures.implementation { + assert_eq!(ty_signature.ident, impl_signature.ident, "The identifiers of struct and impl must match"); + } + + let ty_ty_generics_str = quote!(#ty_ty_generics).to_string(); + let impl_ty_generics_str = quote!(#impl_ty_generics).to_string(); + assert_eq!( + ty_ty_generics_str, impl_ty_generics_str, + "The generic parameters of struct ({ty_ty_generics_str}) and impl ({impl_ty_generics_str}) must match" + ); + + let vis = &ty_signature.visibility; + let struct_name = &ty_signature.ident; + let struct_attrs = &ty_signature.attrs; + let summary_fn = if is_ascent_run { + quote! { + pub fn summary(&self) -> &'static str {#summary} + } + } else { + quote! { + pub fn summary() -> &'static str {#summary} + } + }; + let rule_time_fields = if mir.config.include_rule_times { rule_time_fields } else { vec![] }; + let rule_time_fields_defaults = if mir.config.include_rule_times { rule_time_fields_defaults } else { vec![] }; + + let mut rel_codegens = vec![]; + for rel in mir.relations_ir_relations.keys() { + let macro_path = &mir.relations_metadata[rel].ds_macro_path; + let macro_input = rel_ds_macro_input(rel, mir); + rel_codegens.push(quote_spanned! { macro_path.span()=> #macro_path::rel_codegen!{#macro_input} }); + } + + let sccs_count = sccs_ordered.len(); + let res = quote! { + #(#rel_codegens)* + + #(#struct_attrs)* + #vis struct #struct_name #ty_impl_generics #ty_where_clause { + #(#relation_fields)* + scc_times: [std::time::Duration; #sccs_count], + scc_iters: [usize; #sccs_count], + #(#rule_time_fields)* + pub update_time_nanos: std::sync::atomic::AtomicU64, + pub update_indices_duration: std::time::Duration, + } + impl #impl_impl_generics #struct_name #impl_ty_generics #impl_where_clause { + #run_func + + #run_timeout_func + // TODO remove pub update_indices at some point + #[allow(noop_method_call, suspicious_double_ref_op)] + fn update_indices_priv(&mut self) { + let before = ::ascent::internal::Instant::now(); + #update_indices_body + self.update_indices_duration += before.elapsed(); + } + + #[deprecated = "Explicit call to update_indices not required anymore."] + pub fn update_indices(&mut self) { + self.update_indices_priv(); + } + fn type_constraints() { + #(#type_constraints)* + } + #summary_fn + + pub fn relation_sizes_summary(&self) -> String { + #relation_sizes_body + } + pub fn scc_times_summary(&self) -> String { + #scc_times_summary_body + } + } + impl #impl_impl_generics Default for #struct_name #impl_ty_generics #impl_where_clause { + fn default() -> Self { + let mut _self = #struct_name { + #(#field_defaults)* + scc_times: [std::time::Duration::ZERO; #sccs_count], + scc_iters: [0; #sccs_count], + #(#rule_time_fields_defaults)* + update_time_nanos: Default::default(), + update_indices_duration: std::time::Duration::default() + }; + #(#relation_initializations_for_default_impl)* + _self + } + } + }; + if !is_ascent_run { + res + } else { + quote! { + { + #res + let mut __run_res: #struct_name #ty_ty_generics = #struct_name::default(); + #[allow(unused_imports, noop_method_call, suspicious_double_ref_op)] + { + ascent::internal::comment("running..."); + #run_code + } + __run_res + } + } + } +} + +fn rel_ind_common_type(rel: &RelationIdentity, mir: &AscentMir) -> Type { + if rel.is_lattice { + parse_quote! { () } + } else { + let macro_path = &mir.relations_metadata[rel].ds_macro_path; + let macro_input = rel_ds_macro_input(rel, mir); + parse_quote_spanned! { macro_path.span()=> #macro_path::rel_ind_common!(#macro_input) } + } +} + +fn rel_index_type(rel: &IrRelation, mir: &AscentMir) -> Type { + let span = rel.relation.name.span(); + let key_type = rel.key_type(); + let value_type = rel.value_type(); + + let is_lat_full_index = rel.relation.is_lattice && &mir.lattices_full_indices[&rel.relation] == rel; + + if rel.relation.is_lattice { + let res = if !mir.is_parallel { + if is_lat_full_index { + quote_spanned! { span=>ascent::internal::RelFullIndexType<#key_type, #value_type> } + } else { + quote_spanned! { span=>ascent::internal::LatticeIndexType<#key_type, #value_type> } + } + } else { + // parallel + if is_lat_full_index { + quote_spanned! { span=>ascent::internal::CRelFullIndex<#key_type, #value_type> } + } else if rel.is_no_index() { + quote_spanned! { span=>ascent::internal::CRelNoIndex<#value_type> } + } else { + quote_spanned! { span=>ascent::internal::CRelIndex<#key_type, #value_type> } + } + }; + syn::parse2(res).unwrap() + } else { + let macro_path = &mir.relations_metadata[&rel.relation].ds_macro_path; + let span = macro_path.span(); + let macro_input = rel_ds_macro_input(&rel.relation, mir); + if rel.is_full_index() { + parse_quote_spanned! {span=> #macro_path::rel_full_ind!(#macro_input, #key_type, #value_type)} + } else { + let ind = rel_index_to_macro_input(&rel.indices); + parse_quote_spanned! {span=> #macro_path::rel_ind!(#macro_input, #ind, #key_type, #value_type)} + } + } +} + +fn rel_type(rel: &RelationIdentity, mir: &AscentMir) -> Type { + let field_types = tuple_type(&rel.field_types); + + if rel.is_lattice { + if mir.is_parallel { + parse_quote! {::ascent::boxcar::Vec<::std::sync::RwLock<#field_types>>} + } else { + parse_quote! {::std::vec::Vec<#field_types>} + } + } else { + let macro_path = &mir.relations_metadata[rel].ds_macro_path; + let macro_input = rel_ds_macro_input(rel, mir); + parse_quote_spanned! {macro_path.span()=> #macro_path::rel!(#macro_input) } + } +} + +fn rel_index_to_macro_input(ind: &[usize]) -> TokenStream { + let indices = ind.iter().cloned().map(syn::Index::from); + quote! { [#(#indices),*] } +} + +fn rel_ds_macro_input(rel: &RelationIdentity, mir: &AscentMir) -> TokenStream { + let span = rel.name.span(); + let field_types = tuple_type(&rel.field_types); + let indices = mir.relations_ir_relations[rel] + .iter() + .sorted_by_key(|r| &r.indices) + .map(|ir_rel| rel_index_to_macro_input(&ir_rel.indices)); + let args = &mir.relations_metadata[rel].ds_macro_args; + let par: Ident = if mir.is_parallel { + parse_quote_spanned! {span=> par} + } else { + parse_quote_spanned! {span=> ser} + }; + let name = Ident::new(&format!("{}_{}", mir.signatures.declaration.ident, rel.name), span); + quote! { + #name, + #field_types, + [#(#indices),*], + #par, + (#args) + } +} + +fn rule_time_field_name(scc_ind: usize, rule_ind: usize) -> Ident { + Ident::new(&format!("rule{}_{}_duration", scc_ind, rule_ind), Span::call_site()) +} + +fn compile_mir_scc(mir: &AscentMir, scc_ind: usize) -> proc_macro2::TokenStream { + let scc = &mir.sccs[scc_ind]; + let mut move_total_to_delta = vec![]; + let mut shift_delta_to_total_new_to_delta = vec![]; + let mut move_total_to_field = vec![]; + let mut freeze_code = vec![]; + let mut unfreeze_code = vec![]; + + let _self = quote! { _self }; + + use std::iter::once; + let sorted_dynamic_relations = scc.dynamic_relations.iter().sorted_by_cached_key(|(rel, _)| rel.name.clone()); + for rel in sorted_dynamic_relations.flat_map(|(rel, indices)| { + once(Either::Left(rel)).chain(indices.iter().sorted_by_cached_key(|rel| rel.ir_name()).map(Either::Right)) + }) { + let (ir_name, ty) = match rel { + Either::Left(rel) => (rel_ind_common_var_name(rel), rel_ind_common_type(rel, mir)), + Either::Right(rel_ind) => (rel_ind.ir_name(), rel_index_type(rel_ind, mir)), + }; + let delta_var_name = ir_relation_version_var_name(&ir_name, MirRelationVersion::Delta); + let total_var_name = ir_relation_version_var_name(&ir_name, MirRelationVersion::Total); + let new_var_name = ir_relation_version_var_name(&ir_name, MirRelationVersion::New); + let total_field = &ir_name; + move_total_to_delta.push(quote_spanned! {ir_name.span()=> + let mut #delta_var_name: #ty = ::std::mem::take(&mut #_self.#total_field); + let mut #total_var_name : #ty = Default::default(); + let mut #new_var_name : #ty = Default::default(); + }); + + match rel { + Either::Left(_rel_ind_common) => { + shift_delta_to_total_new_to_delta.push(quote_spanned!{ir_name.span()=> + ::ascent::internal::RelIndexMerge::merge_delta_to_total_new_to_delta(&mut #new_var_name, &mut #delta_var_name, &mut #total_var_name); + }); + move_total_to_delta.push(quote_spanned! {ir_name.span()=> + ::ascent::internal::RelIndexMerge::init(&mut #new_var_name, &mut #delta_var_name, &mut #total_var_name); + }); + }, + Either::Right(ir_rel) => { + let delta_expr = expr_for_rel_write(&MirRelation::from(ir_rel.clone(), Delta), mir); + let total_expr = expr_for_rel_write(&MirRelation::from(ir_rel.clone(), Total), mir); + let new_expr = expr_for_rel_write(&MirRelation::from(ir_rel.clone(), New), mir); + + shift_delta_to_total_new_to_delta.push(quote_spanned!{ir_name.span()=> + ::ascent::internal::RelIndexMerge::merge_delta_to_total_new_to_delta(&mut #new_expr, &mut #delta_expr, &mut #total_expr); + }); + move_total_to_delta.push(quote_spanned! {ir_name.span()=> + ::ascent::internal::RelIndexMerge::init(&mut #new_expr, &mut #delta_expr, &mut #total_expr); + }); + }, + } + + move_total_to_field.push(quote_spanned! {ir_name.span()=> + #_self.#total_field = #total_var_name; + }); + + if mir.is_parallel { + freeze_code.push(quote_spanned! {ir_name.span()=> + #total_var_name.freeze(); + #delta_var_name.freeze(); + }); + + unfreeze_code.push(quote_spanned! {ir_name.span()=> + #total_var_name.unfreeze(); + #delta_var_name.unfreeze(); + }); + } + } + let sorted_body_only_relations = scc.body_only_relations.iter().sorted_by_cached_key(|(rel, _)| rel.name.clone()); + for rel in sorted_body_only_relations.flat_map(|(rel, indices)| { + let sorted_indices = indices.iter().sorted_by_cached_key(|rel| rel.ir_name()); + once(Either::Left(rel)).chain(sorted_indices.map(Either::Right)) + }) { + let (ir_name, ty) = match rel { + Either::Left(rel) => (rel_ind_common_var_name(rel), rel_ind_common_type(rel, mir)), + Either::Right(rel_ind) => (rel_ind.ir_name(), rel_index_type(rel_ind, mir)), + }; + let total_var_name = ir_relation_version_var_name(&ir_name, MirRelationVersion::Total); + let total_field = &ir_name; + + if mir.is_parallel { + move_total_to_delta.push(quote_spanned! {ir_name.span()=> + #_self.#total_field.freeze(); + }); + } + + move_total_to_delta.push(quote_spanned! {ir_name.span()=> + let #total_var_name: #ty = std::mem::take(&mut #_self.#total_field); + }); + + move_total_to_field.push(quote_spanned! {ir_name.span()=> + #_self.#total_field = #total_var_name; + }); + } + + let rule_parallelism = mir.config.inter_rule_parallelism && mir.is_parallel; + + let mut evaluate_rules = vec![]; + + for (i, rule) in scc.rules.iter().enumerate() { + let msg = mir_rule_summary(rule); + let rule_compiled = compile_mir_rule(rule, scc, mir); + let rule_time_field = rule_time_field_name(scc_ind, i); + let (before_rule_var, update_rule_time_field) = if mir.config.include_rule_times { + ( + quote! {let before_rule = ::ascent::internal::Instant::now();}, + quote! {_self.#rule_time_field += before_rule.elapsed();}, + ) + } else { + (quote! {}, quote! {}) + }; + evaluate_rules.push(if rule_parallelism { + quote! { + ascent::internal::comment(#msg); + __scope.spawn(|_| { + #before_rule_var + #rule_compiled + #update_rule_time_field + }); + } + } else { + quote! { + #before_rule_var + ascent::internal::comment(#msg); + { + #rule_compiled + } + #update_rule_time_field + } + }); + } + + let evaluate_rules = if rule_parallelism { + quote! { + ascent::rayon::scope(|__scope| { + #(#evaluate_rules)* + }); + } + } else { + quote! { #(#evaluate_rules)* } + }; + + let changed_var_def_code = if !mir.is_parallel { + quote! { let mut __changed = false; } + } else { + quote! { let __changed = std::sync::atomic::AtomicBool::new(false); } + }; + let check_changed_code = if !mir.is_parallel { + quote! {__changed} + } else { + quote! {__changed.load(std::sync::atomic::Ordering::Relaxed)} + }; + + let evaluate_rules_loop = if scc.is_looping { + quote! { + #[allow(unused_assignments, unused_variables)] + loop { + #changed_var_def_code + + #(#freeze_code)* + // evaluate rules + #evaluate_rules + + #(#unfreeze_code)* + #(#shift_delta_to_total_new_to_delta)* + _self.scc_iters[#scc_ind] += 1; + if !#check_changed_code {break;} + __check_return_conditions!(); + } + + } + } else { + quote! { + #[allow(unused_assignments, unused_variables)] + { + #changed_var_def_code + #(#freeze_code)* + + #evaluate_rules + + #(#unfreeze_code)* + + #(#shift_delta_to_total_new_to_delta)* + #(#shift_delta_to_total_new_to_delta)* + _self.scc_iters[#scc_ind] += 1; + __check_return_conditions!(); + } + } + }; + quote! { + // define variables for delta and new versions of dynamic relations in the scc + // move total versions of dynamic indices to delta + #(#move_total_to_delta)* + + #evaluate_rules_loop + + #(#move_total_to_field)* + } +} + +fn compile_relation_sizes_body(mir: &AscentMir) -> proc_macro2::TokenStream { + let mut write_sizes = vec![]; + for r in mir.relations_ir_relations.keys().sorted_by_key(|r| &r.name) { + let rel_name = &r.name; + let rel_name_str = r.name.to_string(); + write_sizes.push(quote! { + writeln!(&mut res, "{} size: {}", #rel_name_str, self.#rel_name.len()).unwrap(); + }); + } + quote! { + use std::fmt::Write; + let mut res = String::new(); + #(#write_sizes)* + res + } +} + +fn compile_scc_times_summary_body(mir: &AscentMir) -> proc_macro2::TokenStream { + let mut res = vec![]; + for i in 0..mir.sccs.len() { + let i_str = format!("{}", i); + res.push(quote!{ + writeln!(&mut res, "scc {}: iterations: {}, time: {:?}", #i_str, self.scc_iters[#i], self.scc_times[#i]).unwrap(); + }); + if mir.config.include_rule_times { + let mut sum_of_rule_times = quote! { std::time::Duration::ZERO }; + for (rule_ind, _rule) in mir.sccs[i].rules.iter().enumerate() { + let rule_time_field = rule_time_field_name(i, rule_ind); + sum_of_rule_times = quote! { #sum_of_rule_times + self.#rule_time_field}; + } + res.push(quote! { + let sum_of_rule_times = #sum_of_rule_times; + writeln!(&mut res, " sum of rule times: {:?}", sum_of_rule_times).unwrap(); + }); + for (rule_ind, rule) in mir.sccs[i].rules.iter().enumerate() { + let rule_time_field = rule_time_field_name(i, rule_ind); + let rule_summary = mir_rule_summary(rule); + res.push(quote! { + writeln!(&mut res, " rule {}\n time: {:?}", #rule_summary, self.#rule_time_field).unwrap(); + }); + } + res.push(quote! { + writeln!(&mut res, "-----------------------------------------").unwrap(); + }); + } + } + let update_indices_time_code = quote! { + writeln!(&mut res, "update_indices time: {:?}", self.update_indices_duration).unwrap(); + }; + quote! { + use std::fmt::Write; + let mut res = String::new(); + #update_indices_time_code + #(#res)* + res + } +} + +fn compile_update_indices_function_body(mir: &AscentMir) -> proc_macro2::TokenStream { + let par = mir.is_parallel; + let mut res = vec![]; + if par { + res.push(quote! { use ascent::rayon::iter::{IntoParallelIterator, ParallelIterator}; }) + } + let (rel_index_write_trait, index_insert_fn) = if !par { + (quote! {ascent::internal::RelIndexWrite}, quote! {index_insert}) + } else { + (quote! {ascent::internal::CRelIndexWrite}, quote! {index_insert}) + }; + let sorted_relations_ir_relations = mir.relations_ir_relations.iter().sorted_by_key(|(rel, _)| &rel.name); + for (r, indices_set) in sorted_relations_ir_relations { + let _ref = if !par { + quote! {&mut} + } else { + quote! {&} + } + .with_span(r.name.span()); + let ind_common = rel_ind_common_var_name(r); + let rel_index_write_trait = rel_index_write_trait.clone().with_span(r.name.span()); + let _self = quote_spanned! {r.name.span().resolved_at(Span::call_site())=> self }; + let to_rel_index_fn = if !par { + quote! {to_rel_index_write} + } else { + quote! {to_c_rel_index_write} + }; + let to_rel_index = if r.is_lattice { + quote! {} + } else { + quote! {.#to_rel_index_fn(#_ref #_self.#ind_common) } + }; + + let mut update_indices = vec![]; + for ind in indices_set.iter().sorted_by_cached_key(|rel| rel.ir_name()) { + let ind_name = &ind.ir_name(); + let selection_tuple: Vec = ind + .indices + .iter() + .map(|&i| { + let ind = syn::Index::from(i); + parse_quote_spanned! {r.name.span()=> tuple.#ind.clone()} + }) + .collect_vec(); + let selection_tuple = tuple_spanned(&selection_tuple, r.name.span()); + let entry_val = index_get_entry_val_for_insert( + ind, + &parse_quote_spanned! {r.name.span()=> tuple}, + &parse_quote_spanned! {r.name.span()=> _i}, + ); + let _pre_ref = if r.is_lattice { quote!() } else { _ref.clone() }; + update_indices.push(quote_spanned! {r.name.span()=> + let selection_tuple = #selection_tuple; + let rel_ind = #_ref #_self.#ind_name; + #rel_index_write_trait::#index_insert_fn(#_pre_ref rel_ind #to_rel_index, selection_tuple, #entry_val); + }); + } + let rel_name = &r.name; + let maybe_lock = if r.is_lattice && mir.is_parallel { + quote_spanned! {r.name.span()=> let tuple = tuple.read().unwrap(); } + } else { + quote! {} + }; + if !par { + res.push(quote_spanned! {r.name.span()=> + for (_i, tuple) in #_self.#rel_name.iter().enumerate() { + #maybe_lock + #(#update_indices)* + } + }); + } else { + res.push(quote_spanned! {r.name.span()=> + (0..#_self.#rel_name.len()).into_par_iter().for_each(|_i| { + let tuple = &#_self.#rel_name[_i]; + #maybe_lock + #(#update_indices)* + }); + }); + } + } + + quote! { + use ascent::internal::ToRelIndex0; + use #rel_index_write_trait; + #(#res)* + } +} + +fn compile_cond_clause(cond: &CondClause, body: proc_macro2::TokenStream) -> proc_macro2::TokenStream { + match cond { + CondClause::IfLet(if_let_clause) => { + let pat = &if_let_clause.pattern; + let expr = &if_let_clause.exp; + quote_spanned! {if_let_clause.if_keyword.span()=> + if let #pat = #expr { + #body + } + } + }, + CondClause::Let(let_clause) => { + let pat = &let_clause.pattern; + let expr = &let_clause.exp; + quote_spanned! {let_clause.let_keyword.span()=> + let #pat = #expr; + #body + } + }, + CondClause::If(if_clause) => { + let cond = &if_clause.cond; + quote_spanned! {if_clause.if_keyword.span()=> + if #cond { + #body + } + } + }, + } +} + +fn compile_mir_rule(rule: &MirRule, scc: &MirScc, mir: &AscentMir) -> proc_macro2::TokenStream { + let head_update_code = head_update_code(rule, scc, mir); + + const MAX_PAR_ITERS: usize = 2; + + // do parallel iteration up to this clause index (exclusive) + let par_iter_to_ind = if mir.is_parallel { + use itertools::FoldWhile::*; + rule + .body_items + .iter() + .fold_while((0, 0), |(count, i), bi| { + let new_count = count + matches!(bi, MirBodyItem::Clause(..)) as usize; + if new_count > MAX_PAR_ITERS { Done((new_count, i)) } else { Continue((new_count, i + 1)) } + }) + .into_inner() + .1 + } else { + 0 + }; + compile_mir_rule_inner(rule, scc, mir, par_iter_to_ind, head_update_code, 0) +} + +fn compile_mir_rule_inner( + rule: &MirRule, _scc: &MirScc, mir: &AscentMir, par_iter_to_ind: usize, head_update_code: proc_macro2::TokenStream, + clause_ind: usize, +) -> proc_macro2::TokenStream { + if Some(clause_ind) == rule.simple_join_start_index && rule.reorderable { + let mut rule_cp1 = rule.clone(); + let mut rule_cp2 = rule.clone(); + rule_cp1.reorderable = false; + rule_cp2.reorderable = false; + rule_cp2.body_items.swap(clause_ind, clause_ind + 1); + let rule_cp1_compiled = + compile_mir_rule_inner(&rule_cp1, _scc, mir, par_iter_to_ind, head_update_code.clone(), clause_ind); + let rule_cp2_compiled = + compile_mir_rule_inner(&rule_cp2, _scc, mir, par_iter_to_ind, head_update_code, clause_ind); + + if let [MirBodyItem::Clause(bcl1), MirBodyItem::Clause(bcl2)] = &rule.body_items[clause_ind..clause_ind + 2] { + let rel1_var_name = expr_for_rel(&bcl1.rel, mir); + let rel2_var_name = expr_for_rel(&bcl2.rel, mir); + + return quote_spanned! {bcl1.rel_args_span=> + if #rel1_var_name.len() <= #rel2_var_name.len() { + #rule_cp1_compiled + } else { + #rule_cp2_compiled + } + }; + } else { + panic!("unexpected body items in reorderable rule") + } + } + if clause_ind < rule.body_items.len() { + let bitem = &rule.body_items[clause_ind]; + let doing_simple_join = rule.simple_join_start_index == Some(clause_ind); + let next_loop = if doing_simple_join { + compile_mir_rule_inner(rule, _scc, mir, par_iter_to_ind, head_update_code, clause_ind + 2) + } else { + compile_mir_rule_inner(rule, _scc, mir, par_iter_to_ind, head_update_code, clause_ind + 1) + }; + + match bitem { + MirBodyItem::Clause(bclause) => { + let (clause_ind, bclause) = if doing_simple_join { + (clause_ind + 1, rule.body_items[clause_ind + 1].unwrap_clause()) + } else { + (clause_ind, bclause) + }; + + let bclause_rel_name = &bclause.rel.relation.name; + let selected_args = &bclause.selected_args(); + let pre_clause_vars = + rule.body_items.iter().take(clause_ind).flat_map(MirBodyItem::bound_vars).collect::>(); + + let clause_vars = bclause.vars(); + let common_vars = clause_vars.iter().filter(|(_i, v)| pre_clause_vars.contains(v)).collect::>(); + let common_vars_no_indices = common_vars.iter().map(|(_i, v)| v.clone()).collect::>(); + + let cloning_needed = true; + + let matched_val_ident = Ident::new("__val", bclause.rel_args_span); + let new_vars_assignments = clause_var_assignments( + &bclause.rel, + clause_vars.iter().filter(|(_i, var)| !common_vars_no_indices.contains(var)).cloned(), + &matched_val_ident, + &parse_quote! {_self.#bclause_rel_name}, + cloning_needed, + mir, + ); + + let selected_args_cloned = selected_args.iter().map(exp_cloned).collect_vec(); + let selected_args_tuple = tuple_spanned(&selected_args_cloned, bclause.args_span); + let rel_version_exp = expr_for_rel(&bclause.rel, mir); + + let mut conds_then_next_loop = next_loop; + for cond in bclause.cond_clauses.iter().rev() { + conds_then_next_loop = compile_cond_clause(cond, conds_then_next_loop); + } + + let span = bclause.rel_args_span; + + let matching_dot_iter = quote_spanned! {bclause.rel_args_span=> __matching}; + + let (index_get, iter_all) = if clause_ind < par_iter_to_ind { + (quote_spanned! {span=> c_index_get}, quote_spanned! {span=> c_iter_all}) + } else { + (quote_spanned! {span=> index_get}, quote_spanned! {span=> iter_all}) + }; + + // The special case where the first clause has indices, but there are no expressions + // in the args of the first clause + if doing_simple_join { + let cl1 = rule.body_items[rule.simple_join_start_index.unwrap()].unwrap_clause(); + let cl2 = bclause; + let cl1_var_name = expr_for_rel(&cl1.rel, mir); + let cl2_var_name = expr_for_rel(&cl2.rel, mir); + let cl1_vars = cl1.vars(); + + let cl1_rel_name = &cl1.rel.relation.name; + + let mut cl1_join_vars_assignments = vec![]; + for (tuple_ind, &i) in cl1.rel.indices.iter().enumerate() { + let var = expr_to_ident(&cl1.args[i]).unwrap(); + let tuple_ind = syn::Index { index: tuple_ind as u32, span: var.span() }; + cl1_join_vars_assignments + .push(quote_spanned! {var.span()=> let #var = __cl1_joined_columns.#tuple_ind;}); + } + + let cl1_matched_val_ident = syn::Ident::new("cl1_val", cl1.rel_args_span); + let cl1_vars_assignments = clause_var_assignments( + &cl1.rel, + cl1_vars.iter().filter(|(i, _var)| !cl1.rel.indices.contains(i)).cloned(), + &cl1_matched_val_ident, + &parse_quote! {_self.#cl1_rel_name}, + cloning_needed, + mir, + ); + let cl1_vars_assignments = vec![cl1_vars_assignments]; + + let joined_args_for_cl2_cloned = cl2.selected_args().iter().map(exp_cloned).collect_vec(); + let joined_args_tuple_for_cl2 = tuple_spanned(&joined_args_for_cl2_cloned, cl2.args_span); + + let cl1_tuple_indices_iter = quote_spanned!(cl1.rel_args_span=> __cl1_tuple_indices); + + let mut cl1_conds_then_rest = quote_spanned! {bclause.rel_args_span=> + #matching_dot_iter.clone().for_each(|__val| { + // TODO we may be doing excessive cloning + #new_vars_assignments + #conds_then_next_loop + }); + }; + for cond in cl1.cond_clauses.iter().rev() { + cl1_conds_then_rest = compile_cond_clause(cond, cl1_conds_then_rest); + } + quote_spanned! {cl1.rel_args_span=> + #cl1_var_name.#iter_all().for_each(|(__cl1_joined_columns, __cl1_tuple_indices)| { + let __cl1_joined_columns = __cl1_joined_columns.tuple_of_borrowed(); + #(#cl1_join_vars_assignments)* + if let Some(__matching) = #cl2_var_name.#index_get(&#joined_args_tuple_for_cl2) { + #cl1_tuple_indices_iter.for_each(|cl1_val| { + #(#cl1_vars_assignments)* + #cl1_conds_then_rest + }); + } + }); + } + } else { + quote_spanned! {bclause.rel_args_span=> + if let Some(__matching) = #rel_version_exp.#index_get( &#selected_args_tuple) { + #matching_dot_iter.for_each(|__val| { + // TODO we may be doing excessive cloning + #new_vars_assignments + #conds_then_next_loop + }); + } + } + } + }, + MirBodyItem::Generator(gen) => { + let pat = &gen.pattern; + let expr = &gen.expr; + quote_spanned! {gen.for_keyword.span()=> + for #pat in #expr { + #next_loop + } + } + }, + MirBodyItem::Cond(cond) => compile_cond_clause(cond, next_loop), + MirBodyItem::Agg(agg) => { + let pat = &agg.pat; + let rel_name = &agg.rel.relation.name; + let mir_relation = MirRelation::from(agg.rel.clone(), Total); + // let rel_version_var_name = mir_relation.var_name(); + let rel_expr = expr_for_rel(&mir_relation, mir); + let selected_args = mir_relation.indices.iter().map(|&i| &agg.rel_args[i]); + let selected_args_cloned = selected_args.map(exp_cloned).collect_vec(); + let selected_args_tuple = tuple_spanned(&selected_args_cloned, agg.span); + let agg_args_tuple_indices = agg.bound_args.iter().map(|arg| { + ( + agg.rel_args.iter().find_position(|rel_arg| expr_to_ident(rel_arg) == Some(arg.clone())).unwrap().0, + arg.clone(), + ) + }); + + let agg_args_tuple = + tuple_spanned(&agg.bound_args.iter().map(|v| parse_quote! {#v}).collect_vec(), agg.span); + + let vars_assignments = clause_var_assignments( + &MirRelation::from(agg.rel.clone(), MirRelationVersion::Total), + agg_args_tuple_indices, + &parse_quote_spanned! {agg.span=> __val}, + &parse_quote! {_self.#rel_name}, + false, + mir, + ); + + let agg_func = &agg.aggregator; + let _self = quote! { _self }; + quote_spanned! {agg.span=> + let __aggregated_rel = #rel_expr; + let __matching = __aggregated_rel.index_get( &#selected_args_tuple); + let __agg_args = __matching.into_iter().flatten().map(|__val| { + #vars_assignments + #agg_args_tuple + }); + for #pat in #agg_func(__agg_args) { + #next_loop + } + + } + }, + } + } else { + quote! { + // let before_update = ::ascent::internal::Instant::now(); + #head_update_code + // let update_took = before_update.elapsed(); + // _self.update_time_nanos.fetch_add(update_took.as_nanos() as u64, std::sync::atomic::Ordering::Relaxed); + } + } +} + +fn head_update_code(rule: &MirRule, scc: &MirScc, mir: &AscentMir) -> proc_macro2::TokenStream { + let mut add_rows = vec![]; + + let set_changed_true_code = if !mir.is_parallel { + quote! { __changed = true; } + } else { + quote! { __changed.store(true, std::sync::atomic::Ordering::Relaxed);} + }; + + for hcl in rule.head_clause.iter() { + let head_rel_name = Ident::new(&hcl.rel.name.to_string(), hcl.span); + let hcl_args_converted = hcl.args.iter().cloned().map(convert_head_arg).collect_vec(); + let new_row_tuple = tuple_spanned(&hcl_args_converted, hcl.args_span); + + let head_relation = &hcl.rel; + let row_type = tuple_type(&head_relation.field_types); + + let mut update_indices = vec![]; + let rel_indices = scc.dynamic_relations.get(head_relation); + let (rel_index_write_trait, index_insert_fn) = if !mir.is_parallel { + (quote! { ::ascent::internal::RelIndexWrite }, quote! {index_insert}) + } else { + (quote! { ::ascent::internal::CRelIndexWrite }, quote! {index_insert}) + }; + let (rel_index_write_trait, index_insert_fn) = + (rel_index_write_trait.with_span(hcl.span), index_insert_fn.with_span(hcl.span)); + let new_ref = if !mir.is_parallel { + quote! {&mut} + } else { + quote! {&} + }; + let mut used_fields = HashSet::new(); + if let Some(rel_indices) = rel_indices { + for rel_ind in rel_indices.iter().sorted_by_cached_key(|rel| rel.ir_name()) { + if rel_ind.is_full_index() { + continue + }; + let var_name = if !mir.is_parallel { + expr_for_rel_write(&MirRelation::from(rel_ind.clone(), New), mir) + } else { + expr_for_c_rel_write(&MirRelation::from(rel_ind.clone(), New), mir) + }; + let args_tuple: Vec = rel_ind + .indices + .iter() + .map(|&i| { + let i_ind = syn::Index::from(i); + syn::parse2(quote_spanned! {hcl.span=> __new_row.#i_ind.clone()}).unwrap() + }) + .collect(); + used_fields.extend(rel_ind.indices.iter().cloned()); + if let IndexValType::Direct(direct) = &rel_ind.val_type { + used_fields.extend(direct.iter().cloned()); + } + let args_tuple = tuple(&args_tuple); + let entry_val = index_get_entry_val_for_insert( + rel_ind, + &parse_quote_spanned! {hcl.span=> __new_row}, + &parse_quote_spanned! {hcl.span=> __new_row_ind}, + ); + update_indices.push(quote_spanned! {hcl.span=> + #rel_index_write_trait::#index_insert_fn(#new_ref #var_name, #args_tuple, #entry_val); + }); + } + } + + let head_rel_full_index = &mir.relations_full_indices[head_relation]; + + let expr_for_rel_maybe_mut = if mir.is_parallel { expr_for_c_rel_write } else { expr_for_rel_write }; + let head_rel_full_index_expr_new = + expr_for_rel_maybe_mut(&MirRelation::from(head_rel_full_index.clone(), New), mir); + let head_rel_full_index_expr_delta = expr_for_rel(&MirRelation::from(head_rel_full_index.clone(), Delta), mir); + let head_rel_full_index_expr_total = expr_for_rel(&MirRelation::from(head_rel_full_index.clone(), Total), mir); + + let rel_full_index_write_trait = if !mir.is_parallel { + quote! { ::ascent::internal::RelFullIndexWrite } + } else { + quote! { ::ascent::internal::CRelFullIndexWrite } + } + .with_span(hcl.span); + + let new_row_to_be_pushed = (0..hcl.rel.field_types.len()) + .map(|i| { + let ind = syn::Index::from(i); + let clone = if used_fields.contains(&i) { + quote! {.clone()} + } else { + quote! {} + }; + parse_quote_spanned! {hcl.span=> __new_row.#ind #clone } + }) + .collect_vec(); + let new_row_to_be_pushed = tuple_spanned(&new_row_to_be_pushed, hcl.span); + + let push_code = if !mir.is_parallel { + quote! { + let __new_row_ind = _self.#head_rel_name.len(); + _self.#head_rel_name.push(#new_row_to_be_pushed); + } + } else { + quote! { + let __new_row_ind = _self.#head_rel_name.push(#new_row_to_be_pushed); + } + }; + + if !hcl.rel.is_lattice { + let add_row = quote_spanned! {hcl.span=> + let __new_row: #row_type = #new_row_tuple; + + if !::ascent::internal::RelFullIndexRead::contains_key(&#head_rel_full_index_expr_total, &__new_row) && + !::ascent::internal::RelFullIndexRead::contains_key(&#head_rel_full_index_expr_delta, &__new_row) { + if #rel_full_index_write_trait::insert_if_not_present(#new_ref #head_rel_full_index_expr_new, + &__new_row, ()) + { + #push_code + #(#update_indices)* + #set_changed_true_code + } + } + }; + add_rows.push(add_row); + } else { + // rel.is_lattice: + let lattice_insertion_mutex = lattice_insertion_mutex_var_name(head_relation); + let head_lat_full_index = &mir.lattices_full_indices[head_relation]; + let head_lat_full_index_var_name_new = ir_relation_version_var_name(&head_lat_full_index.ir_name(), New); + let head_lat_full_index_var_name_delta = ir_relation_version_var_name(&head_lat_full_index.ir_name(), Delta); + let head_lat_full_index_var_name_full = ir_relation_version_var_name(&head_lat_full_index.ir_name(), Total); + let tuple_lat_index = syn::Index::from(hcl.rel.field_types.len() - 1); + let lattice_key_args: Vec = (0..hcl.args.len() - 1) + .map(|i| { + let i_ind = syn::Index::from(i); + syn::parse2(quote_spanned! {hcl.span=> __new_row.#i_ind}).unwrap() + }) + .map(|e| exp_cloned(&e)) + .collect_vec(); + let lattice_key_tuple = tuple(&lattice_key_args); + + let _self = quote! { _self }; + let add_row = if !mir.is_parallel { + quote_spanned! {hcl.span=> + let __new_row: #row_type = #new_row_tuple; + let __lattice_key = #lattice_key_tuple; + if let Some(mut __existing_ind) = #head_lat_full_index_var_name_new.index_get(&__lattice_key) + .or_else(|| #head_lat_full_index_var_name_delta.index_get(&__lattice_key)) + .or_else(|| #head_lat_full_index_var_name_full.index_get(&__lattice_key)) + { + let __existing_ind = *__existing_ind.next().unwrap(); + // TODO possible excessive cloning here? + let __lat_changed = ::ascent::Lattice::join_mut(&mut #_self.#head_rel_name[__existing_ind].#tuple_lat_index, __new_row.#tuple_lat_index.clone()); + if __lat_changed { + let __new_row_ind = __existing_ind; + #(#update_indices)* + #set_changed_true_code + } + } else { + let __new_row_ind = #_self.#head_rel_name.len(); + #(#update_indices)* + #_self.#head_rel_name.push(#new_row_to_be_pushed); + #set_changed_true_code + } + } + } else { + quote_spanned! {hcl.span=> // mir.is_parallel: + let __new_row: #row_type = #new_row_tuple; + let __lattice_key = #lattice_key_tuple; + let __existing_ind_in_new = #head_lat_full_index_var_name_new.get_cloned(&__lattice_key); + let __new_has_ind = __existing_ind_in_new.is_some(); + if let Some(__existing_ind) = __existing_ind_in_new + .or_else(|| #head_lat_full_index_var_name_delta.get_cloned(&__lattice_key)) + .or_else(|| #head_lat_full_index_var_name_full.get_cloned(&__lattice_key)) + { + let __lat_changed = ::ascent::Lattice::join_mut(&mut #_self.#head_rel_name[__existing_ind].write().unwrap().#tuple_lat_index, + __new_row.#tuple_lat_index.clone()); + if __lat_changed && !__new_has_ind { + let __new_row_ind = __existing_ind; + #(#update_indices)* + #set_changed_true_code + } + } else { + let __hash = #head_lat_full_index_var_name_new.hash_usize(&__lattice_key); + let __lock = #_self.#lattice_insertion_mutex.get(__hash % #_self.#lattice_insertion_mutex.len()).expect("lattice_insertion_mutex index out of bounds").lock().unwrap(); + if let Some(__existing_ind) = #head_lat_full_index_var_name_new.get_cloned(&__lattice_key) { + ::ascent::Lattice::join_mut(&mut #_self.#head_rel_name[__existing_ind].write().unwrap().#tuple_lat_index, + __new_row.#tuple_lat_index.clone()); + } else { + let __new_row_ind = #_self.#head_rel_name.push(::std::sync::RwLock::new(#new_row_to_be_pushed)); + #(#update_indices)* + #set_changed_true_code + } + } + } + }; + add_rows.push(add_row); + } + } + quote! {#(#add_rows)*} +} + +fn lattice_insertion_mutex_var_name(head_relation: &RelationIdentity) -> Ident { + Ident::new(&format!("__{}_mutex", head_relation.name), head_relation.name.span()) +} + +fn rel_ind_common_var_name(relation: &RelationIdentity) -> Ident { + Ident::new(&format!("__{}_ind_common", relation.name), relation.name.span()) +} + +fn convert_head_arg(arg: Expr) -> Expr { + if let Some(var) = expr_to_ident(&arg) { + parse2(quote_spanned! {arg.span()=> ascent::internal::Convert::convert(#var)}).unwrap() + } else { + arg + } +} + +fn expr_for_rel(rel: &MirRelation, mir: &AscentMir) -> proc_macro2::TokenStream { + fn expr_for_rel_inner( + ir_name: &Ident, version: MirRelationVersion, _mir: &AscentMir, mir_rel: &MirRelation, + ) -> (TokenStream, bool) { + let var = ir_relation_version_var_name(ir_name, version); + if mir_rel.relation.is_lattice { + (quote! { & #var }, true) + } else { + let rel_ind_common = ir_relation_version_var_name(&rel_ind_common_var_name(&mir_rel.relation), version); + (quote! { #var.to_rel_index(& #rel_ind_common) }, false) + } + } + + if rel.version == MirRelationVersion::TotalDelta { + let total_expr = expr_for_rel_inner(&rel.ir_name, MirRelationVersion::Total, mir, rel).0; + let delta_expr = expr_for_rel_inner(&rel.ir_name, MirRelationVersion::Delta, mir, rel).0; + quote! { + ascent::internal::RelIndexCombined::new(& #total_expr, & #delta_expr) + } + } else { + let (res, borrowed) = expr_for_rel_inner(&rel.ir_name, rel.version, mir, rel); + if !borrowed { + res + } else { + quote! {(#res)} + } + } +} + +fn expr_for_rel_write(mir_rel: &MirRelation, _mir: &AscentMir) -> proc_macro2::TokenStream { + let var = mir_rel.var_name(); + if mir_rel.relation.is_lattice { + quote! { #var } + } else { + let rel_ind_common = ir_relation_version_var_name(&rel_ind_common_var_name(&mir_rel.relation), mir_rel.version); + quote! { #var.to_rel_index_write(&mut #rel_ind_common) } + } +} + +fn expr_for_c_rel_write(mir_rel: &MirRelation, _mir: &AscentMir) -> proc_macro2::TokenStream { + let var = mir_rel.var_name(); + if mir_rel.relation.is_lattice { + quote! { #var } + } else { + let rel_ind_common = ir_relation_version_var_name(&rel_ind_common_var_name(&mir_rel.relation), mir_rel.version); + quote! { #var.to_c_rel_index_write(&#rel_ind_common) } + } +} + +fn clause_var_assignments( + rel: &MirRelation, vars: impl Iterator, val_ident: &Ident, relation_expr: &Expr, + cloning_needed: bool, mir: &AscentMir, +) -> proc_macro2::TokenStream { + let mut assignments = vec![]; + + let mut any_vars = false; + for (ind_in_tuple, var) in vars { + let var_type_ascription = { + let ty = &rel.relation.field_types[ind_in_tuple]; + quote! { : & #ty} + }; + any_vars = true; + match &rel.val_type { + IndexValType::Reference => { + let ind = syn::Index::from(ind_in_tuple); + assignments.push(quote! { + let #var #var_type_ascription = &__row.#ind; + }) + }, + IndexValType::Direct(inds) => { + let ind = inds.iter().enumerate().find(|(_i, ind)| **ind == ind_in_tuple).unwrap().0; + let ind = syn::Index::from(ind); + + assignments.push(quote! { + let #var #var_type_ascription = #val_ident.#ind; + }) + }, + } + } + + if any_vars { + match &rel.val_type { + IndexValType::Reference => { + let maybe_lock = if rel.relation.is_lattice && mir.is_parallel { + quote! {.read().unwrap()} + } else { + quote! {} + }; + let maybe_clone = if cloning_needed { + quote! {.clone()} + } else { + quote! {} + }; + assignments.insert(0, quote! { + let __row = &#relation_expr[*#val_ident]#maybe_lock #maybe_clone; + }); + }, + IndexValType::Direct(_) => { + assignments.insert(0, quote! { + let #val_ident = #val_ident.tuple_of_borrowed(); + }); + }, + } + } + + quote! { + #(#assignments)* + } +} + +fn index_get_entry_val_for_insert(rel_ind: &IrRelation, tuple_expr: &Expr, ind_expr: &Expr) -> Expr { + match &rel_ind.val_type { + IndexValType::Reference => ind_expr.clone(), + IndexValType::Direct(inds) => { + let val = inds + .iter() + .map(|&ind| { + let ind = syn::Index::from(ind); + parse_quote! { + #tuple_expr.#ind.clone() + } + }) + .collect_vec(); + tuple(&val) + }, + } +} diff --git a/ascent_macro/src/ascent_hir.rs b/ascent_macro/src/ascent_hir.rs index 74dea24..b2c095d 100644 --- a/ascent_macro/src/ascent_hir.rs +++ b/ascent_macro/src/ascent_hir.rs @@ -1,489 +1,524 @@ -#![deny(warnings)] -use std::{collections::{HashMap, HashSet}, rc::Rc}; - -use itertools::Itertools; -use proc_macro2::{Ident, Span, TokenStream}; -use syn::{Attribute, Error, Expr, Pat, Type, parse2, spanned::Spanned, parse_quote, Path}; - -use crate::{AscentProgram, ascent_syntax::{RelationNode, DsAttributeContents, Signatures}, utils::{expr_to_ident, is_wild_card, tuple_type}, syn_utils::{expr_get_vars, pattern_get_vars}}; -use crate::ascent_syntax::{BodyClauseArg, BodyItemNode, CondClause, GeneratorNode, RelationIdentity, RuleNode}; - -#[derive(Clone)] -pub(crate) struct AscentConfig { - #[allow(dead_code)] - pub attrs: Vec, - pub include_rule_times: bool, - pub generate_run_partial: bool, - pub inter_rule_parallelism: bool, - pub default_ds: DsAttributeContents, -} - -impl AscentConfig { - const MEASURE_RULE_TIMES_ATTR: &'static str = "measure_rule_times"; - const GENERATE_RUN_TIMEOUT_ATTR: &'static str = "generate_run_timeout"; - const INTER_RULE_PARALLELISM_ATTR: &'static str = "inter_rule_parallelism"; - - pub fn new(attrs: Vec, is_parallel: bool) -> syn::Result { - let include_rule_times = attrs.iter().find(|attr| attr.meta.path().is_ident(Self::MEASURE_RULE_TIMES_ATTR)) - .map(|attr| attr.meta.require_path_only()).transpose()?.is_some(); - let generate_run_partial = attrs.iter().find(|attr| attr.meta.path().is_ident(Self::GENERATE_RUN_TIMEOUT_ATTR)) - .map(|attr| attr.meta.require_path_only()).transpose()?.is_some(); - let inter_rule_parallelism = attrs.iter().find(|attr| attr.meta.path().is_ident(Self::INTER_RULE_PARALLELISM_ATTR)) - .map(|attr| attr.meta.require_path_only()).transpose()?; - - let recognized_attrs = - [Self::MEASURE_RULE_TIMES_ATTR, Self::GENERATE_RUN_TIMEOUT_ATTR, Self::INTER_RULE_PARALLELISM_ATTR, REL_DS_ATTR]; - for attr in attrs.iter() { - if !recognized_attrs.iter().any(|recognized_attr| attr.meta.path().is_ident(recognized_attr)) { - return Err(Error::new_spanned(attr, - format!("unrecognized attribute. recognized attributes are: {}", - recognized_attrs.join(", ")))); - } - } - if inter_rule_parallelism.is_some() && !is_parallel { - return Err(Error::new_spanned(inter_rule_parallelism, "attribute only allowed in parallel Ascent")); - } - let default_ds = get_ds_attr(&attrs)?.unwrap_or_else(|| - DsAttributeContents { path: parse_quote! {::ascent::rel}, args: TokenStream::default() } - ); - Ok(AscentConfig { - inter_rule_parallelism: inter_rule_parallelism.is_some(), - attrs, - include_rule_times, - generate_run_partial, - default_ds - }) - } -} - -pub(crate) struct AscentIr { - pub relations_ir_relations: HashMap>, - pub relations_full_indices: HashMap, - pub lattices_full_indices: HashMap, - // pub relations_no_indices: HashMap, - pub relations_metadata: HashMap, - pub rules: Vec, - pub signatures: Signatures, - pub config: AscentConfig, - pub is_parallel: bool, -} - -#[derive(Clone)] -pub(crate) struct RelationMetadata{ - pub initialization: Option>, - pub attributes: Rc>, - pub ds_macro_path: Path, - pub ds_macro_args: TokenStream -} - -pub(crate) struct IrRule { - pub head_clauses: Vec, - pub body_items: Vec, - pub simple_join_start_index: Option -} - -#[allow(unused)] -pub(crate) fn ir_rule_summary(rule: &IrRule) -> String { - fn bitem_to_str(bi: &IrBodyItem) -> String { - match bi { - IrBodyItem::Clause(cl) => cl.rel.ir_name().to_string(), - IrBodyItem::Generator(_) => "for ⋯".into(), - IrBodyItem::Cond(CondClause::If(..)) => format!("if ⋯"), - IrBodyItem::Cond(CondClause::IfLet(..)) => format!("if let ⋯"), - IrBodyItem::Cond(CondClause::Let(..)) => format!("let ⋯"), - IrBodyItem::Agg(agg) => format!("agg {}", agg.rel.ir_name()), - } - } - format!("{} <-- {}", - rule.head_clauses.iter().map(|hcl| hcl.rel.name.to_string()).join(", "), - rule.body_items.iter().map(bitem_to_str).join(", ")) -} - -#[derive(Clone)] -pub(crate) struct IrHeadClause{ - pub rel : RelationIdentity, - pub args : Vec, - pub span: Span, - pub args_span: Span, -} - -pub(crate) enum IrBodyItem { - Clause(IrBodyClause), - Generator(GeneratorNode), - Cond(CondClause), - Agg(IrAggClause) -} - -impl IrBodyItem { - pub(crate) fn rel(&self) -> Option<&IrRelation> { - match self { - IrBodyItem::Clause(bcl) => Some(&bcl.rel), - IrBodyItem::Agg(agg) => Some(&agg.rel), - IrBodyItem::Generator(_) | - IrBodyItem::Cond(_) => None, - } - } -} - -#[derive(Clone)] -pub(crate) struct IrBodyClause { - pub rel : IrRelation, - pub args : Vec, - pub rel_args_span: Span, - pub args_span: Span, - pub cond_clauses : Vec -} - -impl IrBodyClause { - #[allow(dead_code)] - pub fn selected_args(&self) -> Vec { - self.rel.indices.iter().map(|&i| self.args[i].clone()).collect() - } -} - -#[derive(Clone)] -pub(crate) struct IrAggClause { - pub span: Span, - pub pat: Pat, - pub aggregator: Expr, - pub bound_args: Vec, - pub rel: IrRelation, - pub rel_args: Vec -} - -#[derive(Clone, Hash, PartialEq, Eq, Debug)] -pub(crate) struct IrRelation { - pub relation: RelationIdentity, - pub indices: Vec, - pub val_type: IndexValType, -} - -#[derive(Clone, PartialEq, Eq, Hash, Debug)] -pub enum IndexValType { - Reference, - Direct(Vec) -} - -impl IrRelation { - pub fn new(relation: RelationIdentity, indices: Vec) -> Self { - // TODO this is not the right place for this - let val_type = if relation.is_lattice //|| indices.len() == relation.field_types.len() - { - IndexValType::Reference - } else { - IndexValType::Direct((0..relation.field_types.len()).filter(|i| !indices.contains(i)).collect_vec()) - }; - IrRelation { relation, indices, val_type } - } - - pub fn key_type(&self) -> Type { - let index_types : Vec<_> = self.indices.iter().map(|&i| self.relation.field_types[i].clone()).collect(); - tuple_type(&index_types) - } - pub fn ir_name(&self) -> Ident { - ir_name_for_rel_indices(&self.relation.name, &self.indices) - } - pub fn is_full_index(&self) -> bool { - self.relation.field_types.len() == self.indices.len() - } - pub fn is_no_index(&self) -> bool { - self.indices.is_empty() - } - - pub fn value_type(&self) -> Type { - match &self.val_type { - IndexValType::Reference => parse_quote!{usize}, - IndexValType::Direct(cols) => { - let index_types : Vec<_> = cols.iter().map(|&i| self.relation.field_types[i].clone()).collect(); - tuple_type(&index_types) - }, - } - } -} - -const REL_DS_ATTR: &str = "ds"; -const RECOGNIIZED_REL_ATTRS: [&str; 1] = [REL_DS_ATTR]; - -pub(crate) fn compile_ascent_program_to_hir(prog: &AscentProgram, is_parallel: bool) -> syn::Result{ - let ir_rules : Vec<(IrRule, Vec)> = prog.rules.iter().map(|r| compile_rule_to_ir_rule(r, prog)).try_collect()?; - let config = AscentConfig::new(prog.attributes.clone(), is_parallel)?; - let num_relations = prog.relations.len(); - let mut relations_ir_relations: HashMap> = HashMap::with_capacity(num_relations); - let mut relations_full_indices = HashMap::with_capacity(num_relations); - let mut relations_initializations = HashMap::new(); - let mut relations_metadata = HashMap::with_capacity(num_relations); - // let mut relations_no_indices = HashMap::new(); - let mut lattices_full_indices = HashMap::new(); - for rel in prog.relations.iter(){ - let rel_identity = RelationIdentity::from(rel); - - if rel.is_lattice { - let indices = (0 .. rel_identity.field_types.len() - 1).collect_vec(); - let lat_full_index = IrRelation::new(rel_identity.clone(), indices); - relations_ir_relations.entry(rel_identity.clone()).or_default().insert(lat_full_index.clone()); - lattices_full_indices.insert(rel_identity.clone(), lat_full_index); - } - - let full_indices = (0 .. rel_identity.field_types.len()).collect_vec(); - let rel_full_index = IrRelation::new(rel_identity.clone(),full_indices); - - relations_ir_relations.entry(rel_identity.clone()).or_default().insert(rel_full_index.clone()); - // relations_ir_relations.entry(rel_identity.clone()).or_default().insert(rel_no_index.clone()); - relations_full_indices.insert(rel_identity.clone(), rel_full_index); - if let Some(init_expr) = &rel.initialization { - relations_initializations.insert(rel_identity.clone(), Rc::new(init_expr.clone())); - } - let ds_attribute = get_ds_attr(&rel.attrs)?.unwrap_or_else(|| config.default_ds.clone()); - - relations_metadata.insert( - rel_identity.clone(), - RelationMetadata { - initialization: rel.initialization.clone().map(Rc::new), - attributes: Rc::new(rel.attrs.iter().filter(|attr| attr.meta.path().get_ident().map_or(true, |ident| !RECOGNIIZED_REL_ATTRS.iter().any(|ra| ident == ra))).cloned().collect_vec()), - ds_macro_path: ds_attribute.path, - ds_macro_args: ds_attribute.args - } - ); - // relations_no_indices.insert(rel_identity, rel_no_index); - } - for (ir_rule, extra_relations) in ir_rules.iter(){ - for bitem in ir_rule.body_items.iter(){ - let rel = match bitem { - IrBodyItem::Clause(bcl) => Some(&bcl.rel), - IrBodyItem::Agg(agg) => Some(&agg.rel), - _ => None - }; - if let Some(rel) = rel { - let relation = &rel.relation; - relations_ir_relations.entry(relation.clone()).or_default().insert(rel.clone()); - } - } - for extra_rel in extra_relations.iter(){ - relations_ir_relations.entry(extra_rel.relation.clone()).or_default().insert(extra_rel.clone()); - } - } - let signatures = prog.signatures.clone().unwrap_or_else(|| parse2(quote! {pub struct AscentProgram;}).unwrap()); - Ok(AscentIr { - rules: ir_rules.into_iter().map(|(rule, _extra_rels)| rule).collect_vec(), - relations_ir_relations, - relations_full_indices, - lattices_full_indices, - relations_metadata, - // relations_no_indices, - signatures, - config, - is_parallel - }) -} - -fn get_ds_attr(attrs: &[Attribute]) -> syn::Result> { - let ds_attrs = attrs.iter() - .filter(|attr| attr.meta.path().get_ident().map_or(false, |ident| ident == REL_DS_ATTR)) - .collect_vec(); - match &ds_attrs[..] { - [] => Ok(None), - [attr] => { - let res = syn::parse2::(attr.meta.require_list()?.tokens.clone())?; - Ok(Some(res)) - }, - [_attr1, attr2, ..] => Err(Error::new(attr2.bracket_token.span.join(), "multiple `ds` attributes specified")), - } -} - -fn compile_rule_to_ir_rule(rule: &RuleNode, prog: &AscentProgram) -> syn::Result<(IrRule, Vec)> { - let mut body_items = vec![]; - let mut grounded_vars = vec![]; - fn extend_grounded_vars(grounded_vars: &mut Vec, new_vars: impl IntoIterator) -> syn::Result<()> { - for v in new_vars.into_iter() { - if grounded_vars.contains(&v) { - // TODO may someday this will work - let other_var = grounded_vars.iter().find(|&x| x == &v).unwrap(); - let other_err = Error::new(other_var.span(), "variable being shadowed"); - let mut err = Error::new(v.span(), format!("'{}' shadows another variable with the same name", v)); - err.combine(other_err); - return Err(err); - } - grounded_vars.push(v); - } - Ok(()) - } - - let first_clause_ind = rule.body_items.iter().enumerate().find(|(_, bi)| matches!(bi, BodyItemNode::Clause(..))).map(|(i, _)| i); - let mut first_two_clauses_simple = first_clause_ind.is_some() && - matches!(rule.body_items.get(first_clause_ind.unwrap() + 1), Some(BodyItemNode::Clause(..))); - for (bitem_ind, bitem) in rule.body_items.iter().enumerate() { - match bitem { - BodyItemNode::Clause(ref bcl) => { - if first_clause_ind == Some(bitem_ind) && bcl.cond_clauses.iter().any(|c| matches!(c, &CondClause::IfLet(_))) - { - first_two_clauses_simple = false; - } - - if first_clause_ind.map(|x| x + 1) == Some(bitem_ind) && first_two_clauses_simple{ - let mut self_vars = HashSet::new(); - for var in bcl.args.iter().filter_map(|arg| expr_to_ident(arg.unwrap_expr_ref())) { - if !self_vars.insert(var) { - first_two_clauses_simple = false; - } - } - for cond_cl in bcl.cond_clauses.iter(){ - let cond_expr = cond_cl.expr(); - let expr_idents = expr_get_vars(cond_expr); - if !expr_idents.iter().all(|v| self_vars.contains(v)){ - first_two_clauses_simple = false; - break; - } - self_vars.extend(cond_cl.bound_vars()); - } - } - let mut indices = vec![]; - for (i,arg) in bcl.args.iter().enumerate() { - if let Some(var) = expr_to_ident(arg.unwrap_expr_ref()) { - if grounded_vars.contains(&var){ - indices.push(i); - if first_clause_ind == Some(bitem_ind) { - first_two_clauses_simple = false; - } - } else { - grounded_vars.push(var); - } - } else { - indices.push(i); - if bitem_ind < 2 + first_clause_ind.unwrap_or(0) { - first_two_clauses_simple = false; - } - } - } - let relation = prog_get_relation(prog, &bcl.rel, bcl.args.len())?; - - for cond_clause in bcl.cond_clauses.iter() { - extend_grounded_vars(&mut grounded_vars, cond_clause.bound_vars())?; - } - - let ir_rel = IrRelation::new(relation.into(), indices); - let ir_bcl = IrBodyClause { - rel: ir_rel, - args: bcl.args.iter().cloned().map(BodyClauseArg::unwrap_expr).collect(), - rel_args_span: bcl.rel.span().join(bcl.args.span()).unwrap_or_else(|| bcl.rel.span()), - args_span: bcl.args.span(), - cond_clauses: bcl.cond_clauses.clone() - }; - body_items.push(IrBodyItem::Clause(ir_bcl)); - }, - BodyItemNode::Generator(ref gen) => { - extend_grounded_vars(&mut grounded_vars, pattern_get_vars(&gen.pattern))?; - body_items.push(IrBodyItem::Generator(gen.clone())); - }, - BodyItemNode::Cond(ref cl) => { - body_items.push(IrBodyItem::Cond(cl.clone())); - extend_grounded_vars(&mut grounded_vars, cl.bound_vars())?; - }, - BodyItemNode::Agg(ref agg) => { - extend_grounded_vars(&mut grounded_vars, pattern_get_vars(&agg.pat))?; - let indices = agg.rel_args.iter().enumerate().filter(|(_i, expr)| { - if is_wild_card(expr) { - return false; - } else if let Some(ident) = expr_to_ident(expr) { - if agg.bound_args.iter().contains(&ident) { - return false; - } - } - true - }).map(|(i, _expr)| i).collect_vec(); - let relation = prog_get_relation(prog, &agg.rel, agg.rel_args.len())?; - - let ir_rel = IrRelation::new(relation.into(), indices); - let ir_agg_clause = IrAggClause { - span: agg.agg_kw.span, - pat: agg.pat.clone(), - aggregator: agg.aggregator.get_expr(), - bound_args: agg.bound_args.iter().cloned().collect_vec(), - rel: ir_rel, - rel_args: agg.rel_args.iter().cloned().collect_vec(), - }; - body_items.push(IrBodyItem::Agg(ir_agg_clause)); - }, - _ => panic!("unrecognized body item") - } - - } - let mut head_clauses = vec![]; - for hcl_node in rule.head_clauses.iter(){ - let hcl_node = hcl_node.clause(); - let rel = prog.relations.iter().find(|r| hcl_node.rel == r.name); - let rel = match rel { - Some(rel) => rel, - None => return Err(Error::new(hcl_node.rel.span(), format!("relation {} not defined", hcl_node.rel))), - }; - - let rel = RelationIdentity::from(rel); - let head_clause = IrHeadClause { - rel, - args : hcl_node.args.iter().cloned().collect(), - span: hcl_node.span(), - args_span: hcl_node.args.span() - }; - head_clauses.push(head_clause); - } - - let is_simple_join = first_two_clauses_simple && body_items.len() >= 2; - let simple_join_start_index = if is_simple_join {first_clause_ind} else {None}; - - let simple_join_ir_relations = if let Some(start_ind) = simple_join_start_index { - let (bcl1, bcl2) = match &body_items[start_ind..start_ind + 2] { - [IrBodyItem::Clause(bcl1), IrBodyItem::Clause(bcl2)] => (bcl1, bcl2), - _ => panic!("incorrect simple join handling in ascent_hir") - }; - let bcl2_vars = bcl2.args.iter().filter_map(expr_to_ident).collect_vec(); - let indices = get_indices_given_grounded_variables(&bcl1.args, &bcl2_vars); - let new_cl1_ir_relation = IrRelation::new(bcl1.rel.relation.clone(), indices); - vec![new_cl1_ir_relation] - } else {vec![]}; - - if let Some(start_ind) = simple_join_start_index { - if let IrBodyItem::Clause(cl1) = &mut body_items[start_ind] { - cl1.rel = simple_join_ir_relations[0].clone(); - } - } - - Ok((IrRule { - simple_join_start_index, - head_clauses, - body_items, - }, vec![])) -} - -pub fn ir_name_for_rel_indices(rel: &Ident, indices: &[usize]) -> Ident { - let indices_str = if indices.is_empty() {format!("none")} else {indices.iter().join("_")}; - let name = format!("{}_indices_{}", rel, indices_str); - Ident::new(&name, rel.span()) -} - -/// for a clause with args, returns the indices assuming vars are grounded. -pub fn get_indices_given_grounded_variables(args: &[Expr], vars: &[Ident]) -> Vec{ - let mut res = vec![]; - for (i, arg) in args.iter().enumerate(){ - if let Some(arg_var) = expr_to_ident(arg){ - if vars.contains(&arg_var) { - res.push(i); - } - } else { - res.push(i); - } - } - res -} - -pub(crate) fn prog_get_relation<'a>(prog: &'a AscentProgram, name: &Ident, arity: usize) -> syn::Result<&'a RelationNode> { - let relation = prog.relations.iter().find(|r| name == &r.name); - match relation { - Some(rel) => { - if rel.field_types.len() != arity { - Err(Error::new(name.span(), format!("Wrong arity for relation {}. Actual arity: {}", name, rel.field_types.len()))) - } else { - Ok(rel) - } - }, - None => Err(Error::new(name.span(), format!("Relation {} not defined", name))), - } -} \ No newline at end of file +#![deny(warnings)] +use std::collections::{HashMap, HashSet}; +use std::rc::Rc; + +use itertools::Itertools; +use proc_macro2::{Ident, Span, TokenStream}; +use syn::spanned::Spanned; +use syn::{Attribute, Error, Expr, Pat, Path, Type, parse_quote, parse2}; + +use crate::AscentProgram; +use crate::ascent_syntax::{ + BodyClauseArg, BodyItemNode, CondClause, DsAttributeContents, GeneratorNode, RelationIdentity, RelationNode, + RuleNode, Signatures, +}; +use crate::syn_utils::{expr_get_vars, pattern_get_vars}; +use crate::utils::{expr_to_ident, is_wild_card, tuple_type}; + +#[derive(Clone)] +pub(crate) struct AscentConfig { + #[allow(dead_code)] + pub attrs: Vec, + pub include_rule_times: bool, + pub generate_run_partial: bool, + pub inter_rule_parallelism: bool, + pub default_ds: DsAttributeContents, +} + +impl AscentConfig { + const MEASURE_RULE_TIMES_ATTR: &'static str = "measure_rule_times"; + const GENERATE_RUN_TIMEOUT_ATTR: &'static str = "generate_run_timeout"; + const INTER_RULE_PARALLELISM_ATTR: &'static str = "inter_rule_parallelism"; + + pub fn new(attrs: Vec, is_parallel: bool) -> syn::Result { + let include_rule_times = attrs + .iter() + .find(|attr| attr.meta.path().is_ident(Self::MEASURE_RULE_TIMES_ATTR)) + .map(|attr| attr.meta.require_path_only()) + .transpose()? + .is_some(); + let generate_run_partial = attrs + .iter() + .find(|attr| attr.meta.path().is_ident(Self::GENERATE_RUN_TIMEOUT_ATTR)) + .map(|attr| attr.meta.require_path_only()) + .transpose()? + .is_some(); + let inter_rule_parallelism = attrs + .iter() + .find(|attr| attr.meta.path().is_ident(Self::INTER_RULE_PARALLELISM_ATTR)) + .map(|attr| attr.meta.require_path_only()) + .transpose()?; + + let recognized_attrs = [ + Self::MEASURE_RULE_TIMES_ATTR, + Self::GENERATE_RUN_TIMEOUT_ATTR, + Self::INTER_RULE_PARALLELISM_ATTR, + REL_DS_ATTR, + ]; + for attr in attrs.iter() { + if !recognized_attrs.iter().any(|recognized_attr| attr.meta.path().is_ident(recognized_attr)) { + return Err(Error::new_spanned( + attr, + format!("unrecognized attribute. recognized attributes are: {}", recognized_attrs.join(", ")), + )); + } + } + if inter_rule_parallelism.is_some() && !is_parallel { + return Err(Error::new_spanned(inter_rule_parallelism, "attribute only allowed in parallel Ascent")); + } + let default_ds = get_ds_attr(&attrs)? + .unwrap_or_else(|| DsAttributeContents { path: parse_quote! {::ascent::rel}, args: TokenStream::default() }); + Ok(AscentConfig { + inter_rule_parallelism: inter_rule_parallelism.is_some(), + attrs, + include_rule_times, + generate_run_partial, + default_ds, + }) + } +} + +pub(crate) struct AscentIr { + pub relations_ir_relations: HashMap>, + pub relations_full_indices: HashMap, + pub lattices_full_indices: HashMap, + // pub relations_no_indices: HashMap, + pub relations_metadata: HashMap, + pub rules: Vec, + pub signatures: Signatures, + pub config: AscentConfig, + pub is_parallel: bool, +} + +#[derive(Clone)] +pub(crate) struct RelationMetadata { + pub initialization: Option>, + pub attributes: Rc>, + pub ds_macro_path: Path, + pub ds_macro_args: TokenStream, +} + +pub(crate) struct IrRule { + pub head_clauses: Vec, + pub body_items: Vec, + pub simple_join_start_index: Option, +} + +#[allow(unused)] +pub(crate) fn ir_rule_summary(rule: &IrRule) -> String { + fn bitem_to_str(bi: &IrBodyItem) -> String { + match bi { + IrBodyItem::Clause(cl) => cl.rel.ir_name().to_string(), + IrBodyItem::Generator(_) => "for ⋯".into(), + IrBodyItem::Cond(CondClause::If(..)) => format!("if ⋯"), + IrBodyItem::Cond(CondClause::IfLet(..)) => format!("if let ⋯"), + IrBodyItem::Cond(CondClause::Let(..)) => format!("let ⋯"), + IrBodyItem::Agg(agg) => format!("agg {}", agg.rel.ir_name()), + } + } + format!( + "{} <-- {}", + rule.head_clauses.iter().map(|hcl| hcl.rel.name.to_string()).join(", "), + rule.body_items.iter().map(bitem_to_str).join(", ") + ) +} + +#[derive(Clone)] +pub(crate) struct IrHeadClause { + pub rel: RelationIdentity, + pub args: Vec, + pub span: Span, + pub args_span: Span, +} + +pub(crate) enum IrBodyItem { + Clause(IrBodyClause), + Generator(GeneratorNode), + Cond(CondClause), + Agg(IrAggClause), +} + +impl IrBodyItem { + pub(crate) fn rel(&self) -> Option<&IrRelation> { + match self { + IrBodyItem::Clause(bcl) => Some(&bcl.rel), + IrBodyItem::Agg(agg) => Some(&agg.rel), + IrBodyItem::Generator(_) | IrBodyItem::Cond(_) => None, + } + } +} + +#[derive(Clone)] +pub(crate) struct IrBodyClause { + pub rel: IrRelation, + pub args: Vec, + pub rel_args_span: Span, + pub args_span: Span, + pub cond_clauses: Vec, +} + +impl IrBodyClause { + #[allow(dead_code)] + pub fn selected_args(&self) -> Vec { self.rel.indices.iter().map(|&i| self.args[i].clone()).collect() } +} + +#[derive(Clone)] +pub(crate) struct IrAggClause { + pub span: Span, + pub pat: Pat, + pub aggregator: Expr, + pub bound_args: Vec, + pub rel: IrRelation, + pub rel_args: Vec, +} + +#[derive(Clone, Hash, PartialEq, Eq, Debug)] +pub(crate) struct IrRelation { + pub relation: RelationIdentity, + pub indices: Vec, + pub val_type: IndexValType, +} + +#[derive(Clone, PartialEq, Eq, Hash, Debug)] +pub enum IndexValType { + Reference, + Direct(Vec), +} + +impl IrRelation { + pub fn new(relation: RelationIdentity, indices: Vec) -> Self { + // TODO this is not the right place for this + let val_type = if relation.is_lattice + //|| indices.len() == relation.field_types.len() + { + IndexValType::Reference + } else { + IndexValType::Direct((0..relation.field_types.len()).filter(|i| !indices.contains(i)).collect_vec()) + }; + IrRelation { relation, indices, val_type } + } + + pub fn key_type(&self) -> Type { + let index_types: Vec<_> = self.indices.iter().map(|&i| self.relation.field_types[i].clone()).collect(); + tuple_type(&index_types) + } + pub fn ir_name(&self) -> Ident { ir_name_for_rel_indices(&self.relation.name, &self.indices) } + pub fn is_full_index(&self) -> bool { self.relation.field_types.len() == self.indices.len() } + pub fn is_no_index(&self) -> bool { self.indices.is_empty() } + + pub fn value_type(&self) -> Type { + match &self.val_type { + IndexValType::Reference => parse_quote! {usize}, + IndexValType::Direct(cols) => { + let index_types: Vec<_> = cols.iter().map(|&i| self.relation.field_types[i].clone()).collect(); + tuple_type(&index_types) + }, + } + } +} + +const REL_DS_ATTR: &str = "ds"; +const RECOGNIIZED_REL_ATTRS: [&str; 1] = [REL_DS_ATTR]; + +pub(crate) fn compile_ascent_program_to_hir(prog: &AscentProgram, is_parallel: bool) -> syn::Result { + let ir_rules: Vec<(IrRule, Vec)> = + prog.rules.iter().map(|r| compile_rule_to_ir_rule(r, prog)).try_collect()?; + let config = AscentConfig::new(prog.attributes.clone(), is_parallel)?; + let num_relations = prog.relations.len(); + let mut relations_ir_relations: HashMap> = + HashMap::with_capacity(num_relations); + let mut relations_full_indices = HashMap::with_capacity(num_relations); + let mut relations_initializations = HashMap::new(); + let mut relations_metadata = HashMap::with_capacity(num_relations); + // let mut relations_no_indices = HashMap::new(); + let mut lattices_full_indices = HashMap::new(); + for rel in prog.relations.iter() { + let rel_identity = RelationIdentity::from(rel); + + if rel.is_lattice { + let indices = (0..rel_identity.field_types.len() - 1).collect_vec(); + let lat_full_index = IrRelation::new(rel_identity.clone(), indices); + relations_ir_relations.entry(rel_identity.clone()).or_default().insert(lat_full_index.clone()); + lattices_full_indices.insert(rel_identity.clone(), lat_full_index); + } + + let full_indices = (0..rel_identity.field_types.len()).collect_vec(); + let rel_full_index = IrRelation::new(rel_identity.clone(), full_indices); + + relations_ir_relations.entry(rel_identity.clone()).or_default().insert(rel_full_index.clone()); + // relations_ir_relations.entry(rel_identity.clone()).or_default().insert(rel_no_index.clone()); + relations_full_indices.insert(rel_identity.clone(), rel_full_index); + if let Some(init_expr) = &rel.initialization { + relations_initializations.insert(rel_identity.clone(), Rc::new(init_expr.clone())); + } + let ds_attribute = get_ds_attr(&rel.attrs)?.unwrap_or_else(|| config.default_ds.clone()); + + relations_metadata.insert(rel_identity.clone(), RelationMetadata { + initialization: rel.initialization.clone().map(Rc::new), + attributes: Rc::new( + rel.attrs + .iter() + .filter(|attr| { + attr.meta.path().get_ident().map_or(true, |ident| !RECOGNIIZED_REL_ATTRS.iter().any(|ra| ident == ra)) + }) + .cloned() + .collect_vec(), + ), + ds_macro_path: ds_attribute.path, + ds_macro_args: ds_attribute.args, + }); + // relations_no_indices.insert(rel_identity, rel_no_index); + } + for (ir_rule, extra_relations) in ir_rules.iter() { + for bitem in ir_rule.body_items.iter() { + let rel = match bitem { + IrBodyItem::Clause(bcl) => Some(&bcl.rel), + IrBodyItem::Agg(agg) => Some(&agg.rel), + _ => None, + }; + if let Some(rel) = rel { + let relation = &rel.relation; + relations_ir_relations.entry(relation.clone()).or_default().insert(rel.clone()); + } + } + for extra_rel in extra_relations.iter() { + relations_ir_relations.entry(extra_rel.relation.clone()).or_default().insert(extra_rel.clone()); + } + } + let signatures = prog.signatures.clone().unwrap_or_else(|| parse2(quote! {pub struct AscentProgram;}).unwrap()); + Ok(AscentIr { + rules: ir_rules.into_iter().map(|(rule, _extra_rels)| rule).collect_vec(), + relations_ir_relations, + relations_full_indices, + lattices_full_indices, + relations_metadata, + // relations_no_indices, + signatures, + config, + is_parallel, + }) +} + +fn get_ds_attr(attrs: &[Attribute]) -> syn::Result> { + let ds_attrs = attrs + .iter() + .filter(|attr| attr.meta.path().get_ident().map_or(false, |ident| ident == REL_DS_ATTR)) + .collect_vec(); + match &ds_attrs[..] { + [] => Ok(None), + [attr] => { + let res = syn::parse2::(attr.meta.require_list()?.tokens.clone())?; + Ok(Some(res)) + }, + [_attr1, attr2, ..] => Err(Error::new(attr2.bracket_token.span.join(), "multiple `ds` attributes specified")), + } +} + +fn compile_rule_to_ir_rule(rule: &RuleNode, prog: &AscentProgram) -> syn::Result<(IrRule, Vec)> { + let mut body_items = vec![]; + let mut grounded_vars = vec![]; + fn extend_grounded_vars( + grounded_vars: &mut Vec, new_vars: impl IntoIterator, + ) -> syn::Result<()> { + for v in new_vars.into_iter() { + if grounded_vars.contains(&v) { + // TODO may someday this will work + let other_var = grounded_vars.iter().find(|&x| x == &v).unwrap(); + let other_err = Error::new(other_var.span(), "variable being shadowed"); + let mut err = Error::new(v.span(), format!("'{}' shadows another variable with the same name", v)); + err.combine(other_err); + return Err(err); + } + grounded_vars.push(v); + } + Ok(()) + } + + let first_clause_ind = + rule.body_items.iter().enumerate().find(|(_, bi)| matches!(bi, BodyItemNode::Clause(..))).map(|(i, _)| i); + let mut first_two_clauses_simple = first_clause_ind.is_some() + && matches!(rule.body_items.get(first_clause_ind.unwrap() + 1), Some(BodyItemNode::Clause(..))); + for (bitem_ind, bitem) in rule.body_items.iter().enumerate() { + match bitem { + BodyItemNode::Clause(ref bcl) => { + if first_clause_ind == Some(bitem_ind) + && bcl.cond_clauses.iter().any(|c| matches!(c, &CondClause::IfLet(_))) + { + first_two_clauses_simple = false; + } + + if first_clause_ind.map(|x| x + 1) == Some(bitem_ind) && first_two_clauses_simple { + let mut self_vars = HashSet::new(); + for var in bcl.args.iter().filter_map(|arg| expr_to_ident(arg.unwrap_expr_ref())) { + if !self_vars.insert(var) { + first_two_clauses_simple = false; + } + } + for cond_cl in bcl.cond_clauses.iter() { + let cond_expr = cond_cl.expr(); + let expr_idents = expr_get_vars(cond_expr); + if !expr_idents.iter().all(|v| self_vars.contains(v)) { + first_two_clauses_simple = false; + break; + } + self_vars.extend(cond_cl.bound_vars()); + } + } + let mut indices = vec![]; + for (i, arg) in bcl.args.iter().enumerate() { + if let Some(var) = expr_to_ident(arg.unwrap_expr_ref()) { + if grounded_vars.contains(&var) { + indices.push(i); + if first_clause_ind == Some(bitem_ind) { + first_two_clauses_simple = false; + } + } else { + grounded_vars.push(var); + } + } else { + indices.push(i); + if bitem_ind < 2 + first_clause_ind.unwrap_or(0) { + first_two_clauses_simple = false; + } + } + } + let relation = prog_get_relation(prog, &bcl.rel, bcl.args.len())?; + + for cond_clause in bcl.cond_clauses.iter() { + extend_grounded_vars(&mut grounded_vars, cond_clause.bound_vars())?; + } + + let ir_rel = IrRelation::new(relation.into(), indices); + let ir_bcl = IrBodyClause { + rel: ir_rel, + args: bcl.args.iter().cloned().map(BodyClauseArg::unwrap_expr).collect(), + rel_args_span: bcl.rel.span().join(bcl.args.span()).unwrap_or_else(|| bcl.rel.span()), + args_span: bcl.args.span(), + cond_clauses: bcl.cond_clauses.clone(), + }; + body_items.push(IrBodyItem::Clause(ir_bcl)); + }, + BodyItemNode::Generator(ref gen) => { + extend_grounded_vars(&mut grounded_vars, pattern_get_vars(&gen.pattern))?; + body_items.push(IrBodyItem::Generator(gen.clone())); + }, + BodyItemNode::Cond(ref cl) => { + body_items.push(IrBodyItem::Cond(cl.clone())); + extend_grounded_vars(&mut grounded_vars, cl.bound_vars())?; + }, + BodyItemNode::Agg(ref agg) => { + extend_grounded_vars(&mut grounded_vars, pattern_get_vars(&agg.pat))?; + let indices = agg + .rel_args + .iter() + .enumerate() + .filter(|(_i, expr)| { + if is_wild_card(expr) { + return false; + } else if let Some(ident) = expr_to_ident(expr) { + if agg.bound_args.iter().contains(&ident) { + return false; + } + } + true + }) + .map(|(i, _expr)| i) + .collect_vec(); + let relation = prog_get_relation(prog, &agg.rel, agg.rel_args.len())?; + + let ir_rel = IrRelation::new(relation.into(), indices); + let ir_agg_clause = IrAggClause { + span: agg.agg_kw.span, + pat: agg.pat.clone(), + aggregator: agg.aggregator.get_expr(), + bound_args: agg.bound_args.iter().cloned().collect_vec(), + rel: ir_rel, + rel_args: agg.rel_args.iter().cloned().collect_vec(), + }; + body_items.push(IrBodyItem::Agg(ir_agg_clause)); + }, + _ => panic!("unrecognized body item"), + } + } + let mut head_clauses = vec![]; + for hcl_node in rule.head_clauses.iter() { + let hcl_node = hcl_node.clause(); + let rel = prog.relations.iter().find(|r| hcl_node.rel == r.name); + let rel = match rel { + Some(rel) => rel, + None => return Err(Error::new(hcl_node.rel.span(), format!("relation {} not defined", hcl_node.rel))), + }; + + let rel = RelationIdentity::from(rel); + let head_clause = IrHeadClause { + rel, + args: hcl_node.args.iter().cloned().collect(), + span: hcl_node.span(), + args_span: hcl_node.args.span(), + }; + head_clauses.push(head_clause); + } + + let is_simple_join = first_two_clauses_simple && body_items.len() >= 2; + let simple_join_start_index = if is_simple_join { first_clause_ind } else { None }; + + let simple_join_ir_relations = if let Some(start_ind) = simple_join_start_index { + let (bcl1, bcl2) = match &body_items[start_ind..start_ind + 2] { + [IrBodyItem::Clause(bcl1), IrBodyItem::Clause(bcl2)] => (bcl1, bcl2), + _ => panic!("incorrect simple join handling in ascent_hir"), + }; + let bcl2_vars = bcl2.args.iter().filter_map(expr_to_ident).collect_vec(); + let indices = get_indices_given_grounded_variables(&bcl1.args, &bcl2_vars); + let new_cl1_ir_relation = IrRelation::new(bcl1.rel.relation.clone(), indices); + vec![new_cl1_ir_relation] + } else { + vec![] + }; + + if let Some(start_ind) = simple_join_start_index { + if let IrBodyItem::Clause(cl1) = &mut body_items[start_ind] { + cl1.rel = simple_join_ir_relations[0].clone(); + } + } + + Ok((IrRule { simple_join_start_index, head_clauses, body_items }, vec![])) +} + +pub fn ir_name_for_rel_indices(rel: &Ident, indices: &[usize]) -> Ident { + let indices_str = if indices.is_empty() { format!("none") } else { indices.iter().join("_") }; + let name = format!("{}_indices_{}", rel, indices_str); + Ident::new(&name, rel.span()) +} + +/// for a clause with args, returns the indices assuming vars are grounded. +pub fn get_indices_given_grounded_variables(args: &[Expr], vars: &[Ident]) -> Vec { + let mut res = vec![]; + for (i, arg) in args.iter().enumerate() { + if let Some(arg_var) = expr_to_ident(arg) { + if vars.contains(&arg_var) { + res.push(i); + } + } else { + res.push(i); + } + } + res +} + +pub(crate) fn prog_get_relation<'a>( + prog: &'a AscentProgram, name: &Ident, arity: usize, +) -> syn::Result<&'a RelationNode> { + let relation = prog.relations.iter().find(|r| name == &r.name); + match relation { + Some(rel) => + if rel.field_types.len() != arity { + Err(Error::new( + name.span(), + format!("Wrong arity for relation {}. Actual arity: {}", name, rel.field_types.len()), + )) + } else { + Ok(rel) + }, + None => Err(Error::new(name.span(), format!("Relation {} not defined", name))), + } +} diff --git a/ascent_macro/src/ascent_mir.rs b/ascent_macro/src/ascent_mir.rs index 991b7b8..94018d3 100644 --- a/ascent_macro/src/ascent_mir.rs +++ b/ascent_macro/src/ascent_mir.rs @@ -1,409 +1,427 @@ -#![deny(warnings)] -use std::collections::{HashMap, HashSet}; -use std::fmt::Write; -use itertools::Itertools; -use petgraph::{algo::condensation, graphmap::DiGraphMap}; -use proc_macro2::{Ident, Span}; -use syn::{Expr, Type}; -use crate::{ascent_mir::MirRelationVersion::*, ascent_syntax::Signatures, syn_utils::pattern_get_vars}; -use crate::utils::{expr_to_ident, pat_to_ident, tuple_type, intersects}; -use crate::ascent_syntax::{CondClause, GeneratorNode, RelationIdentity}; -use crate::ascent_hir::{AscentConfig, AscentIr, IndexValType, IrAggClause, IrBodyClause, IrBodyItem, IrHeadClause, IrRelation, IrRule, RelationMetadata}; - -pub(crate) struct AscentMir { - pub sccs: Vec, - #[allow(unused)] - pub deps: HashMap>, - pub relations_ir_relations: HashMap>, - pub relations_full_indices: HashMap, - pub relations_metadata: HashMap, - pub lattices_full_indices: HashMap, - pub signatures: Signatures, - pub config: AscentConfig, - pub is_parallel: bool, -} - -pub(crate) struct MirScc { - pub rules: Vec, - pub dynamic_relations: HashMap>, - pub body_only_relations: HashMap>, - pub is_looping: bool -} - - -pub(crate) fn mir_summary(mir: &AscentMir) -> String { - let mut res = String::new(); - for (i, scc) in mir.sccs.iter().enumerate() { - writeln!(&mut res, "scc {}, is_looping: {}:", i, scc.is_looping).unwrap(); - for r in scc.rules.iter() { - writeln!(&mut res, " {}", mir_rule_summary(r)).unwrap(); - } - let sorted_dynamic_relation_keys = scc.dynamic_relations.keys().sorted_by_key(|rel| &rel.name); - write!(&mut res, " dynamic relations: ").unwrap(); - writeln!(&mut res, "{}", sorted_dynamic_relation_keys.map(|r| r.name.to_string()).join(", ")).unwrap(); - } - res -} - -#[derive(Clone)] -pub(crate) struct MirRule { - // TODO rename to head_clauses - pub head_clause: Vec, - pub body_items: Vec, - pub simple_join_start_index: Option, - pub reorderable: bool -} - -pub(crate) fn mir_rule_summary(rule: &MirRule) -> String { - fn bitem_to_str(bitem: &MirBodyItem) -> String { - match bitem { - MirBodyItem::Clause(bcl) => format!("{}_{}", bcl.rel.ir_name, bcl.rel.version.to_string()), - MirBodyItem::Generator(gen) => format!("for_{}", pat_to_ident(&gen.pattern).map(|x| x.to_string()).unwrap_or_default()), - MirBodyItem::Cond(CondClause::If(..)) => format!("if ⋯"), - MirBodyItem::Cond(CondClause::IfLet(..)) => format!("if let ⋯"), - MirBodyItem::Cond(CondClause::Let(..)) => format!("let ⋯"), - MirBodyItem::Agg(agg) => format!("agg {}", agg.rel.ir_name()), - } - } - format!("{} <-- {}{simple_join}{reorderable}", - rule.head_clause.iter().map(|hcl| hcl.rel.name.to_string()).join(", "), - rule.body_items.iter().map(bitem_to_str).join(", "), - simple_join = if rule.simple_join_start_index.is_some() {" [SIMPLE JOIN]"} else {""}, - reorderable = if rule.simple_join_start_index.is_some() && !rule.reorderable {" [NOT REORDERABLE]"} else {""}) -} - -#[derive(Clone)] -pub(crate) enum MirBodyItem { - Clause(MirBodyClause), - Generator(GeneratorNode), - Cond(CondClause), - Agg(IrAggClause) -} - -impl MirBodyItem { - pub fn unwrap_clause(&self) -> &MirBodyClause { - match self { - MirBodyItem::Clause(cl) => cl, - _ => panic!("MirBodyItem: unwrap_clause called on non_clause") - } - } - - pub fn bound_vars(&self) -> Vec { - match self { - MirBodyItem::Clause(cl) => { - let cl_vars = cl.args.iter().filter_map(expr_to_ident); - let cond_cl_vars = cl.cond_clauses.iter().flat_map(|cc| cc.bound_vars()); - cl_vars.chain(cond_cl_vars).collect() - }, - MirBodyItem::Generator(gen) => pattern_get_vars(&gen.pattern), - MirBodyItem::Cond(cond) => cond.bound_vars(), - MirBodyItem::Agg(agg) => pattern_get_vars(&agg.pat), - } - } -} - -#[derive(Clone)] -pub(crate) struct MirBodyClause { - pub rel: MirRelation, - pub args: Vec, - pub rel_args_span: Span, - pub args_span: Span, - pub cond_clauses : Vec -} -impl MirBodyClause { - pub fn selected_args(&self) -> Vec { - self.rel.indices.iter().map(|&i| self.args[i].clone()).collect() - } - - /// returns a vec of (var_ind, var) of all the variables in the clause - pub fn vars(&self) -> Vec<(usize, Ident)> { - self.args.iter().enumerate() - .filter_map(|(i,v)| expr_to_ident(v).map(|v| (i, v))) - .collect::>() - } - - #[allow(dead_code)] - pub fn from(ir_body_clause: IrBodyClause, rel: MirRelation) -> MirBodyClause{ - MirBodyClause { - rel, - args: ir_body_clause.args, - rel_args_span: ir_body_clause.rel_args_span, - args_span: ir_body_clause.args_span, - cond_clauses: ir_body_clause.cond_clauses, - } - } -} - -#[derive(Clone, PartialEq, Eq, Hash)] -pub(crate) struct MirRelation { - pub relation: RelationIdentity, - pub indices: Vec, - pub ir_name: Ident, - pub version: MirRelationVersion, - pub is_full_index: bool, - pub is_no_index: bool, - pub val_type: IndexValType, -} - -pub(crate) fn ir_relation_version_var_name(ir_name: &Ident, version : MirRelationVersion) -> Ident{ - let name = format!("{}_{}", ir_name, version.to_string()); - Ident::new(&name, ir_name.span()) -} - -impl MirRelation { - pub fn var_name(&self) -> Ident { - ir_relation_version_var_name(&self.ir_name, self.version) - } - - #[allow(dead_code)] - pub fn key_type(&self) -> Type { - let index_types : Vec<_> = self.indices.iter().map(|&i| self.relation.field_types[i].clone()).collect(); - tuple_type(&index_types) - } - - pub fn from(ir_relation : IrRelation, version : MirRelationVersion) -> MirRelation { - MirRelation { - ir_name: ir_relation.ir_name(), - is_full_index: ir_relation.is_full_index(), - is_no_index: ir_relation.is_no_index(), - relation: ir_relation.relation, - indices: ir_relation.indices, - version, - val_type: ir_relation.val_type - } - } -} - -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] -pub(crate) enum MirRelationVersion { - TotalDelta, - Total, - Delta, - New, -} - -impl MirRelationVersion { - pub fn to_string(self) -> &'static str{ - use MirRelationVersion::*; - match self { TotalDelta => "total+delta", Delta => "delta", Total => "total", New => "new" } - } -} - -fn get_hir_dep_graph(hir: &AscentIr) -> Vec<(usize,usize)> { - let mut relations_to_rules_in_head : HashMap<&RelationIdentity, HashSet> = HashMap::with_capacity(hir.rules.len()); - for (i, rule) in hir.rules.iter().enumerate(){ - for head_rel in rule.head_clauses.iter().map(|hcl| &hcl.rel){ - relations_to_rules_in_head.entry(head_rel).or_default().insert(i); - } - } - - let mut edges = vec![]; - for (i, rule) in hir.rules.iter().enumerate() { - for bitem in rule.body_items.iter() { - if let Some(body_rel) = bitem.rel() { - let body_rel_identity = &body_rel.relation; - if let Some(set) = relations_to_rules_in_head.get(body_rel_identity){ - for &rule_with_rel_in_head in set.iter().sorted(){ - edges.push((rule_with_rel_in_head, i)); - } - } - } - } - } - edges -} - -pub(crate) fn compile_hir_to_mir(hir: &AscentIr) -> syn::Result{ - - let dep_graph = get_hir_dep_graph(hir); - let mut dep_graph = DiGraphMap::<_,()>::from_edges(&dep_graph); - for i in 0..hir.rules.len() {dep_graph.add_node(i);} - let dep_graph = dep_graph.into_graph::(); - // println!("{:?}", Dot::with_config(&dep_graph, &[Config::EdgeNoLabel])); - let mut sccs = condensation(dep_graph, true); - - let mut mir_sccs = vec![]; - for scc in sccs.node_weights().collect_vec().iter().rev(){ - let mut dynamic_relations: HashMap> = HashMap::new(); - let mut body_only_relations: HashMap> = HashMap::new(); - - let mut dynamic_relations_set = HashSet::new(); - for &rule_ind in scc.iter(){ - let rule = &hir.rules[rule_ind]; - for bitem in rule.body_items.iter() { - if let Some(rel) = bitem.rel(){ - body_only_relations.entry(rel.relation.clone()).or_default().insert(rel.clone()); - } - } - - for hcl in hir.rules[rule_ind].head_clauses.iter() { - - dynamic_relations_set.insert(hcl.rel.clone()); - dynamic_relations.entry(hcl.rel.clone()).or_default(); - // TODO why this? - // ... we can add only indices used in bodies in the scc, that requires the codegen to be updated. - for rel_ind in &hir.relations_ir_relations[&hcl.rel]{ - dynamic_relations.get_mut(&hcl.rel).unwrap().insert(rel_ind.clone()); - } - } - - } - - let mut is_looping = false; - for rel in dynamic_relations.keys().cloned().collect_vec() { - if let Some(indices) = body_only_relations.remove(&rel){ - is_looping = true; - for ind in indices { - dynamic_relations.entry(rel.clone()).or_default().insert(ind); - } - } - } - - let rules = scc.iter() - .flat_map(|&ind| compile_hir_rule_to_mir_rules(&hir.rules[ind], &dynamic_relations_set)) - .collect_vec(); - - for rule in rules.iter() { - for bi in rule.body_items.iter() { - if let MirBodyItem::Agg(agg) = bi { - if dynamic_relations.contains_key(&agg.rel.relation) { - return Err(syn::Error::new(agg.span, - format!("use of aggregated relation {} cannot be stratified", &agg.rel.relation.name))); - } - } - } - } - let mir_scc = MirScc{ - rules, - dynamic_relations, - body_only_relations, - is_looping - }; - mir_sccs.push(mir_scc); - } - - sccs.reverse(); - let sccs_nodes_count = sccs.node_indices().count(); - let mut sccs_dep_graph = HashMap::with_capacity(sccs_nodes_count); - for n in sccs.node_indices() { - //the nodes in the sccs graph is in reverse topological order, so we do this - sccs_dep_graph.insert(sccs_nodes_count - n.index() - 1, sccs.neighbors(n).map(|n| sccs_nodes_count - n.index() - 1).collect()); - } - - Ok(AscentMir { - sccs: mir_sccs, - deps: sccs_dep_graph, - relations_ir_relations: hir.relations_ir_relations.clone(), - relations_full_indices: hir.relations_full_indices.clone(), - lattices_full_indices: hir.lattices_full_indices.clone(), - // relations_no_indices: hir.relations_no_indices.clone(), - relations_metadata: hir.relations_metadata.clone(), - signatures: hir.signatures.clone(), - config: hir.config.clone(), - is_parallel: hir.is_parallel, - }) -} - -fn compile_hir_rule_to_mir_rules(rule: &IrRule, dynamic_relations: &HashSet) -> Vec { - - fn versions_base(count: usize) -> Vec> { - if count == 0 { - vec![] - } else { - let mut res = versions_base(count - 1); - for v in &mut res { - v.push(MirRelationVersion::TotalDelta); - } - let mut new_combination = vec![MirRelationVersion::Total; count]; - new_combination[count - 1] = MirRelationVersion::Delta; - res.push(new_combination); - res - } - } - - // TODO is it worth it? - fn versions(dynamic_cls: &[usize], simple_join_start_index: Option) -> Vec> { - fn remove_total_delta_at_index(ind: usize, res: &mut Vec>) { - let mut i = 0; - while i < res.len() { - if res[i].get(ind) == Some(&TotalDelta) { - res.insert(i + 1, res[i].clone()); - res[i][ind] = Total; - res[i + 1][ind] = Delta; - } - i += 1; - } - } - - let count = dynamic_cls.len(); - let mut res = versions_base(count); - let no_total_delta_at_beginning = false; - if no_total_delta_at_beginning { - if let Some(ind) = simple_join_start_index { - remove_total_delta_at_index(ind, &mut res); - remove_total_delta_at_index(ind + 1, &mut res); - } else if dynamic_cls.get(0) == Some(&0) { - remove_total_delta_at_index(0, &mut res); - } - } - res - } - - fn hir_body_item_to_mir_body_item(hir_bitem : &IrBodyItem, version: Option) -> MirBodyItem{ - match hir_bitem { - IrBodyItem::Clause(_) => { }, - _ => assert!(version.is_none()) - } - match hir_bitem{ - IrBodyItem::Clause(hir_bcl) => { - let ver = version.unwrap_or(MirRelationVersion::Total); - let mir_relation = MirRelation::from(hir_bcl.rel.clone(), ver); - let mir_bcl = MirBodyClause{ - rel: mir_relation, - args : hir_bcl.args.clone(), - rel_args_span: hir_bcl.rel_args_span, - args_span: hir_bcl.args_span, - cond_clauses: hir_bcl.cond_clauses.clone() - }; - MirBodyItem::Clause(mir_bcl) - }, - IrBodyItem::Cond(cl) => MirBodyItem::Cond(cl.clone()), - IrBodyItem::Generator(gen) => MirBodyItem::Generator(gen.clone()), - IrBodyItem::Agg(agg) => MirBodyItem::Agg(agg.clone()) - } - } - - - let dynamic_cls = rule.body_items.iter().enumerate().filter_map(|(i, cl)| match cl { - IrBodyItem::Clause(cl) if dynamic_relations.contains(&cl.rel.relation) => Some(i), - _ => None, - }).collect_vec(); - - let version_combinations = if dynamic_cls.is_empty() {vec![vec![]]} else {versions(&dynamic_cls[..], rule.simple_join_start_index)}; - - let mut mir_body_items = Vec::with_capacity(version_combinations.len()); - - for version_combination in version_combinations { - let versions = dynamic_cls.iter().zip(version_combination) - .fold(vec![None; rule.body_items.len()], |mut acc, (i, v)| {acc[*i] = Some(v); acc}); - let mir_bodys = rule.body_items.iter().zip(versions).map(|(bi, v)| hir_body_item_to_mir_body_item(bi, v)) - .collect_vec(); - mir_body_items.push(mir_bodys) - } - - mir_body_items.into_iter().map(|bcls| { - - // rule is reorderable if it is a simple join and the second clause does not depend on items - // before the first clause (e.g., let z = &1, foo(x, y), bar(y, z) is not reorderable) - let reorderable = rule.simple_join_start_index.map_or(false, |ind| { - let pre_first_clause_vars = bcls.iter().take(ind).flat_map(MirBodyItem::bound_vars); - !intersects(pre_first_clause_vars, bcls[ind + 1].bound_vars()) - }); - MirRule { - body_items: bcls, - head_clause: rule.head_clauses.clone(), - simple_join_start_index: rule.simple_join_start_index, - reorderable - } - }).collect() -} +#![deny(warnings)] +use std::collections::{HashMap, HashSet}; +use std::fmt::Write; + +use itertools::Itertools; +use petgraph::algo::condensation; +use petgraph::graphmap::DiGraphMap; +use proc_macro2::{Ident, Span}; +use syn::{Expr, Type}; + +use crate::ascent_hir::{ + AscentConfig, AscentIr, IndexValType, IrAggClause, IrBodyClause, IrBodyItem, IrHeadClause, IrRelation, IrRule, + RelationMetadata, +}; +use crate::ascent_mir::MirRelationVersion::*; +use crate::ascent_syntax::{CondClause, GeneratorNode, RelationIdentity, Signatures}; +use crate::syn_utils::pattern_get_vars; +use crate::utils::{expr_to_ident, intersects, pat_to_ident, tuple_type}; + +pub(crate) struct AscentMir { + pub sccs: Vec, + #[allow(unused)] + pub deps: HashMap>, + pub relations_ir_relations: HashMap>, + pub relations_full_indices: HashMap, + pub relations_metadata: HashMap, + pub lattices_full_indices: HashMap, + pub signatures: Signatures, + pub config: AscentConfig, + pub is_parallel: bool, +} + +pub(crate) struct MirScc { + pub rules: Vec, + pub dynamic_relations: HashMap>, + pub body_only_relations: HashMap>, + pub is_looping: bool, +} + +pub(crate) fn mir_summary(mir: &AscentMir) -> String { + let mut res = String::new(); + for (i, scc) in mir.sccs.iter().enumerate() { + writeln!(&mut res, "scc {}, is_looping: {}:", i, scc.is_looping).unwrap(); + for r in scc.rules.iter() { + writeln!(&mut res, " {}", mir_rule_summary(r)).unwrap(); + } + let sorted_dynamic_relation_keys = scc.dynamic_relations.keys().sorted_by_key(|rel| &rel.name); + write!(&mut res, " dynamic relations: ").unwrap(); + writeln!(&mut res, "{}", sorted_dynamic_relation_keys.map(|r| r.name.to_string()).join(", ")).unwrap(); + } + res +} + +#[derive(Clone)] +pub(crate) struct MirRule { + // TODO rename to head_clauses + pub head_clause: Vec, + pub body_items: Vec, + pub simple_join_start_index: Option, + pub reorderable: bool, +} + +pub(crate) fn mir_rule_summary(rule: &MirRule) -> String { + fn bitem_to_str(bitem: &MirBodyItem) -> String { + match bitem { + MirBodyItem::Clause(bcl) => format!("{}_{}", bcl.rel.ir_name, bcl.rel.version.to_string()), + MirBodyItem::Generator(gen) => + format!("for_{}", pat_to_ident(&gen.pattern).map(|x| x.to_string()).unwrap_or_default()), + MirBodyItem::Cond(CondClause::If(..)) => format!("if ⋯"), + MirBodyItem::Cond(CondClause::IfLet(..)) => format!("if let ⋯"), + MirBodyItem::Cond(CondClause::Let(..)) => format!("let ⋯"), + MirBodyItem::Agg(agg) => format!("agg {}", agg.rel.ir_name()), + } + } + format!( + "{} <-- {}{simple_join}{reorderable}", + rule.head_clause.iter().map(|hcl| hcl.rel.name.to_string()).join(", "), + rule.body_items.iter().map(bitem_to_str).join(", "), + simple_join = if rule.simple_join_start_index.is_some() { " [SIMPLE JOIN]" } else { "" }, + reorderable = if rule.simple_join_start_index.is_some() && !rule.reorderable { " [NOT REORDERABLE]" } else { "" } + ) +} + +#[derive(Clone)] +pub(crate) enum MirBodyItem { + Clause(MirBodyClause), + Generator(GeneratorNode), + Cond(CondClause), + Agg(IrAggClause), +} + +impl MirBodyItem { + pub fn unwrap_clause(&self) -> &MirBodyClause { + match self { + MirBodyItem::Clause(cl) => cl, + _ => panic!("MirBodyItem: unwrap_clause called on non_clause"), + } + } + + pub fn bound_vars(&self) -> Vec { + match self { + MirBodyItem::Clause(cl) => { + let cl_vars = cl.args.iter().filter_map(expr_to_ident); + let cond_cl_vars = cl.cond_clauses.iter().flat_map(|cc| cc.bound_vars()); + cl_vars.chain(cond_cl_vars).collect() + }, + MirBodyItem::Generator(gen) => pattern_get_vars(&gen.pattern), + MirBodyItem::Cond(cond) => cond.bound_vars(), + MirBodyItem::Agg(agg) => pattern_get_vars(&agg.pat), + } + } +} + +#[derive(Clone)] +pub(crate) struct MirBodyClause { + pub rel: MirRelation, + pub args: Vec, + pub rel_args_span: Span, + pub args_span: Span, + pub cond_clauses: Vec, +} +impl MirBodyClause { + pub fn selected_args(&self) -> Vec { self.rel.indices.iter().map(|&i| self.args[i].clone()).collect() } + + /// returns a vec of (var_ind, var) of all the variables in the clause + pub fn vars(&self) -> Vec<(usize, Ident)> { + self.args.iter().enumerate().filter_map(|(i, v)| expr_to_ident(v).map(|v| (i, v))).collect::>() + } + + #[allow(dead_code)] + pub fn from(ir_body_clause: IrBodyClause, rel: MirRelation) -> MirBodyClause { + MirBodyClause { + rel, + args: ir_body_clause.args, + rel_args_span: ir_body_clause.rel_args_span, + args_span: ir_body_clause.args_span, + cond_clauses: ir_body_clause.cond_clauses, + } + } +} + +#[derive(Clone, PartialEq, Eq, Hash)] +pub(crate) struct MirRelation { + pub relation: RelationIdentity, + pub indices: Vec, + pub ir_name: Ident, + pub version: MirRelationVersion, + pub is_full_index: bool, + pub is_no_index: bool, + pub val_type: IndexValType, +} + +pub(crate) fn ir_relation_version_var_name(ir_name: &Ident, version: MirRelationVersion) -> Ident { + let name = format!("{}_{}", ir_name, version.to_string()); + Ident::new(&name, ir_name.span()) +} + +impl MirRelation { + pub fn var_name(&self) -> Ident { ir_relation_version_var_name(&self.ir_name, self.version) } + + #[allow(dead_code)] + pub fn key_type(&self) -> Type { + let index_types: Vec<_> = self.indices.iter().map(|&i| self.relation.field_types[i].clone()).collect(); + tuple_type(&index_types) + } + + pub fn from(ir_relation: IrRelation, version: MirRelationVersion) -> MirRelation { + MirRelation { + ir_name: ir_relation.ir_name(), + is_full_index: ir_relation.is_full_index(), + is_no_index: ir_relation.is_no_index(), + relation: ir_relation.relation, + indices: ir_relation.indices, + version, + val_type: ir_relation.val_type, + } + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub(crate) enum MirRelationVersion { + TotalDelta, + Total, + Delta, + New, +} + +impl MirRelationVersion { + pub fn to_string(self) -> &'static str { + use MirRelationVersion::*; + match self { + TotalDelta => "total+delta", + Delta => "delta", + Total => "total", + New => "new", + } + } +} + +fn get_hir_dep_graph(hir: &AscentIr) -> Vec<(usize, usize)> { + let mut relations_to_rules_in_head: HashMap<&RelationIdentity, HashSet> = + HashMap::with_capacity(hir.rules.len()); + for (i, rule) in hir.rules.iter().enumerate() { + for head_rel in rule.head_clauses.iter().map(|hcl| &hcl.rel) { + relations_to_rules_in_head.entry(head_rel).or_default().insert(i); + } + } + + let mut edges = vec![]; + for (i, rule) in hir.rules.iter().enumerate() { + for bitem in rule.body_items.iter() { + if let Some(body_rel) = bitem.rel() { + let body_rel_identity = &body_rel.relation; + if let Some(set) = relations_to_rules_in_head.get(body_rel_identity) { + for &rule_with_rel_in_head in set.iter().sorted() { + edges.push((rule_with_rel_in_head, i)); + } + } + } + } + } + edges +} + +pub(crate) fn compile_hir_to_mir(hir: &AscentIr) -> syn::Result { + let dep_graph = get_hir_dep_graph(hir); + let mut dep_graph = DiGraphMap::<_, ()>::from_edges(&dep_graph); + for i in 0..hir.rules.len() { + dep_graph.add_node(i); + } + let dep_graph = dep_graph.into_graph::(); + // println!("{:?}", Dot::with_config(&dep_graph, &[Config::EdgeNoLabel])); + let mut sccs = condensation(dep_graph, true); + + let mut mir_sccs = vec![]; + for scc in sccs.node_weights().collect_vec().iter().rev() { + let mut dynamic_relations: HashMap> = HashMap::new(); + let mut body_only_relations: HashMap> = HashMap::new(); + + let mut dynamic_relations_set = HashSet::new(); + for &rule_ind in scc.iter() { + let rule = &hir.rules[rule_ind]; + for bitem in rule.body_items.iter() { + if let Some(rel) = bitem.rel() { + body_only_relations.entry(rel.relation.clone()).or_default().insert(rel.clone()); + } + } + + for hcl in hir.rules[rule_ind].head_clauses.iter() { + dynamic_relations_set.insert(hcl.rel.clone()); + dynamic_relations.entry(hcl.rel.clone()).or_default(); + // TODO why this? + // ... we can add only indices used in bodies in the scc, that requires the codegen to be updated. + for rel_ind in &hir.relations_ir_relations[&hcl.rel] { + dynamic_relations.get_mut(&hcl.rel).unwrap().insert(rel_ind.clone()); + } + } + } + + let mut is_looping = false; + for rel in dynamic_relations.keys().cloned().collect_vec() { + if let Some(indices) = body_only_relations.remove(&rel) { + is_looping = true; + for ind in indices { + dynamic_relations.entry(rel.clone()).or_default().insert(ind); + } + } + } + + let rules = scc + .iter() + .flat_map(|&ind| compile_hir_rule_to_mir_rules(&hir.rules[ind], &dynamic_relations_set)) + .collect_vec(); + + for rule in rules.iter() { + for bi in rule.body_items.iter() { + if let MirBodyItem::Agg(agg) = bi { + if dynamic_relations.contains_key(&agg.rel.relation) { + return Err(syn::Error::new( + agg.span, + format!("use of aggregated relation {} cannot be stratified", &agg.rel.relation.name), + )); + } + } + } + } + let mir_scc = MirScc { rules, dynamic_relations, body_only_relations, is_looping }; + mir_sccs.push(mir_scc); + } + + sccs.reverse(); + let sccs_nodes_count = sccs.node_indices().count(); + let mut sccs_dep_graph = HashMap::with_capacity(sccs_nodes_count); + for n in sccs.node_indices() { + //the nodes in the sccs graph is in reverse topological order, so we do this + sccs_dep_graph.insert( + sccs_nodes_count - n.index() - 1, + sccs.neighbors(n).map(|n| sccs_nodes_count - n.index() - 1).collect(), + ); + } + + Ok(AscentMir { + sccs: mir_sccs, + deps: sccs_dep_graph, + relations_ir_relations: hir.relations_ir_relations.clone(), + relations_full_indices: hir.relations_full_indices.clone(), + lattices_full_indices: hir.lattices_full_indices.clone(), + // relations_no_indices: hir.relations_no_indices.clone(), + relations_metadata: hir.relations_metadata.clone(), + signatures: hir.signatures.clone(), + config: hir.config.clone(), + is_parallel: hir.is_parallel, + }) +} + +fn compile_hir_rule_to_mir_rules(rule: &IrRule, dynamic_relations: &HashSet) -> Vec { + fn versions_base(count: usize) -> Vec> { + if count == 0 { + vec![] + } else { + let mut res = versions_base(count - 1); + for v in &mut res { + v.push(MirRelationVersion::TotalDelta); + } + let mut new_combination = vec![MirRelationVersion::Total; count]; + new_combination[count - 1] = MirRelationVersion::Delta; + res.push(new_combination); + res + } + } + + // TODO is it worth it? + fn versions(dynamic_cls: &[usize], simple_join_start_index: Option) -> Vec> { + fn remove_total_delta_at_index(ind: usize, res: &mut Vec>) { + let mut i = 0; + while i < res.len() { + if res[i].get(ind) == Some(&TotalDelta) { + res.insert(i + 1, res[i].clone()); + res[i][ind] = Total; + res[i + 1][ind] = Delta; + } + i += 1; + } + } + + let count = dynamic_cls.len(); + let mut res = versions_base(count); + let no_total_delta_at_beginning = false; + if no_total_delta_at_beginning { + if let Some(ind) = simple_join_start_index { + remove_total_delta_at_index(ind, &mut res); + remove_total_delta_at_index(ind + 1, &mut res); + } else if dynamic_cls.get(0) == Some(&0) { + remove_total_delta_at_index(0, &mut res); + } + } + res + } + + fn hir_body_item_to_mir_body_item(hir_bitem: &IrBodyItem, version: Option) -> MirBodyItem { + match hir_bitem { + IrBodyItem::Clause(_) => {}, + _ => assert!(version.is_none()), + } + match hir_bitem { + IrBodyItem::Clause(hir_bcl) => { + let ver = version.unwrap_or(MirRelationVersion::Total); + let mir_relation = MirRelation::from(hir_bcl.rel.clone(), ver); + let mir_bcl = MirBodyClause { + rel: mir_relation, + args: hir_bcl.args.clone(), + rel_args_span: hir_bcl.rel_args_span, + args_span: hir_bcl.args_span, + cond_clauses: hir_bcl.cond_clauses.clone(), + }; + MirBodyItem::Clause(mir_bcl) + }, + IrBodyItem::Cond(cl) => MirBodyItem::Cond(cl.clone()), + IrBodyItem::Generator(gen) => MirBodyItem::Generator(gen.clone()), + IrBodyItem::Agg(agg) => MirBodyItem::Agg(agg.clone()), + } + } + + let dynamic_cls = rule + .body_items + .iter() + .enumerate() + .filter_map(|(i, cl)| match cl { + IrBodyItem::Clause(cl) if dynamic_relations.contains(&cl.rel.relation) => Some(i), + _ => None, + }) + .collect_vec(); + + let version_combinations = + if dynamic_cls.is_empty() { vec![vec![]] } else { versions(&dynamic_cls[..], rule.simple_join_start_index) }; + + let mut mir_body_items = Vec::with_capacity(version_combinations.len()); + + for version_combination in version_combinations { + let versions = + dynamic_cls.iter().zip(version_combination).fold(vec![None; rule.body_items.len()], |mut acc, (i, v)| { + acc[*i] = Some(v); + acc + }); + let mir_bodys = + rule.body_items.iter().zip(versions).map(|(bi, v)| hir_body_item_to_mir_body_item(bi, v)).collect_vec(); + mir_body_items.push(mir_bodys) + } + + mir_body_items + .into_iter() + .map(|bcls| { + // rule is reorderable if it is a simple join and the second clause does not depend on items + // before the first clause (e.g., let z = &1, foo(x, y), bar(y, z) is not reorderable) + let reorderable = rule.simple_join_start_index.map_or(false, |ind| { + let pre_first_clause_vars = bcls.iter().take(ind).flat_map(MirBodyItem::bound_vars); + !intersects(pre_first_clause_vars, bcls[ind + 1].bound_vars()) + }); + MirRule { + body_items: bcls, + head_clause: rule.head_clauses.clone(), + simple_join_start_index: rule.simple_join_start_index, + reorderable, + } + }) + .collect() +} diff --git a/ascent_macro/src/ascent_syntax.rs b/ascent_macro/src/ascent_syntax.rs index 2ac2f34..655ac49 100644 --- a/ascent_macro/src/ascent_syntax.rs +++ b/ascent_macro/src/ascent_syntax.rs @@ -1,30 +1,29 @@ #![deny(warnings)] extern crate proc_macro; +use std::collections::{HashMap, HashSet}; +use std::sync::Mutex; + use ascent_base::util::update; +use derive_syn_parse::Parse; +use itertools::Itertools; use proc_macro2::{Span, TokenStream}; -use syn::{ImplGenerics, TypeGenerics}; +use quote::ToTokens; +use syn::parse::{Parse, ParseStream, Parser}; +use syn::punctuated::Punctuated; +use syn::spanned::Spanned; use syn::{ - braced, parenthesized, parse2, punctuated::Punctuated, spanned::Spanned, Attribute, Error, Expr, ExprMacro, - Generics, Ident, Pat, Result, Token, Type, Visibility, - WhereClause, + Attribute, Error, Expr, ExprMacro, Generics, Ident, ImplGenerics, Pat, Result, Token, Type, TypeGenerics, + Visibility, WhereClause, braced, parenthesized, parse2, }; -use syn::parse::{Parse, ParseStream, Parser}; -use std::{collections::{HashMap, HashSet}, sync::Mutex}; - -use quote::ToTokens; -use itertools::Itertools; -use derive_syn_parse::Parse; -use crate::utils::{ - expr_to_ident, expr_to_ident_mut, flatten_punctuated, is_wild_card, pat_to_ident, - punctuated_map, punctuated_singleton, punctuated_try_map, punctuated_try_unwrap, spans_eq, - token_stream_replace_macro_idents, Piper, -}; use crate::syn_utils::{ - expr_get_vars, expr_visit_free_vars_mut, expr_visit_idents_in_macros_mut, - pattern_get_vars, pattern_visit_vars_mut, token_stream_idents, token_stream_replace_ident, + expr_get_vars, expr_visit_free_vars_mut, expr_visit_idents_in_macros_mut, pattern_get_vars, pattern_visit_vars_mut, + token_stream_idents, token_stream_replace_ident, +}; +use crate::utils::{ + Piper, expr_to_ident, expr_to_ident_mut, flatten_punctuated, is_wild_card, pat_to_ident, punctuated_map, + punctuated_singleton, punctuated_try_map, punctuated_try_unwrap, spans_eq, token_stream_replace_macro_idents, }; - // resources: // https://blog.rust-lang.org/2018/12/21/Procedural-Macros-in-Rust-2018.html @@ -67,11 +66,7 @@ impl Signatures { impl Parse for Signatures { fn parse(input: ParseStream) -> Result { let declaration = TypeSignature::parse(input)?; - let implementation = if input.peek(Token![impl]) { - Some(ImplSignature::parse(input)?) - } else { - None - }; + let implementation = if input.peek(Token![impl]) { Some(ImplSignature::parse(input)?) } else { None }; Ok(Signatures { declaration, implementation }) } } @@ -86,7 +81,7 @@ pub struct TypeSignature { pub ident: Ident, #[call(parse_generics_with_where_clause)] pub generics: Generics, - pub _semi: Token![;] + pub _semi: Token![;], } #[derive(Clone, Parse, Debug)] @@ -96,7 +91,7 @@ pub struct ImplSignature { pub ident: Ident, #[call(parse_generics_with_where_clause)] pub generics: Generics, - pub _semi: Token![;] + pub _semi: Token![;], } /// Parse impl on Generics does not parse WhereClauses, hence this function @@ -109,10 +104,10 @@ fn parse_generics_with_where_clause(input: ParseStream) -> Result { } // #[derive(Clone)] -pub struct RelationNode{ - pub attrs : Vec, +pub struct RelationNode { + pub attrs: Vec, pub name: Ident, - pub field_types : Punctuated, + pub field_types: Punctuated, pub initialization: Option, pub _semi_colon: Token![;], pub is_lattice: bool, @@ -120,21 +115,27 @@ pub struct RelationNode{ impl Parse for RelationNode { fn parse(input: ParseStream) -> Result { let is_lattice = input.peek(kw::lattice); - if is_lattice {input.parse::()?;} else {input.parse::()?;} - let name : Ident = input.parse()?; + if is_lattice { + input.parse::()?; + } else { + input.parse::()?; + } + let name: Ident = input.parse()?; let content; parenthesized!(content in input); let field_types = content.parse_terminated(Type::parse, Token![,])?; let initialization = if input.peek(Token![=]) { input.parse::()?; Some(input.parse::()?) - } else {None}; + } else { + None + }; let _semi_colon = input.parse::()?; if is_lattice && field_types.empty_or_trailing() { return Err(input.error("empty lattice is not allowed")); } - Ok(RelationNode{attrs: vec![], name, field_types, _semi_colon, is_lattice, initialization}) + Ok(RelationNode { attrs: vec![], name, field_types, _semi_colon, is_lattice, initialization }) } } @@ -160,9 +161,7 @@ fn peek_macro_invocation(parse_stream: ParseStream) -> bool { parse_stream.peek(Ident) && parse_stream.peek2(Token![!]) } -fn peek_if_or_let(parse_stream: ParseStream) -> bool { - parse_stream.peek(Token![if]) || parse_stream.peek(Token![let]) -} +fn peek_if_or_let(parse_stream: ParseStream) -> bool { parse_stream.peek(Token![if]) || parse_stream.peek(Token![let]) } #[derive(Clone)] pub struct DisjunctionNode { @@ -175,28 +174,28 @@ impl Parse for DisjunctionNode { let content; let paren = parenthesized!(content in input); let res: Punctuated, Token![||]> = - Punctuated::, Token![||]>::parse_terminated_with(&content, Punctuated::::parse_separated_nonempty)?; - Ok(DisjunctionNode{paren, disjuncts: res}) + Punctuated::, Token![||]>::parse_terminated_with( + &content, + Punctuated::::parse_separated_nonempty, + )?; + Ok(DisjunctionNode { paren, disjuncts: res }) } } - - - #[derive(Parse, Clone)] pub struct GeneratorNode { pub for_keyword: Token![for], #[call(Pat::parse_multi)] pub pattern: Pat, pub _in_keyword: Token![in], - pub expr: Expr + pub expr: Expr, } #[derive(Clone)] pub struct BodyClauseNode { - pub rel : Ident, - pub args : Punctuated, - pub cond_clauses: Vec + pub rel: Ident, + pub args: Punctuated, + pub cond_clauses: Vec, } #[derive(Parse, Clone, PartialEq, Eq, Debug)] @@ -211,14 +210,14 @@ impl BodyClauseArg { pub fn unwrap_expr(self) -> Expr { match self { Self::Expr(exp) => exp, - Self::Pat(_) => panic!("unwrap_expr(): BodyClauseArg is not an expr") + Self::Pat(_) => panic!("unwrap_expr(): BodyClauseArg is not an expr"), } } pub fn unwrap_expr_ref(&self) -> &Expr { match self { Self::Expr(exp) => exp, - Self::Pat(_) => panic!("unwrap_expr(): BodyClauseArg is not an expr") + Self::Pat(_) => panic!("unwrap_expr(): BodyClauseArg is not an expr"), } } @@ -231,7 +230,7 @@ impl BodyClauseArg { } impl ToTokens for BodyClauseArg { fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { - match self{ + match self { BodyClauseArg::Pat(pat) => { pat.huh_token.to_tokens(tokens); pat.pattern.to_tokens(tokens); @@ -245,7 +244,7 @@ impl ToTokens for BodyClauseArg { pub struct ClauseArgPattern { pub huh_token: Token![?], #[call(Pat::parse_multi)] - pub pattern : Pat, + pub pattern: Pat, } #[derive(Parse, Clone, PartialEq, Eq, Hash, Debug)] @@ -254,14 +253,14 @@ pub struct IfLetClause { pub let_keyword: Token![let], #[call(Pat::parse_multi)] pub pattern: Pat, - pub eq_symbol : Token![=], + pub eq_symbol: Token![=], pub exp: syn::Expr, } #[derive(Parse, Clone, PartialEq, Eq, Hash, Debug)] pub struct IfClause { pub if_keyword: Token![if], - pub cond: Expr + pub cond: Expr, } #[derive(Parse, Clone, PartialEq, Eq, Hash, Debug)] @@ -269,7 +268,7 @@ pub struct LetClause { pub let_keyword: Token![let], #[call(Pat::parse_multi)] pub pattern: Pat, - pub eq_symbol : Token![=], + pub eq_symbol: Token![=], pub exp: syn::Expr, } @@ -283,16 +282,16 @@ pub enum CondClause { impl CondClause { pub fn bound_vars(&self) -> Vec { match self { - CondClause::IfLet(cl) => pattern_get_vars(&cl.pattern), - CondClause::If(_) => vec![], - CondClause::Let(cl) => pattern_get_vars(&cl.pattern), + CondClause::IfLet(cl) => pattern_get_vars(&cl.pattern), + CondClause::If(_) => vec![], + CondClause::Let(cl) => pattern_get_vars(&cl.pattern), } } - /// returns the expression associated with the CondClause. + /// returns the expression associated with the CondClause. /// Useful for determining clause dependencies pub fn expr(&self) -> &Expr { - match self { + match self { CondClause::IfLet(cl) => &cl.exp, CondClause::If(cl) => &cl.cond, CondClause::Let(cl) => &cl.exp, @@ -325,32 +324,31 @@ impl Parse for CondClause { // } // } -impl Parse for BodyClauseNode{ +impl Parse for BodyClauseNode { fn parse(input: ParseStream) -> Result { - let rel : Ident = input.parse()?; + let rel: Ident = input.parse()?; let args_content; parenthesized!(args_content in input); let args = args_content.parse_terminated(BodyClauseArg::parse, Token![,])?; let mut cond_clauses = vec![]; - while let Ok(cl) = input.parse(){ + while let Ok(cl) = input.parse() { cond_clauses.push(cl); } - Ok(BodyClauseNode{rel, args, cond_clauses}) + Ok(BodyClauseNode { rel, args, cond_clauses }) } } #[derive(Parse, Clone)] pub struct NegationClauseNode { neg_token: Token![!], - pub rel : Ident, + pub rel: Ident, #[paren] _rel_arg_paren: syn::token::Paren, #[inside(_rel_arg_paren)] #[call(Punctuated::parse_terminated)] - pub args : Punctuated, + pub args: Punctuated, } - #[derive(Clone, Parse)] pub enum HeadItemNode { #[peek_with(peek_macro_invocation, name = "macro invocation")] @@ -370,8 +368,8 @@ impl HeadItemNode { #[derive(Clone)] pub struct HeadClauseNode { - pub rel : Ident, - pub args : Punctuated, + pub rel: Ident, + pub args: Punctuated, } impl ToTokens for HeadClauseNode { fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { @@ -380,13 +378,13 @@ impl ToTokens for HeadClauseNode { } } -impl Parse for HeadClauseNode{ +impl Parse for HeadClauseNode { fn parse(input: ParseStream) -> Result { - let rel : Ident = input.parse()?; + let rel: Ident = input.parse()?; let args_content; parenthesized!(args_content in input); let args = args_content.parse_terminated(Expr::parse, Token![,])?; - Ok(HeadClauseNode{rel, args}) + Ok(HeadClauseNode { rel, args }) } } @@ -403,18 +401,18 @@ pub struct AggClauseNode { #[call(Punctuated::parse_terminated)] pub bound_args: Punctuated, pub _in_kw: Token![in], - pub rel : Ident, + pub rel: Ident, #[paren] _rel_arg_paren: syn::token::Paren, #[inside(_rel_arg_paren)] #[call(Punctuated::parse_terminated)] - pub rel_args : Punctuated + pub rel_args: Punctuated, } #[derive(Clone)] pub enum AggregatorNode { Path(syn::Path), - Expr(Expr) + Expr(Expr), } impl Parse for AggregatorNode { @@ -431,7 +429,7 @@ impl Parse for AggregatorNode { impl AggregatorNode { pub fn get_expr(&self) -> Expr { match self { - AggregatorNode::Path(path) => parse2(quote!{#path}).unwrap(), + AggregatorNode::Path(path) => parse2(quote! {#path}).unwrap(), AggregatorNode::Expr(expr) => expr.clone(), } } @@ -439,7 +437,7 @@ impl AggregatorNode { pub struct RuleNode { pub head_clauses: Punctuated, - pub body_items: Vec// Punctuated, + pub body_items: Vec, // Punctuated, } impl Parse for RuleNode { @@ -452,10 +450,10 @@ impl Parse for RuleNode { Punctuated::::parse_separated_nonempty(input)? }; - if input.peek(Token![;]){ + if input.peek(Token![;]) { // println!("fact rule!!!"); input.parse::()?; - Ok(RuleNode{head_clauses, body_items: vec![] /*Punctuated::default()*/}) + Ok(RuleNode { head_clauses, body_items: vec![] /*Punctuated::default()*/ }) } else { input.parse::()?; input.parse::()?; @@ -465,7 +463,7 @@ impl Parse for RuleNode { let body_items = Punctuated::::parse_separated_nonempty(input)?; input.parse::()?; - Ok(RuleNode{ head_clauses, body_items: body_items.into_iter().collect()}) + Ok(RuleNode { head_clauses, body_items: body_items.into_iter().collect() }) } } } @@ -475,7 +473,8 @@ impl Parse for RuleNode { pub(crate) fn rule_node_summary(rule: &RuleNode) -> String { fn bitem_to_str(bitem: &BodyItemNode) -> String { match bitem { - BodyItemNode::Generator(gen) => format!("for_{}", pat_to_ident(&gen.pattern).map(|x| x.to_string()).unwrap_or_default()), + BodyItemNode::Generator(gen) => + format!("for_{}", pat_to_ident(&gen.pattern).map(|x| x.to_string()).unwrap_or_default()), BodyItemNode::Clause(bcl) => format!("{}", bcl.rel), BodyItemNode::Disjunction(_) => todo!(), BodyItemNode::Cond(_cl) => format!("if_"), @@ -490,9 +489,11 @@ pub(crate) fn rule_node_summary(rule: &RuleNode) -> String { HeadItemNode::HeadClause(cl) => cl.rel.to_string(), } } - format!("{} <-- {}", - rule.head_clauses.iter().map(hitem_to_str).join(", "), - rule.body_items.iter().map(bitem_to_str).join(", ")) + format!( + "{} <-- {}", + rule.head_clauses.iter().map(hitem_to_str).join(", "), + rule.body_items.iter().map(bitem_to_str).join(", ") + ) } #[derive(Parse)] @@ -500,7 +501,7 @@ pub struct MacroDefParam { _dollar: Token![$], name: Ident, _colon: Token![:], - kind: MacroParamKind + kind: MacroParamKind, } #[derive(Parse)] @@ -509,7 +510,7 @@ pub enum MacroParamKind { #[peek(kw::ident, name = "ident")] Expr(Ident), #[peek(kw::expr, name = "expr")] - Ident(Ident) + Ident(Ident), } #[derive(Parse)] @@ -529,11 +530,11 @@ pub struct MacroDefNode { // #[derive(Clone)] pub(crate) struct AscentProgram { - pub rules : Vec, - pub relations : Vec, + pub rules: Vec, + pub relations: Vec, pub signatures: Option, pub attributes: Vec, - pub macros: Vec + pub macros: Vec, } impl Parse for AscentProgram { @@ -551,8 +552,9 @@ impl Parse for AscentProgram { let mut relations = vec![]; let mut macros = vec![]; while !input.is_empty() { - let attrs = if !struct_attrs.is_empty() {std::mem::take(&mut struct_attrs)} else {Attribute::parse_outer(input)?}; - if input.peek(kw::relation) || input.peek(kw::lattice){ + let attrs = + if !struct_attrs.is_empty() { std::mem::take(&mut struct_attrs) } else { Attribute::parse_outer(input)? }; + if input.peek(kw::relation) || input.peek(kw::lattice) { let mut relation_node = RelationNode::parse(input)?; relation_node.attrs = attrs; relations.push(relation_node); @@ -568,7 +570,7 @@ impl Parse for AscentProgram { rules.push(RuleNode::parse(input)?); } } - Ok(AscentProgram{rules, relations, signatures, attributes, macros}) + Ok(AscentProgram { rules, relations, signatures, attributes, macros }) } } @@ -579,15 +581,15 @@ pub(crate) struct RelationIdentity { pub is_lattice: bool, } -impl From<&RelationNode> for RelationIdentity{ +impl From<&RelationNode> for RelationIdentity { fn from(relation_node: &RelationNode) -> Self { RelationIdentity { name: relation_node.name.clone(), field_types: relation_node.field_types.iter().cloned().collect(), - is_lattice: relation_node.is_lattice + is_lattice: relation_node.is_lattice, } } -} +} #[derive(Clone)] pub(crate) struct DsAttributeContents { @@ -623,19 +625,19 @@ fn rule_desugar_disjunction_nodes(rule: RuleNode) -> Vec { BodyItemNode::Disjunction(d) => { let mut res = vec![]; for disjunt in d.disjuncts.iter() { - for conjunction in bitems_desugar(&disjunt.iter().cloned().collect_vec()){ + for conjunction in bitems_desugar(&disjunt.iter().cloned().collect_vec()) { res.push(conjunction); } - } - res + } + res }, - BodyItemNode::MacroInvocation(m) => panic!("unexpected macro invocation: {:?}", m.mac.path), + BodyItemNode::MacroInvocation(m) => panic!("unexpected macro invocation: {:?}", m.mac.path), } } fn bitems_desugar(bitems: &[BodyItemNode]) -> Vec> { let mut res = vec![]; if !bitems.is_empty() { - let sub_res = bitems_desugar(&bitems[0 .. bitems.len() - 1]); + let sub_res = bitems_desugar(&bitems[0..bitems.len() - 1]); let last_desugared = bitem_desugar(&bitems[bitems.len() - 1]); for sub_res_item in sub_res.into_iter() { for last_item in last_desugared.iter() { @@ -652,11 +654,8 @@ fn rule_desugar_disjunction_nodes(rule: RuleNode) -> Vec { } let mut res = vec![]; - for conjunction in bitems_desugar(&rule.body_items){ - res.push(RuleNode { - body_items: conjunction, - head_clauses: rule.head_clauses.clone() - }) + for conjunction in bitems_desugar(&rule.body_items) { + res.push(RuleNode { body_items: conjunction, head_clauses: rule.head_clauses.clone() }) } res } @@ -667,9 +666,8 @@ fn body_item_get_bound_vars(bi: &BodyItemNode) -> Vec { BodyItemNode::Agg(agg) => pattern_get_vars(&agg.pat), BodyItemNode::Clause(cl) => cl.args.iter().flat_map(|arg| arg.get_vars()).collect(), BodyItemNode::Negation(_cl) => vec![], - BodyItemNode::Disjunction(disj) => disj.disjuncts.iter() - .flat_map(|conj| conj.iter().flat_map(body_item_get_bound_vars)) - .collect(), + BodyItemNode::Disjunction(disj) => + disj.disjuncts.iter().flat_map(|conj| conj.iter().flat_map(body_item_get_bound_vars)).collect(), BodyItemNode::Cond(cl) => cl.bound_vars(), BodyItemNode::MacroInvocation(_) => vec![], } @@ -679,22 +677,23 @@ fn body_item_visit_bound_vars_mut(bi: &mut BodyItemNode, visitor: &mut dyn FnMut match bi { BodyItemNode::Generator(gen) => pattern_visit_vars_mut(&mut gen.pattern, visitor), BodyItemNode::Agg(agg) => pattern_visit_vars_mut(&mut agg.pat, visitor), - BodyItemNode::Clause(cl) => { + BodyItemNode::Clause(cl) => for arg in cl.args.iter_mut() { match arg { BodyClauseArg::Pat(p) => pattern_visit_vars_mut(&mut p.pattern, visitor), - BodyClauseArg::Expr(e) => if let Some(ident) = expr_to_ident_mut(e) {visitor(ident)}, + BodyClauseArg::Expr(e) => + if let Some(ident) = expr_to_ident_mut(e) { + visitor(ident) + }, } - } - }, - BodyItemNode::Negation(_cl) =>(), - BodyItemNode::Disjunction(disj) =>{ + }, + BodyItemNode::Negation(_cl) => (), + BodyItemNode::Disjunction(disj) => for conj in disj.disjuncts.iter_mut() { for bi in conj.iter_mut() { body_item_visit_bound_vars_mut(bi, visitor) } - } - }, + }, BodyItemNode::Cond(cl) => match cl { CondClause::IfLet(cl) => pattern_visit_vars_mut(&mut cl.pattern, visitor), CondClause::If(_cl) => (), @@ -704,7 +703,9 @@ fn body_item_visit_bound_vars_mut(bi: &mut BodyItemNode, visitor: &mut dyn FnMut } } -fn body_item_visit_exprs_free_vars_mut(bi: &mut BodyItemNode, visitor: &mut dyn FnMut(&mut Ident), visit_macro_idents: bool){ +fn body_item_visit_exprs_free_vars_mut( + bi: &mut BodyItemNode, visitor: &mut dyn FnMut(&mut Ident), visit_macro_idents: bool, +) { let mut visit = |expr: &mut Expr| { expr_visit_free_vars_mut(expr, visitor); if visit_macro_idents { @@ -714,28 +715,29 @@ fn body_item_visit_exprs_free_vars_mut(bi: &mut BodyItemNode, visitor: &mut dyn match bi { BodyItemNode::Generator(gen) => visit(&mut gen.expr), BodyItemNode::Agg(agg) => { - for arg in agg.rel_args.iter_mut() {visit(arg)} - if let AggregatorNode::Expr(e) = &mut agg.aggregator {visit(e)} + for arg in agg.rel_args.iter_mut() { + visit(arg) + } + if let AggregatorNode::Expr(e) = &mut agg.aggregator { + visit(e) + } }, - BodyItemNode::Clause(cl) => { + BodyItemNode::Clause(cl) => for arg in cl.args.iter_mut() { if let BodyClauseArg::Expr(e) = arg { visit(e); } - } - }, - BodyItemNode::Negation(cl) => { + }, + BodyItemNode::Negation(cl) => for arg in cl.args.iter_mut() { visit(arg); - } - }, - BodyItemNode::Disjunction(disj) => { + }, + BodyItemNode::Disjunction(disj) => for conj in disj.disjuncts.iter_mut() { for bi in conj.iter_mut() { body_item_visit_exprs_free_vars_mut(bi, visitor, visit_macro_idents); } - } - }, + }, BodyItemNode::Cond(cl) => match cl { CondClause::IfLet(cl) => visit(&mut cl.exp), CondClause::If(cl) => visit(&mut cl.cond), @@ -752,35 +754,38 @@ struct GenSym(HashMap, fn(&str) -> String); impl GenSym { pub fn next(&mut self, ident: &str) -> String { match self.0.get_mut(ident) { - Some(n) => {*n += 1; format!("{}{}", self.1(ident), *n - 1)}, - None => {self.0.insert(ident.into(), 1); self.1(ident)}, + Some(n) => { + *n += 1; + format!("{}{}", self.1(ident), *n - 1) + }, + None => { + self.0.insert(ident.into(), 1); + self.1(ident) + }, } } - pub fn next_ident(&mut self, ident: &str, span: Span) -> Ident { - Ident::new(&self.next(ident), span) - } - pub fn new(transformer: fn(&str) -> String) -> Self { - Self(Default::default(), transformer) - } + pub fn next_ident(&mut self, ident: &str, span: Span) -> Ident { Ident::new(&self.next(ident), span) } + pub fn new(transformer: fn(&str) -> String) -> Self { Self(Default::default(), transformer) } } + impl Default for GenSym { - fn default() -> Self { - Self(Default::default(), |x| format!("{}_", x)) - } + fn default() -> Self { Self(Default::default(), |x| format!("{}_", x)) } } -fn body_items_rename_macro_originated_vars(bis: &mut [&mut BodyItemNode], macro_def: &MacroDefNode, gensym: &mut GenSym) { +fn body_items_rename_macro_originated_vars( + bis: &mut [&mut BodyItemNode], macro_def: &MacroDefNode, gensym: &mut GenSym, +) { let bi_vars = bis.iter().flat_map(|bi| body_item_get_bound_vars(bi)).collect_vec(); let mut mac_body_idents = token_stream_idents(macro_def.body.clone()); mac_body_idents.retain(|ident| bi_vars.contains(ident)); - let macro_originated_vars = bi_vars.iter() + let macro_originated_vars = bi_vars + .iter() .filter(|v| mac_body_idents.iter().any(|ident| spans_eq(&v.span(), &ident.span()))) .cloned() .collect::>(); - - let var_mappings = macro_originated_vars.iter() - .map(|v| (v, gensym.next(&v.to_string()))).collect::>(); + + let var_mappings = macro_originated_vars.iter().map(|v| (v, gensym.next(&v.to_string()))).collect::>(); let mut visitor = |ident: &mut Ident| { if let Some(replacement) = var_mappings.get(ident) { if mac_body_idents.iter().any(|mac_ident| spans_eq(&mac_ident.span(), &ident.span())) { @@ -805,50 +810,53 @@ fn rule_desugar_pattern_args(rule: RuleNode) -> RuleNode { BodyClauseArg::Pat(pat) => { let pattern = pat.pattern; let ident = gensym.next_ident("__arg_pattern", pattern.span()); - let new_cond_clause = quote!{ if let #pattern = #ident}; + let new_cond_clause = quote! { if let #pattern = #ident}; let new_cond_clause = CondClause::IfLet(syn::parse2(new_cond_clause).unwrap()); new_cond_clauses.push(new_cond_clause); - BodyClauseArg::Expr(syn::parse2(quote!{#ident}).unwrap()) - } + BodyClauseArg::Expr(syn::parse2(quote! {#ident}).unwrap()) + }, }; new_args.push_value(new_arg); - if let Some(punc) = punc {new_args.push_punct(punc)} + if let Some(punc) = punc { + new_args.push_punct(punc) + } } new_cond_clauses.extend(body_clause.cond_clauses); - BodyClauseNode{ - args: new_args, - cond_clauses: new_cond_clauses, - rel: body_clause.rel - } + BodyClauseNode { args: new_args, cond_clauses: new_cond_clauses, rel: body_clause.rel } } let mut gensym = GenSym::default(); - use BodyItemNode::*; RuleNode { - body_items: rule.body_items.into_iter().map(|bi| match bi { - Clause(cl) => Clause(clause_desugar_pattern_args(cl, &mut gensym)), - _ => bi}).collect(), - head_clauses: rule.head_clauses + body_items: rule + .body_items + .into_iter() + .map(|bi| match bi { + BodyItemNode::Clause(cl) => BodyItemNode::Clause(clause_desugar_pattern_args(cl, &mut gensym)), + _ => bi, + }) + .collect(), + head_clauses: rule.head_clauses, } } fn rule_desugar_repeated_vars(mut rule: RuleNode) -> RuleNode { - let mut grounded_vars = HashMap::::new(); - for i in 0..rule.body_items.len(){ + for i in 0..rule.body_items.len() { let bitem = &mut rule.body_items[i]; match bitem { BodyItemNode::Clause(cl) => { let mut new_cond_clauses = vec![]; for arg_ind in 0..cl.args.len() { let expr = cl.args[arg_ind].unwrap_expr_ref(); - let expr_has_vars_from_same_clause = - expr_get_vars(expr).iter() - .any(|var| if let Some(cl_ind) = grounded_vars.get(var) {*cl_ind == i} else {false}); + let expr_has_vars_from_same_clause = expr_get_vars(expr) + .iter() + .any(|var| if let Some(cl_ind) = grounded_vars.get(var) { *cl_ind == i } else { false }); if expr_has_vars_from_same_clause { - let new_ident = fresh_ident(&expr_to_ident(expr).map(|e| e.to_string()).unwrap_or_else(|| "expr_replaced".to_string()), expr.span()); - new_cond_clauses.push(CondClause::If( - parse2(quote_spanned! {expr.span()=> if #new_ident.eq(&(#expr))}).unwrap() - )); + let new_ident = fresh_ident( + &expr_to_ident(expr).map(|e| e.to_string()).unwrap_or_else(|| "expr_replaced".to_string()), + expr.span(), + ); + new_cond_clauses + .push(CondClause::If(parse2(quote_spanned! {expr.span()=> if #new_ident.eq(&(#expr))}).unwrap())); cl.args[arg_ind] = BodyClauseArg::Expr(parse2(new_ident.to_token_stream()).unwrap()); } else if let Some(ident) = expr_to_ident(expr) { grounded_vars.entry(ident).or_insert(i); @@ -858,23 +866,20 @@ fn rule_desugar_repeated_vars(mut rule: RuleNode) -> RuleNode { cl.cond_clauses.insert(0, new_cond_cl); } }, - BodyItemNode::Generator(gen) => { + BodyItemNode::Generator(gen) => for ident in pattern_get_vars(&gen.pattern) { grounded_vars.entry(ident).or_insert(i); - } - }, - BodyItemNode::Cond(ref cond_cl @ CondClause::IfLet(_)) | - BodyItemNode::Cond(ref cond_cl @ CondClause::Let(_)) => { + }, + BodyItemNode::Cond(ref cond_cl @ CondClause::IfLet(_)) + | BodyItemNode::Cond(ref cond_cl @ CondClause::Let(_)) => for ident in cond_cl.bound_vars() { grounded_vars.entry(ident).or_insert(i); - } - } + }, BodyItemNode::Cond(CondClause::If(_)) => (), - BodyItemNode::Agg(agg) => { - for ident in pattern_get_vars(&agg.pat){ + BodyItemNode::Agg(agg) => + for ident in pattern_get_vars(&agg.pat) { grounded_vars.entry(ident).or_insert(i); - } - }, + }, BodyItemNode::Negation(_) => (), BodyItemNode::Disjunction(_) => panic!("unrecognized BodyItemNode variant"), BodyItemNode::MacroInvocation(m) => panic!("unexpected macro invocation: {:?}", m.mac.path), @@ -890,12 +895,11 @@ fn rule_desugar_wildcards(mut rule: RuleNode) -> RuleNode { if let BodyItemNode::Clause(bcl) = bi { for arg in bcl.args.iter_mut() { match arg { - BodyClauseArg::Expr(expr) => { + BodyClauseArg::Expr(expr) => if is_wild_card(expr) { let new_ident = gensym.next_ident("_", expr.span()); *expr = parse2(quote! {#new_ident}).unwrap(); - } - } + }, BodyClauseArg::Pat(_) => (), } } @@ -909,12 +913,12 @@ fn rule_desugar_negation(mut rule: RuleNode) -> RuleNode { if let BodyItemNode::Negation(neg) = bi { let rel = &neg.rel; let args = &neg.args; - let replacement = quote_spanned! {neg.neg_token.span=> + let replacement = quote_spanned! {neg.neg_token.span=> agg () = ::ascent::aggregators::not() in #rel(#args) }; let replacement: AggClauseNode = parse2(replacement).unwrap(); *bi = BodyItemNode::Agg(replacement); - } + } } rule } @@ -925,7 +929,7 @@ fn invoke_macro(invocation: &ExprMacro, definition: &MacroDefNode) -> Result Result> { let mut ident_replacement = HashMap::new(); - for pair in definition.params.pairs(){ + for pair in definition.params.pairs() { if args.is_empty() { return Err(Error::new(span, "expected more arguments")); } @@ -934,7 +938,7 @@ fn invoke_macro(invocation: &ExprMacro, definition: &MacroDefNode) -> Result args.parse::()?.into_token_stream(), MacroParamKind::Ident(_) => args.parse::()?.into_token_stream(), }; - + ident_replacement.insert(param.name.clone(), arg); if comma.is_some() { if args.is_empty() { @@ -943,7 +947,7 @@ fn invoke_macro(invocation: &ExprMacro, definition: &MacroDefNode) -> Result()?; } } - + Ok(ident_replacement) } @@ -955,50 +959,57 @@ fn invoke_macro(invocation: &ExprMacro, definition: &MacroDefNode) -> Result) -> Result { - const RECURSIVE_MACRO_ERROR: &'static str = "recursively defined Ascent macro"; - fn body_item_expand_macros(bi: BodyItemNode, macros: &HashMap, gensym: &mut GenSym, depth: i16, span: Option) - -> Result> - { + fn body_item_expand_macros( + bi: BodyItemNode, macros: &HashMap, gensym: &mut GenSym, depth: i16, span: Option, + ) -> Result> { if depth <= 0 { return Err(Error::new(span.unwrap_or_else(Span::call_site), RECURSIVE_MACRO_ERROR)) } match bi { BodyItemNode::MacroInvocation(m) => { - let mac_def = macros.get(m.mac.path.get_ident().unwrap()) - .ok_or_else(|| Error::new(m.span(), "undefined macro"))?; + let mac_def = + macros.get(m.mac.path.get_ident().unwrap()).ok_or_else(|| Error::new(m.span(), "undefined macro"))?; let macro_invoked = invoke_macro(&m, mac_def)?; let expanded_bis = Parser::parse2(Punctuated::::parse_terminated, macro_invoked)?; - let mut recursively_expanded = - punctuated_try_map(expanded_bis, |ebi| body_item_expand_macros(ebi, macros, gensym, depth - 1, Some(m.span())))? - .pipe(flatten_punctuated); - body_items_rename_macro_originated_vars(&mut recursively_expanded.iter_mut().collect_vec(), mac_def, gensym); + let mut recursively_expanded = punctuated_try_map(expanded_bis, |ebi| { + body_item_expand_macros(ebi, macros, gensym, depth - 1, Some(m.span())) + })? + .pipe(flatten_punctuated); + body_items_rename_macro_originated_vars( + &mut recursively_expanded.iter_mut().collect_vec(), + mac_def, + gensym, + ); Ok(recursively_expanded) }, BodyItemNode::Disjunction(disj) => { - let new_disj: Punctuated, _> = punctuated_map(disj.disjuncts, |bis|{ - let new_bis = punctuated_map(bis,|bi| { + let new_disj: Punctuated, _> = punctuated_map(disj.disjuncts, |bis| { + let new_bis = punctuated_map(bis, |bi| { body_item_expand_macros(bi, macros, gensym, depth - 1, Some(disj.paren.span.join())) }); Ok(flatten_punctuated(punctuated_try_unwrap(new_bis)?)) }); - - Ok(punctuated_singleton(BodyItemNode::Disjunction(DisjunctionNode{disjuncts: punctuated_try_unwrap(new_disj)?, .. disj}))) + + Ok(punctuated_singleton(BodyItemNode::Disjunction(DisjunctionNode { + disjuncts: punctuated_try_unwrap(new_disj)?, + ..disj + }))) }, - _ => Ok(punctuated_singleton(bi)) + _ => Ok(punctuated_singleton(bi)), } } - fn head_item_expand_macros(hi: HeadItemNode, macros: &HashMap, depth: i16, span: Option) - -> Result> - { + fn head_item_expand_macros( + hi: HeadItemNode, macros: &HashMap, depth: i16, span: Option, + ) -> Result> { if depth <= 0 { return Err(Error::new(span.unwrap_or_else(Span::call_site), RECURSIVE_MACRO_ERROR)) } match hi { HeadItemNode::MacroInvocation(m) => { - let mac_def = macros.get(m.mac.path.get_ident().unwrap()) - .ok_or_else(|| Error::new(m.span(), "undefined macro"))?; + let mac_def = + macros.get(m.mac.path.get_ident().unwrap()).ok_or_else(|| Error::new(m.span(), "undefined macro"))?; let macro_invoked = invoke_macro(&m, mac_def)?; let expanded_his = Parser::parse2(Punctuated::::parse_terminated, macro_invoked)?; @@ -1012,27 +1023,29 @@ fn rule_expand_macro_invocations(rule: RuleNode, macros: &HashMap>>()? - .into_iter().flatten().collect_vec(); + let new_body_items = rule + .body_items + .into_iter() + .map(|bi| body_item_expand_macros(bi, macros, &mut gensym, 100, None)) + .collect::>>()? + .into_iter() + .flatten() + .collect_vec(); let new_head_items = punctuated_map(rule.head_clauses, |hi| head_item_expand_macros(hi, macros, 100, None)) - .pipe(punctuated_try_unwrap)? - .pipe(flatten_punctuated); - - Ok(RuleNode {body_items: new_body_items, head_clauses: new_head_items}) + .pipe(punctuated_try_unwrap)? + .pipe(flatten_punctuated); + + Ok(RuleNode { body_items: new_body_items, head_clauses: new_head_items }) } pub(crate) fn desugar_ascent_program(mut prog: AscentProgram) -> Result { - let macros = prog.macros.iter().map(|m| (m.name.clone(), m)).collect::>(); - let rules_macro_expanded = - prog.rules.into_iter() - .map(|r| rule_expand_macro_invocations(r, ¯os)) - .collect::>>()?; - - prog.rules = - rules_macro_expanded.into_iter() + let macros = prog.macros.iter().map(|m| (m.name.clone(), m)).collect::>(); + let rules_macro_expanded = + prog.rules.into_iter().map(|r| rule_expand_macro_invocations(r, ¯os)).collect::>>()?; + + prog.rules = rules_macro_expanded + .into_iter() .flat_map(rule_desugar_disjunction_nodes) .map(rule_desugar_pattern_args) .map(rule_desugar_wildcards) @@ -1056,6 +1069,6 @@ fn fresh_ident(prefix: &str, span: Span) -> Ident { ident_counters_lock.insert(prefix.to_string(), 1); "".to_string() }; - + Ident::new(&format!("{}_{}", prefix, counter), span) -} \ No newline at end of file +} diff --git a/ascent_macro/src/lib.rs b/ascent_macro/src/lib.rs index 0160ef0..577026b 100644 --- a/ascent_macro/src/lib.rs +++ b/ascent_macro/src/lib.rs @@ -17,10 +17,13 @@ extern crate proc_macro; use ascent_syntax::{AscentProgram, desugar_ascent_program}; use proc_macro::TokenStream; use syn::Result; -use crate::{ascent_codegen::compile_mir, ascent_hir::compile_ascent_program_to_hir, ascent_mir::compile_hir_to_mir}; + +use crate::ascent_codegen::compile_mir; +use crate::ascent_hir::compile_ascent_program_to_hir; +use crate::ascent_mir::compile_hir_to_mir; /// The main macro of the ascent library. Allows writing logical inference rules similar to Datalog. -/// +/// /// Example: /// ``` /// # #[macro_use] extern crate ascent_macro; @@ -32,7 +35,7 @@ use crate::{ascent_codegen::compile_mir, ascent_hir::compile_ascent_program_to_h /// path(x, y) <-- edge(x,y); /// path(x, z) <-- edge(x,y), path(y, z); /// } -/// +/// /// fn main() { /// let mut tc_comp = AscentProgram::default(); /// tc_comp.edge = vec![(1,2), (2,3)]; @@ -45,7 +48,7 @@ use crate::{ascent_codegen::compile_mir, ascent_hir::compile_ascent_program_to_h #[proc_macro] pub fn ascent(input: TokenStream) -> TokenStream { let res = ascent_impl(input.into(), false, false); - + match res { Ok(res) => res.into(), Err(err) => TokenStream::from(err.to_compile_error()), @@ -53,22 +56,21 @@ pub fn ascent(input: TokenStream) -> TokenStream { } /// Similar to `ascent`, allows writing logic programs in Rust. -/// -/// The difference is that `ascent_par` generates parallelized code. +/// +/// The difference is that `ascent_par` generates parallelized code. #[proc_macro] pub fn ascent_par(input: TokenStream) -> TokenStream { let res = ascent_impl(input.into(), false, true); - + match res { Ok(res) => res.into(), Err(err) => TokenStream::from(err.to_compile_error()), } } - /// Like `ascent`, except that the result of an `ascent_run` invocation is a value containing all the relations /// defined inside the macro body, and computed to a fixed point. -/// +/// /// The advantage of `ascent_run` compared to `ascent` is the fact that `ascent_run` has access to local variables /// in scope: /// ``` @@ -85,7 +87,7 @@ pub fn ascent_par(input: TokenStream) -> TokenStream { #[proc_macro] pub fn ascent_run(input: TokenStream) -> TokenStream { let res = ascent_impl(input.into(), true, false); - + match res { Ok(res) => res.into(), Err(err) => TokenStream::from(err.to_compile_error()), @@ -96,20 +98,22 @@ pub fn ascent_run(input: TokenStream) -> TokenStream { #[proc_macro] pub fn ascent_run_par(input: TokenStream) -> TokenStream { let res = ascent_impl(input.into(), true, true); - + match res { Ok(res) => res.into(), Err(err) => TokenStream::from(err.to_compile_error()), } } -pub(crate) fn ascent_impl(input: proc_macro2::TokenStream, is_ascent_run: bool, is_parallel: bool) -> Result { +pub(crate) fn ascent_impl( + input: proc_macro2::TokenStream, is_ascent_run: bool, is_parallel: bool, +) -> Result { let prog: AscentProgram = syn::parse2(input)?; // println!("prog relations: {}", prog.relations.len()); // println!("parse res: {} relations, {} rules", prog.relations.len(), prog.rules.len()); let prog = desugar_ascent_program(prog)?; - + let hir = compile_ascent_program_to_hir(&prog, is_parallel)?; // println!("hir relations: {}", hir.relations_ir_relations.keys().map(|r| &r.name).join(", ")); diff --git a/ascent_macro/src/syn_utils.rs b/ascent_macro/src/syn_utils.rs index 24b6d9e..60a9b16 100644 --- a/ascent_macro/src/syn_utils.rs +++ b/ascent_macro/src/syn_utils.rs @@ -1,507 +1,511 @@ -#![deny(warnings)] -use std::collections::HashSet; -use std::ops::{Deref, DerefMut}; -use proc_macro2::{Ident, TokenStream, TokenTree, Group}; - -use quote::ToTokens; -use syn::visit_mut::VisitMut; -use syn::{Block, Stmt, ExprMacro}; -use syn::{Expr, Pat, Path}; -use crate::utils::{collect_set, into_set}; -use duplicate::duplicate_item; -use ascent_base::util::update; - -#[cfg(test)] -use syn::parse2; - -// TODO maybe remove? -#[allow(unused)] -pub fn block_get_vars(block: &Block) -> Vec { - let mut bound_vars = HashSet::new(); - let mut used_vars = vec![]; - for stmt in block.stmts.iter() { - let (stmt_bound_vars, stmt_used_vars) = stmt_get_vars(stmt); - for used_var in stmt_used_vars.into_iter() { - if !bound_vars.contains(&used_var) { - used_vars.push(used_var); - } - } - bound_vars.extend(stmt_bound_vars); - } - used_vars -} - -pub fn pattern_get_vars(pat: &Pat) -> Vec { - let mut res = vec![]; - match pat { - Pat::Ident(pat_ident) => { - res.push(pat_ident.ident.clone()); - if let Some(subpat) = &pat_ident.subpat { - res.extend(pattern_get_vars(&subpat.1)) - } - }, - Pat::Lit(_) => {}, - Pat::Macro(_) => {}, - Pat::Or(or_pat) => { - let cases_vars = or_pat.cases.iter().map(pattern_get_vars).map(into_set); - let intersection = cases_vars.reduce(|case_vars, accu| collect_set(case_vars.intersection(&accu).cloned())); - if let Some(intersection) = intersection { - res.extend(intersection); - } - }, - Pat::Path(_) => {}, - Pat::Range(_) => {}, - Pat::Reference(ref_pat) => res.extend(pattern_get_vars(&ref_pat.pat)), - Pat::Rest(_) => {}, - Pat::Slice(slice_pat) => { - for sub_pat in slice_pat.elems.iter(){ - res.extend(pattern_get_vars(sub_pat)); - } - }, - Pat::Struct(struct_pat) => { - for field_pat in struct_pat.fields.iter() { - res.extend(pattern_get_vars(&field_pat.pat)); - } - }, - Pat::Tuple(tuple_pat) => { - for elem_pat in tuple_pat.elems.iter() { - res.extend(pattern_get_vars(elem_pat)); - } - } - Pat::TupleStruct(tuple_strcut_pat) => { - for elem_pat in tuple_strcut_pat.elems.iter(){ - res.extend(pattern_get_vars(elem_pat)); - } - }, - Pat::Type(type_pat) => { - res.extend(pattern_get_vars(&type_pat.pat)); - }, - Pat::Verbatim(_) => {}, - Pat::Wild(_) => {}, - _ => {} - } - // println!("pattern vars {} : {}", pat.to_token_stream(), res.iter().map(|ident| ident.to_string()).join(", ")); - res -} - - -pub fn pattern_visit_vars_mut(pat: &mut Pat, visitor: &mut dyn FnMut(&mut Ident)) { - macro_rules! visit { - ($e: expr) => { - pattern_visit_vars_mut($e, visitor) - }; - } - match pat { - Pat::Ident(pat_ident) => { - visitor(&mut pat_ident.ident); - if let Some(subpat) = &mut pat_ident.subpat { - visit!(&mut subpat.1); - } - }, - Pat::Lit(_) => {}, - Pat::Macro(_) => {}, - Pat::Or(or_pat) => { - for case in or_pat.cases.iter_mut() { - visit!(case) - } - }, - Pat::Path(_) => {}, - Pat::Range(_) => {}, - Pat::Reference(ref_pat) => visit!(&mut ref_pat.pat), - Pat::Rest(_) => {}, - Pat::Slice(slice_pat) => { - for sub_pat in slice_pat.elems.iter_mut(){ - visit!(sub_pat); - } - }, - Pat::Struct(struct_pat) => { - for field_pat in struct_pat.fields.iter_mut() { - visit!(&mut field_pat.pat); - } - }, - Pat::Tuple(tuple_pat) => { - for elem_pat in tuple_pat.elems.iter_mut() { - visit!(elem_pat); - } - } - Pat::TupleStruct(tuple_strcut_pat) => { - for elem_pat in tuple_strcut_pat.elems.iter_mut(){ - visit!(elem_pat); - } - }, - Pat::Type(type_pat) => { - visit!(&mut type_pat.pat); - }, - Pat::Verbatim(_) => {}, - Pat::Wild(_) => {}, - _ => {} - } -} - -#[test] -fn test_pattern_get_vars(){ - use syn::parse::Parser; - - let pattern = quote! { - SomePair(x, (y, z)) - }; - let pat = Pat::parse_single.parse2(pattern).unwrap(); - assert_eq!(collect_set(["x", "y", "z"].iter().map(ToString::to_string)), - pattern_get_vars(&pat).into_iter().map(|id| id.to_string()).collect()); - -} - -/// if the expression is a let expression (for example in `if let Some(foo) = bar {..}`), -/// returns the variables bound by the let expression -pub fn expr_get_let_bound_vars(expr: &Expr) -> Vec { - match expr { - Expr::Let(l) => pattern_get_vars(&l.pat), - _ => vec![] - } -} - -pub fn stmt_get_vars(stmt: &Stmt) -> (Vec, Vec) { - let mut bound_vars = vec![]; - let mut used_vars = vec![]; - match stmt { - Stmt::Local(l) => { - bound_vars.extend(pattern_get_vars(&l.pat)); - if let Some(init) = &l.init { - used_vars.extend(expr_get_vars(&init.expr)); - if let Some(diverge) = &init.diverge { - used_vars.extend(expr_get_let_bound_vars(&diverge.1)); - } - } - }, - Stmt::Item(_) => {}, - Stmt::Expr(e, _) => used_vars.extend(expr_get_vars(e)), - Stmt::Macro(m) => { - eprintln!("WARNING: cannot determine variables of macro invocations. macro invocation:\n{}", - m.to_token_stream()); - } - } - (bound_vars, used_vars) -} - -pub fn stmt_visit_free_vars_mut(stmt: &mut Stmt, visitor: &mut dyn FnMut(&mut Ident)) { - match stmt { - Stmt::Local(l) => { - if let Some(init) = &mut l.init { - expr_visit_free_vars_mut(&mut init.expr, visitor); - if let Some(diverge) = &mut init.diverge { - expr_visit_free_vars_mut(&mut diverge.1, visitor); - } - } - }, - Stmt::Item(_) => {}, - Stmt::Expr(e, _) => expr_visit_free_vars_mut(e, visitor), - Stmt::Macro(m) => { - eprintln!("WARNING: cannot determine free variables of macro invocations. macro invocation:\n{}", - m.to_token_stream()); - } - } -} - -pub fn stmt_visit_free_vars(stmt: &Stmt, visitor: &mut dyn FnMut(& Ident)) { - match stmt { - Stmt::Local(l) => { - if let Some(init) = &l.init { - expr_visit_free_vars(&init.expr, visitor); - if let Some(diverge) = &init.diverge { - expr_visit_free_vars(&diverge.1, visitor); - } - } - }, - Stmt::Item(_) => {}, - Stmt::Expr(e, _) => expr_visit_free_vars(e, visitor), - Stmt::Macro(m) => { - eprintln!("WARNING: cannot determine free variables of macro invocations. macro invocation:\n{}", - m.to_token_stream()); - } - } -} - -pub fn block_visit_free_vars_mut(block: &mut Block, visitor: &mut dyn FnMut(&mut Ident)) { - let mut bound_vars = HashSet::new(); - for stmt in block.stmts.iter_mut() { - let (stmt_bound_vars, _) = stmt_get_vars(stmt); - stmt_visit_free_vars_mut(stmt, &mut |ident| if !bound_vars.contains(ident) {visitor(ident)}); - bound_vars.extend(stmt_bound_vars); - } -} - -pub fn block_visit_free_vars(block: &Block, visitor: &mut dyn FnMut(&Ident)) { - let mut bound_vars = HashSet::new(); - for stmt in block.stmts.iter() { - let (stmt_bound_vars, _) = stmt_get_vars(stmt); - stmt_visit_free_vars(stmt, &mut |ident| if !bound_vars.contains(ident) {visitor(ident)}); - bound_vars.extend(stmt_bound_vars); - } -} - - -// all this nonsense to have two versions of this function: -//expr_visit_free_vars and expr_visit_free_vars_mut -#[duplicate_item( - reft(type) expr_visit_free_vars_mbm block_visit_free_vars_mbm iter_mbm path_get_ident_mbm deref_mbm; - [&mut type] [expr_visit_free_vars_mut] [block_visit_free_vars_mut] [iter_mut] [path_get_ident_mut] [deref_mut]; - [&type] [expr_visit_free_vars] [block_visit_free_vars] [iter] [Path::get_ident] [deref]; - )] -/// visits free variables in the expr -pub fn expr_visit_free_vars_mbm(expr: reft([Expr]), visitor: &mut dyn FnMut(reft([Ident]))) { - macro_rules! visit { - ($e: expr) => { expr_visit_free_vars_mbm(reft([$e]), visitor)}; - } - macro_rules! visitor_except { - ($excluded: expr) => { - &mut |ident| {if ! $excluded.contains(ident) {visitor(ident)}} - }; - } - macro_rules! visit_except { - ($e: expr, $excluded: expr) => { expr_visit_free_vars_mbm($e, visitor_except!($excluded))}; - } - match expr { - Expr::Array(arr) => { - for elem in arr.elems.iter_mbm() { - expr_visit_free_vars_mbm(elem, visitor); - } - } - Expr::Assign(assign) => { - visit!(assign.left); - visit!(assign.right) - }, - Expr::Async(a) => block_visit_free_vars_mbm(reft([a.block]), visitor), - Expr::Await(a) => visit!(a.base), - Expr::Binary(b) => { - visit!(b.left); - visit!(b.right) - } - Expr::Block(b) => block_visit_free_vars_mbm(reft([b.block]), visitor), - Expr::Break(b) => if let Some(b_e) = reft([b.expr]) {expr_visit_free_vars_mbm(b_e, visitor)}, - Expr::Call(c) => { - visit!(c.func); - for arg in c.args.iter_mbm() { - expr_visit_free_vars_mbm(arg, visitor) - } - } - Expr::Cast(c) => visit!(c.expr), - Expr::Closure(c) => { - let input_vars : HashSet<_> = c.inputs.iter().flat_map(pattern_get_vars).collect(); - visit_except!(reft([c.body]), input_vars); - }, - Expr::Continue(_c) => {} - Expr::Field(f) => visit!(f.base), - Expr::ForLoop(f) => { - let pat_vars: HashSet<_> = pattern_get_vars(&f.pat).into_iter().collect(); - visit!(f.expr); - block_visit_free_vars_mbm(reft([f.body]), visitor_except!(pat_vars)); - }, - Expr::Group(g) => visit!(g.expr), - Expr::If(e) => { - let bound_vars = expr_get_let_bound_vars(&e.cond).into_iter().collect::>(); - visit!(e.cond); - block_visit_free_vars_mbm(reft([e.then_branch]), visitor_except!(bound_vars)); - if let Some(eb) = reft([e.else_branch]) { - visit!(eb.1) - } - } - Expr::Index(i) => { - visit!(i.expr); - visit!(i.index) - } - Expr::Let(l) => visit!(l.expr), - Expr::Lit(_) => {} - Expr::Loop(l) => block_visit_free_vars_mbm(reft([l.body]), visitor), - Expr::Macro(_m) => { - eprintln!("WARNING: cannot determine free variables of macro invocations. macro invocation:\n{}", - expr.to_token_stream()) - }, - Expr::Match(m) => { - visit!(m.expr); - for arm in m.arms.iter_mbm() { - if let Some(g) = reft([arm.guard]) { - visit!(g.1); - } - let arm_vars = pattern_get_vars(&arm.pat).into_iter().collect::>(); - visit_except!(reft([arm.body]), arm_vars); - } - } - Expr::MethodCall(c) => { - visit!(c.receiver); - for arg in c.args.iter_mbm() { - expr_visit_free_vars_mbm(arg, visitor) - } - } - Expr::Paren(p) => visit!(p.expr), - Expr::Path(p) => { - if let Some(ident) = path_get_ident_mbm(reft([p.path])) { - visitor(ident) - } - } - Expr::Range(r) => { - if let Some(start) = reft([r.start]) { - expr_visit_free_vars_mbm(start, visitor) - }; - if let Some(end) = reft([r.end]) { - expr_visit_free_vars_mbm(end, visitor) - }; - } - Expr::Reference(r) => visit!(r.expr), - Expr::Repeat(r) => { - visit!(r.expr); - visit!(r.len) - } - Expr::Return(r) => { - if let Some(e) = reft([r.expr]) { expr_visit_free_vars_mbm(e, visitor) } - } - Expr::Struct(s) => { - for f in s.fields.iter_mbm() { - visit!(f.expr) - } - if let Some(rest) = reft([s.rest]) { - expr_visit_free_vars_mbm(rest.deref_mbm(), visitor) - } - } - Expr::Try(t) => visit!(t.expr), - Expr::TryBlock(t) => block_visit_free_vars_mbm(reft([t.block]), visitor), - Expr::Tuple(t) => { - for e in t.elems.iter_mbm() { - expr_visit_free_vars_mbm(e, visitor) - } - } - Expr::Unary(u) => visit!(u.expr), - Expr::Unsafe(u) => block_visit_free_vars_mbm(reft([u.block]), visitor), - Expr::Verbatim(_) => {} - Expr::While(w) => { - let bound_vars = expr_get_let_bound_vars(&w.cond).into_iter().collect::>(); - visit!(w.cond); - block_visit_free_vars_mbm(reft([w.body]), visitor_except!(bound_vars)) - } - Expr::Yield(y) => { - if let Some(e) = reft([y.expr]) { - expr_visit_free_vars_mbm(e.deref_mbm(), visitor) - } - } - _ => {} - } -} - -/// like `Path::get_ident(&self)`, but `mut` -pub fn path_get_ident_mut(path: &mut Path) -> Option<&mut Ident> { - if path.segments.len() != 1 || path.leading_colon.is_some() {return None} - let res = path.segments.first_mut()?; - if res.arguments.is_empty() { - Some(&mut res.ident) - } else { - None - } -} - -pub fn expr_get_vars(expr: &Expr) -> Vec { - let mut res = vec![]; - expr_visit_free_vars(expr, &mut |ident| res.push(ident.clone())); - res -} - -#[test] -fn test_expr_get_vars(){ - let test_cases = [ - (quote! { - { - let res = 0; - for i in [0..10] { - let x = i + a; - res += x / {|m, (n, o)| m + n - o}(2, (b, 42)) - } - res - } - }, vec!["a", "b"]), - (quote! { - |x1: u32, x2: u32| { - if y > x1 { - if let Some(z) = foo(x1) { - z + w - } else { - t = 42; - x2 - } - } - } - }, vec!["y", "foo", "w", "t"]) - ]; - - for (expr, expected) in test_cases { - let mut expr = parse2(expr).unwrap(); - let result = expr_get_vars(&mut expr); - let result = result.into_iter().map(|v| v.to_string()).collect::>(); - let expected = expected.into_iter().map(|v| v.to_string()).collect::>(); - println!("result: {:?}", result); - println!("expected: {:?}\n", expected); - assert_eq!(result, expected) - } -} - - -pub fn token_stream_replace_ident(ts: TokenStream, visitor: &mut dyn FnMut(&mut Ident)) -> TokenStream { - - fn token_tree_replace_ident(mut tt: TokenTree, visitor: &mut dyn FnMut(&mut Ident)) -> TokenTree { - match tt { - TokenTree::Group(grp) => { - let updated_ts = token_stream_replace_ident(grp.stream(), visitor); - let new_grp = Group::new(grp.delimiter(), updated_ts); - TokenTree::Group(new_grp) - }, - TokenTree::Ident(ref mut ident) =>{visitor(ident); tt}, - TokenTree::Punct(_) => tt, - TokenTree::Literal(_) => tt, - } - } - - let mut new_tts = vec![]; - for tt in ts.into_iter() { - new_tts.push(token_tree_replace_ident(tt, visitor)); - } - TokenStream::from_iter(new_tts) -} - -pub fn token_stream_visit_idents(ts: TokenStream, visitor: &mut impl FnMut(&Ident)) { - - fn token_tree_visit_idents(mut tt: TokenTree, visitor: &mut impl FnMut(&Ident)) { - match tt { - TokenTree::Group(grp) => { - token_stream_visit_idents(grp.stream(), visitor); - }, - TokenTree::Ident(ref mut ident) =>visitor(ident), - TokenTree::Punct(_) => (), - TokenTree::Literal(_) => (), - } - } - - for tt in ts.into_iter() { - token_tree_visit_idents(tt, visitor); - } -} - -pub fn token_stream_idents(ts: TokenStream) -> Vec { - let mut res = vec![]; - token_stream_visit_idents(ts, &mut |ident| res.push(ident.clone())); - res -} - -pub fn expr_visit_macros_mut(expr: &mut Expr, visitor: &mut dyn FnMut(&mut ExprMacro)) { - struct Visitor<'a>(&'a mut dyn FnMut(&mut ExprMacro)); - impl<'a> syn::visit_mut::VisitMut for Visitor<'a> { - fn visit_expr_macro_mut(&mut self, node: &mut ExprMacro) { - (self.0)(node) - } - } - Visitor(visitor).visit_expr_mut(expr) -} - -pub fn expr_visit_idents_in_macros_mut(expr: &mut Expr, visitor: &mut dyn FnMut(&mut Ident)) { - let mut mac_visitor = |mac: &mut ExprMacro| { - update(&mut mac.mac.tokens, |ts| token_stream_replace_ident(ts, visitor)); - }; - expr_visit_macros_mut(expr, &mut mac_visitor) -} - +#![deny(warnings)] +use std::collections::HashSet; +use std::ops::{Deref, DerefMut}; + +use ascent_base::util::update; +use duplicate::duplicate_item; +use proc_macro2::{Group, Ident, TokenStream, TokenTree}; +use quote::ToTokens; +#[cfg(test)] +use syn::parse2; +use syn::visit_mut::VisitMut; +use syn::{Block, Expr, ExprMacro, Pat, Path, Stmt}; + +use crate::utils::{collect_set, into_set}; + +// TODO maybe remove? +#[allow(unused)] +pub fn block_get_vars(block: &Block) -> Vec { + let mut bound_vars = HashSet::new(); + let mut used_vars = vec![]; + for stmt in block.stmts.iter() { + let (stmt_bound_vars, stmt_used_vars) = stmt_get_vars(stmt); + for used_var in stmt_used_vars.into_iter() { + if !bound_vars.contains(&used_var) { + used_vars.push(used_var); + } + } + bound_vars.extend(stmt_bound_vars); + } + used_vars +} + +pub fn pattern_get_vars(pat: &Pat) -> Vec { + let mut res = vec![]; + match pat { + Pat::Ident(pat_ident) => { + res.push(pat_ident.ident.clone()); + if let Some(subpat) = &pat_ident.subpat { + res.extend(pattern_get_vars(&subpat.1)) + } + }, + Pat::Lit(_) => {}, + Pat::Macro(_) => {}, + Pat::Or(or_pat) => { + let cases_vars = or_pat.cases.iter().map(pattern_get_vars).map(into_set); + let intersection = cases_vars.reduce(|case_vars, accu| collect_set(case_vars.intersection(&accu).cloned())); + if let Some(intersection) = intersection { + res.extend(intersection); + } + }, + Pat::Path(_) => {}, + Pat::Range(_) => {}, + Pat::Reference(ref_pat) => res.extend(pattern_get_vars(&ref_pat.pat)), + Pat::Rest(_) => {}, + Pat::Slice(slice_pat) => + for sub_pat in slice_pat.elems.iter() { + res.extend(pattern_get_vars(sub_pat)); + }, + Pat::Struct(struct_pat) => + for field_pat in struct_pat.fields.iter() { + res.extend(pattern_get_vars(&field_pat.pat)); + }, + Pat::Tuple(tuple_pat) => + for elem_pat in tuple_pat.elems.iter() { + res.extend(pattern_get_vars(elem_pat)); + }, + Pat::TupleStruct(tuple_strcut_pat) => + for elem_pat in tuple_strcut_pat.elems.iter() { + res.extend(pattern_get_vars(elem_pat)); + }, + Pat::Type(type_pat) => { + res.extend(pattern_get_vars(&type_pat.pat)); + }, + Pat::Verbatim(_) => {}, + Pat::Wild(_) => {}, + _ => {}, + } + // println!("pattern vars {} : {}", pat.to_token_stream(), res.iter().map(|ident| ident.to_string()).join(", ")); + res +} + +pub fn pattern_visit_vars_mut(pat: &mut Pat, visitor: &mut dyn FnMut(&mut Ident)) { + macro_rules! visit { + ($e: expr) => { + pattern_visit_vars_mut($e, visitor) + }; + } + match pat { + Pat::Ident(pat_ident) => { + visitor(&mut pat_ident.ident); + if let Some(subpat) = &mut pat_ident.subpat { + visit!(&mut subpat.1); + } + }, + Pat::Lit(_) => {}, + Pat::Macro(_) => {}, + Pat::Or(or_pat) => + for case in or_pat.cases.iter_mut() { + visit!(case) + }, + Pat::Path(_) => {}, + Pat::Range(_) => {}, + Pat::Reference(ref_pat) => visit!(&mut ref_pat.pat), + Pat::Rest(_) => {}, + Pat::Slice(slice_pat) => + for sub_pat in slice_pat.elems.iter_mut() { + visit!(sub_pat); + }, + Pat::Struct(struct_pat) => + for field_pat in struct_pat.fields.iter_mut() { + visit!(&mut field_pat.pat); + }, + Pat::Tuple(tuple_pat) => + for elem_pat in tuple_pat.elems.iter_mut() { + visit!(elem_pat); + }, + Pat::TupleStruct(tuple_strcut_pat) => + for elem_pat in tuple_strcut_pat.elems.iter_mut() { + visit!(elem_pat); + }, + Pat::Type(type_pat) => { + visit!(&mut type_pat.pat); + }, + Pat::Verbatim(_) => {}, + Pat::Wild(_) => {}, + _ => {}, + } +} + +#[test] +fn test_pattern_get_vars() { + use syn::parse::Parser; + + let pattern = quote! { + SomePair(x, (y, z)) + }; + let pat = Pat::parse_single.parse2(pattern).unwrap(); + assert_eq!( + collect_set(["x", "y", "z"].iter().map(ToString::to_string)), + pattern_get_vars(&pat).into_iter().map(|id| id.to_string()).collect() + ); +} + +/// if the expression is a let expression (for example in `if let Some(foo) = bar {..}`), +/// returns the variables bound by the let expression +pub fn expr_get_let_bound_vars(expr: &Expr) -> Vec { + match expr { + Expr::Let(l) => pattern_get_vars(&l.pat), + _ => vec![], + } +} + +pub fn stmt_get_vars(stmt: &Stmt) -> (Vec, Vec) { + let mut bound_vars = vec![]; + let mut used_vars = vec![]; + match stmt { + Stmt::Local(l) => { + bound_vars.extend(pattern_get_vars(&l.pat)); + if let Some(init) = &l.init { + used_vars.extend(expr_get_vars(&init.expr)); + if let Some(diverge) = &init.diverge { + used_vars.extend(expr_get_let_bound_vars(&diverge.1)); + } + } + }, + Stmt::Item(_) => {}, + Stmt::Expr(e, _) => used_vars.extend(expr_get_vars(e)), + Stmt::Macro(m) => { + eprintln!( + "WARNING: cannot determine variables of macro invocations. macro invocation:\n{}", + m.to_token_stream() + ); + }, + } + (bound_vars, used_vars) +} + +pub fn stmt_visit_free_vars_mut(stmt: &mut Stmt, visitor: &mut dyn FnMut(&mut Ident)) { + match stmt { + Stmt::Local(l) => + if let Some(init) = &mut l.init { + expr_visit_free_vars_mut(&mut init.expr, visitor); + if let Some(diverge) = &mut init.diverge { + expr_visit_free_vars_mut(&mut diverge.1, visitor); + } + }, + Stmt::Item(_) => {}, + Stmt::Expr(e, _) => expr_visit_free_vars_mut(e, visitor), + Stmt::Macro(m) => { + eprintln!( + "WARNING: cannot determine free variables of macro invocations. macro invocation:\n{}", + m.to_token_stream() + ); + }, + } +} + +pub fn stmt_visit_free_vars(stmt: &Stmt, visitor: &mut dyn FnMut(&Ident)) { + match stmt { + Stmt::Local(l) => + if let Some(init) = &l.init { + expr_visit_free_vars(&init.expr, visitor); + if let Some(diverge) = &init.diverge { + expr_visit_free_vars(&diverge.1, visitor); + } + }, + Stmt::Item(_) => {}, + Stmt::Expr(e, _) => expr_visit_free_vars(e, visitor), + Stmt::Macro(m) => { + eprintln!( + "WARNING: cannot determine free variables of macro invocations. macro invocation:\n{}", + m.to_token_stream() + ); + }, + } +} + +pub fn block_visit_free_vars_mut(block: &mut Block, visitor: &mut dyn FnMut(&mut Ident)) { + let mut bound_vars = HashSet::new(); + for stmt in block.stmts.iter_mut() { + let (stmt_bound_vars, _) = stmt_get_vars(stmt); + stmt_visit_free_vars_mut(stmt, &mut |ident| { + if !bound_vars.contains(ident) { + visitor(ident) + } + }); + bound_vars.extend(stmt_bound_vars); + } +} + +pub fn block_visit_free_vars(block: &Block, visitor: &mut dyn FnMut(&Ident)) { + let mut bound_vars = HashSet::new(); + for stmt in block.stmts.iter() { + let (stmt_bound_vars, _) = stmt_get_vars(stmt); + stmt_visit_free_vars(stmt, &mut |ident| { + if !bound_vars.contains(ident) { + visitor(ident) + } + }); + bound_vars.extend(stmt_bound_vars); + } +} + +// all this nonsense to have two versions of this function: +//expr_visit_free_vars and expr_visit_free_vars_mut +#[duplicate_item( + reft(type) expr_visit_free_vars_mbm block_visit_free_vars_mbm iter_mbm path_get_ident_mbm deref_mbm; + [&mut type] [expr_visit_free_vars_mut] [block_visit_free_vars_mut] [iter_mut] [path_get_ident_mut] [deref_mut]; + [&type] [expr_visit_free_vars] [block_visit_free_vars] [iter] [Path::get_ident] [deref]; + )] +/// visits free variables in the expr +pub fn expr_visit_free_vars_mbm(expr: reft([Expr]), visitor: &mut dyn FnMut(reft([Ident]))) { + macro_rules! visit { + ($e: expr) => { expr_visit_free_vars_mbm(reft([$e]), visitor)}; + } + macro_rules! visitor_except { + ($excluded: expr) => { + &mut |ident| {if ! $excluded.contains(ident) {visitor(ident)}} + }; + } + macro_rules! visit_except { + ($e: expr, $excluded: expr) => { expr_visit_free_vars_mbm($e, visitor_except!($excluded))}; + } + match expr { + Expr::Array(arr) => + for elem in arr.elems.iter_mbm() { + expr_visit_free_vars_mbm(elem, visitor); + }, + Expr::Assign(assign) => { + visit!(assign.left); + visit!(assign.right) + }, + Expr::Async(a) => block_visit_free_vars_mbm(reft([a.block]), visitor), + Expr::Await(a) => visit!(a.base), + Expr::Binary(b) => { + visit!(b.left); + visit!(b.right) + }, + Expr::Block(b) => block_visit_free_vars_mbm(reft([b.block]), visitor), + Expr::Break(b) => + if let Some(b_e) = reft([b.expr]) { + expr_visit_free_vars_mbm(b_e, visitor) + }, + Expr::Call(c) => { + visit!(c.func); + for arg in c.args.iter_mbm() { + expr_visit_free_vars_mbm(arg, visitor) + } + }, + Expr::Cast(c) => visit!(c.expr), + Expr::Closure(c) => { + let input_vars: HashSet<_> = c.inputs.iter().flat_map(pattern_get_vars).collect(); + visit_except!(reft([c.body]), input_vars); + }, + Expr::Continue(_c) => {}, + Expr::Field(f) => visit!(f.base), + Expr::ForLoop(f) => { + let pat_vars: HashSet<_> = pattern_get_vars(&f.pat).into_iter().collect(); + visit!(f.expr); + block_visit_free_vars_mbm(reft([f.body]), visitor_except!(pat_vars)); + }, + Expr::Group(g) => visit!(g.expr), + Expr::If(e) => { + let bound_vars = expr_get_let_bound_vars(&e.cond).into_iter().collect::>(); + visit!(e.cond); + block_visit_free_vars_mbm(reft([e.then_branch]), visitor_except!(bound_vars)); + if let Some(eb) = reft([e.else_branch]) { + visit!(eb.1) + } + }, + Expr::Index(i) => { + visit!(i.expr); + visit!(i.index) + }, + Expr::Let(l) => visit!(l.expr), + Expr::Lit(_) => {}, + Expr::Loop(l) => block_visit_free_vars_mbm(reft([l.body]), visitor), + Expr::Macro(_m) => { + eprintln!( + "WARNING: cannot determine free variables of macro invocations. macro invocation:\n{}", + expr.to_token_stream() + ) + }, + Expr::Match(m) => { + visit!(m.expr); + for arm in m.arms.iter_mbm() { + if let Some(g) = reft([arm.guard]) { + visit!(g.1); + } + let arm_vars = pattern_get_vars(&arm.pat).into_iter().collect::>(); + visit_except!(reft([arm.body]), arm_vars); + } + }, + Expr::MethodCall(c) => { + visit!(c.receiver); + for arg in c.args.iter_mbm() { + expr_visit_free_vars_mbm(arg, visitor) + } + }, + Expr::Paren(p) => visit!(p.expr), + Expr::Path(p) => + if let Some(ident) = path_get_ident_mbm(reft([p.path])) { + visitor(ident) + }, + Expr::Range(r) => { + if let Some(start) = reft([r.start]) { + expr_visit_free_vars_mbm(start, visitor) + }; + if let Some(end) = reft([r.end]) { + expr_visit_free_vars_mbm(end, visitor) + }; + }, + Expr::Reference(r) => visit!(r.expr), + Expr::Repeat(r) => { + visit!(r.expr); + visit!(r.len) + }, + Expr::Return(r) => + if let Some(e) = reft([r.expr]) { + expr_visit_free_vars_mbm(e, visitor) + }, + Expr::Struct(s) => { + for f in s.fields.iter_mbm() { + visit!(f.expr) + } + if let Some(rest) = reft([s.rest]) { + expr_visit_free_vars_mbm(rest.deref_mbm(), visitor) + } + }, + Expr::Try(t) => visit!(t.expr), + Expr::TryBlock(t) => block_visit_free_vars_mbm(reft([t.block]), visitor), + Expr::Tuple(t) => + for e in t.elems.iter_mbm() { + expr_visit_free_vars_mbm(e, visitor) + }, + Expr::Unary(u) => visit!(u.expr), + Expr::Unsafe(u) => block_visit_free_vars_mbm(reft([u.block]), visitor), + Expr::Verbatim(_) => {}, + Expr::While(w) => { + let bound_vars = expr_get_let_bound_vars(&w.cond).into_iter().collect::>(); + visit!(w.cond); + block_visit_free_vars_mbm(reft([w.body]), visitor_except!(bound_vars)) + }, + Expr::Yield(y) => + if let Some(e) = reft([y.expr]) { + expr_visit_free_vars_mbm(e.deref_mbm(), visitor) + }, + _ => {}, + } +} + +/// like `Path::get_ident(&self)`, but `mut` +pub fn path_get_ident_mut(path: &mut Path) -> Option<&mut Ident> { + if path.segments.len() != 1 || path.leading_colon.is_some() { + return None + } + let res = path.segments.first_mut()?; + if res.arguments.is_empty() { Some(&mut res.ident) } else { None } +} + +pub fn expr_get_vars(expr: &Expr) -> Vec { + let mut res = vec![]; + expr_visit_free_vars(expr, &mut |ident| res.push(ident.clone())); + res +} + +#[test] +fn test_expr_get_vars() { + let test_cases = [ + ( + quote! { + { + let res = 0; + for i in [0..10] { + let x = i + a; + res += x / {|m, (n, o)| m + n - o}(2, (b, 42)) + } + res + } + }, + vec!["a", "b"], + ), + ( + quote! { + |x1: u32, x2: u32| { + if y > x1 { + if let Some(z) = foo(x1) { + z + w + } else { + t = 42; + x2 + } + } + } + }, + vec!["y", "foo", "w", "t"], + ), + ]; + + for (expr, expected) in test_cases { + let mut expr = parse2(expr).unwrap(); + let result = expr_get_vars(&mut expr); + let result = result.into_iter().map(|v| v.to_string()).collect::>(); + let expected = expected.into_iter().map(|v| v.to_string()).collect::>(); + println!("result: {:?}", result); + println!("expected: {:?}\n", expected); + assert_eq!(result, expected) + } +} + +pub fn token_stream_replace_ident(ts: TokenStream, visitor: &mut dyn FnMut(&mut Ident)) -> TokenStream { + fn token_tree_replace_ident(mut tt: TokenTree, visitor: &mut dyn FnMut(&mut Ident)) -> TokenTree { + match tt { + TokenTree::Group(grp) => { + let updated_ts = token_stream_replace_ident(grp.stream(), visitor); + let new_grp = Group::new(grp.delimiter(), updated_ts); + TokenTree::Group(new_grp) + }, + TokenTree::Ident(ref mut ident) => { + visitor(ident); + tt + }, + TokenTree::Punct(_) => tt, + TokenTree::Literal(_) => tt, + } + } + + let mut new_tts = vec![]; + for tt in ts.into_iter() { + new_tts.push(token_tree_replace_ident(tt, visitor)); + } + TokenStream::from_iter(new_tts) +} + +pub fn token_stream_visit_idents(ts: TokenStream, visitor: &mut impl FnMut(&Ident)) { + fn token_tree_visit_idents(mut tt: TokenTree, visitor: &mut impl FnMut(&Ident)) { + match tt { + TokenTree::Group(grp) => { + token_stream_visit_idents(grp.stream(), visitor); + }, + TokenTree::Ident(ref mut ident) => visitor(ident), + TokenTree::Punct(_) => (), + TokenTree::Literal(_) => (), + } + } + + for tt in ts.into_iter() { + token_tree_visit_idents(tt, visitor); + } +} + +pub fn token_stream_idents(ts: TokenStream) -> Vec { + let mut res = vec![]; + token_stream_visit_idents(ts, &mut |ident| res.push(ident.clone())); + res +} + +pub fn expr_visit_macros_mut(expr: &mut Expr, visitor: &mut dyn FnMut(&mut ExprMacro)) { + struct Visitor<'a>(&'a mut dyn FnMut(&mut ExprMacro)); + impl<'a> syn::visit_mut::VisitMut for Visitor<'a> { + fn visit_expr_macro_mut(&mut self, node: &mut ExprMacro) { (self.0)(node) } + } + Visitor(visitor).visit_expr_mut(expr) +} + +pub fn expr_visit_idents_in_macros_mut(expr: &mut Expr, visitor: &mut dyn FnMut(&mut Ident)) { + let mut mac_visitor = |mac: &mut ExprMacro| { + update(&mut mac.mac.tokens, |ts| token_stream_replace_ident(ts, visitor)); + }; + expr_visit_macros_mut(expr, &mut mac_visitor) +} diff --git a/ascent_macro/src/test_errors.rs b/ascent_macro/src/test_errors.rs index 9a99add..a2c474e 100644 --- a/ascent_macro/src/test_errors.rs +++ b/ascent_macro/src/test_errors.rs @@ -1,22 +1,21 @@ -#![cfg(test)] -use crate::ascent_impl; - - -#[test] -fn test_agg_not_stratifiable() { - let inp = quote!{ - relation foo(i32, i32, i32); - relation bar(i32, i32); - relation baz(i32); - - baz(x) <-- - foo(x, _, _), - !bar(_, x); - - bar(x, x + 1) <-- baz(x); - }; - let res = ascent_impl(inp, false, false); - println!("res: {:?}", res); - assert!(res.is_err()); - assert!(res.unwrap_err().to_string().contains("bar")); -} +#![cfg(test)] +use crate::ascent_impl; + +#[test] +fn test_agg_not_stratifiable() { + let inp = quote! { + relation foo(i32, i32, i32); + relation bar(i32, i32); + relation baz(i32); + + baz(x) <-- + foo(x, _, _), + !bar(_, x); + + bar(x, x + 1) <-- baz(x); + }; + let res = ascent_impl(inp, false, false); + println!("res: {:?}", res); + assert!(res.is_err()); + assert!(res.unwrap_err().to_string().contains("bar")); +} diff --git a/ascent_macro/src/tests.rs b/ascent_macro/src/tests.rs index cefc4c3..200e4eb 100644 --- a/ascent_macro/src/tests.rs +++ b/ascent_macro/src/tests.rs @@ -1,544 +1,545 @@ -#![cfg(test)] -use petgraph::dot::{Config, Dot}; -use proc_macro2::TokenStream; - -use crate::ascent_impl; - - -#[test] -fn test_macro0() { - let inp = quote!{ - struct Polonius; - relation subset(T::Origin, T::Origin, T::Point);// = ctx.subset_base.clone(); - relation cfg_edge(T::Point, T::Point); - relation origin_live_on_entry(T::Origin, T::Point); - relation origin_contains_loan_on_entry(T::Origin, T::Loan, T::Point); - relation loan_live_at(T::Loan, T::Point); - relation loan_invalidated_at(T::Loan, T::Point); - relation errors(T::Loan, T::Point); - relation placeholder_origin(T::Origin); - relation subset_error(T::Origin, T::Origin, T::Point); - relation loan_killed_at(T::Loan, T::Point);// = loan_killed_at.iter().cloned().collect(); - relation known_placeholder_subset(T::Origin, T::Origin);// = known_placeholder_subset.iter().cloned().collect(); - - subset(origin1, origin3, point) <-- - subset(origin1, origin2, point), - subset(origin2, origin3, point), - if origin1 != origin3; - - subset(origin1, origin2, point2) <-- - subset(origin1, origin2, point1), - cfg_edge(point1, point2), - origin_live_on_entry(origin1, point2), - origin_live_on_entry(origin2, point2); - - origin_contains_loan_on_entry(origin2, loan, point) <-- - origin_contains_loan_on_entry(origin1, loan, point), - subset(origin1, origin2, point); - - origin_contains_loan_on_entry(origin, loan, point2) <-- - origin_contains_loan_on_entry(origin, loan, point1), - cfg_edge(point1, point2), - !loan_killed_at(loan, point1), - origin_live_on_entry(origin, point2); - - loan_live_at(loan, point) <-- - origin_contains_loan_on_entry(origin, loan, point), - origin_live_on_entry(origin, point); - - errors(loan, point) <-- - loan_invalidated_at(loan, point), - loan_live_at(loan, point); - - subset_error(origin1, origin2, point) <-- - subset(origin1, origin2, point), - placeholder_origin(origin1), - placeholder_origin(origin2), - !known_placeholder_subset(origin1, origin2), - if origin1 != origin2; - }; - // write_ascent_run_to_scratchpad(inp); - write_ascent_run_par_to_scratchpad(inp); -} -#[test] -fn test_macro_generic_tc() { - let inp = quote!{ - #![ds(custom_ds)] - struct TC where TNode: Clone + std::cmp::Eq + std::hash::Hash + Sync + Send; - #[ds(ascent::rel)] - relation edge(TNode, TNode); - relation path(TNode, TNode); - - path(x, z) <-- edge(x, y), path(y, z); - // path(x, z) <-- path(x, y), path(y, z); - }; - - // write_to_scratchpad(inp); - write_par_to_scratchpad(inp); -} - -#[test] -fn test_macro_multiple_dynamic_clauses() { - let inp = quote! { - relation a(i32, i32); - relation b(i32, i32); - relation c(i32, i32); - - a(y, z), - b(z, w), - c(x, y) <-- - a(x, y), - b(y, z), - c(z, w); - }; - write_to_scratchpad(inp); -} - -#[test] -fn test_macro_tc() { - let inp = quote!{ - // #![measure_rule_times] - struct TC; - relation edge(i32, i32); - relation path(i32, i32); - - path(x, y) <-- edge(x, y); - // path(x, z) <-- edge(x, y), path(y, z); - path(x, z) <-- path(x, y), path(y, z); - }; - - // write_to_scratchpad(inp); - write_par_to_scratchpad(inp); -} - -#[test] -fn test_macro2() { - let input = quote! { - relation foo(i32, Option); - relation bar(i32, i32); - relation baz(i32, i32, i32); - foo(1, Some(2)); - foo(2, None); - foo(3, Some(5)); - foo(4, Some(10)); - - - bar(3, 6); - bar(5, 10); - bar(10, 20); - - baz(*x, *y, *z) <-- foo(x, ?Some(y)), bar(y , z); - }; - - write_to_scratchpad(input); -} - -#[test] -fn test_clone_warning() { - let input = quote! { - struct Ancestory<'a>; - relation parent(&'a str, &'a str); - relation ancestor(&'a str,&'a str); - - ancestor(p, c) <-- parent(p, c); - - ancestor(p, gc) <-- - parent(p, c), ancestor(c, gc); - }; - - write_to_scratchpad(input); -} - -#[test] -fn test_macro_unary_rels() { - let input = quote! { - relation foo(i32); - relation bar(i32); - relation baz(i32, i32); - foo(1); - - bar(3); - - foo(x), bar(y) <-- baz(x, y); - baz(x, x + 1) <-- foo(x), bar(x); - }; - - write_to_scratchpad(input); -} - -#[test] -fn test_macro3() { - let input = quote! { - relation bar(i32, i32); - relation foo(i32, i32); - relation baz(i32, i32); - - foo(1, 2); - foo(10, 2); - bar(2, 3); - bar(2, 1); - - baz(*x, *z) <-- foo(x, y) if *x != 10, bar(y, z), if x != z; - foo(*x, *y), bar(*x, *y) <-- baz(x, y); - }; - - write_to_scratchpad(input); -} - -#[test] -fn test_macro_agg() { - let inp = quote! { - relation foo(i32); - relation bar(i32, i32, i32); - lattice baz(i32, i32); - - baz(x, min_z) <-- - foo(x), - agg min_z = min(z) in bar(x, _, z); - }; - write_to_scratchpad(inp); -} - -#[test] -fn test_macro_generator() { - let input = quote! { - relation edge(i32, i32); - relation path(i32, i32); - edge(x, x + 1) <-- for x in 0..100; - path(*x, *y) <-- edge(x,y); - path(*x, *z) <-- edge(x,y), path(y, z); - }; - - write_par_to_scratchpad(input); -} - -#[test] -fn test_macro_patterns() { - let input = quote! { - relation foo(i32, Option); - relation bar(i32, i32); - foo(1, None); - foo(2, Some(2)); - foo(3, Some(30)); - bar(*x, *y) <-- foo(x, ?Some(y)) if y != x; - bar(*x, *y) <-- foo(x, y_opt) if let Some(y) = y_opt if y != x; - }; - - write_to_scratchpad(input); -} - -#[test] -fn test_macro_sp(){ - let input = quote!{ - relation edge(i32, i32, u32); - lattice shortest_path(i32, i32, Dual); - - edge(1, 2, 30); - - shortest_path(x, y, Dual(*len)) <-- edge(x, y, len); - shortest_path(x, z, Dual(len + plen)) <-- edge(x, y, len), shortest_path(y, z, ?Dual(plen)); - }; - // write_to_scratchpad(input); - write_par_to_scratchpad(input); -} - -#[test] -fn test_lattice(){ - let input = quote! { - relation foo(i32, i32); - relation bar(i32, i32); - - bar(x, x+1) <-- for x in 0..10; - foo(*x, *y) <-- bar(x, y); - - lattice foo_as_set(ascent::lattice::set::Set<(i32, i32)>); - foo_as_set(ascent::lattice::set::Set::singleton((*x, *y))) <-- foo(x, y); - - relation baz(i32, i32); - baz(1, 2); - baz(1, 3); - - relation res(i32, i32); - res(*x, *y) <-- baz(x, y), foo_as_set(all_foos), if !all_foos.contains(&(*x, *y)); - }; - write_to_scratchpad(input); -} - -#[test] -fn test_macro_lattices(){ - let input = quote!{ - lattice longest_path(i32, i32, u32); - relation edge(i32, i32, u32); - - longest_path(x, y, ew) <-- edge(x, y, ew); - // longest_path(x, z, *ew + *w) <-- edge(x, y, ew), longest_path(y, z, w); - longest_path(x, z, *l1 + *l2) <-- longest_path(x, y, l1), longest_path(y, z, l2); - - - // edge(1,2, 3); - // edge(2,3, 5); - // edge(1,3, 4); - // edge(2,4, 10); - - }; - // write_to_scratchpad(input); - write_par_to_scratchpad(input); -} - -#[test] -fn test_no_generic(){ - let input = quote!{ - struct AscentProgram; - relation dummy(usize); - }; - // write_to_scratchpad(input); - write_to_scratchpad(input); -} - -#[test] -fn test_generic_ty(){ - let input = quote!{ - struct AscentProgram; - relation dummy(T); - }; - // write_to_scratchpad(input); - write_to_scratchpad(input); -} - -#[test] -fn test_generic_ty_where_clause(){ - let input = quote!{ - struct AscentProgram where T: Clone + Hash + Eq; - relation dummy(T); - }; - write_to_scratchpad(input); -} - -#[test] -fn test_generic_ty_with_divergent_impl_generics(){ - let input = quote!{ - struct AscentProgram; - impl AscentProgram; - relation dummy(T); - }; - write_to_scratchpad(input); -} - -#[test] -fn test_generic_ty_with_divergent_impl_generics_where_clause(){ - let input = quote!{ - /// Type DOC COMMENT - struct AscentProgram; - impl AscentProgram where T: Clone + Hash + Eq; - /// dummy REL DOC COMEMNT - relation dummy(T); - }; - write_to_scratchpad(input); -} - -#[test] -fn exp_borrowing(){ - // let mut v: Vec = vec![]; - // let mut u: Vec = vec![]; - // for i in 0..v.len(){ - // let v_row = &v[i]; - - // for j in 0..u.len(){ - // let u_row = &u[j]; - // let new_row = *u_row + *v_row; - // v.push(new_row); - // } - // } - - // let x: Vec = vec![42]; - // let y: Vec = Convert::convert(&x); - // let z: Vec = Convert::convert(x); -} - -#[test] -fn exp_condensation() { - use petgraph::algo::condensation; - use petgraph::prelude::*; - use petgraph::Graph; - - let mut graph: Graph<&'static str, (), Directed> = Graph::new(); - let a = graph.add_node("a"); // node with no weight - let b = graph.add_node("b"); - let c = graph.add_node("c"); - let d = graph.add_node("d"); - let e = graph.add_node("e"); - let f = graph.add_node("f"); - let g = graph.add_node("g"); - let h = graph.add_node("h"); - - // a ----> b ----> e ----> f - // ^ | ^ | - // | v | v - // d <---- c h <---- g - graph.extend_with_edges(&[(a, b), (b, c), (c, d), (d, a), (b, e), (e, f), (f, g), (g, h), (h, e)]); - let acyclic_condensed_graph = condensation(graph.clone(), true); - #[allow(non_snake_case)] - let (A, B) = (NodeIndex::new(0), NodeIndex::new(1)); - assert_eq!(acyclic_condensed_graph.node_count(), 2); - assert_eq!(acyclic_condensed_graph.edge_count(), 1); - assert_eq!(acyclic_condensed_graph.neighbors(B).collect::>(), vec![A]); - - println!("{:?}", Dot::with_config(&acyclic_condensed_graph, &[Config::EdgeNoLabel])); - - let sccs = petgraph::algo::tarjan_scc(&graph); - println!("sccs ordered:"); - for scc in sccs.iter(){ - println!("{:?}", scc); - } -} - -#[test] -fn exp_items_in_fn(){ - let mut p = Default::default(); - for i in 0..10 { - p = { - #[derive(Debug, Default)] - struct Point{x: i32, y: i32} - impl Point { - pub fn size(&self) -> i32 {self.x * self.x + self.y * self.y} - } - Point{x:i, y: i+1} - }; - } - println!("point is {:?}, with size {}", p, p.size()); -} - -fn write_to_scratchpad_base(tokens: TokenStream, prefix: TokenStream, is_ascent_run: bool, is_parallel: bool) -> TokenStream { - let code = ascent_impl(tokens, is_ascent_run, is_parallel); - let code = code.unwrap(); - let template = std::fs::read_to_string("src/scratchpad_template.rs").unwrap(); - let code_in_template = template.replace("todo!(\"here\");", &code.to_string()); - std::fs::write("src/scratchpad.rs", prefix.to_string()).unwrap(); - std::fs::write("src/scratchpad.rs", code_in_template).unwrap(); - std::process::Command::new("rustfmt").args(&["src/scratchpad.rs"]).spawn().unwrap().wait().unwrap(); - code -} - -fn write_to_scratchpad(tokens: TokenStream) -> TokenStream { - write_to_scratchpad_base(tokens, quote!{}, false, false) -} - -fn write_with_prefix_to_scratchpad(tokens: TokenStream, prefix: TokenStream) -> TokenStream { - write_to_scratchpad_base(tokens, prefix, false, false) -} - -fn write_par_to_scratchpad(tokens: TokenStream) -> TokenStream { - write_to_scratchpad_base(tokens, quote!{}, false, true) -} - -#[allow(unused)] -fn write_ascent_run_to_scratchpad(tokens: TokenStream) -> TokenStream { - write_to_scratchpad_base(tokens, quote!{}, true, false) -} - -fn write_ascent_run_par_to_scratchpad(tokens: TokenStream) -> TokenStream { - write_to_scratchpad_base(tokens, quote!{}, true, true) -} - - -#[test] -fn test_macro_lambda_calc(){ - let prefix = quote! { - #[derive(Clone, PartialEq, Eq, Debug, Hash)] - pub enum LambdaCalcExpr{ - Ref(&'static str), - Lam(&'static str, Rc), - App(Rc, Rc) - } - - use LambdaCalcExpr::*; - - impl LambdaCalcExpr { - #[allow(dead_code)] - fn depth(&self) -> usize { - match self{ - LambdaCalcExpr::Ref(_) => 0, - LambdaCalcExpr::Lam(_x,b) => 1 + b.depth(), - LambdaCalcExpr::App(f,e) => 1 + max(f.depth(), e.depth()) - } - } - } - fn app(f: LambdaCalcExpr, a: LambdaCalcExpr) -> LambdaCalcExpr { - App(Rc::new(f), Rc::new(a)) - } - fn lam(x: &'static str, e: LambdaCalcExpr) -> LambdaCalcExpr { - Lam(x, Rc::new(e)) - } - - fn sub(exp: &LambdaCalcExpr, var: &str, e: &LambdaCalcExpr) -> LambdaCalcExpr { - match exp { - Ref(x) if *x == var => e.clone(), - Ref(_x) => exp.clone(), - App(ef,ea) => app(sub(ef, var, e), sub(ea, var, e)), - Lam(x, _eb) if *x == var => exp.clone(), - Lam(x, eb) => lam(x, sub(eb, var, e)) - } - } - - #[allow(non_snake_case)] - fn U() -> LambdaCalcExpr {lam("x", app(Ref("x"), Ref("x")))} - #[allow(non_snake_case)] - fn I() -> LambdaCalcExpr {lam("x", Ref("x"))} - - fn min<'a>(inp: impl Iterator) -> impl Iterator { - inp.map(|tuple| tuple.0).min().cloned().into_iter() - } - }; - let inp = quote!{ - relation output(LambdaCalcExpr); - relation input(LambdaCalcExpr); - relation do_eval(LambdaCalcExpr); - relation eval(LambdaCalcExpr, LambdaCalcExpr); - - input(app(U(), I())); - do_eval(exp.clone()) <-- input(exp); - output(res.clone()) <-- input(exp), eval(exp, res); - - eval(exp.clone(), exp.clone()) <-- do_eval(?exp @Ref(_)); - - eval(exp.clone(), exp.clone()) <-- do_eval(exp), if let Lam(_,_) = exp; - - do_eval(ef.as_ref().clone()) <-- do_eval(?App(ef,_ea)); - - do_eval(sub(fb, fx, ea)) <-- - do_eval(?App(ef, ea)), - eval(ef.deref(), ?Lam(fx, fb)); - - eval(exp.clone(), final_res.clone()) <-- - do_eval(?exp @ App(ef, ea)), // this requires nightly - eval(ef.deref(), ?Lam(fx, fb)), - eval(sub(fb, fx, ea), final_res); - }; - write_with_prefix_to_scratchpad(inp, prefix); -} - -#[test] -fn test_macro_in_macro() { - let inp = quote!{ - relation foo(i32, i32); - relation bar(i32, i32); - - macro foo_($x: expr, $y: expr) { - foo($x, $y) - } - - macro foo($x: expr, $y: expr) { - let _x = $x, let _y = $y, foo_!(_x, _y) - } - - foo(0, 1); - foo(1, 2); - foo(2, 3); - foo(3, 4); - - bar(x, y) <-- foo(x, y), foo!(x + 1, y + 1), foo!(x + 2, y + 2), foo!(x + 3, y + 3); - - }; - - write_to_scratchpad(inp); -} +#![cfg(test)] +use petgraph::dot::{Config, Dot}; +use proc_macro2::TokenStream; + +use crate::ascent_impl; + +#[test] +fn test_macro0() { + let inp = quote! { + struct Polonius; + relation subset(T::Origin, T::Origin, T::Point);// = ctx.subset_base.clone(); + relation cfg_edge(T::Point, T::Point); + relation origin_live_on_entry(T::Origin, T::Point); + relation origin_contains_loan_on_entry(T::Origin, T::Loan, T::Point); + relation loan_live_at(T::Loan, T::Point); + relation loan_invalidated_at(T::Loan, T::Point); + relation errors(T::Loan, T::Point); + relation placeholder_origin(T::Origin); + relation subset_error(T::Origin, T::Origin, T::Point); + relation loan_killed_at(T::Loan, T::Point);// = loan_killed_at.iter().cloned().collect(); + relation known_placeholder_subset(T::Origin, T::Origin);// = known_placeholder_subset.iter().cloned().collect(); + + subset(origin1, origin3, point) <-- + subset(origin1, origin2, point), + subset(origin2, origin3, point), + if origin1 != origin3; + + subset(origin1, origin2, point2) <-- + subset(origin1, origin2, point1), + cfg_edge(point1, point2), + origin_live_on_entry(origin1, point2), + origin_live_on_entry(origin2, point2); + + origin_contains_loan_on_entry(origin2, loan, point) <-- + origin_contains_loan_on_entry(origin1, loan, point), + subset(origin1, origin2, point); + + origin_contains_loan_on_entry(origin, loan, point2) <-- + origin_contains_loan_on_entry(origin, loan, point1), + cfg_edge(point1, point2), + !loan_killed_at(loan, point1), + origin_live_on_entry(origin, point2); + + loan_live_at(loan, point) <-- + origin_contains_loan_on_entry(origin, loan, point), + origin_live_on_entry(origin, point); + + errors(loan, point) <-- + loan_invalidated_at(loan, point), + loan_live_at(loan, point); + + subset_error(origin1, origin2, point) <-- + subset(origin1, origin2, point), + placeholder_origin(origin1), + placeholder_origin(origin2), + !known_placeholder_subset(origin1, origin2), + if origin1 != origin2; + }; + // write_ascent_run_to_scratchpad(inp); + write_ascent_run_par_to_scratchpad(inp); +} +#[test] +fn test_macro_generic_tc() { + let inp = quote! { + #![ds(custom_ds)] + struct TC where TNode: Clone + std::cmp::Eq + std::hash::Hash + Sync + Send; + #[ds(ascent::rel)] + relation edge(TNode, TNode); + relation path(TNode, TNode); + + path(x, z) <-- edge(x, y), path(y, z); + // path(x, z) <-- path(x, y), path(y, z); + }; + + // write_to_scratchpad(inp); + write_par_to_scratchpad(inp); +} + +#[test] +fn test_macro_multiple_dynamic_clauses() { + let inp = quote! { + relation a(i32, i32); + relation b(i32, i32); + relation c(i32, i32); + + a(y, z), + b(z, w), + c(x, y) <-- + a(x, y), + b(y, z), + c(z, w); + }; + write_to_scratchpad(inp); +} + +#[test] +fn test_macro_tc() { + let inp = quote! { + // #![measure_rule_times] + struct TC; + relation edge(i32, i32); + relation path(i32, i32); + + path(x, y) <-- edge(x, y); + // path(x, z) <-- edge(x, y), path(y, z); + path(x, z) <-- path(x, y), path(y, z); + }; + + // write_to_scratchpad(inp); + write_par_to_scratchpad(inp); +} + +#[test] +fn test_macro2() { + let input = quote! { + relation foo(i32, Option); + relation bar(i32, i32); + relation baz(i32, i32, i32); + foo(1, Some(2)); + foo(2, None); + foo(3, Some(5)); + foo(4, Some(10)); + + + bar(3, 6); + bar(5, 10); + bar(10, 20); + + baz(*x, *y, *z) <-- foo(x, ?Some(y)), bar(y , z); + }; + + write_to_scratchpad(input); +} + +#[test] +fn test_clone_warning() { + let input = quote! { + struct Ancestory<'a>; + relation parent(&'a str, &'a str); + relation ancestor(&'a str,&'a str); + + ancestor(p, c) <-- parent(p, c); + + ancestor(p, gc) <-- + parent(p, c), ancestor(c, gc); + }; + + write_to_scratchpad(input); +} + +#[test] +fn test_macro_unary_rels() { + let input = quote! { + relation foo(i32); + relation bar(i32); + relation baz(i32, i32); + foo(1); + + bar(3); + + foo(x), bar(y) <-- baz(x, y); + baz(x, x + 1) <-- foo(x), bar(x); + }; + + write_to_scratchpad(input); +} + +#[test] +fn test_macro3() { + let input = quote! { + relation bar(i32, i32); + relation foo(i32, i32); + relation baz(i32, i32); + + foo(1, 2); + foo(10, 2); + bar(2, 3); + bar(2, 1); + + baz(*x, *z) <-- foo(x, y) if *x != 10, bar(y, z), if x != z; + foo(*x, *y), bar(*x, *y) <-- baz(x, y); + }; + + write_to_scratchpad(input); +} + +#[test] +fn test_macro_agg() { + let inp = quote! { + relation foo(i32); + relation bar(i32, i32, i32); + lattice baz(i32, i32); + + baz(x, min_z) <-- + foo(x), + agg min_z = min(z) in bar(x, _, z); + }; + write_to_scratchpad(inp); +} + +#[test] +fn test_macro_generator() { + let input = quote! { + relation edge(i32, i32); + relation path(i32, i32); + edge(x, x + 1) <-- for x in 0..100; + path(*x, *y) <-- edge(x,y); + path(*x, *z) <-- edge(x,y), path(y, z); + }; + + write_par_to_scratchpad(input); +} + +#[test] +fn test_macro_patterns() { + let input = quote! { + relation foo(i32, Option); + relation bar(i32, i32); + foo(1, None); + foo(2, Some(2)); + foo(3, Some(30)); + bar(*x, *y) <-- foo(x, ?Some(y)) if y != x; + bar(*x, *y) <-- foo(x, y_opt) if let Some(y) = y_opt if y != x; + }; + + write_to_scratchpad(input); +} + +#[test] +fn test_macro_sp() { + let input = quote! { + relation edge(i32, i32, u32); + lattice shortest_path(i32, i32, Dual); + + edge(1, 2, 30); + + shortest_path(x, y, Dual(*len)) <-- edge(x, y, len); + shortest_path(x, z, Dual(len + plen)) <-- edge(x, y, len), shortest_path(y, z, ?Dual(plen)); + }; + // write_to_scratchpad(input); + write_par_to_scratchpad(input); +} + +#[test] +fn test_lattice() { + let input = quote! { + relation foo(i32, i32); + relation bar(i32, i32); + + bar(x, x+1) <-- for x in 0..10; + foo(*x, *y) <-- bar(x, y); + + lattice foo_as_set(ascent::lattice::set::Set<(i32, i32)>); + foo_as_set(ascent::lattice::set::Set::singleton((*x, *y))) <-- foo(x, y); + + relation baz(i32, i32); + baz(1, 2); + baz(1, 3); + + relation res(i32, i32); + res(*x, *y) <-- baz(x, y), foo_as_set(all_foos), if !all_foos.contains(&(*x, *y)); + }; + write_to_scratchpad(input); +} + +#[test] +fn test_macro_lattices() { + let input = quote! { + lattice longest_path(i32, i32, u32); + relation edge(i32, i32, u32); + + longest_path(x, y, ew) <-- edge(x, y, ew); + // longest_path(x, z, *ew + *w) <-- edge(x, y, ew), longest_path(y, z, w); + longest_path(x, z, *l1 + *l2) <-- longest_path(x, y, l1), longest_path(y, z, l2); + + + // edge(1,2, 3); + // edge(2,3, 5); + // edge(1,3, 4); + // edge(2,4, 10); + + }; + // write_to_scratchpad(input); + write_par_to_scratchpad(input); +} + +#[test] +fn test_no_generic() { + let input = quote! { + struct AscentProgram; + relation dummy(usize); + }; + // write_to_scratchpad(input); + write_to_scratchpad(input); +} + +#[test] +fn test_generic_ty() { + let input = quote! { + struct AscentProgram; + relation dummy(T); + }; + // write_to_scratchpad(input); + write_to_scratchpad(input); +} + +#[test] +fn test_generic_ty_where_clause() { + let input = quote! { + struct AscentProgram where T: Clone + Hash + Eq; + relation dummy(T); + }; + write_to_scratchpad(input); +} + +#[test] +fn test_generic_ty_with_divergent_impl_generics() { + let input = quote! { + struct AscentProgram; + impl AscentProgram; + relation dummy(T); + }; + write_to_scratchpad(input); +} + +#[test] +fn test_generic_ty_with_divergent_impl_generics_where_clause() { + let input = quote! { + /// Type DOC COMMENT + struct AscentProgram; + impl AscentProgram where T: Clone + Hash + Eq; + /// dummy REL DOC COMEMNT + relation dummy(T); + }; + write_to_scratchpad(input); +} + +#[test] +fn exp_borrowing() { + // let mut v: Vec = vec![]; + // let mut u: Vec = vec![]; + // for i in 0..v.len(){ + // let v_row = &v[i]; + + // for j in 0..u.len(){ + // let u_row = &u[j]; + // let new_row = *u_row + *v_row; + // v.push(new_row); + // } + // } + + // let x: Vec = vec![42]; + // let y: Vec = Convert::convert(&x); + // let z: Vec = Convert::convert(x); +} + +#[test] +fn exp_condensation() { + use petgraph::Graph; + use petgraph::algo::condensation; + use petgraph::prelude::*; + + let mut graph: Graph<&'static str, (), Directed> = Graph::new(); + let a = graph.add_node("a"); // node with no weight + let b = graph.add_node("b"); + let c = graph.add_node("c"); + let d = graph.add_node("d"); + let e = graph.add_node("e"); + let f = graph.add_node("f"); + let g = graph.add_node("g"); + let h = graph.add_node("h"); + + // a ----> b ----> e ----> f + // ^ | ^ | + // | v | v + // d <---- c h <---- g + graph.extend_with_edges(&[(a, b), (b, c), (c, d), (d, a), (b, e), (e, f), (f, g), (g, h), (h, e)]); + let acyclic_condensed_graph = condensation(graph.clone(), true); + #[allow(non_snake_case)] + let (A, B) = (NodeIndex::new(0), NodeIndex::new(1)); + assert_eq!(acyclic_condensed_graph.node_count(), 2); + assert_eq!(acyclic_condensed_graph.edge_count(), 1); + assert_eq!(acyclic_condensed_graph.neighbors(B).collect::>(), vec![A]); + + println!("{:?}", Dot::with_config(&acyclic_condensed_graph, &[Config::EdgeNoLabel])); + + let sccs = petgraph::algo::tarjan_scc(&graph); + println!("sccs ordered:"); + for scc in sccs.iter() { + println!("{:?}", scc); + } +} + +#[test] +fn exp_items_in_fn() { + let mut p = Default::default(); + for i in 0..10 { + p = { + #[derive(Debug, Default)] + struct Point { + x: i32, + y: i32, + } + impl Point { + pub fn size(&self) -> i32 { self.x * self.x + self.y * self.y } + } + Point { x: i, y: i + 1 } + }; + } + println!("point is {:?}, with size {}", p, p.size()); +} + +fn write_to_scratchpad_base( + tokens: TokenStream, prefix: TokenStream, is_ascent_run: bool, is_parallel: bool, +) -> TokenStream { + let code = ascent_impl(tokens, is_ascent_run, is_parallel); + let code = code.unwrap(); + let template = std::fs::read_to_string("src/scratchpad_template.rs").unwrap(); + let code_in_template = template.replace("todo!(\"here\");", &code.to_string()); + std::fs::write("src/scratchpad.rs", prefix.to_string()).unwrap(); + std::fs::write("src/scratchpad.rs", code_in_template).unwrap(); + std::process::Command::new("rustfmt").args(&["src/scratchpad.rs"]).spawn().unwrap().wait().unwrap(); + code +} + +fn write_to_scratchpad(tokens: TokenStream) -> TokenStream { write_to_scratchpad_base(tokens, quote! {}, false, false) } + +fn write_with_prefix_to_scratchpad(tokens: TokenStream, prefix: TokenStream) -> TokenStream { + write_to_scratchpad_base(tokens, prefix, false, false) +} + +fn write_par_to_scratchpad(tokens: TokenStream) -> TokenStream { + write_to_scratchpad_base(tokens, quote! {}, false, true) +} + +#[allow(unused)] +fn write_ascent_run_to_scratchpad(tokens: TokenStream) -> TokenStream { + write_to_scratchpad_base(tokens, quote! {}, true, false) +} + +fn write_ascent_run_par_to_scratchpad(tokens: TokenStream) -> TokenStream { + write_to_scratchpad_base(tokens, quote! {}, true, true) +} + +#[test] +fn test_macro_lambda_calc() { + let prefix = quote! { + #[derive(Clone, PartialEq, Eq, Debug, Hash)] + pub enum LambdaCalcExpr{ + Ref(&'static str), + Lam(&'static str, Rc), + App(Rc, Rc) + } + + use LambdaCalcExpr::*; + + impl LambdaCalcExpr { + #[allow(dead_code)] + fn depth(&self) -> usize { + match self{ + LambdaCalcExpr::Ref(_) => 0, + LambdaCalcExpr::Lam(_x,b) => 1 + b.depth(), + LambdaCalcExpr::App(f,e) => 1 + max(f.depth(), e.depth()) + } + } + } + fn app(f: LambdaCalcExpr, a: LambdaCalcExpr) -> LambdaCalcExpr { + App(Rc::new(f), Rc::new(a)) + } + fn lam(x: &'static str, e: LambdaCalcExpr) -> LambdaCalcExpr { + Lam(x, Rc::new(e)) + } + + fn sub(exp: &LambdaCalcExpr, var: &str, e: &LambdaCalcExpr) -> LambdaCalcExpr { + match exp { + Ref(x) if *x == var => e.clone(), + Ref(_x) => exp.clone(), + App(ef,ea) => app(sub(ef, var, e), sub(ea, var, e)), + Lam(x, _eb) if *x == var => exp.clone(), + Lam(x, eb) => lam(x, sub(eb, var, e)) + } + } + + #[allow(non_snake_case)] + fn U() -> LambdaCalcExpr {lam("x", app(Ref("x"), Ref("x")))} + #[allow(non_snake_case)] + fn I() -> LambdaCalcExpr {lam("x", Ref("x"))} + + fn min<'a>(inp: impl Iterator) -> impl Iterator { + inp.map(|tuple| tuple.0).min().cloned().into_iter() + } + }; + let inp = quote! { + relation output(LambdaCalcExpr); + relation input(LambdaCalcExpr); + relation do_eval(LambdaCalcExpr); + relation eval(LambdaCalcExpr, LambdaCalcExpr); + + input(app(U(), I())); + do_eval(exp.clone()) <-- input(exp); + output(res.clone()) <-- input(exp), eval(exp, res); + + eval(exp.clone(), exp.clone()) <-- do_eval(?exp @Ref(_)); + + eval(exp.clone(), exp.clone()) <-- do_eval(exp), if let Lam(_,_) = exp; + + do_eval(ef.as_ref().clone()) <-- do_eval(?App(ef,_ea)); + + do_eval(sub(fb, fx, ea)) <-- + do_eval(?App(ef, ea)), + eval(ef.deref(), ?Lam(fx, fb)); + + eval(exp.clone(), final_res.clone()) <-- + do_eval(?exp @ App(ef, ea)), // this requires nightly + eval(ef.deref(), ?Lam(fx, fb)), + eval(sub(fb, fx, ea), final_res); + }; + write_with_prefix_to_scratchpad(inp, prefix); +} + +#[test] +fn test_macro_in_macro() { + let inp = quote! { + relation foo(i32, i32); + relation bar(i32, i32); + + macro foo_($x: expr, $y: expr) { + foo($x, $y) + } + + macro foo($x: expr, $y: expr) { + let _x = $x, let _y = $y, foo_!(_x, _y) + } + + foo(0, 1); + foo(1, 2); + foo(2, 3); + foo(3, 4); + + bar(x, y) <-- foo(x, y), foo!(x + 1, y + 1), foo!(x + 2, y + 2), foo!(x + 3, y + 3); + + }; + + write_to_scratchpad(inp); +} diff --git a/ascent_macro/src/utils.rs b/ascent_macro/src/utils.rs index db1b384..8fc9599 100644 --- a/ascent_macro/src/utils.rs +++ b/ascent_macro/src/utils.rs @@ -1,262 +1,270 @@ -#![deny(warnings)] -use std::collections::HashMap; -use std::collections::HashSet; -use std::hash::Hash; - -use proc_macro2::{Ident, Span, TokenStream, TokenTree, Group}; - -use syn::{Expr, Pat, Type, punctuated::Punctuated, spanned::Spanned}; - -use crate::syn_utils::path_get_ident_mut; - -pub fn tuple_type(types: &[Type]) -> Type { - let res = match types.len() { - 1 => { - let ty = &types[0]; - quote! { ( #ty, ) } - }, - _ => quote! { ( #(#types),* ) } - }; - syn::parse2(res).unwrap() -} - -pub fn tuple(exprs: &[Expr]) -> Expr { - let span = if !exprs.is_empty() {exprs[0].span()} else {Span::call_site()}; - tuple_spanned(exprs, span) -} -pub fn tuple_spanned(exprs: &[Expr], span: Span) -> Expr { - let res = match exprs.len() { - 1 => { - let exp = &exprs[0]; - quote_spanned! {span=> ( #exp, ) } - }, - _ => quote_spanned! {span=> ( #(#exprs),* ) } - }; - syn::parse2(res).unwrap() -} - -pub fn exp_cloned(exp: &Expr) -> Expr { - let exp_span = exp.span(); - let res = match exp { - Expr::Path(_) | - Expr::Field(_) | - Expr::Paren(_) => quote_spanned! {exp_span=> #exp.clone()}, - _ => quote_spanned! {exp_span=> (#exp).clone()} - }; - syn::parse2(res).unwrap() -} - -pub fn collect_set(iter : impl Iterator) -> HashSet { - iter.collect() -} - -pub fn into_set(iter : impl IntoIterator) -> HashSet { - iter.into_iter().collect() -} - -pub fn punctuated_map(punc: Punctuated, mut f: impl FnMut (T) -> U) -> Punctuated { - let mut res = Punctuated::new(); - for pair in punc.into_pairs() { - let (t, p) = pair.into_tuple(); - res.push_value(f(t)); - if let Some(p) = p {res.push_punct(p)} - }; - res -} - -pub fn punctuated_try_map(punc: Punctuated, mut f: impl FnMut (T) -> Result) -> Result, E> { - let mut res = Punctuated::new(); - for pair in punc.into_pairs() { - let (t, p) = pair.into_tuple(); - res.push_value(f(t)?); - if let Some(p) = p {res.push_punct(p)} - }; - Ok(res) -} - -pub fn flatten_punctuated(punc: Punctuated, P>) -> Punctuated { - let mut res = Punctuated::new(); - for inner_punc in punc.into_pairs() { - let (inner_punc, p) = inner_punc.into_tuple(); - let inner_punc_len = inner_punc.len(); - for (ind, item) in inner_punc.into_pairs().enumerate() { - let (t, p) = item.into_tuple(); - res.push_value(t); - if ind != inner_punc_len - 1 { - res.push_punct(p.unwrap()) - } - } - if let Some(p) = p {res.push_punct(p)} - } - res -} - -pub fn punctuated_try_unwrap(punc: Punctuated,P>) -> Result, E> { - let mut res = Punctuated::new(); - for pair in punc.into_pairs() { - let (t, p) = pair.into_tuple(); - res.push_value(t?); - if let Some(p) = p {res.push_punct(p)} - } - Ok(res) -} - -pub fn punctuated_singleton(item: T) -> Punctuated { - let mut res = Punctuated::new(); - res.push_value(item); - res -} - -pub fn expr_to_ident(expr: &Expr) -> Option { - match expr { - Expr::Path(p) => p.path.get_ident().cloned(), - _ => None - } -} -pub fn expr_to_ident_mut(expr: &mut Expr) -> Option<&mut Ident> { - match expr { - Expr::Path(p) => path_get_ident_mut(&mut p.path), - _ => None - } -} - -pub fn pat_to_ident(pat: &Pat) -> Option { - match pat { - Pat::Ident(ident) => Some(ident.ident.clone()), - _ => None - } -} - -pub fn is_wild_card(expr: &Expr) -> bool { - match expr { - Expr::Infer(_) => true, - Expr::Verbatim(ts) => ts.to_string() == "_", - _ => false - } -} - - -pub fn token_stream_replace_macro_idents(input: TokenStream, ident_replacements: &HashMap) -> TokenStream { - - fn ts_replace(ts: TokenStream, ident_replacements: &HashMap, res: &mut Vec){ - - - let mut last_dollar = None; - for tt in ts { - if let Some(dollar) = last_dollar.take() { - let is_match = match &tt { - TokenTree::Ident(after_dollar_ident) => ident_replacements.get(after_dollar_ident), - _ => None - }; - if let Some(replacement) = is_match { - res.extend(replacement.clone()); - continue; - } else { - res.push(dollar); - } - } - let is_dollar = match &tt { - TokenTree::Punct(punct) => punct.as_char() == '$', - _ => false - }; - if is_dollar { - last_dollar = Some(tt); - } else { - match tt { - TokenTree::Group(grp) => { - let replaced = token_stream_replace_macro_idents(grp.stream(), ident_replacements); - let updated_group = Group::new(grp.delimiter(), replaced); - res.push(TokenTree::Group(updated_group)); - }, - _ => { - res.push(tt) - } - } - } - } - if let Some(dollar) = last_dollar { - res.push(dollar); - } - } - - let mut res = vec![]; - ts_replace(input, ident_replacements, &mut res); - - res.into_iter().collect() -} - -pub fn spans_eq(span1: &Span, span2: &Span) -> bool { - format!("{:?}", span1) == format!("{:?}", span2) -} - -// I don't know why I'm like this -pub trait Piper : Sized { - /// applies `f` to `self`, i.e., `f(self)` - fn pipe(self, f: impl FnOnce(Self) -> Res) -> Res; -} - -impl Piper for T where T: Sized { - #[inline(always)] - fn pipe(self, f: impl FnOnce(Self) -> Res) -> Res { - f(self) - } -} - - -/// sets the span of only top-level tokens to `span` -pub fn with_span(ts: TokenStream, span: Span) -> TokenStream { - ts.into_iter().map(|mut tt| { - tt.set_span(span); - tt - }).collect() -} - -pub(crate) trait TokenStreamExtensions { - fn with_span(self, span: Span) -> TokenStream; -} - -impl TokenStreamExtensions for TokenStream { - fn with_span(self, span: Span) -> TokenStream { with_span(self, span) } -} - -fn check_lazy_set_contains(hs: &mut HashSet, iter: &mut impl Iterator, x: T) -> bool { - if hs.contains(&x) { return true } - - for item in iter { - let eq = item == x; - hs.insert(item); - if eq { return true } - } - false -} - -pub fn intersects(set1: Iter1, set2: Iter2) -> bool -where T: Hash + Eq, Iter1: IntoIterator, Iter2: IntoIterator -{ - let mut hs = HashSet::default(); - let mut set1 = set1.into_iter(); - for x in set2 { - if check_lazy_set_contains(&mut hs, &mut set1, x) { return true } - } - false -} - -#[test] -fn test_subsumes_and_intersects() { - let cases = [ - (vec![1, 2, 3], vec![3, 3, 4], false, true), - (vec![1, 2, 3], vec![3, 3, 2], true, true), - (vec![1, 2, 3, 4], vec![1, 3, 4, 2], true, true), - (vec![1, 2, 3, 4], vec![], true, false), - (vec![1, 2, 3, 4], vec![4, 2, 3, 1, 1, 2, 3, 4], true, true), - (vec![1, 2, 3], vec![4], false, false), - (vec![], vec![4], false, false), - (vec![], vec![], true, false), - (vec![1, 2], vec![3, 4], false, false), - (vec![1, 2, 3, 4], vec![5, 6, 7, 1], false, true) - ]; - for (s1, s2, _subsumes_expected, intersects_expected) in cases { - println!("s1: {:?}, s2: {:?}", s1, s2); - assert_eq!(intersects(s1.iter(), s2.iter()), intersects_expected); - } -} \ No newline at end of file +#![deny(warnings)] +use std::collections::{HashMap, HashSet}; +use std::hash::Hash; + +use proc_macro2::{Group, Ident, Span, TokenStream, TokenTree}; +use syn::punctuated::Punctuated; +use syn::spanned::Spanned; +use syn::{Expr, Pat, Type}; + +use crate::syn_utils::path_get_ident_mut; + +pub fn tuple_type(types: &[Type]) -> Type { + let res = match types.len() { + 1 => { + let ty = &types[0]; + quote! { ( #ty, ) } + }, + _ => quote! { ( #(#types),* ) }, + }; + syn::parse2(res).unwrap() +} + +pub fn tuple(exprs: &[Expr]) -> Expr { + let span = if !exprs.is_empty() { exprs[0].span() } else { Span::call_site() }; + tuple_spanned(exprs, span) +} +pub fn tuple_spanned(exprs: &[Expr], span: Span) -> Expr { + let res = match exprs.len() { + 1 => { + let exp = &exprs[0]; + quote_spanned! {span=> ( #exp, ) } + }, + _ => quote_spanned! {span=> ( #(#exprs),* ) }, + }; + syn::parse2(res).unwrap() +} + +pub fn exp_cloned(exp: &Expr) -> Expr { + let exp_span = exp.span(); + let res = match exp { + Expr::Path(_) | Expr::Field(_) | Expr::Paren(_) => quote_spanned! {exp_span=> #exp.clone()}, + _ => quote_spanned! {exp_span=> (#exp).clone()}, + }; + syn::parse2(res).unwrap() +} + +pub fn collect_set(iter: impl Iterator) -> HashSet { iter.collect() } + +pub fn into_set(iter: impl IntoIterator) -> HashSet { iter.into_iter().collect() } + +pub fn punctuated_map(punc: Punctuated, mut f: impl FnMut(T) -> U) -> Punctuated { + let mut res = Punctuated::new(); + for pair in punc.into_pairs() { + let (t, p) = pair.into_tuple(); + res.push_value(f(t)); + if let Some(p) = p { + res.push_punct(p) + } + } + res +} + +pub fn punctuated_try_map( + punc: Punctuated, mut f: impl FnMut(T) -> Result, +) -> Result, E> { + let mut res = Punctuated::new(); + for pair in punc.into_pairs() { + let (t, p) = pair.into_tuple(); + res.push_value(f(t)?); + if let Some(p) = p { + res.push_punct(p) + } + } + Ok(res) +} + +pub fn flatten_punctuated(punc: Punctuated, P>) -> Punctuated { + let mut res = Punctuated::new(); + for inner_punc in punc.into_pairs() { + let (inner_punc, p) = inner_punc.into_tuple(); + let inner_punc_len = inner_punc.len(); + for (ind, item) in inner_punc.into_pairs().enumerate() { + let (t, p) = item.into_tuple(); + res.push_value(t); + if ind != inner_punc_len - 1 { + res.push_punct(p.unwrap()) + } + } + if let Some(p) = p { + res.push_punct(p) + } + } + res +} + +pub fn punctuated_try_unwrap(punc: Punctuated, P>) -> Result, E> { + let mut res = Punctuated::new(); + for pair in punc.into_pairs() { + let (t, p) = pair.into_tuple(); + res.push_value(t?); + if let Some(p) = p { + res.push_punct(p) + } + } + Ok(res) +} + +pub fn punctuated_singleton(item: T) -> Punctuated { + let mut res = Punctuated::new(); + res.push_value(item); + res +} + +pub fn expr_to_ident(expr: &Expr) -> Option { + match expr { + Expr::Path(p) => p.path.get_ident().cloned(), + _ => None, + } +} +pub fn expr_to_ident_mut(expr: &mut Expr) -> Option<&mut Ident> { + match expr { + Expr::Path(p) => path_get_ident_mut(&mut p.path), + _ => None, + } +} + +pub fn pat_to_ident(pat: &Pat) -> Option { + match pat { + Pat::Ident(ident) => Some(ident.ident.clone()), + _ => None, + } +} + +pub fn is_wild_card(expr: &Expr) -> bool { + match expr { + Expr::Infer(_) => true, + Expr::Verbatim(ts) => ts.to_string() == "_", + _ => false, + } +} + +pub fn token_stream_replace_macro_idents( + input: TokenStream, ident_replacements: &HashMap, +) -> TokenStream { + fn ts_replace(ts: TokenStream, ident_replacements: &HashMap, res: &mut Vec) { + let mut last_dollar = None; + for tt in ts { + if let Some(dollar) = last_dollar.take() { + let is_match = match &tt { + TokenTree::Ident(after_dollar_ident) => ident_replacements.get(after_dollar_ident), + _ => None, + }; + if let Some(replacement) = is_match { + res.extend(replacement.clone()); + continue; + } else { + res.push(dollar); + } + } + let is_dollar = match &tt { + TokenTree::Punct(punct) => punct.as_char() == '$', + _ => false, + }; + if is_dollar { + last_dollar = Some(tt); + } else { + match tt { + TokenTree::Group(grp) => { + let replaced = token_stream_replace_macro_idents(grp.stream(), ident_replacements); + let updated_group = Group::new(grp.delimiter(), replaced); + res.push(TokenTree::Group(updated_group)); + }, + _ => res.push(tt), + } + } + } + if let Some(dollar) = last_dollar { + res.push(dollar); + } + } + + let mut res = vec![]; + ts_replace(input, ident_replacements, &mut res); + + res.into_iter().collect() +} + +pub fn spans_eq(span1: &Span, span2: &Span) -> bool { format!("{:?}", span1) == format!("{:?}", span2) } + +// I don't know why I'm like this +pub trait Piper: Sized { + /// applies `f` to `self`, i.e., `f(self)` + fn pipe(self, f: impl FnOnce(Self) -> Res) -> Res; +} + +impl Piper for T +where T: Sized +{ + #[inline(always)] + fn pipe(self, f: impl FnOnce(Self) -> Res) -> Res { f(self) } +} + +/// sets the span of only top-level tokens to `span` +pub fn with_span(ts: TokenStream, span: Span) -> TokenStream { + ts.into_iter() + .map(|mut tt| { + tt.set_span(span); + tt + }) + .collect() +} + +pub(crate) trait TokenStreamExtensions { + fn with_span(self, span: Span) -> TokenStream; +} + +impl TokenStreamExtensions for TokenStream { + fn with_span(self, span: Span) -> TokenStream { with_span(self, span) } +} + +fn check_lazy_set_contains(hs: &mut HashSet, iter: &mut impl Iterator, x: T) -> bool { + if hs.contains(&x) { + return true + } + + for item in iter { + let eq = item == x; + hs.insert(item); + if eq { + return true + } + } + false +} + +pub fn intersects(set1: Iter1, set2: Iter2) -> bool +where + T: Hash + Eq, + Iter1: IntoIterator, + Iter2: IntoIterator, +{ + let mut hs = HashSet::default(); + let mut set1 = set1.into_iter(); + for x in set2 { + if check_lazy_set_contains(&mut hs, &mut set1, x) { + return true + } + } + false +} + +#[test] +fn test_subsumes_and_intersects() { + let cases = [ + (vec![1, 2, 3], vec![3, 3, 4], false, true), + (vec![1, 2, 3], vec![3, 3, 2], true, true), + (vec![1, 2, 3, 4], vec![1, 3, 4, 2], true, true), + (vec![1, 2, 3, 4], vec![], true, false), + (vec![1, 2, 3, 4], vec![4, 2, 3, 1, 1, 2, 3, 4], true, true), + (vec![1, 2, 3], vec![4], false, false), + (vec![], vec![4], false, false), + (vec![], vec![], true, false), + (vec![1, 2], vec![3, 4], false, false), + (vec![1, 2, 3, 4], vec![5, 6, 7, 1], false, true), + ]; + for (s1, s2, _subsumes_expected, intersects_expected) in cases { + println!("s1: {:?}, s2: {:?}", s1, s2); + assert_eq!(intersects(s1.iter(), s2.iter()), intersects_expected); + } +} diff --git a/ascent_tests/benches/benches.rs b/ascent_tests/benches/benches.rs index c0c0675..cddd1ff 100644 --- a/ascent_tests/benches/benches.rs +++ b/ascent_tests/benches/benches.rs @@ -1,160 +1,156 @@ -#![allow(dead_code)] - -use std::collections::{BTreeMap, HashMap}; -use std::time::Instant; -use ascent_tests::ascent_m_par; -use stopwatch::Stopwatch; -use ascent::{ascent}; -use ascent::lattice::Dual; - -mod tc { - use ascent::ascent; - - ascent! { - relation edge(i32, i32); - relation path(i32, i32); - // edge(x, x + 1) <-- for x in (0..1000); - path(*x, *y) <-- edge(x,y); - path(*x, *z) <-- edge(x,y), path(y, z); - // path(*x, *z) <-- path(x,y), edge(y, z); - - } -} - -fn loop_graph(nodes: usize) -> Vec<(i32, i32)> { - let mut res = vec![]; - let nodes = nodes as i32; - for x in 0..nodes { - res.push((x, (x + 1) % nodes)); - } - res -} - -fn complete_graph(nodes: usize) -> Vec<(i32, i32, u32)> { - let mut res = vec![]; - let nodes = nodes as i32; - for x in 0..nodes { - for y in 0..nodes { - if x != y { - res.push((x, y, 1)); - } - } - } - res -} - -fn bench_tc(nodes_count: i32) { - let mut tc = tc::AscentProgram::default(); - - for i in 0..nodes_count { - tc.edge.push((i, i + 1)); - } - - let mut stopwatch = Stopwatch::start_new(); - tc.run(); - stopwatch.stop(); - - println!("tc for {} nodes took {:?}", nodes_count, stopwatch.elapsed()); - println!("path size: {}", tc.path.len()); -} - -fn test_dl_lattice1(){ - ascent!{ - lattice shortest_path(i32, i32, Dual); - relation edge(i32, i32, u32); - - shortest_path(*x, *y, Dual(*w)) <-- edge(x, y, w); - shortest_path(*x, *z, Dual(w + l.0)) <-- edge(x, y, w), shortest_path(y, z, l); - - edge(1, 2, x + 30) <-- for x in 0..10000; - edge(2, 3, x + 50) <-- for x in 0..10000; - edge(1, 3, x + 40) <-- for x in 0..10000; - edge(2, 4, x + 100) <-- for x in 0..10000; - edge(1, 4, x + 200) <-- for x in 0..10000; - } - let mut prog = AscentProgram::default(); - prog.run(); - // println!("shortest_path ({} tuples):", prog.shortest_path.len()); - //println!("{:?}", prog.shortest_path); - for _i in prog.shortest_path.iter() { - - } - // println!("{}", AscentProgram::summary()); - // assert!(rels_equal(prog.shortest_path, [(1,2, Dual(30)), (1, 3, Dual(40)), (1,4, Dual(130)), (2,3, Dual(50)), (2, 4, Dual(100))])) -} - -fn bench_lattice(){ - let iterations = 100; - let before = Instant::now(); - for _ in 0..iterations { - test_dl_lattice1(); - } - let elapsed = before.elapsed(); - println!("average time: {:?}", elapsed / iterations); -} - - -fn bench_tc_path_join_path(nodes_count: i32) { - ascent_m_par! { - // #![include_rule_times] - struct TCPathJoinPath; - relation edge(i32, i32); - relation path(i32, i32); - path(x, z) <-- path(x,y), path(y, z); - path(x, y) <-- edge(x,y); - } - let mut tc = TCPathJoinPath::default(); - println!("{}", TCPathJoinPath::summary()); - - for i in 0..nodes_count { - tc.edge.push((i, i + 1)); - } - - let mut stopwatch = Stopwatch::start_new(); - tc.run(); - stopwatch.stop(); - println!("tc path_join_path for {} nodes took {:?}", nodes_count, stopwatch.elapsed()); - // println!("summary: \n{}", tc.scc_times_summary()); - println!("path size: {}", tc.path.len()); -} - -fn bench_hash(){ - - let mut hm = HashMap::new(); - let mut bt = BTreeMap::new(); - - let iters = 10_000_000; - - let random_nums = rand::seq::index::sample(&mut rand::thread_rng(), iters, iters); - - let before = Instant::now(); - for i in random_nums.iter() { - hm.insert((i, i, i), i * 2); - } - println!("hm took {:?}", before.elapsed()); - - let before = Instant::now(); - for i in random_nums.iter() { - bt.insert((i, i, i), i * 2); - } - println!("btree took {:?}", before.elapsed()); -} -fn bench_tc_for_graph(graph: Vec<(i32, i32)>, name: &str) { - - let before = Instant::now(); - let mut tc = tc::AscentProgram::default(); - tc.edge = graph; - tc.run(); - let elapsed = before.elapsed(); - println!("tc for {} took {:?}", name, elapsed); - // println!("summary: \n{}", tc.scc_times_summary()); - println!("path size: {}", tc.path.len()); -} - -fn main() { - // bench_tc(1000); - bench_tc_path_join_path(1000); - // bench_tc_for_graph(loop_graph(4000), "loop 4000"); - //bench_lattice(); - // bench_hash(); -} +#![allow(dead_code)] + +use std::collections::{BTreeMap, HashMap}; +use std::time::Instant; + +use ascent::ascent; +use ascent::lattice::Dual; +use ascent_tests::ascent_m_par; +use stopwatch::Stopwatch; + +mod tc { + use ascent::ascent; + + ascent! { + relation edge(i32, i32); + relation path(i32, i32); + // edge(x, x + 1) <-- for x in (0..1000); + path(*x, *y) <-- edge(x,y); + path(*x, *z) <-- edge(x,y), path(y, z); + // path(*x, *z) <-- path(x,y), edge(y, z); + + } +} + +fn loop_graph(nodes: usize) -> Vec<(i32, i32)> { + let mut res = vec![]; + let nodes = nodes as i32; + for x in 0..nodes { + res.push((x, (x + 1) % nodes)); + } + res +} + +fn complete_graph(nodes: usize) -> Vec<(i32, i32, u32)> { + let mut res = vec![]; + let nodes = nodes as i32; + for x in 0..nodes { + for y in 0..nodes { + if x != y { + res.push((x, y, 1)); + } + } + } + res +} + +fn bench_tc(nodes_count: i32) { + let mut tc = tc::AscentProgram::default(); + + for i in 0..nodes_count { + tc.edge.push((i, i + 1)); + } + + let mut stopwatch = Stopwatch::start_new(); + tc.run(); + stopwatch.stop(); + + println!("tc for {} nodes took {:?}", nodes_count, stopwatch.elapsed()); + println!("path size: {}", tc.path.len()); +} + +fn test_dl_lattice1() { + ascent! { + lattice shortest_path(i32, i32, Dual); + relation edge(i32, i32, u32); + + shortest_path(*x, *y, Dual(*w)) <-- edge(x, y, w); + shortest_path(*x, *z, Dual(w + l.0)) <-- edge(x, y, w), shortest_path(y, z, l); + + edge(1, 2, x + 30) <-- for x in 0..10000; + edge(2, 3, x + 50) <-- for x in 0..10000; + edge(1, 3, x + 40) <-- for x in 0..10000; + edge(2, 4, x + 100) <-- for x in 0..10000; + edge(1, 4, x + 200) <-- for x in 0..10000; + } + let mut prog = AscentProgram::default(); + prog.run(); + // println!("shortest_path ({} tuples):", prog.shortest_path.len()); + //println!("{:?}", prog.shortest_path); + for _i in prog.shortest_path.iter() {} + // println!("{}", AscentProgram::summary()); + // assert!(rels_equal(prog.shortest_path, [(1,2, Dual(30)), (1, 3, Dual(40)), (1,4, Dual(130)), (2,3, Dual(50)), (2, 4, Dual(100))])) +} + +fn bench_lattice() { + let iterations = 100; + let before = Instant::now(); + for _ in 0..iterations { + test_dl_lattice1(); + } + let elapsed = before.elapsed(); + println!("average time: {:?}", elapsed / iterations); +} + +fn bench_tc_path_join_path(nodes_count: i32) { + ascent_m_par! { + // #![include_rule_times] + struct TCPathJoinPath; + relation edge(i32, i32); + relation path(i32, i32); + path(x, z) <-- path(x,y), path(y, z); + path(x, y) <-- edge(x,y); + } + let mut tc = TCPathJoinPath::default(); + println!("{}", TCPathJoinPath::summary()); + + for i in 0..nodes_count { + tc.edge.push((i, i + 1)); + } + + let mut stopwatch = Stopwatch::start_new(); + tc.run(); + stopwatch.stop(); + println!("tc path_join_path for {} nodes took {:?}", nodes_count, stopwatch.elapsed()); + // println!("summary: \n{}", tc.scc_times_summary()); + println!("path size: {}", tc.path.len()); +} + +fn bench_hash() { + let mut hm = HashMap::new(); + let mut bt = BTreeMap::new(); + + let iters = 10_000_000; + + let random_nums = rand::seq::index::sample(&mut rand::thread_rng(), iters, iters); + + let before = Instant::now(); + for i in random_nums.iter() { + hm.insert((i, i, i), i * 2); + } + println!("hm took {:?}", before.elapsed()); + + let before = Instant::now(); + for i in random_nums.iter() { + bt.insert((i, i, i), i * 2); + } + println!("btree took {:?}", before.elapsed()); +} +fn bench_tc_for_graph(graph: Vec<(i32, i32)>, name: &str) { + let before = Instant::now(); + let mut tc = tc::AscentProgram::default(); + tc.edge = graph; + tc.run(); + let elapsed = before.elapsed(); + println!("tc for {} took {:?}", name, elapsed); + // println!("summary: \n{}", tc.scc_times_summary()); + println!("path size: {}", tc.path.len()); +} + +fn main() { + // bench_tc(1000); + bench_tc_path_join_path(1000); + // bench_tc_for_graph(loop_graph(4000), "loop 4000"); + //bench_lattice(); + // bench_hash(); +} diff --git a/ascent_tests/src/agg_tests.rs b/ascent_tests/src/agg_tests.rs index c09d4d3..1f45bd1 100644 --- a/ascent_tests/src/agg_tests.rs +++ b/ascent_tests/src/agg_tests.rs @@ -1,189 +1,185 @@ -#![cfg(test)] -use ascent::ascent; -use ascent::ascent_run; -use itertools::Itertools; - -use crate::ascent_run_m_par; -use crate::assert_rels_eq; -use crate::utils::rels_equal; - -fn percentile<'a, TInputIter>(p: f32) -> impl Fn(TInputIter) -> std::option::IntoIter -where - TInputIter: Iterator, -{ - move |inp| { - let sorted = inp.map(|tuple| *tuple.0).sorted().collect_vec(); - let p_index = (sorted.len() as f32 * p / 100.0) as usize; - let p_index = p_index.clamp(0, sorted.len() - 1); - sorted.get(p_index).cloned().into_iter() - } -} - -#[test] -fn test_ascent_agg3(){ - let res = ascent_run_m_par!{ - relation foo(i32, i32); - relation bar(i32, i32, i32); - relation baz(i32, i32); - foo(1, 2); - foo(10, 11); - - bar(1, x, y), - bar(10, x * 10, y * 10), - bar(100, x * 100, y * 100) <-- for (x, y) in (1..100).map(|x| (x, x * 2)); - - baz(a, x_75th_p) <-- - foo(a, _), - agg x_75th_p = (percentile(75.0))(x) in bar(a, x, _); - }; - // println!("{}", res.summary()); - println!("baz: {:?}", res.baz); - assert!(rels_equal([(1, 75), (10, 750)], res.baz)); -} - -#[test] -fn test_ascent_agg4(){ - use ascent::aggregators::*; - let res = ascent_run_m_par!{ - relation foo(i32, i32); - relation bar(i32, i32, i32); - relation baz(i32, i32, i32); - foo(1, 2); - foo(10, 11); - - bar(1, x, y), - bar(10, x * 10, y * 10), - bar(100, x * 100, y * 100) <-- for (x, y) in (1..100).map(|x| (x, x * 2)); - - baz(a, x_mean as i32, y_mean as i32) <-- - foo(a, _), - agg x_mean = mean(x) in bar(a, x, _), - agg y_mean = mean(y) in bar(a, _, y); - - }; - // println!("{}", res.summary()); - println!("baz: {:?}", res.baz); - assert!(rels_equal([(1, 50, 100), (10, 500, 1000)], res.baz)); -} - -#[test] -fn test_ascent_negation(){ - use ascent::aggregators::*; - let res = ascent_run_m_par!{ - relation foo(i32, i32); - relation bar(i32, i32, i32); - relation baz(i32, i32); - relation baz2(i32, i32); - - foo(0, 1); - foo(1, 2); - foo(10, 11); - foo(100, 101); - - bar(1, 2, 102); - bar(10, 11, 20); - bar(10, 11, 12); - - baz(x, y) <-- - foo(x, y), - !bar(x, y, _); - - // equivalent to: - baz2(x, y) <-- - foo(x, y), - agg () = not() in bar(x, y, _); - }; - // println!("{}", res.summary()); - println!("baz: {:?}", res.baz); - assert!(rels_equal([(0, 1), (100, 101)], res.baz)); - assert!(rels_equal([(0, 1), (100, 101)], res.baz2)); -} - -#[test] -fn test_ascent_negation2(){ - use ascent::aggregators::*; - let res = ascent_run_m_par!{ - relation foo(i32, i32); - relation bar(i32, i32); - relation baz(i32, i32); - relation baz2(i32, i32); - - foo(0, 1); - foo(1, 2); - foo(10, 11); - foo(100, 101); - - bar(1, 2); - bar(10, 11); - bar(10, 11); - - baz(x, y) <-- - foo(x, y), - !bar(x, y); - - // equivalent to: - baz2(x, y) <-- - foo(x, y), - agg () = not() in bar(x, y); - }; - // println!("{}", res.summary()); - println!("baz: {:?}", res.baz); - assert_rels_eq!([(0, 1), (100, 101)], res.baz); - assert_rels_eq!([(0, 1), (100, 101)], res.baz2); -} - -#[test] -fn test_ascent_negation3(){ - use ascent::aggregators::*; - let res = ascent_run_m_par!{ - relation foo(i32, i32); - relation bar(i32, i32, i32); - relation baz(i32, i32); - - foo(0, 1); - foo(1, 2); - foo(10, 11); - foo(100, 101); - - bar(1, 2, 3); - bar(10, 11, 13); - - baz(x, y) <-- - foo(x, y), - !bar(x, y, y + 1); - }; - // println!("{}", res.summary()); - println!("baz: {:?}", res.baz); - assert!(rels_equal([(0, 1), (10, 11), (100, 101)], res.baz)); -} - -#[test] -fn test_ascent_agg_simple(){ - use ascent::aggregators::*; - let res = ascent_run_m_par!{ - relation foo(i32); - foo(0); foo(10); - - relation bar(i32); - bar(m as i32) <-- agg m = mean(x) in foo(x); - }; - assert!(rels_equal([(5,)], res.bar)); -} - -// Must fail to compile: -// #[test] -// fn test_ascent_agg_not_stratifiable(){ -// use ascent::aggregators::*; -// let res = ascent_run!{ -// relation foo(i32, i32, i32); -// relation bar(i32, i32); -// relation baz(i32); - -// baz(x) <-- -// foo(x, _, _), -// !bar(_, x); - -// bar(x, x + 1) <-- baz(x); -// }; -// assert!(rels_equal([(5,)], res.bar)); -// } \ No newline at end of file +#![cfg(test)] +use ascent::{ascent, ascent_run}; +use itertools::Itertools; + +use crate::utils::rels_equal; +use crate::{ascent_run_m_par, assert_rels_eq}; + +fn percentile<'a, TInputIter>(p: f32) -> impl Fn(TInputIter) -> std::option::IntoIter +where TInputIter: Iterator { + move |inp| { + let sorted = inp.map(|tuple| *tuple.0).sorted().collect_vec(); + let p_index = (sorted.len() as f32 * p / 100.0) as usize; + let p_index = p_index.clamp(0, sorted.len() - 1); + sorted.get(p_index).cloned().into_iter() + } +} + +#[test] +fn test_ascent_agg3() { + let res = ascent_run_m_par! { + relation foo(i32, i32); + relation bar(i32, i32, i32); + relation baz(i32, i32); + foo(1, 2); + foo(10, 11); + + bar(1, x, y), + bar(10, x * 10, y * 10), + bar(100, x * 100, y * 100) <-- for (x, y) in (1..100).map(|x| (x, x * 2)); + + baz(a, x_75th_p) <-- + foo(a, _), + agg x_75th_p = (percentile(75.0))(x) in bar(a, x, _); + }; + // println!("{}", res.summary()); + println!("baz: {:?}", res.baz); + assert!(rels_equal([(1, 75), (10, 750)], res.baz)); +} + +#[test] +fn test_ascent_agg4() { + use ascent::aggregators::*; + let res = ascent_run_m_par! { + relation foo(i32, i32); + relation bar(i32, i32, i32); + relation baz(i32, i32, i32); + foo(1, 2); + foo(10, 11); + + bar(1, x, y), + bar(10, x * 10, y * 10), + bar(100, x * 100, y * 100) <-- for (x, y) in (1..100).map(|x| (x, x * 2)); + + baz(a, x_mean as i32, y_mean as i32) <-- + foo(a, _), + agg x_mean = mean(x) in bar(a, x, _), + agg y_mean = mean(y) in bar(a, _, y); + + }; + // println!("{}", res.summary()); + println!("baz: {:?}", res.baz); + assert!(rels_equal([(1, 50, 100), (10, 500, 1000)], res.baz)); +} + +#[test] +fn test_ascent_negation() { + use ascent::aggregators::*; + let res = ascent_run_m_par! { + relation foo(i32, i32); + relation bar(i32, i32, i32); + relation baz(i32, i32); + relation baz2(i32, i32); + + foo(0, 1); + foo(1, 2); + foo(10, 11); + foo(100, 101); + + bar(1, 2, 102); + bar(10, 11, 20); + bar(10, 11, 12); + + baz(x, y) <-- + foo(x, y), + !bar(x, y, _); + + // equivalent to: + baz2(x, y) <-- + foo(x, y), + agg () = not() in bar(x, y, _); + }; + // println!("{}", res.summary()); + println!("baz: {:?}", res.baz); + assert!(rels_equal([(0, 1), (100, 101)], res.baz)); + assert!(rels_equal([(0, 1), (100, 101)], res.baz2)); +} + +#[test] +fn test_ascent_negation2() { + use ascent::aggregators::*; + let res = ascent_run_m_par! { + relation foo(i32, i32); + relation bar(i32, i32); + relation baz(i32, i32); + relation baz2(i32, i32); + + foo(0, 1); + foo(1, 2); + foo(10, 11); + foo(100, 101); + + bar(1, 2); + bar(10, 11); + bar(10, 11); + + baz(x, y) <-- + foo(x, y), + !bar(x, y); + + // equivalent to: + baz2(x, y) <-- + foo(x, y), + agg () = not() in bar(x, y); + }; + // println!("{}", res.summary()); + println!("baz: {:?}", res.baz); + assert_rels_eq!([(0, 1), (100, 101)], res.baz); + assert_rels_eq!([(0, 1), (100, 101)], res.baz2); +} + +#[test] +fn test_ascent_negation3() { + use ascent::aggregators::*; + let res = ascent_run_m_par! { + relation foo(i32, i32); + relation bar(i32, i32, i32); + relation baz(i32, i32); + + foo(0, 1); + foo(1, 2); + foo(10, 11); + foo(100, 101); + + bar(1, 2, 3); + bar(10, 11, 13); + + baz(x, y) <-- + foo(x, y), + !bar(x, y, y + 1); + }; + // println!("{}", res.summary()); + println!("baz: {:?}", res.baz); + assert!(rels_equal([(0, 1), (10, 11), (100, 101)], res.baz)); +} + +#[test] +fn test_ascent_agg_simple() { + use ascent::aggregators::*; + let res = ascent_run_m_par! { + relation foo(i32); + foo(0); foo(10); + + relation bar(i32); + bar(m as i32) <-- agg m = mean(x) in foo(x); + }; + assert!(rels_equal([(5,)], res.bar)); +} + +// Must fail to compile: +// #[test] +// fn test_ascent_agg_not_stratifiable(){ +// use ascent::aggregators::*; +// let res = ascent_run!{ +// relation foo(i32, i32, i32); +// relation bar(i32, i32); +// relation baz(i32); + +// baz(x) <-- +// foo(x, _, _), +// !bar(_, x); + +// bar(x, x + 1) <-- baz(x); +// }; +// assert!(rels_equal([(5,)], res.bar)); +// } diff --git a/ascent_tests/src/analysis_exp.rs b/ascent_tests/src/analysis_exp.rs index 6c6251b..4380941 100644 --- a/ascent_tests/src/analysis_exp.rs +++ b/ascent_tests/src/analysis_exp.rs @@ -1,267 +1,263 @@ -#![allow(dead_code)] -///! k-cfa on lambda calculus + numbers - -use std::collections::BTreeMap; -use std::ops::Deref; -use std::rc::Rc; - -use arrayvec::ArrayVec; -use ascent::ascent; -use ascent::ascent_run; - -use Expr::*; -use ascent::lattice::constant_propagation::ConstPropagation; -use crate::utils::*; -use itertools::Itertools; -type Var = &'static str; -type NumConcrete = isize; -type Num = ConstPropagation; - -#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)] -pub enum Op { - Add, Mul, Sub, Div -} - -#[derive(Clone, PartialEq, Eq, Debug, Hash)] -pub enum Expr{ - Ref(Var), - Lam(Var, Rc), - App(Rc, Rc), - Lit(Num), - Binop(Op, Rc, Rc) -} -fn app(f: Expr, a: Expr) -> Expr { - App(Rc::new(f), Rc::new(a)) -} -fn lam(x: Var, e: Expr) -> Expr { - Lam(x, Rc::new(e)) -} -fn binop(op: Op, e1: Expr, e2: Expr) -> Expr { - Binop(op, Rc::new(e1), Rc::new(e2)) -} - -fn sub(exp: &Expr, var: &str, e: &Expr) -> Expr { - match exp { - Ref(x) if *x == var => e.clone(), - Ref(_x) => exp.clone(), - App(ef,ea) => app(sub(ef, var, e), sub(ea, var, e)), - Lam(x, _eb) if *x == var => exp.clone(), - Lam(x, eb) => lam(x, sub(eb, var, e)), - Lit(_) => exp.clone(), - Binop(op, e1, e2) => Binop(*op, Rc::new(sub(e1, var, e)), Rc::new(sub(e2, var, e))), - } -} - -#[allow(non_snake_case)] -fn U() -> Expr {lam("ux", app(Ref("ux"), Ref("ux")))} -#[allow(non_snake_case)] -fn I() -> Expr {lam("ix", Ref("ix"))} - -const K: usize = 1; -type Contour = ArrayVec; -type Lab = Expr; -type Time = (Option, Contour); - -#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] -enum Either {Left(L), Right(R)} -type Addr = (Either, Contour); -#[derive(Clone, Hash, PartialEq, Eq, Debug)] -enum Storable { - Value(Expr, Env), - Kont(Continuation) -} -use Storable::*; -#[derive(Clone, Hash, PartialEq, Eq, Debug)] -enum Continuation { - Fn(Expr, Env, Addr), - Ar(Expr, Env, Addr), - BinopAr2(Op, Expr, Env, Addr), - BinopAr1(Op, Num, Addr), - Mt -} -use Continuation::*; -type Env = Rc>; - -fn array_vec_cons(x: T, array_vec: &ArrayVec) -> ArrayVec { - if N == 0 {return ArrayVec::new()} - let mut res = ArrayVec::new(); - res.insert(0, x); - let to_take = array_vec.len(); - let to_take = if to_take == N {N - 1} else {to_take}; - res.extend(array_vec[0..to_take].iter().cloned()); - res -} -fn tick(e: &Expr, _ρ: &Env, _a: &Addr, t: &Time, k: &Continuation) -> Time { - let (lt, δ) = t; - match e { - Ref(_) => t.clone(), - App(..) => (Some(e.clone()), t.1.clone()), - Lam(..) => match k { - BinopAr2(_, _, _, _) | - BinopAr1(_, _, _) | // TODO this line is a judgment call - Ar(_, _, _) => t.clone(), - Fn(_, _, _) => (None, match lt {Some(lt) => array_vec_cons(lt.clone(), δ), None => δ.clone()}), - Mt => panic!("tick on Mt continuation") - }, - Lit(_) => t.clone(), - Binop(_, _, _) => t.clone(), // TODO this line is a judgment call - } -} -fn alloc(e: &Expr, _ρ: &Env, _a: &Addr, t: &Time, k: &Continuation) -> Addr { - let (_lt, δ) = t; - match e { - Binop(..) | - App(..) => (Either::Left(e.clone()), δ.clone()), - Lit(_) | - Lam(_, _) => match k { - BinopAr2(_, ek, _, _) | // TODO judgment call - Ar(ek, _, _) => (Either::Left(ek.clone()), δ.clone()), - Fn(Lam(x, _e), _, _) => (Either::Right(x), δ.clone()), - Fn(..) => panic!("alloc() call with Fn continuation with non-lambda expression"), - BinopAr1(_, _, _) => (Either::Right("IT"), δ.clone()), - Mt => panic!("alloc() called with Mt continuation"), - }, - Ref(_) => panic!("alloc with Ref(_) as expression"), - } -} -fn upd(ρ: &Env, var: Var, addr: Addr) -> Env{ - let mut ρ = ρ.deref().clone(); - ρ.insert(var, addr); - Rc::new(ρ) -} - -fn atom(e: &Expr) -> bool { - match e { - Lit(_) => true, - Lam(_, _) => true, - _ => false - } -} - -fn apply_op_concrete(op: Op, x: NumConcrete, y: NumConcrete) -> NumConcrete { - match op { - Op::Add => x + y, - Op::Mul => x * y, - Op::Sub => x - y, - Op::Div => x / y, - } -} - -fn apply_op(op: Op, x: &Num, y: &Num) -> Num { - use ascent::lattice::constant_propagation::ConstPropagation::*; - match (x, y) { - (Bottom, _) => Bottom, - (_, Bottom) => Bottom, - (Constant(x), Constant(y)) => Constant(apply_op_concrete(op, *x, *y)), - (Constant(0), Top) if op == Op::Mul => Constant(0), - (Top, Constant(0)) if op == Op::Mul => Constant(0), - (_, Top) => Top, - (Top, _) => Top, - } -} - -ascent!{ - struct CESK; - relation σ(Addr, Storable); - lattice σnum(Addr, Num); - relation ς(Expr, Env, Addr, Time); - - ς(v.clone(), ρ2, a, tick(e, ρ, a, t, k)) <-- - ς(?e@Ref(x), ρ, a, t), - (σ(ρ[x], ?Value(v, ρ2)) || - σnum(ρ[x], lit), let v = Lit(*lit), let ρ2 = ρ), - σ(a, ?Kont(k)); - - σ(b.clone(), Kont(Ar(e1.deref().clone(), ρ.clone(), a.clone()))), - ς(e0, ρ, b, tick(e, ρ, a, t, k)) <-- - ς(?e@App(e0, e1), ρ, a, t), - σ(a, ?Kont(k)), - let b = alloc(e, ρ, a, t, k); - - σ(b.clone(), Kont(BinopAr2(*op, e2.deref().clone(), ρ.clone(), a.clone()))), - ς(e1, ρ, b, tick(e, ρ, a, t, k)) <-- - ς(?e@Binop(op, e1, e2), ρ, a, t), - σ(a, ?Kont(k)), - let b = alloc(e, ρ, a, t, k); - - σ(b.clone(), Kont(Fn(v.clone(), ρ.clone(), c.clone()))), - ς(e, ρ2, b, tick(e, ρ, a, t, k)) <-- - ς(?v@Lam(..), ρ, a, t), - σ(a, ?Kont(k)), - if let Ar(e, ρ2, c) = k, - let b = alloc(v, ρ, a, t, k); - - σ(b.clone(), Kont(BinopAr1(*op, *l, c.clone()))), - ς(e, ρ2, b, tick(v, ρ, a, t, k)) <-- - ς(?v@Lit(l), ρ, a, t), - σ(a, ?Kont(k)), - if let BinopAr2(op, e, ρ2, c) = k, - let b = alloc(v, ρ, a, t, k); - - σnum(op_addr.clone(), apply_op(*op, l1, l2)), - ς(Ref("IT"), upd(ρ, "IT", op_addr), c, tick(v2, ρ, a, t, k)) <-- - ς(?v2@Lit(l2), ρ, a, t), - σ(a, ?Kont(k)), - if let BinopAr1(op, l1, c) = k, - let op_addr = alloc(v2, ρ, a, t, k); - - σ(b.clone(), Value(v.clone(), ρ.clone())), - ς(e, upd(ρ2, x, b), c, tick(v, ρ, a, t, k)) <-- - ς(?v@Lam(..), ρ, a, t), - σ(a, ?Kont(k)), - if let Fn(Lam(x, e), ρ2, c) = k, - let b = alloc(v, ρ, a, t, k); - - σnum(b.clone(), lit), - ς(e, upd(ρ2, x, b), c, tick(v, ρ, a, t, k)) <-- - ς(?v@Lit(lit), ρ, a, t), - σ(a, ?Kont(k)), - if let Fn(Lam(x, e), ρ2, c) = k, - let b = alloc(v, ρ, a, t, k); - - relation output(Expr, Env); - relation input(Expr); - σ((Either::Right("__a0"), Contour::default()), Storable::Kont(Continuation::Mt)), - ς(e, Env::default(), (Either::Right("__a0"), Contour::default()), (None, Contour::default())) <-- - input(e); - - output(e, ρ) <-- - ς(e, ρ, a, t) if atom(e), - σ(a, Kont(Continuation::Mt)); -} - -fn let_(x: &'static str, e0: Expr, e1: Expr) -> Expr { - app(lam(x, e1), e0) -} -#[allow(non_snake_case)] -fn Y() -> Expr{ - // (λ (f) (let ([u′ (λ (x) (f (λ (v) (let ([xx (x x)]) (xx v)))))]) - // (u′ u′))) - - // u' = (λ (x) (f (λ (v) (let ([xx (x x)]) (xx v))))) - let uprime = lam("x", app(Ref("f"), lam("v", let_("xx", app(Ref("x"), Ref("x")), app(Ref("xx"), Ref("v")))))); - lam("f", let_("u'", uprime, app(Ref("u'"), Ref("u'")))) -} - -// #[test] -pub fn analysis_exp(){ - use ascent::lattice::constant_propagation::ConstPropagation::*; - // println!("CESK summary:\n{}", CESK::summary()); - - // term = (λx. 42 + x) 58 - // let term = app(lam("x", binop(Op::Add, Lit(Constant(42)), Ref("x"))), - // Lit(Constant(58))); - - //equal to f, where f = λ x. f(x + 1) - let f = app(Y(), lam("self", lam("x", app(Ref("self"), binop(Op::Add, Ref("x"), Lit(Constant(1))))))); - let term = app(f, Lit(Constant(0))); - let mut cesk = CESK::default(); - cesk.input = vec![(term,)]; - cesk.run(); - // println!("ς: \n{}", cesk.ς.iter().map(|x| format!("{:?}", x)).join("\n")); - // println!("σ: \n{}", cesk.σ.iter().map(|x| format!("{:?}", x)).join("\n")); - println!("σnum: \n{}", cesk.σnum.iter().map(|x| format!("{:?}", x)).join("\n")); - - println!("summary: \n{}", cesk.relation_sizes_summary()); - println!("output: \n{}", cesk.output.iter().map(|x| format!("{:?}", x.0)).join("\n\n")); -} +#![allow(dead_code)] +///! k-cfa on lambda calculus + numbers +use std::collections::BTreeMap; +use std::ops::Deref; +use std::rc::Rc; + +use Expr::*; +use arrayvec::ArrayVec; +use ascent::lattice::constant_propagation::ConstPropagation; +use ascent::{ascent, ascent_run}; +use itertools::Itertools; + +use crate::utils::*; +type Var = &'static str; +type NumConcrete = isize; +type Num = ConstPropagation; + +#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)] +pub enum Op { + Add, + Mul, + Sub, + Div, +} + +#[derive(Clone, PartialEq, Eq, Debug, Hash)] +pub enum Expr { + Ref(Var), + Lam(Var, Rc), + App(Rc, Rc), + Lit(Num), + Binop(Op, Rc, Rc), +} +fn app(f: Expr, a: Expr) -> Expr { App(Rc::new(f), Rc::new(a)) } +fn lam(x: Var, e: Expr) -> Expr { Lam(x, Rc::new(e)) } +fn binop(op: Op, e1: Expr, e2: Expr) -> Expr { Binop(op, Rc::new(e1), Rc::new(e2)) } + +fn sub(exp: &Expr, var: &str, e: &Expr) -> Expr { + match exp { + Ref(x) if *x == var => e.clone(), + Ref(_x) => exp.clone(), + App(ef, ea) => app(sub(ef, var, e), sub(ea, var, e)), + Lam(x, _eb) if *x == var => exp.clone(), + Lam(x, eb) => lam(x, sub(eb, var, e)), + Lit(_) => exp.clone(), + Binop(op, e1, e2) => Binop(*op, Rc::new(sub(e1, var, e)), Rc::new(sub(e2, var, e))), + } +} + +#[allow(non_snake_case)] +fn U() -> Expr { lam("ux", app(Ref("ux"), Ref("ux"))) } +#[allow(non_snake_case)] +fn I() -> Expr { lam("ix", Ref("ix")) } + +const K: usize = 1; +type Contour = ArrayVec; +type Lab = Expr; +type Time = (Option, Contour); + +#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] +enum Either { + Left(L), + Right(R), +} +type Addr = (Either, Contour); +#[derive(Clone, Hash, PartialEq, Eq, Debug)] +enum Storable { + Value(Expr, Env), + Kont(Continuation), +} +use Storable::*; +#[derive(Clone, Hash, PartialEq, Eq, Debug)] +enum Continuation { + Fn(Expr, Env, Addr), + Ar(Expr, Env, Addr), + BinopAr2(Op, Expr, Env, Addr), + BinopAr1(Op, Num, Addr), + Mt, +} +use Continuation::*; +type Env = Rc>; + +fn array_vec_cons(x: T, array_vec: &ArrayVec) -> ArrayVec { + if N == 0 { + return ArrayVec::new() + } + let mut res = ArrayVec::new(); + res.insert(0, x); + let to_take = array_vec.len(); + let to_take = if to_take == N { N - 1 } else { to_take }; + res.extend(array_vec[0..to_take].iter().cloned()); + res +} +fn tick(e: &Expr, _ρ: &Env, _a: &Addr, t: &Time, k: &Continuation) -> Time { + let (lt, δ) = t; + match e { + Ref(_) => t.clone(), + App(..) => (Some(e.clone()), t.1.clone()), + Lam(..) => match k { + BinopAr2(_, _, _, _) | + BinopAr1(_, _, _) | // TODO this line is a judgment call + Ar(_, _, _) => t.clone(), + Fn(_, _, _) => (None, match lt {Some(lt) => array_vec_cons(lt.clone(), δ), None => δ.clone()}), + Mt => panic!("tick on Mt continuation") + }, + Lit(_) => t.clone(), + Binop(_, _, _) => t.clone(), // TODO this line is a judgment call + } +} +fn alloc(e: &Expr, _ρ: &Env, _a: &Addr, t: &Time, k: &Continuation) -> Addr { + let (_lt, δ) = t; + match e { + Binop(..) | App(..) => (Either::Left(e.clone()), δ.clone()), + Lit(_) | Lam(_, _) => match k { + BinopAr2(_, ek, _, _) | // TODO judgment call + Ar(ek, _, _) => (Either::Left(ek.clone()), δ.clone()), + Fn(Lam(x, _e), _, _) => (Either::Right(x), δ.clone()), + Fn(..) => panic!("alloc() call with Fn continuation with non-lambda expression"), + BinopAr1(_, _, _) => (Either::Right("IT"), δ.clone()), + Mt => panic!("alloc() called with Mt continuation"), + }, + Ref(_) => panic!("alloc with Ref(_) as expression"), + } +} +fn upd(ρ: &Env, var: Var, addr: Addr) -> Env { + let mut ρ = ρ.deref().clone(); + ρ.insert(var, addr); + Rc::new(ρ) +} + +fn atom(e: &Expr) -> bool { + match e { + Lit(_) => true, + Lam(_, _) => true, + _ => false, + } +} + +fn apply_op_concrete(op: Op, x: NumConcrete, y: NumConcrete) -> NumConcrete { + match op { + Op::Add => x + y, + Op::Mul => x * y, + Op::Sub => x - y, + Op::Div => x / y, + } +} + +fn apply_op(op: Op, x: &Num, y: &Num) -> Num { + use ascent::lattice::constant_propagation::ConstPropagation::*; + match (x, y) { + (Bottom, _) => Bottom, + (_, Bottom) => Bottom, + (Constant(x), Constant(y)) => Constant(apply_op_concrete(op, *x, *y)), + (Constant(0), Top) if op == Op::Mul => Constant(0), + (Top, Constant(0)) if op == Op::Mul => Constant(0), + (_, Top) => Top, + (Top, _) => Top, + } +} + +ascent! { + struct CESK; + relation σ(Addr, Storable); + lattice σnum(Addr, Num); + relation ς(Expr, Env, Addr, Time); + + ς(v.clone(), ρ2, a, tick(e, ρ, a, t, k)) <-- + ς(?e@Ref(x), ρ, a, t), + (σ(ρ[x], ?Value(v, ρ2)) || + σnum(ρ[x], lit), let v = Lit(*lit), let ρ2 = ρ), + σ(a, ?Kont(k)); + + σ(b.clone(), Kont(Ar(e1.deref().clone(), ρ.clone(), a.clone()))), + ς(e0, ρ, b, tick(e, ρ, a, t, k)) <-- + ς(?e@App(e0, e1), ρ, a, t), + σ(a, ?Kont(k)), + let b = alloc(e, ρ, a, t, k); + + σ(b.clone(), Kont(BinopAr2(*op, e2.deref().clone(), ρ.clone(), a.clone()))), + ς(e1, ρ, b, tick(e, ρ, a, t, k)) <-- + ς(?e@Binop(op, e1, e2), ρ, a, t), + σ(a, ?Kont(k)), + let b = alloc(e, ρ, a, t, k); + + σ(b.clone(), Kont(Fn(v.clone(), ρ.clone(), c.clone()))), + ς(e, ρ2, b, tick(e, ρ, a, t, k)) <-- + ς(?v@Lam(..), ρ, a, t), + σ(a, ?Kont(k)), + if let Ar(e, ρ2, c) = k, + let b = alloc(v, ρ, a, t, k); + + σ(b.clone(), Kont(BinopAr1(*op, *l, c.clone()))), + ς(e, ρ2, b, tick(v, ρ, a, t, k)) <-- + ς(?v@Lit(l), ρ, a, t), + σ(a, ?Kont(k)), + if let BinopAr2(op, e, ρ2, c) = k, + let b = alloc(v, ρ, a, t, k); + + σnum(op_addr.clone(), apply_op(*op, l1, l2)), + ς(Ref("IT"), upd(ρ, "IT", op_addr), c, tick(v2, ρ, a, t, k)) <-- + ς(?v2@Lit(l2), ρ, a, t), + σ(a, ?Kont(k)), + if let BinopAr1(op, l1, c) = k, + let op_addr = alloc(v2, ρ, a, t, k); + + σ(b.clone(), Value(v.clone(), ρ.clone())), + ς(e, upd(ρ2, x, b), c, tick(v, ρ, a, t, k)) <-- + ς(?v@Lam(..), ρ, a, t), + σ(a, ?Kont(k)), + if let Fn(Lam(x, e), ρ2, c) = k, + let b = alloc(v, ρ, a, t, k); + + σnum(b.clone(), lit), + ς(e, upd(ρ2, x, b), c, tick(v, ρ, a, t, k)) <-- + ς(?v@Lit(lit), ρ, a, t), + σ(a, ?Kont(k)), + if let Fn(Lam(x, e), ρ2, c) = k, + let b = alloc(v, ρ, a, t, k); + + relation output(Expr, Env); + relation input(Expr); + σ((Either::Right("__a0"), Contour::default()), Storable::Kont(Continuation::Mt)), + ς(e, Env::default(), (Either::Right("__a0"), Contour::default()), (None, Contour::default())) <-- + input(e); + + output(e, ρ) <-- + ς(e, ρ, a, t) if atom(e), + σ(a, Kont(Continuation::Mt)); +} + +fn let_(x: &'static str, e0: Expr, e1: Expr) -> Expr { app(lam(x, e1), e0) } +#[allow(non_snake_case)] +fn Y() -> Expr { + // (λ (f) (let ([u′ (λ (x) (f (λ (v) (let ([xx (x x)]) (xx v)))))]) + // (u′ u′))) + + // u' = (λ (x) (f (λ (v) (let ([xx (x x)]) (xx v))))) + let uprime = lam("x", app(Ref("f"), lam("v", let_("xx", app(Ref("x"), Ref("x")), app(Ref("xx"), Ref("v")))))); + lam("f", let_("u'", uprime, app(Ref("u'"), Ref("u'")))) +} + +// #[test] +pub fn analysis_exp() { + use ascent::lattice::constant_propagation::ConstPropagation::*; + // println!("CESK summary:\n{}", CESK::summary()); + + // term = (λx. 42 + x) 58 + // let term = app(lam("x", binop(Op::Add, Lit(Constant(42)), Ref("x"))), + // Lit(Constant(58))); + + //equal to f, where f = λ x. f(x + 1) + let f = app(Y(), lam("self", lam("x", app(Ref("self"), binop(Op::Add, Ref("x"), Lit(Constant(1))))))); + let term = app(f, Lit(Constant(0))); + let mut cesk = CESK::default(); + cesk.input = vec![(term,)]; + cesk.run(); + // println!("ς: \n{}", cesk.ς.iter().map(|x| format!("{:?}", x)).join("\n")); + // println!("σ: \n{}", cesk.σ.iter().map(|x| format!("{:?}", x)).join("\n")); + println!("σnum: \n{}", cesk.σnum.iter().map(|x| format!("{:?}", x)).join("\n")); + + println!("summary: \n{}", cesk.relation_sizes_summary()); + println!("output: \n{}", cesk.output.iter().map(|x| format!("{:?}", x.0)).join("\n\n")); +} diff --git a/ascent_tests/src/ascent_maybe_par.rs b/ascent_tests/src/ascent_maybe_par.rs index da3a085..8e76e4a 100644 --- a/ascent_tests/src/ascent_maybe_par.rs +++ b/ascent_tests/src/ascent_maybe_par.rs @@ -1,46 +1,43 @@ -use std::sync::RwLock; - - -#[cfg(not(feature = "par"))] -#[macro_export] -macro_rules! ascent_m_par { - ($($tt: tt)*) => { - ascent::ascent!{ $($tt)* } - }; -} - -#[cfg(feature = "par")] -#[macro_export] -macro_rules! ascent_m_par { - ($($tt: tt)*) => { - ascent::ascent_par!{ $($tt)* } - }; -} - -#[cfg(not(feature = "par"))] -#[macro_export] -macro_rules! ascent_run_m_par { - ($($tt: tt)*) => { - ascent::ascent_run!{ $($tt)* } - }; -} - -#[cfg(feature = "par")] -#[macro_export] -macro_rules! ascent_run_m_par { - ($($tt: tt)*) => { - ascent::ascent_run_par!{ $($tt)* } - }; -} - -#[cfg(not(feature = "par"))] -#[allow(dead_code)] -pub fn lat_to_vec(vec: Vec) -> Vec { - vec -} - -#[cfg(feature = "par")] -#[allow(dead_code)] -pub fn lat_to_vec(vec: ascent::boxcar::Vec>) -> Vec { - vec.into_iter().map(|x| x.into_inner().unwrap()).collect() -} \ No newline at end of file +use std::sync::RwLock; + +#[cfg(not(feature = "par"))] +#[macro_export] +macro_rules! ascent_m_par { + ($($tt: tt)*) => { + ascent::ascent!{ $($tt)* } + }; +} + +#[cfg(feature = "par")] +#[macro_export] +macro_rules! ascent_m_par { + ($($tt: tt)*) => { + ascent::ascent_par!{ $($tt)* } + }; +} + +#[cfg(not(feature = "par"))] +#[macro_export] +macro_rules! ascent_run_m_par { + ($($tt: tt)*) => { + ascent::ascent_run!{ $($tt)* } + }; +} + +#[cfg(feature = "par")] +#[macro_export] +macro_rules! ascent_run_m_par { + ($($tt: tt)*) => { + ascent::ascent_run_par!{ $($tt)* } + }; +} + +#[cfg(not(feature = "par"))] +#[allow(dead_code)] +pub fn lat_to_vec(vec: Vec) -> Vec { vec } + +#[cfg(feature = "par")] +#[allow(dead_code)] +pub fn lat_to_vec(vec: ascent::boxcar::Vec>) -> Vec { + vec.into_iter().map(|x| x.into_inner().unwrap()).collect() +} diff --git a/ascent_tests/src/bin/tc.rs b/ascent_tests/src/bin/tc.rs index 85bdf7b..f8e2876 100644 --- a/ascent_tests/src/bin/tc.rs +++ b/ascent_tests/src/bin/tc.rs @@ -1,24 +1,25 @@ -use ascent_tests::ascent_m_par; -use std::time::Instant; - -ascent_m_par!{ - struct TC; - relation edge(i32, i32); - relation path(i32, i32); - - path(*x, *y) <-- edge(x,y); - path(*x, *z) <-- path(x, y), edge(y, z); -} - -fn main() { - let edges = (0..5000).map(|x| (x, x + 1)).collect(); - let mut prog = TC::default(); - - prog.edge = edges; - - let before = Instant::now(); - prog.run(); - let took = before.elapsed(); - println!("path len: {}", prog.path.len()); - println!("took {took:?}"); -} \ No newline at end of file +use std::time::Instant; + +use ascent_tests::ascent_m_par; + +ascent_m_par! { + struct TC; + relation edge(i32, i32); + relation path(i32, i32); + + path(*x, *y) <-- edge(x,y); + path(*x, *z) <-- path(x, y), edge(y, z); +} + +fn main() { + let edges = (0..5000).map(|x| (x, x + 1)).collect(); + let mut prog = TC::default(); + + prog.edge = edges; + + let before = Instant::now(); + prog.run(); + let took = before.elapsed(); + println!("path len: {}", prog.path.len()); + println!("took {took:?}"); +} diff --git a/ascent_tests/src/example_tests.rs b/ascent_tests/src/example_tests.rs index 13fc352..f3acc94 100644 --- a/ascent_tests/src/example_tests.rs +++ b/ascent_tests/src/example_tests.rs @@ -1,19 +1,18 @@ -use ascent::ascent_run; -use ascent::ascent; +use std::hash::Hash; use std::rc::Rc; + use ascent::aggregators::mean; -use crate::ascent_m_par; -use crate::ascent_run_m_par; -use crate::assert_rels_eq; +use ascent::{ascent, ascent_run}; + use crate::utils::rels_equal; -use std::hash::Hash; +use crate::{ascent_m_par, ascent_run_m_par, assert_rels_eq}; #[test] fn test_generators_conditions_example() { let res = ascent_run! { relation node(i32, Rc>); relation edge(i32, i32); - + node(1, Rc::new(vec![2, 3])); node(2, Rc::new(vec![3, 4])); @@ -40,7 +39,7 @@ fn test_agg_example() { agg avg = mean(g) in course_grade(s, _, g); } let mut prog = AscentProgram::default(); - prog.student = FromIterator::from_iter([(1, ), (2, )]); + prog.student = FromIterator::from_iter([(1,), (2,)]); prog.course_grade = FromIterator::from_iter([(1, 600, 60), (1, 602, 80), (2, 602, 70), (2, 605, 90)]); prog.run(); println!("avg grade: {:?}", prog.avg_grade); @@ -50,39 +49,61 @@ fn test_agg_example() { #[test] fn test_tc_example() { fn tc(r: Vec<(i32, i32)>, reflexive: bool) -> Vec<(i32, i32)> { - ascent_run!{ + ascent_run! { relation r(i32, i32) = r; relation tc(i32, i32); tc(x, y) <-- r(x, y); tc(x, z) <-- r(x, y), tc(y, z); tc(x, x), tc(y, y) <-- if reflexive, r(x, y); - }.tc + } + .tc } let r = vec![(1, 2), (2, 4), (3, 1)]; println!("tc: {:?}", tc(r.clone(), true)); println!("reflexive tc: {:?}", tc(r.clone(), true)); - assert_rels_eq!(tc(r.clone(), true), - vec![(1,1), (2,2), (3,3), (4,4), (1, 2), (1, 4), (2, 4), (3, 1), (3, 2), (3, 4)]); + assert_rels_eq!(tc(r.clone(), true), vec![ + (1, 1), + (2, 2), + (3, 3), + (4, 4), + (1, 2), + (1, 4), + (2, 4), + (3, 1), + (3, 2), + (3, 4) + ]); } - #[test] fn test_generic_tc_example() { - fn tc(r: Vec<(N, N)>, reflexive: bool) -> Vec<(N, N)> where N: Clone + Hash + Eq{ - ascent_run!{ + fn tc(r: Vec<(N, N)>, reflexive: bool) -> Vec<(N, N)> + where N: Clone + Hash + Eq { + ascent_run! { struct TC; relation r(N, N) = r; relation tc(N, N); tc(x, y) <-- r(x, y); tc(x, z) <-- r(x, y), tc(y, z); tc(x, x), tc(y, y) <-- if reflexive, r(x, y); - }.tc + } + .tc } let r = vec![(1, 2), (2, 4), (3, 1)]; println!("tc: {:?}", tc(r.clone(), true)); println!("reflexive tc: {:?}", tc(r.clone(), true)); - assert_rels_eq!(tc(r.clone(), true), - vec![(1,1), (2,2), (3,3), (4,4), (1, 2), (1, 4), (2, 4), (3, 1), (3, 2), (3, 4)]); + assert_rels_eq!(tc(r.clone(), true), vec![ + (1, 1), + (2, 2), + (3, 3), + (4, 4), + (1, 2), + (1, 4), + (2, 4), + (3, 1), + (3, 2), + (3, 4) + ]); } #[test] @@ -92,15 +113,13 @@ fn test_generic_ty() { relation dummy(T); } - struct Container(AscentProgram) where T: Clone + Hash + Eq; + struct Container(AscentProgram) + where T: Clone + Hash + Eq; impl Container - where - T: Clone + Hash + Eq + where T: Clone + Hash + Eq { - fn run(&mut self) { - self.0.run(); - } + fn run(&mut self) { self.0.run(); } } let mut container: Container = Container(AscentProgram::default()); @@ -118,12 +137,9 @@ fn test_generic_ty_with_divergent_impl_generics() { struct Container(AscentProgram); impl Container - where - T: Clone + Hash + Eq + where T: Clone + Hash + Eq { - fn run(&mut self) { - self.0.run(); - } + fn run(&mut self) { self.0.run(); } } let mut container: Container = Container(AscentProgram::default()); @@ -159,7 +175,6 @@ fn test_borrowed_strings() { #[test] fn test_borrowed_strings_2() { - fn ancestry_fn<'a>(parent_rel: impl Iterator) -> Vec<(&'a str, &'a str)> { ascent_run_m_par! { struct Ancestry<'a>; @@ -170,7 +185,10 @@ fn test_borrowed_strings_2() { ancestor(p, gc) <-- parent(p, c), ancestor(c, gc); - }.ancestor.into_iter().collect() + } + .ancestor + .into_iter() + .collect() } let james = "James".to_string(); diff --git a/ascent_tests/src/exps.rs b/ascent_tests/src/exps.rs index 932a4f7..6d44b35 100644 --- a/ascent_tests/src/exps.rs +++ b/ascent_tests/src/exps.rs @@ -1,65 +1,57 @@ -#![allow(warnings)] -use ascent::ascent; -fn exp_type_eq(){ - -} - -fn types_eq() -> bool{ - todo!() -} - -macro_rules! run_combined{ - ($($b:tt)*) => { - - }; -} - -// trait TypesEq { -// const ARE_EQUAL : bool; -// } - -// impl TypesEq for T { -// const ARE_EQUAL : bool = true; -// } - -// impl TypesEq for T { -// const ARE_EQUAL : bool = false; -// } - -fn exp_type_param(inp: T) { - struct Inner { - one: T, - } - let instance = Inner {one: inp}; - -} - -fn exp_rel_traits(){ - trait HasRel { - type Fields; - } - - ascent!{ - struct AscentProg1; - relation foo(i32, i32); - } - - ascent!{ - struct AscentProg2; - relation foo(i32, i32); - } - - impl HasRel<{mangle("foo")}> for AscentProg1 { - type Fields = (i32, i32); - } - - impl HasRel<{mangle("foo")}> for AscentProg2 { - type Fields = (i32, i32); - } - - run_combined!(AscentProg1::default(), AscentProg2::default() on foo); - -} -const fn mangle(input: &str) -> u64{ - const_fnv1a_hash::fnv1a_hash_str_64(input) -} +#![allow(warnings)] +use ascent::ascent; +fn exp_type_eq() {} + +fn types_eq() -> bool { todo!() } + +macro_rules! run_combined { + ($($b:tt)*) => { + + }; +} + +// trait TypesEq { +// const ARE_EQUAL : bool; +// } + +// impl TypesEq for T { +// const ARE_EQUAL : bool = true; +// } + +// impl TypesEq for T { +// const ARE_EQUAL : bool = false; +// } + +fn exp_type_param(inp: T) { + struct Inner { + one: T, + } + let instance = Inner { one: inp }; +} + +fn exp_rel_traits() { + trait HasRel { + type Fields; + } + + ascent! { + struct AscentProg1; + relation foo(i32, i32); + } + + ascent! { + struct AscentProg2; + relation foo(i32, i32); + } + + impl HasRel<{ mangle("foo") }> for AscentProg1 { + type Fields = (i32, i32); + } + + impl HasRel<{ mangle("foo") }> for AscentProg2 { + type Fields = (i32, i32); + } + + run_combined!(AscentProg1::default(), AscentProg2::default() on foo); +} +const fn mangle(input: &str) -> u64 { const_fnv1a_hash::fnv1a_hash_str_64(input) } diff --git a/ascent_tests/src/macros_tests.rs b/ascent_tests/src/macros_tests.rs index 4203434..35be62a 100644 --- a/ascent_tests/src/macros_tests.rs +++ b/ascent_tests/src/macros_tests.rs @@ -6,7 +6,7 @@ use crate::assert_rels_eq; #[test] fn test_macro_in_macro() { - ascent!{ + ascent! { relation foo1(i32, i32); relation foo2(i32, i32); relation bar(i32 , i32); @@ -38,7 +38,6 @@ fn test_macro_in_macro() { #[test] fn test_macro_in_macro2() { - type Var = String; type Val = isize; #[derive(Clone, Eq, PartialEq, Hash)] @@ -70,16 +69,15 @@ fn test_macro_in_macro2() { prog.res = vec![(Atomic::Val(1000),), (Atomic::Var("x1".into()),)]; prog.run(); - println!("res_val: {}\n{:?}" , prog.res_val.len(), prog.res_val); + println!("res_val: {}\n{:?}", prog.res_val.len(), prog.res_val); println!("res_val2: {}\n{:?}", prog.res_val2.len(), prog.res_val2); assert_eq!(prog.res_val2.len(), prog.res_val.len().pow(2)); - assert_rels_eq!(prog.res_val, [(100, ), (1000, )]); + assert_rels_eq!(prog.res_val, [(100,), (1000,)]); } #[test] fn test_macro_in_macro3() { - ascent! { relation edge(i32, i32); relation edge_rev(i32, i32); @@ -101,7 +99,6 @@ fn test_macro_in_macro3() { #[test] fn test_macro_in_macro4() { - ascent! { relation foo(i32, i32); relation bar(i32, i32); @@ -124,7 +121,6 @@ fn test_macro_in_macro4() { #[test] fn test_macro_in_macro5() { - type Lang = &'static str; type CompilerName = &'static str; ascent! { @@ -145,9 +141,13 @@ fn test_macro_in_macro5() { } let mut prog = AscentProgram::default(); - prog.compiler = vec![("Rustc", "Rust", "X86"), ("Rustc", "Rust", "WASM"), - ("MyRandomCompiler", "Python", "Rust"), - ("Cython", "Python", "C"), ("Clang", "C", "X86")]; + prog.compiler = vec![ + ("Rustc", "Rust", "X86"), + ("Rustc", "Rust", "WASM"), + ("MyRandomCompiler", "Python", "Rust"), + ("Cython", "Python", "C"), + ("Clang", "C", "X86"), + ]; prog.bad_compiler = vec![("MyRandomCompiler",)]; prog.run(); @@ -160,7 +160,6 @@ fn test_macro_in_macro5() { #[test] fn test_macro_in_macro6() { - ascent! { relation foo(i32, i32) = vec![(0, 1), (1, 2), (2, 3), (3, 4)]; @@ -194,7 +193,6 @@ fn test_macro_in_macro6() { #[test] fn test_macro_in_macro7() { - ascent! { relation foo(i32, i32) = vec![(0, 1), (1, 2), (2, 3), (3, 4), (4, 6), (3, 7)]; relation bar(Option, i32) = vec![(Some(1), 2), (Some(2), 3), (None, 4)]; @@ -202,8 +200,8 @@ fn test_macro_in_macro7() { macro foo($x: expr, $y: expr) { foo($x, $y), } macro bar($x: ident, $y: expr) { bar(?Some($x), $y) } - macro foo2($x: expr, $y: expr) { - foo!($x, $y), let x = $x, for x2 in [1, 2], ((foo(x, x2), let y = $y, let _ = println!("{}", y)) || if true, for y in [$y, $y]), + macro foo2($x: expr, $y: expr) { + foo!($x, $y), let x = $x, for x2 in [1, 2], ((foo(x, x2), let y = $y, let _ = println!("{}", y)) || if true, for y in [$y, $y]), foo!(x + 0, y - 0), foo(x, y), foo!(x, y), let z = |x: i32| {x}, foo(z(*x), z(*y)) } @@ -215,7 +213,7 @@ fn test_macro_in_macro7() { baz!(x, z) <-- bar!(x, y), foo!(y, z); baz!(a, c) <-- bar!(a, b), foo!(b, c); baz_e(x, z) <-- bar(?Some(x), y), foo(y, z); - + relation quax(i32, i32); relation quax_e(i32, i32); @@ -228,15 +226,13 @@ fn test_macro_in_macro7() { println!("baz : {:?}", prog.baz); assert_rels_eq!(prog.baz, prog.baz_e); - + println!("quax: {:?}", prog.quax); assert_rels_eq!(prog.quax, prog.quax_e); - } #[test] fn test_macro_in_macro8() { - macro_rules! id { ($($inp: tt)*) => { $($inp)* }; } @@ -248,7 +244,7 @@ fn test_macro_in_macro8() { macro foo2($x: expr, $y: expr) { foo($x, $y), } macro bar($x: ident, $y: expr) { bar(?Some($x), $y) } - macro foo($xx: expr, $yy: expr) { + macro foo($xx: expr, $yy: expr) { foo2!($xx, $yy), let x = id!($xx), let y = id!($yy), for x2 in [1, 2], let _ = assert!(x == $xx && y == $yy), foo2!(id!(id!(x) + 0), id!(y - 0)), foo(x, y), foo2!(x, y), @@ -258,7 +254,7 @@ fn test_macro_in_macro8() { relation baz(i32, i32); relation baz_e(i32, i32); baz(x, z) <-- bar!(x, y), foo!(y, z); - + relation baz2(i32, i32); baz2(a, c) <-- bar!(a, b), foo!(b, c); diff --git a/ascent_tests/src/se.rs b/ascent_tests/src/se.rs index 91f2306..e5b65b7 100644 --- a/ascent_tests/src/se.rs +++ b/ascent_tests/src/se.rs @@ -1,61 +1,59 @@ -#![allow(warnings)] -use std::rc::Rc; - -use ascent::*; - -type SrcLine = u32; - -type Register = &'static str; - -#[derive(Clone, Copy, PartialEq, Eq, Hash)] -pub enum Val { - Ref(Register), - Lit(i32) -} - -#[derive(Clone, PartialEq, Eq, Hash)] -pub enum Trace { - Null, - Cons(SrcLine, Rc) -} - -#[derive(Clone, Copy, PartialEq, Eq, Hash)] -pub enum Instr { - Mov(Register, Val), - Cmp(Register, Val, Val), - Brz(Register, SrcLine), - Add(Register, Val, Val), - Br(SrcLine), -} -fn instr_vals(instr: &Instr) -> Vec<&Val>{ - todo!() -} - -use Instr::*; -ascent! { - relation trace(SrcLine, Trace); //scrline duplicates the head of trace for fast lookup - relation source(SrcLine, Instr); - relation store(Trace, Register, i32); - relation aeval(Trace, Val, i32); - - aeval(time.clone(), val, eval) <-- - source(pc, instr), - trace(pc, time), - for (&val, &eval) in instr_vals(instr).into_iter().filter_map(|v| match v {Val::Lit(l) => Some((v,l)), Val::Ref(r) => None}); - - aeval(time.clone(), val, eval) <-- - source(pc, instr), - trace(pc, time), - for (&val, ®) in instr_vals(instr).into_iter().filter_map(|v| match v {Val::Lit(x) => None, Val::Ref(r) => Some((v, r))}), - store(time, reg, eval); - - trace(pc + 1, Trace::Cons(pc + 1, Rc::new(time.clone()) )) <-- - source(pc, ?Mov(target, val)), - trace(pc, time), - aeval(time, val, eval); - - store(Trace::Cons(pc + 1, Rc::new(time.clone())), target, *target_eval) <-- - source(pc, ?Mov(target, val)), - trace(pc, time), - aeval(time, val, target_eval); -} +#![allow(warnings)] +use std::rc::Rc; + +use ascent::*; + +type SrcLine = u32; + +type Register = &'static str; + +#[derive(Clone, Copy, PartialEq, Eq, Hash)] +pub enum Val { + Ref(Register), + Lit(i32), +} + +#[derive(Clone, PartialEq, Eq, Hash)] +pub enum Trace { + Null, + Cons(SrcLine, Rc), +} + +#[derive(Clone, Copy, PartialEq, Eq, Hash)] +pub enum Instr { + Mov(Register, Val), + Cmp(Register, Val, Val), + Brz(Register, SrcLine), + Add(Register, Val, Val), + Br(SrcLine), +} +fn instr_vals(instr: &Instr) -> Vec<&Val> { todo!() } + +use Instr::*; +ascent! { + relation trace(SrcLine, Trace); //scrline duplicates the head of trace for fast lookup + relation source(SrcLine, Instr); + relation store(Trace, Register, i32); + relation aeval(Trace, Val, i32); + + aeval(time.clone(), val, eval) <-- + source(pc, instr), + trace(pc, time), + for (&val, &eval) in instr_vals(instr).into_iter().filter_map(|v| match v {Val::Lit(l) => Some((v,l)), Val::Ref(r) => None}); + + aeval(time.clone(), val, eval) <-- + source(pc, instr), + trace(pc, time), + for (&val, ®) in instr_vals(instr).into_iter().filter_map(|v| match v {Val::Lit(x) => None, Val::Ref(r) => Some((v, r))}), + store(time, reg, eval); + + trace(pc + 1, Trace::Cons(pc + 1, Rc::new(time.clone()) )) <-- + source(pc, ?Mov(target, val)), + trace(pc, time), + aeval(time, val, eval); + + store(Trace::Cons(pc + 1, Rc::new(time.clone())), target, *target_eval) <-- + source(pc, ?Mov(target, val)), + trace(pc, time), + aeval(time, val, target_eval); +} diff --git a/ascent_tests/src/tests.rs b/ascent_tests/src/tests.rs index daf8b53..5a31c79 100644 --- a/ascent_tests/src/tests.rs +++ b/ascent_tests/src/tests.rs @@ -1,949 +1,974 @@ -#![cfg(test)] -#![allow(irrefutable_let_patterns)] -use std::collections::{BTreeMap, BTreeSet, HashMap}; -use std::fmt::Debug; -use std::ops::Deref; -use std::primitive; -use std::sync::Arc; -use std::time::Duration; -use std::{cmp::max, rc::Rc}; -use ascent::Dual; -use std::hash::Hash; - -use ascent::ascent; -use ascent::ascent_run; - -use LambdaCalcExpr::*; -use crate::ascent_maybe_par::lat_to_vec; -use crate::{utils::*, assert_rels_eq, ascent_m_par, ascent_run_m_par}; -use itertools::Itertools; - -#[derive(Clone, PartialEq, Eq, Debug, Hash)] -pub enum LambdaCalcExpr{ - Ref(&'static str), - Lam(&'static str, Arc), - App(Arc, Arc) -} - -impl LambdaCalcExpr { - #[allow(dead_code)] - fn depth(&self) -> usize { - match self{ - LambdaCalcExpr::Ref(_) => 0, - LambdaCalcExpr::Lam(_x,b) => 1 + b.depth(), - LambdaCalcExpr::App(f,e) => 1 + max(f.depth(), e.depth()) - } - } -} -fn app(f: LambdaCalcExpr, a: LambdaCalcExpr) -> LambdaCalcExpr { - App(Arc::new(f), Arc::new(a)) -} -fn lam(x: &'static str, e: LambdaCalcExpr) -> LambdaCalcExpr { - Lam(x, Arc::new(e)) -} - -fn sub(exp: &LambdaCalcExpr, var: &str, e: &LambdaCalcExpr) -> LambdaCalcExpr { - match exp { - Ref(x) if *x == var => e.clone(), - Ref(_x) => exp.clone(), - App(ef,ea) => app(sub(ef, var, e), sub(ea, var, e)), - Lam(x, _eb) if *x == var => exp.clone(), - Lam(x, eb) => lam(x, sub(eb, var, e)) - } -} - -#[allow(non_snake_case)] -fn U() -> LambdaCalcExpr {lam("x", app(Ref("x"), Ref("x")))} -#[allow(non_snake_case)] -fn I() -> LambdaCalcExpr {lam("x", Ref("x"))} - -#[test] -fn test_dl_lambda(){ - ascent_m_par!{ - relation output(LambdaCalcExpr); - relation input(LambdaCalcExpr); - relation eval(LambdaCalcExpr, LambdaCalcExpr); - relation do_eval(LambdaCalcExpr); - - input(app(U(), I())); - do_eval(exp.clone()) <-- input(exp); - output(res.clone()) <-- input(exp), eval(exp, res); - - eval(exp.clone(), exp.clone()) <-- do_eval(?exp @Ref(_)); - - eval(exp.clone(), exp.clone()) <-- do_eval(?exp @Lam(_,_)); - - do_eval(ef.as_ref().clone()) <-- do_eval(?App(ef,_ea)); - - do_eval(sub(fb, fx, ea)) <-- - do_eval(?App(ef, ea)), - eval(ef.deref(), ?Lam(fx, fb)); - - eval(exp.clone(), final_res.clone()) <-- - do_eval(?exp @ App(ef, ea)), // this requires nightly - eval(ef.deref(), ?Lam(fx, fb)), - eval(sub(fb, fx, ea), final_res); - }; - - let mut prog = AscentProgram::default(); - // println!("{}", AscentProgram::summary()); - prog.run(); - // println!("input:{:?}\n", prog.input); - // println!("eval: {}\n", prog.eval.iter().map(|(e,v)| format!("{:?} ===> {:?}", e, v)).join("\n")); - println!("output: {:?}", prog.output); - assert!(prog.output.iter().contains(&(I(),))); - assert!(prog.output.len() == 1); -} - -#[allow(dead_code)] -fn _test_dl_lambda2(){ - type Time = u32; - type Addr = u32; - #[derive(Clone, Hash, PartialEq, Eq, Debug)] - enum Storable { - Value(LambdaCalcExpr, Env), - Kont(Continuation) - } - use Storable::*; - #[derive(Clone, Hash, PartialEq, Eq, Debug)] - enum Continuation { - Fn(LambdaCalcExpr, Env, Addr), - Ar(LambdaCalcExpr, Env, Addr) - } - use Continuation::*; - type Env = Rc>; - fn tick(_t: &Time, _k: &Continuation) -> Time { - todo!() - } - fn alloc(_e: &LambdaCalcExpr, _ρ: &Env, _a: &Addr, _t: &Time, _k: &Continuation) -> Addr { todo!()} - fn upd(ρ: &Env, var: &'static str, addr: Addr) -> Env{ - let mut ρ = ρ.deref().clone(); - ρ.insert(var, addr); - Rc::new(ρ) - } - // #[derive(Clone, Hash, PartialEq, Eq)] - // struct State(LambdaCalcExpr, Env, Addr, Time); - ascent!{ - struct CESK; - relation output(LambdaCalcExpr); - relation input(LambdaCalcExpr); - relation σ(Addr, Storable); - relation ς(LambdaCalcExpr, Env, Addr, Time); - input(app(U(), I())); - - - ς(v, ρ2, a, tick(t, k)) <-- - ς(?Ref(x), ρ, a, t), - σ(ρ[x], ?Value(v, ρ2)), - σ(a, ?Kont(k)); - - σ(b, Kont(Ar(e1.deref().clone(), ρ.clone(), *a))), - ς(e0.deref().clone(), ρ, b, tick(t, k)) <-- - ς(?e@App(e0, e1), ρ, addr, t), - σ(a, ?Kont(k)), - let b = alloc(e, ρ, addr, t, k); - - σ(b, Kont(Fn(v.clone(), ρ.clone(), *c))), - ς(e, ρ2, b, tick(t, k)) <-- - ς(?v@Lam(..), ρ, a, t), - σ(a, ?Kont(k)), - if let Ar(e, ρ2, c) = k, - let b = alloc(v, ρ, a, t, k); - - σ(b, Value(v.clone(), ρ.clone())), - ς(e, upd(&ρ2, x, b), b, tick(t, k)) <-- - ς(?v@Lam(..), ρ, a, t), - σ(a, ?Kont(k)), - if let Fn(Lam(x, e), ρ2, c) = k, - let b = alloc(v, ρ, a, t, k); - }; - use std::collections::HashSet; - type Store = Rc>>; - ascent!{ - struct CeskLocalStore; - relation output(LambdaCalcExpr); - relation input(LambdaCalcExpr); - relation ς(LambdaCalcExpr, Env, Store, Addr, Time); - // rule 1: - ς(v.clone(), ρ2.clone(), σ.clone(), *a, tick(t, k)) <-- - ς(?Ref(x), ρ, σ, a, t), - for sv in σ[&ρ[x]].iter(), if let Value(v, ρ2) = sv, - for sa in σ[a].iter(), if let Kont(k) = sa; - // ... - } - let mut prog = CESK::default(); - // println!("{}", AscentProgram::summary()); - prog.run(); - // println!("input:{:?}\n", prog.input); - // println!("eval: {}\n", prog.eval.iter().map(|(e,v)| format!("{:?} ===> {:?}", e, v)).join("\n")); -} - - -#[test] -fn test_dl_patterns(){ - // ascent!{ - ascent_m_par!{ - #![measure_rule_times] - relation foo(i32, Option); - relation bar(i32, i32); - foo(1, None); - foo(2, Some(2)); - foo(3, Some(30)); - bar(*x, *y) <-- foo(x, y_opt) if let Some(y) = y_opt if y != x; - }; - let mut prog = AscentProgram::default(); - prog.run(); - println!("bar: {:?}", prog.bar); - assert!(prog.bar.iter().contains(&(3,30))); - assert!(prog.bar.len() == 1); -} - -#[test] -fn test_dl_pattern_args(){ - ascent_m_par!{ - relation foo(i32, Option); - relation bar(i32, i32); - foo(1, None); - foo(2, Some(2)); - foo(3, Some(30)); - foo(3, None); - bar(*x, *y) <-- foo(x, ?Some(y)) if y != x; - }; - let mut prog = AscentProgram::default(); - prog.run(); - println!("bar: {:?}", prog.bar); - assert!(prog.bar.iter().contains(&(3,30))); - assert!(prog.bar.len() == 1); -} - -#[test] -fn test_dl2(){ - ascent_m_par!{ - relation bar(i32, i32); - relation foo1(i32, i32); - relation foo2(i32, i32); - - foo1(1, 2); - foo1(10, 20); - foo1(0, 2); - - bar(*x, y + z) <-- foo1(x, y) if *x != 0, foo2(y, z); - } - - let mut prog = AscentProgram::default(); - - let foo2 = vec![ - (2, 4), - (2, 1), - (20, 40), - (20, 0), - ]; - prog.foo2 = FromIterator::from_iter(foo2); - - prog.run(); - - println!("bar: {:?}", prog.bar); - assert!(rels_equal([(1, 3), (1, 6), (10, 60), (10, 20)], prog.bar)); -} - -#[test] -fn test_ascent_expressions_and_inits(){ - ascent_m_par!{ - relation foo(i32, i32) = vec![(1, 2)].into_iter().collect(); - foo(2, 3); - foo(3, 5); - - relation bar(i32, i32); - relation baz(i32, i32, i32); - - bar(3, 6); - bar(5, 10); - - baz(*x, *y, *z) <-- foo(x, y), bar(x + y , z); - }; - let mut prog = AscentProgram::default(); - prog.run(); - println!("baz: {:?}", prog.baz); - assert!(rels_equal([(1,2,6), (2,3,10)], prog.baz)); -} - -#[test] -fn test_dl_cross_join(){ - ascent_m_par!{ - relation foo(i32, i32); - relation bar(i32, i32); - relation baz(i32, i32, i32, i32); - foo(x, x + 1) <-- for x in 0..5; - - bar(11, 12); - bar(12, 13); - - baz(a, b, c, d) <-- foo(a, b), bar(c , d); - } - let mut prog = AscentProgram::default(); - prog.run(); - println!("baz: {:?}", prog.baz); - assert_eq!(prog.baz.len(), prog.foo.len() * prog.bar.len()); -} - -#[test] -fn test_dl_vars_bound_in_patterns(){ - ascent_m_par!{ - relation foo(i32, Option); - relation bar(i32, i32); - relation baz(i32, i32, i32); - foo(1, Some(2)); - foo(2, None); - foo(3, Some(5)); - foo(4, Some(10)); - - - bar(3, 6); - bar(5, 10); - bar(10, 20); - - baz(*x, *y, *z) <-- foo(x, ?Some(y)), bar(y , z); - }; - let mut prog = AscentProgram::default(); - prog.run(); - println!("baz: {:?}", prog.baz); - assert!(rels_equal([(3, 5, 10), (4, 10, 20)], prog.baz)); -} - -#[test] -fn test_dl_generators(){ - ascent!{ - relation foo(i32, i32); - relation bar(i32); - - foo(x, y) <-- for x in 0..10, for y in (x+1)..10; - - bar(*x) <-- foo(x, y); - bar(*y) <-- foo(x, y); - }; - let mut prog = AscentProgram::default(); - prog.run(); - println!("foo: {:?}", prog.foo); - assert_eq!(prog.foo.len(), 45); -} - -#[test] -fn test_dl_generators2(){ - ascent!{ - relation foo(i32, i32); - relation bar(i32); - - foo(3, 4); - foo(4, 6); - foo(20, 21); - bar(x) <-- for (x, y) in (0..10).map(|x| (x, x+1)), foo(x, y); - - }; - let mut prog = AscentProgram::default(); - prog.run(); - println!("bar: {:?}", prog.bar); - assert!(rels_equal([(3,)], prog.bar)); -} - - -#[test] -fn test_dl_multiple_head_clauses(){ - ascent_m_par!{ - relation foo(Vec, Vec); - relation foo2(Vec); - relation foo1(Vec); - - relation bar(i32); - - foo(vec![3], vec![4]); - foo(vec![1, 2], vec![4, 5]); - foo(vec![10, 11], vec![20]); - - foo1(x.clone()), foo2(y.clone()) <-- foo(x, y) if x.len() > 1; - }; - let mut prog = AscentProgram::default(); - prog.run(); - println!("foo1: {:?}", prog.foo1); - println!("foo2: {:?}", prog.foo2); - - assert!(rels_equal([(vec![1, 2],), (vec![10,11],)], prog.foo1)); - assert!(rels_equal([(vec![4, 5],), (vec![20],)], prog.foo2)); -} - -#[test] -fn test_dl_multiple_head_clauses2(){ - ascent_m_par!{ - relation foo(Vec); - relation foo_left(Vec); - relation foo_right(Vec); - - foo(vec![1,2,3,4]); - - foo_left(xs[..i].into()), foo_right(xs[i..].into()) <-- foo(xs), for i in 0..xs.len(); - foo(xs.clone()) <-- foo_left(xs); - foo(xs.clone()) <-- foo_right(xs); - }; - - let mut prog = AscentProgram::default(); - prog.run(); - println!("foo: {:?}", prog.foo); - - assert!(rels_equal([(vec![],), (vec![1],), (vec![1, 2],), (vec![1, 2, 3],), (vec![1, 2, 3, 4],), (vec![2],), (vec![2, 3],), (vec![2, 3, 4],), (vec![3],), (vec![3, 4],), (vec![4],)], - prog.foo)); -} - -#[test] -fn test_dl_disjunctions(){ - ascent!{ - relation foo1(i32, i32); - relation foo2(i32, i32); - relation small(i32); - relation bar(i32, i32); - - small(x) <-- for x in 0..5; - foo1(0, 4); - foo1(1, 4); - foo1(2, 6); - - foo2(3, 30); - foo2(2, 20); - foo2(8, 21); - foo2(9, 21); - - bar(x.clone(), y.clone()) <-- ((for x in 3..10), small(x) || foo1(x,_y)), (foo2(x,y)); - - }; - let mut prog = AscentProgram::default(); - prog.run(); - println!("bar: {:?}", prog.bar); - assert_rels_eq!([(3,30), (2, 20)], prog.bar); -} - -#[test] -fn test_dl_repeated_vars(){ - ascent_m_par!{ - relation foo(i32); - relation bar(i32, i32); - relation res(i32); - relation bar_refl(i32); - relation bar3(i32, i32, i32); - relation bar3_res(i32); - - foo(3); - bar(2, 1); - bar(1, 1); - bar(3, 3); - - bar_refl(*x) <-- bar(x, x); - - res(*x) <-- foo(x), bar(x, x); - - bar3(10,10,11); - bar3(1,1,1); - bar3(1,2,3); - bar3(2,1,3); - - bar3_res(*x) <-- bar3(x, x, *x + 1); - }; - let mut prog = AscentProgram::default(); - prog.run(); - println!("res: {:?}", prog.res); - assert!(rels_equal([(3,)], prog.res)); - assert!(rels_equal([(1,), (3,)], prog.bar_refl)); - assert!(rels_equal([(10,)], prog.bar3_res)); -} - - -#[test] -fn test_dl_lattice1(){ - ascent_m_par!{ - lattice shortest_path(i32, i32, Dual); - relation edge(i32, i32, u32); - - shortest_path(*x, *y, Dual(*w)) <-- edge(x, y, w); - shortest_path(*x, *z, Dual(w + l.0)) <-- edge(x, y, w), shortest_path(y, z, l); - - edge(1, 2, x + 30) <-- for x in 0..100; - edge(2, 3, x + 50) <-- for x in 0..100; - edge(1, 3, x + 40) <-- for x in 0..100; - edge(2, 4, x + 100) <-- for x in 0..100; - edge(1, 4, x + 200) <-- for x in 0..100; - } - let mut prog = AscentProgram::default(); - prog.run(); - println!("shortest_path ({} tuples):", prog.shortest_path.len()); - println!("\n{:?}", prog.shortest_path); - println!("{}", AscentProgram::summary()); - assert!(rels_equal(lat_to_vec(prog.shortest_path), [(1,2, Dual(30)), (1, 3, Dual(40)), (1,4, Dual(130)), (2,3, Dual(50)), (2, 4, Dual(100))])) -} - -#[test] -fn test_dl_lattice2(){ - ascent!{ - lattice shortest_path(i32, i32, Dual); - relation edge(i32, i32, u32); - - shortest_path(*x,*y, {println!("adding sp({},{},{})?", x, y, w ); Dual(*w)}) <-- edge(x, y, w); - shortest_path(*x, *z, {println!("adding sp({},{},{})?", x, z, *w + len.0); Dual(w + len.0)}) <-- edge(x, y, w), shortest_path(y, z, len); - - edge(1, 2, 30); - edge(2, 3, 50); - edge(1, 3, 40); - edge(2, 4, 100); - edge(4, 1, 1000); - - }; - let mut prog = AscentProgram::default(); - prog.run(); - println!("shortest_path ({} tuples):\n{:?}", prog.shortest_path.len(), prog.shortest_path); -} - -#[test] -fn test_ascent_run(){ - let foo_contents = (0..10).flat_map(|x| (x+1..10).map(move |y| (x,y))).collect_vec(); - let res = ascent_run!{ - relation foo(i32, i32); - relation bar(i32); - - foo(x,y) <-- for &(x,y) in foo_contents.iter(); - - bar(*x) <-- foo(x, y); - bar(*y) <-- foo(x, y); - }; - - let _res2 = ascent_run!{ - relation foo(i32); - foo(42); - }; - println!("foo: {:?}", res.foo); - assert_eq!(res.foo.len(), 45); -} - -#[test] -fn test_ascent_run_rel_init(){ - let foo_contents = (0..10).flat_map(|x| (x+1..10).map(move |y| (x,y))).collect_vec(); - let res = ascent_run!{ - relation foo(i32, i32) = foo_contents; - relation bar(i32); - - bar(*x), bar(*y) <-- foo(x, y); - }; - - println!("foo: {:?}", res.foo); - assert_eq!(res.foo.len(), 45); -} - -#[test] -fn test_ascentception(){ - let res = ascent_run!{ - relation ascentception_input(i32); - ascentception_input(0); - ascentception_input(100); - - relation funny(i32, i32); - funny(x, y) <-- ascentception_input(inp), for (x, y) in { - ascent_run! { - relation ascentception(i32, i32); - ascentception(x, x + 1) <-- for x in *inp..(inp + 10); - }.ascentception - }; - }; - println!("funny: {:?}", res.funny); - assert_eq!(res.funny.len(), 20); -} - -#[test] -fn test_ascent_run_tc(){ - fn compute_tc(inp: Vec<(i32, i32)>) -> Vec<(i32,i32)> { - ascent_run_m_par!{ - relation r(i32, i32) = FromIterator::from_iter(inp); - relation tc(i32, i32); - tc(x, y) <-- r(x, y); - tc(x, z) <-- r(x, y), tc(y, z); - }.tc.into_iter().collect() - } - assert!(rels_equal([(1,2), (2, 3), (1, 3)], compute_tc(vec![(1,2), (2,3)]))); -} - -#[test] -fn test_ascent_run_tc_generic(){ - fn compute_tc(r: &[(TNode, TNode)]) -> Vec<(TNode,TNode)> { - ascent_run_m_par!{ - struct TC; - relation tc(TNode, TNode); - tc(x.clone(), y.clone()) <-- for (x, y) in r.iter(); - tc(x.clone(), z.clone()) <-- for (x, y) in r.iter(), tc(y, z); - }.tc.into_iter().collect() - } - assert!(rels_equal([(1,2), (2, 3), (1, 3)], compute_tc(&[(1,2), (2,3)]))); -} - -#[test] -fn test_ascent_tc_generic(){ - ascent!{ - struct TC; - relation tc(TNode, TNode); - relation r(TNode, TNode); - tc(x.clone(), y.clone()) <-- r(x,y); - tc(x.clone(), z.clone()) <-- r(x, y), tc(y, z); - } - let mut prog = TC::default(); - prog.r = vec![(1,2), (2,3)]; - prog.run(); - assert!(rels_equal([(1,2), (2, 3), (1, 3)], prog.tc)); -} - - -#[test] -fn test_ascent_negation_through_lattices(){ - use ascent::lattice::set::Set; - let res = ascent_run_m_par!{ - relation foo(i32, i32); - relation bar(i32, i32); - - bar(x, x+1) <-- for x in 0..10; - foo(*x, *y) <-- bar(x, y); - - lattice foo_as_set(Set<(i32, i32)>); - foo_as_set(Set::singleton((*x, *y))) <-- foo(x, y); - - relation baz(i32, i32); - baz(1, 2); - baz(1, 3); - - relation res(i32, i32); - res(*x, *y) <-- baz(x, y), foo_as_set(all_foos), if !all_foos.contains(&(*x, *y)); - }; - println!("res: {:?}", res.res); - assert!(rels_equal([(1,3)], res.res)); -} - -#[test] -fn test_ascent_run_explicit_decl(){ - fn compute_tc(edges: &[(TNode, TNode)]) -> Vec<(TNode, TNode)> { - ascent_run!{ - struct TC where TNode: Clone + Eq + Hash; - relation edge(TNode, TNode); - relation path(TNode, TNode); - edge(x.clone(), y.clone()) <-- for (x, y) in edges.iter(); - - path(x.clone(), y.clone()) <-- edge(x,y); - path(x.clone(), z.clone()) <-- edge(x,y), path(y, z); - }.path - } - - let res = compute_tc(&[(1,2), (2, 3)]); - println!("res: {:?}", res); - assert!(rels_equal([(1,2), (2,3), (1,3)], res)); - -} - -#[test] -fn test_ascent_fac(){ - ascent_m_par!{ - struct Fac; - relation fac(u64, u64); - relation do_fac(u64); - - fac(0, 1) <-- do_fac(0); - do_fac(x - 1) <-- do_fac(x), if *x > 0; - fac(*x, x * sub1fac) <-- do_fac(x) if *x > 0, fac(x - 1, sub1fac); - - do_fac(10); - } - let mut prog = Fac::default(); - prog.run(); - println!("fac: {:?}", prog.fac); - println!("{}", Fac::summary()); - println!("{}", prog.relation_sizes_summary()); - println!("{}", prog.scc_times_summary()); - - assert!(prog.fac.iter().contains(&(5, 120))); -} - -#[test] -fn test_consuming_ascent_run_tc(){ - fn compute_tc(inp: impl Iterator) -> Vec<(i32,i32)> { - ascent_run!{ - relation tc(i32, i32); - relation r(i32, i32); - - r(x, y) <-- for (x, y) in inp; - tc(*x, *y) <-- r(x, y); - tc(*x, *z) <-- r(x, y), tc(y, z); - }.tc - } - let res = compute_tc([(1,2), (2,3)].into_iter()); - println!("res: {:?}", res); - assert!(rels_equal([(1,2), (2, 3), (1, 3)], compute_tc([(1,2), (2,3)].into_iter()))); -} - -#[test] -fn test_ascent_simple_join(){ - let res = ascent_run_m_par!{ - relation bar(i32, i32); - relation foo(i32, i32); - relation baz(i32, i32); - - bar(2, 3); - foo(1, 2); - bar(2, 1); - - baz(*x, *z) <-- foo(x, y), bar(y, z), if x != z; - foo(*x, *y), bar(*x, y) <-- baz(x, y); - }; - println!("baz: {:?}", res.baz); - assert!(rels_equal([(1, 3)], res.baz)); -} - -#[test] -fn test_ascent_simple_join2(){ - let res = ascent_run_m_par!{ - relation bar(i32, i32); - relation foo(i32, i32); - relation baz(i32, i32); - - foo(1, 2); - foo(10, 2); - bar(2, 3); - bar(2, 1); - - baz(*x, *z) <-- foo(x, y) if *x != 10, bar(y, z), if x != z; - foo(*x, *y), bar(*x, *y) <-- baz(x, y); - }; - println!("baz: {:?}", res.baz); - assert_rels_eq!([(1, 3)], res.baz); -} - -#[test] -fn test_ascent_simple_join3(){ - let res = ascent_run_m_par!{ - relation bar(i32, i32); - relation foo(i32, i32); - relation baz(i32, i32); - - foo(1, 2); - foo(10, 2); - bar(2, 3); - bar(2, 1); - - baz(*x, *z) <-- foo(x, y) if *x != 10, bar(y, ?z) if *z < 4, if x != z; - - baz(*x, *z) <-- foo(x, y) if *x != 10, bar(y, z) if *z * x != 444, if x != z; - foo(*x, *y), bar(*x, *y) <-- baz(x, y); - }; - println!("baz: {:?}", res.baz); - assert_rels_eq!([(1, 3)], res.baz); -} - -#[test] -fn test_ascent_simple_join4() { - - #[derive(Default, Clone, Copy)] - struct Prop { transitive: bool, reflexive: bool, symmetric: bool } - let no_prop = Prop::default(); - - let input_rel = vec![(1, 2), (2, 3)]; - - let test_cases = vec![ - (Prop { transitive: true, ..no_prop }, vec![(1, 2), (2, 3), (1, 3)]), - (Prop { reflexive: true, ..no_prop }, vec![(1, 2), (2, 3), (1, 1), (2, 2), (3, 3)]), - (Prop { symmetric: true, ..no_prop }, vec![(1, 2), (2, 3), (3, 2), (2, 1)]), - (Prop { reflexive: true, transitive: true, symmetric: true }, vec![(1, 2), (2, 3), (1, 3), (1, 1), (2, 2), (3, 3), (2, 1), (3, 2), (3, 1)]) - ]; - - for (prop, expected) in test_cases { - let res = ascent_run_m_par! { - relation rel(i32, i32) = input_rel.iter().cloned().collect(); - - rel(y, x) <-- if prop.symmetric, rel(x, y); - rel(y, y), rel(x, x) <-- if prop.reflexive, rel(x, y); - rel(x, z) <-- if prop.transitive, rel(x, y), rel(y, z); - }; - assert_rels_eq!(res.rel, expected); - } -} - -#[test] -fn test_ascent_simple_join5() { - let res = ascent_run_m_par! { - relation foo(i32, i32); - foo(1,2), foo(2, 3); - foo(x, z) <-- let x = &42, foo(x, y), foo(y, z); - - relation bar(i32, i32); - bar(1, 2), bar(2, 3); - bar(x, z) <-- let z = &42, bar(x, y), bar(y, z); - - relation baz(i32, i32); - baz(x, y) <-- foo(x, y), bar(x, y); - baz(x, z) <-- let _ = 42, if let Some(w) = Some(42), baz(x, y), baz(y, z); - - }; - - assert_rels_eq!(res.foo, [(1, 2), (2, 3)]); - assert_rels_eq!(res.bar, [(1, 2), (2, 3)]); - assert_rels_eq!(res.baz, [(1, 2), (2, 3), (1, 3)]); -} - -#[test] -fn test_ascent_wildcards(){ - let res = ascent_run_m_par!{ - relation foo(i32, i32, i32); - relation bar(i32, i32); - relation baz(i32); - - foo(1, 2, 3); - foo(2, 3, 4); - bar(1, 1); - bar(1, 2); - bar(1, 3); - - baz(x) <-- - foo(x, _, _), - bar(_, x); - }; - println!("baz: {:?}", res.baz); - assert_rels_eq!([(1,), (2,)], res.baz); -} - -fn min<'a>(inp: impl Iterator) -> impl Iterator { - inp.map(|tuple| tuple.0).min().cloned().into_iter() -} - -#[test] -fn test_ascent_agg(){ - let res = ascent_run_m_par!{ - relation foo(i32, i32); - relation bar(i32, i32, i32); - relation baz(i32, i32, i32); - - foo(1, 2); - foo(2, 3); - bar(1, 2, 10); - bar(1, 2, 100); - - baz(x, y, min_z) <-- - foo(x, y), - agg min_z = min(z) in bar(x, y, z); - }; - println!("{}", res.summary()); - println!("baz: {:?}", res.baz); - assert_rels_eq!([(1, 2, 10)], res.baz); -} - -#[test] -fn test_run_timeout() { - ascent! { - #![generate_run_timeout] - /// A diverging Ascent program - struct Diverging; - /// foooooooooooo - relation foo(u128); - foo(0); - foo(x + 1) <-- foo(x); - } - - let mut prog = Diverging::default(); - prog.foo = vec![(1,), (2,)]; - let run_timeout_res = prog.run_timeout(Duration::from_millis(5)); - assert!(!run_timeout_res); -} - -#[test] -fn test_ascent_bounded_set() { - use ascent::lattice::bounded_set::BoundedSet; - ascent_m_par! { struct AscentProgram; - - lattice num_store(BoundedSet); - relation init(i32); - - init(x) <-- for x in 0..20; - num_store(BoundedSet::singleton(*x)) <-- - init(x); - } - - let mut prog = AscentProgram::<10>::default(); - prog.run(); - let store = &lat_to_vec(prog.num_store)[0].0; - for (x, ) in prog.init.iter() { - assert!(store.contains(x)); - } -} - -#[test] -fn test_issue3() { - #![allow(non_snake_case)] - - ascent_m_par!{ - relation a__(i32, i32); - relation c__(i32, i32, i32); - relation e__(i32); - relation h__(i32, i32, i32); - - e__(a) <-- a__(b, a); - h__(e, e, e) <-- a__(d, e), c__(e, f, e), e__(e); - } - let mut prog = AscentProgram::default(); - prog.a__ = FromIterator::from_iter(vec![(88,5), (37,24), (11,91)]); - prog.c__ = FromIterator::from_iter(vec![(32,83,88), (2,8,5)]); - prog.e__ = FromIterator::from_iter(vec![(44,), (83,)]); - prog.h__ = FromIterator::from_iter(vec![(38,88,18), (76,18,65), (86,73,91), (98,26,91), (76,10,14)]); - - prog.run(); - println!("h__: {:?}", prog.h__); - assert_rels_eq!(prog.h__, [(38, 88, 18), (76, 18, 65), (86, 73, 91), (98, 26, 91), (76, 10, 14)]); -} - -#[test] -fn test_repeated_vars_simple_joins() { - ascent_m_par! { - relation foo1(i32, i32); - relation foo2(i32, i32); - relation bar(i32, i32); - - // filler: - foo2(100, 100), foo2(101, 101), foo2(102, 102); - - foo1(1, 1), foo2(1, 2), foo1(10, 11), foo2(11, 12); - - bar(x, y) <-- foo2(x, y), foo1(x, x); - } - let mut prog = AscentProgram::default(); - prog.run(); - - println!("bar: {:?}", prog.bar); - assert_rels_eq!(prog.bar, [(1, 2)]); -} - -#[test] -fn test_aggregated_lattice() { - let res = ascent::ascent_run! { - relation foo(i32, i32); - lattice bar(i32, i32); - - bar(x, y) <-- for x in 0..2, for y in 5..10; - - foo(x, z) <-- - for x in 0..2, - agg z = ascent::aggregators::sum(y) in bar(x, y); - }; - assert_rels_eq!(res.bar, [(0, 9), (1, 9)]); -} - -#[test] -fn test_ds_attr() { - use ascent::rel as my_rel; - let res = ascent::ascent_run! { - #![ds(my_rel)] - - #[ds(ascent::rel)] - relation foo(i32, i32) = vec![(0, 1), (1, 0)]; - - relation bar(i32, i32); - - bar(x, y) <-- foo(x, y), if x < y; - }; - - assert_rels_eq!(res.bar, [(0, 1)]); -} \ No newline at end of file +#![cfg(test)] +#![allow(irrefutable_let_patterns)] +use std::cmp::max; +use std::collections::{BTreeMap, BTreeSet, HashMap}; +use std::fmt::Debug; +use std::hash::Hash; +use std::ops::Deref; +use std::primitive; +use std::rc::Rc; +use std::sync::Arc; +use std::time::Duration; + +use LambdaCalcExpr::*; +use ascent::{Dual, ascent, ascent_run}; +use itertools::Itertools; + +use crate::ascent_maybe_par::lat_to_vec; +use crate::utils::*; +use crate::{ascent_m_par, ascent_run_m_par, assert_rels_eq}; + +#[derive(Clone, PartialEq, Eq, Debug, Hash)] +pub enum LambdaCalcExpr { + Ref(&'static str), + Lam(&'static str, Arc), + App(Arc, Arc), +} + +impl LambdaCalcExpr { + #[allow(dead_code)] + fn depth(&self) -> usize { + match self { + LambdaCalcExpr::Ref(_) => 0, + LambdaCalcExpr::Lam(_x, b) => 1 + b.depth(), + LambdaCalcExpr::App(f, e) => 1 + max(f.depth(), e.depth()), + } + } +} +fn app(f: LambdaCalcExpr, a: LambdaCalcExpr) -> LambdaCalcExpr { App(Arc::new(f), Arc::new(a)) } +fn lam(x: &'static str, e: LambdaCalcExpr) -> LambdaCalcExpr { Lam(x, Arc::new(e)) } + +fn sub(exp: &LambdaCalcExpr, var: &str, e: &LambdaCalcExpr) -> LambdaCalcExpr { + match exp { + Ref(x) if *x == var => e.clone(), + Ref(_x) => exp.clone(), + App(ef, ea) => app(sub(ef, var, e), sub(ea, var, e)), + Lam(x, _eb) if *x == var => exp.clone(), + Lam(x, eb) => lam(x, sub(eb, var, e)), + } +} + +#[allow(non_snake_case)] +fn U() -> LambdaCalcExpr { lam("x", app(Ref("x"), Ref("x"))) } +#[allow(non_snake_case)] +fn I() -> LambdaCalcExpr { lam("x", Ref("x")) } + +#[test] +fn test_dl_lambda() { + ascent_m_par! { + relation output(LambdaCalcExpr); + relation input(LambdaCalcExpr); + relation eval(LambdaCalcExpr, LambdaCalcExpr); + relation do_eval(LambdaCalcExpr); + + input(app(U(), I())); + do_eval(exp.clone()) <-- input(exp); + output(res.clone()) <-- input(exp), eval(exp, res); + + eval(exp.clone(), exp.clone()) <-- do_eval(?exp @Ref(_)); + + eval(exp.clone(), exp.clone()) <-- do_eval(?exp @Lam(_,_)); + + do_eval(ef.as_ref().clone()) <-- do_eval(?App(ef,_ea)); + + do_eval(sub(fb, fx, ea)) <-- + do_eval(?App(ef, ea)), + eval(ef.deref(), ?Lam(fx, fb)); + + eval(exp.clone(), final_res.clone()) <-- + do_eval(?exp @ App(ef, ea)), // this requires nightly + eval(ef.deref(), ?Lam(fx, fb)), + eval(sub(fb, fx, ea), final_res); + }; + + let mut prog = AscentProgram::default(); + // println!("{}", AscentProgram::summary()); + prog.run(); + // println!("input:{:?}\n", prog.input); + // println!("eval: {}\n", prog.eval.iter().map(|(e,v)| format!("{:?} ===> {:?}", e, v)).join("\n")); + println!("output: {:?}", prog.output); + assert!(prog.output.iter().contains(&(I(),))); + assert!(prog.output.len() == 1); +} + +#[allow(dead_code)] +fn _test_dl_lambda2() { + type Time = u32; + type Addr = u32; + #[derive(Clone, Hash, PartialEq, Eq, Debug)] + enum Storable { + Value(LambdaCalcExpr, Env), + Kont(Continuation), + } + use Storable::*; + #[derive(Clone, Hash, PartialEq, Eq, Debug)] + enum Continuation { + Fn(LambdaCalcExpr, Env, Addr), + Ar(LambdaCalcExpr, Env, Addr), + } + use Continuation::*; + type Env = Rc>; + fn tick(_t: &Time, _k: &Continuation) -> Time { todo!() } + fn alloc(_e: &LambdaCalcExpr, _ρ: &Env, _a: &Addr, _t: &Time, _k: &Continuation) -> Addr { todo!() } + fn upd(ρ: &Env, var: &'static str, addr: Addr) -> Env { + let mut ρ = ρ.deref().clone(); + ρ.insert(var, addr); + Rc::new(ρ) + } + // #[derive(Clone, Hash, PartialEq, Eq)] + // struct State(LambdaCalcExpr, Env, Addr, Time); + ascent! { + struct CESK; + relation output(LambdaCalcExpr); + relation input(LambdaCalcExpr); + relation σ(Addr, Storable); + relation ς(LambdaCalcExpr, Env, Addr, Time); + input(app(U(), I())); + + + ς(v, ρ2, a, tick(t, k)) <-- + ς(?Ref(x), ρ, a, t), + σ(ρ[x], ?Value(v, ρ2)), + σ(a, ?Kont(k)); + + σ(b, Kont(Ar(e1.deref().clone(), ρ.clone(), *a))), + ς(e0.deref().clone(), ρ, b, tick(t, k)) <-- + ς(?e@App(e0, e1), ρ, addr, t), + σ(a, ?Kont(k)), + let b = alloc(e, ρ, addr, t, k); + + σ(b, Kont(Fn(v.clone(), ρ.clone(), *c))), + ς(e, ρ2, b, tick(t, k)) <-- + ς(?v@Lam(..), ρ, a, t), + σ(a, ?Kont(k)), + if let Ar(e, ρ2, c) = k, + let b = alloc(v, ρ, a, t, k); + + σ(b, Value(v.clone(), ρ.clone())), + ς(e, upd(&ρ2, x, b), b, tick(t, k)) <-- + ς(?v@Lam(..), ρ, a, t), + σ(a, ?Kont(k)), + if let Fn(Lam(x, e), ρ2, c) = k, + let b = alloc(v, ρ, a, t, k); + }; + use std::collections::HashSet; + type Store = Rc>>; + ascent! { + struct CeskLocalStore; + relation output(LambdaCalcExpr); + relation input(LambdaCalcExpr); + relation ς(LambdaCalcExpr, Env, Store, Addr, Time); + // rule 1: + ς(v.clone(), ρ2.clone(), σ.clone(), *a, tick(t, k)) <-- + ς(?Ref(x), ρ, σ, a, t), + for sv in σ[&ρ[x]].iter(), if let Value(v, ρ2) = sv, + for sa in σ[a].iter(), if let Kont(k) = sa; + // ... + } + let mut prog = CESK::default(); + // println!("{}", AscentProgram::summary()); + prog.run(); + // println!("input:{:?}\n", prog.input); + // println!("eval: {}\n", prog.eval.iter().map(|(e,v)| format!("{:?} ===> {:?}", e, v)).join("\n")); +} + +#[test] +fn test_dl_patterns() { + // ascent!{ + ascent_m_par! { + #![measure_rule_times] + relation foo(i32, Option); + relation bar(i32, i32); + foo(1, None); + foo(2, Some(2)); + foo(3, Some(30)); + bar(*x, *y) <-- foo(x, y_opt) if let Some(y) = y_opt if y != x; + }; + let mut prog = AscentProgram::default(); + prog.run(); + println!("bar: {:?}", prog.bar); + assert!(prog.bar.iter().contains(&(3, 30))); + assert!(prog.bar.len() == 1); +} + +#[test] +fn test_dl_pattern_args() { + ascent_m_par! { + relation foo(i32, Option); + relation bar(i32, i32); + foo(1, None); + foo(2, Some(2)); + foo(3, Some(30)); + foo(3, None); + bar(*x, *y) <-- foo(x, ?Some(y)) if y != x; + }; + let mut prog = AscentProgram::default(); + prog.run(); + println!("bar: {:?}", prog.bar); + assert!(prog.bar.iter().contains(&(3, 30))); + assert!(prog.bar.len() == 1); +} + +#[test] +fn test_dl2() { + ascent_m_par! { + relation bar(i32, i32); + relation foo1(i32, i32); + relation foo2(i32, i32); + + foo1(1, 2); + foo1(10, 20); + foo1(0, 2); + + bar(*x, y + z) <-- foo1(x, y) if *x != 0, foo2(y, z); + } + + let mut prog = AscentProgram::default(); + + let foo2 = vec![(2, 4), (2, 1), (20, 40), (20, 0)]; + prog.foo2 = FromIterator::from_iter(foo2); + + prog.run(); + + println!("bar: {:?}", prog.bar); + assert!(rels_equal([(1, 3), (1, 6), (10, 60), (10, 20)], prog.bar)); +} + +#[test] +fn test_ascent_expressions_and_inits() { + ascent_m_par! { + relation foo(i32, i32) = vec![(1, 2)].into_iter().collect(); + foo(2, 3); + foo(3, 5); + + relation bar(i32, i32); + relation baz(i32, i32, i32); + + bar(3, 6); + bar(5, 10); + + baz(*x, *y, *z) <-- foo(x, y), bar(x + y , z); + }; + let mut prog = AscentProgram::default(); + prog.run(); + println!("baz: {:?}", prog.baz); + assert!(rels_equal([(1, 2, 6), (2, 3, 10)], prog.baz)); +} + +#[test] +fn test_dl_cross_join() { + ascent_m_par! { + relation foo(i32, i32); + relation bar(i32, i32); + relation baz(i32, i32, i32, i32); + foo(x, x + 1) <-- for x in 0..5; + + bar(11, 12); + bar(12, 13); + + baz(a, b, c, d) <-- foo(a, b), bar(c , d); + } + let mut prog = AscentProgram::default(); + prog.run(); + println!("baz: {:?}", prog.baz); + assert_eq!(prog.baz.len(), prog.foo.len() * prog.bar.len()); +} + +#[test] +fn test_dl_vars_bound_in_patterns() { + ascent_m_par! { + relation foo(i32, Option); + relation bar(i32, i32); + relation baz(i32, i32, i32); + foo(1, Some(2)); + foo(2, None); + foo(3, Some(5)); + foo(4, Some(10)); + + + bar(3, 6); + bar(5, 10); + bar(10, 20); + + baz(*x, *y, *z) <-- foo(x, ?Some(y)), bar(y , z); + }; + let mut prog = AscentProgram::default(); + prog.run(); + println!("baz: {:?}", prog.baz); + assert!(rels_equal([(3, 5, 10), (4, 10, 20)], prog.baz)); +} + +#[test] +fn test_dl_generators() { + ascent! { + relation foo(i32, i32); + relation bar(i32); + + foo(x, y) <-- for x in 0..10, for y in (x+1)..10; + + bar(*x) <-- foo(x, y); + bar(*y) <-- foo(x, y); + }; + let mut prog = AscentProgram::default(); + prog.run(); + println!("foo: {:?}", prog.foo); + assert_eq!(prog.foo.len(), 45); +} + +#[test] +fn test_dl_generators2() { + ascent! { + relation foo(i32, i32); + relation bar(i32); + + foo(3, 4); + foo(4, 6); + foo(20, 21); + bar(x) <-- for (x, y) in (0..10).map(|x| (x, x+1)), foo(x, y); + + }; + let mut prog = AscentProgram::default(); + prog.run(); + println!("bar: {:?}", prog.bar); + assert!(rels_equal([(3,)], prog.bar)); +} + +#[test] +fn test_dl_multiple_head_clauses() { + ascent_m_par! { + relation foo(Vec, Vec); + relation foo2(Vec); + relation foo1(Vec); + + relation bar(i32); + + foo(vec![3], vec![4]); + foo(vec![1, 2], vec![4, 5]); + foo(vec![10, 11], vec![20]); + + foo1(x.clone()), foo2(y.clone()) <-- foo(x, y) if x.len() > 1; + }; + let mut prog = AscentProgram::default(); + prog.run(); + println!("foo1: {:?}", prog.foo1); + println!("foo2: {:?}", prog.foo2); + + assert!(rels_equal([(vec![1, 2],), (vec![10, 11],)], prog.foo1)); + assert!(rels_equal([(vec![4, 5],), (vec![20],)], prog.foo2)); +} + +#[test] +fn test_dl_multiple_head_clauses2() { + ascent_m_par! { + relation foo(Vec); + relation foo_left(Vec); + relation foo_right(Vec); + + foo(vec![1,2,3,4]); + + foo_left(xs[..i].into()), foo_right(xs[i..].into()) <-- foo(xs), for i in 0..xs.len(); + foo(xs.clone()) <-- foo_left(xs); + foo(xs.clone()) <-- foo_right(xs); + }; + + let mut prog = AscentProgram::default(); + prog.run(); + println!("foo: {:?}", prog.foo); + + assert!(rels_equal( + [ + (vec![],), + (vec![1],), + (vec![1, 2],), + (vec![1, 2, 3],), + (vec![1, 2, 3, 4],), + (vec![2],), + (vec![2, 3],), + (vec![2, 3, 4],), + (vec![3],), + (vec![3, 4],), + (vec![4],) + ], + prog.foo + )); +} + +#[test] +fn test_dl_disjunctions() { + ascent! { + relation foo1(i32, i32); + relation foo2(i32, i32); + relation small(i32); + relation bar(i32, i32); + + small(x) <-- for x in 0..5; + foo1(0, 4); + foo1(1, 4); + foo1(2, 6); + + foo2(3, 30); + foo2(2, 20); + foo2(8, 21); + foo2(9, 21); + + bar(x.clone(), y.clone()) <-- ((for x in 3..10), small(x) || foo1(x,_y)), (foo2(x,y)); + + }; + let mut prog = AscentProgram::default(); + prog.run(); + println!("bar: {:?}", prog.bar); + assert_rels_eq!([(3, 30), (2, 20)], prog.bar); +} + +#[test] +fn test_dl_repeated_vars() { + ascent_m_par! { + relation foo(i32); + relation bar(i32, i32); + relation res(i32); + relation bar_refl(i32); + relation bar3(i32, i32, i32); + relation bar3_res(i32); + + foo(3); + bar(2, 1); + bar(1, 1); + bar(3, 3); + + bar_refl(*x) <-- bar(x, x); + + res(*x) <-- foo(x), bar(x, x); + + bar3(10,10,11); + bar3(1,1,1); + bar3(1,2,3); + bar3(2,1,3); + + bar3_res(*x) <-- bar3(x, x, *x + 1); + }; + let mut prog = AscentProgram::default(); + prog.run(); + println!("res: {:?}", prog.res); + assert!(rels_equal([(3,)], prog.res)); + assert!(rels_equal([(1,), (3,)], prog.bar_refl)); + assert!(rels_equal([(10,)], prog.bar3_res)); +} + +#[test] +fn test_dl_lattice1() { + ascent_m_par! { + lattice shortest_path(i32, i32, Dual); + relation edge(i32, i32, u32); + + shortest_path(*x, *y, Dual(*w)) <-- edge(x, y, w); + shortest_path(*x, *z, Dual(w + l.0)) <-- edge(x, y, w), shortest_path(y, z, l); + + edge(1, 2, x + 30) <-- for x in 0..100; + edge(2, 3, x + 50) <-- for x in 0..100; + edge(1, 3, x + 40) <-- for x in 0..100; + edge(2, 4, x + 100) <-- for x in 0..100; + edge(1, 4, x + 200) <-- for x in 0..100; + } + let mut prog = AscentProgram::default(); + prog.run(); + println!("shortest_path ({} tuples):", prog.shortest_path.len()); + println!("\n{:?}", prog.shortest_path); + println!("{}", AscentProgram::summary()); + assert!(rels_equal(lat_to_vec(prog.shortest_path), [ + (1, 2, Dual(30)), + (1, 3, Dual(40)), + (1, 4, Dual(130)), + (2, 3, Dual(50)), + (2, 4, Dual(100)) + ])) +} + +#[test] +fn test_dl_lattice2() { + ascent! { + lattice shortest_path(i32, i32, Dual); + relation edge(i32, i32, u32); + + shortest_path(*x,*y, {println!("adding sp({},{},{})?", x, y, w ); Dual(*w)}) <-- edge(x, y, w); + shortest_path(*x, *z, {println!("adding sp({},{},{})?", x, z, *w + len.0); Dual(w + len.0)}) <-- edge(x, y, w), shortest_path(y, z, len); + + edge(1, 2, 30); + edge(2, 3, 50); + edge(1, 3, 40); + edge(2, 4, 100); + edge(4, 1, 1000); + + }; + let mut prog = AscentProgram::default(); + prog.run(); + println!("shortest_path ({} tuples):\n{:?}", prog.shortest_path.len(), prog.shortest_path); +} + +#[test] +fn test_ascent_run() { + let foo_contents = (0..10).flat_map(|x| (x + 1..10).map(move |y| (x, y))).collect_vec(); + let res = ascent_run! { + relation foo(i32, i32); + relation bar(i32); + + foo(x,y) <-- for &(x,y) in foo_contents.iter(); + + bar(*x) <-- foo(x, y); + bar(*y) <-- foo(x, y); + }; + + let _res2 = ascent_run! { + relation foo(i32); + foo(42); + }; + println!("foo: {:?}", res.foo); + assert_eq!(res.foo.len(), 45); +} + +#[test] +fn test_ascent_run_rel_init() { + let foo_contents = (0..10).flat_map(|x| (x + 1..10).map(move |y| (x, y))).collect_vec(); + let res = ascent_run! { + relation foo(i32, i32) = foo_contents; + relation bar(i32); + + bar(*x), bar(*y) <-- foo(x, y); + }; + + println!("foo: {:?}", res.foo); + assert_eq!(res.foo.len(), 45); +} + +#[test] +fn test_ascentception() { + let res = ascent_run! { + relation ascentception_input(i32); + ascentception_input(0); + ascentception_input(100); + + relation funny(i32, i32); + funny(x, y) <-- ascentception_input(inp), for (x, y) in { + ascent_run! { + relation ascentception(i32, i32); + ascentception(x, x + 1) <-- for x in *inp..(inp + 10); + }.ascentception + }; + }; + println!("funny: {:?}", res.funny); + assert_eq!(res.funny.len(), 20); +} + +#[test] +fn test_ascent_run_tc() { + fn compute_tc(inp: Vec<(i32, i32)>) -> Vec<(i32, i32)> { + ascent_run_m_par! { + relation r(i32, i32) = FromIterator::from_iter(inp); + relation tc(i32, i32); + tc(x, y) <-- r(x, y); + tc(x, z) <-- r(x, y), tc(y, z); + } + .tc + .into_iter() + .collect() + } + assert!(rels_equal([(1, 2), (2, 3), (1, 3)], compute_tc(vec![(1, 2), (2, 3)]))); +} + +#[test] +fn test_ascent_run_tc_generic() { + fn compute_tc(r: &[(TNode, TNode)]) -> Vec<(TNode, TNode)> { + ascent_run_m_par! { + struct TC; + relation tc(TNode, TNode); + tc(x.clone(), y.clone()) <-- for (x, y) in r.iter(); + tc(x.clone(), z.clone()) <-- for (x, y) in r.iter(), tc(y, z); + } + .tc + .into_iter() + .collect() + } + assert!(rels_equal([(1, 2), (2, 3), (1, 3)], compute_tc(&[(1, 2), (2, 3)]))); +} + +#[test] +fn test_ascent_tc_generic() { + ascent! { + struct TC; + relation tc(TNode, TNode); + relation r(TNode, TNode); + tc(x.clone(), y.clone()) <-- r(x,y); + tc(x.clone(), z.clone()) <-- r(x, y), tc(y, z); + } + let mut prog = TC::default(); + prog.r = vec![(1, 2), (2, 3)]; + prog.run(); + assert!(rels_equal([(1, 2), (2, 3), (1, 3)], prog.tc)); +} + +#[test] +fn test_ascent_negation_through_lattices() { + use ascent::lattice::set::Set; + let res = ascent_run_m_par! { + relation foo(i32, i32); + relation bar(i32, i32); + + bar(x, x+1) <-- for x in 0..10; + foo(*x, *y) <-- bar(x, y); + + lattice foo_as_set(Set<(i32, i32)>); + foo_as_set(Set::singleton((*x, *y))) <-- foo(x, y); + + relation baz(i32, i32); + baz(1, 2); + baz(1, 3); + + relation res(i32, i32); + res(*x, *y) <-- baz(x, y), foo_as_set(all_foos), if !all_foos.contains(&(*x, *y)); + }; + println!("res: {:?}", res.res); + assert!(rels_equal([(1, 3)], res.res)); +} + +#[test] +fn test_ascent_run_explicit_decl() { + fn compute_tc(edges: &[(TNode, TNode)]) -> Vec<(TNode, TNode)> { + ascent_run! { + struct TC where TNode: Clone + Eq + Hash; + relation edge(TNode, TNode); + relation path(TNode, TNode); + edge(x.clone(), y.clone()) <-- for (x, y) in edges.iter(); + + path(x.clone(), y.clone()) <-- edge(x,y); + path(x.clone(), z.clone()) <-- edge(x,y), path(y, z); + } + .path + } + + let res = compute_tc(&[(1, 2), (2, 3)]); + println!("res: {:?}", res); + assert!(rels_equal([(1, 2), (2, 3), (1, 3)], res)); +} + +#[test] +fn test_ascent_fac() { + ascent_m_par! { + struct Fac; + relation fac(u64, u64); + relation do_fac(u64); + + fac(0, 1) <-- do_fac(0); + do_fac(x - 1) <-- do_fac(x), if *x > 0; + fac(*x, x * sub1fac) <-- do_fac(x) if *x > 0, fac(x - 1, sub1fac); + + do_fac(10); + } + let mut prog = Fac::default(); + prog.run(); + println!("fac: {:?}", prog.fac); + println!("{}", Fac::summary()); + println!("{}", prog.relation_sizes_summary()); + println!("{}", prog.scc_times_summary()); + + assert!(prog.fac.iter().contains(&(5, 120))); +} + +#[test] +fn test_consuming_ascent_run_tc() { + fn compute_tc(inp: impl Iterator) -> Vec<(i32, i32)> { + ascent_run! { + relation tc(i32, i32); + relation r(i32, i32); + + r(x, y) <-- for (x, y) in inp; + tc(*x, *y) <-- r(x, y); + tc(*x, *z) <-- r(x, y), tc(y, z); + } + .tc + } + let res = compute_tc([(1, 2), (2, 3)].into_iter()); + println!("res: {:?}", res); + assert!(rels_equal([(1, 2), (2, 3), (1, 3)], compute_tc([(1, 2), (2, 3)].into_iter()))); +} + +#[test] +fn test_ascent_simple_join() { + let res = ascent_run_m_par! { + relation bar(i32, i32); + relation foo(i32, i32); + relation baz(i32, i32); + + bar(2, 3); + foo(1, 2); + bar(2, 1); + + baz(*x, *z) <-- foo(x, y), bar(y, z), if x != z; + foo(*x, *y), bar(*x, y) <-- baz(x, y); + }; + println!("baz: {:?}", res.baz); + assert!(rels_equal([(1, 3)], res.baz)); +} + +#[test] +fn test_ascent_simple_join2() { + let res = ascent_run_m_par! { + relation bar(i32, i32); + relation foo(i32, i32); + relation baz(i32, i32); + + foo(1, 2); + foo(10, 2); + bar(2, 3); + bar(2, 1); + + baz(*x, *z) <-- foo(x, y) if *x != 10, bar(y, z), if x != z; + foo(*x, *y), bar(*x, *y) <-- baz(x, y); + }; + println!("baz: {:?}", res.baz); + assert_rels_eq!([(1, 3)], res.baz); +} + +#[test] +fn test_ascent_simple_join3() { + let res = ascent_run_m_par! { + relation bar(i32, i32); + relation foo(i32, i32); + relation baz(i32, i32); + + foo(1, 2); + foo(10, 2); + bar(2, 3); + bar(2, 1); + + baz(*x, *z) <-- foo(x, y) if *x != 10, bar(y, ?z) if *z < 4, if x != z; + + baz(*x, *z) <-- foo(x, y) if *x != 10, bar(y, z) if *z * x != 444, if x != z; + foo(*x, *y), bar(*x, *y) <-- baz(x, y); + }; + println!("baz: {:?}", res.baz); + assert_rels_eq!([(1, 3)], res.baz); +} + +#[test] +fn test_ascent_simple_join4() { + #[derive(Default, Clone, Copy)] + struct Prop { + transitive: bool, + reflexive: bool, + symmetric: bool, + } + let no_prop = Prop::default(); + + let input_rel = vec![(1, 2), (2, 3)]; + + let test_cases = vec![ + (Prop { transitive: true, ..no_prop }, vec![(1, 2), (2, 3), (1, 3)]), + (Prop { reflexive: true, ..no_prop }, vec![(1, 2), (2, 3), (1, 1), (2, 2), (3, 3)]), + (Prop { symmetric: true, ..no_prop }, vec![(1, 2), (2, 3), (3, 2), (2, 1)]), + (Prop { reflexive: true, transitive: true, symmetric: true }, vec![ + (1, 2), + (2, 3), + (1, 3), + (1, 1), + (2, 2), + (3, 3), + (2, 1), + (3, 2), + (3, 1), + ]), + ]; + + for (prop, expected) in test_cases { + let res = ascent_run_m_par! { + relation rel(i32, i32) = input_rel.iter().cloned().collect(); + + rel(y, x) <-- if prop.symmetric, rel(x, y); + rel(y, y), rel(x, x) <-- if prop.reflexive, rel(x, y); + rel(x, z) <-- if prop.transitive, rel(x, y), rel(y, z); + }; + assert_rels_eq!(res.rel, expected); + } +} + +#[test] +fn test_ascent_simple_join5() { + let res = ascent_run_m_par! { + relation foo(i32, i32); + foo(1,2), foo(2, 3); + foo(x, z) <-- let x = &42, foo(x, y), foo(y, z); + + relation bar(i32, i32); + bar(1, 2), bar(2, 3); + bar(x, z) <-- let z = &42, bar(x, y), bar(y, z); + + relation baz(i32, i32); + baz(x, y) <-- foo(x, y), bar(x, y); + baz(x, z) <-- let _ = 42, if let Some(w) = Some(42), baz(x, y), baz(y, z); + + }; + + assert_rels_eq!(res.foo, [(1, 2), (2, 3)]); + assert_rels_eq!(res.bar, [(1, 2), (2, 3)]); + assert_rels_eq!(res.baz, [(1, 2), (2, 3), (1, 3)]); +} + +#[test] +fn test_ascent_wildcards() { + let res = ascent_run_m_par! { + relation foo(i32, i32, i32); + relation bar(i32, i32); + relation baz(i32); + + foo(1, 2, 3); + foo(2, 3, 4); + bar(1, 1); + bar(1, 2); + bar(1, 3); + + baz(x) <-- + foo(x, _, _), + bar(_, x); + }; + println!("baz: {:?}", res.baz); + assert_rels_eq!([(1,), (2,)], res.baz); +} + +fn min<'a>(inp: impl Iterator) -> impl Iterator { + inp.map(|tuple| tuple.0).min().cloned().into_iter() +} + +#[test] +fn test_ascent_agg() { + let res = ascent_run_m_par! { + relation foo(i32, i32); + relation bar(i32, i32, i32); + relation baz(i32, i32, i32); + + foo(1, 2); + foo(2, 3); + bar(1, 2, 10); + bar(1, 2, 100); + + baz(x, y, min_z) <-- + foo(x, y), + agg min_z = min(z) in bar(x, y, z); + }; + println!("{}", res.summary()); + println!("baz: {:?}", res.baz); + assert_rels_eq!([(1, 2, 10)], res.baz); +} + +#[test] +fn test_run_timeout() { + ascent! { + #![generate_run_timeout] + /// A diverging Ascent program + struct Diverging; + /// foooooooooooo + relation foo(u128); + foo(0); + foo(x + 1) <-- foo(x); + } + + let mut prog = Diverging::default(); + prog.foo = vec![(1,), (2,)]; + let run_timeout_res = prog.run_timeout(Duration::from_millis(5)); + assert!(!run_timeout_res); +} + +#[test] +fn test_ascent_bounded_set() { + use ascent::lattice::bounded_set::BoundedSet; + ascent_m_par! { struct AscentProgram; + + lattice num_store(BoundedSet); + relation init(i32); + + init(x) <-- for x in 0..20; + num_store(BoundedSet::singleton(*x)) <-- + init(x); + } + + let mut prog = AscentProgram::<10>::default(); + prog.run(); + let store = &lat_to_vec(prog.num_store)[0].0; + for (x,) in prog.init.iter() { + assert!(store.contains(x)); + } +} + +#[test] +fn test_issue3() { + #![allow(non_snake_case)] + + ascent_m_par! { + relation a__(i32, i32); + relation c__(i32, i32, i32); + relation e__(i32); + relation h__(i32, i32, i32); + + e__(a) <-- a__(b, a); + h__(e, e, e) <-- a__(d, e), c__(e, f, e), e__(e); + } + let mut prog = AscentProgram::default(); + prog.a__ = FromIterator::from_iter(vec![(88, 5), (37, 24), (11, 91)]); + prog.c__ = FromIterator::from_iter(vec![(32, 83, 88), (2, 8, 5)]); + prog.e__ = FromIterator::from_iter(vec![(44,), (83,)]); + prog.h__ = FromIterator::from_iter(vec![(38, 88, 18), (76, 18, 65), (86, 73, 91), (98, 26, 91), (76, 10, 14)]); + + prog.run(); + println!("h__: {:?}", prog.h__); + assert_rels_eq!(prog.h__, [(38, 88, 18), (76, 18, 65), (86, 73, 91), (98, 26, 91), (76, 10, 14)]); +} + +#[test] +fn test_repeated_vars_simple_joins() { + ascent_m_par! { + relation foo1(i32, i32); + relation foo2(i32, i32); + relation bar(i32, i32); + + // filler: + foo2(100, 100), foo2(101, 101), foo2(102, 102); + + foo1(1, 1), foo2(1, 2), foo1(10, 11), foo2(11, 12); + + bar(x, y) <-- foo2(x, y), foo1(x, x); + } + let mut prog = AscentProgram::default(); + prog.run(); + + println!("bar: {:?}", prog.bar); + assert_rels_eq!(prog.bar, [(1, 2)]); +} + +#[test] +fn test_aggregated_lattice() { + let res = ascent::ascent_run! { + relation foo(i32, i32); + lattice bar(i32, i32); + + bar(x, y) <-- for x in 0..2, for y in 5..10; + + foo(x, z) <-- + for x in 0..2, + agg z = ascent::aggregators::sum(y) in bar(x, y); + }; + assert_rels_eq!(res.bar, [(0, 9), (1, 9)]); +} + +#[test] +fn test_ds_attr() { + use ascent::rel as my_rel; + let res = ascent::ascent_run! { + #![ds(my_rel)] + + #[ds(ascent::rel)] + relation foo(i32, i32) = vec![(0, 1), (1, 0)]; + + relation bar(i32, i32); + + bar(x, y) <-- foo(x, y), if x < y; + }; + + assert_rels_eq!(res.bar, [(0, 1)]); +} diff --git a/ascent_tests/src/utils.rs b/ascent_tests/src/utils.rs index 235302e..a211f30 100644 --- a/ascent_tests/src/utils.rs +++ b/ascent_tests/src/utils.rs @@ -1,26 +1,23 @@ -#![allow(dead_code)] -use std::collections::HashSet; - - -pub fn collect_set(iter : impl Iterator) -> HashSet { - iter.collect() -} - -pub fn into_set(iter : impl IntoIterator) -> HashSet { - iter.into_iter().collect() -} - -pub fn rels_equal(rel1: impl IntoIterator, rel2: impl IntoIterator) -> bool { - rel1.into_iter().collect::>() == rel2.into_iter().collect::>() -} - -#[macro_export] -macro_rules! assert_rels_eq { - ($rel1: expr, $rel2: expr) => { - let (rel1, rel2) = ($rel1.into_iter().collect::>(), $rel2.into_iter().collect::>()); - if rel1 != rel2 { - panic!("Expected rels to be equal. \nrel1: {:?} \nrel2: {:?}", - rel1, rel2); - } - }; -} \ No newline at end of file +#![allow(dead_code)] +use std::collections::HashSet; + +pub fn collect_set(iter: impl Iterator) -> HashSet { iter.collect() } + +pub fn into_set(iter: impl IntoIterator) -> HashSet { iter.into_iter().collect() } + +pub fn rels_equal( + rel1: impl IntoIterator, rel2: impl IntoIterator, +) -> bool { + rel1.into_iter().collect::>() == rel2.into_iter().collect::>() +} + +#[macro_export] +macro_rules! assert_rels_eq { + ($rel1: expr, $rel2: expr) => { + let (rel1, rel2) = ($rel1.into_iter().collect::>(), $rel2.into_iter().collect::>()); + if rel1 != rel2 { + panic!("Expected rels to be equal. \nrel1: {:?} \nrel2: {:?}", + rel1, rel2); + } + }; +} diff --git a/byods/ascent-byods-rels/src/adaptor/bin_rel.rs b/byods/ascent-byods-rels/src/adaptor/bin_rel.rs index 10b32bd..989c0b7 100644 --- a/byods/ascent-byods-rels/src/adaptor/bin_rel.rs +++ b/byods/ascent-byods-rels/src/adaptor/bin_rel.rs @@ -1,261 +1,269 @@ -use std::iter::{Map, once, Once}; -use std::marker::PhantomData; - -use ascent::internal::{RelIndexMerge, RelIndexWrite, RelFullIndexWrite, RelFullIndexRead}; -use ascent::internal::{RelIndexRead, RelIndexReadAll}; - -use crate::iterator_from_dyn::IteratorFromDyn; - - -/// a helper trait for implementing binary relation data structures -pub trait ByodsBinRel: RelIndexMerge + Default { - type T0; - type T1; - - fn contains(&self, x0: &Self::T0, x1: &Self::T1) -> bool; - - type AllIter<'a>: Iterator where Self: 'a; - fn iter_all<'a>(&'a self) -> Self::AllIter<'a>; - fn len_estimate(&self) -> usize; - - type Ind0AllIterValsIter<'a>: Iterator where Self: 'a; - type Ind0AllIter<'a>: Iterator)> where Self: 'a; - fn ind0_iter_all<'a>(&'a self) -> Self::Ind0AllIter<'a>; - fn ind0_len_estimate(&self) -> usize; - - type Ind0ValsIter<'a>: Iterator + Clone where Self: 'a; - fn ind0_index_get<'a>(&'a self, key: &Self::T0) -> Option>; - - - type Ind1AllIterValsIter<'a>: Iterator where Self: 'a; - type Ind1AllIter<'a>: Iterator)> where Self: 'a; - fn ind1_iter_all<'a>(&'a self) -> Self::Ind1AllIter<'a>; - fn ind1_len_estimate(&self) -> usize; - - type Ind1ValsIter<'a>: Iterator + Clone where Self: 'a; - fn ind1_index_get<'a>(&'a self, key: &Self::T1) -> Option>; - - fn insert(&mut self, x0: Self::T0, x1: Self::T1) -> bool; - - fn is_empty(&self) -> bool { - self.iter_all().next().is_none() - } -} - -pub struct ByodsBinRelInd0<'a, TBinRel>(&'a TBinRel); - -impl<'a, TBinRel: ByodsBinRel> RelIndexRead<'a> for ByodsBinRelInd0<'a, TBinRel> { - type Key = (TBinRel::T0,); - type Value = (&'a TBinRel::T1, ); - - type IteratorType = std::iter::Map, fn(&TBinRel::T1) -> (&TBinRel::T1,)>; - - fn index_get(&'a self, key: &Self::Key) -> Option { - let res = self.0.ind0_index_get(&key.0)?; - let res: Self::IteratorType = res.map(|v| (v, )); - Some(res) - } - - fn len(&'a self) -> usize { - self.0.ind0_len_estimate() - } -} - -impl<'a, TBinRel: ByodsBinRel> RelIndexReadAll<'a> for ByodsBinRelInd0<'a, TBinRel> { - type Key = (&'a TBinRel::T0, ); - type Value = (&'a TBinRel::T1, ); - - type ValueIteratorType = std::iter::Map, fn(&TBinRel::T1) -> (&TBinRel::T1,)>; - type AllIteratorType = Map, for<'aa> fn((&'aa TBinRel::T0, TBinRel::Ind0AllIterValsIter<'a>)) -> ((&'aa TBinRel::T0,), Map, fn(&TBinRel::T1) -> (&TBinRel::T1,)>)>; - - fn iter_all(&'a self) -> Self::AllIteratorType { - let res: Self::AllIteratorType = self.0.ind0_iter_all().map(|(k, vals_iter)| { - let new_vals_iter: Self::ValueIteratorType = vals_iter.map(|v| (v, )); - ((k, ), new_vals_iter) - }); - res - } -} - -pub struct ByodsBinRelInd1<'a, TBinRel>(&'a TBinRel); - -impl<'a, TBinRel: ByodsBinRel> RelIndexRead<'a> for ByodsBinRelInd1<'a, TBinRel> { - type Key = (TBinRel::T1,); - type Value = (&'a TBinRel::T0, ); - - type IteratorType = std::iter::Map, fn(&TBinRel::T0) -> (&TBinRel::T0,)>; - - fn index_get(&'a self, key: &Self::Key) -> Option { - let res = self.0.ind1_index_get(&key.0)?; - let res: Self::IteratorType = res.map(|v| (v, )); - Some(res) - } - - fn len(&'a self) -> usize { - self.0.ind1_len_estimate() - } -} - -impl<'a, TBinRel: ByodsBinRel> RelIndexReadAll<'a> for ByodsBinRelInd1<'a, TBinRel> { - type Key = (&'a TBinRel::T1, ); - type Value = (&'a TBinRel::T0, ); - - type ValueIteratorType = std::iter::Map, fn(&TBinRel::T0) -> (&TBinRel::T0,)>; - type AllIteratorType = Map, for<'aa> fn((&'aa TBinRel::T1, TBinRel::Ind1AllIterValsIter<'a>)) -> ((&'aa TBinRel::T1,), Map, fn(&TBinRel::T0) -> (&TBinRel::T0,)>)>; - - fn iter_all(&'a self) -> Self::AllIteratorType { - let res: Self::AllIteratorType = self.0.ind1_iter_all().map(|(k, vals_iter)| { - let new_vals_iter: Self::ValueIteratorType = vals_iter.map(|v| (v, )); - ((k, ), new_vals_iter) - }); - res - } -} - -pub struct ByodsBinRelInd0_1<'a, TBinRel>(&'a TBinRel); - -impl<'a, TBinRel: ByodsBinRel> RelIndexRead<'a> for ByodsBinRelInd0_1<'a, TBinRel> { - type Key = (TBinRel::T0, TBinRel::T1); - type Value = (); - - type IteratorType = Once<()>; - - fn index_get(&'a self, key: &Self::Key) -> Option { - if self.0.contains(&key.0, &key.1) { - Some(once(())) - } else { - None - } - } - - fn len(&'a self) -> usize { - self.0.len_estimate() - } -} - -impl<'a, TBinRel: ByodsBinRel> RelIndexReadAll<'a> for ByodsBinRelInd0_1<'a, TBinRel> { - type Key = (&'a TBinRel::T0, &'a TBinRel::T1); - type Value = (); - - type ValueIteratorType = Once<()>; - type AllIteratorType = Map, for<'aa> fn((&'aa TBinRel::T0, &'aa TBinRel::T1)) -> ((&'aa TBinRel::T0, &'aa TBinRel::T1), Once<()>)>; - - fn iter_all(&'a self) -> Self::AllIteratorType { - let res: Self::AllIteratorType = self.0.iter_all().map(|t| (t, once(()))); - res - } -} - -impl<'a, TBinRel: ByodsBinRel> RelFullIndexRead<'a> for ByodsBinRelInd0_1<'a, TBinRel> { - type Key = (TBinRel::T0, TBinRel::T1); - - fn contains_key(&'a self, key: &Self::Key) -> bool { - self.0.contains(&key.0, &key.1) - } -} - -pub struct ByodsBinRelInd0_1Write<'a, TBinRel>(&'a mut TBinRel); - -impl<'a, TBinRel> RelIndexMerge for ByodsBinRelInd0_1Write<'a, TBinRel> { - fn move_index_contents(_from: &mut Self, _to: &mut Self) { } //noop -} - -impl<'a, TBinRel: ByodsBinRel> RelIndexWrite for ByodsBinRelInd0_1Write<'a, TBinRel> { - type Key = (TBinRel::T0, TBinRel::T1); - type Value = (); - - fn index_insert(&mut self, key: Self::Key, (): Self::Value) { - self.0.insert(key.0, key.1); - } -} - -impl<'a, TBinRel: ByodsBinRel> RelFullIndexWrite for ByodsBinRelInd0_1Write<'a, TBinRel> -where TBinRel::T0: Clone, TBinRel::T1: Clone -{ - type Key = (TBinRel::T0, TBinRel::T1); - type Value = (); - - fn insert_if_not_present(&mut self, key: &Self::Key, (): Self::Value) -> bool { - self.0.insert(key.0.clone(), key.1.clone()) - } -} - - -pub struct ByodsBinRelIndNone<'a, TBinRel>(&'a TBinRel); - -impl<'a, TBinRel: ByodsBinRel> RelIndexRead<'a> for ByodsBinRelIndNone<'a, TBinRel> { - type Key = (); - - type Value = (&'a TBinRel::T0, &'a TBinRel::T1); - type IteratorType = IteratorFromDyn<'a, Self::Value>; - - fn index_get(&'a self, (): &Self::Key) -> Option { - let res = || self.0.iter_all(); - Some(IteratorFromDyn::new(res)) - } - - fn len(&'a self) -> usize { - 1 - } -} - -impl<'a, TBinRel: ByodsBinRel> RelIndexReadAll<'a> for ByodsBinRelIndNone<'a, TBinRel> { - type Key = (); - type Value = (&'a TBinRel::T0, &'a TBinRel::T1); - - type ValueIteratorType = TBinRel::AllIter<'a>; - type AllIteratorType = Once<((), Self::ValueIteratorType)>; - - fn iter_all(&'a self) -> Self::AllIteratorType { - let res = once(((), self.0.iter_all())); - res - } -} - - -use ascent::internal::ToRelIndex; -use crate::rel_boilerplate::NoopRelIndexWrite; -macro_rules! to_rel_ind { - ($name: ident, $key: ty, $val: ty) => {paste::paste!{ - pub struct [](PhantomData<(T0, T1)>); - - impl Default for [] { - fn default() -> Self { Self(PhantomData) } - } - - impl ToRelIndex for [] - where Rel: ByodsBinRel, - { - type RelIndex<'a> = $name<'a, Rel> where Self: 'a, Rel: 'a; - fn to_rel_index<'a>(&'a self, rel: &'a Rel) -> Self::RelIndex<'a> { $name(rel) } - - type RelIndexWrite<'a> = NoopRelIndexWrite<$key, $val> where Self: 'a, Rel: 'a; - fn to_rel_index_write<'a>(&'a mut self, _rel: &'a mut Rel) -> Self::RelIndexWrite<'a> { - NoopRelIndexWrite::default() - } - } - }}; -} - -to_rel_ind!(ByodsBinRelIndNone, (), (T0, T1)); -to_rel_ind!(ByodsBinRelInd0, (T0, ), (T1, )); -to_rel_ind!(ByodsBinRelInd1, (T1, ), (T0, )); - -pub struct ToByodsBinRelInd0_1(PhantomData<(T0, T1)>); - -impl Default for ToByodsBinRelInd0_1 { - fn default() -> Self { Self(PhantomData) } -} -impl ToRelIndex for ToByodsBinRelInd0_1 -where - Rel: ByodsBinRel, -{ - type RelIndex<'a> = ByodsBinRelInd0_1<'a, Rel> where Self:'a, Rel:'a; - fn to_rel_index<'a>(&'a self, rel: &'a Rel) -> Self::RelIndex<'a> { ByodsBinRelInd0_1(rel) } - - type RelIndexWrite<'a> = ByodsBinRelInd0_1Write<'a, Rel> where Self:'a, Rel:'a; - fn to_rel_index_write<'a>(&'a mut self, rel: &'a mut Rel) -> Self::RelIndexWrite<'a> { - ByodsBinRelInd0_1Write(rel) - } -} - +use std::iter::{Map, Once, once}; +use std::marker::PhantomData; + +use ascent::internal::{ + RelFullIndexRead, RelFullIndexWrite, RelIndexMerge, RelIndexRead, RelIndexReadAll, RelIndexWrite, +}; + +use crate::iterator_from_dyn::IteratorFromDyn; + +/// a helper trait for implementing binary relation data structures +pub trait ByodsBinRel: RelIndexMerge + Default { + type T0; + type T1; + + fn contains(&self, x0: &Self::T0, x1: &Self::T1) -> bool; + + type AllIter<'a>: Iterator + where Self: 'a; + fn iter_all<'a>(&'a self) -> Self::AllIter<'a>; + fn len_estimate(&self) -> usize; + + type Ind0AllIterValsIter<'a>: Iterator + where Self: 'a; + type Ind0AllIter<'a>: Iterator)> + where Self: 'a; + fn ind0_iter_all<'a>(&'a self) -> Self::Ind0AllIter<'a>; + fn ind0_len_estimate(&self) -> usize; + + type Ind0ValsIter<'a>: Iterator + Clone + where Self: 'a; + fn ind0_index_get<'a>(&'a self, key: &Self::T0) -> Option>; + + type Ind1AllIterValsIter<'a>: Iterator + where Self: 'a; + type Ind1AllIter<'a>: Iterator)> + where Self: 'a; + fn ind1_iter_all<'a>(&'a self) -> Self::Ind1AllIter<'a>; + fn ind1_len_estimate(&self) -> usize; + + type Ind1ValsIter<'a>: Iterator + Clone + where Self: 'a; + fn ind1_index_get<'a>(&'a self, key: &Self::T1) -> Option>; + + fn insert(&mut self, x0: Self::T0, x1: Self::T1) -> bool; + + fn is_empty(&self) -> bool { self.iter_all().next().is_none() } +} + +pub struct ByodsBinRelInd0<'a, TBinRel>(&'a TBinRel); + +impl<'a, TBinRel: ByodsBinRel> RelIndexRead<'a> for ByodsBinRelInd0<'a, TBinRel> { + type Key = (TBinRel::T0,); + type Value = (&'a TBinRel::T1,); + + type IteratorType = std::iter::Map, fn(&TBinRel::T1) -> (&TBinRel::T1,)>; + + fn index_get(&'a self, key: &Self::Key) -> Option { + let res = self.0.ind0_index_get(&key.0)?; + let res: Self::IteratorType = res.map(|v| (v,)); + Some(res) + } + + fn len(&'a self) -> usize { self.0.ind0_len_estimate() } +} + +impl<'a, TBinRel: ByodsBinRel> RelIndexReadAll<'a> for ByodsBinRelInd0<'a, TBinRel> { + type Key = (&'a TBinRel::T0,); + type Value = (&'a TBinRel::T1,); + + type ValueIteratorType = std::iter::Map, fn(&TBinRel::T1) -> (&TBinRel::T1,)>; + type AllIteratorType = Map< + TBinRel::Ind0AllIter<'a>, + for<'aa> fn( + (&'aa TBinRel::T0, TBinRel::Ind0AllIterValsIter<'a>), + ) + -> ((&'aa TBinRel::T0,), Map, fn(&TBinRel::T1) -> (&TBinRel::T1,)>), + >; + + fn iter_all(&'a self) -> Self::AllIteratorType { + let res: Self::AllIteratorType = self.0.ind0_iter_all().map(|(k, vals_iter)| { + let new_vals_iter: Self::ValueIteratorType = vals_iter.map(|v| (v,)); + ((k,), new_vals_iter) + }); + res + } +} + +pub struct ByodsBinRelInd1<'a, TBinRel>(&'a TBinRel); + +impl<'a, TBinRel: ByodsBinRel> RelIndexRead<'a> for ByodsBinRelInd1<'a, TBinRel> { + type Key = (TBinRel::T1,); + type Value = (&'a TBinRel::T0,); + + type IteratorType = std::iter::Map, fn(&TBinRel::T0) -> (&TBinRel::T0,)>; + + fn index_get(&'a self, key: &Self::Key) -> Option { + let res = self.0.ind1_index_get(&key.0)?; + let res: Self::IteratorType = res.map(|v| (v,)); + Some(res) + } + + fn len(&'a self) -> usize { self.0.ind1_len_estimate() } +} + +impl<'a, TBinRel: ByodsBinRel> RelIndexReadAll<'a> for ByodsBinRelInd1<'a, TBinRel> { + type Key = (&'a TBinRel::T1,); + type Value = (&'a TBinRel::T0,); + + type ValueIteratorType = std::iter::Map, fn(&TBinRel::T0) -> (&TBinRel::T0,)>; + type AllIteratorType = Map< + TBinRel::Ind1AllIter<'a>, + for<'aa> fn( + (&'aa TBinRel::T1, TBinRel::Ind1AllIterValsIter<'a>), + ) + -> ((&'aa TBinRel::T1,), Map, fn(&TBinRel::T0) -> (&TBinRel::T0,)>), + >; + + fn iter_all(&'a self) -> Self::AllIteratorType { + let res: Self::AllIteratorType = self.0.ind1_iter_all().map(|(k, vals_iter)| { + let new_vals_iter: Self::ValueIteratorType = vals_iter.map(|v| (v,)); + ((k,), new_vals_iter) + }); + res + } +} + +pub struct ByodsBinRelInd0_1<'a, TBinRel>(&'a TBinRel); + +impl<'a, TBinRel: ByodsBinRel> RelIndexRead<'a> for ByodsBinRelInd0_1<'a, TBinRel> { + type Key = (TBinRel::T0, TBinRel::T1); + type Value = (); + + type IteratorType = Once<()>; + + fn index_get(&'a self, key: &Self::Key) -> Option { + if self.0.contains(&key.0, &key.1) { Some(once(())) } else { None } + } + + fn len(&'a self) -> usize { self.0.len_estimate() } +} + +impl<'a, TBinRel: ByodsBinRel> RelIndexReadAll<'a> for ByodsBinRelInd0_1<'a, TBinRel> { + type Key = (&'a TBinRel::T0, &'a TBinRel::T1); + type Value = (); + + type ValueIteratorType = Once<()>; + type AllIteratorType = Map< + TBinRel::AllIter<'a>, + for<'aa> fn((&'aa TBinRel::T0, &'aa TBinRel::T1)) -> ((&'aa TBinRel::T0, &'aa TBinRel::T1), Once<()>), + >; + + fn iter_all(&'a self) -> Self::AllIteratorType { + let res: Self::AllIteratorType = self.0.iter_all().map(|t| (t, once(()))); + res + } +} + +impl<'a, TBinRel: ByodsBinRel> RelFullIndexRead<'a> for ByodsBinRelInd0_1<'a, TBinRel> { + type Key = (TBinRel::T0, TBinRel::T1); + + fn contains_key(&'a self, key: &Self::Key) -> bool { self.0.contains(&key.0, &key.1) } +} + +pub struct ByodsBinRelInd0_1Write<'a, TBinRel>(&'a mut TBinRel); + +impl<'a, TBinRel> RelIndexMerge for ByodsBinRelInd0_1Write<'a, TBinRel> { + fn move_index_contents(_from: &mut Self, _to: &mut Self) {} //noop +} + +impl<'a, TBinRel: ByodsBinRel> RelIndexWrite for ByodsBinRelInd0_1Write<'a, TBinRel> { + type Key = (TBinRel::T0, TBinRel::T1); + type Value = (); + + fn index_insert(&mut self, key: Self::Key, (): Self::Value) { self.0.insert(key.0, key.1); } +} + +impl<'a, TBinRel: ByodsBinRel> RelFullIndexWrite for ByodsBinRelInd0_1Write<'a, TBinRel> +where + TBinRel::T0: Clone, + TBinRel::T1: Clone, +{ + type Key = (TBinRel::T0, TBinRel::T1); + type Value = (); + + fn insert_if_not_present(&mut self, key: &Self::Key, (): Self::Value) -> bool { + self.0.insert(key.0.clone(), key.1.clone()) + } +} + +pub struct ByodsBinRelIndNone<'a, TBinRel>(&'a TBinRel); + +impl<'a, TBinRel: ByodsBinRel> RelIndexRead<'a> for ByodsBinRelIndNone<'a, TBinRel> { + type Key = (); + + type Value = (&'a TBinRel::T0, &'a TBinRel::T1); + type IteratorType = IteratorFromDyn<'a, Self::Value>; + + fn index_get(&'a self, (): &Self::Key) -> Option { + let res = || self.0.iter_all(); + Some(IteratorFromDyn::new(res)) + } + + fn len(&'a self) -> usize { 1 } +} + +impl<'a, TBinRel: ByodsBinRel> RelIndexReadAll<'a> for ByodsBinRelIndNone<'a, TBinRel> { + type Key = (); + type Value = (&'a TBinRel::T0, &'a TBinRel::T1); + + type ValueIteratorType = TBinRel::AllIter<'a>; + type AllIteratorType = Once<((), Self::ValueIteratorType)>; + + fn iter_all(&'a self) -> Self::AllIteratorType { + let res = once(((), self.0.iter_all())); + res + } +} + +use ascent::internal::ToRelIndex; + +use crate::rel_boilerplate::NoopRelIndexWrite; +macro_rules! to_rel_ind { + ($name: ident, $key: ty, $val: ty) => {paste::paste!{ + pub struct [](PhantomData<(T0, T1)>); + + impl Default for [] { + fn default() -> Self { Self(PhantomData) } + } + + impl ToRelIndex for [] + where Rel: ByodsBinRel, + { + type RelIndex<'a> = $name<'a, Rel> where Self: 'a, Rel: 'a; + fn to_rel_index<'a>(&'a self, rel: &'a Rel) -> Self::RelIndex<'a> { $name(rel) } + + type RelIndexWrite<'a> = NoopRelIndexWrite<$key, $val> where Self: 'a, Rel: 'a; + fn to_rel_index_write<'a>(&'a mut self, _rel: &'a mut Rel) -> Self::RelIndexWrite<'a> { + NoopRelIndexWrite::default() + } + } + }}; +} + +to_rel_ind!(ByodsBinRelIndNone, (), (T0, T1)); +to_rel_ind!(ByodsBinRelInd0, (T0,), (T1,)); +to_rel_ind!(ByodsBinRelInd1, (T1,), (T0,)); + +pub struct ToByodsBinRelInd0_1(PhantomData<(T0, T1)>); + +impl Default for ToByodsBinRelInd0_1 { + fn default() -> Self { Self(PhantomData) } +} +impl ToRelIndex for ToByodsBinRelInd0_1 +where Rel: ByodsBinRel +{ + type RelIndex<'a> + = ByodsBinRelInd0_1<'a, Rel> + where + Self: 'a, + Rel: 'a; + fn to_rel_index<'a>(&'a self, rel: &'a Rel) -> Self::RelIndex<'a> { ByodsBinRelInd0_1(rel) } + + type RelIndexWrite<'a> + = ByodsBinRelInd0_1Write<'a, Rel> + where + Self: 'a, + Rel: 'a; + fn to_rel_index_write<'a>(&'a mut self, rel: &'a mut Rel) -> Self::RelIndexWrite<'a> { ByodsBinRelInd0_1Write(rel) } +} diff --git a/byods/ascent-byods-rels/src/adaptor/bin_rel_plus_ternary_provider.rs b/byods/ascent-byods-rels/src/adaptor/bin_rel_plus_ternary_provider.rs index 9d85e32..f81eee3 100644 --- a/byods/ascent-byods-rels/src/adaptor/bin_rel_plus_ternary_provider.rs +++ b/byods/ascent-byods-rels/src/adaptor/bin_rel_plus_ternary_provider.rs @@ -1,115 +1,115 @@ -/// Re-export macros in this module for your binary relation data structure provider -/// that you wish to be a ternary relation as well - -#[doc(hidden)] -#[macro_export] -macro_rules! bin_rel_plus_ternary_provider_rel { - ($name: ident, ($col0: ty, $col1: ty), $indices: expr, ser, ()) => { - $crate::fake_vec::FakeVec<($col0, $col1)> - }; - - ($name: ident, ($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, ()) => { - $crate::fake_vec::FakeVec<($col0, $col1, $col2)> - }; -} -pub use bin_rel_plus_ternary_provider_rel as rel; - -#[doc(hidden)] -#[macro_export] -macro_rules! bin_rel_plus_ternary_provider_full_ind { - ($name: ident, ($col0: ty, $col1: ty), $indices: expr, ser, (), $key: ty, $val: ty) => { - $crate::adaptor::bin_rel::ToByodsBinRelInd0_1<$col0, $col1> - }; - - ($name: ident, ($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), $key: ty, $val: ty) => { - $crate::adaptor::bin_rel_to_ternary::ToBinRelToTernaryInd0_1_2<$col0, $col1, $col2> - }; -} -pub use bin_rel_plus_ternary_provider_full_ind as rel_full_ind; - -#[doc(hidden)] -#[macro_export] -macro_rules! bin_rel_plus_ternary_provider_ind { - ($name: ident, ($col0: ty, $col1: ty), $indices: expr, ser, (), [0], $key: ty, $val: ty) => { - $crate::adaptor::bin_rel::ToByodsBinRelInd0<$col0, $col1> - }; - ($name: ident, ($col0: ty, $col1: ty), $indices: expr, ser, (), [1], $key: ty, $val: ty) => { - $crate::adaptor::bin_rel::ToByodsBinRelInd1<$col0, $col1> - }; - ($name: ident, ($col0: ty, $col1: ty), $indices: expr, ser, (), [], $key: ty, $val: ty) => { - $crate::adaptor::bin_rel::ToByodsBinRelIndNone<$col0, $col1> - }; - - // ternary - ($name: ident, ($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [], $key: ty, $val: ty) => { - $crate::adaptor::bin_rel_to_ternary::ToBinRelToTernaryIndNone<$col0, $col1, $col2> - }; - ($name: ident, ($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [0], $key: ty, $val: ty) => { - $crate::adaptor::bin_rel_to_ternary::ToBinRelToTernaryInd0<$col0, $col1, $col2> - }; - ($name: ident, ($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [1], $key: ty, $val: ty) => { - $crate::adaptor::bin_rel_to_ternary::ToBinRelToTernaryInd1<$col0, $col1, $col2> - }; - ($name: ident, ($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [2], $key: ty, $val: ty) => { - $crate::adaptor::bin_rel_to_ternary::ToBinRelToTernaryInd2<$col0, $col1, $col2> - }; - ($name: ident, ($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [0, 1], $key: ty, $val: ty) => { - $crate::adaptor::bin_rel_to_ternary::ToBinRelToTernaryInd0_1<$col0, $col1, $col2> - }; - ($name: ident, ($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [0, 2], $key: ty, $val: ty) => { - $crate::adaptor::bin_rel_to_ternary::ToBinRelToTernaryInd0_2<$col0, $col1, $col2> - }; - ($name: ident, ($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [1, 2], $key: ty, $val: ty) => { - $crate::adaptor::bin_rel_to_ternary::ToBinRelToTernaryInd1_2<$col0, $col1, $col2> - }; - ($name: ident, ($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [0, 1, 2], $key: ty, $val: ty) => { - $crate::adaptor::bin_rel_to_ternary::ToBinRelToTernaryInd0_1_2<$col0, $col1, $col2> - }; -} -pub use bin_rel_plus_ternary_provider_ind as rel_ind; - -#[doc(hidden)] -#[macro_export] -macro_rules! bin_rel_plus_ternary_provider_rel_codegen { - ( $($tt: tt)* ) => { }; -} -pub use bin_rel_plus_ternary_provider_rel_codegen as rel_codegen; - -#[cfg(test)] -mod test { - #[doc(hidden)] - #[macro_export] - macro_rules! bin_rel_plus_ternary_provider_ind_common { - ($name: ident, ($col0: ty, $col1: ty), $indices: tt, ser, ()) => { - $crate::adaptor::bin_rel_provider::test::DummyRel<$col0, $col1> - }; - - ($name: ident, ($col0: ty, $col1: ty, $col2: ty), $indices: tt, ser, ()) => { - $crate::adaptor::bin_rel_to_ternary::BinRelToTernaryWrapper< - // reverse_map_1 required: - {$crate::inds_contain!($indices, [1]) || $crate::inds_contain!($indices, [1, 2])}, - // reverse_map_2 required: - {$crate::inds_contain!($indices, [2]) || $crate::inds_contain!($indices, [1, 2])}, - $col0, $col1, $col2, - $crate::adaptor::bin_rel_provider::test::DummyRel<$col1, $col2> - > - }; - } - pub use bin_rel_plus_ternary_provider_ind_common as rel_ind_common; - - pub use super::{rel, rel_codegen, rel_full_ind, rel_ind}; - - ascent::ascent! { - #[ds(self)] - relation foo(u32, u64, u128); - - relation bar(u32, u64, u128); - - foo(*x as u32, *y as u64, *z as u128) <-- foo(y, x, z); - - bar(x, y, z) <-- foo(x, y, z), bar(x, _, z); - - bar(x, y, z) <-- foo(_, y, z), bar(x, y, z), foo(x, _, _); - - } -} \ No newline at end of file +/// Re-export macros in this module for your binary relation data structure provider +/// that you wish to be a ternary relation as well + +#[doc(hidden)] +#[macro_export] +macro_rules! bin_rel_plus_ternary_provider_rel { + ($name: ident, ($col0: ty, $col1: ty), $indices: expr, ser, ()) => { + $crate::fake_vec::FakeVec<($col0, $col1)> + }; + + ($name: ident, ($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, ()) => { + $crate::fake_vec::FakeVec<($col0, $col1, $col2)> + }; +} +pub use bin_rel_plus_ternary_provider_rel as rel; + +#[doc(hidden)] +#[macro_export] +macro_rules! bin_rel_plus_ternary_provider_full_ind { + ($name: ident, ($col0: ty, $col1: ty), $indices: expr, ser, (), $key: ty, $val: ty) => { + $crate::adaptor::bin_rel::ToByodsBinRelInd0_1<$col0, $col1> + }; + + ($name: ident, ($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), $key: ty, $val: ty) => { + $crate::adaptor::bin_rel_to_ternary::ToBinRelToTernaryInd0_1_2<$col0, $col1, $col2> + }; +} +pub use bin_rel_plus_ternary_provider_full_ind as rel_full_ind; + +#[doc(hidden)] +#[macro_export] +macro_rules! bin_rel_plus_ternary_provider_ind { + ($name: ident, ($col0: ty, $col1: ty), $indices: expr, ser, (), [0], $key: ty, $val: ty) => { + $crate::adaptor::bin_rel::ToByodsBinRelInd0<$col0, $col1> + }; + ($name: ident, ($col0: ty, $col1: ty), $indices: expr, ser, (), [1], $key: ty, $val: ty) => { + $crate::adaptor::bin_rel::ToByodsBinRelInd1<$col0, $col1> + }; + ($name: ident, ($col0: ty, $col1: ty), $indices: expr, ser, (), [], $key: ty, $val: ty) => { + $crate::adaptor::bin_rel::ToByodsBinRelIndNone<$col0, $col1> + }; + + // ternary + ($name: ident, ($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [], $key: ty, $val: ty) => { + $crate::adaptor::bin_rel_to_ternary::ToBinRelToTernaryIndNone<$col0, $col1, $col2> + }; + ($name: ident, ($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [0], $key: ty, $val: ty) => { + $crate::adaptor::bin_rel_to_ternary::ToBinRelToTernaryInd0<$col0, $col1, $col2> + }; + ($name: ident, ($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [1], $key: ty, $val: ty) => { + $crate::adaptor::bin_rel_to_ternary::ToBinRelToTernaryInd1<$col0, $col1, $col2> + }; + ($name: ident, ($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [2], $key: ty, $val: ty) => { + $crate::adaptor::bin_rel_to_ternary::ToBinRelToTernaryInd2<$col0, $col1, $col2> + }; + ($name: ident, ($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [0, 1], $key: ty, $val: ty) => { + $crate::adaptor::bin_rel_to_ternary::ToBinRelToTernaryInd0_1<$col0, $col1, $col2> + }; + ($name: ident, ($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [0, 2], $key: ty, $val: ty) => { + $crate::adaptor::bin_rel_to_ternary::ToBinRelToTernaryInd0_2<$col0, $col1, $col2> + }; + ($name: ident, ($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [1, 2], $key: ty, $val: ty) => { + $crate::adaptor::bin_rel_to_ternary::ToBinRelToTernaryInd1_2<$col0, $col1, $col2> + }; + ($name: ident, ($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [0, 1, 2], $key: ty, $val: ty) => { + $crate::adaptor::bin_rel_to_ternary::ToBinRelToTernaryInd0_1_2<$col0, $col1, $col2> + }; +} +pub use bin_rel_plus_ternary_provider_ind as rel_ind; + +#[doc(hidden)] +#[macro_export] +macro_rules! bin_rel_plus_ternary_provider_rel_codegen { + ( $($tt: tt)* ) => { }; +} +pub use bin_rel_plus_ternary_provider_rel_codegen as rel_codegen; + +#[cfg(test)] +mod test { + #[doc(hidden)] + #[macro_export] + macro_rules! bin_rel_plus_ternary_provider_ind_common { + ($name: ident, ($col0: ty, $col1: ty), $indices: tt, ser, ()) => { + $crate::adaptor::bin_rel_provider::test::DummyRel<$col0, $col1> + }; + + ($name: ident, ($col0: ty, $col1: ty, $col2: ty), $indices: tt, ser, ()) => { + $crate::adaptor::bin_rel_to_ternary::BinRelToTernaryWrapper< + // reverse_map_1 required: + {$crate::inds_contain!($indices, [1]) || $crate::inds_contain!($indices, [1, 2])}, + // reverse_map_2 required: + {$crate::inds_contain!($indices, [2]) || $crate::inds_contain!($indices, [1, 2])}, + $col0, $col1, $col2, + $crate::adaptor::bin_rel_provider::test::DummyRel<$col1, $col2> + > + }; + } + pub use bin_rel_plus_ternary_provider_ind_common as rel_ind_common; + + pub use super::{rel, rel_codegen, rel_full_ind, rel_ind}; + + ascent::ascent! { + #[ds(self)] + relation foo(u32, u64, u128); + + relation bar(u32, u64, u128); + + foo(*x as u32, *y as u64, *z as u128) <-- foo(y, x, z); + + bar(x, y, z) <-- foo(x, y, z), bar(x, _, z); + + bar(x, y, z) <-- foo(_, y, z), bar(x, y, z), foo(x, _, _); + + } +} diff --git a/byods/ascent-byods-rels/src/adaptor/bin_rel_provider.rs b/byods/ascent-byods-rels/src/adaptor/bin_rel_provider.rs index a67334b..328672b 100644 --- a/byods/ascent-byods-rels/src/adaptor/bin_rel_provider.rs +++ b/byods/ascent-byods-rels/src/adaptor/bin_rel_provider.rs @@ -1,110 +1,123 @@ -//! Re-export macros in this module for your binary relation data structure provider -//! implementd via [`ByodsBinRel`](crate::adaptor::bin_rel::ByodsBinRel) -#[doc(hidden)] -#[macro_export] -macro_rules! bin_rel_provider_rel { - ($name: ident, ($col0: ty, $col1: ty), $indices: expr, ser, ()) => { - $crate::fake_vec::FakeVec<($col0, $col1)> - }; -} -pub use bin_rel_provider_rel as rel; - -#[doc(hidden)] -#[macro_export] -macro_rules! bin_rel_provider_full_ind { - ($name: ident, ($col0: ty, $col1: ty), $indices: expr, ser, (), $key: ty, $val: ty) => { - $crate::adaptor::bin_rel::ToByodsBinRelInd0_1<$col0, $col1> - }; -} -pub use bin_rel_provider_full_ind as rel_full_ind; - -#[doc(hidden)] -#[macro_export] -macro_rules! bin_rel_provider_ind { - ($name: ident, ($col0: ty, $col1: ty), $indices: expr, ser, (), [0], $key: ty, $val: ty) => { - $crate::adaptor::bin_rel::ToByodsBinRelInd0<$col0, $col1> - }; - ($name: ident, ($col0: ty, $col1: ty), $indices: expr, ser, (), [1], $key: ty, $val: ty) => { - $crate::adaptor::bin_rel::ToByodsBinRelInd1<$col0, $col1> - }; - ($name: ident, ($col0: ty, $col1: ty), $indices: expr, ser, (), [], $key: ty, $val: ty) => { - $crate::adaptor::bin_rel::ToByodsBinRelIndNone<$col0, $col1> - }; -} -pub use bin_rel_provider_ind as rel_ind; - -#[doc(hidden)] -#[macro_export] -macro_rules! bin_rel_provider_rel_codegen { - ( $($tt: tt)* ) => { }; -} -pub use bin_rel_provider_rel_codegen as rel_codegen; - -pub(crate) mod test { - use std::iter::Once; - use std::marker::PhantomData; - - use ascent::internal::RelIndexMerge; - - use crate::adaptor::bin_rel::ByodsBinRel; - - #[doc(hidden)] - #[macro_export] - macro_rules! bin_rel_provider_ind_common { - ($name: ident, ($col0: ty, $col1: ty), $indices: expr, ser, ()) => { - $crate::adaptor::bin_rel_provider::test::DummyRel<$col0, $col1> - }; - } - pub use bin_rel_provider_ind_common as rel_ind_common; - - pub use super::{rel, rel_codegen, rel_full_ind, rel_ind}; - - pub struct DummyRel(PhantomData<(T0, T1)>); - - impl Default for DummyRel { - fn default() -> Self { Self(Default::default()) } - } - - impl RelIndexMerge for DummyRel { - fn move_index_contents(_from: &mut Self, _to: &mut Self) { todo!() } - } - - impl ByodsBinRel for DummyRel { - type T0 = T0; - type T1 = T1; - - fn contains(&self, _x0: &Self::T0, _x1: &Self::T1) -> bool { todo!() } - - type AllIter<'a> = Once<(&'a T0, &'a T1)> where Self: 'a; - fn iter_all<'a>(&'a self) -> Self::AllIter<'a> { todo!() } - - fn len_estimate(&self) -> usize { todo!() } - - type Ind0AllIterValsIter<'a> = Once<&'a Self::T1> where Self: 'a; - type Ind0AllIter<'a> = Once<(&'a Self::T0, Self::Ind0AllIterValsIter<'a>)> where Self: 'a; - fn ind0_iter_all<'a>(&'a self) -> Self::Ind0AllIter<'a> { todo!() } - fn ind0_len_estimate(&self) -> usize { todo!() } - - type Ind0ValsIter<'a> = Once<&'a Self::T1> where Self: 'a; - fn ind0_index_get<'a>(&'a self, _key: &Self::T0) -> Option> { todo!() } - - type Ind1AllIterValsIter<'a> = Once<&'a Self::T0> where Self: 'a; - type Ind1AllIter<'a> = Once<(&'a Self::T1, Self::Ind1AllIterValsIter<'a>)> where Self: 'a; - fn ind1_iter_all<'a>(&'a self) -> Self::Ind1AllIter<'a> { todo!() } - fn ind1_len_estimate(&self) -> usize { todo!() } - - type Ind1ValsIter<'a> = Once<&'a Self::T0> where Self: 'a; - fn ind1_index_get<'a>(&'a self, _key: &Self::T1) -> Option> { todo!() } - fn insert(&mut self, _x0: Self::T0, _x1: Self::T1) -> bool { todo!() } - } - - ascent::ascent! { - #[ds(super::test)] - relation foo(u32, usize); - - foo(*x as u32, *y as usize) <-- foo(y, x); - foo(x, y) <-- foo(x, y), foo(& (*y as u32), &(*x as usize)); - - } - -} \ No newline at end of file +//! Re-export macros in this module for your binary relation data structure provider +//! implementd via [`ByodsBinRel`](crate::adaptor::bin_rel::ByodsBinRel) +#[doc(hidden)] +#[macro_export] +macro_rules! bin_rel_provider_rel { + ($name: ident, ($col0: ty, $col1: ty), $indices: expr, ser, ()) => { + $crate::fake_vec::FakeVec<($col0, $col1)> + }; +} +pub use bin_rel_provider_rel as rel; + +#[doc(hidden)] +#[macro_export] +macro_rules! bin_rel_provider_full_ind { + ($name: ident, ($col0: ty, $col1: ty), $indices: expr, ser, (), $key: ty, $val: ty) => { + $crate::adaptor::bin_rel::ToByodsBinRelInd0_1<$col0, $col1> + }; +} +pub use bin_rel_provider_full_ind as rel_full_ind; + +#[doc(hidden)] +#[macro_export] +macro_rules! bin_rel_provider_ind { + ($name: ident, ($col0: ty, $col1: ty), $indices: expr, ser, (), [0], $key: ty, $val: ty) => { + $crate::adaptor::bin_rel::ToByodsBinRelInd0<$col0, $col1> + }; + ($name: ident, ($col0: ty, $col1: ty), $indices: expr, ser, (), [1], $key: ty, $val: ty) => { + $crate::adaptor::bin_rel::ToByodsBinRelInd1<$col0, $col1> + }; + ($name: ident, ($col0: ty, $col1: ty), $indices: expr, ser, (), [], $key: ty, $val: ty) => { + $crate::adaptor::bin_rel::ToByodsBinRelIndNone<$col0, $col1> + }; +} +pub use bin_rel_provider_ind as rel_ind; + +#[doc(hidden)] +#[macro_export] +macro_rules! bin_rel_provider_rel_codegen { + ( $($tt: tt)* ) => { }; +} +pub use bin_rel_provider_rel_codegen as rel_codegen; + +pub(crate) mod test { + use std::iter::Once; + use std::marker::PhantomData; + + use ascent::internal::RelIndexMerge; + + use crate::adaptor::bin_rel::ByodsBinRel; + + #[doc(hidden)] + #[macro_export] + macro_rules! bin_rel_provider_ind_common { + ($name: ident, ($col0: ty, $col1: ty), $indices: expr, ser, ()) => { + $crate::adaptor::bin_rel_provider::test::DummyRel<$col0, $col1> + }; + } + pub use bin_rel_provider_ind_common as rel_ind_common; + + pub use super::{rel, rel_codegen, rel_full_ind, rel_ind}; + + pub struct DummyRel(PhantomData<(T0, T1)>); + + impl Default for DummyRel { + fn default() -> Self { Self(Default::default()) } + } + + impl RelIndexMerge for DummyRel { + fn move_index_contents(_from: &mut Self, _to: &mut Self) { todo!() } + } + + impl ByodsBinRel for DummyRel { + type T0 = T0; + type T1 = T1; + + fn contains(&self, _x0: &Self::T0, _x1: &Self::T1) -> bool { todo!() } + + type AllIter<'a> + = Once<(&'a T0, &'a T1)> + where Self: 'a; + fn iter_all<'a>(&'a self) -> Self::AllIter<'a> { todo!() } + + fn len_estimate(&self) -> usize { todo!() } + + type Ind0AllIterValsIter<'a> + = Once<&'a Self::T1> + where Self: 'a; + type Ind0AllIter<'a> + = Once<(&'a Self::T0, Self::Ind0AllIterValsIter<'a>)> + where Self: 'a; + fn ind0_iter_all<'a>(&'a self) -> Self::Ind0AllIter<'a> { todo!() } + fn ind0_len_estimate(&self) -> usize { todo!() } + + type Ind0ValsIter<'a> + = Once<&'a Self::T1> + where Self: 'a; + fn ind0_index_get<'a>(&'a self, _key: &Self::T0) -> Option> { todo!() } + + type Ind1AllIterValsIter<'a> + = Once<&'a Self::T0> + where Self: 'a; + type Ind1AllIter<'a> + = Once<(&'a Self::T1, Self::Ind1AllIterValsIter<'a>)> + where Self: 'a; + fn ind1_iter_all<'a>(&'a self) -> Self::Ind1AllIter<'a> { todo!() } + fn ind1_len_estimate(&self) -> usize { todo!() } + + type Ind1ValsIter<'a> + = Once<&'a Self::T0> + where Self: 'a; + fn ind1_index_get<'a>(&'a self, _key: &Self::T1) -> Option> { todo!() } + fn insert(&mut self, _x0: Self::T0, _x1: Self::T1) -> bool { todo!() } + } + + ascent::ascent! { + #[ds(super::test)] + relation foo(u32, usize); + + foo(*x as u32, *y as usize) <-- foo(y, x); + foo(x, y) <-- foo(x, y), foo(& (*y as u32), &(*x as usize)); + + } +} diff --git a/byods/ascent-byods-rels/src/adaptor/bin_rel_to_ternary.rs b/byods/ascent-byods-rels/src/adaptor/bin_rel_to_ternary.rs index d27a2e3..5ffe11e 100644 --- a/byods/ascent-byods-rels/src/adaptor/bin_rel_to_ternary.rs +++ b/byods/ascent-byods-rels/src/adaptor/bin_rel_to_ternary.rs @@ -1,623 +1,754 @@ - -use std::hash::BuildHasherDefault; -use std::hash::Hash; -use std::iter::Map; -use std::iter::once; -use std::marker::PhantomData; -use ascent::internal::RelFullIndexRead; -use ascent::internal::RelFullIndexWrite; -use ascent::internal::RelIndexMerge; -use ascent::internal::RelIndexWrite; -use ascent::internal::RelIndexRead; -use ascent::internal::RelIndexReadAll; -use ascent::internal::ToRelIndex; - -use hashbrown::HashMap; -use rustc_hash::FxHasher; - -use crate::iterator_from_dyn::IteratorFromDyn; -use crate::utils::AltHashSet; -use crate::utils::hash_one; - -use super::bin_rel::ByodsBinRel; - -pub struct BinRelToTernary -where T0: Clone + Hash + Eq, T1: Clone + Hash + Eq, T2: Clone + Hash + Eq, TBinRel: ByodsBinRel -{ - pub map: HashMap>, - pub reverse_map1: Option>, BuildHasherDefault>>, - pub reverse_map2: Option>, BuildHasherDefault>> -} - -impl RelIndexMerge for BinRelToTernary -where T0: Clone + Hash + Eq, T1: Clone + Hash + Eq, T2: Clone + Hash + Eq, TBinRel: ByodsBinRel -{ - fn move_index_contents(_from: &mut Self, _to: &mut Self) { - panic!("merge_delta_to_total_new_to_delta must be called instead"); - } - - fn merge_delta_to_total_new_to_delta(new: &mut Self, delta: &mut Self, total: &mut Self) { - let mut new_delta_map = HashMap::default(); - for (k, mut delta_trrel) in delta.map.drain() { - let mut new_trrel = new.map.remove(&k).unwrap_or_default(); - match total.map.entry(k.clone()) { - hashbrown::hash_map::Entry::Occupied(mut total_entry) => { - TBinRel::merge_delta_to_total_new_to_delta(&mut new_trrel, &mut delta_trrel, total_entry.get_mut()); - if !delta_trrel.is_empty() { - new_delta_map.insert(k, delta_trrel); - } - }, - hashbrown::hash_map::Entry::Vacant(total_vacant_entry) => { - let mut new_total = TBinRel::default(); - TBinRel::merge_delta_to_total_new_to_delta(&mut new_trrel, &mut delta_trrel, &mut new_total); - total_vacant_entry.insert(new_total); - if !delta_trrel.is_empty() { - new_delta_map.insert(k, delta_trrel); - } - }, - } - } - for (k, mut new_trrel) in new.map.drain() { - let mut new_delta = Default::default(); - match total.map.entry(k.clone()) { - hashbrown::hash_map::Entry::Occupied(mut total_entry) => { - TBinRel::merge_delta_to_total_new_to_delta(&mut new_trrel, &mut new_delta, total_entry.get_mut()); - new_delta_map.insert(k, new_delta); - }, - hashbrown::hash_map::Entry::Vacant(_) => { - TBinRel::merge_delta_to_total_new_to_delta(&mut new_trrel, &mut new_delta, &mut Default::default()); - new_delta_map.insert(k, new_delta); - }, - } - } - delta.map = new_delta_map; - - if delta.reverse_map1.is_some() { - crate::utils::move_hash_map_of_alt_hash_set_contents(delta.reverse_map1.as_mut().unwrap(), total.reverse_map1.as_mut().unwrap()); - std::mem::swap(delta.reverse_map1.as_mut().unwrap(), new.reverse_map1.as_mut().unwrap()); - } - - if delta.reverse_map2.is_some() { - crate::utils::move_hash_map_of_alt_hash_set_contents(delta.reverse_map2.as_mut().unwrap(), total.reverse_map2.as_mut().unwrap()); - std::mem::swap(delta.reverse_map2.as_mut().unwrap(), new.reverse_map2.as_mut().unwrap()); - } - } -} - - -pub struct BinRelToTernaryInd0<'a, T0, T1, T2, TBinRel>(&'a BinRelToTernary) -where T0: Clone + Hash + Eq, T1: Clone + Hash + Eq, T2: Clone + Hash + Eq, TBinRel: ByodsBinRel -; - -impl<'a, T0, T1, T2, TBinRel> RelIndexReadAll<'a> for BinRelToTernaryInd0<'a, T0, T1, T2, TBinRel> -where T0: Clone + Hash + Eq, T1: Clone + Hash + Eq, T2: Clone + Hash + Eq, TBinRel: ByodsBinRel -{ - type Key = (&'a T0, ); - type Value = (&'a T1, &'a T2); - - type ValueIteratorType = Box + 'a>; - type AllIteratorType = Box + 'a>; - - fn iter_all(&'a self) -> Self::AllIteratorType { - Box::new(self.0.map.iter().map(|(k, v)| { - ((k, ), Box::new(v.iter_all()) as _) - })) - } -} - -impl<'a, T0, T1, T2, TBinRel> RelIndexRead<'a> for BinRelToTernaryInd0<'a, T0, T1, T2, TBinRel> -where T0: Clone + Hash + Eq, T1: Clone + Hash + Eq, T2: Clone + Hash + Eq, TBinRel: ByodsBinRel -{ - type Key = (T0, ); - type Value = (&'a T1, &'a T2); - - type IteratorType = IteratorFromDyn<'a, Self::Value>; - - fn index_get(&'a self, key: &Self::Key) -> Option { - let trrel = self.0.map.get(&key.0)?; - Some(IteratorFromDyn::new(|| trrel.iter_all())) - } - - fn len(&self) -> usize { - let sample_size = 4; - let sum = self.0.map.values().map(|x| x.len_estimate()).sum::(); - sum * self.0.map.len() / sample_size.min(self.0.map.len()).max(1) - } -} - - -pub struct BinRelToTernaryInd0_1<'a, T0, T1, T2, TBinRel>(&'a BinRelToTernary) -where T0: Clone + Hash + Eq, T1: Clone + Hash + Eq, T2: Clone + Hash + Eq, TBinRel: ByodsBinRel -; - -impl<'a, T0, T1, T2, TBinRel> RelIndexReadAll<'a> for BinRelToTernaryInd0_1<'a, T0, T1, T2, TBinRel> -where T0: Clone + Hash + Eq, T1: Clone + Hash + Eq, T2: Clone + Hash + Eq, TBinRel: ByodsBinRel -{ - type Key = (&'a T0, &'a T1); - - type Value = (&'a T2, ); - - type ValueIteratorType = Map, fn(&T2) -> (&T2,)>; - - type AllIteratorType = Box + 'a>; - - fn iter_all(&'a self) -> Self::AllIteratorType { - Box::new(self.0.map.iter().flat_map(|(x0, v)| v.ind0_iter_all().map(move |(x1, x2s)| { - let iter: Self::ValueIteratorType = x2s.map(|x2| (x2, )); - ((x0, x1), iter) - }))) - } -} - -impl<'a, T0, T1, T2, TBinRel> RelIndexRead<'a> for BinRelToTernaryInd0_1<'a, T0, T1, T2, TBinRel> -where T0: Clone + Hash + Eq, T1: Clone + Hash + Eq, T2: Clone + Hash + Eq, TBinRel: ByodsBinRel -{ - type Key = (T0, T1); - - type Value = (&'a T2, ); - - type IteratorType = Map, fn(&T2) -> (&T2,)>; - - fn index_get(&'a self, key: &Self::Key) -> Option { - let trrel = self.0.map.get(&key.0)?; - let res: Self::IteratorType = trrel.ind0_index_get(&key.1)?.map(|x| (x, )); - Some(res) - } - - fn len(&self) -> usize { - let sample_size = 3; - let sum = self.0.map.values().take(sample_size).map(|trrel| trrel.ind0_len_estimate()).sum::(); - let map_len = self.0.map.len(); - sum * map_len / sample_size.min(map_len).max(1) - - } -} - - -pub struct BinRelToTernaryInd0_2<'a, T0, T1, T2, TBinRel>(&'a BinRelToTernary) -where T0: Clone + Hash + Eq, T1: Clone + Hash + Eq, T2: Clone + Hash + Eq, TBinRel: ByodsBinRel -; - -impl<'a, T0, T1, T2, TBinRel> RelIndexReadAll<'a> for BinRelToTernaryInd0_2<'a, T0, T1, T2, TBinRel> -where T0: Clone + Hash + Eq, T1: Clone + Hash + Eq, T2: Clone + Hash + Eq, TBinRel: ByodsBinRel -{ - type Key = (&'a T0, &'a T2); - - type Value = (&'a T1, ); - - type ValueIteratorType = Map, fn(&T1) -> (&T1,)>; - - type AllIteratorType = Box + 'a>; - - fn iter_all(&'a self) -> Self::AllIteratorType { - Box::new(self.0.map.iter().flat_map(|(x0, v)| v.ind1_iter_all().map(move |(x2, x1s)| { - let iter: Self::ValueIteratorType = x1s.map(|x2| (x2, )); - ((x0, x2), iter) - }))) - } -} - -impl<'a, T0, T1, T2, TBinRel> RelIndexRead<'a> for BinRelToTernaryInd0_2<'a, T0, T1, T2, TBinRel> -where T0: Clone + Hash + Eq, T1: Clone + Hash + Eq, T2: Clone + Hash + Eq, TBinRel: ByodsBinRel -{ - type Key = (T0, T2); - - type Value = (&'a T1, ); - - type IteratorType = Map, fn(&T1) -> (&T1,)>; - - fn index_get(&'a self, key: &Self::Key) -> Option { - let trrel = self.0.map.get(&key.0)?; - let res: Self::IteratorType = trrel.ind1_index_get(&key.1)?.map(|x| (x, )); - Some(res) - } - - fn len(&self) -> usize { - let sample_size = 3; - let sum = self.0.map.values().take(sample_size).map(|trrel| trrel.ind1_len_estimate()).sum::(); - let map_len = self.0.map.len(); - sum * map_len / sample_size.min(map_len).max(1) - } -} - - -pub struct BinRelToTernaryInd1<'a, T0, T1, T2, TBinRel>(&'a BinRelToTernary) -where T0: Clone + Hash + Eq, T1: Clone + Hash + Eq, T2: Clone + Hash + Eq, TBinRel: ByodsBinRel -; - -impl<'a, T0, T1, T2, TBinRel> RelIndexReadAll<'a> for BinRelToTernaryInd1<'a, T0, T1, T2, TBinRel> -where T0: Clone + Hash + Eq, T1: Clone + Hash + Eq, T2: Clone + Hash + Eq, TBinRel: ByodsBinRel -{ - type Key = (&'a T1, ); - type Value = (&'a T0, &'a T2); - - type ValueIteratorType = >::IteratorType; - type AllIteratorType = Box + 'a>; - - fn iter_all(&'a self) -> Self::AllIteratorType { - Box::new(self.0.reverse_map1.as_ref().unwrap().keys().map(|x1| { - ((x1, ), self.get(x1).unwrap()) - })) - } -} - -impl<'a, T0, T1, T2, TBinRel> BinRelToTernaryInd1<'a, T0, T1, T2, TBinRel> -where T0: Clone + Hash + Eq, T1: Clone + Hash + Eq, T2: Clone + Hash + Eq, TBinRel: ByodsBinRel -{ - fn get(&'a self, x1: &T1) -> Option<>::IteratorType> { - let (x1, x0s) = self.0.reverse_map1.as_ref().unwrap().get_key_value(x1)?; - let res = move || x0s.iter().filter_map(move |x0| { - Some(self.0.map.get(x0).unwrap().ind0_index_get(x1)?.map(move |x2| (x0, x2))) - }).flatten(); - Some(IteratorFromDyn::new(res)) - } -} - -impl<'a, T0, T1, T2, TBinRel> RelIndexRead<'a> for BinRelToTernaryInd1<'a, T0, T1, T2, TBinRel> -where T0: Clone + Hash + Eq, T1: Clone + Hash + Eq, T2: Clone + Hash + Eq, TBinRel: ByodsBinRel -{ - type Key = (T1, ); - - type Value = (&'a T0, &'a T2); - - type IteratorType = IteratorFromDyn<'a, Self::Value>; - - fn index_get(&'a self, (x1, ): &Self::Key) -> Option { - self.get(x1) - } - - fn len(&self) -> usize { - self.0.reverse_map1.as_ref().unwrap().len() - } -} - - -pub struct BinRelToTernaryInd2<'a, T0, T1, T2, TBinRel>(&'a BinRelToTernary) -where T0: Clone + Hash + Eq, T1: Clone + Hash + Eq, T2: Clone + Hash + Eq, TBinRel: ByodsBinRel -; - -impl<'a, T0, T1, T2, TBinRel> RelIndexReadAll<'a> for BinRelToTernaryInd2<'a, T0, T1, T2, TBinRel> -where T0: Clone + Hash + Eq, T1: Clone + Hash + Eq, T2: Clone + Hash + Eq, TBinRel: ByodsBinRel -{ - type Key = (&'a T2, ); - type Value = (&'a T0, &'a T1); - - type ValueIteratorType = >::IteratorType; - type AllIteratorType = Box + 'a>; - - fn iter_all(&'a self) -> Self::AllIteratorType { - Box::new(self.0.reverse_map2.as_ref().unwrap().keys().map(|x2| { - ((x2, ), self.get(x2).unwrap()) - })) - } -} - -impl<'a, T0, T1, T2, TBinRel> BinRelToTernaryInd2<'a, T0, T1, T2, TBinRel> -where T0: Clone + Hash + Eq, T1: Clone + Hash + Eq, T2: Clone + Hash + Eq, TBinRel: ByodsBinRel -{ - fn get(&'a self, x2: &T2) -> Option<>::IteratorType> { - let (x2, x0s) = self.0.reverse_map2.as_ref().unwrap().get_key_value(x2)?; - let res = move || x0s.iter().filter_map(move |x0| { - Some(self.0.map.get(x0).unwrap().ind1_index_get(x2)?.map(move |x1| (x0, x1))) - }).flatten(); - Some(IteratorFromDyn::new(res)) - } -} - -impl<'a, T0, T1, T2, TBinRel> RelIndexRead<'a> for BinRelToTernaryInd2<'a, T0, T1, T2, TBinRel> -where T0: Clone + Hash + Eq, T1: Clone + Hash + Eq, T2: Clone + Hash + Eq, TBinRel: ByodsBinRel -{ - type Key = (T2, ); - type Value = (&'a T0, &'a T1); - - type IteratorType = IteratorFromDyn<'a, Self::Value>; - - fn index_get(&'a self, (x2, ): &Self::Key) -> Option { - self.get(x2) - } - - fn len(&self) -> usize { - self.0.reverse_map2.as_ref().unwrap().len() - } -} - - -pub struct BinRelToTernaryInd1_2<'a, T0, T1, T2, TBinRel>(&'a BinRelToTernary) -where T0: Clone + Hash + Eq, T1: Clone + Hash + Eq, T2: Clone + Hash + Eq, TBinRel: ByodsBinRel -; - - -impl<'a, T0, T1, T2, TBinRel> RelIndexReadAll<'a> for BinRelToTernaryInd1_2<'a, T0, T1, T2, TBinRel> -where T0: Clone + Hash + Eq, T1: Clone + Hash + Eq, T2: Clone + Hash + Eq, TBinRel: ByodsBinRel -{ - type Key = (&'a T1, &'a T2); - type Value = (&'a T0, ); - - type ValueIteratorType = Box + 'a>; - type AllIteratorType = Box + 'a>; - - fn iter_all(&'a self) -> Self::AllIteratorType { - Box::new(self.0.reverse_map1.as_ref().unwrap().iter().flat_map(move |(x1, x0s_for_x1)| { - self.0.reverse_map2.as_ref().unwrap().iter().map(move |(x2, x0s_for_x2)| { - let x0s: Self::ValueIteratorType = Box::new(x0s_for_x1.intersection(x0s_for_x2) - .filter(|&x0| self.0.map.get(x0).unwrap().contains(x1, x2)) - .map(|x0| (x0, ))); - ((x1, x2), x0s) - }) - })) - } -} - -impl<'a, T0, T1, T2, TBinRel> RelIndexRead<'a> for BinRelToTernaryInd1_2<'a, T0, T1, T2, TBinRel> -where T0: Clone + Hash + Eq, T1: Clone + Hash + Eq, T2: Clone + Hash + Eq, TBinRel: ByodsBinRel -{ - type Key = (T1, T2); - - type Value = (&'a T0, ); - - type IteratorType = IteratorFromDyn<'a, Self::Value>; - fn index_get(&'a self, (x1, x2): &Self::Key) -> Option { - let (x1, x1_map) = self.0.reverse_map1.as_ref().unwrap().get_key_value(x1)?; - let (x2, x2_map) = self.0.reverse_map2.as_ref().unwrap().get_key_value(x2)?; - - let res = || x1_map.intersection(x2_map) - .filter(|&x0| self.0.map.get(x0).unwrap().contains(x1, x2)) - .map(|x0| (x0, )); - Some(IteratorFromDyn::new(res)) - } - - fn len(&self) -> usize { - // TODO random estimate, could be very wrong - self.0.reverse_map1.as_ref().unwrap().len() * self.0.reverse_map2.as_ref().unwrap().len() - / ((self.0.map.len() as f32).sqrt() as usize) - } -} - - -pub struct BinRelToTernaryIndNone<'a, T0, T1, T2, TBinRel>(&'a BinRelToTernary) -where T0: Clone + Hash + Eq, T1: Clone + Hash + Eq, T2: Clone + Hash + Eq, TBinRel: ByodsBinRel -; - -impl<'a, T0, T1, T2, TBinRel> RelIndexReadAll<'a> for BinRelToTernaryIndNone<'a, T0, T1, T2, TBinRel> -where T0: Clone + Hash + Eq, T1: Clone + Hash + Eq, T2: Clone + Hash + Eq, TBinRel: ByodsBinRel -{ - type Key = (); - type Value = (&'a T0, &'a T1, &'a T2); - - type ValueIteratorType = >::IteratorType; - type AllIteratorType = std::option::IntoIter<(Self::Key, Self::ValueIteratorType)>; - - fn iter_all(&'a self) -> Self::AllIteratorType { - self.index_get(&()).map(|x| ((), x)).into_iter() - } -} - -impl<'a, T0, T1, T2, TBinRel> RelIndexRead<'a> for BinRelToTernaryIndNone<'a, T0, T1, T2, TBinRel> -where T0: Clone + Hash + Eq, T1: Clone + Hash + Eq, T2: Clone + Hash + Eq, TBinRel: ByodsBinRel -{ - type Key = (); - type Value = (&'a T0, &'a T1, &'a T2); - - type IteratorType = IteratorFromDyn<'a, Self::Value>; - - fn index_get(&'a self, (): &Self::Key) -> Option { - let res = || self.0.map.iter().flat_map(|(x0, rel)| { - rel.iter_all().map(move |(x1, x2)| (x0, x1, x2)) - }); - Some(IteratorFromDyn::new(res)) - } - - fn len(&self) -> usize { - 1 - } -} - -pub struct BinRelToTernaryInd0_1_2<'a, T0, T1, T2, TBinRel>(&'a BinRelToTernary) -where T0: Clone + Hash + Eq, T1: Clone + Hash + Eq, T2: Clone + Hash + Eq, TBinRel: ByodsBinRel -; - -impl<'a, T0, T1, T2, TBinRel> RelFullIndexRead<'a> for BinRelToTernaryInd0_1_2<'a, T0, T1, T2, TBinRel> -where T0: Clone + Hash + Eq, T1: Clone + Hash + Eq, T2: Clone + Hash + Eq, TBinRel: ByodsBinRel -{ - type Key = (T0, T1, T2); - - #[inline] - fn contains_key(&'a self, (x0, x1, x2): &Self::Key) -> bool { - match self.0.map.get(x0) { - None => false, - Some(rel) => rel.contains(x1, x2), - } - } -} - -impl<'a, T0, T1, T2, TBinRel> RelIndexReadAll<'a> for BinRelToTernaryInd0_1_2<'a, T0, T1, T2, TBinRel> -where T0: Clone + Hash + Eq, T1: Clone + Hash + Eq, T2: Clone + Hash + Eq, TBinRel: ByodsBinRel -{ - type Key = (&'a T0, &'a T1, &'a T2); - type Value = (); - - type ValueIteratorType = std::iter::Once<()>; - type AllIteratorType = Box + 'a>; - - fn iter_all(&'a self) -> Self::AllIteratorType { - let iter = self.0.map.iter().flat_map(|(x0, rel)| { - rel.iter_all().map(move |(x1, x2)| (x0, x1, x2)) - }); - - Box::new(iter.map(|t| (t, once(())))) - } -} - -impl<'a, T0, T1, T2, TBinRel> RelIndexRead<'a> for BinRelToTernaryInd0_1_2<'a, T0, T1, T2, TBinRel> -where T0: Clone + Hash + Eq, T1: Clone + Hash + Eq, T2: Clone + Hash + Eq, TBinRel: ByodsBinRel -{ - type Key = (T0, T1, T2); - type Value = (); - - type IteratorType = std::iter::Once; - - fn index_get(&'a self, (x0, x1, x2): &Self::Key) -> Option { - if self.0.map.get(x0)?.contains(x1, x2) { - Some(once(())) - } else { - None - } - } - - fn len(&self) -> usize { - let sample_size = 3; - let sum = self.0.map.values().take(sample_size).map(|rel| rel.len_estimate()).sum::(); - let map_len = self.0.map.len(); - sum * map_len / sample_size.min(map_len).max(1) - } -} - - -pub struct BinRelToTernaryInd0_1_2Write<'a, T0, T1, T2, TBinRel>(&'a mut BinRelToTernary) -where T0: Clone + Hash + Eq, T1: Clone + Hash + Eq, T2: Clone + Hash + Eq, TBinRel: ByodsBinRel -; - -impl<'a, T0, T1, T2, TBinRel> RelIndexMerge for BinRelToTernaryInd0_1_2Write<'a, T0, T1, T2, TBinRel> -where T0: Clone + Hash + Eq, T1: Clone + Hash + Eq, T2: Clone + Hash + Eq, TBinRel: ByodsBinRel -{ - fn move_index_contents(_from: &mut Self, _to: &mut Self) { } // noop -} - -impl<'a, T0, T1, T2, TBinRel> RelFullIndexWrite for BinRelToTernaryInd0_1_2Write<'a, T0, T1, T2, TBinRel> -where T0: Clone + Hash + Eq, T1: Clone + Hash + Eq, T2: Clone + Hash + Eq, TBinRel: ByodsBinRel -{ - type Key = (T0, T1, T2); - type Value = (); - - fn insert_if_not_present(&mut self, (x0, x1, x2): &Self::Key, (): Self::Value) -> bool { - let x0_hash = hash_one(self.0.map.hasher(), x0); - - if !self.0.map.raw_entry_mut().from_key_hashed_nocheck(x0_hash, x0) - .or_insert_with(|| (x0.clone(), TBinRel::default())) - .1.insert(x1.clone(), x2.clone()) - { - return false; - } - if let Some(reverse_map1) = self.0.reverse_map1.as_mut() { - reverse_map1.entry(x1.clone()).or_default().insert_with_hash_no_check(x0_hash, x0.clone()); - } - if let Some(reverse_map2) = self.0.reverse_map2.as_mut() { - reverse_map2.entry(x2.clone()).or_default().insert_with_hash_no_check(x0_hash, x0.clone()); - } - true - } -} - -impl<'a, T0, T1, T2, TBinRel> RelIndexWrite for BinRelToTernaryInd0_1_2Write<'a, T0, T1, T2, TBinRel> -where T0: Clone + Hash + Eq, T1: Clone + Hash + Eq, T2: Clone + Hash + Eq, TBinRel: ByodsBinRel -{ - type Key = (T0, T1, T2); - type Value = (); - - fn index_insert(&mut self, (x0, x1, x2): Self::Key, (): Self::Value) { - if let Some(reverse_map1) = self.0.reverse_map1.as_mut() { - reverse_map1.entry(x1.clone()).or_default().insert(x0.clone()); - } - if let Some(reverse_map2) = self.0.reverse_map2.as_mut() { - reverse_map2.entry(x2.clone()).or_default().insert(x0.clone()); - } - self.0.map.entry(x0).or_insert_with(|| TBinRel::default()).insert(x1, x2); - } -} - - - -pub struct BinRelToTernaryWrapper -(pub BinRelToTernary) -where T0: Clone + Hash + Eq, T1: Clone + Hash + Eq, T2: Clone + Hash + Eq, TBinRel: ByodsBinRel -; - -impl RelIndexMerge for BinRelToTernaryWrapper -where T0: Clone + Hash + Eq, T1: Clone + Hash + Eq, T2: Clone + Hash + Eq, TBinRel: ByodsBinRel -{ - fn move_index_contents(from: &mut Self, to: &mut Self) { - RelIndexMerge::move_index_contents(&mut from.0, &mut to.0) - } - fn merge_delta_to_total_new_to_delta(new: &mut Self, delta: &mut Self, total: &mut Self) { - RelIndexMerge::merge_delta_to_total_new_to_delta(&mut new.0, &mut delta.0, &mut total.0) - } -} - -impl Default for BinRelToTernaryWrapper -where T0: Clone + Hash + Eq, T1: Clone + Hash + Eq, T2: Clone + Hash + Eq, TBinRel: ByodsBinRel -{ - fn default() -> Self { - let reverse_map1 = if HAS_REVERSE_MAP1 {Some(Default::default())} else {None}; - let reverse_map2 = if HAS_REVERSE_MAP2 {Some(Default::default())} else {None}; - Self(BinRelToTernary { map: Default::default(), reverse_map1, reverse_map2 }) - } -} - - -pub struct ToBinRelToTernaryInd0_1_2(PhantomData<(T0, T1, T2)>) -where T0: Clone + Hash + Eq, T1: Clone + Hash + Eq, T2: Clone + Hash + Eq//, TBinRel: ByodsBinRel -; - -impl<'a, T0, T1, T2> Default for ToBinRelToTernaryInd0_1_2 -where T0: Clone + Hash + Eq, T1: Clone + Hash + Eq, T2: Clone + Hash + Eq//, TBinRel: ByodsBinRel -{ - fn default() -> Self { Self(PhantomData) } -} - -impl -ToRelIndex> for ToBinRelToTernaryInd0_1_2 -where T0: Clone + Hash + Eq, T1: Clone + Hash + Eq, T2: Clone + Hash + Eq, TBinRel: ByodsBinRel, - // TWrapper: DerefMut> -{ - type RelIndex<'a> = BinRelToTernaryInd0_1_2<'a, T0, T1, T2, TBinRel> where Self:'a, BinRelToTernaryWrapper: 'a; - #[inline(always)] - fn to_rel_index<'a>(&'a self, rel: &'a BinRelToTernaryWrapper) -> Self::RelIndex<'a> { BinRelToTernaryInd0_1_2(&rel.0) } - - type RelIndexWrite<'a> = BinRelToTernaryInd0_1_2Write<'a, T0, T1, T2, TBinRel> where Self: 'a, BinRelToTernaryWrapper: 'a; - #[inline(always)] - fn to_rel_index_write<'a>(&'a mut self, rel: &'a mut BinRelToTernaryWrapper) -> Self::RelIndexWrite<'a> { - BinRelToTernaryInd0_1_2Write(&mut rel.0) - } -} - - -use crate::rel_boilerplate::NoopRelIndexWrite; -macro_rules! to_trrel2 { - ($name: ident, $key: ty, $val: ty) => {paste::paste!{ - pub struct [](PhantomData<(T0, T1, T2)>) - where T0: Clone + Hash + Eq, T1: Clone + Hash + Eq, T2: Clone + Hash + Eq//, TBinRel: ByodsBinRel - ; - - - impl<'a, T0, T1, T2> Default for [] - where T0: Clone + Hash + Eq, T1: Clone + Hash + Eq, T2: Clone + Hash + Eq//, TBinRel: ByodsBinRel - { - fn default() -> Self { Self(PhantomData) } - } - - impl - ToRelIndex> for [] - where T0: Clone + Hash + Eq, T1: Clone + Hash + Eq, T2: Clone + Hash + Eq, TBinRel: ByodsBinRel, - //TWrapper: Deref> - { - - type RelIndex<'a> = $name<'a, T0, T1, T2, TBinRel> where Self:'a, BinRelToTernaryWrapper: 'a; - #[inline(always)] - fn to_rel_index<'a>(&'a self, rel: &'a BinRelToTernaryWrapper) -> Self::RelIndex<'a> { $name(& rel.0) } - - type RelIndexWrite<'a> = NoopRelIndexWrite<$key, $val> where Self: 'a, BinRelToTernaryWrapper: 'a; - #[inline(always)] - fn to_rel_index_write<'a>(&'a mut self, _rel: &'a mut BinRelToTernaryWrapper) -> Self::RelIndexWrite<'a> { - NoopRelIndexWrite::default() - } - } - }}; -} - -to_trrel2!(BinRelToTernaryIndNone, (), (T0, T1, T2)); -to_trrel2!(BinRelToTernaryInd0, (T0, ), (T1, T2)); -to_trrel2!(BinRelToTernaryInd1, (T1, ), (T0, T2)); -to_trrel2!(BinRelToTernaryInd2, (T2, ), (T0, T1)); -to_trrel2!(BinRelToTernaryInd0_1, (T0, T1), (T2, )); -to_trrel2!(BinRelToTernaryInd0_2, (T0, T2), (T1, )); -to_trrel2!(BinRelToTernaryInd1_2, (T1, T2), (T0, )); +use std::hash::{BuildHasherDefault, Hash}; +use std::iter::{Map, once}; +use std::marker::PhantomData; + +use ascent::internal::{ + RelFullIndexRead, RelFullIndexWrite, RelIndexMerge, RelIndexRead, RelIndexReadAll, RelIndexWrite, ToRelIndex, +}; +use hashbrown::HashMap; +use rustc_hash::FxHasher; + +use super::bin_rel::ByodsBinRel; +use crate::iterator_from_dyn::IteratorFromDyn; +use crate::utils::{AltHashSet, hash_one}; + +pub struct BinRelToTernary +where + T0: Clone + Hash + Eq, + T1: Clone + Hash + Eq, + T2: Clone + Hash + Eq, + TBinRel: ByodsBinRel, +{ + pub map: HashMap>, + pub reverse_map1: Option>, BuildHasherDefault>>, + pub reverse_map2: Option>, BuildHasherDefault>>, +} + +impl RelIndexMerge for BinRelToTernary +where + T0: Clone + Hash + Eq, + T1: Clone + Hash + Eq, + T2: Clone + Hash + Eq, + TBinRel: ByodsBinRel, +{ + fn move_index_contents(_from: &mut Self, _to: &mut Self) { + panic!("merge_delta_to_total_new_to_delta must be called instead"); + } + + fn merge_delta_to_total_new_to_delta(new: &mut Self, delta: &mut Self, total: &mut Self) { + let mut new_delta_map = HashMap::default(); + for (k, mut delta_trrel) in delta.map.drain() { + let mut new_trrel = new.map.remove(&k).unwrap_or_default(); + match total.map.entry(k.clone()) { + hashbrown::hash_map::Entry::Occupied(mut total_entry) => { + TBinRel::merge_delta_to_total_new_to_delta(&mut new_trrel, &mut delta_trrel, total_entry.get_mut()); + if !delta_trrel.is_empty() { + new_delta_map.insert(k, delta_trrel); + } + }, + hashbrown::hash_map::Entry::Vacant(total_vacant_entry) => { + let mut new_total = TBinRel::default(); + TBinRel::merge_delta_to_total_new_to_delta(&mut new_trrel, &mut delta_trrel, &mut new_total); + total_vacant_entry.insert(new_total); + if !delta_trrel.is_empty() { + new_delta_map.insert(k, delta_trrel); + } + }, + } + } + for (k, mut new_trrel) in new.map.drain() { + let mut new_delta = Default::default(); + match total.map.entry(k.clone()) { + hashbrown::hash_map::Entry::Occupied(mut total_entry) => { + TBinRel::merge_delta_to_total_new_to_delta(&mut new_trrel, &mut new_delta, total_entry.get_mut()); + new_delta_map.insert(k, new_delta); + }, + hashbrown::hash_map::Entry::Vacant(_) => { + TBinRel::merge_delta_to_total_new_to_delta(&mut new_trrel, &mut new_delta, &mut Default::default()); + new_delta_map.insert(k, new_delta); + }, + } + } + delta.map = new_delta_map; + + if delta.reverse_map1.is_some() { + crate::utils::move_hash_map_of_alt_hash_set_contents( + delta.reverse_map1.as_mut().unwrap(), + total.reverse_map1.as_mut().unwrap(), + ); + std::mem::swap(delta.reverse_map1.as_mut().unwrap(), new.reverse_map1.as_mut().unwrap()); + } + + if delta.reverse_map2.is_some() { + crate::utils::move_hash_map_of_alt_hash_set_contents( + delta.reverse_map2.as_mut().unwrap(), + total.reverse_map2.as_mut().unwrap(), + ); + std::mem::swap(delta.reverse_map2.as_mut().unwrap(), new.reverse_map2.as_mut().unwrap()); + } + } +} + +pub struct BinRelToTernaryInd0<'a, T0, T1, T2, TBinRel>(&'a BinRelToTernary) +where + T0: Clone + Hash + Eq, + T1: Clone + Hash + Eq, + T2: Clone + Hash + Eq, + TBinRel: ByodsBinRel; + +impl<'a, T0, T1, T2, TBinRel> RelIndexReadAll<'a> for BinRelToTernaryInd0<'a, T0, T1, T2, TBinRel> +where + T0: Clone + Hash + Eq, + T1: Clone + Hash + Eq, + T2: Clone + Hash + Eq, + TBinRel: ByodsBinRel, +{ + type Key = (&'a T0,); + type Value = (&'a T1, &'a T2); + + type ValueIteratorType = Box + 'a>; + type AllIteratorType = Box + 'a>; + + fn iter_all(&'a self) -> Self::AllIteratorType { + Box::new(self.0.map.iter().map(|(k, v)| ((k,), Box::new(v.iter_all()) as _))) + } +} + +impl<'a, T0, T1, T2, TBinRel> RelIndexRead<'a> for BinRelToTernaryInd0<'a, T0, T1, T2, TBinRel> +where + T0: Clone + Hash + Eq, + T1: Clone + Hash + Eq, + T2: Clone + Hash + Eq, + TBinRel: ByodsBinRel, +{ + type Key = (T0,); + type Value = (&'a T1, &'a T2); + + type IteratorType = IteratorFromDyn<'a, Self::Value>; + + fn index_get(&'a self, key: &Self::Key) -> Option { + let trrel = self.0.map.get(&key.0)?; + Some(IteratorFromDyn::new(|| trrel.iter_all())) + } + + fn len(&self) -> usize { + let sample_size = 4; + let sum = self.0.map.values().map(|x| x.len_estimate()).sum::(); + sum * self.0.map.len() / sample_size.min(self.0.map.len()).max(1) + } +} + +pub struct BinRelToTernaryInd0_1<'a, T0, T1, T2, TBinRel>(&'a BinRelToTernary) +where + T0: Clone + Hash + Eq, + T1: Clone + Hash + Eq, + T2: Clone + Hash + Eq, + TBinRel: ByodsBinRel; + +impl<'a, T0, T1, T2, TBinRel> RelIndexReadAll<'a> for BinRelToTernaryInd0_1<'a, T0, T1, T2, TBinRel> +where + T0: Clone + Hash + Eq, + T1: Clone + Hash + Eq, + T2: Clone + Hash + Eq, + TBinRel: ByodsBinRel, +{ + type Key = (&'a T0, &'a T1); + + type Value = (&'a T2,); + + type ValueIteratorType = Map, fn(&T2) -> (&T2,)>; + + type AllIteratorType = Box + 'a>; + + fn iter_all(&'a self) -> Self::AllIteratorType { + Box::new(self.0.map.iter().flat_map(|(x0, v)| { + v.ind0_iter_all().map(move |(x1, x2s)| { + let iter: Self::ValueIteratorType = x2s.map(|x2| (x2,)); + ((x0, x1), iter) + }) + })) + } +} + +impl<'a, T0, T1, T2, TBinRel> RelIndexRead<'a> for BinRelToTernaryInd0_1<'a, T0, T1, T2, TBinRel> +where + T0: Clone + Hash + Eq, + T1: Clone + Hash + Eq, + T2: Clone + Hash + Eq, + TBinRel: ByodsBinRel, +{ + type Key = (T0, T1); + + type Value = (&'a T2,); + + type IteratorType = Map, fn(&T2) -> (&T2,)>; + + fn index_get(&'a self, key: &Self::Key) -> Option { + let trrel = self.0.map.get(&key.0)?; + let res: Self::IteratorType = trrel.ind0_index_get(&key.1)?.map(|x| (x,)); + Some(res) + } + + fn len(&self) -> usize { + let sample_size = 3; + let sum = self.0.map.values().take(sample_size).map(|trrel| trrel.ind0_len_estimate()).sum::(); + let map_len = self.0.map.len(); + sum * map_len / sample_size.min(map_len).max(1) + } +} + +pub struct BinRelToTernaryInd0_2<'a, T0, T1, T2, TBinRel>(&'a BinRelToTernary) +where + T0: Clone + Hash + Eq, + T1: Clone + Hash + Eq, + T2: Clone + Hash + Eq, + TBinRel: ByodsBinRel; + +impl<'a, T0, T1, T2, TBinRel> RelIndexReadAll<'a> for BinRelToTernaryInd0_2<'a, T0, T1, T2, TBinRel> +where + T0: Clone + Hash + Eq, + T1: Clone + Hash + Eq, + T2: Clone + Hash + Eq, + TBinRel: ByodsBinRel, +{ + type Key = (&'a T0, &'a T2); + + type Value = (&'a T1,); + + type ValueIteratorType = Map, fn(&T1) -> (&T1,)>; + + type AllIteratorType = Box + 'a>; + + fn iter_all(&'a self) -> Self::AllIteratorType { + Box::new(self.0.map.iter().flat_map(|(x0, v)| { + v.ind1_iter_all().map(move |(x2, x1s)| { + let iter: Self::ValueIteratorType = x1s.map(|x2| (x2,)); + ((x0, x2), iter) + }) + })) + } +} + +impl<'a, T0, T1, T2, TBinRel> RelIndexRead<'a> for BinRelToTernaryInd0_2<'a, T0, T1, T2, TBinRel> +where + T0: Clone + Hash + Eq, + T1: Clone + Hash + Eq, + T2: Clone + Hash + Eq, + TBinRel: ByodsBinRel, +{ + type Key = (T0, T2); + + type Value = (&'a T1,); + + type IteratorType = Map, fn(&T1) -> (&T1,)>; + + fn index_get(&'a self, key: &Self::Key) -> Option { + let trrel = self.0.map.get(&key.0)?; + let res: Self::IteratorType = trrel.ind1_index_get(&key.1)?.map(|x| (x,)); + Some(res) + } + + fn len(&self) -> usize { + let sample_size = 3; + let sum = self.0.map.values().take(sample_size).map(|trrel| trrel.ind1_len_estimate()).sum::(); + let map_len = self.0.map.len(); + sum * map_len / sample_size.min(map_len).max(1) + } +} + +pub struct BinRelToTernaryInd1<'a, T0, T1, T2, TBinRel>(&'a BinRelToTernary) +where + T0: Clone + Hash + Eq, + T1: Clone + Hash + Eq, + T2: Clone + Hash + Eq, + TBinRel: ByodsBinRel; + +impl<'a, T0, T1, T2, TBinRel> RelIndexReadAll<'a> for BinRelToTernaryInd1<'a, T0, T1, T2, TBinRel> +where + T0: Clone + Hash + Eq, + T1: Clone + Hash + Eq, + T2: Clone + Hash + Eq, + TBinRel: ByodsBinRel, +{ + type Key = (&'a T1,); + type Value = (&'a T0, &'a T2); + + type ValueIteratorType = >::IteratorType; + type AllIteratorType = Box + 'a>; + + fn iter_all(&'a self) -> Self::AllIteratorType { + Box::new(self.0.reverse_map1.as_ref().unwrap().keys().map(|x1| ((x1,), self.get(x1).unwrap()))) + } +} + +impl<'a, T0, T1, T2, TBinRel> BinRelToTernaryInd1<'a, T0, T1, T2, TBinRel> +where + T0: Clone + Hash + Eq, + T1: Clone + Hash + Eq, + T2: Clone + Hash + Eq, + TBinRel: ByodsBinRel, +{ + fn get(&'a self, x1: &T1) -> Option<>::IteratorType> { + let (x1, x0s) = self.0.reverse_map1.as_ref().unwrap().get_key_value(x1)?; + let res = move || { + x0s.iter() + .filter_map(move |x0| Some(self.0.map.get(x0).unwrap().ind0_index_get(x1)?.map(move |x2| (x0, x2)))) + .flatten() + }; + Some(IteratorFromDyn::new(res)) + } +} + +impl<'a, T0, T1, T2, TBinRel> RelIndexRead<'a> for BinRelToTernaryInd1<'a, T0, T1, T2, TBinRel> +where + T0: Clone + Hash + Eq, + T1: Clone + Hash + Eq, + T2: Clone + Hash + Eq, + TBinRel: ByodsBinRel, +{ + type Key = (T1,); + + type Value = (&'a T0, &'a T2); + + type IteratorType = IteratorFromDyn<'a, Self::Value>; + + fn index_get(&'a self, (x1,): &Self::Key) -> Option { self.get(x1) } + + fn len(&self) -> usize { self.0.reverse_map1.as_ref().unwrap().len() } +} + +pub struct BinRelToTernaryInd2<'a, T0, T1, T2, TBinRel>(&'a BinRelToTernary) +where + T0: Clone + Hash + Eq, + T1: Clone + Hash + Eq, + T2: Clone + Hash + Eq, + TBinRel: ByodsBinRel; + +impl<'a, T0, T1, T2, TBinRel> RelIndexReadAll<'a> for BinRelToTernaryInd2<'a, T0, T1, T2, TBinRel> +where + T0: Clone + Hash + Eq, + T1: Clone + Hash + Eq, + T2: Clone + Hash + Eq, + TBinRel: ByodsBinRel, +{ + type Key = (&'a T2,); + type Value = (&'a T0, &'a T1); + + type ValueIteratorType = >::IteratorType; + type AllIteratorType = Box + 'a>; + + fn iter_all(&'a self) -> Self::AllIteratorType { + Box::new(self.0.reverse_map2.as_ref().unwrap().keys().map(|x2| ((x2,), self.get(x2).unwrap()))) + } +} + +impl<'a, T0, T1, T2, TBinRel> BinRelToTernaryInd2<'a, T0, T1, T2, TBinRel> +where + T0: Clone + Hash + Eq, + T1: Clone + Hash + Eq, + T2: Clone + Hash + Eq, + TBinRel: ByodsBinRel, +{ + fn get(&'a self, x2: &T2) -> Option<>::IteratorType> { + let (x2, x0s) = self.0.reverse_map2.as_ref().unwrap().get_key_value(x2)?; + let res = move || { + x0s.iter() + .filter_map(move |x0| Some(self.0.map.get(x0).unwrap().ind1_index_get(x2)?.map(move |x1| (x0, x1)))) + .flatten() + }; + Some(IteratorFromDyn::new(res)) + } +} + +impl<'a, T0, T1, T2, TBinRel> RelIndexRead<'a> for BinRelToTernaryInd2<'a, T0, T1, T2, TBinRel> +where + T0: Clone + Hash + Eq, + T1: Clone + Hash + Eq, + T2: Clone + Hash + Eq, + TBinRel: ByodsBinRel, +{ + type Key = (T2,); + type Value = (&'a T0, &'a T1); + + type IteratorType = IteratorFromDyn<'a, Self::Value>; + + fn index_get(&'a self, (x2,): &Self::Key) -> Option { self.get(x2) } + + fn len(&self) -> usize { self.0.reverse_map2.as_ref().unwrap().len() } +} + +pub struct BinRelToTernaryInd1_2<'a, T0, T1, T2, TBinRel>(&'a BinRelToTernary) +where + T0: Clone + Hash + Eq, + T1: Clone + Hash + Eq, + T2: Clone + Hash + Eq, + TBinRel: ByodsBinRel; + +impl<'a, T0, T1, T2, TBinRel> RelIndexReadAll<'a> for BinRelToTernaryInd1_2<'a, T0, T1, T2, TBinRel> +where + T0: Clone + Hash + Eq, + T1: Clone + Hash + Eq, + T2: Clone + Hash + Eq, + TBinRel: ByodsBinRel, +{ + type Key = (&'a T1, &'a T2); + type Value = (&'a T0,); + + type ValueIteratorType = Box + 'a>; + type AllIteratorType = Box + 'a>; + + fn iter_all(&'a self) -> Self::AllIteratorType { + Box::new(self.0.reverse_map1.as_ref().unwrap().iter().flat_map(move |(x1, x0s_for_x1)| { + self.0.reverse_map2.as_ref().unwrap().iter().map(move |(x2, x0s_for_x2)| { + let x0s: Self::ValueIteratorType = Box::new( + x0s_for_x1 + .intersection(x0s_for_x2) + .filter(|&x0| self.0.map.get(x0).unwrap().contains(x1, x2)) + .map(|x0| (x0,)), + ); + ((x1, x2), x0s) + }) + })) + } +} + +impl<'a, T0, T1, T2, TBinRel> RelIndexRead<'a> for BinRelToTernaryInd1_2<'a, T0, T1, T2, TBinRel> +where + T0: Clone + Hash + Eq, + T1: Clone + Hash + Eq, + T2: Clone + Hash + Eq, + TBinRel: ByodsBinRel, +{ + type Key = (T1, T2); + + type Value = (&'a T0,); + + type IteratorType = IteratorFromDyn<'a, Self::Value>; + fn index_get(&'a self, (x1, x2): &Self::Key) -> Option { + let (x1, x1_map) = self.0.reverse_map1.as_ref().unwrap().get_key_value(x1)?; + let (x2, x2_map) = self.0.reverse_map2.as_ref().unwrap().get_key_value(x2)?; + + let res = + || x1_map.intersection(x2_map).filter(|&x0| self.0.map.get(x0).unwrap().contains(x1, x2)).map(|x0| (x0,)); + Some(IteratorFromDyn::new(res)) + } + + fn len(&self) -> usize { + // TODO random estimate, could be very wrong + self.0.reverse_map1.as_ref().unwrap().len() * self.0.reverse_map2.as_ref().unwrap().len() + / ((self.0.map.len() as f32).sqrt() as usize) + } +} + +pub struct BinRelToTernaryIndNone<'a, T0, T1, T2, TBinRel>(&'a BinRelToTernary) +where + T0: Clone + Hash + Eq, + T1: Clone + Hash + Eq, + T2: Clone + Hash + Eq, + TBinRel: ByodsBinRel; + +impl<'a, T0, T1, T2, TBinRel> RelIndexReadAll<'a> for BinRelToTernaryIndNone<'a, T0, T1, T2, TBinRel> +where + T0: Clone + Hash + Eq, + T1: Clone + Hash + Eq, + T2: Clone + Hash + Eq, + TBinRel: ByodsBinRel, +{ + type Key = (); + type Value = (&'a T0, &'a T1, &'a T2); + + type ValueIteratorType = >::IteratorType; + type AllIteratorType = std::option::IntoIter<(Self::Key, Self::ValueIteratorType)>; + + fn iter_all(&'a self) -> Self::AllIteratorType { self.index_get(&()).map(|x| ((), x)).into_iter() } +} + +impl<'a, T0, T1, T2, TBinRel> RelIndexRead<'a> for BinRelToTernaryIndNone<'a, T0, T1, T2, TBinRel> +where + T0: Clone + Hash + Eq, + T1: Clone + Hash + Eq, + T2: Clone + Hash + Eq, + TBinRel: ByodsBinRel, +{ + type Key = (); + type Value = (&'a T0, &'a T1, &'a T2); + + type IteratorType = IteratorFromDyn<'a, Self::Value>; + + fn index_get(&'a self, (): &Self::Key) -> Option { + let res = || self.0.map.iter().flat_map(|(x0, rel)| rel.iter_all().map(move |(x1, x2)| (x0, x1, x2))); + Some(IteratorFromDyn::new(res)) + } + + fn len(&self) -> usize { 1 } +} + +pub struct BinRelToTernaryInd0_1_2<'a, T0, T1, T2, TBinRel>(&'a BinRelToTernary) +where + T0: Clone + Hash + Eq, + T1: Clone + Hash + Eq, + T2: Clone + Hash + Eq, + TBinRel: ByodsBinRel; + +impl<'a, T0, T1, T2, TBinRel> RelFullIndexRead<'a> for BinRelToTernaryInd0_1_2<'a, T0, T1, T2, TBinRel> +where + T0: Clone + Hash + Eq, + T1: Clone + Hash + Eq, + T2: Clone + Hash + Eq, + TBinRel: ByodsBinRel, +{ + type Key = (T0, T1, T2); + + #[inline] + fn contains_key(&'a self, (x0, x1, x2): &Self::Key) -> bool { + match self.0.map.get(x0) { + None => false, + Some(rel) => rel.contains(x1, x2), + } + } +} + +impl<'a, T0, T1, T2, TBinRel> RelIndexReadAll<'a> for BinRelToTernaryInd0_1_2<'a, T0, T1, T2, TBinRel> +where + T0: Clone + Hash + Eq, + T1: Clone + Hash + Eq, + T2: Clone + Hash + Eq, + TBinRel: ByodsBinRel, +{ + type Key = (&'a T0, &'a T1, &'a T2); + type Value = (); + + type ValueIteratorType = std::iter::Once<()>; + type AllIteratorType = Box + 'a>; + + fn iter_all(&'a self) -> Self::AllIteratorType { + let iter = self.0.map.iter().flat_map(|(x0, rel)| rel.iter_all().map(move |(x1, x2)| (x0, x1, x2))); + + Box::new(iter.map(|t| (t, once(())))) + } +} + +impl<'a, T0, T1, T2, TBinRel> RelIndexRead<'a> for BinRelToTernaryInd0_1_2<'a, T0, T1, T2, TBinRel> +where + T0: Clone + Hash + Eq, + T1: Clone + Hash + Eq, + T2: Clone + Hash + Eq, + TBinRel: ByodsBinRel, +{ + type Key = (T0, T1, T2); + type Value = (); + + type IteratorType = std::iter::Once; + + fn index_get(&'a self, (x0, x1, x2): &Self::Key) -> Option { + if self.0.map.get(x0)?.contains(x1, x2) { Some(once(())) } else { None } + } + + fn len(&self) -> usize { + let sample_size = 3; + let sum = self.0.map.values().take(sample_size).map(|rel| rel.len_estimate()).sum::(); + let map_len = self.0.map.len(); + sum * map_len / sample_size.min(map_len).max(1) + } +} + +pub struct BinRelToTernaryInd0_1_2Write<'a, T0, T1, T2, TBinRel>(&'a mut BinRelToTernary) +where + T0: Clone + Hash + Eq, + T1: Clone + Hash + Eq, + T2: Clone + Hash + Eq, + TBinRel: ByodsBinRel; + +impl<'a, T0, T1, T2, TBinRel> RelIndexMerge for BinRelToTernaryInd0_1_2Write<'a, T0, T1, T2, TBinRel> +where + T0: Clone + Hash + Eq, + T1: Clone + Hash + Eq, + T2: Clone + Hash + Eq, + TBinRel: ByodsBinRel, +{ + fn move_index_contents(_from: &mut Self, _to: &mut Self) {} // noop +} + +impl<'a, T0, T1, T2, TBinRel> RelFullIndexWrite for BinRelToTernaryInd0_1_2Write<'a, T0, T1, T2, TBinRel> +where + T0: Clone + Hash + Eq, + T1: Clone + Hash + Eq, + T2: Clone + Hash + Eq, + TBinRel: ByodsBinRel, +{ + type Key = (T0, T1, T2); + type Value = (); + + fn insert_if_not_present(&mut self, (x0, x1, x2): &Self::Key, (): Self::Value) -> bool { + let x0_hash = hash_one(self.0.map.hasher(), x0); + + if !self + .0 + .map + .raw_entry_mut() + .from_key_hashed_nocheck(x0_hash, x0) + .or_insert_with(|| (x0.clone(), TBinRel::default())) + .1 + .insert(x1.clone(), x2.clone()) + { + return false; + } + if let Some(reverse_map1) = self.0.reverse_map1.as_mut() { + reverse_map1.entry(x1.clone()).or_default().insert_with_hash_no_check(x0_hash, x0.clone()); + } + if let Some(reverse_map2) = self.0.reverse_map2.as_mut() { + reverse_map2.entry(x2.clone()).or_default().insert_with_hash_no_check(x0_hash, x0.clone()); + } + true + } +} + +impl<'a, T0, T1, T2, TBinRel> RelIndexWrite for BinRelToTernaryInd0_1_2Write<'a, T0, T1, T2, TBinRel> +where + T0: Clone + Hash + Eq, + T1: Clone + Hash + Eq, + T2: Clone + Hash + Eq, + TBinRel: ByodsBinRel, +{ + type Key = (T0, T1, T2); + type Value = (); + + fn index_insert(&mut self, (x0, x1, x2): Self::Key, (): Self::Value) { + if let Some(reverse_map1) = self.0.reverse_map1.as_mut() { + reverse_map1.entry(x1.clone()).or_default().insert(x0.clone()); + } + if let Some(reverse_map2) = self.0.reverse_map2.as_mut() { + reverse_map2.entry(x2.clone()).or_default().insert(x0.clone()); + } + self.0.map.entry(x0).or_insert_with(|| TBinRel::default()).insert(x1, x2); + } +} + +pub struct BinRelToTernaryWrapper( + pub BinRelToTernary, +) +where + T0: Clone + Hash + Eq, + T1: Clone + Hash + Eq, + T2: Clone + Hash + Eq, + TBinRel: ByodsBinRel; + +impl RelIndexMerge + for BinRelToTernaryWrapper +where + T0: Clone + Hash + Eq, + T1: Clone + Hash + Eq, + T2: Clone + Hash + Eq, + TBinRel: ByodsBinRel, +{ + fn move_index_contents(from: &mut Self, to: &mut Self) { RelIndexMerge::move_index_contents(&mut from.0, &mut to.0) } + fn merge_delta_to_total_new_to_delta(new: &mut Self, delta: &mut Self, total: &mut Self) { + RelIndexMerge::merge_delta_to_total_new_to_delta(&mut new.0, &mut delta.0, &mut total.0) + } +} + +impl Default + for BinRelToTernaryWrapper +where + T0: Clone + Hash + Eq, + T1: Clone + Hash + Eq, + T2: Clone + Hash + Eq, + TBinRel: ByodsBinRel, +{ + fn default() -> Self { + let reverse_map1 = if HAS_REVERSE_MAP1 { Some(Default::default()) } else { None }; + let reverse_map2 = if HAS_REVERSE_MAP2 { Some(Default::default()) } else { None }; + Self(BinRelToTernary { map: Default::default(), reverse_map1, reverse_map2 }) + } +} + +pub struct ToBinRelToTernaryInd0_1_2(PhantomData<(T0, T1, T2)>) +where + T0: Clone + Hash + Eq, + T1: Clone + Hash + Eq, + T2: Clone + Hash + Eq; + +impl<'a, T0, T1, T2> Default for ToBinRelToTernaryInd0_1_2 +where + T0: Clone + Hash + Eq, + T1: Clone + Hash + Eq, + T2: Clone + Hash + Eq, //, TBinRel: ByodsBinRel +{ + fn default() -> Self { Self(PhantomData) } +} + +impl + ToRelIndex> + for ToBinRelToTernaryInd0_1_2 +where + T0: Clone + Hash + Eq, + T1: Clone + Hash + Eq, + T2: Clone + Hash + Eq, + TBinRel: ByodsBinRel, + // TWrapper: DerefMut> +{ + type RelIndex<'a> + = BinRelToTernaryInd0_1_2<'a, T0, T1, T2, TBinRel> + where + Self: 'a, + BinRelToTernaryWrapper: 'a; + #[inline(always)] + fn to_rel_index<'a>( + &'a self, rel: &'a BinRelToTernaryWrapper, + ) -> Self::RelIndex<'a> { + BinRelToTernaryInd0_1_2(&rel.0) + } + + type RelIndexWrite<'a> + = BinRelToTernaryInd0_1_2Write<'a, T0, T1, T2, TBinRel> + where + Self: 'a, + BinRelToTernaryWrapper: 'a; + #[inline(always)] + fn to_rel_index_write<'a>( + &'a mut self, rel: &'a mut BinRelToTernaryWrapper, + ) -> Self::RelIndexWrite<'a> { + BinRelToTernaryInd0_1_2Write(&mut rel.0) + } +} + +use crate::rel_boilerplate::NoopRelIndexWrite; +macro_rules! to_trrel2 { + ($name: ident, $key: ty, $val: ty) => {paste::paste!{ + pub struct [](PhantomData<(T0, T1, T2)>) + where T0: Clone + Hash + Eq, T1: Clone + Hash + Eq, T2: Clone + Hash + Eq//, TBinRel: ByodsBinRel + ; + + + impl<'a, T0, T1, T2> Default for [] + where T0: Clone + Hash + Eq, T1: Clone + Hash + Eq, T2: Clone + Hash + Eq//, TBinRel: ByodsBinRel + { + fn default() -> Self { Self(PhantomData) } + } + + impl + ToRelIndex> for [] + where T0: Clone + Hash + Eq, T1: Clone + Hash + Eq, T2: Clone + Hash + Eq, TBinRel: ByodsBinRel, + //TWrapper: Deref> + { + + type RelIndex<'a> = $name<'a, T0, T1, T2, TBinRel> where Self:'a, BinRelToTernaryWrapper: 'a; + #[inline(always)] + fn to_rel_index<'a>(&'a self, rel: &'a BinRelToTernaryWrapper) -> Self::RelIndex<'a> { $name(& rel.0) } + + type RelIndexWrite<'a> = NoopRelIndexWrite<$key, $val> where Self: 'a, BinRelToTernaryWrapper: 'a; + #[inline(always)] + fn to_rel_index_write<'a>(&'a mut self, _rel: &'a mut BinRelToTernaryWrapper) -> Self::RelIndexWrite<'a> { + NoopRelIndexWrite::default() + } + } + }}; +} + +to_trrel2!(BinRelToTernaryIndNone, (), (T0, T1, T2)); +to_trrel2!(BinRelToTernaryInd0, (T0,), (T1, T2)); +to_trrel2!(BinRelToTernaryInd1, (T1,), (T0, T2)); +to_trrel2!(BinRelToTernaryInd2, (T2,), (T0, T1)); +to_trrel2!(BinRelToTernaryInd0_1, (T0, T1), (T2,)); +to_trrel2!(BinRelToTernaryInd0_2, (T0, T2), (T1,)); +to_trrel2!(BinRelToTernaryInd1_2, (T1, T2), (T0,)); diff --git a/byods/ascent-byods-rels/src/adaptor/mod.rs b/byods/ascent-byods-rels/src/adaptor/mod.rs index e1cd15d..55eecd3 100644 --- a/byods/ascent-byods-rels/src/adaptor/mod.rs +++ b/byods/ascent-byods-rels/src/adaptor/mod.rs @@ -1,5 +1,5 @@ -//! Utilities for easily implementing custom data structures for binary relations -pub mod bin_rel; -pub mod bin_rel_provider; -pub mod bin_rel_to_ternary; -pub mod bin_rel_plus_ternary_provider; \ No newline at end of file +//! Utilities for easily implementing custom data structures for binary relations +pub mod bin_rel; +pub mod bin_rel_provider; +pub mod bin_rel_to_ternary; +pub mod bin_rel_plus_ternary_provider; diff --git a/byods/ascent-byods-rels/src/binary_rel.rs b/byods/ascent-byods-rels/src/binary_rel.rs index 65f0d8c..8e682c5 100644 --- a/byods/ascent-byods-rels/src/binary_rel.rs +++ b/byods/ascent-byods-rels/src/binary_rel.rs @@ -1,168 +1,173 @@ -use std::hash::{Hash, BuildHasherDefault}; - -use ascent::internal::{RelIndexRead, RelIndexReadAll}; -use hashbrown::hash_map::Iter; -use hashbrown::HashMap; -use rustc_hash::FxHasher; - -pub type MyHashSetIter<'a, T> = hashbrown::hash_set::Iter<'a, T>; -pub type MyHashSet = hashbrown::HashSet; - -pub type Map = HashMap>, BuildHasherDefault>; -pub type RevMap = HashMap, BuildHasherDefault>; - -#[derive(Clone)] -pub struct BinaryRel { - pub(crate) map: Map, - pub(crate) reverse_map: RevMap, -} - -impl Default for BinaryRel { - fn default() -> Self { - Self { - map: Default::default(), - reverse_map: Default::default(), - } - } -} - -impl BinaryRel { - - /// returns true if this tuple did not exist in the binary relation - pub fn insert(&mut self, x: T, y: T) -> bool { - if self.map.entry(x.clone()).or_default().insert(y.clone()) { - self.reverse_map.entry(y).or_default().push(x); - true - } else { - false - } - } - - /// returns true if this tuple did not exist in the binary relation - pub fn insert_by_ref(&mut self, x: &T, y: &T) -> bool { - let added = match self.map.raw_entry_mut().from_key(x) { - hashbrown::hash_map::RawEntryMut::Occupied(mut occ) => { - occ.get_mut().insert(y.clone()) - }, - hashbrown::hash_map::RawEntryMut::Vacant(vac) => { - vac.insert(x.clone(), MyHashSet::from_iter([y.clone()])); - true - }, - }; - if added { - match self.reverse_map.raw_entry_mut().from_key(y) { - hashbrown::hash_map::RawEntryMut::Occupied(mut occ) => {occ.get_mut().push(x.clone());}, - hashbrown::hash_map::RawEntryMut::Vacant(vac) => {vac.insert(y.clone(), vec![x.clone()]);}, - }; - true - } else { - false - } - // if self.map.raw_entry_mut().from_key(x)..entry(x.clone()).or_default().insert(y.clone()) { - // self.reverse_map.entry(y).or_default().push(x); - // true - // } else { - // false - // } - } - - pub fn iter_all(&self) -> impl Iterator + '_ { - self.map.iter().flat_map(|(x, x_set)| x_set.iter().map(move |y| (x, y))) - } - - #[inline] - pub fn contains(&self, x: &T, y: &T) -> bool { - self.map.get(x).map_or(false, |s| s.contains(y)) - } - - pub fn count_estimate(&self) -> usize { - let sample_size = 3; - let sum = self.map.values().take(sample_size).map(|x| x.len()).sum::(); - sum * self.map.len() / sample_size.min(self.map.len()).max(1) - } - - pub fn count_exact(&self) -> usize { - self.map.values().map(|x| x.len()).sum() - } -} - -pub struct MapRelIndexAdaptor<'a, T: Clone + Hash + Eq>(pub &'a HashMap>, BuildHasherDefault>); - -impl<'a, T: Clone + Hash + Eq + 'a> RelIndexReadAll<'a> for MapRelIndexAdaptor<'a, T> { - type Key = &'a T; - - type Value = &'a T; - - type ValueIteratorType = MyHashSetIter<'a, T>; - - type AllIteratorType = std::iter::Map>>, for<'aa> fn((&'aa T, &'a MyHashSet>)) -> (&'aa T, Self::ValueIteratorType)>; - - fn iter_all(&'a self) -> Self::AllIteratorType { - let res: Self::AllIteratorType = self.0.iter().map(|(k, v)| { - let v_iter = v.iter(); - (k, v_iter) - }); - res - } -} - -impl<'a, T: Clone + Hash + Eq> RelIndexRead<'a> for MapRelIndexAdaptor<'a, T> { - type Key = T; - - type Value = &'a T; - - type IteratorType = hashbrown::hash_set::Iter<'a, T>; - - fn index_get(&'a self, key: &Self::Key) -> Option { - let set = self.0.get(key)?; - let res = set.iter(); - Some(res) - } - - fn len(&'a self) -> usize { - self.0.len() - } -} - -pub struct RelIndexValTransformer{ - rel: T, - f: F -} - -impl<'a, T: 'a, F: 'a, V: 'a, U: 'a> RelIndexRead<'a> for RelIndexValTransformer -where T: RelIndexRead<'a, Value = V>, F: Fn(V) -> U -{ - type Key = T::Key; - type Value = U; - - type IteratorType = std::iter::Map>>, for <'aa> fn((V, &'aa RelIndexValTransformer)) -> U>; - - fn index_get(&'a self, key: &Self::Key) -> Option { - let res: Self::IteratorType = self.rel.index_get(key)?.zip(std::iter::repeat(self)).map(|(x, _self)| (_self.f)(x)); - Some(res) - } - - fn len(&'a self) -> usize { - self.rel.len() - } -} - -impl<'a, T: 'a, F: 'a, V: 'a, U: 'a> RelIndexReadAll<'a> for RelIndexValTransformer -where T: RelIndexReadAll<'a, Value = V>, F: Fn(V) -> U -{ - type Key = T::Key; - type Value = U; - - type ValueIteratorType = std::iter::Map>>, for <'aa> fn((V, &'aa RelIndexValTransformer)) -> U>; - - type AllIteratorType = Box + 'a>; - - fn iter_all(&'a self) -> Self::AllIteratorType { - let res = self.rel.iter_all().map(move |(k, vals_iter)| { - let new_vals_iter: Self::ValueIteratorType = vals_iter.zip(std::iter::repeat(self)).map(|(x, _self)| (_self.f)(x)); - (k, new_vals_iter) - }); - - Box::new(res) - } -} +use std::hash::{BuildHasherDefault, Hash}; + +use ascent::internal::{RelIndexRead, RelIndexReadAll}; +use hashbrown::HashMap; +use hashbrown::hash_map::Iter; +use rustc_hash::FxHasher; + +pub type MyHashSetIter<'a, T> = hashbrown::hash_set::Iter<'a, T>; +pub type MyHashSet = hashbrown::HashSet; + +pub type Map = HashMap>, BuildHasherDefault>; +pub type RevMap = HashMap, BuildHasherDefault>; + +#[derive(Clone)] +pub struct BinaryRel { + pub(crate) map: Map, + pub(crate) reverse_map: RevMap, +} + +impl Default for BinaryRel { + fn default() -> Self { Self { map: Default::default(), reverse_map: Default::default() } } +} + +impl BinaryRel { + /// returns true if this tuple did not exist in the binary relation + pub fn insert(&mut self, x: T, y: T) -> bool { + if self.map.entry(x.clone()).or_default().insert(y.clone()) { + self.reverse_map.entry(y).or_default().push(x); + true + } else { + false + } + } + + /// returns true if this tuple did not exist in the binary relation + pub fn insert_by_ref(&mut self, x: &T, y: &T) -> bool { + let added = match self.map.raw_entry_mut().from_key(x) { + hashbrown::hash_map::RawEntryMut::Occupied(mut occ) => occ.get_mut().insert(y.clone()), + hashbrown::hash_map::RawEntryMut::Vacant(vac) => { + vac.insert(x.clone(), MyHashSet::from_iter([y.clone()])); + true + }, + }; + if added { + match self.reverse_map.raw_entry_mut().from_key(y) { + hashbrown::hash_map::RawEntryMut::Occupied(mut occ) => { + occ.get_mut().push(x.clone()); + }, + hashbrown::hash_map::RawEntryMut::Vacant(vac) => { + vac.insert(y.clone(), vec![x.clone()]); + }, + }; + true + } else { + false + } + // if self.map.raw_entry_mut().from_key(x)..entry(x.clone()).or_default().insert(y.clone()) { + // self.reverse_map.entry(y).or_default().push(x); + // true + // } else { + // false + // } + } + + pub fn iter_all(&self) -> impl Iterator + '_ { + self.map.iter().flat_map(|(x, x_set)| x_set.iter().map(move |y| (x, y))) + } + + #[inline] + pub fn contains(&self, x: &T, y: &T) -> bool { self.map.get(x).map_or(false, |s| s.contains(y)) } + + pub fn count_estimate(&self) -> usize { + let sample_size = 3; + let sum = self.map.values().take(sample_size).map(|x| x.len()).sum::(); + sum * self.map.len() / sample_size.min(self.map.len()).max(1) + } + + pub fn count_exact(&self) -> usize { self.map.values().map(|x| x.len()).sum() } +} + +pub struct MapRelIndexAdaptor<'a, T: Clone + Hash + Eq>( + pub &'a HashMap>, BuildHasherDefault>, +); + +impl<'a, T: Clone + Hash + Eq + 'a> RelIndexReadAll<'a> for MapRelIndexAdaptor<'a, T> { + type Key = &'a T; + + type Value = &'a T; + + type ValueIteratorType = MyHashSetIter<'a, T>; + + type AllIteratorType = std::iter::Map< + Iter<'a, T, MyHashSet>>, + for<'aa> fn((&'aa T, &'a MyHashSet>)) -> (&'aa T, Self::ValueIteratorType), + >; + + fn iter_all(&'a self) -> Self::AllIteratorType { + let res: Self::AllIteratorType = self.0.iter().map(|(k, v)| { + let v_iter = v.iter(); + (k, v_iter) + }); + res + } +} + +impl<'a, T: Clone + Hash + Eq> RelIndexRead<'a> for MapRelIndexAdaptor<'a, T> { + type Key = T; + + type Value = &'a T; + + type IteratorType = hashbrown::hash_set::Iter<'a, T>; + + fn index_get(&'a self, key: &Self::Key) -> Option { + let set = self.0.get(key)?; + let res = set.iter(); + Some(res) + } + + fn len(&'a self) -> usize { self.0.len() } +} + +pub struct RelIndexValTransformer { + rel: T, + f: F, +} + +impl<'a, T: 'a, F: 'a, V: 'a, U: 'a> RelIndexRead<'a> for RelIndexValTransformer +where + T: RelIndexRead<'a, Value = V>, + F: Fn(V) -> U, +{ + type Key = T::Key; + type Value = U; + + type IteratorType = std::iter::Map< + std::iter::Zip>>, + for<'aa> fn((V, &'aa RelIndexValTransformer)) -> U, + >; + + fn index_get(&'a self, key: &Self::Key) -> Option { + let res: Self::IteratorType = + self.rel.index_get(key)?.zip(std::iter::repeat(self)).map(|(x, _self)| (_self.f)(x)); + Some(res) + } + + fn len(&'a self) -> usize { self.rel.len() } +} + +impl<'a, T: 'a, F: 'a, V: 'a, U: 'a> RelIndexReadAll<'a> for RelIndexValTransformer +where + T: RelIndexReadAll<'a, Value = V>, + F: Fn(V) -> U, +{ + type Key = T::Key; + type Value = U; + + type ValueIteratorType = std::iter::Map< + std::iter::Zip>>, + for<'aa> fn((V, &'aa RelIndexValTransformer)) -> U, + >; + + type AllIteratorType = Box + 'a>; + + fn iter_all(&'a self) -> Self::AllIteratorType { + let res = self.rel.iter_all().map(move |(k, vals_iter)| { + let new_vals_iter: Self::ValueIteratorType = + vals_iter.zip(std::iter::repeat(self)).map(|(x, _self)| (_self.f)(x)); + (k, new_vals_iter) + }); + + Box::new(res) + } +} diff --git a/byods/ascent-byods-rels/src/ceqrel_ind.rs b/byods/ascent-byods-rels/src/ceqrel_ind.rs index 8c6d4e8..1e9c5b8 100644 --- a/byods/ascent-byods-rels/src/ceqrel_ind.rs +++ b/byods/ascent-byods-rels/src/ceqrel_ind.rs @@ -1,608 +1,609 @@ -use hashbrown::HashSet; -use std::hash::{Hash, BuildHasherDefault}; -use std::iter::{FlatMap, Map, Repeat, Zip}; -use std::marker::PhantomData; -use std::mem::transmute; -use std::sync::Mutex; - -use ascent::internal::{RelFullIndexRead, RelIndexRead, RelIndexReadAll, RelIndexWrite, RelIndexMerge, RelFullIndexWrite, CRelIndexWrite, CRelFullIndexWrite, Freezable}; -use ascent::rayon::prelude::{ParallelIterator, IntoParallelRefIterator}; -use ascent::internal::{CRelIndexRead, CRelIndexReadAll}; -use ascent::internal::{ToRelIndex, ToRelIndex0}; -use rustc_hash::FxHasher; -use ascent::rayon; - -use crate::iterator_from_dyn::IteratorFromDyn; -use crate::union_find::EqRel; - -use hashbrown::hash_set::Iter as HashSetIter; - - -pub struct EqRelInd0<'a, T: Clone + Hash + Eq>(pub(crate) &'a CEqRelIndCommon); - -pub struct ToEqRelIndNone(PhantomData); -impl Freezable for ToEqRelIndNone { } -impl Default for ToEqRelIndNone { - fn default() -> Self { Self(PhantomData) } -} -impl ToRelIndex> for ToEqRelIndNone { - type RelIndex<'a> = EqRelIndNone<'a, T> where T: 'a; - fn to_rel_index<'a>(&'a self, rel: &'a CEqRelIndCommon) -> Self::RelIndex<'a> { EqRelIndNone(rel) } - - type RelIndexWrite<'a> = EqRelIndNone<'a, T> where T: 'a; - fn to_rel_index_write<'a>(&'a mut self, rel: &'a mut CEqRelIndCommon) -> Self::RelIndexWrite<'a> { EqRelIndNone(rel) } -} - - -pub struct ToEqRelInd0(PhantomData); -impl Freezable for ToEqRelInd0 { } - -impl Default for ToEqRelInd0 { - fn default() -> Self { Self(Default::default()) } -} - -impl ToRelIndex> for ToEqRelInd0 { - type RelIndex<'a> = EqRelInd0<'a, T> where T: 'a; - fn to_rel_index<'a>(&'a self, rel: &'a CEqRelIndCommon) -> Self::RelIndex<'a> { EqRelInd0(rel) } - - type RelIndexWrite<'a> = EqRelInd0<'a, T> where T: 'a; - fn to_rel_index_write<'a>(&'a mut self, rel: &'a mut CEqRelIndCommon) -> Self::RelIndexWrite<'a> { EqRelInd0(rel) } -} - -pub struct ToEqRelInd0_1(PhantomData); -impl Freezable for ToEqRelInd0_1 { } - - -impl Default for ToEqRelInd0_1 { - fn default() -> Self { Self(Default::default()) } -} - -pub struct EqRelInd0_1<'a, T: Clone + Hash + Eq>(&'a CEqRelIndCommon); -pub struct EqRelInd0_1Write<'a, T: Clone + Hash + Eq>(&'a mut CEqRelIndCommon); -pub struct EqRelInd0_1CWrite<'a, T: Clone + Hash + Eq>(&'a CEqRelIndCommon); - - -impl <'a, T: Clone + Hash + Eq> RelIndexWrite for EqRelInd0_1Write<'a, T> { - type Key = (T, T); - type Value = (); - - fn index_insert(&mut self, key: Self::Key, value: Self::Value) { - self.0.index_insert(key, value) - } -} - -impl <'a, T: Clone + Hash + Eq> CRelIndexWrite for EqRelInd0_1CWrite<'a, T> { - type Key = (T, T); - type Value = (); - - fn index_insert(&self, key: Self::Key, (): Self::Value) { - self.0.unwrap_unfrozen().lock().unwrap().add(key.0, key.1); - } -} - - -impl <'a, T: Clone + Hash + Eq> RelIndexMerge for EqRelInd0_1Write<'a, T> { - fn move_index_contents(_from: &mut Self, _to: &mut Self) { /* noop */ } -} - -impl RelFullIndexWrite for CEqRelIndCommon { - type Key = (T, T); - type Value = (); - - fn insert_if_not_present(&mut self, key: &Self::Key, _v: Self::Value) -> bool { - self.unwrap_mut_unfrozen().add(key.0.clone(), key.1.clone()) - } -} - -impl CRelFullIndexWrite for CEqRelIndCommon { - type Key = (T, T); - type Value = (); - - fn insert_if_not_present(&self, key: &Self::Key, _v: Self::Value) -> bool { - self.unwrap_unfrozen().lock().unwrap().add(key.0.clone(), key.1.clone()) - } -} - - -impl <'a, T: Clone + Hash + Eq> RelFullIndexWrite for EqRelInd0_1Write<'a, T> { - type Key = as RelFullIndexWrite>::Key; - type Value = as RelFullIndexWrite>::Value; - fn insert_if_not_present(&mut self, key: &Self::Key, v: Self::Value) -> bool { - self.0.insert_if_not_present(key, v) - } -} - -impl <'a, T: Clone + Hash + Eq> CRelFullIndexWrite for EqRelInd0_1CWrite<'a, T> { - type Key = as CRelFullIndexWrite>::Key; - type Value = as CRelFullIndexWrite>::Value; - fn insert_if_not_present(&self, key: &Self::Key, v: Self::Value) -> bool { - CRelFullIndexWrite::insert_if_not_present(&self.0, key, v) - } -} - -impl<'a, T: Clone + Hash + Eq> RelIndexRead<'a> for EqRelInd0_1<'a, T> { - type Key = as RelIndexRead<'a>>::Key; - type Value = as RelIndexRead<'a>>::Value; - type IteratorType = as RelIndexRead<'a>>::IteratorType; - - fn index_get(&'a self, key: &Self::Key) -> Option { - self.0.index_get(key) - } - - fn len(&self) -> usize { - self.0.len() - } -} - -impl<'a, T: Clone + Hash + Eq + Sync> CRelIndexRead<'a> for EqRelInd0_1<'a, T> { - type Key = as CRelIndexRead<'a>>::Key; - type Value = as CRelIndexRead<'a>>::Value; - type IteratorType = as CRelIndexRead<'a>>::IteratorType; - - fn c_index_get(&'a self, key: &Self::Key) -> Option { - self.0.c_index_get(key) - } -} - -impl<'a, T: Clone + Hash + Eq> RelIndexReadAll<'a> for EqRelInd0_1<'a, T> { - type Key = as RelIndexReadAll<'a>>::Key; - type Value = as RelIndexReadAll<'a>>::Value; - type ValueIteratorType = as RelIndexReadAll<'a>>::ValueIteratorType; - type AllIteratorType = as RelIndexReadAll<'a>>::AllIteratorType; - fn iter_all(&'a self) -> Self::AllIteratorType { - self.0.iter_all() - } -} - -impl<'a, T: Clone + Hash + Eq + Sync> CRelIndexReadAll<'a> for EqRelInd0_1<'a, T> { - type Key = as CRelIndexReadAll<'a>>::Key; - type Value = as CRelIndexReadAll<'a>>::Value; - type ValueIteratorType = as CRelIndexReadAll<'a>>::ValueIteratorType; - type AllIteratorType = as CRelIndexReadAll<'a>>::AllIteratorType; - fn c_iter_all(&'a self) -> Self::AllIteratorType { - self.0.c_iter_all() - } -} - -impl<'a, T: Clone + Hash + Eq> RelFullIndexRead<'a> for EqRelInd0_1<'a, T> { - type Key = as RelFullIndexRead<'a>>::Key; - fn contains_key(&self, key: &Self::Key) -> bool { - self.0.contains_key(key) - } -} - -impl ToRelIndex0> for ToEqRelInd0_1 { - type RelIndex<'a> = EqRelInd0_1<'a, T> where T: 'a; - fn to_rel_index<'a>(&'a self, rel: &'a CEqRelIndCommon) -> Self::RelIndex<'a> { EqRelInd0_1(rel) } - - type RelIndexWrite<'a> = EqRelInd0_1Write<'a, T> where T: 'a; - fn to_rel_index_write<'a>(&'a mut self, rel: &'a mut CEqRelIndCommon) -> Self::RelIndexWrite<'a> { EqRelInd0_1Write(rel) } - - type CRelIndexWrite<'a> = EqRelInd0_1CWrite<'a, T> where Self: 'a, CEqRelIndCommon: 'a; - fn to_c_rel_index_write<'a>(&'a self, rel: &'a CEqRelIndCommon) -> Self::CRelIndexWrite<'a> { EqRelInd0_1CWrite(rel) } -} - -impl<'a, T: Clone + Hash + Eq> RelIndexRead<'a> for EqRelInd0<'a, T> { - type Key = (T,); - type Value = (&'a T,); - - type IteratorType = IteratorFromDyn<'a, (&'a T,)>; - - fn index_get(&'a self, key: &Self::Key) -> Option { - let _ = self.0.set_of_added(&key.0)?; - let key = key.clone(); - let producer = move || self.0.set_of_added(&key.0).unwrap().map(|x| (x,)); - - Some(IteratorFromDyn::new(producer)) - } - - fn len(&self) -> usize { - self.0.unwrap_frozen().combined.elem_ids.len() - } -} - -impl<'a, T: Clone + Hash + Eq + Sync> CRelIndexRead<'a> for EqRelInd0<'a, T> { - type Key = (T,); - type Value = (&'a T,); - - type IteratorType = rayon::iter::Map, fn(&T) -> (&T,)>; - - fn c_index_get(&'a self, key: &Self::Key) -> Option { - let set = self.0.c_set_of_added(&key.0)?; - let res: Self::IteratorType = set.map(|x| (x, )); - Some(res) - } -} - -impl<'a, T: Clone + Hash + Eq> RelIndexReadAll<'a> for EqRelInd0<'a, T> { - type Key = &'a (T,); - type Value = (&'a T,); - - type ValueIteratorType = Map, for<'aa> fn(&'aa T) -> (&'aa T,)>; - - type AllIteratorType = FlatMap>>, Map, Repeat>>, for<'aa> fn((&'aa T, HashSetIter<'aa, T>),) -> (&'aa (T,), Map, for<'bb> fn(&'bb T) -> (&'bb T,)>),>, for<'aa> fn(&'aa HashSet>, ) -> Map< - Zip, Repeat>>, - for<'cc> fn((&'cc T, HashSetIter<'cc, T>),) -> (&'cc (T,), Map, for<'dd> fn(&'dd T) -> (&'dd T,)>),>, >; - - fn iter_all(&'a self) -> Self::AllIteratorType { - let res: Self::AllIteratorType = self.0.unwrap_frozen().combined.sets.iter().flat_map(|s| { - s.iter().zip(std::iter::repeat(s.iter())).map(|(x, s)| (ref_to_singleton_tuple_ref(x), s.map(|x| (x,)))) - }); - res - } -} - -pub struct EqRelInd0CRelIndexReadAllIter<'a, T: Clone + Hash + Eq + Sync>(&'a EqRelPair); - -impl<'a, T: Clone + Hash + Eq + Sync + Send> ParallelIterator for EqRelInd0CRelIndexReadAllIter<'a, T> { - type Item = (&'a (T,), rayon::iter::Map, fn(&T) -> (&T,)>); - - fn drive_unindexed(self, consumer: C) -> C::Result - where C: rayon::iter::plumbing::UnindexedConsumer - { - self.0.combined.sets.par_iter().flat_map(|s| { - s.par_iter().map(|x| { - let vals_iter: rayon::iter::Map, fn(&T) -> (&T,)> - = s.par_iter().map(|x| (x,)); - (ref_to_singleton_tuple_ref(x), vals_iter) - }) - }).drive_unindexed(consumer) - } -} - -impl<'a, T: Clone + Hash + Eq + Sync + Send> CRelIndexReadAll<'a> for EqRelInd0<'a, T> { - type Key = &'a (T,); - type Value = (&'a T,); - - type AllIteratorType = EqRelInd0CRelIndexReadAllIter<'a, T>; - - type ValueIteratorType = rayon::iter::Map, fn(&T) -> (&T,)>; - - fn c_iter_all(&'a self) -> Self::AllIteratorType { - EqRelInd0CRelIndexReadAllIter(self.0.unwrap_frozen()) - } -} - - -impl<'a, T: Clone + Hash + Eq> RelIndexWrite for EqRelInd0<'a, T> { - type Key = (T,); - type Value = (T,); - fn index_insert(&mut self, _key: Self::Key, _value: Self::Value) { /* noop */ } -} - -impl<'a, T: Clone + Hash + Eq> CRelIndexWrite for EqRelInd0<'a, T> { - type Key = (T,); - type Value = (T,); - fn index_insert(&self, _key: Self::Key, _value: Self::Value) { /* noop */ } -} - -impl<'a, T: Clone + Hash + Eq> RelIndexMerge for EqRelInd0<'a, T> { - fn move_index_contents(_from: &mut Self, _to: &mut Self) { /* noop */ } -} - - -pub enum CEqRelIndCommon { - Unfrozen(Mutex>), - Frozen(EqRelPair) -} - -#[derive(Clone)] -pub struct EqRelPair { - pub(crate) old: EqRel, - pub(crate) combined: EqRel, -} - -impl Default for EqRelPair { - fn default() -> Self { Self { old: Default::default(), combined: Default::default() } } -} - -impl Freezable for CEqRelIndCommon { } - -impl CEqRelIndCommon { - fn unwrap_frozen(&self) -> &EqRelPair { - match self { - CEqRelIndCommon::Frozen(old) => old, - CEqRelIndCommon::Unfrozen(_) => panic!("unwrap_frozen() called on Unfrozen"), - } - } - fn unwrap_mut_frozen(&mut self) -> &mut EqRelPair { - match self { - CEqRelIndCommon::Frozen(old) => old, - CEqRelIndCommon::Unfrozen(_) => panic!("unwrap_mut_frozen() called on Unfrozen"), - } - } - fn unwrap_mut_unfrozen(&mut self) -> &mut EqRel { - match self { - CEqRelIndCommon::Unfrozen(uf) => uf.get_mut().unwrap(), - CEqRelIndCommon::Frozen(_) => panic!("unwrap_mut_unfrozen called on Frozen"), - } - } - fn unwrap_unfrozen(&self) -> &Mutex> { - match self { - CEqRelIndCommon::Unfrozen(uf) => uf, - CEqRelIndCommon::Frozen(_) => panic!("unwrap_unfrozen called on Frozen"), - } - } -} - -impl Clone for CEqRelIndCommon { - fn clone(&self) -> Self { - match self { - CEqRelIndCommon::Unfrozen(new) => CEqRelIndCommon::Unfrozen(Mutex::new(new.lock().unwrap().clone())), - CEqRelIndCommon::Frozen(old) => CEqRelIndCommon::Frozen(old.clone()), - } - } -} - -#[derive(Clone)] -pub struct AllAddedParIter<'a, T: Clone + Hash + Eq + Sync>(&'a EqRelPair); - -impl<'a, T: Clone + Hash + Eq + Sync> ParallelIterator for AllAddedParIter<'a, T> { - type Item = (&'a T, &'a T); - - fn drive_unindexed(self, consumer: C) -> C::Result - where C: ascent::rayon::iter::plumbing::UnindexedConsumer { - self.0.combined.c_iter_all().filter(|(x, y)| !self.0.old.contains(x, y)) - .drive_unindexed(consumer) - } -} - -#[derive(Clone)] -pub struct SetOfAddedParIter<'a, T: Clone + Hash + Eq + Sync>{ - set: &'a hashbrown::HashSet>, - old_set: Option<&'a HashSet>> -} - -impl<'a, T: Clone + Hash + Eq + Sync> ParallelIterator for SetOfAddedParIter<'a, T> { - type Item = &'a T; - - fn drive_unindexed(self, consumer: C) -> C::Result - where C: ascent::rayon::iter::plumbing::UnindexedConsumer { - self.set.par_iter().filter(move |y| !self.old_set.map_or(false, |os| os.contains(*y))) - .drive_unindexed(consumer) - } -} - -impl CEqRelIndCommon { - - pub(crate) fn iter_all_added(&self) -> impl Iterator { - let eqrel = self.unwrap_frozen(); - eqrel.combined.iter_all().filter(|(x, y)| !eqrel.old.contains(x, y)) - } - - pub(crate) fn c_iter_all_added(&self) -> AllAddedParIter<'_, T> where T: Sync { - let eqrel = self.unwrap_frozen(); - AllAddedParIter(eqrel) - } - - pub(crate) fn set_of_added(&self, x: &T) -> Option> { - let self_ = self.unwrap_frozen(); - let set = self_.combined.set_of(x)?; - // let old_set = self.old.set_of(x).into_iter().flatten(); - let old_set = self_.old.elem_set(x).map(|id| &self_.old.sets[id]); - Some(set.filter(move |y| !old_set.map_or(false, |os| os.contains(*y)))) - } - - pub(crate) fn c_set_of_added(&self, x: &T) -> Option> where T: Sync{ - let self_ = self.unwrap_frozen(); - let set = self_.combined.c_set_of(x)?; - let old_set = self_.old.elem_set(x).map(|id| &self_.old.sets[id]); - Some(SetOfAddedParIter{ set, old_set }) - } - - #[allow(dead_code)] - pub(crate) fn added_contains(&self, x: &T, y: &T) -> bool { - let self_ = self.unwrap_frozen(); - self_.combined.contains(x, y) && !self_.old.contains(x, y) - } - -} - -impl Default for CEqRelIndCommon { - fn default() -> Self { - Self::Frozen(Default::default()) - } -} - -impl<'a, T: Clone + Hash + Eq + 'a> RelIndexRead<'a> for CEqRelIndCommon { - type Key = &'a (T, T); - type Value = (); - - type IteratorType = std::iter::Once<()>; - - fn index_get(&'a self, (x, y): &Self::Key) -> Option { - let self_ = self.unwrap_frozen(); - if self_.combined.contains(x, y) && !self_.old.contains(x, y) { - Some(std::iter::once(())) - } else { - None - } - } - - fn len(&self) -> usize { - let self_ = self.unwrap_frozen(); - let sample_size = 3; - let sum: usize = self_.combined.sets.iter().take(sample_size).map(|s| s.len().pow(2)).sum(); - let sets_len = self_.combined.sets.len(); - sum * sets_len / sample_size.min(sets_len).max(1) - } -} - -impl<'a, T: Clone + Hash + Eq + Sync + 'a> CRelIndexRead<'a> for CEqRelIndCommon { - type Key = &'a (T, T); - type Value = (); - - type IteratorType = ascent::rayon::iter::Once<()>; - - fn c_index_get(&'a self, (x, y): &Self::Key) -> Option { - let self_ = self.unwrap_frozen(); - if self_.combined.contains(x, y) && !self_.old.contains(x, y) { - Some(ascent::rayon::iter::once(())) - } else { - None - } - } -} - - -impl<'a, T: Clone + Hash + Eq + 'a> RelIndexReadAll<'a> for CEqRelIndCommon { - type Key = (&'a T, &'a T); - type Value = (); - - type ValueIteratorType = std::iter::Once<()>; - - type AllIteratorType = Box + 'a>; - - fn iter_all(&'a self) -> Self::AllIteratorType { - Box::new(self.iter_all_added().map(|x| (x, std::iter::once(())))) - } -} - -impl<'a, T: Clone + Hash + Eq + Sync + 'a> CRelIndexReadAll<'a> for CEqRelIndCommon { - type Key = (&'a T, &'a T); - type Value = (); - - type ValueIteratorType = ascent::rayon::iter::Once<()>; - - type AllIteratorType = ascent::rayon::iter::Map, for<'aa, 'bb> fn((&'aa T, &'bb T)) -> ((&'aa T, &'bb T), ascent::rayon::iter::Once<()>)>; - - fn c_iter_all(&'a self) -> Self::AllIteratorType { - let res: Self::AllIteratorType - = self.c_iter_all_added().map(|x| (x, ascent::rayon::iter::once(()))); - res - } -} - -impl<'a, T: Clone + Hash + Eq> RelFullIndexRead<'a> for CEqRelIndCommon { - type Key = (T, T); - - fn contains_key(&'a self, (x, y): &Self::Key) -> bool { - let self_ = self.unwrap_frozen(); - self_.combined.contains(x, y) && !self_.old.contains(x, y) - } -} - -impl<'a, T: Clone + Hash + Eq> RelIndexWrite for CEqRelIndCommon { - type Key = (T, T); - type Value = (); - - fn index_insert(&mut self, key: Self::Key, _value: Self::Value) { - self.unwrap_mut_unfrozen().add(key.0, key.1); - } - -} - -impl<'a, T: Clone + Hash + Eq> CRelIndexWrite for CEqRelIndCommon { - type Key = (T, T); - type Value = (); - - fn index_insert(&self, key: Self::Key, _value: Self::Value) { - self.unwrap_unfrozen().lock().unwrap().add(key.0, key.1); - } -} - -impl<'a, T: Clone + Hash + Eq> RelIndexMerge for CEqRelIndCommon { - fn move_index_contents(_from: &mut Self, _to: &mut Self) { - unimplemented!("merge_delta_to_total_new_to_delta must be used instead") - } - - fn init(new: &mut Self, _delta: &mut Self, _total: &mut Self) { - *new = Self::Unfrozen(Mutex::new(Default::default())) - } - - fn merge_delta_to_total_new_to_delta(new: &mut Self, delta: &mut Self, total: &mut Self) { - let total = total.unwrap_mut_frozen(); - let delta = delta.unwrap_mut_frozen(); - total.combined = delta.combined.clone(); - delta.old = total.combined.clone(); - - // delta.combined.combine(new.combined.clone()); - delta.combined.combine(std::mem::take(&mut new.unwrap_mut_unfrozen())); - } - -} - -pub struct EqRelIndNone<'a, T: Clone + Hash + Eq>(&'a CEqRelIndCommon); - -impl<'a, T: Clone + Hash + Eq> RelIndexRead<'a> for EqRelIndNone<'a, T> { - type Key = (); - type Value = (&'a T, &'a T); - - type IteratorType = IteratorFromDyn<'a, (&'a T, &'a T)>; - - fn index_get(&'a self, _key: &Self::Key) -> Option { - Some(IteratorFromDyn::new(|| self.0.iter_all_added())) - } - - fn len(&self) -> usize { - 1 - } -} - -impl<'a, T: Clone + Hash + Eq + Sync> CRelIndexRead<'a> for EqRelIndNone<'a, T> { - type Key = (); - type Value = (&'a T, &'a T); - - type IteratorType = AllAddedParIter<'a, T>; - - fn c_index_get(&'a self, _key: &Self::Key) -> Option { - Some(self.0.c_iter_all_added()) - } -} - -impl<'a, T: Clone + Hash + Eq> RelIndexReadAll<'a> for EqRelIndNone<'a, T> { - type Key = (); - - type Value = (&'a T, &'a T); - - type ValueIteratorType = IteratorFromDyn<'a, (&'a T, &'a T)>; - - type AllIteratorType = std::option::IntoIter<(Self::Key, Self::ValueIteratorType)>; - - fn iter_all(&'a self) -> Self::AllIteratorType { - self.index_get(&()).map(|iter| ((), iter)).into_iter() - } -} - -impl<'a, T: Clone + Hash + Eq + Sync> CRelIndexReadAll<'a> for EqRelIndNone<'a, T> { - type Key = (); - type Value = (&'a T, &'a T); - - type ValueIteratorType = AllAddedParIter<'a, T>; - - type AllIteratorType = ascent::rayon::iter::Once<(Self::Key, Self::ValueIteratorType)>; - - fn c_iter_all(&'a self) -> Self::AllIteratorType { - ascent::rayon::iter::once(((), self.0.c_iter_all_added())) - } -} - -impl<'a, T: Clone + Hash + Eq> RelIndexWrite for EqRelIndNone<'a, T> { - type Key = (); - type Value = (T, T); - fn index_insert(&mut self, _key: Self::Key, _value: Self::Value) { /* noop */ } -} - -impl<'a, T: Clone + Hash + Eq> CRelIndexWrite for EqRelIndNone<'a, T> { - type Key = (); - type Value = (T, T); - fn index_insert(&self, _key: Self::Key, _value: Self::Value) { /* noop */ } -} - -impl<'a, T: Clone + Hash + Eq> RelIndexMerge for EqRelIndNone<'a, T> { - fn move_index_contents(_from: &mut Self, _to: &mut Self) { /* noop */ } -} - -// TODO this is not safe, and not required. Get rid of it. -pub(crate) fn ref_to_singleton_tuple_ref(x: &T) -> &(T,) { unsafe { transmute(x) } } - -#[test] -fn test_ref_to_singleton_tuple_ref() { - use std::mem::size_of; - println!("size_of::>(): {}", size_of::>()); - println!("size_of::<(Vec,)>(): {}", size_of::<(Vec,)>()); - - let x = vec![1, 2, 3]; - let x2 = ref_to_singleton_tuple_ref(&x); - assert_eq!(&x, &x2.0); -} +use std::hash::{BuildHasherDefault, Hash}; +use std::iter::{FlatMap, Map, Repeat, Zip}; +use std::marker::PhantomData; +use std::mem::transmute; +use std::sync::Mutex; + +use ascent::internal::{ + CRelFullIndexWrite, CRelIndexRead, CRelIndexReadAll, CRelIndexWrite, Freezable, RelFullIndexRead, RelFullIndexWrite, + RelIndexMerge, RelIndexRead, RelIndexReadAll, RelIndexWrite, ToRelIndex, ToRelIndex0, +}; +use ascent::rayon; +use ascent::rayon::prelude::{IntoParallelRefIterator, ParallelIterator}; +use hashbrown::HashSet; +use hashbrown::hash_set::Iter as HashSetIter; +use rustc_hash::FxHasher; + +use crate::iterator_from_dyn::IteratorFromDyn; +use crate::union_find::EqRel; + +pub struct EqRelInd0<'a, T: Clone + Hash + Eq>(pub(crate) &'a CEqRelIndCommon); + +pub struct ToEqRelIndNone(PhantomData); +impl Freezable for ToEqRelIndNone {} +impl Default for ToEqRelIndNone { + fn default() -> Self { Self(PhantomData) } +} +impl ToRelIndex> for ToEqRelIndNone { + type RelIndex<'a> + = EqRelIndNone<'a, T> + where T: 'a; + fn to_rel_index<'a>(&'a self, rel: &'a CEqRelIndCommon) -> Self::RelIndex<'a> { EqRelIndNone(rel) } + + type RelIndexWrite<'a> + = EqRelIndNone<'a, T> + where T: 'a; + fn to_rel_index_write<'a>(&'a mut self, rel: &'a mut CEqRelIndCommon) -> Self::RelIndexWrite<'a> { + EqRelIndNone(rel) + } +} + +pub struct ToEqRelInd0(PhantomData); +impl Freezable for ToEqRelInd0 {} + +impl Default for ToEqRelInd0 { + fn default() -> Self { Self(Default::default()) } +} + +impl ToRelIndex> for ToEqRelInd0 { + type RelIndex<'a> + = EqRelInd0<'a, T> + where T: 'a; + fn to_rel_index<'a>(&'a self, rel: &'a CEqRelIndCommon) -> Self::RelIndex<'a> { EqRelInd0(rel) } + + type RelIndexWrite<'a> + = EqRelInd0<'a, T> + where T: 'a; + fn to_rel_index_write<'a>(&'a mut self, rel: &'a mut CEqRelIndCommon) -> Self::RelIndexWrite<'a> { + EqRelInd0(rel) + } +} + +pub struct ToEqRelInd0_1(PhantomData); +impl Freezable for ToEqRelInd0_1 {} + +impl Default for ToEqRelInd0_1 { + fn default() -> Self { Self(Default::default()) } +} + +pub struct EqRelInd0_1<'a, T: Clone + Hash + Eq>(&'a CEqRelIndCommon); +pub struct EqRelInd0_1Write<'a, T: Clone + Hash + Eq>(&'a mut CEqRelIndCommon); +pub struct EqRelInd0_1CWrite<'a, T: Clone + Hash + Eq>(&'a CEqRelIndCommon); + +impl<'a, T: Clone + Hash + Eq> RelIndexWrite for EqRelInd0_1Write<'a, T> { + type Key = (T, T); + type Value = (); + + fn index_insert(&mut self, key: Self::Key, value: Self::Value) { self.0.index_insert(key, value) } +} + +impl<'a, T: Clone + Hash + Eq> CRelIndexWrite for EqRelInd0_1CWrite<'a, T> { + type Key = (T, T); + type Value = (); + + fn index_insert(&self, key: Self::Key, (): Self::Value) { + self.0.unwrap_unfrozen().lock().unwrap().add(key.0, key.1); + } +} + +impl<'a, T: Clone + Hash + Eq> RelIndexMerge for EqRelInd0_1Write<'a, T> { + fn move_index_contents(_from: &mut Self, _to: &mut Self) { /* noop */ + } +} + +impl RelFullIndexWrite for CEqRelIndCommon { + type Key = (T, T); + type Value = (); + + fn insert_if_not_present(&mut self, key: &Self::Key, _v: Self::Value) -> bool { + self.unwrap_mut_unfrozen().add(key.0.clone(), key.1.clone()) + } +} + +impl CRelFullIndexWrite for CEqRelIndCommon { + type Key = (T, T); + type Value = (); + + fn insert_if_not_present(&self, key: &Self::Key, _v: Self::Value) -> bool { + self.unwrap_unfrozen().lock().unwrap().add(key.0.clone(), key.1.clone()) + } +} + +impl<'a, T: Clone + Hash + Eq> RelFullIndexWrite for EqRelInd0_1Write<'a, T> { + type Key = as RelFullIndexWrite>::Key; + type Value = as RelFullIndexWrite>::Value; + fn insert_if_not_present(&mut self, key: &Self::Key, v: Self::Value) -> bool { self.0.insert_if_not_present(key, v) } +} + +impl<'a, T: Clone + Hash + Eq> CRelFullIndexWrite for EqRelInd0_1CWrite<'a, T> { + type Key = as CRelFullIndexWrite>::Key; + type Value = as CRelFullIndexWrite>::Value; + fn insert_if_not_present(&self, key: &Self::Key, v: Self::Value) -> bool { + CRelFullIndexWrite::insert_if_not_present(&self.0, key, v) + } +} + +impl<'a, T: Clone + Hash + Eq> RelIndexRead<'a> for EqRelInd0_1<'a, T> { + type Key = as RelIndexRead<'a>>::Key; + type Value = as RelIndexRead<'a>>::Value; + type IteratorType = as RelIndexRead<'a>>::IteratorType; + + fn index_get(&'a self, key: &Self::Key) -> Option { self.0.index_get(key) } + + fn len(&self) -> usize { self.0.len() } +} + +impl<'a, T: Clone + Hash + Eq + Sync> CRelIndexRead<'a> for EqRelInd0_1<'a, T> { + type Key = as CRelIndexRead<'a>>::Key; + type Value = as CRelIndexRead<'a>>::Value; + type IteratorType = as CRelIndexRead<'a>>::IteratorType; + + fn c_index_get(&'a self, key: &Self::Key) -> Option { self.0.c_index_get(key) } +} + +impl<'a, T: Clone + Hash + Eq> RelIndexReadAll<'a> for EqRelInd0_1<'a, T> { + type Key = as RelIndexReadAll<'a>>::Key; + type Value = as RelIndexReadAll<'a>>::Value; + type ValueIteratorType = as RelIndexReadAll<'a>>::ValueIteratorType; + type AllIteratorType = as RelIndexReadAll<'a>>::AllIteratorType; + fn iter_all(&'a self) -> Self::AllIteratorType { self.0.iter_all() } +} + +impl<'a, T: Clone + Hash + Eq + Sync> CRelIndexReadAll<'a> for EqRelInd0_1<'a, T> { + type Key = as CRelIndexReadAll<'a>>::Key; + type Value = as CRelIndexReadAll<'a>>::Value; + type ValueIteratorType = as CRelIndexReadAll<'a>>::ValueIteratorType; + type AllIteratorType = as CRelIndexReadAll<'a>>::AllIteratorType; + fn c_iter_all(&'a self) -> Self::AllIteratorType { self.0.c_iter_all() } +} + +impl<'a, T: Clone + Hash + Eq> RelFullIndexRead<'a> for EqRelInd0_1<'a, T> { + type Key = as RelFullIndexRead<'a>>::Key; + fn contains_key(&self, key: &Self::Key) -> bool { self.0.contains_key(key) } +} + +impl ToRelIndex0> for ToEqRelInd0_1 { + type RelIndex<'a> + = EqRelInd0_1<'a, T> + where T: 'a; + fn to_rel_index<'a>(&'a self, rel: &'a CEqRelIndCommon) -> Self::RelIndex<'a> { EqRelInd0_1(rel) } + + type RelIndexWrite<'a> + = EqRelInd0_1Write<'a, T> + where T: 'a; + fn to_rel_index_write<'a>(&'a mut self, rel: &'a mut CEqRelIndCommon) -> Self::RelIndexWrite<'a> { + EqRelInd0_1Write(rel) + } + + type CRelIndexWrite<'a> + = EqRelInd0_1CWrite<'a, T> + where + Self: 'a, + CEqRelIndCommon: 'a; + fn to_c_rel_index_write<'a>(&'a self, rel: &'a CEqRelIndCommon) -> Self::CRelIndexWrite<'a> { + EqRelInd0_1CWrite(rel) + } +} + +impl<'a, T: Clone + Hash + Eq> RelIndexRead<'a> for EqRelInd0<'a, T> { + type Key = (T,); + type Value = (&'a T,); + + type IteratorType = IteratorFromDyn<'a, (&'a T,)>; + + fn index_get(&'a self, key: &Self::Key) -> Option { + let _ = self.0.set_of_added(&key.0)?; + let key = key.clone(); + let producer = move || self.0.set_of_added(&key.0).unwrap().map(|x| (x,)); + + Some(IteratorFromDyn::new(producer)) + } + + fn len(&self) -> usize { self.0.unwrap_frozen().combined.elem_ids.len() } +} + +impl<'a, T: Clone + Hash + Eq + Sync> CRelIndexRead<'a> for EqRelInd0<'a, T> { + type Key = (T,); + type Value = (&'a T,); + + type IteratorType = rayon::iter::Map, fn(&T) -> (&T,)>; + + fn c_index_get(&'a self, key: &Self::Key) -> Option { + let set = self.0.c_set_of_added(&key.0)?; + let res: Self::IteratorType = set.map(|x| (x,)); + Some(res) + } +} + +impl<'a, T: Clone + Hash + Eq> RelIndexReadAll<'a> for EqRelInd0<'a, T> { + type Key = &'a (T,); + type Value = (&'a T,); + + type ValueIteratorType = Map, for<'aa> fn(&'aa T) -> (&'aa T,)>; + + type AllIteratorType = FlatMap< + std::slice::Iter<'a, HashSet>>, + Map< + Zip, Repeat>>, + for<'aa> fn( + (&'aa T, HashSetIter<'aa, T>), + ) -> (&'aa (T,), Map, for<'bb> fn(&'bb T) -> (&'bb T,)>), + >, + for<'aa> fn( + &'aa HashSet>, + ) -> Map< + Zip, Repeat>>, + for<'cc> fn( + (&'cc T, HashSetIter<'cc, T>), + ) -> (&'cc (T,), Map, for<'dd> fn(&'dd T) -> (&'dd T,)>), + >, + >; + + fn iter_all(&'a self) -> Self::AllIteratorType { + let res: Self::AllIteratorType = self.0.unwrap_frozen().combined.sets.iter().flat_map(|s| { + s.iter().zip(std::iter::repeat(s.iter())).map(|(x, s)| (ref_to_singleton_tuple_ref(x), s.map(|x| (x,)))) + }); + res + } +} + +pub struct EqRelInd0CRelIndexReadAllIter<'a, T: Clone + Hash + Eq + Sync>(&'a EqRelPair); + +impl<'a, T: Clone + Hash + Eq + Sync + Send> ParallelIterator for EqRelInd0CRelIndexReadAllIter<'a, T> { + type Item = (&'a (T,), rayon::iter::Map, fn(&T) -> (&T,)>); + + fn drive_unindexed(self, consumer: C) -> C::Result + where C: rayon::iter::plumbing::UnindexedConsumer { + self + .0 + .combined + .sets + .par_iter() + .flat_map(|s| { + s.par_iter().map(|x| { + let vals_iter: rayon::iter::Map, fn(&T) -> (&T,)> = + s.par_iter().map(|x| (x,)); + (ref_to_singleton_tuple_ref(x), vals_iter) + }) + }) + .drive_unindexed(consumer) + } +} + +impl<'a, T: Clone + Hash + Eq + Sync + Send> CRelIndexReadAll<'a> for EqRelInd0<'a, T> { + type Key = &'a (T,); + type Value = (&'a T,); + + type AllIteratorType = EqRelInd0CRelIndexReadAllIter<'a, T>; + + type ValueIteratorType = rayon::iter::Map, fn(&T) -> (&T,)>; + + fn c_iter_all(&'a self) -> Self::AllIteratorType { EqRelInd0CRelIndexReadAllIter(self.0.unwrap_frozen()) } +} + +impl<'a, T: Clone + Hash + Eq> RelIndexWrite for EqRelInd0<'a, T> { + type Key = (T,); + type Value = (T,); + fn index_insert(&mut self, _key: Self::Key, _value: Self::Value) { /* noop */ + } +} + +impl<'a, T: Clone + Hash + Eq> CRelIndexWrite for EqRelInd0<'a, T> { + type Key = (T,); + type Value = (T,); + fn index_insert(&self, _key: Self::Key, _value: Self::Value) { /* noop */ + } +} + +impl<'a, T: Clone + Hash + Eq> RelIndexMerge for EqRelInd0<'a, T> { + fn move_index_contents(_from: &mut Self, _to: &mut Self) { /* noop */ + } +} + +pub enum CEqRelIndCommon { + Unfrozen(Mutex>), + Frozen(EqRelPair), +} + +#[derive(Clone)] +pub struct EqRelPair { + pub(crate) old: EqRel, + pub(crate) combined: EqRel, +} + +impl Default for EqRelPair { + fn default() -> Self { Self { old: Default::default(), combined: Default::default() } } +} + +impl Freezable for CEqRelIndCommon {} + +impl CEqRelIndCommon { + fn unwrap_frozen(&self) -> &EqRelPair { + match self { + CEqRelIndCommon::Frozen(old) => old, + CEqRelIndCommon::Unfrozen(_) => panic!("unwrap_frozen() called on Unfrozen"), + } + } + fn unwrap_mut_frozen(&mut self) -> &mut EqRelPair { + match self { + CEqRelIndCommon::Frozen(old) => old, + CEqRelIndCommon::Unfrozen(_) => panic!("unwrap_mut_frozen() called on Unfrozen"), + } + } + fn unwrap_mut_unfrozen(&mut self) -> &mut EqRel { + match self { + CEqRelIndCommon::Unfrozen(uf) => uf.get_mut().unwrap(), + CEqRelIndCommon::Frozen(_) => panic!("unwrap_mut_unfrozen called on Frozen"), + } + } + fn unwrap_unfrozen(&self) -> &Mutex> { + match self { + CEqRelIndCommon::Unfrozen(uf) => uf, + CEqRelIndCommon::Frozen(_) => panic!("unwrap_unfrozen called on Frozen"), + } + } +} + +impl Clone for CEqRelIndCommon { + fn clone(&self) -> Self { + match self { + CEqRelIndCommon::Unfrozen(new) => CEqRelIndCommon::Unfrozen(Mutex::new(new.lock().unwrap().clone())), + CEqRelIndCommon::Frozen(old) => CEqRelIndCommon::Frozen(old.clone()), + } + } +} + +#[derive(Clone)] +pub struct AllAddedParIter<'a, T: Clone + Hash + Eq + Sync>(&'a EqRelPair); + +impl<'a, T: Clone + Hash + Eq + Sync> ParallelIterator for AllAddedParIter<'a, T> { + type Item = (&'a T, &'a T); + + fn drive_unindexed(self, consumer: C) -> C::Result + where C: ascent::rayon::iter::plumbing::UnindexedConsumer { + self.0.combined.c_iter_all().filter(|(x, y)| !self.0.old.contains(x, y)).drive_unindexed(consumer) + } +} + +#[derive(Clone)] +pub struct SetOfAddedParIter<'a, T: Clone + Hash + Eq + Sync> { + set: &'a hashbrown::HashSet>, + old_set: Option<&'a HashSet>>, +} + +impl<'a, T: Clone + Hash + Eq + Sync> ParallelIterator for SetOfAddedParIter<'a, T> { + type Item = &'a T; + + fn drive_unindexed(self, consumer: C) -> C::Result + where C: ascent::rayon::iter::plumbing::UnindexedConsumer { + self.set.par_iter().filter(move |y| !self.old_set.map_or(false, |os| os.contains(*y))).drive_unindexed(consumer) + } +} + +impl CEqRelIndCommon { + pub(crate) fn iter_all_added(&self) -> impl Iterator { + let eqrel = self.unwrap_frozen(); + eqrel.combined.iter_all().filter(|(x, y)| !eqrel.old.contains(x, y)) + } + + pub(crate) fn c_iter_all_added(&self) -> AllAddedParIter<'_, T> + where T: Sync { + let eqrel = self.unwrap_frozen(); + AllAddedParIter(eqrel) + } + + pub(crate) fn set_of_added(&self, x: &T) -> Option> { + let self_ = self.unwrap_frozen(); + let set = self_.combined.set_of(x)?; + // let old_set = self.old.set_of(x).into_iter().flatten(); + let old_set = self_.old.elem_set(x).map(|id| &self_.old.sets[id]); + Some(set.filter(move |y| !old_set.map_or(false, |os| os.contains(*y)))) + } + + pub(crate) fn c_set_of_added(&self, x: &T) -> Option> + where T: Sync { + let self_ = self.unwrap_frozen(); + let set = self_.combined.c_set_of(x)?; + let old_set = self_.old.elem_set(x).map(|id| &self_.old.sets[id]); + Some(SetOfAddedParIter { set, old_set }) + } + + #[allow(dead_code)] + pub(crate) fn added_contains(&self, x: &T, y: &T) -> bool { + let self_ = self.unwrap_frozen(); + self_.combined.contains(x, y) && !self_.old.contains(x, y) + } +} + +impl Default for CEqRelIndCommon { + fn default() -> Self { Self::Frozen(Default::default()) } +} + +impl<'a, T: Clone + Hash + Eq + 'a> RelIndexRead<'a> for CEqRelIndCommon { + type Key = &'a (T, T); + type Value = (); + + type IteratorType = std::iter::Once<()>; + + fn index_get(&'a self, (x, y): &Self::Key) -> Option { + let self_ = self.unwrap_frozen(); + if self_.combined.contains(x, y) && !self_.old.contains(x, y) { Some(std::iter::once(())) } else { None } + } + + fn len(&self) -> usize { + let self_ = self.unwrap_frozen(); + let sample_size = 3; + let sum: usize = self_.combined.sets.iter().take(sample_size).map(|s| s.len().pow(2)).sum(); + let sets_len = self_.combined.sets.len(); + sum * sets_len / sample_size.min(sets_len).max(1) + } +} + +impl<'a, T: Clone + Hash + Eq + Sync + 'a> CRelIndexRead<'a> for CEqRelIndCommon { + type Key = &'a (T, T); + type Value = (); + + type IteratorType = ascent::rayon::iter::Once<()>; + + fn c_index_get(&'a self, (x, y): &Self::Key) -> Option { + let self_ = self.unwrap_frozen(); + if self_.combined.contains(x, y) && !self_.old.contains(x, y) { + Some(ascent::rayon::iter::once(())) + } else { + None + } + } +} + +impl<'a, T: Clone + Hash + Eq + 'a> RelIndexReadAll<'a> for CEqRelIndCommon { + type Key = (&'a T, &'a T); + type Value = (); + + type ValueIteratorType = std::iter::Once<()>; + + type AllIteratorType = Box + 'a>; + + fn iter_all(&'a self) -> Self::AllIteratorType { Box::new(self.iter_all_added().map(|x| (x, std::iter::once(())))) } +} + +impl<'a, T: Clone + Hash + Eq + Sync + 'a> CRelIndexReadAll<'a> for CEqRelIndCommon { + type Key = (&'a T, &'a T); + type Value = (); + + type ValueIteratorType = ascent::rayon::iter::Once<()>; + + type AllIteratorType = ascent::rayon::iter::Map< + AllAddedParIter<'a, T>, + for<'aa, 'bb> fn((&'aa T, &'bb T)) -> ((&'aa T, &'bb T), ascent::rayon::iter::Once<()>), + >; + + fn c_iter_all(&'a self) -> Self::AllIteratorType { + let res: Self::AllIteratorType = self.c_iter_all_added().map(|x| (x, ascent::rayon::iter::once(()))); + res + } +} + +impl<'a, T: Clone + Hash + Eq> RelFullIndexRead<'a> for CEqRelIndCommon { + type Key = (T, T); + + fn contains_key(&'a self, (x, y): &Self::Key) -> bool { + let self_ = self.unwrap_frozen(); + self_.combined.contains(x, y) && !self_.old.contains(x, y) + } +} + +impl<'a, T: Clone + Hash + Eq> RelIndexWrite for CEqRelIndCommon { + type Key = (T, T); + type Value = (); + + fn index_insert(&mut self, key: Self::Key, _value: Self::Value) { self.unwrap_mut_unfrozen().add(key.0, key.1); } +} + +impl<'a, T: Clone + Hash + Eq> CRelIndexWrite for CEqRelIndCommon { + type Key = (T, T); + type Value = (); + + fn index_insert(&self, key: Self::Key, _value: Self::Value) { + self.unwrap_unfrozen().lock().unwrap().add(key.0, key.1); + } +} + +impl<'a, T: Clone + Hash + Eq> RelIndexMerge for CEqRelIndCommon { + fn move_index_contents(_from: &mut Self, _to: &mut Self) { + unimplemented!("merge_delta_to_total_new_to_delta must be used instead") + } + + fn init(new: &mut Self, _delta: &mut Self, _total: &mut Self) { + *new = Self::Unfrozen(Mutex::new(Default::default())) + } + + fn merge_delta_to_total_new_to_delta(new: &mut Self, delta: &mut Self, total: &mut Self) { + let total = total.unwrap_mut_frozen(); + let delta = delta.unwrap_mut_frozen(); + total.combined = delta.combined.clone(); + delta.old = total.combined.clone(); + + // delta.combined.combine(new.combined.clone()); + delta.combined.combine(std::mem::take(&mut new.unwrap_mut_unfrozen())); + } +} + +pub struct EqRelIndNone<'a, T: Clone + Hash + Eq>(&'a CEqRelIndCommon); + +impl<'a, T: Clone + Hash + Eq> RelIndexRead<'a> for EqRelIndNone<'a, T> { + type Key = (); + type Value = (&'a T, &'a T); + + type IteratorType = IteratorFromDyn<'a, (&'a T, &'a T)>; + + fn index_get(&'a self, _key: &Self::Key) -> Option { + Some(IteratorFromDyn::new(|| self.0.iter_all_added())) + } + + fn len(&self) -> usize { 1 } +} + +impl<'a, T: Clone + Hash + Eq + Sync> CRelIndexRead<'a> for EqRelIndNone<'a, T> { + type Key = (); + type Value = (&'a T, &'a T); + + type IteratorType = AllAddedParIter<'a, T>; + + fn c_index_get(&'a self, _key: &Self::Key) -> Option { Some(self.0.c_iter_all_added()) } +} + +impl<'a, T: Clone + Hash + Eq> RelIndexReadAll<'a> for EqRelIndNone<'a, T> { + type Key = (); + + type Value = (&'a T, &'a T); + + type ValueIteratorType = IteratorFromDyn<'a, (&'a T, &'a T)>; + + type AllIteratorType = std::option::IntoIter<(Self::Key, Self::ValueIteratorType)>; + + fn iter_all(&'a self) -> Self::AllIteratorType { self.index_get(&()).map(|iter| ((), iter)).into_iter() } +} + +impl<'a, T: Clone + Hash + Eq + Sync> CRelIndexReadAll<'a> for EqRelIndNone<'a, T> { + type Key = (); + type Value = (&'a T, &'a T); + + type ValueIteratorType = AllAddedParIter<'a, T>; + + type AllIteratorType = ascent::rayon::iter::Once<(Self::Key, Self::ValueIteratorType)>; + + fn c_iter_all(&'a self) -> Self::AllIteratorType { ascent::rayon::iter::once(((), self.0.c_iter_all_added())) } +} + +impl<'a, T: Clone + Hash + Eq> RelIndexWrite for EqRelIndNone<'a, T> { + type Key = (); + type Value = (T, T); + fn index_insert(&mut self, _key: Self::Key, _value: Self::Value) { /* noop */ + } +} + +impl<'a, T: Clone + Hash + Eq> CRelIndexWrite for EqRelIndNone<'a, T> { + type Key = (); + type Value = (T, T); + fn index_insert(&self, _key: Self::Key, _value: Self::Value) { /* noop */ + } +} + +impl<'a, T: Clone + Hash + Eq> RelIndexMerge for EqRelIndNone<'a, T> { + fn move_index_contents(_from: &mut Self, _to: &mut Self) { /* noop */ + } +} + +// TODO this is not safe, and not required. Get rid of it. +pub(crate) fn ref_to_singleton_tuple_ref(x: &T) -> &(T,) { unsafe { transmute(x) } } + +#[test] +fn test_ref_to_singleton_tuple_ref() { + use std::mem::size_of; + println!("size_of::>(): {}", size_of::>()); + println!("size_of::<(Vec,)>(): {}", size_of::<(Vec,)>()); + + let x = vec![1, 2, 3]; + let x2 = ref_to_singleton_tuple_ref(&x); + assert_eq!(&x, &x2.0); +} diff --git a/byods/ascent-byods-rels/src/eqrel.rs b/byods/ascent-byods-rels/src/eqrel.rs index c1750f2..6ae604f 100644 --- a/byods/ascent-byods-rels/src/eqrel.rs +++ b/byods/ascent-byods-rels/src/eqrel.rs @@ -1,56 +1,56 @@ -//! equivalence relations for Ascent - -#[doc(hidden)] -#[macro_export] -macro_rules! eqrel_rel_codegen { - ( $($tt: tt)* ) => { }; -} -pub use eqrel_rel_codegen as rel_codegen; - -#[doc(hidden)] -#[macro_export] -macro_rules! eqrel_rel { - ($name: ident, ($col1: ty, $col2: ty), $($rest: tt)*) => { - $crate::eqrel_binary::rel!(($col1, $col2), $($rest)*) - }; - ($name: ident, ($col1: ty, $col2: ty, $col3: ty), $($rest: tt)*) => { - $crate::eqrel_ternary::rel!(($col1, $col2, $col3), $($rest)*) - }; -} -pub use eqrel_rel as rel; - -#[doc(hidden)] -#[macro_export] -macro_rules! eqrel_rel_full_ind { - ($name: ident, ($col1: ty, $col2: ty), $($rest: tt)*) => { - $crate::eqrel_binary::rel_full_ind!(($col1, $col2), $($rest)*) - }; - ($name: ident, ($col1: ty, $col2: ty, $col3: ty), $($rest: tt)*) => { - $crate::eqrel_ternary::rel_full_ind!(($col1, $col2, $col3), $($rest)*) - }; -} -pub use eqrel_rel_full_ind as rel_full_ind; - -#[doc(hidden)] -#[macro_export] -macro_rules! eqrel_rel_ind { - ($name: ident, ($col1: ty, $col2: ty), $($rest: tt)*) => { - $crate::eqrel_binary::rel_ind!(($col1, $col2), $($rest)*) - }; - ($name: ident, ($col1: ty, $col2: ty, $col3: ty), $($rest: tt)*) => { - $crate::eqrel_ternary::rel_ind!(($col1, $col2, $col3), $($rest)*) - }; -} -pub use eqrel_rel_ind as rel_ind; - -#[doc(hidden)] -#[macro_export] -macro_rules! eqrel_rel_ind_common { - ($name: ident, ($col1: ty, $col2: ty), $($rest: tt)*) => { - $crate::eqrel_binary::rel_ind_common!(($col1, $col2), $($rest)*) - }; - ($name: ident, ($col1: ty, $col2: ty, $col3: ty), $($rest: tt)*) => { - $crate::eqrel_ternary::rel_ind_common!(($col1, $col2, $col3), $($rest)*) - }; -} -pub use eqrel_rel_ind_common as rel_ind_common; +//! equivalence relations for Ascent + +#[doc(hidden)] +#[macro_export] +macro_rules! eqrel_rel_codegen { + ( $($tt: tt)* ) => { }; +} +pub use eqrel_rel_codegen as rel_codegen; + +#[doc(hidden)] +#[macro_export] +macro_rules! eqrel_rel { + ($name: ident, ($col1: ty, $col2: ty), $($rest: tt)*) => { + $crate::eqrel_binary::rel!(($col1, $col2), $($rest)*) + }; + ($name: ident, ($col1: ty, $col2: ty, $col3: ty), $($rest: tt)*) => { + $crate::eqrel_ternary::rel!(($col1, $col2, $col3), $($rest)*) + }; +} +pub use eqrel_rel as rel; + +#[doc(hidden)] +#[macro_export] +macro_rules! eqrel_rel_full_ind { + ($name: ident, ($col1: ty, $col2: ty), $($rest: tt)*) => { + $crate::eqrel_binary::rel_full_ind!(($col1, $col2), $($rest)*) + }; + ($name: ident, ($col1: ty, $col2: ty, $col3: ty), $($rest: tt)*) => { + $crate::eqrel_ternary::rel_full_ind!(($col1, $col2, $col3), $($rest)*) + }; +} +pub use eqrel_rel_full_ind as rel_full_ind; + +#[doc(hidden)] +#[macro_export] +macro_rules! eqrel_rel_ind { + ($name: ident, ($col1: ty, $col2: ty), $($rest: tt)*) => { + $crate::eqrel_binary::rel_ind!(($col1, $col2), $($rest)*) + }; + ($name: ident, ($col1: ty, $col2: ty, $col3: ty), $($rest: tt)*) => { + $crate::eqrel_ternary::rel_ind!(($col1, $col2, $col3), $($rest)*) + }; +} +pub use eqrel_rel_ind as rel_ind; + +#[doc(hidden)] +#[macro_export] +macro_rules! eqrel_rel_ind_common { + ($name: ident, ($col1: ty, $col2: ty), $($rest: tt)*) => { + $crate::eqrel_binary::rel_ind_common!(($col1, $col2), $($rest)*) + }; + ($name: ident, ($col1: ty, $col2: ty, $col3: ty), $($rest: tt)*) => { + $crate::eqrel_ternary::rel_ind_common!(($col1, $col2, $col3), $($rest)*) + }; +} +pub use eqrel_rel_ind_common as rel_ind_common; diff --git a/byods/ascent-byods-rels/src/eqrel_binary.rs b/byods/ascent-byods-rels/src/eqrel_binary.rs index 171765d..38de657 100644 --- a/byods/ascent-byods-rels/src/eqrel_binary.rs +++ b/byods/ascent-byods-rels/src/eqrel_binary.rs @@ -1,76 +1,75 @@ -#[doc(hidden)] -#[macro_export] -macro_rules! eqrel_binary_rel { - (($col1: ty, $col2: ty), $indices: expr, ser, ()) => { - // ::std::vec::Vec<($col1, $col2)> - $crate::fake_vec::FakeVec<($col1, $col2)> - }; - - // par: - (($col1: ty, $col2: ty), $indices: expr, par, ()) => { - // ::std::vec::Vec<($col1, $col2)> - $crate::fake_vec::FakeVec<($col1, $col2)> - }; -} -pub use eqrel_binary_rel as rel; - -#[doc(hidden)] -#[macro_export] -macro_rules! eqrel_binary_rel_full_ind { - (($col1: ty, $col2: ty), $indices: expr, ser, (), $key: ty, $val: ty) => { - $crate::eqrel_ind::ToEqRelInd0_1<$col1> - }; - - // par: - (($col1: ty, $col2: ty), $indices: expr, par, (), $key: ty, $val: ty) => { - $crate::ceqrel_ind::ToEqRelInd0_1<$col1> - }; -} -pub use eqrel_binary_rel_full_ind as rel_full_ind; - -#[doc(hidden)] -#[macro_export] -macro_rules! eqrel_binary_rel_ind { - (($col1: ty, $col2: ty), $indices: expr, ser, (), [0], $key: ty, $val: ty) => { - $crate::eqrel_ind::ToEqRelInd0<$col1> - }; - (($col1: ty, $col2: ty), $indices: expr, ser, (), [1], $key: ty, $val: ty) => { - $crate::eqrel_ind::ToEqRelInd0<$col1> - }; - (($col1: ty, $col2: ty), $indices: expr, ser, (), [], $key: ty, $val: ty) => { - $crate::eqrel_ind::ToEqRelIndNone<$col1> - }; - - // par: - (($col1: ty, $col2: ty), $indices: expr, par, (), [0], $key: ty, $val: ty) => { - $crate::ceqrel_ind::ToEqRelInd0<$col1> - }; - (($col1: ty, $col2: ty), $indices: expr, par, (), [1], $key: ty, $val: ty) => { - $crate::ceqrel_ind::ToEqRelInd0<$col1> - }; - (($col1: ty, $col2: ty), $indices: expr, par, (), [], $key: ty, $val: ty) => { - $crate::ceqrel_ind::ToEqRelIndNone<$col1> - }; -} -pub use eqrel_binary_rel_ind as rel_ind; - -#[doc(hidden)] -#[macro_export] -macro_rules! eqrel_binary_rel_ind_common { - (($col1: ty, $col2: ty), $indices: expr, ser, ()) => { - $crate::eqrel_ind::EqRelIndCommon<$col1> - }; - - // par: - (($col1: ty, $col2: ty), $indices: expr, par, ()) => { - $crate::ceqrel_ind::CEqRelIndCommon<$col1> - }; -} -pub use eqrel_binary_rel_ind_common as rel_ind_common; - - -fn _test_macros() { - let _x: rel!((u32, u32), [[0,1], [0]], ser, ()); - let _full_ind: rel_full_ind!((u32, u32), [[0, 1], [0]], ser, (), (u32, u32), ()); - let _ind_0: rel_ind!((u32, u32), [[0, 1], [0]], ser, (), [0], (u32,) , (u32,)); -} +#[doc(hidden)] +#[macro_export] +macro_rules! eqrel_binary_rel { + (($col1: ty, $col2: ty), $indices: expr, ser, ()) => { + // ::std::vec::Vec<($col1, $col2)> + $crate::fake_vec::FakeVec<($col1, $col2)> + }; + + // par: + (($col1: ty, $col2: ty), $indices: expr, par, ()) => { + // ::std::vec::Vec<($col1, $col2)> + $crate::fake_vec::FakeVec<($col1, $col2)> + }; +} +pub use eqrel_binary_rel as rel; + +#[doc(hidden)] +#[macro_export] +macro_rules! eqrel_binary_rel_full_ind { + (($col1: ty, $col2: ty), $indices: expr, ser, (), $key: ty, $val: ty) => { + $crate::eqrel_ind::ToEqRelInd0_1<$col1> + }; + + // par: + (($col1: ty, $col2: ty), $indices: expr, par, (), $key: ty, $val: ty) => { + $crate::ceqrel_ind::ToEqRelInd0_1<$col1> + }; +} +pub use eqrel_binary_rel_full_ind as rel_full_ind; + +#[doc(hidden)] +#[macro_export] +macro_rules! eqrel_binary_rel_ind { + (($col1: ty, $col2: ty), $indices: expr, ser, (), [0], $key: ty, $val: ty) => { + $crate::eqrel_ind::ToEqRelInd0<$col1> + }; + (($col1: ty, $col2: ty), $indices: expr, ser, (), [1], $key: ty, $val: ty) => { + $crate::eqrel_ind::ToEqRelInd0<$col1> + }; + (($col1: ty, $col2: ty), $indices: expr, ser, (), [], $key: ty, $val: ty) => { + $crate::eqrel_ind::ToEqRelIndNone<$col1> + }; + + // par: + (($col1: ty, $col2: ty), $indices: expr, par, (), [0], $key: ty, $val: ty) => { + $crate::ceqrel_ind::ToEqRelInd0<$col1> + }; + (($col1: ty, $col2: ty), $indices: expr, par, (), [1], $key: ty, $val: ty) => { + $crate::ceqrel_ind::ToEqRelInd0<$col1> + }; + (($col1: ty, $col2: ty), $indices: expr, par, (), [], $key: ty, $val: ty) => { + $crate::ceqrel_ind::ToEqRelIndNone<$col1> + }; +} +pub use eqrel_binary_rel_ind as rel_ind; + +#[doc(hidden)] +#[macro_export] +macro_rules! eqrel_binary_rel_ind_common { + (($col1: ty, $col2: ty), $indices: expr, ser, ()) => { + $crate::eqrel_ind::EqRelIndCommon<$col1> + }; + + // par: + (($col1: ty, $col2: ty), $indices: expr, par, ()) => { + $crate::ceqrel_ind::CEqRelIndCommon<$col1> + }; +} +pub use eqrel_binary_rel_ind_common as rel_ind_common; + +fn _test_macros() { + let _x: rel!((u32, u32), [[0, 1], [0]], ser, ()); + let _full_ind: rel_full_ind!((u32, u32), [[0, 1], [0]], ser, (), (u32, u32), ()); + let _ind_0: rel_ind!((u32, u32), [[0, 1], [0]], ser, (), [0], (u32,), (u32,)); +} diff --git a/byods/ascent-byods-rels/src/eqrel_ind.rs b/byods/ascent-byods-rels/src/eqrel_ind.rs index 514e3b4..f3180df 100644 --- a/byods/ascent-byods-rels/src/eqrel_ind.rs +++ b/byods/ascent-byods-rels/src/eqrel_ind.rs @@ -1,364 +1,361 @@ -use hashbrown::HashSet; -use std::hash::{Hash, BuildHasherDefault}; -use std::iter::{FlatMap, Map, Repeat, Zip}; -use std::marker::PhantomData; -use std::mem::transmute; -use std::rc::Rc; - -use ascent::internal::{ - RelFullIndexRead, RelFullIndexWrite, RelIndexMerge, RelIndexRead, RelIndexReadAll, RelIndexWrite, ToRelIndex, -}; -use rustc_hash::FxHasher; - -#[cfg(test)] -use { - itertools::Itertools -}; - -use crate::iterator_from_dyn::IteratorFromDyn; -use crate::union_find::EqRel; - -use hashbrown::hash_set::Iter as HashSetIter; - - -pub struct EqRelInd0<'a, T: Clone + Hash + Eq>(pub(crate) &'a EqRelIndCommon); - -#[test] -fn test_eq_rel_ind_0_iter_all() { - let mut eq_rel_old = EqRel::default(); - for x in 1..=10 { eq_rel_old.add(1, x); }; - for x in 101..=110 { eq_rel_old.add(101, x); }; - - let mut eq_rel_new = EqRel::default(); - eq_rel_new.add(1, 110); - eq_rel_new.combine(eq_rel_old.clone()); - - let eq_rel_full_ind = EqRelIndCommon { old: Rc::new(eq_rel_old), combined: Rc::new(eq_rel_new) }; - let eq_rel_ind_0 = EqRelInd0(&eq_rel_full_ind); - let iter = eq_rel_ind_0.iter_all().map(|x| (x.0, x.1.collect_vec())).collect_vec(); - - for x in [1, 9, 103] { - let iter_at_x = &iter.iter().find(|y| y.0.0 == x).unwrap().1; - - println!("x: {}, iter_at_x: {:?}", x, iter_at_x); - assert_eq!(iter_at_x.len(), 20); - assert_eq!(iter_at_x.into_iter().map(|x| *x.0).collect::>(), (1..=10).chain(101..=110).collect()); - } -} -pub struct ToEqRelIndNone(PhantomData); -impl Default for ToEqRelIndNone { - fn default() -> Self { Self(PhantomData) } -} -impl ToRelIndex> for ToEqRelIndNone { - type RelIndex<'a> = EqRelIndNone<'a, T> where T: 'a; - fn to_rel_index<'a>(&'a self, rel: &'a EqRelIndCommon) -> Self::RelIndex<'a> { EqRelIndNone(rel) } - - type RelIndexWrite<'a> = EqRelIndNone<'a, T> where T: 'a; - fn to_rel_index_write<'a>(&'a mut self, rel: &'a mut EqRelIndCommon) -> Self::RelIndexWrite<'a> { EqRelIndNone(rel) } -} - - -pub struct ToEqRelInd0(PhantomData); - -impl Default for ToEqRelInd0 { - fn default() -> Self { Self(Default::default()) } -} - -impl ToRelIndex> for ToEqRelInd0 { - type RelIndex<'a> = EqRelInd0<'a, T> where T: 'a; - fn to_rel_index<'a>(&'a self, rel: &'a EqRelIndCommon) -> Self::RelIndex<'a> { EqRelInd0(rel) } - - type RelIndexWrite<'a> = EqRelInd0<'a, T> where T: 'a; - fn to_rel_index_write<'a>(&'a mut self, rel: &'a mut EqRelIndCommon) -> Self::RelIndexWrite<'a> { EqRelInd0(rel) } -} - -pub struct ToEqRelInd0_1(PhantomData); - -impl Default for ToEqRelInd0_1 { - fn default() -> Self { Self(Default::default()) } -} - -pub struct EqRelInd0_1<'a, T: Clone + Hash + Eq>(&'a EqRelIndCommon); -pub struct EqRelInd0_1Write<'a, T: Clone + Hash + Eq>(&'a mut EqRelIndCommon); - -impl <'a, T: Clone + Hash + Eq> RelIndexWrite for EqRelInd0_1Write<'a, T> { - type Key = (T, T); - type Value = (); - - fn index_insert(&mut self, key: Self::Key, value: Self::Value) { - self.0.index_insert(key, value) - } -} - -impl <'a, T: Clone + Hash + Eq> RelIndexMerge for EqRelInd0_1Write<'a, T> { - fn move_index_contents(_from: &mut Self, _to: &mut Self) { - //noop - } -} - -impl RelFullIndexWrite for EqRelIndCommon { - type Key = (T, T); - type Value = (); - - fn insert_if_not_present(&mut self, key: &Self::Key, _v: Self::Value) -> bool { - Rc::get_mut(&mut self.combined).unwrap().add(key.0.clone(), key.1.clone()) - } -} - - -impl <'a, T: Clone + Hash + Eq> RelFullIndexWrite for EqRelInd0_1Write<'a, T> { - type Key = as RelFullIndexWrite>::Key; - type Value = as RelFullIndexWrite>::Value; - fn insert_if_not_present(&mut self, key: &Self::Key, v: Self::Value) -> bool { - self.0.insert_if_not_present(key, v) - } -} - -impl<'a, T: Clone + Hash + Eq> RelIndexRead<'a> for EqRelInd0_1<'a, T> { - type Key = as RelIndexRead<'a>>::Key; - type Value = as RelIndexRead<'a>>::Value; - type IteratorType = as RelIndexRead<'a>>::IteratorType; - - fn index_get(&'a self, key: &Self::Key) -> Option { - self.0.index_get(key) - } - - fn len(&self) -> usize { - self.0.len() - } -} - -impl<'a, T: Clone + Hash + Eq> RelIndexReadAll<'a> for EqRelInd0_1<'a, T> { - type Key = as RelIndexReadAll<'a>>::Key; - type Value = as RelIndexReadAll<'a>>::Value; - type ValueIteratorType = as RelIndexReadAll<'a>>::ValueIteratorType; - type AllIteratorType = as RelIndexReadAll<'a>>::AllIteratorType; - fn iter_all(&'a self) -> Self::AllIteratorType { - self.0.iter_all() - } -} - -impl<'a, T: Clone + Hash + Eq> RelFullIndexRead<'a> for EqRelInd0_1<'a, T> { - type Key = as RelFullIndexRead<'a>>::Key; - fn contains_key(&self, key: &Self::Key) -> bool { - self.0.contains_key(key) - } -} - -impl ToRelIndex> for ToEqRelInd0_1 { - type RelIndex<'a> = EqRelInd0_1<'a, T> where T: 'a; - fn to_rel_index<'a>(&'a self, rel: &'a EqRelIndCommon) -> Self::RelIndex<'a> { EqRelInd0_1(rel) } - - type RelIndexWrite<'a> = EqRelInd0_1Write<'a, T> where T: 'a; - fn to_rel_index_write<'a>(&'a mut self, rel: &'a mut EqRelIndCommon) -> Self::RelIndexWrite<'a> { EqRelInd0_1Write(rel) } -} - -impl<'a, T: Clone + Hash + Eq> RelIndexRead<'a> for EqRelInd0<'a, T> { - type Key = (T,); - type Value = (&'a T,); - - type IteratorType = IteratorFromDyn<'a, (&'a T,)>; - - fn index_get(&'a self, key: &Self::Key) -> Option { - let _ = self.0.set_of_added(&key.0)?; - let key = key.clone(); - let producer = move || self.0.set_of_added(&key.0).unwrap().map(|x| (x,)); - - Some(IteratorFromDyn::new(producer)) - } - - fn len(&self) -> usize { - self.0.combined.elem_ids.len() - } -} - -impl<'a, T: Clone + Hash + Eq> RelIndexReadAll<'a> for EqRelInd0<'a, T> { - type Key = &'a (T,); - type Value = (&'a T,); - - type ValueIteratorType = Map, for<'aa> fn(&'aa T) -> (&'aa T,)>; - - type AllIteratorType = FlatMap>>, Map, Repeat>>, for<'aa> fn((&'aa T, HashSetIter<'aa, T>),) -> (&'aa (T,), Map, for<'bb> fn(&'bb T) -> (&'bb T,)>),>, for<'aa> fn(&'aa HashSet>, ) -> Map< - Zip, Repeat>>, - for<'cc> fn((&'cc T, HashSetIter<'cc, T>),) -> (&'cc (T,), Map, for<'dd> fn(&'dd T) -> (&'dd T,)>),>, >; - - fn iter_all(&'a self) -> Self::AllIteratorType { - let res: Self::AllIteratorType = self.0.combined.sets.iter().flat_map(|s| { - s.iter().zip(std::iter::repeat(s.iter())).map(|(x, s)| (ref_to_singleton_tuple_ref(x), s.map(|x| (x,)))) - }); - res - } -} - -impl<'a, T: Clone + Hash + Eq> RelIndexWrite for EqRelInd0<'a, T> { - type Key = (T,); - type Value = (T,); - fn index_insert(&mut self, _key: Self::Key, _value: Self::Value) { - // noop - } -} - -impl<'a, T: Clone + Hash + Eq> RelIndexMerge for EqRelInd0<'a, T> { - fn move_index_contents(_from: &mut Self, _to: &mut Self) { - //noop - } -} - -#[derive(Clone)] -pub struct EqRelIndCommon { - pub(crate) old: Rc>, - pub(crate) combined: Rc>, -} - -impl EqRelIndCommon { - - pub fn iter_all_added(&self) -> impl Iterator { - self.combined.iter_all().filter(|(x, y)| !self.old.contains(x, y)) - } - - pub(crate) fn set_of_added(&self, x: &T) -> Option> { - let set = self.combined.set_of(x)?; - // let old_set = self.old.set_of(x).into_iter().flatten(); - let old_set = self.old.elem_set(x).map(|id| &self.old.sets[id]); - Some(set.filter(move |y| !old_set.map_or(false, |os| os.contains(*y)))) - } - - pub(crate) fn added_contains(&self, x: &T, y: &T) -> bool { - self.combined.contains(x, y) && !self.old.contains(x, y) - } - - pub fn count_exact(&self) -> usize { - // old must be a subset of combined - self.combined.count_exact() - self.old.count_exact() - } - -} - -impl Default for EqRelIndCommon { - fn default() -> Self { - Self { old: Default::default(), combined: Default::default() } - } -} - -impl<'a, T: Clone + Hash + Eq + 'a> RelIndexRead<'a> for EqRelIndCommon { - type Key = (T, T); - type Value = (); - - type IteratorType = std::iter::Once<()>; - - fn index_get(&'a self, (x, y): &Self::Key) -> Option { - if self.combined.contains(x, y) && !self.old.contains(x, y) { - Some(std::iter::once(())) - } else { - None - } - } - - fn len(&self) -> usize { - let sample_size = 3; - let sum: usize = self.combined.sets.iter().take(sample_size).map(|s| s.len().pow(2)).sum(); - let sets_len = self.combined.sets.len(); - sum * sets_len / sample_size.min(sets_len).max(1) - } -} - -impl<'a, T: Clone + Hash + Eq + 'a> RelIndexReadAll<'a> for EqRelIndCommon { - type Key = (&'a T, &'a T); - type Value = (); - - type ValueIteratorType = std::iter::Once<()>; - - type AllIteratorType = Box + 'a>; - - fn iter_all(&'a self) -> Self::AllIteratorType { - Box::new(self.iter_all_added().map(|x| (x, std::iter::once(())))) - } -} - -impl<'a, T: Clone + Hash + Eq> RelFullIndexRead<'a> for EqRelIndCommon { - type Key = (T, T); - - fn contains_key(&'a self, (x, y): &Self::Key) -> bool { - self.combined.contains(x, y) && !self.old.contains(x, y) - } -} - -impl<'a, T: Clone + Hash + Eq> RelIndexWrite for EqRelIndCommon { - type Key = (T, T); - type Value = (); - - fn index_insert(&mut self, key: Self::Key, _value: Self::Value) { - Rc::get_mut(&mut self.combined).unwrap().add(key.0, key.1); - } - -} - -impl<'a, T: Clone + Hash + Eq> RelIndexMerge for EqRelIndCommon { - fn move_index_contents(_from: &mut Self, _to: &mut Self) { - unimplemented!("merge_delta_to_total_new_to_delta must be used instead") - } - - fn merge_delta_to_total_new_to_delta(new: &mut Self, delta: &mut Self, total: &mut Self) { - total.combined = delta.combined.clone(); - delta.old = total.combined.clone(); - - // delta.combined.combine(new.combined.clone()); - Rc::make_mut(&mut delta.combined).combine(std::mem::take(Rc::get_mut(&mut new.combined).unwrap())); - } -} - -pub struct EqRelIndNone<'a, T: Clone + Hash + Eq>(&'a EqRelIndCommon); - -impl<'a, T: Clone + Hash + Eq> RelIndexRead<'a> for EqRelIndNone<'a, T> { - type Key = (); - - type Value = (&'a T, &'a T); - - type IteratorType = IteratorFromDyn<'a, (&'a T, &'a T)>; - - fn index_get(&'a self, _key: &Self::Key) -> Option { - Some(IteratorFromDyn::new(|| self.0.iter_all_added())) - } - - fn len(&self) -> usize { - 1 - } -} - -impl<'a, T: Clone + Hash + Eq> RelIndexReadAll<'a> for EqRelIndNone<'a, T> { - type Key = (); - - type Value = (&'a T, &'a T); - - type ValueIteratorType = IteratorFromDyn<'a, (&'a T, &'a T)>; - - type AllIteratorType = std::option::IntoIter<(Self::Key, Self::ValueIteratorType)>; - - fn iter_all(&'a self) -> Self::AllIteratorType { - self.index_get(&()).map(|iter| ((), iter)).into_iter() - } -} - -impl<'a, T: Clone + Hash + Eq> RelIndexWrite for EqRelIndNone<'a, T> { - type Key = (); - type Value = (T, T); - fn index_insert(&mut self, _key: Self::Key, _value: Self::Value) { /* noop */ } -} - -impl<'a, T: Clone + Hash + Eq> RelIndexMerge for EqRelIndNone<'a, T> { - fn move_index_contents(_from: &mut Self, _to: &mut Self) { /* noop */ } -} - - -pub(crate) fn ref_to_singleton_tuple_ref(x: &T) -> &(T,) { unsafe { transmute(x) } } - -#[test] -fn test_ref_to_singleton_tuple_ref() { - use std::mem::size_of; - println!("size_of::>(): {}", size_of::>()); - println!("size_of::<(Vec,)>(): {}", size_of::<(Vec,)>()); - - let x = vec![1, 2, 3]; - let x2 = ref_to_singleton_tuple_ref(&x); - assert_eq!(&x, &x2.0); -} +use std::hash::{BuildHasherDefault, Hash}; +use std::iter::{FlatMap, Map, Repeat, Zip}; +use std::marker::PhantomData; +use std::mem::transmute; +use std::rc::Rc; + +use ascent::internal::{ + RelFullIndexRead, RelFullIndexWrite, RelIndexMerge, RelIndexRead, RelIndexReadAll, RelIndexWrite, ToRelIndex, +}; +use hashbrown::HashSet; +use hashbrown::hash_set::Iter as HashSetIter; +#[cfg(test)] +use itertools::Itertools; +use rustc_hash::FxHasher; + +use crate::iterator_from_dyn::IteratorFromDyn; +use crate::union_find::EqRel; + +pub struct EqRelInd0<'a, T: Clone + Hash + Eq>(pub(crate) &'a EqRelIndCommon); + +#[test] +fn test_eq_rel_ind_0_iter_all() { + let mut eq_rel_old = EqRel::default(); + for x in 1..=10 { + eq_rel_old.add(1, x); + } + for x in 101..=110 { + eq_rel_old.add(101, x); + } + + let mut eq_rel_new = EqRel::default(); + eq_rel_new.add(1, 110); + eq_rel_new.combine(eq_rel_old.clone()); + + let eq_rel_full_ind = EqRelIndCommon { old: Rc::new(eq_rel_old), combined: Rc::new(eq_rel_new) }; + let eq_rel_ind_0 = EqRelInd0(&eq_rel_full_ind); + let iter = eq_rel_ind_0.iter_all().map(|x| (x.0, x.1.collect_vec())).collect_vec(); + + for x in [1, 9, 103] { + let iter_at_x = &iter.iter().find(|y| y.0.0 == x).unwrap().1; + + println!("x: {}, iter_at_x: {:?}", x, iter_at_x); + assert_eq!(iter_at_x.len(), 20); + assert_eq!(iter_at_x.into_iter().map(|x| *x.0).collect::>(), (1..=10).chain(101..=110).collect()); + } +} +pub struct ToEqRelIndNone(PhantomData); +impl Default for ToEqRelIndNone { + fn default() -> Self { Self(PhantomData) } +} +impl ToRelIndex> for ToEqRelIndNone { + type RelIndex<'a> + = EqRelIndNone<'a, T> + where T: 'a; + fn to_rel_index<'a>(&'a self, rel: &'a EqRelIndCommon) -> Self::RelIndex<'a> { EqRelIndNone(rel) } + + type RelIndexWrite<'a> + = EqRelIndNone<'a, T> + where T: 'a; + fn to_rel_index_write<'a>(&'a mut self, rel: &'a mut EqRelIndCommon) -> Self::RelIndexWrite<'a> { + EqRelIndNone(rel) + } +} + +pub struct ToEqRelInd0(PhantomData); + +impl Default for ToEqRelInd0 { + fn default() -> Self { Self(Default::default()) } +} + +impl ToRelIndex> for ToEqRelInd0 { + type RelIndex<'a> + = EqRelInd0<'a, T> + where T: 'a; + fn to_rel_index<'a>(&'a self, rel: &'a EqRelIndCommon) -> Self::RelIndex<'a> { EqRelInd0(rel) } + + type RelIndexWrite<'a> + = EqRelInd0<'a, T> + where T: 'a; + fn to_rel_index_write<'a>(&'a mut self, rel: &'a mut EqRelIndCommon) -> Self::RelIndexWrite<'a> { EqRelInd0(rel) } +} + +pub struct ToEqRelInd0_1(PhantomData); + +impl Default for ToEqRelInd0_1 { + fn default() -> Self { Self(Default::default()) } +} + +pub struct EqRelInd0_1<'a, T: Clone + Hash + Eq>(&'a EqRelIndCommon); +pub struct EqRelInd0_1Write<'a, T: Clone + Hash + Eq>(&'a mut EqRelIndCommon); + +impl<'a, T: Clone + Hash + Eq> RelIndexWrite for EqRelInd0_1Write<'a, T> { + type Key = (T, T); + type Value = (); + + fn index_insert(&mut self, key: Self::Key, value: Self::Value) { self.0.index_insert(key, value) } +} + +impl<'a, T: Clone + Hash + Eq> RelIndexMerge for EqRelInd0_1Write<'a, T> { + fn move_index_contents(_from: &mut Self, _to: &mut Self) { + //noop + } +} + +impl RelFullIndexWrite for EqRelIndCommon { + type Key = (T, T); + type Value = (); + + fn insert_if_not_present(&mut self, key: &Self::Key, _v: Self::Value) -> bool { + Rc::get_mut(&mut self.combined).unwrap().add(key.0.clone(), key.1.clone()) + } +} + +impl<'a, T: Clone + Hash + Eq> RelFullIndexWrite for EqRelInd0_1Write<'a, T> { + type Key = as RelFullIndexWrite>::Key; + type Value = as RelFullIndexWrite>::Value; + fn insert_if_not_present(&mut self, key: &Self::Key, v: Self::Value) -> bool { self.0.insert_if_not_present(key, v) } +} + +impl<'a, T: Clone + Hash + Eq> RelIndexRead<'a> for EqRelInd0_1<'a, T> { + type Key = as RelIndexRead<'a>>::Key; + type Value = as RelIndexRead<'a>>::Value; + type IteratorType = as RelIndexRead<'a>>::IteratorType; + + fn index_get(&'a self, key: &Self::Key) -> Option { self.0.index_get(key) } + + fn len(&self) -> usize { self.0.len() } +} + +impl<'a, T: Clone + Hash + Eq> RelIndexReadAll<'a> for EqRelInd0_1<'a, T> { + type Key = as RelIndexReadAll<'a>>::Key; + type Value = as RelIndexReadAll<'a>>::Value; + type ValueIteratorType = as RelIndexReadAll<'a>>::ValueIteratorType; + type AllIteratorType = as RelIndexReadAll<'a>>::AllIteratorType; + fn iter_all(&'a self) -> Self::AllIteratorType { self.0.iter_all() } +} + +impl<'a, T: Clone + Hash + Eq> RelFullIndexRead<'a> for EqRelInd0_1<'a, T> { + type Key = as RelFullIndexRead<'a>>::Key; + fn contains_key(&self, key: &Self::Key) -> bool { self.0.contains_key(key) } +} + +impl ToRelIndex> for ToEqRelInd0_1 { + type RelIndex<'a> + = EqRelInd0_1<'a, T> + where T: 'a; + fn to_rel_index<'a>(&'a self, rel: &'a EqRelIndCommon) -> Self::RelIndex<'a> { EqRelInd0_1(rel) } + + type RelIndexWrite<'a> + = EqRelInd0_1Write<'a, T> + where T: 'a; + fn to_rel_index_write<'a>(&'a mut self, rel: &'a mut EqRelIndCommon) -> Self::RelIndexWrite<'a> { + EqRelInd0_1Write(rel) + } +} + +impl<'a, T: Clone + Hash + Eq> RelIndexRead<'a> for EqRelInd0<'a, T> { + type Key = (T,); + type Value = (&'a T,); + + type IteratorType = IteratorFromDyn<'a, (&'a T,)>; + + fn index_get(&'a self, key: &Self::Key) -> Option { + let _ = self.0.set_of_added(&key.0)?; + let key = key.clone(); + let producer = move || self.0.set_of_added(&key.0).unwrap().map(|x| (x,)); + + Some(IteratorFromDyn::new(producer)) + } + + fn len(&self) -> usize { self.0.combined.elem_ids.len() } +} + +impl<'a, T: Clone + Hash + Eq> RelIndexReadAll<'a> for EqRelInd0<'a, T> { + type Key = &'a (T,); + type Value = (&'a T,); + + type ValueIteratorType = Map, for<'aa> fn(&'aa T) -> (&'aa T,)>; + + type AllIteratorType = FlatMap< + std::slice::Iter<'a, HashSet>>, + Map< + Zip, Repeat>>, + for<'aa> fn( + (&'aa T, HashSetIter<'aa, T>), + ) -> (&'aa (T,), Map, for<'bb> fn(&'bb T) -> (&'bb T,)>), + >, + for<'aa> fn( + &'aa HashSet>, + ) -> Map< + Zip, Repeat>>, + for<'cc> fn( + (&'cc T, HashSetIter<'cc, T>), + ) -> (&'cc (T,), Map, for<'dd> fn(&'dd T) -> (&'dd T,)>), + >, + >; + + fn iter_all(&'a self) -> Self::AllIteratorType { + let res: Self::AllIteratorType = self.0.combined.sets.iter().flat_map(|s| { + s.iter().zip(std::iter::repeat(s.iter())).map(|(x, s)| (ref_to_singleton_tuple_ref(x), s.map(|x| (x,)))) + }); + res + } +} + +impl<'a, T: Clone + Hash + Eq> RelIndexWrite for EqRelInd0<'a, T> { + type Key = (T,); + type Value = (T,); + fn index_insert(&mut self, _key: Self::Key, _value: Self::Value) { + // noop + } +} + +impl<'a, T: Clone + Hash + Eq> RelIndexMerge for EqRelInd0<'a, T> { + fn move_index_contents(_from: &mut Self, _to: &mut Self) { + //noop + } +} + +#[derive(Clone)] +pub struct EqRelIndCommon { + pub(crate) old: Rc>, + pub(crate) combined: Rc>, +} + +impl EqRelIndCommon { + pub fn iter_all_added(&self) -> impl Iterator { + self.combined.iter_all().filter(|(x, y)| !self.old.contains(x, y)) + } + + pub(crate) fn set_of_added(&self, x: &T) -> Option> { + let set = self.combined.set_of(x)?; + // let old_set = self.old.set_of(x).into_iter().flatten(); + let old_set = self.old.elem_set(x).map(|id| &self.old.sets[id]); + Some(set.filter(move |y| !old_set.map_or(false, |os| os.contains(*y)))) + } + + pub(crate) fn added_contains(&self, x: &T, y: &T) -> bool { + self.combined.contains(x, y) && !self.old.contains(x, y) + } + + pub fn count_exact(&self) -> usize { + // old must be a subset of combined + self.combined.count_exact() - self.old.count_exact() + } +} + +impl Default for EqRelIndCommon { + fn default() -> Self { Self { old: Default::default(), combined: Default::default() } } +} + +impl<'a, T: Clone + Hash + Eq + 'a> RelIndexRead<'a> for EqRelIndCommon { + type Key = (T, T); + type Value = (); + + type IteratorType = std::iter::Once<()>; + + fn index_get(&'a self, (x, y): &Self::Key) -> Option { + if self.combined.contains(x, y) && !self.old.contains(x, y) { Some(std::iter::once(())) } else { None } + } + + fn len(&self) -> usize { + let sample_size = 3; + let sum: usize = self.combined.sets.iter().take(sample_size).map(|s| s.len().pow(2)).sum(); + let sets_len = self.combined.sets.len(); + sum * sets_len / sample_size.min(sets_len).max(1) + } +} + +impl<'a, T: Clone + Hash + Eq + 'a> RelIndexReadAll<'a> for EqRelIndCommon { + type Key = (&'a T, &'a T); + type Value = (); + + type ValueIteratorType = std::iter::Once<()>; + + type AllIteratorType = Box + 'a>; + + fn iter_all(&'a self) -> Self::AllIteratorType { Box::new(self.iter_all_added().map(|x| (x, std::iter::once(())))) } +} + +impl<'a, T: Clone + Hash + Eq> RelFullIndexRead<'a> for EqRelIndCommon { + type Key = (T, T); + + fn contains_key(&'a self, (x, y): &Self::Key) -> bool { self.combined.contains(x, y) && !self.old.contains(x, y) } +} + +impl<'a, T: Clone + Hash + Eq> RelIndexWrite for EqRelIndCommon { + type Key = (T, T); + type Value = (); + + fn index_insert(&mut self, key: Self::Key, _value: Self::Value) { + Rc::get_mut(&mut self.combined).unwrap().add(key.0, key.1); + } +} + +impl<'a, T: Clone + Hash + Eq> RelIndexMerge for EqRelIndCommon { + fn move_index_contents(_from: &mut Self, _to: &mut Self) { + unimplemented!("merge_delta_to_total_new_to_delta must be used instead") + } + + fn merge_delta_to_total_new_to_delta(new: &mut Self, delta: &mut Self, total: &mut Self) { + total.combined = delta.combined.clone(); + delta.old = total.combined.clone(); + + // delta.combined.combine(new.combined.clone()); + Rc::make_mut(&mut delta.combined).combine(std::mem::take(Rc::get_mut(&mut new.combined).unwrap())); + } +} + +pub struct EqRelIndNone<'a, T: Clone + Hash + Eq>(&'a EqRelIndCommon); + +impl<'a, T: Clone + Hash + Eq> RelIndexRead<'a> for EqRelIndNone<'a, T> { + type Key = (); + + type Value = (&'a T, &'a T); + + type IteratorType = IteratorFromDyn<'a, (&'a T, &'a T)>; + + fn index_get(&'a self, _key: &Self::Key) -> Option { + Some(IteratorFromDyn::new(|| self.0.iter_all_added())) + } + + fn len(&self) -> usize { 1 } +} + +impl<'a, T: Clone + Hash + Eq> RelIndexReadAll<'a> for EqRelIndNone<'a, T> { + type Key = (); + + type Value = (&'a T, &'a T); + + type ValueIteratorType = IteratorFromDyn<'a, (&'a T, &'a T)>; + + type AllIteratorType = std::option::IntoIter<(Self::Key, Self::ValueIteratorType)>; + + fn iter_all(&'a self) -> Self::AllIteratorType { self.index_get(&()).map(|iter| ((), iter)).into_iter() } +} + +impl<'a, T: Clone + Hash + Eq> RelIndexWrite for EqRelIndNone<'a, T> { + type Key = (); + type Value = (T, T); + fn index_insert(&mut self, _key: Self::Key, _value: Self::Value) { /* noop */ + } +} + +impl<'a, T: Clone + Hash + Eq> RelIndexMerge for EqRelIndNone<'a, T> { + fn move_index_contents(_from: &mut Self, _to: &mut Self) { /* noop */ + } +} + +pub(crate) fn ref_to_singleton_tuple_ref(x: &T) -> &(T,) { unsafe { transmute(x) } } + +#[test] +fn test_ref_to_singleton_tuple_ref() { + use std::mem::size_of; + println!("size_of::>(): {}", size_of::>()); + println!("size_of::<(Vec,)>(): {}", size_of::<(Vec,)>()); + + let x = vec![1, 2, 3]; + let x2 = ref_to_singleton_tuple_ref(&x); + assert_eq!(&x, &x2.0); +} diff --git a/byods/ascent-byods-rels/src/eqrel_ternary.rs b/byods/ascent-byods-rels/src/eqrel_ternary.rs index 0d35657..beaef4d 100644 --- a/byods/ascent-byods-rels/src/eqrel_ternary.rs +++ b/byods/ascent-byods-rels/src/eqrel_ternary.rs @@ -1,470 +1,513 @@ -#[doc(hidden)] -#[macro_export] -macro_rules! eqrel_ternary_rel { - (($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, ()) => { - $crate::fake_vec::FakeVec<($col0, $col1, $col2)> - }; -} -pub use eqrel_ternary_rel as rel; - -#[doc(hidden)] -#[macro_export] -macro_rules! eqrel_rel_ternary_full_ind { - (($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), $key: ty, $val: ty) => { - $crate::eqrel_ternary::ToEqRel2IndFull<$col0, $col1> - }; -} -pub use eqrel_rel_ternary_full_ind as rel_full_ind; - -#[doc(hidden)] -#[macro_export] -macro_rules! eqrel_ternary_rel_ind { - (($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [], $key: ty, $val: ty) => { - $crate::eqrel_ternary::ToEqRel2IndNone<$col0, $col1> - }; - (($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [0], $key: ty, $val: ty) => { - $crate::eqrel_ternary::ToEqRel2Ind0<$col0, $col1> - }; - (($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [1], $key: ty, $val: ty) => { - $crate::eqrel_ternary::ToEqRel2Ind1<$col0, $col1> - }; - (($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [2], $key: ty, $val: ty) => { - $crate::eqrel_ternary::ToEqRel2Ind2<$col0, $col1> - }; - (($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [0, 1], $key: ty, $val: ty) => { - $crate::eqrel_ternary::ToEqRel2Ind0_1<$col0, $col1> - }; - (($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [0, 2], $key: ty, $val: ty) => { - $crate::eqrel_ternary::ToEqRel2Ind0_1<$col0, $col1> - }; - (($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [1, 2], $key: ty, $val: ty) => { - $crate::eqrel_ternary::ToEqRel2Ind1_2<$col0, $col1> - }; - (($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [0, 1, 2], $key: ty, $val: ty) => { - $crate::eqrel_ternary::ToEqRel2IndFull<$col0, $col1> - }; -} -pub use eqrel_ternary_rel_ind as rel_ind; - -#[doc(hidden)] -#[macro_export] -macro_rules! eqrel_ternary_rel_ind_common { - (($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, ()) => { - $crate::eqrel_ternary::EqRel2IndCommonWithReverse<$col0, $col1> - }; -} -pub use eqrel_ternary_rel_ind_common as rel_ind_common; - - -use ascent::internal::{RelIndexRead, RelIndexReadAll, RelIndexWrite, RelIndexMerge, RelFullIndexWrite, RelFullIndexRead}; -use ascent::internal::ToRelIndex; - -use itertools::Itertools; -use rustc_hash::FxHasher; -use crate::eqrel_ind::{EqRelIndCommon, ref_to_singleton_tuple_ref}; -use crate::iterator_from_dyn::IteratorFromDyn; -use crate::rel_boilerplate::NoopRelIndexWrite; -use std::hash::{Hash, BuildHasherDefault}; -use std::iter::Map; -use std::marker::PhantomData; - -type FxHashSet = hashbrown::hash_set::HashSet>; -type FxHashMap = hashbrown::hash_map::HashMap>; - - -fn _test_macros() { - let _x: rel!((u32, u64, u64), [[0,1], [0]], ser, ()); - let _full_ind: rel_full_ind!((u32, u64, u64), [[0, 1], [0]], ser, (), (u32, u32), ()); - let _ind_0: rel_ind!((u32, u64, u64), [[0, 1], [0]], ser, (), [0], (u32,) , (u32,)); -} - -#[derive(Clone)] -pub struct EqRel2IndWrapper(EqRel2IndCommon); -pub type EqRel2IndCommonWithReverse = EqRel2IndWrapper; -pub type EqRel2IndCommonWithoutReverse = EqRel2IndWrapper; - -impl RelIndexMerge for EqRel2IndWrapper { - fn move_index_contents(from: &mut Self, to: &mut Self) { - EqRel2IndCommon::move_index_contents(&mut from.0, &mut to.0) - } - fn merge_delta_to_total_new_to_delta(new: &mut Self, delta: &mut Self, total: &mut Self) { - EqRel2IndCommon::merge_delta_to_total_new_to_delta(&mut new.0, &mut delta.0, &mut total.0) - } -} - -impl Default for EqRel2IndWrapper { - fn default() -> Self { - let reverse_map = if WITH_REVERSE { Some(Default::default()) } else { None }; - Self(EqRel2IndCommon { map: Default::default(), reverse_map }) - } -} - -pub trait ToEqRel2IndCommon { - fn to_eq_rel2_ind_common(&self) -> &EqRel2IndCommon; - fn to_eq_rel2_ind_common_mut(&mut self) -> &mut EqRel2IndCommon; -} - -impl ToEqRel2IndCommon for EqRel2IndWrapper -where T0: Clone + Hash + Eq, T1: Clone + Hash + Eq -{ - fn to_eq_rel2_ind_common(&self) -> &EqRel2IndCommon { &self.0 } - fn to_eq_rel2_ind_common_mut(&mut self) -> &mut EqRel2IndCommon { &mut self.0 } -} - -#[derive(Clone)] -pub struct EqRel2IndCommon { - map: FxHashMap>, - reverse_map: Option>>, -} - -impl EqRel2IndCommon { - fn insert_into_reverse_map(&mut self, key: &(T0, T1, T1)) { - if let Some(reverse_map) = &mut self.reverse_map { - reverse_map.raw_entry_mut().from_key(&key.1).or_insert_with(|| (key.1.clone(), Default::default())).1.insert(key.0.clone()); - reverse_map.raw_entry_mut().from_key(&key.2).or_insert_with(|| (key.2.clone(), Default::default())).1.insert(key.0.clone()); - } - } -} - -impl RelFullIndexWrite for EqRel2IndCommon { - type Key = (T0, T1, T1); - type Value = (); - - fn insert_if_not_present(&mut self, key: &Self::Key, (): Self::Value) -> bool { - self.insert_into_reverse_map(key); - match self.map.entry(key.0.clone()) { - hashbrown::hash_map::Entry::Occupied(mut occ) => { - occ.get_mut().insert_if_not_present(&(key.1.clone(), key.2.clone()), ()) - }, - hashbrown::hash_map::Entry::Vacant(vac) => { - let mut eqrel = EqRelIndCommon::default(); - eqrel.index_insert((key.1.clone(), key.2.clone()), ()); - vac.insert(eqrel); - true - }, - } - } -} - -impl RelIndexWrite for EqRel2IndCommon { - type Key = (T0, T1, T1); - type Value = (); - - fn index_insert(&mut self, key: Self::Key, (): Self::Value) { - self.insert_into_reverse_map(&key); - self.map.entry(key.0).or_default().index_insert((key.1, key.2), ()); - } -} - -impl RelIndexMerge for EqRel2IndCommon { - fn move_index_contents(_from: &mut Self, _to: &mut Self) { - unimplemented!("merge_delta_to_total_new_to_delta must be used instead") - } - - fn merge_delta_to_total_new_to_delta(new: &mut Self, delta: &mut Self, total: &mut Self) { - for (t0, mut delta_eqrel) in delta.map.drain() { - let mut new_eqrel = new.map.remove(&t0).unwrap_or_default(); - let total_eqrel = total.map.entry(t0).or_default(); - RelIndexMerge::merge_delta_to_total_new_to_delta(&mut new_eqrel, &mut delta_eqrel, total_eqrel); - } - for (t0, new_eqrel) in new.map.drain() { - delta.map.insert(t0, new_eqrel); - } - - // TODO not sure about this - if delta.reverse_map.is_some() { - crate::utils::move_hash_map_of_hash_set_contents(&mut delta.reverse_map.as_mut().unwrap(), total.reverse_map.as_mut().unwrap()); - crate::utils::move_hash_map_of_hash_set_contents(new.reverse_map.as_mut().unwrap(), delta.reverse_map.as_mut().unwrap()); - } - } -} - -impl EqRel2IndCommon { - pub(crate) fn iter_all_added(& self) -> impl Iterator { - self.map.iter().flat_map(|(t0, eqrel)| { - eqrel.iter_all_added().map(move |(t1, t2)| (t0, t1, t2)) - }) - } -} - -pub struct EqRel2Ind0_1<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq>( - &'a EqRel2IndCommon -); - -impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexRead<'a> for EqRel2Ind0_1<'a, T0, T1> { - type Key = (T0, T1); - type Value = (&'a T1, ); - - type IteratorType = IteratorFromDyn<'a, (&'a T1, )>; - - fn index_get(&'a self, key: &Self::Key) -> Option { - - let eqrel = self.0.map.get(&key.0)?; - let _ = eqrel.set_of_added(&key.1)?; - let key_1 = key.1.clone(); - let producer = move || eqrel.set_of_added(&key_1).unwrap().map(|x| (x,)); - - Some(IteratorFromDyn::new(producer)) - } - - fn len(&self) -> usize { - let sample_size = 4; - let (count, sum) = self.0.map.values().take(sample_size).map(|eqrel| eqrel.combined.elem_ids.len()).fold((0, 0), |(c, s), x| (c + 1, s + x)); - - sum * self.0.map.len() / count.max(1) - } -} - -impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexReadAll<'a> for EqRel2Ind0_1<'a, T0, T1> { - type Key = (&'a T0, &'a T1); - type Value = (&'a T1, ); - - type ValueIteratorType = std::iter::Once<(&'a T1, )>; - - type AllIteratorType = Box + 'a>; - - fn iter_all(&'a self) -> Self::AllIteratorType { - Box::new(self.0.map.iter().flat_map(|(t0, eqrel)| { - eqrel.iter_all_added().map(move |(t1, t2)| ((t0, t1), std::iter::once((t2,)) )) - })) - } -} - -pub struct EqRel2Ind0<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq>( - &'a EqRel2IndCommon -); - -impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexReadAll<'a> for EqRel2Ind0<'a, T0, T1> { - type Key = &'a (T0, ); - type Value = (&'a T1, &'a T1); - - type ValueIteratorType = IteratorFromDyn<'a, Self::Value>; - - type AllIteratorType = Box + 'a>; - - fn iter_all(&'a self) -> Self::AllIteratorType { - Box::new(self.0.map.iter().map(|(t, eqrel)| ( ref_to_singleton_tuple_ref(t), IteratorFromDyn::new(|| eqrel.iter_all_added())))) - } -} - -impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexRead<'a> for EqRel2Ind0<'a, T0, T1> { - type Key = (T0, ); - type Value = (&'a T1, &'a T1); - - type IteratorType = IteratorFromDyn<'a, Self::Value>; - fn index_get(&'a self, key: &Self::Key) -> Option { - let eqrel = self.0.map.get(&key.0)?; - Some(IteratorFromDyn::new(|| eqrel.iter_all_added())) - } - - fn len(&self) -> usize { - self.0.map.len() - } -} - -pub struct EqRel2Ind1<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq>( - &'a EqRel2IndCommon -); - -impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexReadAll<'a> for EqRel2Ind1<'a, T0, T1> { - type Key = &'a (T1, ); - type Value = (&'a T0, &'a T1); - - type ValueIteratorType = IteratorFromDyn<'a, Self::Value>; - type AllIteratorType = Box + 'a>; - - fn iter_all(&'a self) -> Self::AllIteratorType { - Box::new(self.0.reverse_map.as_ref().unwrap().keys().map(ref_to_singleton_tuple_ref).map(|k| (k, self.index_get(k).unwrap()))) - } -} - -impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexRead<'a> for EqRel2Ind1<'a, T0, T1> { - type Key = (T1, ); - type Value = (&'a T0, &'a T1); - - type IteratorType = IteratorFromDyn<'a, Self::Value>; - - fn index_get(&'a self, key: &Self::Key) -> Option { - let t0s = self.0.reverse_map.as_ref().unwrap().get(&key.0)?; - let t1 = key.0.clone(); - let res = move || t0s.iter().zip(std::iter::repeat(t1.clone())).flat_map(move |(t0, t1)| self.0.map.get(t0).unwrap().set_of_added(&t1).into_iter().flatten().map(move |t2| (t0, t2))); - Some(IteratorFromDyn::new(res)) - } - - fn len(&self) -> usize { - self.0.reverse_map.as_ref().unwrap().len() - } -} - -pub struct EqRel2Ind1_2<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq>( - &'a EqRel2IndCommon -); - -impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexReadAll<'a> for EqRel2Ind1_2<'a, T0, T1> { - type Key = (&'a T1, &'a T1); - type Value = (&'a T0, ); - - type ValueIteratorType = Map>, for<'aa> fn(&'aa T0) -> (&T0,)>; - - type AllIteratorType = Box + 'a>; - - fn iter_all(&'a self) -> Self::AllIteratorType { - let res = self.0.reverse_map.as_ref().unwrap().iter() - .cartesian_product(self.0.reverse_map.as_ref().unwrap().iter()) - .map(|((t0, t0_set), (t1, t1_set))| { - let intersection: Self::ValueIteratorType = t0_set.intersection(t1_set).map(|x| (x, )); - ((t0, t1), intersection) - }); - Box::new(res) - } -} - -impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexRead<'a> for EqRel2Ind1_2<'a, T0, T1> { - type Key = (T1, T1); - - type Value = (&'a T0, ); - - type IteratorType = IteratorFromDyn<'a, Self::Value>; - - fn index_get(&'a self, key: &Self::Key) -> Option { - let t0s = self.0.reverse_map.as_ref().unwrap().get(&key.0)?; - - let key = key.clone(); - let res = move || t0s.iter().zip(std::iter::repeat(key.clone())) - .filter(|(t0, key)| self.0.map.get(*t0).unwrap().added_contains(&key.0, &key.1)) - .map(|(t0, _key)| (t0, )); - - Some(IteratorFromDyn::new(res)) - } - - fn len(&self) -> usize { - let sample_size = 4; - let sum = self.0.map.values().take(sample_size).map(|eqrel| eqrel.len()).sum::(); - let map_len = self.0.map.len(); - sum / sample_size.min(map_len).max(1) - } -} - -pub struct ToEqRel2IndFull(PhantomData<(T0, T1)>); - -impl Default for ToEqRel2IndFull { - fn default() -> Self { Self(PhantomData) } -} - -impl ToRelIndex for ToEqRel2IndFull -where Rel: ToEqRel2IndCommon -{ - type RelIndex<'a> = EqRel2IndFull<'a, T0, T1> where Self: 'a, Rel: 'a; - fn to_rel_index<'a>(&'a self, rel: &'a Rel) -> Self::RelIndex<'a> { EqRel2IndFull(rel.to_eq_rel2_ind_common()) } - - type RelIndexWrite<'a> = &'a mut EqRel2IndCommon where Self: 'a, Rel: 'a; - fn to_rel_index_write<'a>(&'a mut self, rel: &'a mut Rel) -> Self::RelIndexWrite<'a> { - rel.to_eq_rel2_ind_common_mut() - } -} - -pub struct EqRel2IndFull<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq>( - &'a EqRel2IndCommon -); - -impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexReadAll<'a> for EqRel2IndFull<'a, T0, T1> { - type Key = (&'a T0, &'a T1, &'a T1); - - type Value = &'a (); - - type ValueIteratorType = std::iter::Once<&'a()>; - - type AllIteratorType = Box + 'a>; - - fn iter_all(&'a self) -> Self::AllIteratorType { - Box::new(self.0.iter_all_added().map(|t| (t, std::iter::once(&())))) - } -} - -impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexRead<'a> for EqRel2IndFull<'a, T0, T1> { - type Key = (T0, T1, T1); - type Value = &'a (); - - type IteratorType = std::iter::Once<&'a ()>; - - fn index_get(&'a self, key: &Self::Key) -> Option { - if self.contains_key(key) { Some(std::iter::once(&())) } - else { None } - } - - fn len(&self) -> usize { - let sample_size = 4; - let sum = self.0.map.values().take(sample_size).map(|eqrel| eqrel.len()).sum::(); - let map_len = self.0.map.len(); - sum * map_len / sample_size.min(map_len).max(1) - } -} - -impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelFullIndexRead<'a> for EqRel2IndFull<'a, T0, T1> { - type Key = (T0, T1, T1); - - fn contains_key(&'a self, (t0, t1, t2): &Self::Key) -> bool { - if let Some(eqrel) = self.0.map.get(t0) { eqrel.added_contains(t1, t2) } - else { false } - } -} - -pub struct EqRel2IndNone<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq>( - &'a EqRel2IndCommon -); - -impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexReadAll<'a> for EqRel2IndNone<'a, T0, T1> { - type Key = &'a (); - type Value = (&'a T0, &'a T1, &'a T1); - - type ValueIteratorType = >::IteratorType; - - type AllIteratorType = std::iter::Once<(Self::Key, Self::ValueIteratorType)>; - - fn iter_all(&'a self) -> Self::AllIteratorType { - std::iter::once((&(), self.index_get(&()).unwrap())) - } -} - -impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexRead<'a> for EqRel2IndNone<'a, T0, T1> { - type Key = (); - type Value = (&'a T0, &'a T1, &'a T1); - - type IteratorType = IteratorFromDyn<'a, Self::Value>; - fn index_get(&'a self, (): &Self::Key) -> Option { - Some(IteratorFromDyn::new(|| self.0.iter_all_added())) - } - - fn len(&self) -> usize { - 1 - } -} - -macro_rules! to_eq_rel2 { - ($name: ident, $key: ty, $val: ty) => {paste::paste!{ - pub struct [](PhantomData<(T0, T1)>); - - impl Default for [] { - fn default() -> Self { Self(PhantomData) } - } - - impl ToRelIndex for [] - where Rel: ToEqRel2IndCommon - { - type RelIndex<'a> = $name<'a, T0, T1> where Self: 'a, Rel: 'a; - fn to_rel_index<'a>(&'a self, rel: &'a Rel) -> Self::RelIndex<'a> { $name(rel.to_eq_rel2_ind_common()) } - - type RelIndexWrite<'a> = NoopRelIndexWrite<$key, $val> where Self: 'a, Rel: 'a; - fn to_rel_index_write<'a>(&'a mut self, _rel: &'a mut Rel) -> Self::RelIndexWrite<'a> { - NoopRelIndexWrite::default() - } - } - }}; -} - -to_eq_rel2!(EqRel2IndNone, (), (T0, T1, T1)); -to_eq_rel2!(EqRel2Ind0, (T0, ), (T1, T1)); -to_eq_rel2!(EqRel2Ind0_1, (T0, T1), (T1, )); -to_eq_rel2!(EqRel2Ind1, (T1, ), (T0, T1)); -to_eq_rel2!(EqRel2Ind1_2, (T1, T1), (T0, )); -// to_eq_rel2!(EqRel2IndFull, (T0, T1, T1), ()); \ No newline at end of file +#[doc(hidden)] +#[macro_export] +macro_rules! eqrel_ternary_rel { + (($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, ()) => { + $crate::fake_vec::FakeVec<($col0, $col1, $col2)> + }; +} +pub use eqrel_ternary_rel as rel; + +#[doc(hidden)] +#[macro_export] +macro_rules! eqrel_rel_ternary_full_ind { + (($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), $key: ty, $val: ty) => { + $crate::eqrel_ternary::ToEqRel2IndFull<$col0, $col1> + }; +} +pub use eqrel_rel_ternary_full_ind as rel_full_ind; + +#[doc(hidden)] +#[macro_export] +macro_rules! eqrel_ternary_rel_ind { + (($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [], $key: ty, $val: ty) => { + $crate::eqrel_ternary::ToEqRel2IndNone<$col0, $col1> + }; + (($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [0], $key: ty, $val: ty) => { + $crate::eqrel_ternary::ToEqRel2Ind0<$col0, $col1> + }; + (($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [1], $key: ty, $val: ty) => { + $crate::eqrel_ternary::ToEqRel2Ind1<$col0, $col1> + }; + (($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [2], $key: ty, $val: ty) => { + $crate::eqrel_ternary::ToEqRel2Ind2<$col0, $col1> + }; + (($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [0, 1], $key: ty, $val: ty) => { + $crate::eqrel_ternary::ToEqRel2Ind0_1<$col0, $col1> + }; + (($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [0, 2], $key: ty, $val: ty) => { + $crate::eqrel_ternary::ToEqRel2Ind0_1<$col0, $col1> + }; + (($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [1, 2], $key: ty, $val: ty) => { + $crate::eqrel_ternary::ToEqRel2Ind1_2<$col0, $col1> + }; + (($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [0, 1, 2], $key: ty, $val: ty) => { + $crate::eqrel_ternary::ToEqRel2IndFull<$col0, $col1> + }; +} +pub use eqrel_ternary_rel_ind as rel_ind; + +#[doc(hidden)] +#[macro_export] +macro_rules! eqrel_ternary_rel_ind_common { + (($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, ()) => { + $crate::eqrel_ternary::EqRel2IndCommonWithReverse<$col0, $col1> + }; +} +use std::hash::{BuildHasherDefault, Hash}; +use std::iter::Map; +use std::marker::PhantomData; + +use ascent::internal::{ + RelFullIndexRead, RelFullIndexWrite, RelIndexMerge, RelIndexRead, RelIndexReadAll, RelIndexWrite, ToRelIndex, +}; +pub use eqrel_ternary_rel_ind_common as rel_ind_common; +use itertools::Itertools; +use rustc_hash::FxHasher; + +use crate::eqrel_ind::{EqRelIndCommon, ref_to_singleton_tuple_ref}; +use crate::iterator_from_dyn::IteratorFromDyn; +use crate::rel_boilerplate::NoopRelIndexWrite; + +type FxHashSet = hashbrown::hash_set::HashSet>; +type FxHashMap = hashbrown::hash_map::HashMap>; + +fn _test_macros() { + let _x: rel!((u32, u64, u64), [[0, 1], [0]], ser, ()); + let _full_ind: rel_full_ind!((u32, u64, u64), [[0, 1], [0]], ser, (), (u32, u32), ()); + let _ind_0: rel_ind!((u32, u64, u64), [[0, 1], [0]], ser, (), [0], (u32,), (u32,)); +} + +#[derive(Clone)] +pub struct EqRel2IndWrapper( + EqRel2IndCommon, +); +pub type EqRel2IndCommonWithReverse = EqRel2IndWrapper; +pub type EqRel2IndCommonWithoutReverse = EqRel2IndWrapper; + +impl RelIndexMerge + for EqRel2IndWrapper +{ + fn move_index_contents(from: &mut Self, to: &mut Self) { + EqRel2IndCommon::move_index_contents(&mut from.0, &mut to.0) + } + fn merge_delta_to_total_new_to_delta(new: &mut Self, delta: &mut Self, total: &mut Self) { + EqRel2IndCommon::merge_delta_to_total_new_to_delta(&mut new.0, &mut delta.0, &mut total.0) + } +} + +impl Default + for EqRel2IndWrapper +{ + fn default() -> Self { + let reverse_map = if WITH_REVERSE { Some(Default::default()) } else { None }; + Self(EqRel2IndCommon { map: Default::default(), reverse_map }) + } +} + +pub trait ToEqRel2IndCommon { + fn to_eq_rel2_ind_common(&self) -> &EqRel2IndCommon; + fn to_eq_rel2_ind_common_mut(&mut self) -> &mut EqRel2IndCommon; +} + +impl ToEqRel2IndCommon for EqRel2IndWrapper +where + T0: Clone + Hash + Eq, + T1: Clone + Hash + Eq, +{ + fn to_eq_rel2_ind_common(&self) -> &EqRel2IndCommon { &self.0 } + fn to_eq_rel2_ind_common_mut(&mut self) -> &mut EqRel2IndCommon { &mut self.0 } +} + +#[derive(Clone)] +pub struct EqRel2IndCommon { + map: FxHashMap>, + reverse_map: Option>>, +} + +impl EqRel2IndCommon { + fn insert_into_reverse_map(&mut self, key: &(T0, T1, T1)) { + if let Some(reverse_map) = &mut self.reverse_map { + reverse_map + .raw_entry_mut() + .from_key(&key.1) + .or_insert_with(|| (key.1.clone(), Default::default())) + .1 + .insert(key.0.clone()); + reverse_map + .raw_entry_mut() + .from_key(&key.2) + .or_insert_with(|| (key.2.clone(), Default::default())) + .1 + .insert(key.0.clone()); + } + } +} + +impl RelFullIndexWrite for EqRel2IndCommon { + type Key = (T0, T1, T1); + type Value = (); + + fn insert_if_not_present(&mut self, key: &Self::Key, (): Self::Value) -> bool { + self.insert_into_reverse_map(key); + match self.map.entry(key.0.clone()) { + hashbrown::hash_map::Entry::Occupied(mut occ) => + occ.get_mut().insert_if_not_present(&(key.1.clone(), key.2.clone()), ()), + hashbrown::hash_map::Entry::Vacant(vac) => { + let mut eqrel = EqRelIndCommon::default(); + eqrel.index_insert((key.1.clone(), key.2.clone()), ()); + vac.insert(eqrel); + true + }, + } + } +} + +impl RelIndexWrite for EqRel2IndCommon { + type Key = (T0, T1, T1); + type Value = (); + + fn index_insert(&mut self, key: Self::Key, (): Self::Value) { + self.insert_into_reverse_map(&key); + self.map.entry(key.0).or_default().index_insert((key.1, key.2), ()); + } +} + +impl RelIndexMerge for EqRel2IndCommon { + fn move_index_contents(_from: &mut Self, _to: &mut Self) { + unimplemented!("merge_delta_to_total_new_to_delta must be used instead") + } + + fn merge_delta_to_total_new_to_delta(new: &mut Self, delta: &mut Self, total: &mut Self) { + for (t0, mut delta_eqrel) in delta.map.drain() { + let mut new_eqrel = new.map.remove(&t0).unwrap_or_default(); + let total_eqrel = total.map.entry(t0).or_default(); + RelIndexMerge::merge_delta_to_total_new_to_delta(&mut new_eqrel, &mut delta_eqrel, total_eqrel); + } + for (t0, new_eqrel) in new.map.drain() { + delta.map.insert(t0, new_eqrel); + } + + // TODO not sure about this + if delta.reverse_map.is_some() { + crate::utils::move_hash_map_of_hash_set_contents( + &mut delta.reverse_map.as_mut().unwrap(), + total.reverse_map.as_mut().unwrap(), + ); + crate::utils::move_hash_map_of_hash_set_contents( + new.reverse_map.as_mut().unwrap(), + delta.reverse_map.as_mut().unwrap(), + ); + } + } +} + +impl EqRel2IndCommon { + pub(crate) fn iter_all_added(&self) -> impl Iterator { + self.map.iter().flat_map(|(t0, eqrel)| eqrel.iter_all_added().map(move |(t1, t2)| (t0, t1, t2))) + } +} + +pub struct EqRel2Ind0_1<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq>(&'a EqRel2IndCommon); + +impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexRead<'a> for EqRel2Ind0_1<'a, T0, T1> { + type Key = (T0, T1); + type Value = (&'a T1,); + + type IteratorType = IteratorFromDyn<'a, (&'a T1,)>; + + fn index_get(&'a self, key: &Self::Key) -> Option { + let eqrel = self.0.map.get(&key.0)?; + let _ = eqrel.set_of_added(&key.1)?; + let key_1 = key.1.clone(); + let producer = move || eqrel.set_of_added(&key_1).unwrap().map(|x| (x,)); + + Some(IteratorFromDyn::new(producer)) + } + + fn len(&self) -> usize { + let sample_size = 4; + let (count, sum) = self + .0 + .map + .values() + .take(sample_size) + .map(|eqrel| eqrel.combined.elem_ids.len()) + .fold((0, 0), |(c, s), x| (c + 1, s + x)); + + sum * self.0.map.len() / count.max(1) + } +} + +impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexReadAll<'a> for EqRel2Ind0_1<'a, T0, T1> { + type Key = (&'a T0, &'a T1); + type Value = (&'a T1,); + + type ValueIteratorType = std::iter::Once<(&'a T1,)>; + + type AllIteratorType = Box + 'a>; + + fn iter_all(&'a self) -> Self::AllIteratorType { + Box::new( + self + .0 + .map + .iter() + .flat_map(|(t0, eqrel)| eqrel.iter_all_added().map(move |(t1, t2)| ((t0, t1), std::iter::once((t2,))))), + ) + } +} + +pub struct EqRel2Ind0<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq>(&'a EqRel2IndCommon); + +impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexReadAll<'a> for EqRel2Ind0<'a, T0, T1> { + type Key = &'a (T0,); + type Value = (&'a T1, &'a T1); + + type ValueIteratorType = IteratorFromDyn<'a, Self::Value>; + + type AllIteratorType = Box + 'a>; + + fn iter_all(&'a self) -> Self::AllIteratorType { + Box::new( + self + .0 + .map + .iter() + .map(|(t, eqrel)| (ref_to_singleton_tuple_ref(t), IteratorFromDyn::new(|| eqrel.iter_all_added()))), + ) + } +} + +impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexRead<'a> for EqRel2Ind0<'a, T0, T1> { + type Key = (T0,); + type Value = (&'a T1, &'a T1); + + type IteratorType = IteratorFromDyn<'a, Self::Value>; + fn index_get(&'a self, key: &Self::Key) -> Option { + let eqrel = self.0.map.get(&key.0)?; + Some(IteratorFromDyn::new(|| eqrel.iter_all_added())) + } + + fn len(&self) -> usize { self.0.map.len() } +} + +pub struct EqRel2Ind1<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq>(&'a EqRel2IndCommon); + +impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexReadAll<'a> for EqRel2Ind1<'a, T0, T1> { + type Key = &'a (T1,); + type Value = (&'a T0, &'a T1); + + type ValueIteratorType = IteratorFromDyn<'a, Self::Value>; + type AllIteratorType = Box + 'a>; + + fn iter_all(&'a self) -> Self::AllIteratorType { + Box::new( + self + .0 + .reverse_map + .as_ref() + .unwrap() + .keys() + .map(ref_to_singleton_tuple_ref) + .map(|k| (k, self.index_get(k).unwrap())), + ) + } +} + +impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexRead<'a> for EqRel2Ind1<'a, T0, T1> { + type Key = (T1,); + type Value = (&'a T0, &'a T1); + + type IteratorType = IteratorFromDyn<'a, Self::Value>; + + fn index_get(&'a self, key: &Self::Key) -> Option { + let t0s = self.0.reverse_map.as_ref().unwrap().get(&key.0)?; + let t1 = key.0.clone(); + let res = move || { + t0s.iter().zip(std::iter::repeat(t1.clone())).flat_map(move |(t0, t1)| { + self.0.map.get(t0).unwrap().set_of_added(&t1).into_iter().flatten().map(move |t2| (t0, t2)) + }) + }; + Some(IteratorFromDyn::new(res)) + } + + fn len(&self) -> usize { self.0.reverse_map.as_ref().unwrap().len() } +} + +pub struct EqRel2Ind1_2<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq>(&'a EqRel2IndCommon); + +impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexReadAll<'a> for EqRel2Ind1_2<'a, T0, T1> { + type Key = (&'a T1, &'a T1); + type Value = (&'a T0,); + + type ValueIteratorType = + Map>, for<'aa> fn(&'aa T0) -> (&T0,)>; + + type AllIteratorType = Box + 'a>; + + fn iter_all(&'a self) -> Self::AllIteratorType { + let res = self + .0 + .reverse_map + .as_ref() + .unwrap() + .iter() + .cartesian_product(self.0.reverse_map.as_ref().unwrap().iter()) + .map(|((t0, t0_set), (t1, t1_set))| { + let intersection: Self::ValueIteratorType = t0_set.intersection(t1_set).map(|x| (x,)); + ((t0, t1), intersection) + }); + Box::new(res) + } +} + +impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexRead<'a> for EqRel2Ind1_2<'a, T0, T1> { + type Key = (T1, T1); + + type Value = (&'a T0,); + + type IteratorType = IteratorFromDyn<'a, Self::Value>; + + fn index_get(&'a self, key: &Self::Key) -> Option { + let t0s = self.0.reverse_map.as_ref().unwrap().get(&key.0)?; + + let key = key.clone(); + let res = move || { + t0s.iter() + .zip(std::iter::repeat(key.clone())) + .filter(|(t0, key)| self.0.map.get(*t0).unwrap().added_contains(&key.0, &key.1)) + .map(|(t0, _key)| (t0,)) + }; + + Some(IteratorFromDyn::new(res)) + } + + fn len(&self) -> usize { + let sample_size = 4; + let sum = self.0.map.values().take(sample_size).map(|eqrel| eqrel.len()).sum::(); + let map_len = self.0.map.len(); + sum / sample_size.min(map_len).max(1) + } +} + +pub struct ToEqRel2IndFull(PhantomData<(T0, T1)>); + +impl Default for ToEqRel2IndFull { + fn default() -> Self { Self(PhantomData) } +} + +impl ToRelIndex for ToEqRel2IndFull +where Rel: ToEqRel2IndCommon +{ + type RelIndex<'a> + = EqRel2IndFull<'a, T0, T1> + where + Self: 'a, + Rel: 'a; + fn to_rel_index<'a>(&'a self, rel: &'a Rel) -> Self::RelIndex<'a> { EqRel2IndFull(rel.to_eq_rel2_ind_common()) } + + type RelIndexWrite<'a> + = &'a mut EqRel2IndCommon + where + Self: 'a, + Rel: 'a; + fn to_rel_index_write<'a>(&'a mut self, rel: &'a mut Rel) -> Self::RelIndexWrite<'a> { + rel.to_eq_rel2_ind_common_mut() + } +} + +pub struct EqRel2IndFull<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq>(&'a EqRel2IndCommon); + +impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexReadAll<'a> for EqRel2IndFull<'a, T0, T1> { + type Key = (&'a T0, &'a T1, &'a T1); + + type Value = &'a (); + + type ValueIteratorType = std::iter::Once<&'a ()>; + + type AllIteratorType = Box + 'a>; + + fn iter_all(&'a self) -> Self::AllIteratorType { + Box::new(self.0.iter_all_added().map(|t| (t, std::iter::once(&())))) + } +} + +impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexRead<'a> for EqRel2IndFull<'a, T0, T1> { + type Key = (T0, T1, T1); + type Value = &'a (); + + type IteratorType = std::iter::Once<&'a ()>; + + fn index_get(&'a self, key: &Self::Key) -> Option { + if self.contains_key(key) { Some(std::iter::once(&())) } else { None } + } + + fn len(&self) -> usize { + let sample_size = 4; + let sum = self.0.map.values().take(sample_size).map(|eqrel| eqrel.len()).sum::(); + let map_len = self.0.map.len(); + sum * map_len / sample_size.min(map_len).max(1) + } +} + +impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelFullIndexRead<'a> for EqRel2IndFull<'a, T0, T1> { + type Key = (T0, T1, T1); + + fn contains_key(&'a self, (t0, t1, t2): &Self::Key) -> bool { + if let Some(eqrel) = self.0.map.get(t0) { eqrel.added_contains(t1, t2) } else { false } + } +} + +pub struct EqRel2IndNone<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq>(&'a EqRel2IndCommon); + +impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexReadAll<'a> for EqRel2IndNone<'a, T0, T1> { + type Key = &'a (); + type Value = (&'a T0, &'a T1, &'a T1); + + type ValueIteratorType = >::IteratorType; + + type AllIteratorType = std::iter::Once<(Self::Key, Self::ValueIteratorType)>; + + fn iter_all(&'a self) -> Self::AllIteratorType { std::iter::once((&(), self.index_get(&()).unwrap())) } +} + +impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexRead<'a> for EqRel2IndNone<'a, T0, T1> { + type Key = (); + type Value = (&'a T0, &'a T1, &'a T1); + + type IteratorType = IteratorFromDyn<'a, Self::Value>; + fn index_get(&'a self, (): &Self::Key) -> Option { + Some(IteratorFromDyn::new(|| self.0.iter_all_added())) + } + + fn len(&self) -> usize { 1 } +} + +macro_rules! to_eq_rel2 { + ($name: ident, $key: ty, $val: ty) => {paste::paste!{ + pub struct [](PhantomData<(T0, T1)>); + + impl Default for [] { + fn default() -> Self { Self(PhantomData) } + } + + impl ToRelIndex for [] + where Rel: ToEqRel2IndCommon + { + type RelIndex<'a> = $name<'a, T0, T1> where Self: 'a, Rel: 'a; + fn to_rel_index<'a>(&'a self, rel: &'a Rel) -> Self::RelIndex<'a> { $name(rel.to_eq_rel2_ind_common()) } + + type RelIndexWrite<'a> = NoopRelIndexWrite<$key, $val> where Self: 'a, Rel: 'a; + fn to_rel_index_write<'a>(&'a mut self, _rel: &'a mut Rel) -> Self::RelIndexWrite<'a> { + NoopRelIndexWrite::default() + } + } + }}; +} + +to_eq_rel2!(EqRel2IndNone, (), (T0, T1, T1)); +to_eq_rel2!(EqRel2Ind0, (T0,), (T1, T1)); +to_eq_rel2!(EqRel2Ind0_1, (T0, T1), (T1,)); +to_eq_rel2!(EqRel2Ind1, (T1,), (T0, T1)); +to_eq_rel2!(EqRel2Ind1_2, (T1, T1), (T0,)); +// to_eq_rel2!(EqRel2IndFull, (T0, T1, T1), ()); diff --git a/byods/ascent-byods-rels/src/fake_vec.rs b/byods/ascent-byods-rels/src/fake_vec.rs index 588d151..1206c67 100644 --- a/byods/ascent-byods-rels/src/fake_vec.rs +++ b/byods/ascent-byods-rels/src/fake_vec.rs @@ -1,31 +1,25 @@ -use std::marker::PhantomData; -use std::ops::Index; - -pub struct FakeVec { _phantom: PhantomData } - -impl Default for FakeVec { - fn default() -> Self { Self { _phantom: PhantomData } } -} - -impl FakeVec { - #[inline(always)] - pub fn push(&self, _: T) { - } - - pub fn len(&self) -> usize { - 0 - } - - pub fn iter(&self) -> std::iter::Empty<&T> { - std::iter::empty() - } -} - -impl Index for FakeVec { - type Output = T; - - fn index(&self, _index: usize) -> &Self::Output { - panic!("FakeVec is empty!") - } -} - +use std::marker::PhantomData; +use std::ops::Index; + +pub struct FakeVec { + _phantom: PhantomData, +} + +impl Default for FakeVec { + fn default() -> Self { Self { _phantom: PhantomData } } +} + +impl FakeVec { + #[inline(always)] + pub fn push(&self, _: T) {} + + pub fn len(&self) -> usize { 0 } + + pub fn iter(&self) -> std::iter::Empty<&T> { std::iter::empty() } +} + +impl Index for FakeVec { + type Output = T; + + fn index(&self, _index: usize) -> &Self::Output { panic!("FakeVec is empty!") } +} diff --git a/byods/ascent-byods-rels/src/iterator_from_dyn.rs b/byods/ascent-byods-rels/src/iterator_from_dyn.rs index ac56026..a7c639c 100644 --- a/byods/ascent-byods-rels/src/iterator_from_dyn.rs +++ b/byods/ascent-byods-rels/src/iterator_from_dyn.rs @@ -1,32 +1,27 @@ -use std::rc::Rc; - - -pub struct IteratorFromDyn<'a, T> { - iter: Box + 'a>, - producer: Rc Box + 'a> + 'a>, -} - -impl<'a, T> IteratorFromDyn<'a, T> { - pub fn from_box_clo Box + 'a> + 'a>(producer: F) -> Self { - let iter = producer(); - Self { iter, producer: Rc::new(producer) as _ } - } - - pub fn new Iter + 'a, Iter: Iterator + 'a>(producer: F) -> Self { - Self::from_box_clo(move || Box::new(producer())) - } -} -impl<'a, T> Iterator for IteratorFromDyn<'a, T> { - type Item = T; - - #[inline(always)] - fn next(&mut self) -> Option { - self.iter.next() - } -} - -impl<'a, T> Clone for IteratorFromDyn<'a, T> { - fn clone(&self) -> Self { - Self { iter: (self.producer)(), producer: self.producer.clone() } - } -} \ No newline at end of file +use std::rc::Rc; + +pub struct IteratorFromDyn<'a, T> { + iter: Box + 'a>, + producer: Rc Box + 'a> + 'a>, +} + +impl<'a, T> IteratorFromDyn<'a, T> { + pub fn from_box_clo Box + 'a> + 'a>(producer: F) -> Self { + let iter = producer(); + Self { iter, producer: Rc::new(producer) as _ } + } + + pub fn new Iter + 'a, Iter: Iterator + 'a>(producer: F) -> Self { + Self::from_box_clo(move || Box::new(producer())) + } +} +impl<'a, T> Iterator for IteratorFromDyn<'a, T> { + type Item = T; + + #[inline(always)] + fn next(&mut self) -> Option { self.iter.next() } +} + +impl<'a, T> Clone for IteratorFromDyn<'a, T> { + fn clone(&self) -> Self { Self { iter: (self.producer)(), producer: self.producer.clone() } } +} diff --git a/byods/ascent-byods-rels/src/lib.rs b/byods/ascent-byods-rels/src/lib.rs index 8891260..4e8524e 100644 --- a/byods/ascent-byods-rels/src/lib.rs +++ b/byods/ascent-byods-rels/src/lib.rs @@ -1,9 +1,9 @@ -//! data structures for [`ascent`](https://github.com/s-arash/ascent) relations, +//! data structures for [`ascent`](https://github.com/s-arash/ascent) relations, //! made possible by Ascent's [BYODS](https://dl.acm.org/doi/pdf/10.1145/3622840) feature #![cfg_attr(not(test), deny(unused_crate_dependencies))] -// See Cargo.toml for why this is needed. +// See Cargo.toml for why this is needed. use syn as _; mod union_find; diff --git a/byods/ascent-byods-rels/src/rel_boilerplate.rs b/byods/ascent-byods-rels/src/rel_boilerplate.rs index f5441e5..0ad4af6 100644 --- a/byods/ascent-byods-rels/src/rel_boilerplate.rs +++ b/byods/ascent-byods-rels/src/rel_boilerplate.rs @@ -1,21 +1,21 @@ -use std::marker::PhantomData; - -use ascent::internal::{RelIndexMerge, RelIndexWrite}; - -pub struct NoopRelIndexWrite(PhantomData<(K, V)>); - -impl Default for NoopRelIndexWrite { - #[inline(always)] - fn default() -> Self { Self(PhantomData) } -} - -impl RelIndexMerge for NoopRelIndexWrite { - fn move_index_contents(_from: &mut Self, _to: &mut Self) { } -} - -impl RelIndexWrite for NoopRelIndexWrite { - type Key = K; - type Value = V; - #[inline(always)] - fn index_insert(&mut self, _key: Self::Key, _value: Self::Value) { } -} +use std::marker::PhantomData; + +use ascent::internal::{RelIndexMerge, RelIndexWrite}; + +pub struct NoopRelIndexWrite(PhantomData<(K, V)>); + +impl Default for NoopRelIndexWrite { + #[inline(always)] + fn default() -> Self { Self(PhantomData) } +} + +impl RelIndexMerge for NoopRelIndexWrite { + fn move_index_contents(_from: &mut Self, _to: &mut Self) {} +} + +impl RelIndexWrite for NoopRelIndexWrite { + type Key = K; + type Value = V; + #[inline(always)] + fn index_insert(&mut self, _key: Self::Key, _value: Self::Value) {} +} diff --git a/byods/ascent-byods-rels/src/test.rs b/byods/ascent-byods-rels/src/test.rs index a15b230..d04c6cf 100644 --- a/byods/ascent-byods-rels/src/test.rs +++ b/byods/ascent-byods-rels/src/test.rs @@ -1,339 +1,339 @@ -#![cfg(test)] - -use ascent::{ascent, ascent_run}; -use itertools::Itertools; - -#[test] -fn test_eq_rel2_in_ascent() { - ascent! { - struct EqRel2TestProg; - - #[ds(crate::eqrel)] - relation indexed_eq_rel(u32, u64, u64); - - relation seed(u32); - seed(x) <-- for x in 0..2; - - relation seed2(u32, u64, u64); - - seed2(a, 1, 2) <-- seed(a); - seed2(a, x + 1, y + 1) <-- seed2(a, x, y), if *y < 20; - - indexed_eq_rel(a, x, y) <-- seed2(a, x, y); - - relation indexed_eq_rel_materialized(u32, u64, u64); - indexed_eq_rel_materialized(a, x, y) <-- indexed_eq_rel(a, x, y); - } - - let mut prog = EqRel2TestProg::default(); - - prog.run(); - - println!("indexed_eq_rel_materialized len: {}", prog.indexed_eq_rel_materialized.len()); - assert_eq!(prog.indexed_eq_rel_materialized.len(), 2 * 20usize.pow(2)); -} - -#[test] -fn test_eq_rel2_in_ascent2() { - - let test_cases = vec![ - (0..2).flat_map(|a| (3..5).flat_map(move |x| (9..11).map(move |y| (a, x, y)))).collect_vec(), - vec![(0, 10, 11), (0, 12, 13), (0, 13, 14)], - vec![(0, 10, 11), (0, 12, 13), (0, 13, 14), (0, 12, 11), (1, 11, 14), (1, 12, 13), (1, 13, 14)], - ]; - for (i, seed_rel) in test_cases.into_iter().enumerate() { - println!("test {}", i); - let res = ascent_run! { - struct EqRel2TestProg; - - #[ds(crate::eqrel)] - relation indexed_eq_rel(u32, u64, u64); - - relation indexed_eq_rel_explicit(u32, u64, u64); - - indexed_eq_rel_explicit(a, x, x), indexed_eq_rel_explicit(a, y, x) <-- indexed_eq_rel_explicit(a, x, y); - indexed_eq_rel_explicit(a, x, z) <-- indexed_eq_rel_explicit(a, x, y), indexed_eq_rel_explicit(a, y, z); - - indexed_eq_rel(a, x, y), indexed_eq_rel_explicit(a, x, y) <-- for (a, x, y) in seed_rel; - - - relation indexed_eq_rel_materialized(u32, u64, u64); - indexed_eq_rel_materialized(a, x, y) <-- indexed_eq_rel(a, x, y); - - relation foo(u32, u64, u64); - foo(a, x, y) <-- for a in 0..2, for x in 8..12, for y in 10..15; - - relation test1_expected(u32, u64, u64); - relation test1_actual(u32, u64, u64); - - test1_expected(a, x, y) <-- foo(a, x, y), indexed_eq_rel_explicit(a, x, y); - test1_actual(a, x, y) <-- foo(a, x, y), indexed_eq_rel(a, x, y); - - relation bar(u32, u64); - bar(a, x) <-- for a in 0..3, for x in 10..16; - - relation test2_expected(u32, u64, u64); - relation test2_actual(u32, u64, u64); - - test2_expected(a, x, y) <-- bar(a, x), indexed_eq_rel_explicit(a, x, y); - test2_actual(a, x, y) <-- bar(a, x), indexed_eq_rel(a, x, y); - - relation test3_expected(u32, u64, u64); - relation test3_actual(u32, u64, u64); - - test3_expected(a, x, y) <-- bar(a, _), indexed_eq_rel_explicit(a, x, y); - test3_actual(a, x, y) <-- bar(a, _), indexed_eq_rel(a, x, y); - - relation test4_expected(u32, u64, u64); - relation test4_actual(u32, u64, u64); - - test4_expected(a, x, y) <-- bar(_, x), indexed_eq_rel_explicit(a, x, y); - test4_actual(a, x, y) <-- bar(_, x), indexed_eq_rel_explicit(a, x, y); - - relation test5_expected(u32, u64, u64); - relation test5_actual(u32, u64, u64); - - test5_expected(a, x, x) <-- for x in [5, 7, 9, 11, 13], indexed_eq_rel_explicit(a, x, x); - test5_actual(a, x, x) <-- for x in [5, 7, 9, 11, 13], indexed_eq_rel(a, x, x); - - relation test6_expected(u32, u64, u64); - relation test6_actual(u32, u64, u64); - - test6_expected(a, x, y) <-- for x in 8..14, indexed_eq_rel_explicit(a, x, y); - test6_actual(a, x, y) <-- for x in 8..14, indexed_eq_rel(a, x, y); - }; - - println!("indexed_eq_rel_materialized len: {}", res.indexed_eq_rel_materialized.len()); - println!("test1_actual len: {}", res.test1_actual.len()); - println!("test2_actual len: {}", res.test2_actual.len()); - println!("test3_actual len: {}", res.test3_actual.len()); - println!("test4_actual len: {}", res.test4_actual.len()); - println!("test5_actual len: {}", res.test5_actual.len()); - println!("test6_actual len: {}", res.test6_actual.len()); - - assert_eq!(res.indexed_eq_rel_materialized.len(), res.indexed_eq_rel_explicit.len()); - assert_eq!(res.test1_actual.len(), res.test1_expected.len()); - assert_eq!(res.test2_actual.len(), res.test2_expected.len()); - assert_eq!(res.test3_actual.len(), res.test3_expected.len()); - assert_eq!(res.test4_actual.len(), res.test4_expected.len()); - assert_eq!(res.test5_actual.len(), res.test5_expected.len()); - assert_eq!(res.test6_actual.len(), res.test6_expected.len()); - } -} - -#[test] -fn test_trrel_uf_in_ascent() { - let ub = 15; - let res = ascent_run! { - relation seed(u32); - seed(1); - seed(x + 1) <-- seed(x), if x < &ub; - - #[ds(crate::trrel_uf)] - relation tr(u32, u32); - - tr(x, x + 1), tr(x + 1, x) <-- seed(x); - - // included to pull all rules into the same SCC - seed(0) <-- if false, tr(1000, 2000); - }; - assert_eq!(res.__tr_ind_common.count_exact() as u32, (ub + 1).pow(2)); -} - -#[test] -fn test_trrel_uf_ternary_in_ascent() { - let ub = 15; - let res = ascent_run! { - relation seed(u32); - seed(1); - seed(x + 1) <-- seed(x), if x < &ub; - - #[ds(crate::trrel_uf)] - relation tr(u8, u32, u32); - - tr(i, x, x + 1), tr(i, x + 1, x) <-- seed(x), for i in 0..10; - - // included to pull all rules into the same SCC - seed(0) <-- if false, tr(_, 1000, 2000); - }; - assert_eq!(res.__tr_ind_common.0.map[&1].count_exact() as u32, (ub + 1).pow(2)); -} - - -#[test] -fn test_trrel1() { - - let test_cases = vec![ - vec![(1, 2), (2, 3)], - vec![(1, 2)], - (4..6).flat_map(|x| (6..9).map(move |y| (x, y))).collect(), - (0..5).map(|x| (x, x + 1)).collect() - ]; - - for (i, seed_rel) in test_cases.into_iter().enumerate() { - - let res = ascent_run! { - #[ds(crate::trrel)] - relation tr(u32, u32); - - relation tr_explicit(u32, u32); - tr_explicit(x, z) <-- tr_explicit(x, y), tr_explicit(y, z); - - tr(x, y), tr_explicit(x, y) <-- for (x, y) in seed_rel; - - tr(x, y) <-- tr(x, y); - relation tr_materialized(u32, u32); - tr_materialized(x, y) <-- tr(x, y); - - // testing New, Old variants of `TrRelIndCommon` - #[ds(crate::trrel)] - relation dummy1(u32, u32); - - #[ds(crate::trrel)] - relation dummy2(u32, u32); - - dummy2(x, y) <-- dummy1(x, y); - }; - - println!("TEST {}", i); - println!("explicit len: {}", res.tr_explicit.len()); - println!("materialized len: {}", res.tr_materialized.len()); - - assert_eq!(res.tr_explicit.len(), res.tr_materialized.len()); - } -} - -#[test] -fn test_trrel_reflexive_facts() { - - let test_cases = vec![ - vec![(1, 1), (1, 2)], - vec![(1, 1)], - (4..6).flat_map(|x| (5..7).map(move |y| (x, y))).collect(), - (0..5).map(|x| (x, x)).collect() - ]; - - for (i, seed_rel) in test_cases.into_iter().enumerate() { - - let res = ascent_run! { - #[ds(crate::trrel)] - relation tr(u32, u32); - - relation tr_explicit(u32, u32); - tr_explicit(x, z) <-- tr_explicit(x, y), tr_explicit(y, z); - - tr(x, y), tr_explicit(x, y) <-- for (x, y) in seed_rel; - - relation tr_materialized(u32, u32); - tr_materialized(x, y) <-- tr(x, y); - }; - - println!("TEST {}", i); - println!("explicit len: {}", res.tr_explicit.len()); - assert_eq!(res.tr_explicit.len(), res.tr_materialized.len()); - } -} - -#[test] -fn test_trrel2_in_ascent() { - let test_cases = vec![ - (1..4).flat_map(|a| (3..7).flat_map(move |x| (11..15).map(move |y| (a, x, y)))).collect_vec(), - (0..3).flat_map(|a| (0..15).filter(move |x| x % 3 == a as u64 % 3).flat_map(move |x| (a as u64 + 14..19).map(move |y| (a, x, y)))).collect_vec(), - (1..4).flat_map(|a| (5..10).flat_map(move |x| (12..17).map(move |y| (a, x, y)))).collect_vec(), - vec![(0, 10, 11), (0, 5, 13), (0, 6, 13), (0, 7, 14), (0, 12, 13), (0, 13, 14)], - vec![(0, 7, 8), (0, 8, 10), (0, 8, 13), (0, 9, 11), (1, 9, 14), (1, 9, 13), (1, 10, 14)] - .into_iter().flat_map(|(a, x, y)| (0..3).map(move |o| (a + o, x, y - o as u64))).collect(), - ]; - - for (i, seed_rel) in test_cases.into_iter().enumerate() { - let res = ascent_run! { - #[ds(crate::trrel)] - relation tr_indexed(u32, u64, u64); - - relation tr_indexed_explicit(u32, u64, u64); - tr_indexed_explicit(a, x, z) <-- tr_indexed_explicit(a, x, y), tr_indexed_explicit(a, y, z); - - relation empty(u32, u64); - - relation seed(u32, u64, u64) = seed_rel; - - tr_indexed(a, x, y), tr_indexed_explicit(a, x, y) <-- seed(a, x, y); - - // checking join works, should add no tuples - tr_indexed(a, x, y) <-- tr_indexed(a, x, y), empty(a, y); - - relation tr_indexed_materialized(u32, u64, u64); - - tr_indexed_materialized(a, x, y) <-- tr_indexed(a, x, y); - - relation foo(u32, u64); - foo(a, x) <-- for a in 0..1, for x in 4..12; - foo(a, x) <-- for a in 0..1, for x in (0..20).filter(|x| x % 3 == 2); - - - relation test1_actual(u32, u64); - relation test1_expected(u32, u64); - - test1_actual(a, x) <-- tr_indexed(a, x, _), foo(a, x); - test1_expected(a, x) <-- tr_indexed_explicit(a, x, _), foo(a, x); - - relation bar(u32, u64, u64); - bar(a, x, y) <-- for a in 0..1, for x in 5..6, for y in 10..12; - bar(a, x, y) <-- - for a in 1..5, - for x in (5..21).filter(|x| x % 3 == a as u64 % 3), - for y in (12..18).filter(|y| y % 5 == a as u64 % 5); - - - relation test2_actual(u32, u64, u64); - relation test2_expected(u32, u64, u64); - - test2_actual(a, x, y) <-- tr_indexed(a, x, y), bar(a, x, y); - test2_expected(a, x, y) <-- tr_indexed_explicit(a, x, y), bar(a, x, y); - - relation test3_actual(u32, u64, u64); - relation test3_expected(u32, u64, u64); - - test3_actual(a, x, y) <-- tr_indexed(a, x, y), bar(_, x, y); - test3_expected(a, x, y) <-- tr_indexed_explicit(a, x, y), bar(_, x, y); - - relation test4_actual(u32, u64, u64); - relation test4_expected(u32, u64, u64); - - test4_actual(a, x, y) <-- tr_indexed(a, x, y), bar(_, x, _); - test4_expected(a, x, y) <-- tr_indexed_explicit(a, x, y), bar(_, x, _); - - relation test5_actual(u32, u64, u64); - relation test5_expected(u32, u64, u64); - - test5_actual(a, x, y) <-- tr_indexed(a, x, y), bar(_, _, y); - test5_expected(a, x, y) <-- tr_indexed_explicit(a, x, y), bar(_, _, y); - - }; - - println!("============ TEST CASE {} ============", i); - println!("tr_indexed_explicit len: {}", res.tr_indexed_explicit.len()); - println!("test1_expected len: {}", res.test1_expected.len()); - println!("test2_expected len: {}", res.test2_expected.len()); - println!("test3_expected len: {}", res.test3_expected.len()); - println!("test4_expected len: {}", res.test4_expected.len()); - println!("test5_expected len: {}", res.test5_expected.len()); - - // println!("test4_expected: {:?}", res.test4_expected); - // println!("test4_actual: {:?}", res.test4_actual); - // use hashbrown::HashSet; - // println!("test_4 diff: {:?}", - // res.test4_expected.iter().cloned().collect::>() - // .symmetric_difference(&res.test4_actual.iter().cloned().collect::>())); - - - assert_eq!(res.test1_expected.len(), res.test1_actual.len()); - assert_eq!(res.test2_expected.len(), res.test2_actual.len()); - assert_eq!(res.test3_expected.len(), res.test3_actual.len()); - assert_eq!(res.test4_expected.len(), res.test4_actual.len()); - assert_eq!(res.test5_expected.len(), res.test5_actual.len()); - - assert_eq!(res.tr_indexed_materialized.len(), res.tr_indexed_explicit.len()); - - } -} +#![cfg(test)] + +use ascent::{ascent, ascent_run}; +use itertools::Itertools; + +#[test] +fn test_eq_rel2_in_ascent() { + ascent! { + struct EqRel2TestProg; + + #[ds(crate::eqrel)] + relation indexed_eq_rel(u32, u64, u64); + + relation seed(u32); + seed(x) <-- for x in 0..2; + + relation seed2(u32, u64, u64); + + seed2(a, 1, 2) <-- seed(a); + seed2(a, x + 1, y + 1) <-- seed2(a, x, y), if *y < 20; + + indexed_eq_rel(a, x, y) <-- seed2(a, x, y); + + relation indexed_eq_rel_materialized(u32, u64, u64); + indexed_eq_rel_materialized(a, x, y) <-- indexed_eq_rel(a, x, y); + } + + let mut prog = EqRel2TestProg::default(); + + prog.run(); + + println!("indexed_eq_rel_materialized len: {}", prog.indexed_eq_rel_materialized.len()); + assert_eq!(prog.indexed_eq_rel_materialized.len(), 2 * 20usize.pow(2)); +} + +#[test] +fn test_eq_rel2_in_ascent2() { + let test_cases = vec![ + (0..2).flat_map(|a| (3..5).flat_map(move |x| (9..11).map(move |y| (a, x, y)))).collect_vec(), + vec![(0, 10, 11), (0, 12, 13), (0, 13, 14)], + vec![(0, 10, 11), (0, 12, 13), (0, 13, 14), (0, 12, 11), (1, 11, 14), (1, 12, 13), (1, 13, 14)], + ]; + for (i, seed_rel) in test_cases.into_iter().enumerate() { + println!("test {}", i); + let res = ascent_run! { + struct EqRel2TestProg; + + #[ds(crate::eqrel)] + relation indexed_eq_rel(u32, u64, u64); + + relation indexed_eq_rel_explicit(u32, u64, u64); + + indexed_eq_rel_explicit(a, x, x), indexed_eq_rel_explicit(a, y, x) <-- indexed_eq_rel_explicit(a, x, y); + indexed_eq_rel_explicit(a, x, z) <-- indexed_eq_rel_explicit(a, x, y), indexed_eq_rel_explicit(a, y, z); + + indexed_eq_rel(a, x, y), indexed_eq_rel_explicit(a, x, y) <-- for (a, x, y) in seed_rel; + + + relation indexed_eq_rel_materialized(u32, u64, u64); + indexed_eq_rel_materialized(a, x, y) <-- indexed_eq_rel(a, x, y); + + relation foo(u32, u64, u64); + foo(a, x, y) <-- for a in 0..2, for x in 8..12, for y in 10..15; + + relation test1_expected(u32, u64, u64); + relation test1_actual(u32, u64, u64); + + test1_expected(a, x, y) <-- foo(a, x, y), indexed_eq_rel_explicit(a, x, y); + test1_actual(a, x, y) <-- foo(a, x, y), indexed_eq_rel(a, x, y); + + relation bar(u32, u64); + bar(a, x) <-- for a in 0..3, for x in 10..16; + + relation test2_expected(u32, u64, u64); + relation test2_actual(u32, u64, u64); + + test2_expected(a, x, y) <-- bar(a, x), indexed_eq_rel_explicit(a, x, y); + test2_actual(a, x, y) <-- bar(a, x), indexed_eq_rel(a, x, y); + + relation test3_expected(u32, u64, u64); + relation test3_actual(u32, u64, u64); + + test3_expected(a, x, y) <-- bar(a, _), indexed_eq_rel_explicit(a, x, y); + test3_actual(a, x, y) <-- bar(a, _), indexed_eq_rel(a, x, y); + + relation test4_expected(u32, u64, u64); + relation test4_actual(u32, u64, u64); + + test4_expected(a, x, y) <-- bar(_, x), indexed_eq_rel_explicit(a, x, y); + test4_actual(a, x, y) <-- bar(_, x), indexed_eq_rel_explicit(a, x, y); + + relation test5_expected(u32, u64, u64); + relation test5_actual(u32, u64, u64); + + test5_expected(a, x, x) <-- for x in [5, 7, 9, 11, 13], indexed_eq_rel_explicit(a, x, x); + test5_actual(a, x, x) <-- for x in [5, 7, 9, 11, 13], indexed_eq_rel(a, x, x); + + relation test6_expected(u32, u64, u64); + relation test6_actual(u32, u64, u64); + + test6_expected(a, x, y) <-- for x in 8..14, indexed_eq_rel_explicit(a, x, y); + test6_actual(a, x, y) <-- for x in 8..14, indexed_eq_rel(a, x, y); + }; + + println!("indexed_eq_rel_materialized len: {}", res.indexed_eq_rel_materialized.len()); + println!("test1_actual len: {}", res.test1_actual.len()); + println!("test2_actual len: {}", res.test2_actual.len()); + println!("test3_actual len: {}", res.test3_actual.len()); + println!("test4_actual len: {}", res.test4_actual.len()); + println!("test5_actual len: {}", res.test5_actual.len()); + println!("test6_actual len: {}", res.test6_actual.len()); + + assert_eq!(res.indexed_eq_rel_materialized.len(), res.indexed_eq_rel_explicit.len()); + assert_eq!(res.test1_actual.len(), res.test1_expected.len()); + assert_eq!(res.test2_actual.len(), res.test2_expected.len()); + assert_eq!(res.test3_actual.len(), res.test3_expected.len()); + assert_eq!(res.test4_actual.len(), res.test4_expected.len()); + assert_eq!(res.test5_actual.len(), res.test5_expected.len()); + assert_eq!(res.test6_actual.len(), res.test6_expected.len()); + } +} + +#[test] +fn test_trrel_uf_in_ascent() { + let ub = 15; + let res = ascent_run! { + relation seed(u32); + seed(1); + seed(x + 1) <-- seed(x), if x < &ub; + + #[ds(crate::trrel_uf)] + relation tr(u32, u32); + + tr(x, x + 1), tr(x + 1, x) <-- seed(x); + + // included to pull all rules into the same SCC + seed(0) <-- if false, tr(1000, 2000); + }; + assert_eq!(res.__tr_ind_common.count_exact() as u32, (ub + 1).pow(2)); +} + +#[test] +fn test_trrel_uf_ternary_in_ascent() { + let ub = 15; + let res = ascent_run! { + relation seed(u32); + seed(1); + seed(x + 1) <-- seed(x), if x < &ub; + + #[ds(crate::trrel_uf)] + relation tr(u8, u32, u32); + + tr(i, x, x + 1), tr(i, x + 1, x) <-- seed(x), for i in 0..10; + + // included to pull all rules into the same SCC + seed(0) <-- if false, tr(_, 1000, 2000); + }; + assert_eq!(res.__tr_ind_common.0.map[&1].count_exact() as u32, (ub + 1).pow(2)); +} + +#[test] +fn test_trrel1() { + let test_cases = vec![ + vec![(1, 2), (2, 3)], + vec![(1, 2)], + (4..6).flat_map(|x| (6..9).map(move |y| (x, y))).collect(), + (0..5).map(|x| (x, x + 1)).collect(), + ]; + + for (i, seed_rel) in test_cases.into_iter().enumerate() { + let res = ascent_run! { + #[ds(crate::trrel)] + relation tr(u32, u32); + + relation tr_explicit(u32, u32); + tr_explicit(x, z) <-- tr_explicit(x, y), tr_explicit(y, z); + + tr(x, y), tr_explicit(x, y) <-- for (x, y) in seed_rel; + + tr(x, y) <-- tr(x, y); + relation tr_materialized(u32, u32); + tr_materialized(x, y) <-- tr(x, y); + + // testing New, Old variants of `TrRelIndCommon` + #[ds(crate::trrel)] + relation dummy1(u32, u32); + + #[ds(crate::trrel)] + relation dummy2(u32, u32); + + dummy2(x, y) <-- dummy1(x, y); + }; + + println!("TEST {}", i); + println!("explicit len: {}", res.tr_explicit.len()); + println!("materialized len: {}", res.tr_materialized.len()); + + assert_eq!(res.tr_explicit.len(), res.tr_materialized.len()); + } +} + +#[test] +fn test_trrel_reflexive_facts() { + let test_cases = vec![ + vec![(1, 1), (1, 2)], + vec![(1, 1)], + (4..6).flat_map(|x| (5..7).map(move |y| (x, y))).collect(), + (0..5).map(|x| (x, x)).collect(), + ]; + + for (i, seed_rel) in test_cases.into_iter().enumerate() { + let res = ascent_run! { + #[ds(crate::trrel)] + relation tr(u32, u32); + + relation tr_explicit(u32, u32); + tr_explicit(x, z) <-- tr_explicit(x, y), tr_explicit(y, z); + + tr(x, y), tr_explicit(x, y) <-- for (x, y) in seed_rel; + + relation tr_materialized(u32, u32); + tr_materialized(x, y) <-- tr(x, y); + }; + + println!("TEST {}", i); + println!("explicit len: {}", res.tr_explicit.len()); + assert_eq!(res.tr_explicit.len(), res.tr_materialized.len()); + } +} + +#[test] +fn test_trrel2_in_ascent() { + let test_cases = vec![ + (1..4).flat_map(|a| (3..7).flat_map(move |x| (11..15).map(move |y| (a, x, y)))).collect_vec(), + (0..3) + .flat_map(|a| { + (0..15) + .filter(move |x| x % 3 == a as u64 % 3) + .flat_map(move |x| (a as u64 + 14..19).map(move |y| (a, x, y))) + }) + .collect_vec(), + (1..4).flat_map(|a| (5..10).flat_map(move |x| (12..17).map(move |y| (a, x, y)))).collect_vec(), + vec![(0, 10, 11), (0, 5, 13), (0, 6, 13), (0, 7, 14), (0, 12, 13), (0, 13, 14)], + vec![(0, 7, 8), (0, 8, 10), (0, 8, 13), (0, 9, 11), (1, 9, 14), (1, 9, 13), (1, 10, 14)] + .into_iter() + .flat_map(|(a, x, y)| (0..3).map(move |o| (a + o, x, y - o as u64))) + .collect(), + ]; + + for (i, seed_rel) in test_cases.into_iter().enumerate() { + let res = ascent_run! { + #[ds(crate::trrel)] + relation tr_indexed(u32, u64, u64); + + relation tr_indexed_explicit(u32, u64, u64); + tr_indexed_explicit(a, x, z) <-- tr_indexed_explicit(a, x, y), tr_indexed_explicit(a, y, z); + + relation empty(u32, u64); + + relation seed(u32, u64, u64) = seed_rel; + + tr_indexed(a, x, y), tr_indexed_explicit(a, x, y) <-- seed(a, x, y); + + // checking join works, should add no tuples + tr_indexed(a, x, y) <-- tr_indexed(a, x, y), empty(a, y); + + relation tr_indexed_materialized(u32, u64, u64); + + tr_indexed_materialized(a, x, y) <-- tr_indexed(a, x, y); + + relation foo(u32, u64); + foo(a, x) <-- for a in 0..1, for x in 4..12; + foo(a, x) <-- for a in 0..1, for x in (0..20).filter(|x| x % 3 == 2); + + + relation test1_actual(u32, u64); + relation test1_expected(u32, u64); + + test1_actual(a, x) <-- tr_indexed(a, x, _), foo(a, x); + test1_expected(a, x) <-- tr_indexed_explicit(a, x, _), foo(a, x); + + relation bar(u32, u64, u64); + bar(a, x, y) <-- for a in 0..1, for x in 5..6, for y in 10..12; + bar(a, x, y) <-- + for a in 1..5, + for x in (5..21).filter(|x| x % 3 == a as u64 % 3), + for y in (12..18).filter(|y| y % 5 == a as u64 % 5); + + + relation test2_actual(u32, u64, u64); + relation test2_expected(u32, u64, u64); + + test2_actual(a, x, y) <-- tr_indexed(a, x, y), bar(a, x, y); + test2_expected(a, x, y) <-- tr_indexed_explicit(a, x, y), bar(a, x, y); + + relation test3_actual(u32, u64, u64); + relation test3_expected(u32, u64, u64); + + test3_actual(a, x, y) <-- tr_indexed(a, x, y), bar(_, x, y); + test3_expected(a, x, y) <-- tr_indexed_explicit(a, x, y), bar(_, x, y); + + relation test4_actual(u32, u64, u64); + relation test4_expected(u32, u64, u64); + + test4_actual(a, x, y) <-- tr_indexed(a, x, y), bar(_, x, _); + test4_expected(a, x, y) <-- tr_indexed_explicit(a, x, y), bar(_, x, _); + + relation test5_actual(u32, u64, u64); + relation test5_expected(u32, u64, u64); + + test5_actual(a, x, y) <-- tr_indexed(a, x, y), bar(_, _, y); + test5_expected(a, x, y) <-- tr_indexed_explicit(a, x, y), bar(_, _, y); + + }; + + println!("============ TEST CASE {} ============", i); + println!("tr_indexed_explicit len: {}", res.tr_indexed_explicit.len()); + println!("test1_expected len: {}", res.test1_expected.len()); + println!("test2_expected len: {}", res.test2_expected.len()); + println!("test3_expected len: {}", res.test3_expected.len()); + println!("test4_expected len: {}", res.test4_expected.len()); + println!("test5_expected len: {}", res.test5_expected.len()); + + // println!("test4_expected: {:?}", res.test4_expected); + // println!("test4_actual: {:?}", res.test4_actual); + // use hashbrown::HashSet; + // println!("test_4 diff: {:?}", + // res.test4_expected.iter().cloned().collect::>() + // .symmetric_difference(&res.test4_actual.iter().cloned().collect::>())); + + assert_eq!(res.test1_expected.len(), res.test1_actual.len()); + assert_eq!(res.test2_expected.len(), res.test2_actual.len()); + assert_eq!(res.test3_expected.len(), res.test3_actual.len()); + assert_eq!(res.test4_expected.len(), res.test4_actual.len()); + assert_eq!(res.test5_expected.len(), res.test5_actual.len()); + + assert_eq!(res.tr_indexed_materialized.len(), res.tr_indexed_explicit.len()); + } +} diff --git a/byods/ascent-byods-rels/src/trrel.rs b/byods/ascent-byods-rels/src/trrel.rs index c58f40e..a100a50 100644 --- a/byods/ascent-byods-rels/src/trrel.rs +++ b/byods/ascent-byods-rels/src/trrel.rs @@ -1,139 +1,133 @@ -//! transitive relations for Ascent - -pub use crate::eqrel::rel_codegen as rel_codegen; - -#[doc(hidden)] -#[macro_export] -macro_rules! trrel_rel { - ($name: ident, ($col0: ty, $col1: ty), $indices: expr, ser, ()) => { - $crate::fake_vec::FakeVec<($col0, $col1)> - }; - - // ternary - ($name: ident, ($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, ()) => { - $crate::fake_vec::FakeVec<($col0, $col1, $col2)> - }; -} -pub use trrel_rel as rel; - -#[doc(hidden)] -#[macro_export] -macro_rules! trrel_rel_full_ind { - ($name: ident, ($col0: ty, $col1: ty), $indices: expr, ser, (), $key: ty, $val: ty) => { - $crate::trrel_binary_ind::ToTrRelIndFull<$col0> - }; - - // ternary - ($name: ident, ($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), $key: ty, $val: ty) => { - $crate::trrel_ternary_ind::ToTrRel2IndFull<$col0, $col1> - }; -} -pub use trrel_rel_full_ind as rel_full_ind; - -#[doc(hidden)] -#[macro_export] -macro_rules! trrel_rel_ind { - ($name: ident, ($col0: ty, $col1: ty), $indices: expr, ser, (), [0], $key: ty, $val: ty) => { - $crate::trrel_binary_ind::ToTrRelInd0<$col0> - }; - ($name: ident, ($col0: ty, $col1: ty), $indices: expr, ser, (), [1], $key: ty, $val: ty) => { - $crate::trrel_binary_ind::ToTrRelInd1<$col0> - }; - ($name: ident, ($col0: ty, $col1: ty), $indices: expr, ser, (), [], $key: ty, $val: ty) => { - $crate::trrel_binary_ind::ToTrRelIndNone<$col0> - }; - - // ternary - ($name: ident, ($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [], $key: ty, $val: ty) => { - $crate::trrel_ternary_ind::ToTrRel2IndNone<$col0, $col1> - }; - ($name: ident, ($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [0], $key: ty, $val: ty) => { - $crate::trrel_ternary_ind::ToTrRel2Ind0<$col0, $col1> - }; - ($name: ident, ($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [1], $key: ty, $val: ty) => { - $crate::trrel_ternary_ind::ToTrRel2Ind1<$col0, $col1> - }; - ($name: ident, ($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [2], $key: ty, $val: ty) => { - $crate::trrel_ternary_ind::ToTrRel2Ind2<$col0, $col1> - }; - ($name: ident, ($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [0, 1], $key: ty, $val: ty) => { - $crate::trrel_ternary_ind::ToTrRel2Ind0_1<$col0, $col1> - }; - ($name: ident, ($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [0, 2], $key: ty, $val: ty) => { - $crate::trrel_ternary_ind::ToTrRel2Ind0_2<$col0, $col1> - }; - ($name: ident, ($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [1, 2], $key: ty, $val: ty) => { - $crate::trrel_ternary_ind::ToTrRel2Ind1_2<$col0, $col1> - }; - ($name: ident, ($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [0, 1, 2], $key: ty, $val: ty) => { - $crate::trrel_ternary_ind::ToTrRel2IndFull<$col0, $col1> - }; -} -pub use trrel_rel_ind as rel_ind; - -#[doc(hidden)] -#[macro_export] -macro_rules! trrel_rel_ind_common { - ($name: ident, ($col0: ty, $col1: ty), $indices: expr, ser, ()) => { - $crate::trrel_binary_ind::TrRelIndCommon<$col0> - }; - - // ternary - ($name: ident, ($col0: ty, $col1: ty, $col2: ty), $indices: tt, ser, ()) => { - $crate::trrel_ternary_ind::TrRel2IndCommonWrapper< - // reverse_map_1 required: - {$crate::inds_contain!($indices, [1]) || $crate::inds_contain!($indices, [1, 2])}, - // reverse_map_2 required: - {$crate::inds_contain!($indices, [2]) || $crate::inds_contain!($indices, [1, 2])}, - $col0, $col1> - }; -} -pub use trrel_rel_ind_common as rel_ind_common; - -#[doc(hidden)] -#[macro_export] -macro_rules! inds_contain { - ([], $ind: tt) => { - false - }; - ([$head: tt], $ind: tt) => { - ($crate::arrs_eq!($head, $ind)) - }; - ([$head: tt, $($tail: tt),*], $ind: tt) => { - ($crate::arrs_eq!($head, $ind)) || $crate::inds_contain!([$($tail),*], $ind) - }; -} - -#[doc(hidden)] -#[macro_export] -macro_rules! arrs_eq { - ([], []) => { true }; - ([$x: expr], [$y: expr]) => { $x == $y }; - ([$x: expr, $($xs: expr),*], [$y: expr, $($ys: expr),*]) => { - $x == $y && $crate::arrs_eq!([$($xs),*], [$($ys),*]) - }; - ([$($xs: expr),*], [$($ys: expr),*]) => { false }; -} - -#[test] -fn test_arrs_eq() { - let test1 = arrs_eq!([1, 2], [1, 2]); - assert!(test1); - assert!(!arrs_eq!([1], [1, 2])); - assert!(arrs_eq!([1], [1])); - assert!(arrs_eq!([], [])); - assert!(!arrs_eq!([1, 2], [1])); -} - -#[cfg(test)] -#[allow(dead_code)] -fn _test_trrel_rel_ind_common() { - let _ind_common1: crate::trrel::rel_ind_common!(rel, (u64, u32, u32), [[], [0, 1], [0], [0, 1, 2]], ser, ()); - let _ind_common2: crate::trrel::rel_ind_common!( - rel, - (u32, u64, u64), - [[0, 1, 2], [0], [1], [0, 1]], - ser, - () - ); -} \ No newline at end of file +//! transitive relations for Ascent + +pub use crate::eqrel::rel_codegen; + +#[doc(hidden)] +#[macro_export] +macro_rules! trrel_rel { + ($name: ident, ($col0: ty, $col1: ty), $indices: expr, ser, ()) => { + $crate::fake_vec::FakeVec<($col0, $col1)> + }; + + // ternary + ($name: ident, ($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, ()) => { + $crate::fake_vec::FakeVec<($col0, $col1, $col2)> + }; +} +pub use trrel_rel as rel; + +#[doc(hidden)] +#[macro_export] +macro_rules! trrel_rel_full_ind { + ($name: ident, ($col0: ty, $col1: ty), $indices: expr, ser, (), $key: ty, $val: ty) => { + $crate::trrel_binary_ind::ToTrRelIndFull<$col0> + }; + + // ternary + ($name: ident, ($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), $key: ty, $val: ty) => { + $crate::trrel_ternary_ind::ToTrRel2IndFull<$col0, $col1> + }; +} +pub use trrel_rel_full_ind as rel_full_ind; + +#[doc(hidden)] +#[macro_export] +macro_rules! trrel_rel_ind { + ($name: ident, ($col0: ty, $col1: ty), $indices: expr, ser, (), [0], $key: ty, $val: ty) => { + $crate::trrel_binary_ind::ToTrRelInd0<$col0> + }; + ($name: ident, ($col0: ty, $col1: ty), $indices: expr, ser, (), [1], $key: ty, $val: ty) => { + $crate::trrel_binary_ind::ToTrRelInd1<$col0> + }; + ($name: ident, ($col0: ty, $col1: ty), $indices: expr, ser, (), [], $key: ty, $val: ty) => { + $crate::trrel_binary_ind::ToTrRelIndNone<$col0> + }; + + // ternary + ($name: ident, ($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [], $key: ty, $val: ty) => { + $crate::trrel_ternary_ind::ToTrRel2IndNone<$col0, $col1> + }; + ($name: ident, ($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [0], $key: ty, $val: ty) => { + $crate::trrel_ternary_ind::ToTrRel2Ind0<$col0, $col1> + }; + ($name: ident, ($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [1], $key: ty, $val: ty) => { + $crate::trrel_ternary_ind::ToTrRel2Ind1<$col0, $col1> + }; + ($name: ident, ($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [2], $key: ty, $val: ty) => { + $crate::trrel_ternary_ind::ToTrRel2Ind2<$col0, $col1> + }; + ($name: ident, ($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [0, 1], $key: ty, $val: ty) => { + $crate::trrel_ternary_ind::ToTrRel2Ind0_1<$col0, $col1> + }; + ($name: ident, ($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [0, 2], $key: ty, $val: ty) => { + $crate::trrel_ternary_ind::ToTrRel2Ind0_2<$col0, $col1> + }; + ($name: ident, ($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [1, 2], $key: ty, $val: ty) => { + $crate::trrel_ternary_ind::ToTrRel2Ind1_2<$col0, $col1> + }; + ($name: ident, ($col0: ty, $col1: ty, $col2: ty), $indices: expr, ser, (), [0, 1, 2], $key: ty, $val: ty) => { + $crate::trrel_ternary_ind::ToTrRel2IndFull<$col0, $col1> + }; +} +pub use trrel_rel_ind as rel_ind; + +#[doc(hidden)] +#[macro_export] +macro_rules! trrel_rel_ind_common { + ($name: ident, ($col0: ty, $col1: ty), $indices: expr, ser, ()) => { + $crate::trrel_binary_ind::TrRelIndCommon<$col0> + }; + + // ternary + ($name: ident, ($col0: ty, $col1: ty, $col2: ty), $indices: tt, ser, ()) => { + $crate::trrel_ternary_ind::TrRel2IndCommonWrapper< + // reverse_map_1 required: + {$crate::inds_contain!($indices, [1]) || $crate::inds_contain!($indices, [1, 2])}, + // reverse_map_2 required: + {$crate::inds_contain!($indices, [2]) || $crate::inds_contain!($indices, [1, 2])}, + $col0, $col1> + }; +} +pub use trrel_rel_ind_common as rel_ind_common; + +#[doc(hidden)] +#[macro_export] +macro_rules! inds_contain { + ([], $ind: tt) => { + false + }; + ([$head: tt], $ind: tt) => { + ($crate::arrs_eq!($head, $ind)) + }; + ([$head: tt, $($tail: tt),*], $ind: tt) => { + ($crate::arrs_eq!($head, $ind)) || $crate::inds_contain!([$($tail),*], $ind) + }; +} + +#[doc(hidden)] +#[macro_export] +macro_rules! arrs_eq { + ([], []) => { true }; + ([$x: expr], [$y: expr]) => { $x == $y }; + ([$x: expr, $($xs: expr),*], [$y: expr, $($ys: expr),*]) => { + $x == $y && $crate::arrs_eq!([$($xs),*], [$($ys),*]) + }; + ([$($xs: expr),*], [$($ys: expr),*]) => { false }; +} + +#[test] +fn test_arrs_eq() { + let test1 = arrs_eq!([1, 2], [1, 2]); + assert!(test1); + assert!(!arrs_eq!([1], [1, 2])); + assert!(arrs_eq!([1], [1])); + assert!(arrs_eq!([], [])); + assert!(!arrs_eq!([1, 2], [1])); +} + +#[cfg(test)] +#[allow(dead_code)] +fn _test_trrel_rel_ind_common() { + let _ind_common1: crate::trrel::rel_ind_common!(rel, (u64, u32, u32), [[], [0, 1], [0], [0, 1, 2]], ser, ()); + let _ind_common2: crate::trrel::rel_ind_common!(rel, (u32, u64, u64), [[0, 1, 2], [0], [1], [0, 1]], ser, ()); +} diff --git a/byods/ascent-byods-rels/src/trrel_binary.rs b/byods/ascent-byods-rels/src/trrel_binary.rs index f6b6c4a..6e9d441 100644 --- a/byods/ascent-byods-rels/src/trrel_binary.rs +++ b/byods/ascent-byods-rels/src/trrel_binary.rs @@ -1,103 +1,96 @@ -use std::hash::{Hash, BuildHasherDefault}; - -use hashbrown::{HashMap}; -use rustc_hash::FxHasher; - -pub type MyHashSetIter<'a, T> = hashbrown::hash_set::Iter<'a, T>; -pub type MyHashSet = hashbrown::HashSet; - -pub struct TrRel { - pub(crate) map: HashMap>, BuildHasherDefault>, - pub(crate) reverse_map: HashMap>, BuildHasherDefault>, - // pub(crate) precursor_map: HashMap>, BuildHasherDefault>, - // pub(crate) precursor_set: HashSet<(T, T), BuildHasherDefault>, - pub(crate) precursor_set: Vec<(T, T)>, - pub anti_reflexive: bool, -} - -impl Default for TrRel { - fn default() -> Self { - Self { - map: Default::default(), - reverse_map: Default::default(), - // precursor_map: Default::default(), - precursor_set: Default::default(), - anti_reflexive: true - } - } -} - - -impl TrRel { - - /// returns true if this tuple did not exist in the transitive relation - pub fn insert(&mut self, x: T, y: T) -> bool { - - // TODO is this correct? - - if x == y { - return false; - } - - if self.map.get(&x).map_or(false, |s| s.contains(&y)) { - return false; - } - - // if !self.precursor_map.entry(x.clone()).or_default().insert(y.clone()) { - // return false; - // } - // if !self.precursor_set.insert((x.clone(), y.clone())) { - // return false; - // } - self.precursor_set.push((x.clone(), y.clone())); - - let mut x_reverse_map = std::mem::take(self.reverse_map.entry(x.clone()).or_default()); - let mut y_map = std::mem::take(self.map.entry(y.clone()).or_default()); - // let y_map2 = y_map.iter().chain([&x]).map(|elem| (hash_one(self.map.hasher(), elem), elem.clone())).collect_vec(); - for x_prime in x_reverse_map.iter().chain([&x]) { - if x_prime != &y { - let x_prime_map = self.map.entry(x_prime.clone()).or_default(); - x_prime_map.extend(y_map.iter().chain([&y]).filter(|&a| a != x_prime).cloned()); - // set_extend_with_hash_no_check(x_prime_map, y_map2.iter().cloned()); - // for y_prime in y_map.iter().chain([&y]) { - // self.reverse_map.entry(y_prime.clone()).or_default().insert(x_prime.clone()); - // } - } - } - - // let x_reverse_map2 = x_reverse_map.iter().chain([&x]).map(|elem| (hash_one(self.map.hasher(), elem), elem.clone())).collect_vec(); - for y_prime in y_map.iter().chain([&y]) { - if y_prime != &x { - let y_prime_reverse_map = self.reverse_map.entry(y_prime.clone()).or_default(); - y_prime_reverse_map.extend(x_reverse_map.iter().chain([&x]).filter(|&a| a != y_prime).cloned()); - // set_extend_with_hash_no_check(y_prime_reverse_map, x_reverse_map2.iter().cloned()); - } - } - if x == y { - x_reverse_map.insert(y.clone()); - y_map.insert(x.clone()); - } - self.reverse_map.insert(x.clone(), x_reverse_map); - self.map.insert(y.clone(), y_map); - true - } - - pub fn iter_all(&self) -> impl Iterator + '_ { - self.map.iter().flat_map(|(x, x_set)| x_set.iter().map(move |y| (x, y))) - } - - #[inline] - pub fn contains(&self, x: &T, y: &T) -> bool { - self.map.get(x).map_or(false, |s| s.contains(y)) - } - - pub fn count_estimate(&self) -> usize { - let sample_size = 3; - let sum = self.map.values().take(sample_size).map(|x| x.len()).sum::(); - sum * self.map.len() / sample_size.min(self.map.len()).max(1) - } - - pub fn count_exact(&self) -> usize { - self.map.values().map(|x| x.len()).sum() - } -} +use std::hash::{BuildHasherDefault, Hash}; + +use hashbrown::HashMap; +use rustc_hash::FxHasher; + +pub type MyHashSetIter<'a, T> = hashbrown::hash_set::Iter<'a, T>; +pub type MyHashSet = hashbrown::HashSet; + +pub struct TrRel { + pub(crate) map: HashMap>, BuildHasherDefault>, + pub(crate) reverse_map: HashMap>, BuildHasherDefault>, + // pub(crate) precursor_map: HashMap>, BuildHasherDefault>, + // pub(crate) precursor_set: HashSet<(T, T), BuildHasherDefault>, + pub(crate) precursor_set: Vec<(T, T)>, + pub anti_reflexive: bool, +} + +impl Default for TrRel { + fn default() -> Self { + Self { + map: Default::default(), + reverse_map: Default::default(), + // precursor_map: Default::default(), + precursor_set: Default::default(), + anti_reflexive: true, + } + } +} + +impl TrRel { + /// returns true if this tuple did not exist in the transitive relation + pub fn insert(&mut self, x: T, y: T) -> bool { + // TODO is this correct? + + if x == y { + return false; + } + + if self.map.get(&x).map_or(false, |s| s.contains(&y)) { + return false; + } + + // if !self.precursor_map.entry(x.clone()).or_default().insert(y.clone()) { + // return false; + // } + // if !self.precursor_set.insert((x.clone(), y.clone())) { + // return false; + // } + self.precursor_set.push((x.clone(), y.clone())); + + let mut x_reverse_map = std::mem::take(self.reverse_map.entry(x.clone()).or_default()); + let mut y_map = std::mem::take(self.map.entry(y.clone()).or_default()); + // let y_map2 = y_map.iter().chain([&x]).map(|elem| (hash_one(self.map.hasher(), elem), elem.clone())).collect_vec(); + for x_prime in x_reverse_map.iter().chain([&x]) { + if x_prime != &y { + let x_prime_map = self.map.entry(x_prime.clone()).or_default(); + x_prime_map.extend(y_map.iter().chain([&y]).filter(|&a| a != x_prime).cloned()); + // set_extend_with_hash_no_check(x_prime_map, y_map2.iter().cloned()); + // for y_prime in y_map.iter().chain([&y]) { + // self.reverse_map.entry(y_prime.clone()).or_default().insert(x_prime.clone()); + // } + } + } + + // let x_reverse_map2 = x_reverse_map.iter().chain([&x]).map(|elem| (hash_one(self.map.hasher(), elem), elem.clone())).collect_vec(); + for y_prime in y_map.iter().chain([&y]) { + if y_prime != &x { + let y_prime_reverse_map = self.reverse_map.entry(y_prime.clone()).or_default(); + y_prime_reverse_map.extend(x_reverse_map.iter().chain([&x]).filter(|&a| a != y_prime).cloned()); + // set_extend_with_hash_no_check(y_prime_reverse_map, x_reverse_map2.iter().cloned()); + } + } + if x == y { + x_reverse_map.insert(y.clone()); + y_map.insert(x.clone()); + } + self.reverse_map.insert(x.clone(), x_reverse_map); + self.map.insert(y.clone(), y_map); + true + } + + pub fn iter_all(&self) -> impl Iterator + '_ { + self.map.iter().flat_map(|(x, x_set)| x_set.iter().map(move |y| (x, y))) + } + + #[inline] + pub fn contains(&self, x: &T, y: &T) -> bool { self.map.get(x).map_or(false, |s| s.contains(y)) } + + pub fn count_estimate(&self) -> usize { + let sample_size = 3; + let sum = self.map.values().take(sample_size).map(|x| x.len()).sum::(); + sum * self.map.len() / sample_size.min(self.map.len()).max(1) + } + + pub fn count_exact(&self) -> usize { self.map.values().map(|x| x.len()).sum() } +} diff --git a/byods/ascent-byods-rels/src/trrel_binary_ind.rs b/byods/ascent-byods-rels/src/trrel_binary_ind.rs index f512aa7..344681a 100644 --- a/byods/ascent-byods-rels/src/trrel_binary_ind.rs +++ b/byods/ascent-byods-rels/src/trrel_binary_ind.rs @@ -1,487 +1,504 @@ -use std::hash::{Hash, BuildHasherDefault}; -use std::iter::Map; -use std::marker::PhantomData; -use std::time::{Duration, Instant}; -use ascent::internal::{RelIndexMerge, RelIndexReadAll, RelIndexRead, RelFullIndexWrite, RelIndexWrite, RelFullIndexRead}; -use ascent::internal::ToRelIndex; -use hashbrown::HashMap; -use rustc_hash::FxHasher; -use crate::binary_rel::BinaryRel; -use crate::iterator_from_dyn::IteratorFromDyn; -use crate::rel_boilerplate::NoopRelIndexWrite; -use crate::trrel_binary::{MyHashSetIter, MyHashSet}; -use crate::utils::{move_hash_map_of_hash_set_contents_disjoint, move_hash_map_of_vec_contents}; - -// TODO do we still need two variants? -pub enum TrRelIndCommon { - New { rel: BinaryRel, anti_reflexive: bool }, - Old { rel: BinaryRel, anti_reflexive: bool }, -} - -impl TrRelIndCommon { - - pub fn anti_reflexive(&self) -> bool { - match self { - TrRelIndCommon::New{ anti_reflexive, .. } => *anti_reflexive, - TrRelIndCommon::Old{ anti_reflexive, .. } => *anti_reflexive, - } - } - - #[inline] - pub fn make_new() -> Self { Self::New { rel: Default::default(), anti_reflexive: true } } - - #[inline] - pub fn rel(&self) -> &BinaryRel { - match self { - TrRelIndCommon::New { rel, .. } => rel, - TrRelIndCommon::Old { rel, .. } => rel, - } - } - - pub fn unwrap_new_mut(&mut self) -> &mut BinaryRel { - match self { - TrRelIndCommon::New{ rel, .. } => rel, - TrRelIndCommon::Old{..} => panic!("TrRelIndCommon: unwrap_new_mut called on Old"), - } - } - pub fn unwrap_new(&self) -> &BinaryRel { - match self { - TrRelIndCommon::New{ rel, .. } => rel, - TrRelIndCommon::Old{..} => panic!("TrRelIndCommon: unwrap_new called on Old"), - } - } - - pub fn unwrap_old(&self) -> &BinaryRel { - match self { - TrRelIndCommon::Old{rel, ..} => rel, - TrRelIndCommon::New{..} => panic!("TrRelIndCommon: unwrap_old called on New"), - } - } - - #[inline] - pub fn insert(&mut self, x: T, y: T) -> bool { - let rel = self.unwrap_new_mut(); - rel.insert(x, y) - } - - #[inline] - pub fn insert_by_ref(&mut self, x: &T, y: &T) -> bool { - let rel = self.unwrap_new_mut(); - rel.insert_by_ref(x, y) - } - - pub fn is_empty(&self) -> bool { - match self { - TrRelIndCommon::New{rel, ..} => rel.map.is_empty(), - TrRelIndCommon::Old{rel, ..} => rel.map.is_empty(), - } - } - -} - -pub static mut MERGE_TIME: Duration = Duration::ZERO; -pub static mut MERGE_COUNT: usize = 0; - -pub trait ToTrRelIndCommon { - fn to_tr_rel_ind(&self) -> &TrRelIndCommon; - fn to_tr_rel_ind_mut(&mut self) -> &mut TrRelIndCommon; -} - -impl ToTrRelIndCommon for TrRelIndCommon { - fn to_tr_rel_ind(&self) -> &TrRelIndCommon { self } - fn to_tr_rel_ind_mut(&mut self) -> &mut TrRelIndCommon { self } -} - -impl Default for TrRelIndCommon { - #[inline] - fn default() -> Self {Self::Old { rel: Default::default(), anti_reflexive: true }} -} - -impl RelIndexMerge for TrRelIndCommon { - fn move_index_contents(_from: &mut Self, _to: &mut Self) { - panic!("merge_delta_to_total_new_to_delta must be called instead.") - } - - fn init(new: &mut Self, delta: &mut Self, total: &mut Self) { - assert!(matches!(delta, Self::Old { .. })); - assert!(matches!(total, Self::Old { .. })); - *new = Self::New { rel: Default::default(), anti_reflexive: delta.anti_reflexive() }; - } - - fn merge_delta_to_total_new_to_delta(new: &mut Self, delta: &mut Self, total: &mut Self) { - let before = Instant::now(); - let anti_reflexive = total.anti_reflexive(); - - let mut total_rel = match total { - TrRelIndCommon::New{..} => Default::default(), - TrRelIndCommon::Old{rel, ..} => std::mem::take(rel), - }; - let mut delta_rel = match delta { - TrRelIndCommon::New{..} => Default::default(), - TrRelIndCommon::Old{rel, ..} => std::mem::take(rel), - }; - - move_hash_map_of_hash_set_contents_disjoint(&mut delta_rel.map, &mut total_rel.map); - move_hash_map_of_vec_contents(&mut delta_rel.reverse_map, &mut total_rel.reverse_map); - - let mut new_delta = BinaryRel::::default(); - - // let new_delta_prog = ascent::ascent_run! { - // struct AscentProg; - // relation delta(T, T); - // relation new(T, T); - // relation total_map(T); - // relation total_reverse_map(T); - // total_map(x) <-- for x in total_rel.map.keys(); - // total_reverse_map(x) <-- for x in total_rel.reverse_map.keys(); - - // new(x, y), delta(x, y) <-- for (x, y) in new.unwrap_new().iter(); - // delta(x, z) <-- delta(x, y), new(y, z); - // delta(x, z) <-- delta(x, y), total_map(y), for z in total_rel.map.get(y).unwrap(); - // delta(x, z) <-- delta(y, z), total_reverse_map(y), for x in total_rel.reverse_map.get(y).unwrap(); - // }; - - // new_delta.reverse_map = new_delta_prog.delta_indices_1.0.into_iter() - // .map(|(k, v)| (k.0, v.into_iter().map(|x| x.0).collect())).collect(); - - // new_delta.map = new_delta_prog.delta_indices_0.0.into_iter() - // .map(|(k, v)| (k.0, v.into_iter().map(|x| x.0).collect())).collect(); - - - type RelMap = HashMap::>, BuildHasherDefault>; - type RelRevMap = HashMap::, BuildHasherDefault>; - - let new_rel = std::mem::take(new.unwrap_new_mut()); - let new_map = new_rel.map; - let mut delta_delta_map = new_map.clone(); - let mut delta_delta_rev_map = new_rel.reverse_map; - - let mut delta_total_map = RelMap::::default(); - let mut delta_total_rev_map = RelRevMap::::default(); - - let mut delta_new_map = RelMap::::default(); - let mut delta_new_rev_map = RelRevMap::::default(); - - fn join( - target: &mut RelMap, target_rev: &mut RelRevMap, rel1: &RelMap, rel2_rev: &RelRevMap, - mut can_add: impl FnMut(&T, &T) -> bool, name: &str - ) -> bool { - let mut changed = false; - if rel1.len() < rel2_rev.len() { - for (x, x_set) in rel1.iter() { - assert!(!x_set.is_empty(), "bad join {name}. rel1 has non-empty sets"); - if let Some(x_rev_set) = rel2_rev.get(x) { - assert!(!x_set.is_empty(), "bad join {name}. rel2_rev has non-empty sets"); - for w in x_rev_set { - let entry = target.entry(w.clone()).or_default(); - for y in x_set.iter() { - if !can_add(w, y) {continue} - if entry.insert(y.clone()) { - target_rev.entry(y.clone()).or_default().push(w.clone()); - changed = true; - } - } - if entry.is_empty() { - target.remove(w); - } - } - } - } - } else { - for (x, x_rev_set) in rel2_rev.iter() { - if let Some(x_set) = rel1.get(x) { - for w in x_rev_set { - let entry = target.entry(w.clone()).or_default(); - for y in x_set.iter() { - if !can_add(w, y) {continue} - if entry.insert(y.clone()) { - target_rev.entry(y.clone()).or_default().push(w.clone()); - changed = true; - } - } - if entry.is_empty() { - target.remove(w); - } - } - } - } - } - changed - } - loop { - - let mut cached_delta_delta_map_entry_for_can_add = None; - let mut cached_delta_delta_map_x_for_can_add = None; - let mut cached_delta_total_map_entry_for_can_add = None; - let mut cached_delta_total_map_x_for_can_add = None; - let mut cached_total_map_entry_for_can_add = None; - let mut cached_total_map_x_for_can_add = None; - let mut can_add = |x: &T, y: &T| { - if anti_reflexive && x == y { return false } - { - if cached_delta_delta_map_x_for_can_add.as_ref() != Some(x) { - cached_delta_delta_map_entry_for_can_add = delta_delta_map.get(x); - cached_delta_delta_map_x_for_can_add = Some(x.clone()); - }; - } - !cached_delta_delta_map_entry_for_can_add.map_or(false, |s| s.contains(y)) && - { - if cached_delta_total_map_x_for_can_add.as_ref() != Some(x) { - cached_delta_total_map_entry_for_can_add = delta_total_map.get(x); - cached_delta_total_map_x_for_can_add = Some(x.clone()); - }; - !cached_delta_total_map_entry_for_can_add.map_or(false, |s| s.contains(y)) - } && - { - if cached_total_map_x_for_can_add.as_ref() != Some(x) { - cached_total_map_entry_for_can_add = total_rel.map.get(x); - cached_total_map_x_for_can_add = Some(x.clone()); - } - !cached_total_map_entry_for_can_add.map_or(false, |s| s.contains(y)) - } - }; - - let join1 = join(&mut delta_new_map, &mut delta_new_rev_map, &delta_delta_map, &total_rel.reverse_map, &mut can_add, "join1"); - let join2 = join(&mut delta_new_map, &mut delta_new_rev_map, &total_rel.map, &delta_delta_rev_map, &mut can_add, "join2"); - let join3 = join(&mut delta_new_map, &mut delta_new_rev_map, &new_map, &delta_delta_rev_map, &mut can_add, "join3"); - - let changed = join1 | join2 | join3; - - move_hash_map_of_hash_set_contents_disjoint(&mut delta_delta_map, &mut delta_total_map); - move_hash_map_of_vec_contents(&mut delta_delta_rev_map, &mut delta_total_rev_map); - - assert!(delta_delta_map.is_empty()); - assert!(delta_delta_rev_map.is_empty()); - - std::mem::swap(&mut delta_delta_map, &mut delta_new_map); - std::mem::swap(&mut delta_delta_rev_map, &mut delta_new_rev_map); - - if !changed { break } - } - new_delta.map = delta_total_map; - new_delta.reverse_map = delta_total_rev_map; - - *total = TrRelIndCommon::Old{ rel: total_rel, anti_reflexive }; - *delta = TrRelIndCommon::Old{ rel: new_delta, anti_reflexive }; - *new = TrRelIndCommon::New{ rel: Default::default(), anti_reflexive }; - - unsafe { - MERGE_TIME += before.elapsed(); - MERGE_COUNT += 1; - } - } -} - -pub struct TrRelInd0<'a, T: Clone + Hash + Eq>(pub(crate) &'a TrRelIndCommon); - -impl<'a, T: Clone + Hash + Eq + 'a> RelIndexReadAll<'a> for TrRelInd0<'a, T> { - type Key = (&'a T, ); - type Value = (&'a T, ); - - type ValueIteratorType = Map, fn(&T) -> (&T,)>; - - type AllIteratorType = Map>>, for<'aa> fn((&'aa T, &'aa MyHashSet>)) -> ((&'aa T,), Map, for <'bb> fn(&'bb T) -> (&'bb T,)>)>; - - fn iter_all(&'a self) -> Self::AllIteratorType { - let res : Self::AllIteratorType = self.0.unwrap_old().map.iter().map(|(k, v)| ((k, ), v.iter().map(|x| (x, )))); - res - } -} - -impl<'a, T: Clone + Hash + Eq> RelIndexRead<'a> for TrRelInd0<'a, T> { - type Key = (T, ); - type Value = (&'a T, ); - - type IteratorType = Map, fn(&T) -> (&T,)>; - - fn index_get(&'a self, key: &Self::Key) -> Option { - let set = self.0.unwrap_old().map.get(&key.0)?; - let res: Self::IteratorType = set.iter().map(|x| (x, )); - Some(res) - } - - fn len(&self) -> usize { - self.0.unwrap_old().map.len() - } -} - -pub struct TrRelInd1<'a, T: Clone + Hash + Eq>(pub(crate) &'a TrRelIndCommon); - -impl<'a, T: Clone + Hash + Eq + 'a> RelIndexReadAll<'a> for TrRelInd1<'a, T> { - type Key = (&'a T, ); - type Value = (&'a T, ); - - type ValueIteratorType = Map, fn(&T) -> (&T,)>; - - type AllIteratorType = Map>, for<'aa> fn((&'aa T, &'aa Vec)) -> ((&'aa T,), Map, for <'bb> fn(&'bb T) -> (&'bb T,)>)>; - - fn iter_all(&'a self) -> Self::AllIteratorType { - let res : Self::AllIteratorType = self.0.rel().reverse_map.iter().map(|(k, v)| ((k, ), v.iter().map(|x| (x, )))); - res - } -} - -impl<'a, T: Clone + Hash + Eq> RelIndexRead<'a> for TrRelInd1<'a, T> { - type Key = (T, ); - type Value = (&'a T, ); - - type IteratorType = Map, fn(&T) -> (&T,)>; - - fn index_get(&'a self, key: &Self::Key) -> Option { - let set = self.0.rel().reverse_map.get(&key.0)?; - let res: Self::IteratorType = set.iter().map(|x| (x, )); - Some(res) - } - - fn len(&self) -> usize { - self.0.rel().reverse_map.len() - } -} - -pub struct TrRelIndNone<'a, T: Clone + Hash + Eq>(&'a TrRelIndCommon); - -impl<'a, T: Clone + Hash + Eq> RelIndexReadAll<'a> for TrRelIndNone<'a, T> { - type Key = (); - type Value = (&'a T, &'a T); - - type ValueIteratorType = >::IteratorType; - - type AllIteratorType = std::iter::Once<((), Self::ValueIteratorType)>; - - fn iter_all(&'a self) -> Self::AllIteratorType { - std::iter::once(((), self.index_get(&()).unwrap())) - } -} - -impl<'a, T: Clone + Hash + Eq> RelIndexRead<'a> for TrRelIndNone<'a, T> { - type Key = (); - type Value = (&'a T, &'a T); - - type IteratorType = IteratorFromDyn<'a, Self::Value>; - - fn index_get(&'a self, (): &Self::Key) -> Option { - println!("iterating TrRelIndNone. {} tuples", self.0.rel().map.values().map(|x| x.len()).sum::()); - let res = || self.0.rel().map.iter().flat_map(|(x, x_set)| x_set.iter().map(move |y| (x, y))); - Some(IteratorFromDyn::new(res)) - } - - fn len(&self) -> usize { - 1 - } -} - -pub struct TrRelIndFullWrite<'a, T: Clone + Hash + Eq>(&'a mut TrRelIndCommon); - -impl<'a, T: Clone + Hash + Eq> RelIndexMerge for TrRelIndFullWrite<'a, T> { - fn move_index_contents(_from: &mut Self, _to: &mut Self) { } //noop -} - -impl<'a, T: Clone + Hash + Eq> RelFullIndexWrite for TrRelIndFullWrite<'a, T> { - type Key = (T, T); - type Value = (); - - fn insert_if_not_present(&mut self, (x, y): &Self::Key, (): Self::Value) -> bool { - self.0.insert_by_ref(x, y) - } -} - -impl<'a, T: Clone + Hash + Eq> RelIndexWrite for TrRelIndFullWrite<'a, T> { - type Key = (T, T); - type Value = (); - - fn index_insert(&mut self, key: Self::Key, (): Self::Value) { - self.0.insert(key.0, key.1); - } -} - -pub struct TrRelIndFull<'a, T: Clone + Hash + Eq>(pub(crate) &'a TrRelIndCommon); - -impl<'a, T: Clone + Hash + Eq> RelFullIndexRead<'a> for TrRelIndFull<'a, T> { - type Key = (T, T); - - fn contains_key(&'a self, key: &Self::Key) -> bool { - self.0.rel().map.get(&key.0).map_or(false, |s| s.contains(&key.1)) - } -} - - -impl<'a, T: Clone + Hash + Eq> RelIndexReadAll<'a> for TrRelIndFull<'a, T> { - type Key = (&'a T, &'a T); - type Value = (); - - type ValueIteratorType = std::iter::Once; - type AllIteratorType = Box + 'a>; - - fn iter_all(&'a self) -> Self::AllIteratorType { - let res = self.0.rel().map.iter().flat_map(|(x, x_set)| x_set.iter().map(move |y| (x, y))) - .map(|key| (key, std::iter::once(()))); - Box::new(res) - } -} - -impl<'a, T: Clone + Hash + Eq> RelIndexRead<'a> for TrRelIndFull<'a, T> { - type Key = (T, T); - type Value = (); - - type IteratorType = std::iter::Once<()>; - - fn index_get(&'a self, key: &Self::Key) -> Option { - if self.0.rel().map.get(&key.0)?.contains(&key.1) { - Some(std::iter::once(())) - } else { - None - } - } - - fn len(&self) -> usize { - let sample_size = 3; - let sum: usize = self.0.rel().map.values().take(sample_size).map(|x| x.len()).sum(); - let map_len = self.0.rel().map.len(); - sum * map_len / sample_size.min(map_len).max(1) - } -} - -macro_rules! to_rel_ind { - ($name: ident, $key: ty, $val: ty) => {paste::paste!{ - pub struct [](PhantomData); - - impl Default for [] { - fn default() -> Self { Self(PhantomData) } - } - - impl ToRelIndex for [] - where Rel: ToTrRelIndCommon - { - type RelIndex<'a> = $name<'a, T> where Self: 'a, Rel: 'a; - fn to_rel_index<'a>(&'a self, rel: &'a Rel) -> Self::RelIndex<'a> { $name(rel.to_tr_rel_ind()) } - - type RelIndexWrite<'a> = NoopRelIndexWrite<$key, $val> where Self: 'a, Rel: 'a; - fn to_rel_index_write<'a>(&'a mut self, _rel: &'a mut Rel) -> Self::RelIndexWrite<'a> { - NoopRelIndexWrite::default() - } - } - }}; -} - -to_rel_ind!(TrRelIndNone, (), (T, T)); -to_rel_ind!(TrRelInd0, (T, ), (T, )); -to_rel_ind!(TrRelInd1, (T, ), (T, )); -// to_rel_ind!(TrRelIndFull, (T, T), ()); - -pub struct ToTrRelIndFull(PhantomData); - -impl Default for ToTrRelIndFull { - fn default() -> Self { Self(PhantomData) } -} -impl ToRelIndex for ToTrRelIndFull -where - Rel: ToTrRelIndCommon, -{ - type RelIndex<'a> = TrRelIndFull<'a,T>where Self:'a, Rel:'a; - fn to_rel_index<'a>(&'a self, rel: &'a Rel) -> Self::RelIndex<'a> { TrRelIndFull(rel.to_tr_rel_ind()) } - - type RelIndexWrite<'a> = TrRelIndFullWrite<'a, T> where Self:'a, Rel:'a; - fn to_rel_index_write<'a>(&'a mut self, rel: &'a mut Rel) -> Self::RelIndexWrite<'a> { - TrRelIndFullWrite(rel.to_tr_rel_ind_mut()) - } -} +use std::hash::{BuildHasherDefault, Hash}; +use std::iter::Map; +use std::marker::PhantomData; +use std::time::{Duration, Instant}; + +use ascent::internal::{ + RelFullIndexRead, RelFullIndexWrite, RelIndexMerge, RelIndexRead, RelIndexReadAll, RelIndexWrite, ToRelIndex, +}; +use hashbrown::HashMap; +use rustc_hash::FxHasher; + +use crate::binary_rel::BinaryRel; +use crate::iterator_from_dyn::IteratorFromDyn; +use crate::rel_boilerplate::NoopRelIndexWrite; +use crate::trrel_binary::{MyHashSet, MyHashSetIter}; +use crate::utils::{move_hash_map_of_hash_set_contents_disjoint, move_hash_map_of_vec_contents}; + +// TODO do we still need two variants? +pub enum TrRelIndCommon { + New { rel: BinaryRel, anti_reflexive: bool }, + Old { rel: BinaryRel, anti_reflexive: bool }, +} + +impl TrRelIndCommon { + pub fn anti_reflexive(&self) -> bool { + match self { + TrRelIndCommon::New { anti_reflexive, .. } => *anti_reflexive, + TrRelIndCommon::Old { anti_reflexive, .. } => *anti_reflexive, + } + } + + #[inline] + pub fn make_new() -> Self { Self::New { rel: Default::default(), anti_reflexive: true } } + + #[inline] + pub fn rel(&self) -> &BinaryRel { + match self { + TrRelIndCommon::New { rel, .. } => rel, + TrRelIndCommon::Old { rel, .. } => rel, + } + } + + pub fn unwrap_new_mut(&mut self) -> &mut BinaryRel { + match self { + TrRelIndCommon::New { rel, .. } => rel, + TrRelIndCommon::Old { .. } => panic!("TrRelIndCommon: unwrap_new_mut called on Old"), + } + } + pub fn unwrap_new(&self) -> &BinaryRel { + match self { + TrRelIndCommon::New { rel, .. } => rel, + TrRelIndCommon::Old { .. } => panic!("TrRelIndCommon: unwrap_new called on Old"), + } + } + + pub fn unwrap_old(&self) -> &BinaryRel { + match self { + TrRelIndCommon::Old { rel, .. } => rel, + TrRelIndCommon::New { .. } => panic!("TrRelIndCommon: unwrap_old called on New"), + } + } + + #[inline] + pub fn insert(&mut self, x: T, y: T) -> bool { + let rel = self.unwrap_new_mut(); + rel.insert(x, y) + } + + #[inline] + pub fn insert_by_ref(&mut self, x: &T, y: &T) -> bool { + let rel = self.unwrap_new_mut(); + rel.insert_by_ref(x, y) + } + + pub fn is_empty(&self) -> bool { + match self { + TrRelIndCommon::New { rel, .. } => rel.map.is_empty(), + TrRelIndCommon::Old { rel, .. } => rel.map.is_empty(), + } + } +} + +pub static mut MERGE_TIME: Duration = Duration::ZERO; +pub static mut MERGE_COUNT: usize = 0; + +pub trait ToTrRelIndCommon { + fn to_tr_rel_ind(&self) -> &TrRelIndCommon; + fn to_tr_rel_ind_mut(&mut self) -> &mut TrRelIndCommon; +} + +impl ToTrRelIndCommon for TrRelIndCommon { + fn to_tr_rel_ind(&self) -> &TrRelIndCommon { self } + fn to_tr_rel_ind_mut(&mut self) -> &mut TrRelIndCommon { self } +} + +impl Default for TrRelIndCommon { + #[inline] + fn default() -> Self { Self::Old { rel: Default::default(), anti_reflexive: true } } +} + +impl RelIndexMerge for TrRelIndCommon { + fn move_index_contents(_from: &mut Self, _to: &mut Self) { + panic!("merge_delta_to_total_new_to_delta must be called instead.") + } + + fn init(new: &mut Self, delta: &mut Self, total: &mut Self) { + assert!(matches!(delta, Self::Old { .. })); + assert!(matches!(total, Self::Old { .. })); + *new = Self::New { rel: Default::default(), anti_reflexive: delta.anti_reflexive() }; + } + + fn merge_delta_to_total_new_to_delta(new: &mut Self, delta: &mut Self, total: &mut Self) { + let before = Instant::now(); + let anti_reflexive = total.anti_reflexive(); + + let mut total_rel = match total { + TrRelIndCommon::New { .. } => Default::default(), + TrRelIndCommon::Old { rel, .. } => std::mem::take(rel), + }; + let mut delta_rel = match delta { + TrRelIndCommon::New { .. } => Default::default(), + TrRelIndCommon::Old { rel, .. } => std::mem::take(rel), + }; + + move_hash_map_of_hash_set_contents_disjoint(&mut delta_rel.map, &mut total_rel.map); + move_hash_map_of_vec_contents(&mut delta_rel.reverse_map, &mut total_rel.reverse_map); + + let mut new_delta = BinaryRel::::default(); + + // let new_delta_prog = ascent::ascent_run! { + // struct AscentProg; + // relation delta(T, T); + // relation new(T, T); + // relation total_map(T); + // relation total_reverse_map(T); + // total_map(x) <-- for x in total_rel.map.keys(); + // total_reverse_map(x) <-- for x in total_rel.reverse_map.keys(); + + // new(x, y), delta(x, y) <-- for (x, y) in new.unwrap_new().iter(); + // delta(x, z) <-- delta(x, y), new(y, z); + // delta(x, z) <-- delta(x, y), total_map(y), for z in total_rel.map.get(y).unwrap(); + // delta(x, z) <-- delta(y, z), total_reverse_map(y), for x in total_rel.reverse_map.get(y).unwrap(); + // }; + + // new_delta.reverse_map = new_delta_prog.delta_indices_1.0.into_iter() + // .map(|(k, v)| (k.0, v.into_iter().map(|x| x.0).collect())).collect(); + + // new_delta.map = new_delta_prog.delta_indices_0.0.into_iter() + // .map(|(k, v)| (k.0, v.into_iter().map(|x| x.0).collect())).collect(); + + type RelMap = HashMap>, BuildHasherDefault>; + type RelRevMap = HashMap, BuildHasherDefault>; + + let new_rel = std::mem::take(new.unwrap_new_mut()); + let new_map = new_rel.map; + let mut delta_delta_map = new_map.clone(); + let mut delta_delta_rev_map = new_rel.reverse_map; + + let mut delta_total_map = RelMap::::default(); + let mut delta_total_rev_map = RelRevMap::::default(); + + let mut delta_new_map = RelMap::::default(); + let mut delta_new_rev_map = RelRevMap::::default(); + + fn join( + target: &mut RelMap, target_rev: &mut RelRevMap, rel1: &RelMap, rel2_rev: &RelRevMap, + mut can_add: impl FnMut(&T, &T) -> bool, name: &str, + ) -> bool { + let mut changed = false; + if rel1.len() < rel2_rev.len() { + for (x, x_set) in rel1.iter() { + assert!(!x_set.is_empty(), "bad join {name}. rel1 has non-empty sets"); + if let Some(x_rev_set) = rel2_rev.get(x) { + assert!(!x_set.is_empty(), "bad join {name}. rel2_rev has non-empty sets"); + for w in x_rev_set { + let entry = target.entry(w.clone()).or_default(); + for y in x_set.iter() { + if !can_add(w, y) { + continue + } + if entry.insert(y.clone()) { + target_rev.entry(y.clone()).or_default().push(w.clone()); + changed = true; + } + } + if entry.is_empty() { + target.remove(w); + } + } + } + } + } else { + for (x, x_rev_set) in rel2_rev.iter() { + if let Some(x_set) = rel1.get(x) { + for w in x_rev_set { + let entry = target.entry(w.clone()).or_default(); + for y in x_set.iter() { + if !can_add(w, y) { + continue + } + if entry.insert(y.clone()) { + target_rev.entry(y.clone()).or_default().push(w.clone()); + changed = true; + } + } + if entry.is_empty() { + target.remove(w); + } + } + } + } + } + changed + } + loop { + let mut cached_delta_delta_map_entry_for_can_add = None; + let mut cached_delta_delta_map_x_for_can_add = None; + let mut cached_delta_total_map_entry_for_can_add = None; + let mut cached_delta_total_map_x_for_can_add = None; + let mut cached_total_map_entry_for_can_add = None; + let mut cached_total_map_x_for_can_add = None; + let mut can_add = |x: &T, y: &T| { + if anti_reflexive && x == y { + return false + } + { + if cached_delta_delta_map_x_for_can_add.as_ref() != Some(x) { + cached_delta_delta_map_entry_for_can_add = delta_delta_map.get(x); + cached_delta_delta_map_x_for_can_add = Some(x.clone()); + }; + } + !cached_delta_delta_map_entry_for_can_add.map_or(false, |s| s.contains(y)) + && { + if cached_delta_total_map_x_for_can_add.as_ref() != Some(x) { + cached_delta_total_map_entry_for_can_add = delta_total_map.get(x); + cached_delta_total_map_x_for_can_add = Some(x.clone()); + }; + !cached_delta_total_map_entry_for_can_add.map_or(false, |s| s.contains(y)) + } + && { + if cached_total_map_x_for_can_add.as_ref() != Some(x) { + cached_total_map_entry_for_can_add = total_rel.map.get(x); + cached_total_map_x_for_can_add = Some(x.clone()); + } + !cached_total_map_entry_for_can_add.map_or(false, |s| s.contains(y)) + } + }; + + let join1 = join( + &mut delta_new_map, &mut delta_new_rev_map, &delta_delta_map, &total_rel.reverse_map, &mut can_add, "join1", + ); + let join2 = join( + &mut delta_new_map, &mut delta_new_rev_map, &total_rel.map, &delta_delta_rev_map, &mut can_add, "join2", + ); + let join3 = + join(&mut delta_new_map, &mut delta_new_rev_map, &new_map, &delta_delta_rev_map, &mut can_add, "join3"); + + let changed = join1 | join2 | join3; + + move_hash_map_of_hash_set_contents_disjoint(&mut delta_delta_map, &mut delta_total_map); + move_hash_map_of_vec_contents(&mut delta_delta_rev_map, &mut delta_total_rev_map); + + assert!(delta_delta_map.is_empty()); + assert!(delta_delta_rev_map.is_empty()); + + std::mem::swap(&mut delta_delta_map, &mut delta_new_map); + std::mem::swap(&mut delta_delta_rev_map, &mut delta_new_rev_map); + + if !changed { + break + } + } + new_delta.map = delta_total_map; + new_delta.reverse_map = delta_total_rev_map; + + *total = TrRelIndCommon::Old { rel: total_rel, anti_reflexive }; + *delta = TrRelIndCommon::Old { rel: new_delta, anti_reflexive }; + *new = TrRelIndCommon::New { rel: Default::default(), anti_reflexive }; + + unsafe { + MERGE_TIME += before.elapsed(); + MERGE_COUNT += 1; + } + } +} + +pub struct TrRelInd0<'a, T: Clone + Hash + Eq>(pub(crate) &'a TrRelIndCommon); + +impl<'a, T: Clone + Hash + Eq + 'a> RelIndexReadAll<'a> for TrRelInd0<'a, T> { + type Key = (&'a T,); + type Value = (&'a T,); + + type ValueIteratorType = Map, fn(&T) -> (&T,)>; + + type AllIteratorType = Map< + hashbrown::hash_map::Iter<'a, T, MyHashSet>>, + for<'aa> fn( + (&'aa T, &'aa MyHashSet>), + ) -> ((&'aa T,), Map, for<'bb> fn(&'bb T) -> (&'bb T,)>), + >; + + fn iter_all(&'a self) -> Self::AllIteratorType { + let res: Self::AllIteratorType = self.0.unwrap_old().map.iter().map(|(k, v)| ((k,), v.iter().map(|x| (x,)))); + res + } +} + +impl<'a, T: Clone + Hash + Eq> RelIndexRead<'a> for TrRelInd0<'a, T> { + type Key = (T,); + type Value = (&'a T,); + + type IteratorType = Map, fn(&T) -> (&T,)>; + + fn index_get(&'a self, key: &Self::Key) -> Option { + let set = self.0.unwrap_old().map.get(&key.0)?; + let res: Self::IteratorType = set.iter().map(|x| (x,)); + Some(res) + } + + fn len(&self) -> usize { self.0.unwrap_old().map.len() } +} + +pub struct TrRelInd1<'a, T: Clone + Hash + Eq>(pub(crate) &'a TrRelIndCommon); + +impl<'a, T: Clone + Hash + Eq + 'a> RelIndexReadAll<'a> for TrRelInd1<'a, T> { + type Key = (&'a T,); + type Value = (&'a T,); + + type ValueIteratorType = Map, fn(&T) -> (&T,)>; + + type AllIteratorType = Map< + hashbrown::hash_map::Iter<'a, T, Vec>, + for<'aa> fn( + (&'aa T, &'aa Vec), + ) -> ((&'aa T,), Map, for<'bb> fn(&'bb T) -> (&'bb T,)>), + >; + + fn iter_all(&'a self) -> Self::AllIteratorType { + let res: Self::AllIteratorType = self.0.rel().reverse_map.iter().map(|(k, v)| ((k,), v.iter().map(|x| (x,)))); + res + } +} + +impl<'a, T: Clone + Hash + Eq> RelIndexRead<'a> for TrRelInd1<'a, T> { + type Key = (T,); + type Value = (&'a T,); + + type IteratorType = Map, fn(&T) -> (&T,)>; + + fn index_get(&'a self, key: &Self::Key) -> Option { + let set = self.0.rel().reverse_map.get(&key.0)?; + let res: Self::IteratorType = set.iter().map(|x| (x,)); + Some(res) + } + + fn len(&self) -> usize { self.0.rel().reverse_map.len() } +} + +pub struct TrRelIndNone<'a, T: Clone + Hash + Eq>(&'a TrRelIndCommon); + +impl<'a, T: Clone + Hash + Eq> RelIndexReadAll<'a> for TrRelIndNone<'a, T> { + type Key = (); + type Value = (&'a T, &'a T); + + type ValueIteratorType = >::IteratorType; + + type AllIteratorType = std::iter::Once<((), Self::ValueIteratorType)>; + + fn iter_all(&'a self) -> Self::AllIteratorType { std::iter::once(((), self.index_get(&()).unwrap())) } +} + +impl<'a, T: Clone + Hash + Eq> RelIndexRead<'a> for TrRelIndNone<'a, T> { + type Key = (); + type Value = (&'a T, &'a T); + + type IteratorType = IteratorFromDyn<'a, Self::Value>; + + fn index_get(&'a self, (): &Self::Key) -> Option { + println!("iterating TrRelIndNone. {} tuples", self.0.rel().map.values().map(|x| x.len()).sum::()); + let res = || self.0.rel().map.iter().flat_map(|(x, x_set)| x_set.iter().map(move |y| (x, y))); + Some(IteratorFromDyn::new(res)) + } + + fn len(&self) -> usize { 1 } +} + +pub struct TrRelIndFullWrite<'a, T: Clone + Hash + Eq>(&'a mut TrRelIndCommon); + +impl<'a, T: Clone + Hash + Eq> RelIndexMerge for TrRelIndFullWrite<'a, T> { + fn move_index_contents(_from: &mut Self, _to: &mut Self) {} //noop +} + +impl<'a, T: Clone + Hash + Eq> RelFullIndexWrite for TrRelIndFullWrite<'a, T> { + type Key = (T, T); + type Value = (); + + fn insert_if_not_present(&mut self, (x, y): &Self::Key, (): Self::Value) -> bool { self.0.insert_by_ref(x, y) } +} + +impl<'a, T: Clone + Hash + Eq> RelIndexWrite for TrRelIndFullWrite<'a, T> { + type Key = (T, T); + type Value = (); + + fn index_insert(&mut self, key: Self::Key, (): Self::Value) { self.0.insert(key.0, key.1); } +} + +pub struct TrRelIndFull<'a, T: Clone + Hash + Eq>(pub(crate) &'a TrRelIndCommon); + +impl<'a, T: Clone + Hash + Eq> RelFullIndexRead<'a> for TrRelIndFull<'a, T> { + type Key = (T, T); + + fn contains_key(&'a self, key: &Self::Key) -> bool { + self.0.rel().map.get(&key.0).map_or(false, |s| s.contains(&key.1)) + } +} + +impl<'a, T: Clone + Hash + Eq> RelIndexReadAll<'a> for TrRelIndFull<'a, T> { + type Key = (&'a T, &'a T); + type Value = (); + + type ValueIteratorType = std::iter::Once; + type AllIteratorType = Box + 'a>; + + fn iter_all(&'a self) -> Self::AllIteratorType { + let res = self + .0 + .rel() + .map + .iter() + .flat_map(|(x, x_set)| x_set.iter().map(move |y| (x, y))) + .map(|key| (key, std::iter::once(()))); + Box::new(res) + } +} + +impl<'a, T: Clone + Hash + Eq> RelIndexRead<'a> for TrRelIndFull<'a, T> { + type Key = (T, T); + type Value = (); + + type IteratorType = std::iter::Once<()>; + + fn index_get(&'a self, key: &Self::Key) -> Option { + if self.0.rel().map.get(&key.0)?.contains(&key.1) { Some(std::iter::once(())) } else { None } + } + + fn len(&self) -> usize { + let sample_size = 3; + let sum: usize = self.0.rel().map.values().take(sample_size).map(|x| x.len()).sum(); + let map_len = self.0.rel().map.len(); + sum * map_len / sample_size.min(map_len).max(1) + } +} + +macro_rules! to_rel_ind { + ($name: ident, $key: ty, $val: ty) => {paste::paste!{ + pub struct [](PhantomData); + + impl Default for [] { + fn default() -> Self { Self(PhantomData) } + } + + impl ToRelIndex for [] + where Rel: ToTrRelIndCommon + { + type RelIndex<'a> = $name<'a, T> where Self: 'a, Rel: 'a; + fn to_rel_index<'a>(&'a self, rel: &'a Rel) -> Self::RelIndex<'a> { $name(rel.to_tr_rel_ind()) } + + type RelIndexWrite<'a> = NoopRelIndexWrite<$key, $val> where Self: 'a, Rel: 'a; + fn to_rel_index_write<'a>(&'a mut self, _rel: &'a mut Rel) -> Self::RelIndexWrite<'a> { + NoopRelIndexWrite::default() + } + } + }}; +} + +to_rel_ind!(TrRelIndNone, (), (T, T)); +to_rel_ind!(TrRelInd0, (T,), (T,)); +to_rel_ind!(TrRelInd1, (T,), (T,)); +// to_rel_ind!(TrRelIndFull, (T, T), ()); + +pub struct ToTrRelIndFull(PhantomData); + +impl Default for ToTrRelIndFull { + fn default() -> Self { Self(PhantomData) } +} +impl ToRelIndex for ToTrRelIndFull +where Rel: ToTrRelIndCommon +{ + type RelIndex<'a> + = TrRelIndFull<'a, T> + where + Self: 'a, + Rel: 'a; + fn to_rel_index<'a>(&'a self, rel: &'a Rel) -> Self::RelIndex<'a> { TrRelIndFull(rel.to_tr_rel_ind()) } + + type RelIndexWrite<'a> + = TrRelIndFullWrite<'a, T> + where + Self: 'a, + Rel: 'a; + fn to_rel_index_write<'a>(&'a mut self, rel: &'a mut Rel) -> Self::RelIndexWrite<'a> { + TrRelIndFullWrite(rel.to_tr_rel_ind_mut()) + } +} diff --git a/byods/ascent-byods-rels/src/trrel_ternary_ind.rs b/byods/ascent-byods-rels/src/trrel_ternary_ind.rs index b496cc5..94b2f96 100644 --- a/byods/ascent-byods-rels/src/trrel_ternary_ind.rs +++ b/byods/ascent-byods-rels/src/trrel_ternary_ind.rs @@ -1,549 +1,549 @@ -use std::hash::{Hash, BuildHasherDefault}; -use std::iter::{Map, once}; - -use ascent::internal::{RelIndexRead, RelIndexReadAll, RelFullIndexRead, RelIndexWrite, RelFullIndexWrite, RelIndexMerge}; -use derive_more::{DerefMut, Deref}; -use hashbrown::HashMap; -use rustc_hash::FxHasher; - -use crate::iterator_from_dyn::IteratorFromDyn; -use crate::trrel_binary_ind::{TrRelIndCommon, TrRelInd0, TrRelInd1}; -use crate::trrel_binary::MyHashSetIter; -use crate::utils::{hash_one, AltHashSet, AltHashSetIter}; - - -#[derive(DerefMut, Deref)] -pub struct TrRel2IndCommonWrapper( - TrRel2IndCommon -); - -impl -RelIndexMerge for TrRel2IndCommonWrapper { - fn move_index_contents(from: &mut Self, to: &mut Self) { - TrRel2IndCommon::move_index_contents(&mut from.0, &mut to.0) - } - - fn merge_delta_to_total_new_to_delta(new: &mut Self, delta: &mut Self, total: &mut Self) { - TrRel2IndCommon::merge_delta_to_total_new_to_delta(&mut new.0, &mut delta.0, &mut total.0) - } -} - -impl -Default for TrRel2IndCommonWrapper { - fn default() -> Self { - let reverse_map1 = if HAS_REVERSE_MAP1 {Some(Default::default())} else {None}; - let reverse_map2 = if HAS_REVERSE_MAP2 {Some(Default::default())} else {None}; - Self(TrRel2IndCommon { map: Default::default(), reverse_map1, reverse_map2 }) - } -} - -type RevMapHashSet = AltHashSet; -#[allow(dead_code)] -type RevMapHashSetIter<'a, T> = AltHashSetIter<'a, T>; - -pub struct TrRel2IndCommon { - pub map: HashMap, BuildHasherDefault>, - pub reverse_map1: Option>, BuildHasherDefault>>, - pub reverse_map2: Option>, BuildHasherDefault>> -} - -impl RelIndexMerge for TrRel2IndCommon { - fn move_index_contents(_from: &mut Self, _to: &mut Self) { - panic!("merge_delta_to_total_new_to_delta must be called instead"); - } - - fn merge_delta_to_total_new_to_delta(new: &mut Self, delta: &mut Self, total: &mut Self) { - let mut new_delta_map = HashMap::default(); - for (k, mut delta_trrel) in delta.map.drain() { - let mut new_trrel = new.map.remove(&k).unwrap_or_else(|| TrRelIndCommon::make_new()); - match total.map.entry(k.clone()) { - hashbrown::hash_map::Entry::Occupied(mut total_entry) => { - TrRelIndCommon::merge_delta_to_total_new_to_delta(&mut new_trrel, &mut delta_trrel, total_entry.get_mut()); - if !delta_trrel.is_empty() { - new_delta_map.insert(k, delta_trrel); - } - }, - hashbrown::hash_map::Entry::Vacant(total_vacant_entry) => { - let mut new_total = TrRelIndCommon::Old{ rel: Default::default(), anti_reflexive: delta_trrel.anti_reflexive()}; - TrRelIndCommon::merge_delta_to_total_new_to_delta(&mut new_trrel, &mut delta_trrel, &mut new_total); - total_vacant_entry.insert(new_total); - if !delta_trrel.is_empty() { - new_delta_map.insert(k, delta_trrel); - } - }, - } - } - for (k, mut new_trrel) in new.map.drain() { - let mut new_delta = Default::default(); - match total.map.entry(k.clone()) { - hashbrown::hash_map::Entry::Occupied(mut total_entry) => { - TrRelIndCommon::merge_delta_to_total_new_to_delta(&mut new_trrel, &mut new_delta, total_entry.get_mut()); - new_delta_map.insert(k, new_delta); - }, - hashbrown::hash_map::Entry::Vacant(_) => { - TrRelIndCommon::merge_delta_to_total_new_to_delta(&mut new_trrel, &mut new_delta, &mut Default::default()); - new_delta_map.insert(k, new_delta); - }, - } - } - delta.map = new_delta_map; - - if delta.reverse_map1.is_some() { - crate::utils::move_hash_map_of_alt_hash_set_contents(delta.reverse_map1.as_mut().unwrap(), total.reverse_map1.as_mut().unwrap()); - std::mem::swap(delta.reverse_map1.as_mut().unwrap(), new.reverse_map1.as_mut().unwrap()); - } - - if delta.reverse_map2.is_some() { - crate::utils::move_hash_map_of_alt_hash_set_contents(delta.reverse_map2.as_mut().unwrap(), total.reverse_map2.as_mut().unwrap()); - std::mem::swap(delta.reverse_map2.as_mut().unwrap(), new.reverse_map2.as_mut().unwrap()); - } - } -} - - - -pub struct TrRel2Ind0<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq>( - &'a TrRel2IndCommon -); - -impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexReadAll<'a> for TrRel2Ind0<'a, T0, T1> { - type Key = (&'a T0, ); - type Value = (&'a T1, &'a T1); - - type ValueIteratorType = Box + 'a>; - type AllIteratorType = Box + 'a>; - - fn iter_all(&'a self) -> Self::AllIteratorType { - Box::new(self.0.map.iter().map(|(k, v)| { - ((k, ), Box::new(v.rel().iter_all()) as _) - })) - } -} - -impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexRead<'a> for TrRel2Ind0<'a, T0, T1> { - type Key = (T0, ); - type Value = (&'a T1, &'a T1); - - type IteratorType = IteratorFromDyn<'a, Self::Value>; - - fn index_get(&'a self, key: &Self::Key) -> Option { - let trrel = self.0.map.get(&key.0)?; - Some(IteratorFromDyn::new(|| trrel.rel().iter_all())) - } - - fn len(&self) -> usize { - let sample_size = 4; - let sum = self.0.map.values().map(|x| x.rel().count_estimate()).sum::(); - sum * self.0.map.len() / sample_size.min(self.0.map.len()).max(1) - } -} - -pub struct TrRel2Ind0_1<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq>( - &'a TrRel2IndCommon -); - -impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexReadAll<'a> for TrRel2Ind0_1<'a, T0, T1> { - type Key = (&'a T0, &'a T1); - - type Value = (&'a T1, ); - - type ValueIteratorType = Map, fn(&T1) -> (&T1,)>; - - type AllIteratorType = Box + 'a>; - - fn iter_all(&'a self) -> Self::AllIteratorType { - Box::new(self.0.map.iter().flat_map(|(x0, v)| v.rel().map.iter().map(move |(x1, x2_set)| { - let iter: Self::ValueIteratorType = x2_set.iter().map(|x2| (x2, )); - ((x0, x1), iter) - }))) - } -} - -impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexRead<'a> for TrRel2Ind0_1<'a, T0, T1> { - type Key = (T0, T1); - - type Value = (&'a T1, ); - - type IteratorType = Map, fn(&T1) -> (&T1,)>; - - fn index_get(&'a self, key: &Self::Key) -> Option { - let trrel = self.0.map.get(&key.0)?; - let res: Self::IteratorType = trrel.rel().map.get(&key.1)?.iter().map(|x| (x, )); - Some(res) - } - - fn len(&self) -> usize { - let sample_size = 3; - let sum = self.0.map.values().take(sample_size).map(|trrel| TrRelInd0(&trrel).len()).sum::(); - let map_len = self.0.map.len(); - sum * map_len / sample_size.min(map_len).max(1) - - } -} - -pub struct TrRel2Ind0_2<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq>( - &'a TrRel2IndCommon -); - -impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexReadAll<'a> for TrRel2Ind0_2<'a, T0, T1> { - type Key = (&'a T0, &'a T1); - type Value = (&'a T1, ); - - type ValueIteratorType = Map, fn(&T1) -> (&T1,)>; - type AllIteratorType = Box + 'a>; - - fn iter_all(&'a self) -> Self::AllIteratorType { - Box::new(self.0.map.iter().flat_map(|(x0, v)| v.rel().reverse_map.iter().map(move |(x1, x2_set)| { - let iter: Self::ValueIteratorType = x2_set.iter().map(|x2| (x2, )); - ((x0, x1), iter) - }))) - } -} - -impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexRead<'a> for TrRel2Ind0_2<'a, T0, T1> { - type Key = (T0, T1); - type Value = (&'a T1, ); - - type IteratorType = Map, fn(&T1) -> (&T1,)>; - - fn index_get(&'a self, key: &Self::Key) -> Option { - let trrel = self.0.map.get(&key.0)?; - let res: Self::IteratorType = trrel.rel().reverse_map.get(&key.1)?.iter().map(|x| (x, )); - Some(res) - } - - fn len(&self) -> usize { - let sample_size = 3; - let sum = self.0.map.values().take(sample_size).map(|trrel| TrRelInd1(&trrel).len()).sum::(); - let map_len = self.0.map.len(); - sum * map_len / sample_size.min(map_len).max(1) - } -} - -pub struct TrRel2Ind1<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq>( - &'a TrRel2IndCommon -); - -impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexReadAll<'a> for TrRel2Ind1<'a, T0, T1> { - type Key = (&'a T1, ); - type Value = (&'a T0, &'a T1); - - type ValueIteratorType = >::IteratorType; - type AllIteratorType = Box + 'a>; - - fn iter_all(&'a self) -> Self::AllIteratorType { - Box::new(self.0.reverse_map1.as_ref().unwrap().keys().map(|x1| { - ((x1, ), self.get(x1).unwrap()) - })) - } -} - -impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> TrRel2Ind1<'a, T0, T1> { - fn get(&'a self, x1: &T1) -> Option<>::IteratorType> { - let (x1, x0s) = self.0.reverse_map1.as_ref().unwrap().get_key_value(x1)?; - let res = move || x0s.iter().filter_map(move |x0| { - Some(self.0.map.get(x0).unwrap().rel().map.get(x1)?.iter().map(move |x2| (x0, x2))) - }).flatten(); - Some(IteratorFromDyn::new(res)) - } -} - -impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexRead<'a> for TrRel2Ind1<'a, T0, T1> { - type Key = (T1, ); - - type Value = (&'a T0, &'a T1); - - type IteratorType = IteratorFromDyn<'a, Self::Value>; - - fn index_get(&'a self, (x1, ): &Self::Key) -> Option { - self.get(x1) - } - - fn len(&self) -> usize { - self.0.reverse_map1.as_ref().unwrap().len() - } -} - -pub struct TrRel2Ind2<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq>( - &'a TrRel2IndCommon -); - -impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexReadAll<'a> for TrRel2Ind2<'a, T0, T1> { - type Key = (&'a T1, ); - type Value = (&'a T0, &'a T1); - - type ValueIteratorType = >::IteratorType; - type AllIteratorType = Box + 'a>; - - fn iter_all(&'a self) -> Self::AllIteratorType { - Box::new(self.0.reverse_map2.as_ref().unwrap().keys().map(|x2| { - ((x2, ), self.get(x2).unwrap()) - })) - } -} - -impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> TrRel2Ind2<'a, T0, T1> { - fn get(&'a self, x2: &T1) -> Option<>::IteratorType> { - let (x2, x0s) = self.0.reverse_map2.as_ref().unwrap().get_key_value(x2)?; - let res = move || x0s.iter().flat_map(move |x0| { - self.0.map.get(x0).unwrap().rel().reverse_map.get(x2).unwrap().iter().map(move |x1| (x0, x1)) - }); - Some(IteratorFromDyn::new(res)) - } -} - -impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexRead<'a> for TrRel2Ind2<'a, T0, T1> { - type Key = (T1, ); - type Value = (&'a T0, &'a T1); - - type IteratorType = IteratorFromDyn<'a, Self::Value>; - - fn index_get(&'a self, (x2, ): &Self::Key) -> Option { - self.get(x2) - } - - fn len(&self) -> usize { - self.0.reverse_map2.as_ref().unwrap().len() - } -} - -pub struct TrRel2Ind1_2<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq>( - &'a TrRel2IndCommon -); - -impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexReadAll<'a> for TrRel2Ind1_2<'a, T0, T1> { - type Key = (&'a T1, &'a T1); - type Value = (&'a T0, ); - - type ValueIteratorType = Box + 'a>; - type AllIteratorType = Box + 'a>; - - fn iter_all(&'a self) -> Self::AllIteratorType { - Box::new(self.0.reverse_map1.as_ref().unwrap().iter().flat_map(move |(x1, x0s_for_x1)| { - self.0.reverse_map2.as_ref().unwrap().iter().map(move |(x2, x0s_for_x2)| { - let x0s: Self::ValueIteratorType = Box::new(x0s_for_x1.intersection(x0s_for_x2) - .filter(|&x0| self.0.map.get(x0).unwrap().rel().contains(x1, x2)) - .map(|x0| (x0, ))); - ((x1, x2), x0s) - }) - })) - } -} - -impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexRead<'a> for TrRel2Ind1_2<'a, T0, T1> { - type Key = (T1, T1); - - type Value = (&'a T0, ); - - type IteratorType = IteratorFromDyn<'a, Self::Value>; - fn index_get(&'a self, (x1, x2): &Self::Key) -> Option { - let (x1, x1_map) = self.0.reverse_map1.as_ref().unwrap().get_key_value(x1)?; - let (x2, x2_map) = self.0.reverse_map2.as_ref().unwrap().get_key_value(x2)?; - - let res = || x1_map.intersection(x2_map) - .filter(|&x0| self.0.map.get(x0).unwrap().rel().contains(x1, x2)) - .map(|x0| (x0, )); - Some(IteratorFromDyn::new(res)) - } - - fn len(&self) -> usize { - // TODO random estimate, could be very wrong - self.0.reverse_map1.as_ref().unwrap().len() * self.0.reverse_map2.as_ref().unwrap().len() - / ((self.0.map.len() as f32).sqrt() as usize) - } -} - -pub struct TrRel2IndNone<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq>( - &'a TrRel2IndCommon -); - -impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexReadAll<'a> for TrRel2IndNone<'a, T0, T1> { - type Key = (); - type Value = (&'a T0, &'a T1, &'a T1); - - type ValueIteratorType = >::IteratorType; - type AllIteratorType = std::option::IntoIter<(Self::Key, Self::ValueIteratorType)>; - - fn iter_all(&'a self) -> Self::AllIteratorType { - self.index_get(&()).map(|x| ((), x)).into_iter() - } -} - -impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexRead<'a> for TrRel2IndNone<'a, T0, T1> { - type Key = (); - type Value = (&'a T0, &'a T1, &'a T1); - - type IteratorType = IteratorFromDyn<'a, Self::Value>; - - fn index_get(&'a self, (): &Self::Key) -> Option { - let res = || self.0.map.iter().flat_map(|(x0, trrel)| { - trrel.rel().iter_all().map(move |(x1, x2)| (x0, x1, x2)) - }); - Some(IteratorFromDyn::new(res)) - } - - fn len(&self) -> usize { - 1 - } -} - -#[repr(transparent)] -pub struct TrRel2IndFull<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq>( - &'a TrRel2IndCommon -); - -impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelFullIndexRead<'a> for TrRel2IndFull<'a, T0, T1> { - type Key = (T0, T1, T1); - - #[inline] - fn contains_key(&'a self, (x0, x1, x2): &Self::Key) -> bool { - match self.0.map.get(x0) { - None => false, - Some(rel) => rel.rel().contains(x1, x2), - } - } -} - -impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexReadAll<'a> for TrRel2IndFull<'a, T0, T1> { - type Key = (&'a T0, &'a T1, &'a T1); - type Value = (); - - type ValueIteratorType = std::iter::Once<()>; - type AllIteratorType = Box + 'a>; - - fn iter_all(&'a self) -> Self::AllIteratorType { - let iter = self.0.map.iter().flat_map(|(x0, trrel)| { - trrel.rel().iter_all().map(move |(x1, x2)| (x0, x1, x2)) - }); - - Box::new(iter.map(|t| (t, std::iter::once(())))) - } -} - -impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexRead<'a> for TrRel2IndFull<'a, T0, T1> { - type Key = (T0, T1, T1); - type Value = (); - - type IteratorType = std::iter::Once; - - fn index_get(&'a self, (x0, x1, x2): &Self::Key) -> Option { - if self.0.map.get(x0)?.rel().contains(x1, x2) { - Some(once(())) - } else { - None - } - } - - fn len(&self) -> usize { - let sample_size = 3; - let sum = self.0.map.values().take(sample_size).map(|trrel| trrel.rel().count_estimate()).sum::(); - let map_len = self.0.map.len(); - sum * map_len / sample_size.min(map_len).max(1) - } -} - -pub struct TrRel2IndFullWrite<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq>( - &'a mut TrRel2IndCommon -); - -impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexMerge for TrRel2IndFullWrite<'a, T0, T1> { - fn move_index_contents(_from: &mut Self, _to: &mut Self) { } // noop -} - -impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelFullIndexWrite for TrRel2IndFullWrite<'a, T0, T1> { - type Key = (T0, T1, T1); - - type Value = (); - - fn insert_if_not_present(&mut self, (x0, x1, x2): &Self::Key, (): Self::Value) -> bool { - let x0_hash = hash_one(self.0.map.hasher(), x0); - - if !self.0.map.raw_entry_mut().from_key_hashed_nocheck(x0_hash, x0) - .or_insert_with(|| (x0.clone(), TrRelIndCommon::make_new())) - .1.insert_by_ref(x1, x2) - { - return false; - } - if let Some(reverse_map1) = self.0.reverse_map1.as_mut() { - reverse_map1.entry(x1.clone()).or_default().insert_with_hash_no_check(x0_hash, x0.clone()); - // reverse_map1.entry(x1.clone()).or_default().raw_table().find(x0_hash, |a| &a.0 == x0).unwrap().copy_from_nonoverlapping(other).insert(x0.clone()); - } - if let Some(reverse_map2) = self.0.reverse_map2.as_mut() { - reverse_map2.entry(x2.clone()).or_default().insert_with_hash_no_check(x0_hash, x0.clone()); - } - true - } -} - -impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexWrite for TrRel2IndFullWrite<'a, T0, T1> { - type Key = (T0, T1, T1); - - type Value = (); - - fn index_insert(&mut self, (x0, x1, x2): Self::Key, (): Self::Value) { - if let Some(reverse_map1) = self.0.reverse_map1.as_mut() { - reverse_map1.entry(x1.clone()).or_default().insert(x0.clone()); - } - if let Some(reverse_map2) = self.0.reverse_map2.as_mut() { - reverse_map2.entry(x2.clone()).or_default().insert(x0.clone()); - } - self.0.map.entry(x0).or_insert_with(|| TrRelIndCommon::make_new()).insert(x1, x2); - } -} - -use std::marker::PhantomData; -use ascent::internal::ToRelIndex; -use crate::rel_boilerplate::NoopRelIndexWrite; - -macro_rules! to_trrel2 { - ($name: ident, $key: ty, $val: ty) => {paste::paste!{ - pub struct [](PhantomData<(T0, T1)>); - - impl Default for [] { - fn default() -> Self { Self(PhantomData) } - } - - impl ToRelIndex for [] - where Rel: std::ops::DerefMut> - { - type RelIndex<'a> = $name<'a, T0, T1> where Self: 'a, Rel: 'a; - #[inline(always)] - fn to_rel_index<'a>(&'a self, rel: &'a Rel) -> Self::RelIndex<'a> { $name(rel.deref()) } - - type RelIndexWrite<'a> = NoopRelIndexWrite<$key, $val> where Self: 'a, Rel: 'a; - #[inline(always)] - fn to_rel_index_write<'a>(&'a mut self, _rel: &'a mut Rel) -> Self::RelIndexWrite<'a> { - NoopRelIndexWrite::default() - } - } - }}; -} - -to_trrel2!(TrRel2IndNone, (), (T0, T1, T1)); -to_trrel2!(TrRel2Ind0, (T0, ), (T1, T1)); -to_trrel2!(TrRel2Ind1, (T1, ), (T0, T1)); -to_trrel2!(TrRel2Ind2, (T1, ), (T0, T1)); -to_trrel2!(TrRel2Ind0_1, (T0, T1), (T1, )); -to_trrel2!(TrRel2Ind0_2, (T0, T1), (T1, )); -to_trrel2!(TrRel2Ind1_2, (T1, T1), (T0, )); -// to_trrel2!(TrRel2IndFull, (T0, T1, T1), ()); - -pub struct ToTrRel2IndFull(PhantomData<(T0, T1)>); - -impl Default for ToTrRel2IndFull { - fn default() -> Self { Self(PhantomData) } -} - -impl ToRelIndex for ToTrRel2IndFull -where Rel: std::ops::DerefMut>, -{ - type RelIndex<'a> = TrRel2IndFull<'a,T0,T1>where Self:'a,Rel:'a; - #[inline(always)] - fn to_rel_index<'a>(&'a self, rel: &'a Rel) -> Self::RelIndex<'a> { TrRel2IndFull(rel.deref()) } - - type RelIndexWrite<'a> = TrRel2IndFullWrite<'a, T0, T1>where Self:'a,Rel:'a; - #[inline(always)] - fn to_rel_index_write<'a>(&'a mut self, rel: &'a mut Rel) -> Self::RelIndexWrite<'a> { - TrRel2IndFullWrite(rel.deref_mut()) - } -} +use std::hash::{BuildHasherDefault, Hash}; +use std::iter::{Map, once}; + +use ascent::internal::{ + RelFullIndexRead, RelFullIndexWrite, RelIndexMerge, RelIndexRead, RelIndexReadAll, RelIndexWrite, +}; +use derive_more::{Deref, DerefMut}; +use hashbrown::HashMap; +use rustc_hash::FxHasher; + +use crate::iterator_from_dyn::IteratorFromDyn; +use crate::trrel_binary::MyHashSetIter; +use crate::trrel_binary_ind::{TrRelInd0, TrRelInd1, TrRelIndCommon}; +use crate::utils::{AltHashSet, AltHashSetIter, hash_one}; + +#[derive(DerefMut, Deref)] +pub struct TrRel2IndCommonWrapper< + const HAS_REVERSE_MAP1: bool, + const HAS_REVERSE_MAP2: bool, + T0: Clone + Hash + Eq, + T1: Clone + Hash + Eq, +>(TrRel2IndCommon); + +impl + RelIndexMerge for TrRel2IndCommonWrapper +{ + fn move_index_contents(from: &mut Self, to: &mut Self) { + TrRel2IndCommon::move_index_contents(&mut from.0, &mut to.0) + } + + fn merge_delta_to_total_new_to_delta(new: &mut Self, delta: &mut Self, total: &mut Self) { + TrRel2IndCommon::merge_delta_to_total_new_to_delta(&mut new.0, &mut delta.0, &mut total.0) + } +} + +impl Default + for TrRel2IndCommonWrapper +{ + fn default() -> Self { + let reverse_map1 = if HAS_REVERSE_MAP1 { Some(Default::default()) } else { None }; + let reverse_map2 = if HAS_REVERSE_MAP2 { Some(Default::default()) } else { None }; + Self(TrRel2IndCommon { map: Default::default(), reverse_map1, reverse_map2 }) + } +} + +type RevMapHashSet = AltHashSet; +#[allow(dead_code)] +type RevMapHashSetIter<'a, T> = AltHashSetIter<'a, T>; + +pub struct TrRel2IndCommon { + pub map: HashMap, BuildHasherDefault>, + pub reverse_map1: Option>, BuildHasherDefault>>, + pub reverse_map2: Option>, BuildHasherDefault>>, +} + +impl RelIndexMerge for TrRel2IndCommon { + fn move_index_contents(_from: &mut Self, _to: &mut Self) { + panic!("merge_delta_to_total_new_to_delta must be called instead"); + } + + fn merge_delta_to_total_new_to_delta(new: &mut Self, delta: &mut Self, total: &mut Self) { + let mut new_delta_map = HashMap::default(); + for (k, mut delta_trrel) in delta.map.drain() { + let mut new_trrel = new.map.remove(&k).unwrap_or_else(|| TrRelIndCommon::make_new()); + match total.map.entry(k.clone()) { + hashbrown::hash_map::Entry::Occupied(mut total_entry) => { + TrRelIndCommon::merge_delta_to_total_new_to_delta( + &mut new_trrel, + &mut delta_trrel, + total_entry.get_mut(), + ); + if !delta_trrel.is_empty() { + new_delta_map.insert(k, delta_trrel); + } + }, + hashbrown::hash_map::Entry::Vacant(total_vacant_entry) => { + let mut new_total = + TrRelIndCommon::Old { rel: Default::default(), anti_reflexive: delta_trrel.anti_reflexive() }; + TrRelIndCommon::merge_delta_to_total_new_to_delta(&mut new_trrel, &mut delta_trrel, &mut new_total); + total_vacant_entry.insert(new_total); + if !delta_trrel.is_empty() { + new_delta_map.insert(k, delta_trrel); + } + }, + } + } + for (k, mut new_trrel) in new.map.drain() { + let mut new_delta = Default::default(); + match total.map.entry(k.clone()) { + hashbrown::hash_map::Entry::Occupied(mut total_entry) => { + TrRelIndCommon::merge_delta_to_total_new_to_delta(&mut new_trrel, &mut new_delta, total_entry.get_mut()); + new_delta_map.insert(k, new_delta); + }, + hashbrown::hash_map::Entry::Vacant(_) => { + TrRelIndCommon::merge_delta_to_total_new_to_delta( + &mut new_trrel, + &mut new_delta, + &mut Default::default(), + ); + new_delta_map.insert(k, new_delta); + }, + } + } + delta.map = new_delta_map; + + if delta.reverse_map1.is_some() { + crate::utils::move_hash_map_of_alt_hash_set_contents( + delta.reverse_map1.as_mut().unwrap(), + total.reverse_map1.as_mut().unwrap(), + ); + std::mem::swap(delta.reverse_map1.as_mut().unwrap(), new.reverse_map1.as_mut().unwrap()); + } + + if delta.reverse_map2.is_some() { + crate::utils::move_hash_map_of_alt_hash_set_contents( + delta.reverse_map2.as_mut().unwrap(), + total.reverse_map2.as_mut().unwrap(), + ); + std::mem::swap(delta.reverse_map2.as_mut().unwrap(), new.reverse_map2.as_mut().unwrap()); + } + } +} + +pub struct TrRel2Ind0<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq>(&'a TrRel2IndCommon); + +impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexReadAll<'a> for TrRel2Ind0<'a, T0, T1> { + type Key = (&'a T0,); + type Value = (&'a T1, &'a T1); + + type ValueIteratorType = Box + 'a>; + type AllIteratorType = Box + 'a>; + + fn iter_all(&'a self) -> Self::AllIteratorType { + Box::new(self.0.map.iter().map(|(k, v)| ((k,), Box::new(v.rel().iter_all()) as _))) + } +} + +impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexRead<'a> for TrRel2Ind0<'a, T0, T1> { + type Key = (T0,); + type Value = (&'a T1, &'a T1); + + type IteratorType = IteratorFromDyn<'a, Self::Value>; + + fn index_get(&'a self, key: &Self::Key) -> Option { + let trrel = self.0.map.get(&key.0)?; + Some(IteratorFromDyn::new(|| trrel.rel().iter_all())) + } + + fn len(&self) -> usize { + let sample_size = 4; + let sum = self.0.map.values().map(|x| x.rel().count_estimate()).sum::(); + sum * self.0.map.len() / sample_size.min(self.0.map.len()).max(1) + } +} + +pub struct TrRel2Ind0_1<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq>(&'a TrRel2IndCommon); + +impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexReadAll<'a> for TrRel2Ind0_1<'a, T0, T1> { + type Key = (&'a T0, &'a T1); + + type Value = (&'a T1,); + + type ValueIteratorType = Map, fn(&T1) -> (&T1,)>; + + type AllIteratorType = Box + 'a>; + + fn iter_all(&'a self) -> Self::AllIteratorType { + Box::new(self.0.map.iter().flat_map(|(x0, v)| { + v.rel().map.iter().map(move |(x1, x2_set)| { + let iter: Self::ValueIteratorType = x2_set.iter().map(|x2| (x2,)); + ((x0, x1), iter) + }) + })) + } +} + +impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexRead<'a> for TrRel2Ind0_1<'a, T0, T1> { + type Key = (T0, T1); + + type Value = (&'a T1,); + + type IteratorType = Map, fn(&T1) -> (&T1,)>; + + fn index_get(&'a self, key: &Self::Key) -> Option { + let trrel = self.0.map.get(&key.0)?; + let res: Self::IteratorType = trrel.rel().map.get(&key.1)?.iter().map(|x| (x,)); + Some(res) + } + + fn len(&self) -> usize { + let sample_size = 3; + let sum = self.0.map.values().take(sample_size).map(|trrel| TrRelInd0(&trrel).len()).sum::(); + let map_len = self.0.map.len(); + sum * map_len / sample_size.min(map_len).max(1) + } +} + +pub struct TrRel2Ind0_2<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq>(&'a TrRel2IndCommon); + +impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexReadAll<'a> for TrRel2Ind0_2<'a, T0, T1> { + type Key = (&'a T0, &'a T1); + type Value = (&'a T1,); + + type ValueIteratorType = Map, fn(&T1) -> (&T1,)>; + type AllIteratorType = Box + 'a>; + + fn iter_all(&'a self) -> Self::AllIteratorType { + Box::new(self.0.map.iter().flat_map(|(x0, v)| { + v.rel().reverse_map.iter().map(move |(x1, x2_set)| { + let iter: Self::ValueIteratorType = x2_set.iter().map(|x2| (x2,)); + ((x0, x1), iter) + }) + })) + } +} + +impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexRead<'a> for TrRel2Ind0_2<'a, T0, T1> { + type Key = (T0, T1); + type Value = (&'a T1,); + + type IteratorType = Map, fn(&T1) -> (&T1,)>; + + fn index_get(&'a self, key: &Self::Key) -> Option { + let trrel = self.0.map.get(&key.0)?; + let res: Self::IteratorType = trrel.rel().reverse_map.get(&key.1)?.iter().map(|x| (x,)); + Some(res) + } + + fn len(&self) -> usize { + let sample_size = 3; + let sum = self.0.map.values().take(sample_size).map(|trrel| TrRelInd1(&trrel).len()).sum::(); + let map_len = self.0.map.len(); + sum * map_len / sample_size.min(map_len).max(1) + } +} + +pub struct TrRel2Ind1<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq>(&'a TrRel2IndCommon); + +impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexReadAll<'a> for TrRel2Ind1<'a, T0, T1> { + type Key = (&'a T1,); + type Value = (&'a T0, &'a T1); + + type ValueIteratorType = >::IteratorType; + type AllIteratorType = Box + 'a>; + + fn iter_all(&'a self) -> Self::AllIteratorType { + Box::new(self.0.reverse_map1.as_ref().unwrap().keys().map(|x1| ((x1,), self.get(x1).unwrap()))) + } +} + +impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> TrRel2Ind1<'a, T0, T1> { + fn get(&'a self, x1: &T1) -> Option<>::IteratorType> { + let (x1, x0s) = self.0.reverse_map1.as_ref().unwrap().get_key_value(x1)?; + let res = move || { + x0s.iter() + .filter_map(move |x0| Some(self.0.map.get(x0).unwrap().rel().map.get(x1)?.iter().map(move |x2| (x0, x2)))) + .flatten() + }; + Some(IteratorFromDyn::new(res)) + } +} + +impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexRead<'a> for TrRel2Ind1<'a, T0, T1> { + type Key = (T1,); + + type Value = (&'a T0, &'a T1); + + type IteratorType = IteratorFromDyn<'a, Self::Value>; + + fn index_get(&'a self, (x1,): &Self::Key) -> Option { self.get(x1) } + + fn len(&self) -> usize { self.0.reverse_map1.as_ref().unwrap().len() } +} + +pub struct TrRel2Ind2<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq>(&'a TrRel2IndCommon); + +impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexReadAll<'a> for TrRel2Ind2<'a, T0, T1> { + type Key = (&'a T1,); + type Value = (&'a T0, &'a T1); + + type ValueIteratorType = >::IteratorType; + type AllIteratorType = Box + 'a>; + + fn iter_all(&'a self) -> Self::AllIteratorType { + Box::new(self.0.reverse_map2.as_ref().unwrap().keys().map(|x2| ((x2,), self.get(x2).unwrap()))) + } +} + +impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> TrRel2Ind2<'a, T0, T1> { + fn get(&'a self, x2: &T1) -> Option<>::IteratorType> { + let (x2, x0s) = self.0.reverse_map2.as_ref().unwrap().get_key_value(x2)?; + let res = move || { + x0s.iter().flat_map(move |x0| { + self.0.map.get(x0).unwrap().rel().reverse_map.get(x2).unwrap().iter().map(move |x1| (x0, x1)) + }) + }; + Some(IteratorFromDyn::new(res)) + } +} + +impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexRead<'a> for TrRel2Ind2<'a, T0, T1> { + type Key = (T1,); + type Value = (&'a T0, &'a T1); + + type IteratorType = IteratorFromDyn<'a, Self::Value>; + + fn index_get(&'a self, (x2,): &Self::Key) -> Option { self.get(x2) } + + fn len(&self) -> usize { self.0.reverse_map2.as_ref().unwrap().len() } +} + +pub struct TrRel2Ind1_2<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq>(&'a TrRel2IndCommon); + +impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexReadAll<'a> for TrRel2Ind1_2<'a, T0, T1> { + type Key = (&'a T1, &'a T1); + type Value = (&'a T0,); + + type ValueIteratorType = Box + 'a>; + type AllIteratorType = Box + 'a>; + + fn iter_all(&'a self) -> Self::AllIteratorType { + Box::new(self.0.reverse_map1.as_ref().unwrap().iter().flat_map(move |(x1, x0s_for_x1)| { + self.0.reverse_map2.as_ref().unwrap().iter().map(move |(x2, x0s_for_x2)| { + let x0s: Self::ValueIteratorType = Box::new( + x0s_for_x1 + .intersection(x0s_for_x2) + .filter(|&x0| self.0.map.get(x0).unwrap().rel().contains(x1, x2)) + .map(|x0| (x0,)), + ); + ((x1, x2), x0s) + }) + })) + } +} + +impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexRead<'a> for TrRel2Ind1_2<'a, T0, T1> { + type Key = (T1, T1); + + type Value = (&'a T0,); + + type IteratorType = IteratorFromDyn<'a, Self::Value>; + fn index_get(&'a self, (x1, x2): &Self::Key) -> Option { + let (x1, x1_map) = self.0.reverse_map1.as_ref().unwrap().get_key_value(x1)?; + let (x2, x2_map) = self.0.reverse_map2.as_ref().unwrap().get_key_value(x2)?; + + let res = || { + x1_map.intersection(x2_map).filter(|&x0| self.0.map.get(x0).unwrap().rel().contains(x1, x2)).map(|x0| (x0,)) + }; + Some(IteratorFromDyn::new(res)) + } + + fn len(&self) -> usize { + // TODO random estimate, could be very wrong + self.0.reverse_map1.as_ref().unwrap().len() * self.0.reverse_map2.as_ref().unwrap().len() + / ((self.0.map.len() as f32).sqrt() as usize) + } +} + +pub struct TrRel2IndNone<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq>(&'a TrRel2IndCommon); + +impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexReadAll<'a> for TrRel2IndNone<'a, T0, T1> { + type Key = (); + type Value = (&'a T0, &'a T1, &'a T1); + + type ValueIteratorType = >::IteratorType; + type AllIteratorType = std::option::IntoIter<(Self::Key, Self::ValueIteratorType)>; + + fn iter_all(&'a self) -> Self::AllIteratorType { self.index_get(&()).map(|x| ((), x)).into_iter() } +} + +impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexRead<'a> for TrRel2IndNone<'a, T0, T1> { + type Key = (); + type Value = (&'a T0, &'a T1, &'a T1); + + type IteratorType = IteratorFromDyn<'a, Self::Value>; + + fn index_get(&'a self, (): &Self::Key) -> Option { + let res = || self.0.map.iter().flat_map(|(x0, trrel)| trrel.rel().iter_all().map(move |(x1, x2)| (x0, x1, x2))); + Some(IteratorFromDyn::new(res)) + } + + fn len(&self) -> usize { 1 } +} + +#[repr(transparent)] +pub struct TrRel2IndFull<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq>(&'a TrRel2IndCommon); + +impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelFullIndexRead<'a> for TrRel2IndFull<'a, T0, T1> { + type Key = (T0, T1, T1); + + #[inline] + fn contains_key(&'a self, (x0, x1, x2): &Self::Key) -> bool { + match self.0.map.get(x0) { + None => false, + Some(rel) => rel.rel().contains(x1, x2), + } + } +} + +impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexReadAll<'a> for TrRel2IndFull<'a, T0, T1> { + type Key = (&'a T0, &'a T1, &'a T1); + type Value = (); + + type ValueIteratorType = std::iter::Once<()>; + type AllIteratorType = Box + 'a>; + + fn iter_all(&'a self) -> Self::AllIteratorType { + let iter = self.0.map.iter().flat_map(|(x0, trrel)| trrel.rel().iter_all().map(move |(x1, x2)| (x0, x1, x2))); + + Box::new(iter.map(|t| (t, std::iter::once(())))) + } +} + +impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexRead<'a> for TrRel2IndFull<'a, T0, T1> { + type Key = (T0, T1, T1); + type Value = (); + + type IteratorType = std::iter::Once; + + fn index_get(&'a self, (x0, x1, x2): &Self::Key) -> Option { + if self.0.map.get(x0)?.rel().contains(x1, x2) { Some(once(())) } else { None } + } + + fn len(&self) -> usize { + let sample_size = 3; + let sum = self.0.map.values().take(sample_size).map(|trrel| trrel.rel().count_estimate()).sum::(); + let map_len = self.0.map.len(); + sum * map_len / sample_size.min(map_len).max(1) + } +} + +pub struct TrRel2IndFullWrite<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq>(&'a mut TrRel2IndCommon); + +impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexMerge for TrRel2IndFullWrite<'a, T0, T1> { + fn move_index_contents(_from: &mut Self, _to: &mut Self) {} // noop +} + +impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelFullIndexWrite for TrRel2IndFullWrite<'a, T0, T1> { + type Key = (T0, T1, T1); + + type Value = (); + + fn insert_if_not_present(&mut self, (x0, x1, x2): &Self::Key, (): Self::Value) -> bool { + let x0_hash = hash_one(self.0.map.hasher(), x0); + + if !self + .0 + .map + .raw_entry_mut() + .from_key_hashed_nocheck(x0_hash, x0) + .or_insert_with(|| (x0.clone(), TrRelIndCommon::make_new())) + .1 + .insert_by_ref(x1, x2) + { + return false; + } + if let Some(reverse_map1) = self.0.reverse_map1.as_mut() { + reverse_map1.entry(x1.clone()).or_default().insert_with_hash_no_check(x0_hash, x0.clone()); + // reverse_map1.entry(x1.clone()).or_default().raw_table().find(x0_hash, |a| &a.0 == x0).unwrap().copy_from_nonoverlapping(other).insert(x0.clone()); + } + if let Some(reverse_map2) = self.0.reverse_map2.as_mut() { + reverse_map2.entry(x2.clone()).or_default().insert_with_hash_no_check(x0_hash, x0.clone()); + } + true + } +} + +impl<'a, T0: Clone + Hash + Eq, T1: Clone + Hash + Eq> RelIndexWrite for TrRel2IndFullWrite<'a, T0, T1> { + type Key = (T0, T1, T1); + + type Value = (); + + fn index_insert(&mut self, (x0, x1, x2): Self::Key, (): Self::Value) { + if let Some(reverse_map1) = self.0.reverse_map1.as_mut() { + reverse_map1.entry(x1.clone()).or_default().insert(x0.clone()); + } + if let Some(reverse_map2) = self.0.reverse_map2.as_mut() { + reverse_map2.entry(x2.clone()).or_default().insert(x0.clone()); + } + self.0.map.entry(x0).or_insert_with(|| TrRelIndCommon::make_new()).insert(x1, x2); + } +} + +use std::marker::PhantomData; + +use ascent::internal::ToRelIndex; + +use crate::rel_boilerplate::NoopRelIndexWrite; + +macro_rules! to_trrel2 { + ($name: ident, $key: ty, $val: ty) => {paste::paste!{ + pub struct [](PhantomData<(T0, T1)>); + + impl Default for [] { + fn default() -> Self { Self(PhantomData) } + } + + impl ToRelIndex for [] + where Rel: std::ops::DerefMut> + { + type RelIndex<'a> = $name<'a, T0, T1> where Self: 'a, Rel: 'a; + #[inline(always)] + fn to_rel_index<'a>(&'a self, rel: &'a Rel) -> Self::RelIndex<'a> { $name(rel.deref()) } + + type RelIndexWrite<'a> = NoopRelIndexWrite<$key, $val> where Self: 'a, Rel: 'a; + #[inline(always)] + fn to_rel_index_write<'a>(&'a mut self, _rel: &'a mut Rel) -> Self::RelIndexWrite<'a> { + NoopRelIndexWrite::default() + } + } + }}; +} + +to_trrel2!(TrRel2IndNone, (), (T0, T1, T1)); +to_trrel2!(TrRel2Ind0, (T0,), (T1, T1)); +to_trrel2!(TrRel2Ind1, (T1,), (T0, T1)); +to_trrel2!(TrRel2Ind2, (T1,), (T0, T1)); +to_trrel2!(TrRel2Ind0_1, (T0, T1), (T1,)); +to_trrel2!(TrRel2Ind0_2, (T0, T1), (T1,)); +to_trrel2!(TrRel2Ind1_2, (T1, T1), (T0,)); +// to_trrel2!(TrRel2IndFull, (T0, T1, T1), ()); + +pub struct ToTrRel2IndFull(PhantomData<(T0, T1)>); + +impl Default for ToTrRel2IndFull { + fn default() -> Self { Self(PhantomData) } +} + +impl ToRelIndex for ToTrRel2IndFull +where Rel: std::ops::DerefMut> +{ + type RelIndex<'a> + = TrRel2IndFull<'a, T0, T1> + where + Self: 'a, + Rel: 'a; + #[inline(always)] + fn to_rel_index<'a>(&'a self, rel: &'a Rel) -> Self::RelIndex<'a> { TrRel2IndFull(rel.deref()) } + + type RelIndexWrite<'a> + = TrRel2IndFullWrite<'a, T0, T1> + where + Self: 'a, + Rel: 'a; + #[inline(always)] + fn to_rel_index_write<'a>(&'a mut self, rel: &'a mut Rel) -> Self::RelIndexWrite<'a> { + TrRel2IndFullWrite(rel.deref_mut()) + } +} diff --git a/byods/ascent-byods-rels/src/trrel_uf.rs b/byods/ascent-byods-rels/src/trrel_uf.rs index fffd847..975115c 100644 --- a/byods/ascent-byods-rels/src/trrel_uf.rs +++ b/byods/ascent-byods-rels/src/trrel_uf.rs @@ -1,23 +1,23 @@ -//! reflexive transitive relations for Ascent, supported by [`TrRelUnionFind`](crate::trrel_union_find::TrRelUnionFind) - -#[doc(hidden)] -#[macro_export] -macro_rules! trrel_uf_ind_common { - ($name: ident, ($col0: ty, $col1: ty), $indices: expr, ser, ()) => { - $crate::trrel_union_find_binary_ind::TrRelIndCommon<$col0> - }; - - ($name: ident, ($col0: ty, $col1: ty, $col2: ty), $indices: tt, ser, ()) => { - $crate::adaptor::bin_rel_to_ternary::BinRelToTernaryWrapper< - // reverse_map_1 required: - {$crate::inds_contain!($indices, [1]) || $crate::inds_contain!($indices, [1, 2])}, - // reverse_map_2 required: - {$crate::inds_contain!($indices, [2]) || $crate::inds_contain!($indices, [1, 2])}, - $col0, $col1, $col2, - $crate::trrel_union_find_binary_ind::TrRelIndCommon<$col1> - > - }; -} -pub use trrel_uf_ind_common as rel_ind_common; - -pub use crate::adaptor::bin_rel_plus_ternary_provider::{rel_codegen, rel_ind, rel_full_ind, rel}; \ No newline at end of file +//! reflexive transitive relations for Ascent, supported by [`TrRelUnionFind`](crate::trrel_union_find::TrRelUnionFind) + +#[doc(hidden)] +#[macro_export] +macro_rules! trrel_uf_ind_common { + ($name: ident, ($col0: ty, $col1: ty), $indices: expr, ser, ()) => { + $crate::trrel_union_find_binary_ind::TrRelIndCommon<$col0> + }; + + ($name: ident, ($col0: ty, $col1: ty, $col2: ty), $indices: tt, ser, ()) => { + $crate::adaptor::bin_rel_to_ternary::BinRelToTernaryWrapper< + // reverse_map_1 required: + {$crate::inds_contain!($indices, [1]) || $crate::inds_contain!($indices, [1, 2])}, + // reverse_map_2 required: + {$crate::inds_contain!($indices, [2]) || $crate::inds_contain!($indices, [1, 2])}, + $col0, $col1, $col2, + $crate::trrel_union_find_binary_ind::TrRelIndCommon<$col1> + > + }; +} +pub use trrel_uf_ind_common as rel_ind_common; + +pub use crate::adaptor::bin_rel_plus_ternary_provider::{rel, rel_codegen, rel_full_ind, rel_ind}; diff --git a/byods/ascent-byods-rels/src/trrel_union_find.rs b/byods/ascent-byods-rels/src/trrel_union_find.rs index a696a23..9f8882a 100644 --- a/byods/ascent-byods-rels/src/trrel_union_find.rs +++ b/byods/ascent-byods-rels/src/trrel_union_find.rs @@ -1,12 +1,12 @@ -//! A data structure for union-find based reflexive transitive relations. -//! +//! A data structure for union-find based reflexive transitive relations. +//! //! This is the backing data strcuture for [`trrel_uf`](crate::trrel_uf) in Ascent. -use ascent::internal::Instant; -use hashbrown::{HashMap, HashSet}; use std::fmt::Debug; use std::hash::{BuildHasher, BuildHasherDefault, Hash}; use std::time::Duration; +use ascent::internal::Instant; +use hashbrown::{HashMap, HashSet}; use itertools::Itertools; use rustc_hash::FxHasher; @@ -61,10 +61,10 @@ impl TrRelUnionFind { *self.set_subsumptions.get_mut(&id).unwrap() = grandparent; self.get_dominant_id_mut(grandparent) - } + }, None => dom_id, } - } + }, None => id, } } @@ -77,7 +77,7 @@ impl TrRelUnionFind { self.set_subsumptions.insert(id, dom_id); } (dom_id, depth + 1) - } + }, None => (id, 0), } } @@ -112,7 +112,7 @@ impl TrRelUnionFind { self.sets.push(HashSet::from_iter([x.clone()])); self.elem_ids.insert(x.clone(), elem_id); (elem_id, true) - } + }, }; #[cfg(debug_assertions)] @@ -120,9 +120,7 @@ impl TrRelUnionFind { res } - pub(crate) fn add_node(&mut self, x: T) -> usize { - self.add_node_new(x).0 - } + pub(crate) fn add_node(&mut self, x: T) -> usize { self.add_node_new(x).0 } pub fn add(&mut self, x: T, y: T) -> bool { let (x_set, x_new) = self.add_node_new(x.clone()); @@ -165,10 +163,7 @@ impl TrRelUnionFind { } fn merge_multiple( - &mut self, - from: usize, - to: usize, - in_between: &hashbrown::HashSet>, + &mut self, from: usize, to: usize, in_between: &hashbrown::HashSet>, ) -> usize { let before = Instant::now(); // TODO is this right? Aren't we doing too much? @@ -322,9 +317,7 @@ impl TrRelUnionFind { } #[inline] - pub fn is_empty(&self) -> bool { - self.sets.is_empty() - } + pub fn is_empty(&self) -> bool { self.sets.is_empty() } // TODO `set_connections` and `reverse_set_connections` should be guaranteed to contain only dominant ids pub fn get_set_connections(&self, set: usize) -> Option + '_> { @@ -404,17 +397,16 @@ fn extend_set(set1: &mut HashSet = HashSet<(T, T), BuildHasherDefault>; -pub enum TrRelIndCommon { - New { rel: NewSet }, - Delta { rel: TrRelDelta }, - Total { rel: Rc> }, -} - -pub struct TrRelDelta { - set_connections: binary_rel::Map, - rev_set_connections: binary_rel::Map, - precursor: NewSet, - total: Rc>, -} - -impl Default for TrRelDelta { - fn default() -> Self { - Self { set_connections: Default::default(), rev_set_connections: Default::default(), precursor: Default::default(), total: Default::default() } - } -} - -impl TrRelDelta { - fn ind_0_get(&self, x: &T) -> Option> { - let x_set = self.total.elem_set(x)?; - let sets_of_x = self.set_connections.get(&x_set)?; - - let res = || sets_of_x.iter().flat_map(|&s| { - self.total.sets[s].iter() - }); - Some(IteratorFromDyn::new(res)) - } - fn ind_1_get(&self, x: &T) -> Option> { - let x_set = self.total.elem_set(x)?; - let sets_of_x = self.rev_set_connections.get(&x_set)?; - - let res = || sets_of_x.iter().flat_map(|&s| { - self.total.sets[s].iter() - }); - Some(IteratorFromDyn::new(res)) - } - fn ind_0_iter_all<'a>(&'a self) -> IteratorFromDyn<(&'a T, IteratorFromDyn<&'_ T>)> { - let res = || self.set_connections.iter().flat_map(|(set_id, set_connections)| { - let xs = &self.total.sets[*set_id]; - xs.iter().map(|x| { - let ys = || set_connections.iter().flat_map(|sc| { - self.total.sets[*sc].iter() - }); - (x, IteratorFromDyn::new(ys)) - }) - }); - IteratorFromDyn::new(res) - } - fn ind_1_iter_all<'a>(&'a self) -> IteratorFromDyn<(&'a T, IteratorFromDyn<&'_ T>)> { - let res = || self.rev_set_connections.iter().flat_map(|(set_id, rev_set_connections)| { - let xs = &self.total.sets[*set_id]; - xs.iter().map(|x| { - let ys = || rev_set_connections.iter().flat_map(|sc| { - self.total.sets[*sc].iter() - }); - (x, IteratorFromDyn::new(ys)) - }) - }); - IteratorFromDyn::new(res) - } - - fn contains(&self, x: &T, y: &T) -> bool { - self.ind_0_1_get(x, y).is_some() - } - fn ind_0_1_get(&self, x: &T, y: &T) -> Option> { - let x_set = self.total.elem_set(x)?; - let y_set = self.total.elem_set(y)?; - if x_set == y_set {return None} - let x_set_connections = self.set_connections.get(&x_set)?; - if x_set_connections.contains(&y_set) { - Some(std::iter::once(())) - } else { - None - } - } - - fn iter_all<'a>(&'a self) -> impl Iterator + 'a { - let res = self.set_connections.iter().flat_map(move |(x_set, y_sets)| { - self.total.sets[*x_set].iter().flat_map(move |x| { - y_sets.iter().filter(move |y_set| *y_set != x_set).flat_map(move |y_set| { - self.total.sets[*y_set].iter().map(move |y| (x, y)) - }) - }) - }); - res - } - fn is_empty(&self) -> bool { - self.precursor.is_empty() - } -} - - -impl Default for TrRelIndCommon { - #[inline] - fn default() -> Self {Self::Total { rel: Default::default()}} -} - -impl TrRelIndCommon { - - pub fn unwrap_new_mut(&mut self) -> &mut NewSet { - match self { - TrRelIndCommon::New{ rel, .. } => rel, - _ => { - assert!(self.is_empty(), "unwrap_new_mut called on non-empty non-New"); - *self = TrRelIndCommon::New { rel: Default::default() }; - self.unwrap_new_mut() - }, - } - } - - pub fn unwrap_total(&self) -> &TrRelUnionFind { - match self { - TrRelIndCommon::Total{ rel, .. } => rel, - _ => panic!("TrRelIndCommon: unwrap_total called on non-Total"), - } - } - - pub fn is_empty(&self) -> bool { - match self { - TrRelIndCommon::New{rel, ..} => rel.is_empty(), - TrRelIndCommon::Delta { rel, .. } => rel.set_connections.is_empty(), - TrRelIndCommon::Total { rel, .. } => rel.elem_ids.is_empty(), - } - } - - pub fn count_exact(&self) -> usize { - self.unwrap_total().count_exact() - } -} - - -pub static mut MERGE_TIME: Duration = Duration::ZERO; -pub static mut MERGE_DELTA_CONSTRUCTION_TIME: Duration = Duration::ZERO; -pub static mut MERGE_TOTAL_UPDATE_TIME: Duration = Duration::ZERO; -pub static mut MERGE_COUNT: usize = 0; - -impl RelIndexMerge for TrRelIndCommon { - fn move_index_contents(_from: &mut Self, _to: &mut Self) { - panic!("merge_delta_to_total_new_to_delta must be called instead.") - } - - fn init(new: &mut Self, _delta: &mut Self, _total: &mut Self) { - *new = TrRelIndCommon::New { rel: Default::default() }; - } - - fn merge_delta_to_total_new_to_delta(new: &mut Self, delta: &mut Self, total: &mut Self) { - let before = Instant::now(); - - - if let TrRelIndCommon::Total { .. } = delta { - assert!(total.is_empty()); - *total = std::mem::take(delta); - *delta = TrRelIndCommon::Delta { rel: TrRelDelta::default() } - } - - - let mut delta_rel = match delta { - TrRelIndCommon::Delta { rel } => std::mem::take(rel), - _ => panic!("expected Delta"), - }; - delta_rel.total = Rc::new(Default::default()); - - let mut total_rel_rc = match total { - TrRelIndCommon::Total{rel, ..} => std::mem::take(rel), - _ => panic!("expected Total") - }; - - let mut new_rel = std::mem::take(new.unwrap_new_mut()); - - // optimization for when total will be empty - if total_rel_rc.is_empty() && delta_rel.is_empty() { - let mut new_delta = TrRelUnionFind::default(); - let before_total_update = Instant::now(); - - for (x, y) in new_rel.drain() { - new_delta.add(x.clone(), y); - } - unsafe { - MERGE_TOTAL_UPDATE_TIME += before_total_update.elapsed(); - } - *delta = TrRelIndCommon::Total { rel: Rc::new(new_delta) }; - return; - } - let total_rel = Rc::get_mut(&mut total_rel_rc).unwrap(); - - - let before_total_update = Instant::now(); - for (x, y) in delta_rel.precursor.drain() { - total_rel.add(x, y); - } - unsafe { - MERGE_TOTAL_UPDATE_TIME += before_total_update.elapsed(); - } - - type RelMap = HashMap::>, BuildHasherDefault>; - - let mut new_classes_map: RelMap = Default::default(); - let mut new_classes_rev_map: RelMap = Default::default(); - for (x, y) in new_rel.iter() { - let x_id = total_rel.add_node(x.clone()); - let y_id = total_rel.add_node(y.clone()); - new_classes_map.entry(x_id.clone()).or_default().insert(y_id); - new_classes_rev_map.entry(y_id).or_default().insert(x_id); - } - // println!("merge. new_classes_map.len(): {}, new_classes_rev_map.len(): {}", new_classes_map.len(), new_classes_rev_map.len()); - - let mut delta_delta_map = new_classes_map.clone(); - let mut delta_delta_rev_map = new_classes_rev_map; - - let mut delta_total_map = RelMap::::default(); - let mut delta_total_rev_map = RelMap::::default(); - - let mut delta_new_map = RelMap::::default(); - let mut delta_new_rev_map = RelMap::::default(); - - fn join<'a, T: Clone + Hash + Eq + 'a, - Rel1: RelIndexRead<'a, Key = T, Value = &'a T> + RelIndexReadAll<'a, Key = &'a T, Value = &'a T>, - Rel2Rev: RelIndexRead<'a, Key = T, Value = &'a T> + RelIndexReadAll<'a, Key = &'a T, Value = &'a T,> - > (target: &mut RelMap, target_rev: &mut RelMap, rel1: &'a Rel1, rel2_rev: &'a Rel2Rev, - mut can_add: impl FnMut(&T, &T) -> bool, _name: &str - ) -> bool - where Rel1::ValueIteratorType: Clone - { - let mut changed = false; - if rel1.len() < rel2_rev.len() { - for (x, x_set) in rel1.iter_all() { - if let Some(x_rev_set) = rel2_rev.index_get(x) { - for w in x_rev_set { - let entry = target.entry(w.clone()).or_default(); - for y in x_set.clone() { - if !can_add(w, y) {continue} - if entry.insert(y.clone()) { - target_rev.entry(y.clone()).or_default().insert(w.clone()); - changed = true; - } - } - if entry.is_empty() { - target.remove(w); - } - } - } - } - } else { - for (x, x_rev_set) in rel2_rev.iter_all() { - if let Some(x_set) = rel1.index_get(x) { - for w in x_rev_set { - let entry = target.entry(w.clone()).or_default(); - for y in x_set.clone() { - if !can_add(w, y) {continue} - if entry.insert(y.clone()) { - target_rev.entry(y.clone()).or_default().insert(w.clone()); - changed = true; - } - } - if entry.is_empty() { - target.remove(w); - } - } - } - } - } - changed - } - let before_loop = Instant::now(); - loop { - - let mut cached_delta_delta_map_entry_for_can_add = None; - let mut cached_delta_delta_map_x_for_can_add = None; - let mut cached_delta_total_map_entry_for_can_add = None; - let mut cached_delta_total_map_x_for_can_add = None; - let mut cached_total_map_entry_for_can_add = None; - let mut cached_total_map_x_for_can_add = None; - let mut can_add = |x: &usize, y: &usize| { - { - if cached_delta_delta_map_x_for_can_add.as_ref() != Some(x) { - cached_delta_delta_map_entry_for_can_add = delta_delta_map.get(x); - cached_delta_delta_map_x_for_can_add = Some(x.clone()); - }; - } - !cached_delta_delta_map_entry_for_can_add.map_or(false, |s| s.contains(y)) && - { - if cached_delta_total_map_x_for_can_add.as_ref() != Some(x) { - cached_delta_total_map_entry_for_can_add = delta_total_map.get(x); - cached_delta_total_map_x_for_can_add = Some(x.clone()); - }; - !cached_delta_total_map_entry_for_can_add.map_or(false, |s| s.contains(y)) - } && - { - if cached_total_map_x_for_can_add.as_ref() != Some(x) { - cached_total_map_entry_for_can_add = total_rel.set_connections.get(x); - cached_total_map_x_for_can_add = Some(x.clone()); - } - !cached_total_map_entry_for_can_add.map_or(false, |s| s.contains(y)) - } - }; - - let join1 = join(&mut delta_new_map, &mut delta_new_rev_map, &MapRelIndexAdaptor(&delta_delta_map), &MapRelIndexAdaptor(&total_rel.reverse_set_connections), &mut can_add, "join1"); - let join2 = join(&mut delta_new_map, &mut delta_new_rev_map, &MapRelIndexAdaptor(&total_rel.set_connections), &MapRelIndexAdaptor(&delta_delta_rev_map), &mut can_add, "join2"); - let join3 = join(&mut delta_new_map, &mut delta_new_rev_map, &MapRelIndexAdaptor(&new_classes_map), &MapRelIndexAdaptor(&delta_delta_rev_map), &mut can_add, "join3"); - - let changed = join1 | join2 | join3; - - move_hash_map_of_hash_set_contents_disjoint(&mut delta_delta_map, &mut delta_total_map); - move_hash_map_of_hash_set_contents_disjoint(&mut delta_delta_rev_map, &mut delta_total_rev_map); - - assert!(delta_delta_map.is_empty()); - assert!(delta_delta_rev_map.is_empty()); - - std::mem::swap(&mut delta_delta_map, &mut delta_new_map); - std::mem::swap(&mut delta_delta_rev_map, &mut delta_new_rev_map); - - if !changed { break } - } - unsafe { - MERGE_DELTA_CONSTRUCTION_TIME += before_loop.elapsed(); - } - - let new_delta: TrRelDelta = TrRelDelta { - set_connections: delta_total_map, - rev_set_connections: delta_total_rev_map, - precursor: new_rel, - total: total_rel_rc.clone() - }; - *delta = TrRelIndCommon::Delta{ rel: new_delta }; - *total = TrRelIndCommon::Total { rel: total_rel_rc }; - - unsafe { - MERGE_TIME += before.elapsed(); - MERGE_COUNT += 1; - } - } -} - -impl ByodsBinRel for TrRelIndCommon { - type T0 = T; - type T1 = T; - - fn contains(&self, x0: &Self::T0, x1: &Self::T1) -> bool { - match self { - TrRelIndCommon::Delta { rel, .. } => rel.contains(x0, x1), - TrRelIndCommon::Total { rel, .. } => rel.contains(x0, x1), - TrRelIndCommon::New { .. } => panic!("unexpected New") - } - } - - type AllIter<'a> = Box + 'a> where Self: 'a; - - fn iter_all<'a>(&'a self) -> Self::AllIter<'a> { - match self { - TrRelIndCommon::Delta { rel, .. } => Box::new(rel.iter_all()), - TrRelIndCommon::Total { rel, .. } => Box::new(rel.iter_all()), - TrRelIndCommon::New { .. } => panic!("unexpected New"), - } - } - - fn len_estimate(&self) -> usize { - let sample_size = 3; - match self { - TrRelIndCommon::Delta { rel, .. } => { - let avg_set_connections = - rel.set_connections.iter().take(3).map(|(_s, sets)| sets.len()).sum::() / sample_size.min(rel.set_connections.len()).max(1); - let avg_set_size = rel.total.sets.iter().rev().take(sample_size).map(|s| s.len()).sum::() - / sample_size.min(rel.total.sets.len()).max(1); - avg_set_connections * avg_set_size - }, - TrRelIndCommon::Total { rel, .. } => { - let avg_set_connections = - rel.set_connections.iter().take(3).map(|(_s, sets)| sets.len()).sum::() / sample_size.min(rel.set_connections.len()).max(1); - let avg_set_size = rel.sets.iter().rev().take(sample_size).map(|s| s.len()).sum::() - / sample_size.min(rel.sets.len()).max(1); - avg_set_connections * avg_set_size - }, - TrRelIndCommon::New { .. } => panic!("unexpected New"), - } - } - - - type Ind0AllIterValsIter<'a> = IteratorFromDyn<'a, &'a T> where Self: 'a; - type Ind0AllIter<'a> = IteratorFromDyn<'a, (&'a T, Self::Ind0AllIterValsIter<'a>)> where Self: 'a; - - fn ind0_iter_all<'a>(&'a self) -> Self::Ind0AllIter<'a> { - match self { - TrRelIndCommon::Delta { rel, .. } => rel.ind_0_iter_all(), - TrRelIndCommon::Total { rel, .. } => { - let res = || rel.elem_ids.iter().map(|(x, set_id)| { - let set = || rel.set_of_by_set_id(x, *set_id); - (x, IteratorFromDyn::new(set)) - }); - IteratorFromDyn::new(res) - }, - TrRelIndCommon::New { .. } => panic!("unexpected New"), - } - } - - fn ind0_len_estimate(&self) -> usize { - let res = match self { - TrRelIndCommon::Delta { rel, .. } => { - let sample_size = 5; - let sum_set_size = rel.total.sets.iter().rev().take(sample_size).map(|s| s.len()).sum::(); - sum_set_size * rel.set_connections.len() / sample_size.min(rel.total.sets.len()).max(1) - }, - TrRelIndCommon::Total { rel, .. } => { - let sample_size = 3; - let sum_set_size = rel.sets.iter().rev().take(sample_size).map(|s| s.len()).sum::(); - sum_set_size * rel.set_connections.len() / sample_size.min(rel.sets.len()).max(1) - }, - TrRelIndCommon::New { .. } => panic!("unexpected New"), - }; - res - } - - type Ind0ValsIter<'a> = IteratorFromDyn<'a, &'a T> where Self: 'a; - - fn ind0_index_get<'a>(&'a self, key: &Self::T0) -> Option> { - match self { - TrRelIndCommon::Delta { rel, .. } => rel.ind_0_get(key), - TrRelIndCommon::Total { rel, .. } => { - let (key, id) = rel.elem_ids.get_key_value(key)?; - let id = rel.get_dominant_id(*id); - let res = move || rel.set_of_by_set_id(key, id); - Some(IteratorFromDyn::new(res)) - }, - TrRelIndCommon::New { .. } => panic!("unexpected New"), - } - } - - type Ind1AllIterValsIter<'a> = IteratorFromDyn<'a, &'a T> where Self: 'a; - type Ind1AllIter<'a> = IteratorFromDyn<'a, (&'a T, Self::Ind1AllIterValsIter<'a>)> where Self: 'a; - - fn ind1_iter_all<'a>(&'a self) -> Self::Ind1AllIter<'a> { - match self { - TrRelIndCommon::Delta { rel, .. } => rel.ind_1_iter_all(), - TrRelIndCommon::Total { rel, .. } => { - let res = || rel.elem_ids.iter().map(|(x, set_id)| { - let set = || rel.rev_set_of_by_set_id(x, *set_id); - (x, IteratorFromDyn::new(set)) - }); - IteratorFromDyn::new(res) - }, - TrRelIndCommon::New { .. } => panic!("unexpected New"), - } - } - - fn ind1_len_estimate(&self) -> usize { - let res = match self { - TrRelIndCommon::Delta { rel, .. } => { - let sample_size = 5; - let sum_set_size = rel.total.sets.iter().rev().take(sample_size).map(|s| s.len()).sum::(); - sum_set_size * rel.rev_set_connections.len() / sample_size.min(rel.total.sets.len()).max(1) - }, - TrRelIndCommon::Total { rel, .. } => { - let sample_size = 3; - let sum_set_size = rel.sets.iter().rev().take(sample_size).map(|s| s.len()).sum::(); - sum_set_size * rel.reverse_set_connections.len() / sample_size.min(rel.sets.len()).max(1) - }, - TrRelIndCommon::New { .. } => panic!("unexpected New"), - }; - res - } - - type Ind1ValsIter<'a> = IteratorFromDyn<'a, &'a T> where Self: 'a; - fn ind1_index_get<'a>(&'a self, key: &Self::T1) -> Option> { - match self { - TrRelIndCommon::Delta { rel, .. } => rel.ind_1_get(key), - TrRelIndCommon::Total { rel, .. } => { - let (key, id) = rel.elem_ids.get_key_value(key)?; - let id = rel.get_dominant_id(*id); - let res = move || rel.rev_set_of_by_set_id(key, id); - Some(IteratorFromDyn::new(res)) - }, - TrRelIndCommon::New { .. } => panic!("unexpected New"), - } - } - - fn insert(&mut self, x0: Self::T0, x1: Self::T1) -> bool { - self.unwrap_new_mut().insert((x0, x1)) - } -} +use core::panic; +use std::hash::{BuildHasherDefault, Hash}; +use std::rc::Rc; +use std::time::{Duration, Instant}; + +use ascent::internal::{RelIndexMerge, RelIndexRead, RelIndexReadAll}; +use hashbrown::{HashMap, HashSet}; +use rustc_hash::FxHasher; + +use crate::adaptor::bin_rel::ByodsBinRel; +use crate::binary_rel::{self, MapRelIndexAdaptor}; +use crate::iterator_from_dyn::IteratorFromDyn; +use crate::trrel_binary::MyHashSet; +use crate::trrel_union_find::TrRelUnionFind; +use crate::utils::move_hash_map_of_hash_set_contents_disjoint; + +type NewSet = HashSet<(T, T), BuildHasherDefault>; +pub enum TrRelIndCommon { + New { rel: NewSet }, + Delta { rel: TrRelDelta }, + Total { rel: Rc> }, +} + +pub struct TrRelDelta { + set_connections: binary_rel::Map, + rev_set_connections: binary_rel::Map, + precursor: NewSet, + total: Rc>, +} + +impl Default for TrRelDelta { + fn default() -> Self { + Self { + set_connections: Default::default(), + rev_set_connections: Default::default(), + precursor: Default::default(), + total: Default::default(), + } + } +} + +impl TrRelDelta { + fn ind_0_get(&self, x: &T) -> Option> { + let x_set = self.total.elem_set(x)?; + let sets_of_x = self.set_connections.get(&x_set)?; + + let res = || sets_of_x.iter().flat_map(|&s| self.total.sets[s].iter()); + Some(IteratorFromDyn::new(res)) + } + fn ind_1_get(&self, x: &T) -> Option> { + let x_set = self.total.elem_set(x)?; + let sets_of_x = self.rev_set_connections.get(&x_set)?; + + let res = || sets_of_x.iter().flat_map(|&s| self.total.sets[s].iter()); + Some(IteratorFromDyn::new(res)) + } + fn ind_0_iter_all<'a>(&'a self) -> IteratorFromDyn<(&'a T, IteratorFromDyn<&'_ T>)> { + let res = || { + self.set_connections.iter().flat_map(|(set_id, set_connections)| { + let xs = &self.total.sets[*set_id]; + xs.iter().map(|x| { + let ys = || set_connections.iter().flat_map(|sc| self.total.sets[*sc].iter()); + (x, IteratorFromDyn::new(ys)) + }) + }) + }; + IteratorFromDyn::new(res) + } + fn ind_1_iter_all<'a>(&'a self) -> IteratorFromDyn<(&'a T, IteratorFromDyn<&'_ T>)> { + let res = || { + self.rev_set_connections.iter().flat_map(|(set_id, rev_set_connections)| { + let xs = &self.total.sets[*set_id]; + xs.iter().map(|x| { + let ys = || rev_set_connections.iter().flat_map(|sc| self.total.sets[*sc].iter()); + (x, IteratorFromDyn::new(ys)) + }) + }) + }; + IteratorFromDyn::new(res) + } + + fn contains(&self, x: &T, y: &T) -> bool { self.ind_0_1_get(x, y).is_some() } + fn ind_0_1_get(&self, x: &T, y: &T) -> Option> { + let x_set = self.total.elem_set(x)?; + let y_set = self.total.elem_set(y)?; + if x_set == y_set { + return None + } + let x_set_connections = self.set_connections.get(&x_set)?; + if x_set_connections.contains(&y_set) { Some(std::iter::once(())) } else { None } + } + + fn iter_all<'a>(&'a self) -> impl Iterator + 'a { + let res = self.set_connections.iter().flat_map(move |(x_set, y_sets)| { + self.total.sets[*x_set].iter().flat_map(move |x| { + y_sets + .iter() + .filter(move |y_set| *y_set != x_set) + .flat_map(move |y_set| self.total.sets[*y_set].iter().map(move |y| (x, y))) + }) + }); + res + } + fn is_empty(&self) -> bool { self.precursor.is_empty() } +} + +impl Default for TrRelIndCommon { + #[inline] + fn default() -> Self { Self::Total { rel: Default::default() } } +} + +impl TrRelIndCommon { + pub fn unwrap_new_mut(&mut self) -> &mut NewSet { + match self { + TrRelIndCommon::New { rel, .. } => rel, + _ => { + assert!(self.is_empty(), "unwrap_new_mut called on non-empty non-New"); + *self = TrRelIndCommon::New { rel: Default::default() }; + self.unwrap_new_mut() + }, + } + } + + pub fn unwrap_total(&self) -> &TrRelUnionFind { + match self { + TrRelIndCommon::Total { rel, .. } => rel, + _ => panic!("TrRelIndCommon: unwrap_total called on non-Total"), + } + } + + pub fn is_empty(&self) -> bool { + match self { + TrRelIndCommon::New { rel, .. } => rel.is_empty(), + TrRelIndCommon::Delta { rel, .. } => rel.set_connections.is_empty(), + TrRelIndCommon::Total { rel, .. } => rel.elem_ids.is_empty(), + } + } + + pub fn count_exact(&self) -> usize { self.unwrap_total().count_exact() } +} + +pub static mut MERGE_TIME: Duration = Duration::ZERO; +pub static mut MERGE_DELTA_CONSTRUCTION_TIME: Duration = Duration::ZERO; +pub static mut MERGE_TOTAL_UPDATE_TIME: Duration = Duration::ZERO; +pub static mut MERGE_COUNT: usize = 0; + +impl RelIndexMerge for TrRelIndCommon { + fn move_index_contents(_from: &mut Self, _to: &mut Self) { + panic!("merge_delta_to_total_new_to_delta must be called instead.") + } + + fn init(new: &mut Self, _delta: &mut Self, _total: &mut Self) { + *new = TrRelIndCommon::New { rel: Default::default() }; + } + + fn merge_delta_to_total_new_to_delta(new: &mut Self, delta: &mut Self, total: &mut Self) { + let before = Instant::now(); + + if let TrRelIndCommon::Total { .. } = delta { + assert!(total.is_empty()); + *total = std::mem::take(delta); + *delta = TrRelIndCommon::Delta { rel: TrRelDelta::default() } + } + + let mut delta_rel = match delta { + TrRelIndCommon::Delta { rel } => std::mem::take(rel), + _ => panic!("expected Delta"), + }; + delta_rel.total = Rc::new(Default::default()); + + let mut total_rel_rc = match total { + TrRelIndCommon::Total { rel, .. } => std::mem::take(rel), + _ => panic!("expected Total"), + }; + + let mut new_rel = std::mem::take(new.unwrap_new_mut()); + + // optimization for when total will be empty + if total_rel_rc.is_empty() && delta_rel.is_empty() { + let mut new_delta = TrRelUnionFind::default(); + let before_total_update = Instant::now(); + + for (x, y) in new_rel.drain() { + new_delta.add(x.clone(), y); + } + unsafe { + MERGE_TOTAL_UPDATE_TIME += before_total_update.elapsed(); + } + *delta = TrRelIndCommon::Total { rel: Rc::new(new_delta) }; + return; + } + let total_rel = Rc::get_mut(&mut total_rel_rc).unwrap(); + + let before_total_update = Instant::now(); + for (x, y) in delta_rel.precursor.drain() { + total_rel.add(x, y); + } + unsafe { + MERGE_TOTAL_UPDATE_TIME += before_total_update.elapsed(); + } + + type RelMap = HashMap>, BuildHasherDefault>; + + let mut new_classes_map: RelMap = Default::default(); + let mut new_classes_rev_map: RelMap = Default::default(); + for (x, y) in new_rel.iter() { + let x_id = total_rel.add_node(x.clone()); + let y_id = total_rel.add_node(y.clone()); + new_classes_map.entry(x_id.clone()).or_default().insert(y_id); + new_classes_rev_map.entry(y_id).or_default().insert(x_id); + } + // println!("merge. new_classes_map.len(): {}, new_classes_rev_map.len(): {}", new_classes_map.len(), new_classes_rev_map.len()); + + let mut delta_delta_map = new_classes_map.clone(); + let mut delta_delta_rev_map = new_classes_rev_map; + + let mut delta_total_map = RelMap::::default(); + let mut delta_total_rev_map = RelMap::::default(); + + let mut delta_new_map = RelMap::::default(); + let mut delta_new_rev_map = RelMap::::default(); + + fn join< + 'a, + T: Clone + Hash + Eq + 'a, + Rel1: RelIndexRead<'a, Key = T, Value = &'a T> + RelIndexReadAll<'a, Key = &'a T, Value = &'a T>, + Rel2Rev: RelIndexRead<'a, Key = T, Value = &'a T> + RelIndexReadAll<'a, Key = &'a T, Value = &'a T>, + >( + target: &mut RelMap, target_rev: &mut RelMap, rel1: &'a Rel1, rel2_rev: &'a Rel2Rev, + mut can_add: impl FnMut(&T, &T) -> bool, _name: &str, + ) -> bool + where + Rel1::ValueIteratorType: Clone, + { + let mut changed = false; + if rel1.len() < rel2_rev.len() { + for (x, x_set) in rel1.iter_all() { + if let Some(x_rev_set) = rel2_rev.index_get(x) { + for w in x_rev_set { + let entry = target.entry(w.clone()).or_default(); + for y in x_set.clone() { + if !can_add(w, y) { + continue + } + if entry.insert(y.clone()) { + target_rev.entry(y.clone()).or_default().insert(w.clone()); + changed = true; + } + } + if entry.is_empty() { + target.remove(w); + } + } + } + } + } else { + for (x, x_rev_set) in rel2_rev.iter_all() { + if let Some(x_set) = rel1.index_get(x) { + for w in x_rev_set { + let entry = target.entry(w.clone()).or_default(); + for y in x_set.clone() { + if !can_add(w, y) { + continue + } + if entry.insert(y.clone()) { + target_rev.entry(y.clone()).or_default().insert(w.clone()); + changed = true; + } + } + if entry.is_empty() { + target.remove(w); + } + } + } + } + } + changed + } + let before_loop = Instant::now(); + loop { + let mut cached_delta_delta_map_entry_for_can_add = None; + let mut cached_delta_delta_map_x_for_can_add = None; + let mut cached_delta_total_map_entry_for_can_add = None; + let mut cached_delta_total_map_x_for_can_add = None; + let mut cached_total_map_entry_for_can_add = None; + let mut cached_total_map_x_for_can_add = None; + let mut can_add = |x: &usize, y: &usize| { + { + if cached_delta_delta_map_x_for_can_add.as_ref() != Some(x) { + cached_delta_delta_map_entry_for_can_add = delta_delta_map.get(x); + cached_delta_delta_map_x_for_can_add = Some(x.clone()); + }; + } + !cached_delta_delta_map_entry_for_can_add.map_or(false, |s| s.contains(y)) + && { + if cached_delta_total_map_x_for_can_add.as_ref() != Some(x) { + cached_delta_total_map_entry_for_can_add = delta_total_map.get(x); + cached_delta_total_map_x_for_can_add = Some(x.clone()); + }; + !cached_delta_total_map_entry_for_can_add.map_or(false, |s| s.contains(y)) + } + && { + if cached_total_map_x_for_can_add.as_ref() != Some(x) { + cached_total_map_entry_for_can_add = total_rel.set_connections.get(x); + cached_total_map_x_for_can_add = Some(x.clone()); + } + !cached_total_map_entry_for_can_add.map_or(false, |s| s.contains(y)) + } + }; + + let join1 = join( + &mut delta_new_map, + &mut delta_new_rev_map, + &MapRelIndexAdaptor(&delta_delta_map), + &MapRelIndexAdaptor(&total_rel.reverse_set_connections), + &mut can_add, + "join1", + ); + let join2 = join( + &mut delta_new_map, + &mut delta_new_rev_map, + &MapRelIndexAdaptor(&total_rel.set_connections), + &MapRelIndexAdaptor(&delta_delta_rev_map), + &mut can_add, + "join2", + ); + let join3 = join( + &mut delta_new_map, + &mut delta_new_rev_map, + &MapRelIndexAdaptor(&new_classes_map), + &MapRelIndexAdaptor(&delta_delta_rev_map), + &mut can_add, + "join3", + ); + + let changed = join1 | join2 | join3; + + move_hash_map_of_hash_set_contents_disjoint(&mut delta_delta_map, &mut delta_total_map); + move_hash_map_of_hash_set_contents_disjoint(&mut delta_delta_rev_map, &mut delta_total_rev_map); + + assert!(delta_delta_map.is_empty()); + assert!(delta_delta_rev_map.is_empty()); + + std::mem::swap(&mut delta_delta_map, &mut delta_new_map); + std::mem::swap(&mut delta_delta_rev_map, &mut delta_new_rev_map); + + if !changed { + break + } + } + unsafe { + MERGE_DELTA_CONSTRUCTION_TIME += before_loop.elapsed(); + } + + let new_delta: TrRelDelta = TrRelDelta { + set_connections: delta_total_map, + rev_set_connections: delta_total_rev_map, + precursor: new_rel, + total: total_rel_rc.clone(), + }; + *delta = TrRelIndCommon::Delta { rel: new_delta }; + *total = TrRelIndCommon::Total { rel: total_rel_rc }; + + unsafe { + MERGE_TIME += before.elapsed(); + MERGE_COUNT += 1; + } + } +} + +impl ByodsBinRel for TrRelIndCommon { + type T0 = T; + type T1 = T; + + fn contains(&self, x0: &Self::T0, x1: &Self::T1) -> bool { + match self { + TrRelIndCommon::Delta { rel, .. } => rel.contains(x0, x1), + TrRelIndCommon::Total { rel, .. } => rel.contains(x0, x1), + TrRelIndCommon::New { .. } => panic!("unexpected New"), + } + } + + type AllIter<'a> + = Box + 'a> + where Self: 'a; + + fn iter_all<'a>(&'a self) -> Self::AllIter<'a> { + match self { + TrRelIndCommon::Delta { rel, .. } => Box::new(rel.iter_all()), + TrRelIndCommon::Total { rel, .. } => Box::new(rel.iter_all()), + TrRelIndCommon::New { .. } => panic!("unexpected New"), + } + } + + fn len_estimate(&self) -> usize { + let sample_size = 3; + match self { + TrRelIndCommon::Delta { rel, .. } => { + let avg_set_connections = rel.set_connections.iter().take(3).map(|(_s, sets)| sets.len()).sum::() + / sample_size.min(rel.set_connections.len()).max(1); + let avg_set_size = rel.total.sets.iter().rev().take(sample_size).map(|s| s.len()).sum::() + / sample_size.min(rel.total.sets.len()).max(1); + avg_set_connections * avg_set_size + }, + TrRelIndCommon::Total { rel, .. } => { + let avg_set_connections = rel.set_connections.iter().take(3).map(|(_s, sets)| sets.len()).sum::() + / sample_size.min(rel.set_connections.len()).max(1); + let avg_set_size = rel.sets.iter().rev().take(sample_size).map(|s| s.len()).sum::() + / sample_size.min(rel.sets.len()).max(1); + avg_set_connections * avg_set_size + }, + TrRelIndCommon::New { .. } => panic!("unexpected New"), + } + } + + type Ind0AllIterValsIter<'a> + = IteratorFromDyn<'a, &'a T> + where Self: 'a; + type Ind0AllIter<'a> + = IteratorFromDyn<'a, (&'a T, Self::Ind0AllIterValsIter<'a>)> + where Self: 'a; + + fn ind0_iter_all<'a>(&'a self) -> Self::Ind0AllIter<'a> { + match self { + TrRelIndCommon::Delta { rel, .. } => rel.ind_0_iter_all(), + TrRelIndCommon::Total { rel, .. } => { + let res = || { + rel.elem_ids.iter().map(|(x, set_id)| { + let set = || rel.set_of_by_set_id(x, *set_id); + (x, IteratorFromDyn::new(set)) + }) + }; + IteratorFromDyn::new(res) + }, + TrRelIndCommon::New { .. } => panic!("unexpected New"), + } + } + + fn ind0_len_estimate(&self) -> usize { + let res = match self { + TrRelIndCommon::Delta { rel, .. } => { + let sample_size = 5; + let sum_set_size = rel.total.sets.iter().rev().take(sample_size).map(|s| s.len()).sum::(); + sum_set_size * rel.set_connections.len() / sample_size.min(rel.total.sets.len()).max(1) + }, + TrRelIndCommon::Total { rel, .. } => { + let sample_size = 3; + let sum_set_size = rel.sets.iter().rev().take(sample_size).map(|s| s.len()).sum::(); + sum_set_size * rel.set_connections.len() / sample_size.min(rel.sets.len()).max(1) + }, + TrRelIndCommon::New { .. } => panic!("unexpected New"), + }; + res + } + + type Ind0ValsIter<'a> + = IteratorFromDyn<'a, &'a T> + where Self: 'a; + + fn ind0_index_get<'a>(&'a self, key: &Self::T0) -> Option> { + match self { + TrRelIndCommon::Delta { rel, .. } => rel.ind_0_get(key), + TrRelIndCommon::Total { rel, .. } => { + let (key, id) = rel.elem_ids.get_key_value(key)?; + let id = rel.get_dominant_id(*id); + let res = move || rel.set_of_by_set_id(key, id); + Some(IteratorFromDyn::new(res)) + }, + TrRelIndCommon::New { .. } => panic!("unexpected New"), + } + } + + type Ind1AllIterValsIter<'a> + = IteratorFromDyn<'a, &'a T> + where Self: 'a; + type Ind1AllIter<'a> + = IteratorFromDyn<'a, (&'a T, Self::Ind1AllIterValsIter<'a>)> + where Self: 'a; + + fn ind1_iter_all<'a>(&'a self) -> Self::Ind1AllIter<'a> { + match self { + TrRelIndCommon::Delta { rel, .. } => rel.ind_1_iter_all(), + TrRelIndCommon::Total { rel, .. } => { + let res = || { + rel.elem_ids.iter().map(|(x, set_id)| { + let set = || rel.rev_set_of_by_set_id(x, *set_id); + (x, IteratorFromDyn::new(set)) + }) + }; + IteratorFromDyn::new(res) + }, + TrRelIndCommon::New { .. } => panic!("unexpected New"), + } + } + + fn ind1_len_estimate(&self) -> usize { + let res = match self { + TrRelIndCommon::Delta { rel, .. } => { + let sample_size = 5; + let sum_set_size = rel.total.sets.iter().rev().take(sample_size).map(|s| s.len()).sum::(); + sum_set_size * rel.rev_set_connections.len() / sample_size.min(rel.total.sets.len()).max(1) + }, + TrRelIndCommon::Total { rel, .. } => { + let sample_size = 3; + let sum_set_size = rel.sets.iter().rev().take(sample_size).map(|s| s.len()).sum::(); + sum_set_size * rel.reverse_set_connections.len() / sample_size.min(rel.sets.len()).max(1) + }, + TrRelIndCommon::New { .. } => panic!("unexpected New"), + }; + res + } + + type Ind1ValsIter<'a> + = IteratorFromDyn<'a, &'a T> + where Self: 'a; + fn ind1_index_get<'a>(&'a self, key: &Self::T1) -> Option> { + match self { + TrRelIndCommon::Delta { rel, .. } => rel.ind_1_get(key), + TrRelIndCommon::Total { rel, .. } => { + let (key, id) = rel.elem_ids.get_key_value(key)?; + let id = rel.get_dominant_id(*id); + let res = move || rel.rev_set_of_by_set_id(key, id); + Some(IteratorFromDyn::new(res)) + }, + TrRelIndCommon::New { .. } => panic!("unexpected New"), + } + } + + fn insert(&mut self, x0: Self::T0, x1: Self::T1) -> bool { self.unwrap_new_mut().insert((x0, x1)) } +} diff --git a/byods/ascent-byods-rels/src/uf.rs b/byods/ascent-byods-rels/src/uf.rs index e448fee..d155adc 100644 --- a/byods/ascent-byods-rels/src/uf.rs +++ b/byods/ascent-byods-rels/src/uf.rs @@ -13,12 +13,10 @@ use self::elems::{Elems, FindResult, Id}; /// - `Elem` IDs are determined by their index in the `Vec` /// - [`Id`] and `Rank` happen to be integers pub mod elems { - use std::{ - cell::Cell, - collections::HashSet, - fmt::Debug, - ops::{Add, Index}, - }; + use std::cell::Cell; + use std::collections::HashSet; + use std::fmt::Debug; + use std::ops::{Add, Index}; #[cfg(not(feature = "compact"))] type UfPtrType = usize; @@ -48,9 +46,7 @@ pub mod elems { impl Add for UfPtr { type Output = UfPtr; - fn add(self, rhs: UfPtrType) -> Self::Output { - UfPtr(self.0 + rhs) - } + fn add(self, rhs: UfPtrType) -> Self::Output { UfPtr(self.0 + rhs) } } #[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)] @@ -62,9 +58,7 @@ pub mod elems { impl Add for Rank { type Output = Rank; - fn add(self, rhs: UfPtrType) -> Self::Output { - Rank(self.0 + rhs) - } + fn add(self, rhs: UfPtrType) -> Self::Output { Rank(self.0 + rhs) } } /// Each field is wrapped in a [`Cell`] so that the union-find can expose @@ -145,24 +139,17 @@ pub mod elems { let parent = unsafe { self.get_unchecked(parent_id) }; let grandparent_id = parent.parent.get(); if grandparent_id == parent_id { - return FindResult { - id: parent_id, - elem: parent, - }; + return FindResult { id: parent_id, elem: parent }; } // Path halving elem.parent.set(grandparent_id); self.find(grandparent_id) } - pub(super) fn get(&self, id: Id) -> Option<&Elem> { - self.0.get(id.0.to_usize()) - } + pub(super) fn get(&self, id: Id) -> Option<&Elem> { self.0.get(id.0.to_usize()) } #[inline] - pub(super) fn has(&self, id: Id) -> bool { - id.0.to_usize() < self.0.len() - } + pub(super) fn has(&self, id: Id) -> bool { id.0.to_usize() < self.0.len() } #[inline] pub(super) unsafe fn get_unchecked(&self, id: Id) -> &Elem { @@ -210,20 +197,14 @@ pub mod elems { } #[inline] - pub(super) fn len(&self) -> usize { - self.0.len() - } + pub(super) fn len(&self) -> usize { self.0.len() } #[inline] - fn next(&self) -> Id { - Id(UfPtr::from_usize(self.len())) - } + fn next(&self) -> Id { Id(UfPtr::from_usize(self.len())) } /// O(1) #[allow(dead_code)] - pub(super) fn ok_cheap(&self) -> bool { - self.len() < UfPtr::MAX.to_usize() - } + pub(super) fn ok_cheap(&self) -> bool { self.len() < UfPtr::MAX.to_usize() } /// O(n^2) #[allow(dead_code)] @@ -290,36 +271,25 @@ pub mod elems { } #[inline] - pub(super) fn _is_empty(&self) -> bool { - self.0.is_empty() - } + pub(super) fn _is_empty(&self) -> bool { self.0.is_empty() } pub(super) fn push(&mut self, value: T) -> Id { debug_assert!(self.ok_cheap()); let id = self.next(); - self.0.push(Elem { - next: Cell::new(id), - parent: Cell::new(id), - rank: Cell::new(Rank(UfPtr(0))), - value, - }); + self.0.push(Elem { next: Cell::new(id), parent: Cell::new(id), rank: Cell::new(Rank(UfPtr(0))), value }); assert!(self.len() < UfPtr::MAX.to_usize()); id } #[inline] - pub(super) fn _with_capacity(cap: usize) -> Self { - Self(Vec::with_capacity(cap)) - } + pub(super) fn _with_capacity(cap: usize) -> Self { Self(Vec::with_capacity(cap)) } } impl Index for Elems { type Output = Elem; #[inline] - fn index(&self, index: Id) -> &Self::Output { - &self.get(index).unwrap() - } + fn index(&self, index: Id) -> &Self::Output { &self.get(index).unwrap() } } /// Iterator over a single equivalence class @@ -336,11 +306,7 @@ pub mod elems { } unsafe fn new_unchecked(elems: &'a Elems, start: Id) -> Self { - Class { - elems, - start, - current: elems.get_unchecked(start), - } + Class { elems, start, current: elems.get_unchecked(start) } } } @@ -370,13 +336,7 @@ pub mod elems { } impl<'a, T: PartialEq> Classes<'a, T> { - fn new(elems: &'a Elems) -> Self { - Self { - elems, - seen: HashSet::new(), - iter: elems.0.iter(), - } - } + fn new(elems: &'a Elems) -> Self { Self { elems, seen: HashSet::new(), iter: elems.0.iter() } } } /// Guaranteed to yield the root of each class first in each iterator. @@ -437,9 +397,7 @@ impl UnionFind { self.elems.find(id).id } - pub fn find_item(&self, item: &T) -> Option { - self.find_item_internal(item).map(|r| r.id) - } + pub fn find_item(&self, item: &T) -> Option { self.find_item_internal(item).map(|r| r.id) } fn find_item_internal(&self, item: &T) -> Option> { match self.items.get(item) { @@ -450,7 +408,7 @@ impl UnionFind { let fr @ FindResult { id: root_id, .. } = unsafe { self.elems.find(id) }; id_cell.set(root_id); Some(fr) - } + }, } } @@ -467,15 +425,11 @@ impl UnionFind { } /// O(1) - fn ok_cheap(&self) -> bool { - self.elems.len() == self.items.len() - } + fn ok_cheap(&self) -> bool { self.elems.len() == self.items.len() } /// O(n^2) #[allow(dead_code)] - fn ok(&self) -> bool { - self.ok_cheap() && self.elems.ok() - } + fn ok(&self) -> bool { self.ok_cheap() && self.elems.ok() } fn push(&mut self, item: T) -> Id { debug_assert!(self.ok_cheap()); // invariant @@ -532,10 +486,8 @@ impl UnionFind { #[cfg(test)] mod tests { - use proptest::{ - prelude::{any, Strategy}, - proptest, - }; + use proptest::prelude::{Strategy, any}; + use proptest::proptest; use super::UnionFind; diff --git a/byods/ascent-byods-rels/src/union_find.rs b/byods/ascent-byods-rels/src/union_find.rs index a32d230..6e1d3bd 100644 --- a/byods/ascent-byods-rels/src/union_find.rs +++ b/byods/ascent-byods-rels/src/union_find.rs @@ -1,199 +1,206 @@ -use hashbrown::{HashMap, HashSet}; -use std::hash::{Hash, BuildHasherDefault}; -use std::iter::{FlatMap, Repeat, Zip}; - -use hashbrown::hash_set::Iter as HashSetIter; - -#[cfg(feature = "par")] -use ascent::rayon::prelude::{ParallelIterator, IntoParallelRefIterator}; -use rustc_hash::FxHasher; - -use crate::utils::merge_sets; - -#[derive(Clone, Debug)] -pub struct EqRel { - pub(crate) sets: Vec>>, - pub(crate) elem_ids: HashMap>, - pub(crate) set_subsumptions: HashMap>, -} - -impl Default for EqRel { - fn default() -> Self { - Self { sets: Default::default(), elem_ids: Default::default(), set_subsumptions: Default::default() } - } -} - -pub type IterAllIterator<'a, T> = FlatMap< - std::slice::Iter<'a, HashSet>>, - FlatMap< - Zip, Repeat>>, - Zip, Repeat<&'a T>>, - for<'aa> fn((&'aa T, HashSetIter<'aa, T>)) -> Zip, Repeat<&'aa T>>, - >, - fn(&HashSet>,) -> FlatMap, Repeat>>,Zip, Repeat<&T>>,for<'aa> fn((&'aa T, HashSetIter<'aa, T>)) -> Zip, Repeat<&'aa T>>,>,>; - - -#[cfg(feature = "par")] -pub struct IterAllParIterator<'a, T: Clone + Hash + Eq>(&'a EqRel); -#[cfg(feature = "par")] -impl<'a, T: Clone + Hash + Eq + Sync> ParallelIterator for IterAllParIterator<'a, T> { - type Item = (&'a T, &'a T); - - fn drive_unindexed(self, consumer: C) -> C::Result - where C: ascent::rayon::iter::plumbing::UnindexedConsumer - { - self.0.sets.par_iter() - .flat_map:: _, _>(|s| s.par_iter().map_with(s, |s, x| s.par_iter().map_with(x, |x, y| (*x, y))).flatten()) - .drive_unindexed(consumer) - } -} - - -impl EqRel { - fn get_dominant_id(&self, id: usize) -> usize { - match self.set_subsumptions.get(&id) { - Some(dom_id) => self.get_dominant_id(*dom_id), - None => id, - } - } - pub(crate) fn elem_set(&self, elem: &T) -> Option { self.elem_ids.get(elem).map(|id| self.get_dominant_id(*id)) } - - fn get_dominant_id_update(&mut self, id: usize) -> usize { - match self.set_subsumptions.get(&id) { - Some(&parent_id) => { - let dom_id = self.get_dominant_id_update(parent_id); - if dom_id != parent_id { - self.set_subsumptions.insert(id, dom_id); - } - dom_id - }, - None => id, - } - } - pub(crate) fn elem_set_update(&mut self, elem: &T) -> Option { - let id = self.elem_ids.get(elem)?; - Some(self.get_dominant_id_update(*id)) - } - - pub fn add(&mut self, x: T, y: T) -> bool { - let x_set = self.elem_set_update(&x); - let y_set = self.elem_set_update(&y); - match (x_set, y_set) { - (None, None) => { - let id = self.sets.len(); - self.sets.push(HashSet::from_iter([x.clone(), y.clone()])); - self.elem_ids.insert(x.clone(), id); - self.elem_ids.insert(y.clone(), id); - true - }, - (None, Some(y_set)) => { - self.sets[y_set].insert(x.clone()); - self.elem_ids.insert(x, y_set); - true - }, - (Some(x_set), None) => { - self.sets[x_set].insert(y.clone()); - self.elem_ids.insert(y, x_set); - true - }, - (Some(x_set), Some(y_set)) => { - if x_set != y_set { - let y_set_taken = std::mem::take(&mut self.sets[y_set]); - merge_sets(&mut self.sets[x_set], y_set_taken); - self.set_subsumptions.insert(y_set, x_set); - true - } else { - false - } - } - } - } - - pub fn set_of(&self, x: &T) -> Option> { - let set = self.elem_set(x)?; - let res = Some(self.sets[set].iter()); - res - } - - #[cfg(feature = "par")] - pub fn c_set_of(&self, x: &T) -> Option<&'_ hashbrown::hash_set::HashSet>> where T: Sync{ - let set = self.elem_set(x)?; - let res = Some(&self.sets[set]); - res - } - - // TODO not used - #[allow(dead_code)] - fn set_of_inc_x<'a> (&'a self, x: &'a T) -> impl Iterator { - let set = self.set_of(x); - let x_itself = if set.is_none() { Some(x) } else { None }; - set.into_iter().flatten().chain(x_itself) - } - - pub fn iter_all(&self) -> IterAllIterator<'_, T> { - let res: IterAllIterator<'_, T> = self - .sets - .iter() - .flat_map(|s| s.iter().zip(std::iter::repeat(s.iter())).flat_map(|(x, s)| s.zip(std::iter::repeat(x)))); - res - } - - #[cfg(feature = "par")] - pub fn c_iter_all<'a>(&'a self) -> IterAllParIterator<'a, T> where T: Sync - { - IterAllParIterator(self) - } - - pub fn contains(&self, x: &T, y: &T) -> bool { self.elem_set(x).map_or(false, |set| self.sets[set].contains(y)) } - - pub fn combine(&mut self, other: Self) { - for set in other.sets.into_iter() { - if set.len() == 1 { - let repr = set.into_iter().next().unwrap(); - self.add(repr.clone(), repr); - } else if set.len() > 1 { - let mut set = set.into_iter(); - let repr = set.next().unwrap(); - for x in set { - self.add(repr.clone(), x); - } - } - } - } - - pub fn count_exact(&self) -> usize { - self.sets.iter().map(|s| s.len() * s.len()).sum() - } -} - -#[test] -fn test_eq_rel() { - let mut eqrel = EqRel::::default(); - eqrel.add(1, 2); - eqrel.add(11, 12); - assert!(eqrel.contains(&1, &2)); - assert!(!eqrel.contains(&1, &12)); - eqrel.add(1, 3); - eqrel.add(13, 12); - assert!(!eqrel.contains(&2, &13)); - eqrel.add(3, 11); - assert!(eqrel.contains(&2, &13)); -} - -#[test] -fn test_eq_rel_combine() { - let mut eqrel1 = EqRel::::default(); - eqrel1.add(1, 2); - eqrel1.add(1, 3); - eqrel1.add(1, 10); - - let mut eqrel2 = EqRel::::default(); - eqrel2.add(10, 11); - eqrel2.add(11, 12); - eqrel2.add(13, 12); - - assert!(!eqrel1.contains(&1, &13)); - eqrel1.combine(eqrel2); - assert!(eqrel1.contains(&1, &13)); -} - +use std::hash::{BuildHasherDefault, Hash}; +use std::iter::{FlatMap, Repeat, Zip}; + +#[cfg(feature = "par")] +use ascent::rayon::prelude::{IntoParallelRefIterator, ParallelIterator}; +use hashbrown::hash_set::Iter as HashSetIter; +use hashbrown::{HashMap, HashSet}; +use rustc_hash::FxHasher; + +use crate::utils::merge_sets; + +#[derive(Clone, Debug)] +pub struct EqRel { + pub(crate) sets: Vec>>, + pub(crate) elem_ids: HashMap>, + pub(crate) set_subsumptions: HashMap>, +} + +impl Default for EqRel { + fn default() -> Self { + Self { sets: Default::default(), elem_ids: Default::default(), set_subsumptions: Default::default() } + } +} + +pub type IterAllIterator<'a, T> = FlatMap< + std::slice::Iter<'a, HashSet>>, + FlatMap< + Zip, Repeat>>, + Zip, Repeat<&'a T>>, + for<'aa> fn((&'aa T, HashSetIter<'aa, T>)) -> Zip, Repeat<&'aa T>>, + >, + fn( + &HashSet>, + ) -> FlatMap< + Zip, Repeat>>, + Zip, Repeat<&T>>, + for<'aa> fn((&'aa T, HashSetIter<'aa, T>)) -> Zip, Repeat<&'aa T>>, + >, +>; + +#[cfg(feature = "par")] +pub struct IterAllParIterator<'a, T: Clone + Hash + Eq>(&'a EqRel); +#[cfg(feature = "par")] +impl<'a, T: Clone + Hash + Eq + Sync> ParallelIterator for IterAllParIterator<'a, T> { + type Item = (&'a T, &'a T); + + fn drive_unindexed(self, consumer: C) -> C::Result + where C: ascent::rayon::iter::plumbing::UnindexedConsumer { + self + .0 + .sets + .par_iter() + .flat_map:: _, _>(|s| { + s.par_iter().map_with(s, |s, x| s.par_iter().map_with(x, |x, y| (*x, y))).flatten() + }) + .drive_unindexed(consumer) + } +} + +impl EqRel { + fn get_dominant_id(&self, id: usize) -> usize { + match self.set_subsumptions.get(&id) { + Some(dom_id) => self.get_dominant_id(*dom_id), + None => id, + } + } + pub(crate) fn elem_set(&self, elem: &T) -> Option { + self.elem_ids.get(elem).map(|id| self.get_dominant_id(*id)) + } + + fn get_dominant_id_update(&mut self, id: usize) -> usize { + match self.set_subsumptions.get(&id) { + Some(&parent_id) => { + let dom_id = self.get_dominant_id_update(parent_id); + if dom_id != parent_id { + self.set_subsumptions.insert(id, dom_id); + } + dom_id + }, + None => id, + } + } + pub(crate) fn elem_set_update(&mut self, elem: &T) -> Option { + let id = self.elem_ids.get(elem)?; + Some(self.get_dominant_id_update(*id)) + } + + pub fn add(&mut self, x: T, y: T) -> bool { + let x_set = self.elem_set_update(&x); + let y_set = self.elem_set_update(&y); + match (x_set, y_set) { + (None, None) => { + let id = self.sets.len(); + self.sets.push(HashSet::from_iter([x.clone(), y.clone()])); + self.elem_ids.insert(x.clone(), id); + self.elem_ids.insert(y.clone(), id); + true + }, + (None, Some(y_set)) => { + self.sets[y_set].insert(x.clone()); + self.elem_ids.insert(x, y_set); + true + }, + (Some(x_set), None) => { + self.sets[x_set].insert(y.clone()); + self.elem_ids.insert(y, x_set); + true + }, + (Some(x_set), Some(y_set)) => + if x_set != y_set { + let y_set_taken = std::mem::take(&mut self.sets[y_set]); + merge_sets(&mut self.sets[x_set], y_set_taken); + self.set_subsumptions.insert(y_set, x_set); + true + } else { + false + }, + } + } + + pub fn set_of(&self, x: &T) -> Option> { + let set = self.elem_set(x)?; + let res = Some(self.sets[set].iter()); + res + } + + #[cfg(feature = "par")] + pub fn c_set_of(&self, x: &T) -> Option<&'_ hashbrown::hash_set::HashSet>> + where T: Sync { + let set = self.elem_set(x)?; + let res = Some(&self.sets[set]); + res + } + + // TODO not used + #[allow(dead_code)] + fn set_of_inc_x<'a>(&'a self, x: &'a T) -> impl Iterator { + let set = self.set_of(x); + let x_itself = if set.is_none() { Some(x) } else { None }; + set.into_iter().flatten().chain(x_itself) + } + + pub fn iter_all(&self) -> IterAllIterator<'_, T> { + let res: IterAllIterator<'_, T> = self + .sets + .iter() + .flat_map(|s| s.iter().zip(std::iter::repeat(s.iter())).flat_map(|(x, s)| s.zip(std::iter::repeat(x)))); + res + } + + #[cfg(feature = "par")] + pub fn c_iter_all<'a>(&'a self) -> IterAllParIterator<'a, T> + where T: Sync { + IterAllParIterator(self) + } + + pub fn contains(&self, x: &T, y: &T) -> bool { self.elem_set(x).map_or(false, |set| self.sets[set].contains(y)) } + + pub fn combine(&mut self, other: Self) { + for set in other.sets.into_iter() { + if set.len() == 1 { + let repr = set.into_iter().next().unwrap(); + self.add(repr.clone(), repr); + } else if set.len() > 1 { + let mut set = set.into_iter(); + let repr = set.next().unwrap(); + for x in set { + self.add(repr.clone(), x); + } + } + } + } + + pub fn count_exact(&self) -> usize { self.sets.iter().map(|s| s.len() * s.len()).sum() } +} + +#[test] +fn test_eq_rel() { + let mut eqrel = EqRel::::default(); + eqrel.add(1, 2); + eqrel.add(11, 12); + assert!(eqrel.contains(&1, &2)); + assert!(!eqrel.contains(&1, &12)); + eqrel.add(1, 3); + eqrel.add(13, 12); + assert!(!eqrel.contains(&2, &13)); + eqrel.add(3, 11); + assert!(eqrel.contains(&2, &13)); +} + +#[test] +fn test_eq_rel_combine() { + let mut eqrel1 = EqRel::::default(); + eqrel1.add(1, 2); + eqrel1.add(1, 3); + eqrel1.add(1, 10); + + let mut eqrel2 = EqRel::::default(); + eqrel2.add(10, 11); + eqrel2.add(11, 12); + eqrel2.add(13, 12); + + assert!(!eqrel1.contains(&1, &13)); + eqrel1.combine(eqrel2); + assert!(eqrel1.contains(&1, &13)); +} diff --git a/byods/ascent-byods-rels/src/utils.rs b/byods/ascent-byods-rels/src/utils.rs index 3a5b657..850b027 100644 --- a/byods/ascent-byods-rels/src/utils.rs +++ b/byods/ascent-byods-rels/src/utils.rs @@ -1,206 +1,226 @@ -use hashbrown; -use hashbrown::Equivalent; -use std::hash::BuildHasher; -use std::hash::Hash; -use std::hash::Hasher; -use hashbrown::HashSet; -use hashbrown::HashMap; - -use crate::iterator_from_dyn::IteratorFromDyn; -use crate::trrel_binary::MyHashSet; - -pub(crate) fn move_hash_map_of_hash_set_contents(from: &mut HashMap, S1>, to: &mut HashMap, S1>) -where K: Clone + Hash + Eq, V: Clone + Hash + Eq, S1: BuildHasher, S2: BuildHasher -{ - if from.len() > to.len() { - std::mem::swap(from, to); - } - - for (k, mut from_set) in from.drain() { - match to.entry(k) { - hashbrown::hash_map::Entry::Occupied(mut to_set) => { - if from_set.len() > to_set.get().len() { std::mem::swap(&mut from_set, to_set.get_mut()) } - to_set.get_mut().extend(from_set.drain()); - }, - hashbrown::hash_map::Entry::Vacant(to_set_vac) => { - to_set_vac.insert(from_set); - }, - } - } -} - -pub(crate) fn move_hash_map_of_vec_contents(from: &mut HashMap, S1>, to: &mut HashMap, S1>) -where K: Clone + Hash + Eq, V: Clone + Hash + Eq, S1: BuildHasher -{ - if from.len() > to.len() { - std::mem::swap(from, to); - } - - for (k, mut from_vec) in from.drain() { - match to.entry(k) { - hashbrown::hash_map::Entry::Occupied(mut to_vec) => { - if from_vec.len() > to_vec.get().len() { std::mem::swap(&mut from_vec, to_vec.get_mut()) } - to_vec.get_mut().append(&mut from_vec); - }, - hashbrown::hash_map::Entry::Vacant(to_vec_vac) => { - to_vec_vac.insert(from_vec); - }, - } - } -} - -pub(crate) fn move_hash_map_of_hash_set_contents_disjoint(from: &mut HashMap, S1>, to: &mut HashMap, S1>) -where K: Clone + Hash + Eq, V: Clone + Hash + Eq, S1: BuildHasher, S2: BuildHasher -{ - if from.len() > to.len() { - std::mem::swap(from, to); - } - - for (k, mut from_set) in from.drain() { - match to.entry(k) { - hashbrown::hash_map::Entry::Occupied(mut to_set) => { - move_hash_set_contents_disjoint(&mut from_set, to_set.get_mut()); - }, - hashbrown::hash_map::Entry::Vacant(to_set_vac) => { - to_set_vac.insert(from_set); - }, - } - } -} - -pub fn move_hash_set_contents_disjoint(from: &mut HashSet, to: &mut HashSet) { - if from.len() > to.len() { - std::mem::swap(from, to); - } - to.reserve(from.len()); - for x in from.drain() { - to.insert_unique_unchecked(x); - } -} - -pub(crate) fn move_hash_map_of_alt_hash_set_contents(from: &mut HashMap, S1>, to: &mut HashMap, S1>) -where K: Clone + Hash + Eq, V: Clone + Hash + Eq, S1: BuildHasher, S2: BuildHasher -{ - if from.len() > to.len() { - std::mem::swap(from, to); - } - - for (k, mut from_set) in from.drain() { - match to.entry(k) { - hashbrown::hash_map::Entry::Occupied(mut to_set) => { - if from_set.len() > to_set.get().len() { std::mem::swap(&mut from_set, to_set.get_mut()) } - to_set.get_mut().extend(from_set.drain()); - }, - hashbrown::hash_map::Entry::Vacant(to_set_vac) => { - to_set_vac.insert(from_set); - }, - } - } -} - -// #[allow(dead_code)] -#[inline] -pub fn hash_one(hahser: &S, x: &T) -> u64 { - let mut hasher = hahser.build_hasher(); - x.hash(&mut hasher); - hasher.finish() -} - - -pub struct AltHashSet(pub(crate) HashMap); - -impl Default for AltHashSet { - #[inline(always)] - fn default() -> Self { Self(Default::default()) } -} - -impl AltHashSet { - #[inline(always)] - pub fn contains>(&self, k: &Q) -> bool { - self.0.contains_key(k) - } - - #[inline(always)] - pub fn iter(&self) -> AltHashSetIter<'_, T> { - self.0.keys() - } - - #[inline(always)] - pub fn insert(&mut self, x: T) -> bool { - self.0.insert(x, ()).is_none() - } - - #[inline(always)] - pub fn len(&self) -> usize { self.0.len() } - - pub fn extend>(&mut self, iter: Iter) { - self.0.extend(iter.into_iter().map(|x| (x, ()))) - } - - #[inline] - pub fn insert_with_hash_no_check(&mut self, hash: u64, item: T) { - self.0.raw_entry_mut().from_key_hashed_nocheck(hash, &item).or_insert(item, ()); - } - - pub fn drain(&mut self) -> impl Iterator + '_ { - self.0.drain().map(|kv| kv.0) - } - - pub fn intersection<'a>(&'a self, other: &'a Self) -> impl Iterator + 'a{ - let (small, big) = if self.len() < other.len() {(self, other)} else {(other, self)}; - small.iter().filter(|&x| big.contains(x)) - } - -} -pub type AltHashSetIter<'a, T> = hashbrown::hash_map::Keys<'a, T, ()>; - -// TODO remove if not used -fn _set_extend_with_hash_no_check(set: &mut AltHashSet, iter: Iter) -where T: Clone + Hash + Eq, S: BuildHasher, Iter: Iterator { - set.0.reserve(iter.size_hint().0); - for (hash, item) in iter { - set.0.raw_entry_mut().from_key_hashed_nocheck(hash, &item).insert(item, ()); - } -} - -// TODO remove if not used -#[allow(dead_code)] -pub fn hash_map_hash_set_intersection<'a, K, V, S>(hm: &'a HashMap, hs: &'a MyHashSet) -> IteratorFromDyn<'a, &'a V> -where K: Clone + Hash + Eq, S: BuildHasher -{ - if hm.len() < hs.len() { - IteratorFromDyn::new(move || hm.iter().filter_map(move |(k, v)| if hs.contains(k) { Some(v) } else {None})) - } else { - IteratorFromDyn::new(|| hs.iter().filter_map(|k| hm.get(k))) - } -} - -#[derive(Clone, Eq, PartialEq, Debug, Hash)] -#[allow(dead_code)] -pub enum Either { - Left(L), - Right(R) -} - -impl Iterator for Either -where L: Iterator, R: Iterator { - type Item = T; - - #[inline] - fn next(&mut self) -> Option { - match self { - Either::Left(l) => l.next(), - Either::Right(r) => r.next(), - } - } -} - -pub(crate) fn merge_sets(set1: &mut HashSet, mut set2: HashSet) { - if set1.len() < set2.len() { - std::mem::swap(set1, &mut set2); - } - set1.extend(set2); -} - - - +use std::hash::{BuildHasher, Hash, Hasher}; + +use hashbrown; +use hashbrown::{Equivalent, HashMap, HashSet}; + +use crate::iterator_from_dyn::IteratorFromDyn; +use crate::trrel_binary::MyHashSet; + +pub(crate) fn move_hash_map_of_hash_set_contents( + from: &mut HashMap, S1>, to: &mut HashMap, S1>, +) where + K: Clone + Hash + Eq, + V: Clone + Hash + Eq, + S1: BuildHasher, + S2: BuildHasher, +{ + if from.len() > to.len() { + std::mem::swap(from, to); + } + + for (k, mut from_set) in from.drain() { + match to.entry(k) { + hashbrown::hash_map::Entry::Occupied(mut to_set) => { + if from_set.len() > to_set.get().len() { + std::mem::swap(&mut from_set, to_set.get_mut()) + } + to_set.get_mut().extend(from_set.drain()); + }, + hashbrown::hash_map::Entry::Vacant(to_set_vac) => { + to_set_vac.insert(from_set); + }, + } + } +} + +pub(crate) fn move_hash_map_of_vec_contents( + from: &mut HashMap, S1>, to: &mut HashMap, S1>, +) where + K: Clone + Hash + Eq, + V: Clone + Hash + Eq, + S1: BuildHasher, +{ + if from.len() > to.len() { + std::mem::swap(from, to); + } + + for (k, mut from_vec) in from.drain() { + match to.entry(k) { + hashbrown::hash_map::Entry::Occupied(mut to_vec) => { + if from_vec.len() > to_vec.get().len() { + std::mem::swap(&mut from_vec, to_vec.get_mut()) + } + to_vec.get_mut().append(&mut from_vec); + }, + hashbrown::hash_map::Entry::Vacant(to_vec_vac) => { + to_vec_vac.insert(from_vec); + }, + } + } +} + +pub(crate) fn move_hash_map_of_hash_set_contents_disjoint( + from: &mut HashMap, S1>, to: &mut HashMap, S1>, +) where + K: Clone + Hash + Eq, + V: Clone + Hash + Eq, + S1: BuildHasher, + S2: BuildHasher, +{ + if from.len() > to.len() { + std::mem::swap(from, to); + } + + for (k, mut from_set) in from.drain() { + match to.entry(k) { + hashbrown::hash_map::Entry::Occupied(mut to_set) => { + move_hash_set_contents_disjoint(&mut from_set, to_set.get_mut()); + }, + hashbrown::hash_map::Entry::Vacant(to_set_vac) => { + to_set_vac.insert(from_set); + }, + } + } +} + +pub fn move_hash_set_contents_disjoint(from: &mut HashSet, to: &mut HashSet) { + if from.len() > to.len() { + std::mem::swap(from, to); + } + to.reserve(from.len()); + for x in from.drain() { + to.insert_unique_unchecked(x); + } +} + +pub(crate) fn move_hash_map_of_alt_hash_set_contents( + from: &mut HashMap, S1>, to: &mut HashMap, S1>, +) where + K: Clone + Hash + Eq, + V: Clone + Hash + Eq, + S1: BuildHasher, + S2: BuildHasher, +{ + if from.len() > to.len() { + std::mem::swap(from, to); + } + + for (k, mut from_set) in from.drain() { + match to.entry(k) { + hashbrown::hash_map::Entry::Occupied(mut to_set) => { + if from_set.len() > to_set.get().len() { + std::mem::swap(&mut from_set, to_set.get_mut()) + } + to_set.get_mut().extend(from_set.drain()); + }, + hashbrown::hash_map::Entry::Vacant(to_set_vac) => { + to_set_vac.insert(from_set); + }, + } + } +} + +// #[allow(dead_code)] +#[inline] +pub fn hash_one(hahser: &S, x: &T) -> u64 { + let mut hasher = hahser.build_hasher(); + x.hash(&mut hasher); + hasher.finish() +} + +pub struct AltHashSet(pub(crate) HashMap); + +impl Default for AltHashSet { + #[inline(always)] + fn default() -> Self { Self(Default::default()) } +} + +impl AltHashSet { + #[inline(always)] + pub fn contains>(&self, k: &Q) -> bool { self.0.contains_key(k) } + + #[inline(always)] + pub fn iter(&self) -> AltHashSetIter<'_, T> { self.0.keys() } + + #[inline(always)] + pub fn insert(&mut self, x: T) -> bool { self.0.insert(x, ()).is_none() } + + #[inline(always)] + pub fn len(&self) -> usize { self.0.len() } + + pub fn extend>(&mut self, iter: Iter) { + self.0.extend(iter.into_iter().map(|x| (x, ()))) + } + + #[inline] + pub fn insert_with_hash_no_check(&mut self, hash: u64, item: T) { + self.0.raw_entry_mut().from_key_hashed_nocheck(hash, &item).or_insert(item, ()); + } + + pub fn drain(&mut self) -> impl Iterator + '_ { self.0.drain().map(|kv| kv.0) } + + pub fn intersection<'a>(&'a self, other: &'a Self) -> impl Iterator + 'a { + let (small, big) = if self.len() < other.len() { (self, other) } else { (other, self) }; + small.iter().filter(|&x| big.contains(x)) + } +} +pub type AltHashSetIter<'a, T> = hashbrown::hash_map::Keys<'a, T, ()>; + +// TODO remove if not used +fn _set_extend_with_hash_no_check(set: &mut AltHashSet, iter: Iter) +where + T: Clone + Hash + Eq, + S: BuildHasher, + Iter: Iterator, +{ + set.0.reserve(iter.size_hint().0); + for (hash, item) in iter { + set.0.raw_entry_mut().from_key_hashed_nocheck(hash, &item).insert(item, ()); + } +} + +// TODO remove if not used +#[allow(dead_code)] +pub fn hash_map_hash_set_intersection<'a, K, V, S>( + hm: &'a HashMap, hs: &'a MyHashSet, +) -> IteratorFromDyn<'a, &'a V> +where + K: Clone + Hash + Eq, + S: BuildHasher, +{ + if hm.len() < hs.len() { + IteratorFromDyn::new(move || hm.iter().filter_map(move |(k, v)| if hs.contains(k) { Some(v) } else { None })) + } else { + IteratorFromDyn::new(|| hs.iter().filter_map(|k| hm.get(k))) + } +} + +#[derive(Clone, Eq, PartialEq, Debug, Hash)] +#[allow(dead_code)] +pub enum Either { + Left(L), + Right(R), +} + +impl Iterator for Either +where + L: Iterator, + R: Iterator, +{ + type Item = T; + + #[inline] + fn next(&mut self) -> Option { + match self { + Either::Left(l) => l.next(), + Either::Right(r) => r.next(), + } + } +} + +pub(crate) fn merge_sets(set1: &mut HashSet, mut set2: HashSet) { + if set1.len() < set2.len() { + std::mem::swap(set1, &mut set2); + } + set1.extend(set2); +} diff --git a/rustfmt.toml b/rustfmt.toml index db09c5b..36e654c 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -1,18 +1,26 @@ # https://github.com/rust-lang/rustfmt/blob/master/Configurations.md tab_spaces = 3 -fn_call_width = 100 -chain_width = 100 max_width = 120 +fn_call_width = 110 +use_small_heuristics = "Max" +single_line_if_else_max_width = 110 +single_line_let_else_max_width = 110 +struct_lit_width = 110 +struct_variant_width = 110 +short_array_element_width_threshold = 100 fn_single_line = true -fn_params_layout = "compressed" +fn_params_layout = "Compressed" match_arm_blocks = false newline_style = "Unix" overflow_delimited_expr = true reorder_modules = false -single_line_if_else_max_width = 100 -single_line_let_else_max_width = 100 -struct_lit_width = 100 -struct_variant_width = 100 trailing_semicolon = false unstable_features = true -use_small_heuristics = "Max" \ No newline at end of file +format_macro_bodies = false +match_block_trailing_comma = true +merge_derives = false +style_edition = "2024" +use_field_init_shorthand = true +where_single_line = true +group_imports = "StdExternalCrate" +imports_granularity = "Module"