From 68be5217869f8e224aa9d16a81c80136e5832346 Mon Sep 17 00:00:00 2001 From: k88hudson-cfa Date: Fri, 11 Oct 2024 10:36:12 -0700 Subject: [PATCH] Add derived properties --- examples/basic-infection/people.rs | 3 +- src/context.rs | 42 ++-- src/people.rs | 318 ++++++++++++++++++++++++++++- 3 files changed, 334 insertions(+), 29 deletions(-) diff --git a/examples/basic-infection/people.rs b/examples/basic-infection/people.rs index a30bea4..528f27a 100644 --- a/examples/basic-infection/people.rs +++ b/examples/basic-infection/people.rs @@ -1,4 +1,4 @@ -use ixa::context::Context; +use ixa::context::{Context, Event}; use ixa::define_data_plugin; use std::collections::HashMap; @@ -18,6 +18,7 @@ pub struct InfectionStatusEvent { pub updated_status: InfectionStatus, pub person_id: usize, } +impl Event for InfectionStatusEvent {} pub trait ContextPeopleExt { fn create_person(&mut self); diff --git a/src/context.rs b/src/context.rs index 65e0ea1..4793806 100644 --- a/src/context.rs +++ b/src/context.rs @@ -16,6 +16,11 @@ type Callback = dyn FnOnce(&mut Context); /// A handler for an event type `E` type EventHandler = dyn Fn(&mut Context, E); +pub trait Event { + /// Called every time `context.subscribe_to_event` is called with this event + fn on_subscribe(_context: &mut Context) {} +} + /// A manager for the state of a discrete-event simulation /// /// Provides core simulation services including @@ -71,7 +76,7 @@ impl Context { /// Handlers will be called upon event emission in order of subscription as /// queued `Callback`s with the appropriate event. #[allow(clippy::missing_panics_doc)] - pub fn subscribe_to_event( + pub fn subscribe_to_event( &mut self, handler: impl Fn(&mut Context, E) + 'static, ) { @@ -81,6 +86,7 @@ impl Context { .or_insert_with(|| Box::>>>::default()); let handler_vec: &mut Vec>> = handler_vec.downcast_mut().unwrap(); handler_vec.push(Rc::new(handler)); + E::on_subscribe(self); } /// Emit and event of type E to be handled by registered receivers @@ -88,7 +94,7 @@ impl Context { /// Receivers will handle events in the order that they have subscribed and /// are queued as callbacks #[allow(clippy::missing_panics_doc)] - pub fn emit_event(&mut self, event: E) { + pub fn emit_event(&mut self, event: E) { // Destructure to obtain event handlers and plan queue let Context { event_handlers, @@ -414,14 +420,16 @@ mod tests { } #[derive(Copy, Clone)] - struct Event { + struct Event1 { pub data: usize, } + impl Event for Event1 {} #[derive(Copy, Clone)] struct Event2 { pub data: usize, } + impl Event for Event2 {} #[test] fn simple_event() { @@ -429,11 +437,11 @@ mod tests { let obs_data = Rc::new(RefCell::new(0)); let obs_data_clone = Rc::clone(&obs_data); - context.subscribe_to_event::(move |_, event| { + context.subscribe_to_event::(move |_, event| { *obs_data_clone.borrow_mut() = event.data; }); - context.emit_event(Event { data: 1 }); + context.emit_event(Event1 { data: 1 }); context.execute(); assert_eq!(*obs_data.borrow(), 1); } @@ -444,12 +452,12 @@ mod tests { let obs_data = Rc::new(RefCell::new(0)); let obs_data_clone = Rc::clone(&obs_data); - context.subscribe_to_event::(move |_, event| { + context.subscribe_to_event::(move |_, event| { *obs_data_clone.borrow_mut() += event.data; }); - context.emit_event(Event { data: 1 }); - context.emit_event(Event { data: 2 }); + context.emit_event(Event1 { data: 1 }); + context.emit_event(Event1 { data: 2 }); context.execute(); // Both of these should have been received. @@ -464,13 +472,13 @@ mod tests { let obs_data2 = Rc::new(RefCell::new(0)); let obs_data2_clone = Rc::clone(&obs_data2); - context.subscribe_to_event::(move |_, event| { + context.subscribe_to_event::(move |_, event| { *obs_data1_clone.borrow_mut() = event.data; }); - context.subscribe_to_event::(move |_, event| { + context.subscribe_to_event::(move |_, event| { *obs_data2_clone.borrow_mut() = event.data; }); - context.emit_event(Event { data: 1 }); + context.emit_event(Event1 { data: 1 }); context.execute(); assert_eq!(*obs_data1.borrow(), 1); assert_eq!(*obs_data2.borrow(), 1); @@ -484,13 +492,13 @@ mod tests { let obs_data2 = Rc::new(RefCell::new(0)); let obs_data2_clone = Rc::clone(&obs_data2); - context.subscribe_to_event::(move |_, event| { + context.subscribe_to_event::(move |_, event| { *obs_data1_clone.borrow_mut() = event.data; }); context.subscribe_to_event::(move |_, event| { *obs_data2_clone.borrow_mut() = event.data; }); - context.emit_event(Event { data: 1 }); + context.emit_event(Event1 { data: 1 }); context.emit_event(Event2 { data: 2 }); context.execute(); assert_eq!(*obs_data1.borrow(), 1); @@ -503,8 +511,8 @@ mod tests { let obs_data = Rc::new(RefCell::new(0)); let obs_data_clone = Rc::clone(&obs_data); - context.emit_event(Event { data: 1 }); - context.subscribe_to_event::(move |_, event| { + context.emit_event(Event1 { data: 1 }); + context.subscribe_to_event::(move |_, event| { *obs_data_clone.borrow_mut() = event.data; }); @@ -545,10 +553,10 @@ mod tests { let mut context = Context::new(); let obs_data = Rc::new(RefCell::new(0)); let obs_data_clone = Rc::clone(&obs_data); - context.subscribe_to_event::(move |_, event| { + context.subscribe_to_event::(move |_, event| { *obs_data_clone.borrow_mut() = event.data; }); - context.emit_event(Event { data: 1 }); + context.emit_event(Event1 { data: 1 }); context.shutdown(); context.execute(); assert_eq!(*obs_data.borrow(), 0); diff --git a/src/people.rs b/src/people.rs index c54e418..38c88ca 100644 --- a/src/people.rs +++ b/src/people.rs @@ -1,9 +1,12 @@ -use crate::{context::Context, define_data_plugin}; use serde::{Deserialize, Serialize}; +use crate::{ + context::{Context, Event}, + define_data_plugin, +}; use std::{ any::{Any, TypeId}, cell::{RefCell, RefMut}, - collections::HashMap, + collections::{HashMap, HashSet}, fmt, }; @@ -13,6 +16,8 @@ use std::{ struct PeopleData { current_population: usize, properties_map: RefCell>>, + registered_derived_properties: RefCell>, + dependency_map: RefCell>>>, } define_data_plugin!( @@ -20,7 +25,9 @@ define_data_plugin!( PeopleData, PeopleData { current_population: 0, - properties_map: RefCell::new(HashMap::new()) + properties_map: RefCell::new(HashMap::new()), + registered_derived_properties: RefCell::new(HashSet::new()), + dependency_map: RefCell::new(HashMap::new()) } ); @@ -51,7 +58,54 @@ impl fmt::Debug for PersonId { // They may be defined with the define_person_property! macro. pub trait PersonProperty: Copy { type Value: Copy; - fn initialize(context: &Context, person_id: PersonId) -> Self::Value; + #[must_use] + fn is_derived() -> bool { + false + } + #[must_use] + fn dependencies() -> Vec { + panic!("Dependencies not implemented"); + } + fn compute(context: &Context, person_id: PersonId) -> Self::Value; + fn get_instance() -> Self; +} + +type ContextCallback = dyn FnOnce(&mut Context); + +// The purpose of this trait is to allow us to store a vector of different PersonProperties +// in an object safe way, and use them to emit change events. +pub trait PersonPropertyHolder { + // Adds a callback to callback_vec which can later be called with a new context + // to emit a change event + fn add_event_callback( + &self, + context: &mut Context, + person: PersonId, + callback_vec: &mut Vec>, + ); +} + +impl PersonPropertyHolder for T +where + T: PersonProperty + 'static, +{ + fn add_event_callback( + &self, + context: &mut Context, + person: PersonId, + callback_vec: &mut Vec>, + ) { + let previous = context.get_person_property(person, T::get_instance()); + callback_vec.push(Box::new(move |ctx| { + let current = ctx.get_person_property(person, T::get_instance()); + let change_event: PersonPropertyChangeEvent = PersonPropertyChangeEvent { + person_id: person, + current, + previous, + }; + ctx.emit_event(change_event); + })); + } } /// Defines a person property with the following parameters: @@ -67,12 +121,15 @@ macro_rules! define_person_property { pub struct $person_property; impl $crate::people::PersonProperty for $person_property { type Value = $value; - fn initialize( + fn compute( _context: &$crate::context::Context, _person: $crate::people::PersonId, ) -> Self::Value { $initialize(_context, _person) } + fn get_instance() -> Self { + $person_property + } } }; ($person_property:ident, $value:ty) => { @@ -95,6 +152,38 @@ macro_rules! define_person_property_with_default { }; } +/// Defines a derived person property with the following parameters: +/// * `$person_property`: A name for the identifier type of the property +/// * `$value`: The type of the property's value +/// * `[$($dependency),+]`: A list of person properties the derived property depends on +/// * $calculate: A closure that takes the values of each dependency and returns the derived value +#[macro_export] +macro_rules! define_derived_property { + ($derived_property:ident, $value:ty, [$($dependency:ident),+], |$($param:ident),+| $derive_fn:expr) => { + #[derive(Copy, Clone)] + pub struct $derived_property; + + impl $crate::people::PersonProperty for $derived_property { + type Value = $value; + + fn compute(context: &$crate::context::Context, person_id: $crate::people::PersonId) -> Self::Value { + #[allow(unused_parens)] + let ($($param),+) = ( + $(context.get_person_property(person_id, $dependency)),+ + ); + (|$($param),+| $derive_fn)($($param),+) + } + fn is_derived() -> bool { true } + fn dependencies() -> Vec { + vec![$(std::any::TypeId::of::<$dependency>()),+] + } + fn get_instance() -> Self { + $derived_property + } + } + }; +} + pub use define_person_property; impl PeopleData { @@ -152,6 +241,7 @@ impl PeopleData { pub struct PersonCreatedEvent { pub person_id: PersonId, } +impl Event for PersonCreatedEvent {} // Emitted when a person property is updated // These should not be emitted outside this module @@ -162,6 +252,13 @@ pub struct PersonPropertyChangeEvent { pub current: T::Value, pub previous: T::Value, } +impl Event for PersonPropertyChangeEvent { + fn on_subscribe(context: &mut Context) { + if T::is_derived() { + context.register_property::(); + } + } +} pub trait ContextPeopleExt { /// Returns the current population size @@ -179,6 +276,8 @@ pub trait ContextPeopleExt { _property: T, ) -> T::Value; + fn register_property(&mut self); + /// Given a `PersonId`, initialize the value of a defined person property. /// Once the the value is set using this API, any initializer will /// not run. @@ -216,6 +315,27 @@ impl ContextPeopleExt for Context { person_id } + fn register_property(&mut self) { + let data_container = self.get_data_container(PeoplePlugin) + .expect("PeoplePlugin is not initialized; make sure you add a person before accessing properties"); + if !data_container + .registered_derived_properties + .borrow() + .contains(&TypeId::of::()) + { + let dependencies = T::dependencies(); + for dependency in dependencies { + let mut property_dependencies = data_container.dependency_map.borrow_mut(); + let deps = property_dependencies.entry(dependency).or_default(); + deps.push(Box::new(T::get_instance())); + } + data_container + .registered_derived_properties + .borrow_mut() + .insert(TypeId::of::()); + } + } + fn get_person_property( &self, person_id: PersonId, @@ -224,13 +344,17 @@ impl ContextPeopleExt for Context { let data_container = self.get_data_container(PeoplePlugin) .expect("PeoplePlugin is not initialized; make sure you add a person before accessing properties"); + if T::is_derived() { + return T::compute(self, person_id); + } + // Attempt to retrieve the existing value if let Some(value) = *data_container.get_person_property_ref(person_id, property) { return value; } // Initialize the property. This does not fire a change event - let initialized_value = T::initialize(self, person_id); + let initialized_value = T::compute(self, person_id); data_container.set_person_property(person_id, property, initialized_value); initialized_value @@ -242,6 +366,7 @@ impl ContextPeopleExt for Context { property: T, value: T::Value, ) { + assert!(!T::is_derived(), "Cannot initialize a derived property"); let data_container = self.get_data_container(PeoplePlugin) .expect("PeoplePlugin is not initialized; make sure you add a person before accessing properties"); @@ -257,26 +382,56 @@ impl ContextPeopleExt for Context { property: T, value: T::Value, ) { + assert!(!T::is_derived(), "Cannot set a derived property"); let data_container = self.get_data_container(PeoplePlugin) .expect("PeoplePlugin is not initialized; make sure you add a person before accessing properties"); - let current_value = *data_container.get_person_property_ref(person_id, property); - let previous_value = match current_value { - Some(current_value) => current_value, + let current_cached_value = *data_container.get_person_property_ref(person_id, property); + let previous_value = match current_cached_value { + Some(value) => value, None => { - let initialize_value = T::initialize(self, person_id); + let initialize_value = T::compute(self, person_id); data_container.set_person_property(person_id, property, initialize_value); initialize_value } }; + // Temporarily remove dependency properties since we need mutable references + // to self during callback execution + let deps_temp = { + let data_container = self.get_data_container(PeoplePlugin).unwrap(); + let mut dependencies = data_container.dependency_map.borrow_mut(); + dependencies.get_mut(&TypeId::of::()).map(std::mem::take) + }; + + let mut dependency_event_callbacks = Vec::new(); + if let Some(mut deps) = deps_temp { + // If there are dependencies, set up a bunch of callbacks with the + // current value + for dep in &mut deps { + dep.add_event_callback(self, person_id, &mut dependency_event_callbacks); + } + + // Put the dependency list back in + let data_container = self.get_data_container(PeoplePlugin).unwrap(); + let mut dependencies = data_container.dependency_map.borrow_mut(); + dependencies.insert(TypeId::of::(), deps); + } + + // Update the main property and send a change event + let data_container = self.get_data_container(PeoplePlugin).unwrap(); + data_container.set_person_property(person_id, property, value); let change_event: PersonPropertyChangeEvent = PersonPropertyChangeEvent { person_id, current: value, previous: previous_value, }; - data_container.set_person_property(person_id, property, value); self.emit_event(change_event); + + // If there are dependency callbacks, call them with the updated value + for callback in dependency_event_callbacks { + callback(self); + } } fn get_person_id(&self, person_id: usize) -> PersonId { @@ -295,6 +450,19 @@ mod test { use std::{cell::RefCell, rc::Rc}; define_person_property!(Age, u8); + #[derive(Copy, Clone, Debug, PartialEq, Eq)] + pub enum AgeGroupType { + Child, + Adult, + } + define_derived_property!(AgeGroup, AgeGroupType, [Age], |age| { + if age < 18 { + AgeGroupType::Child + } else { + AgeGroupType::Adult + } + }); + #[derive(Copy, Clone, PartialEq, Eq, Debug)] pub enum RiskCategory { High, @@ -310,6 +478,15 @@ mod test { 0 } }); + define_derived_property!(TrailRunner, bool, [IsRunner, Age], |is_runner, age| { + is_runner && age > 29 + }); + define_derived_property!( + UltraRunner, + bool, + [TrailRunner, Age], + |trail_runner, age| { trail_runner && age > 39 } + ); #[test] fn observe_person_addition() { @@ -553,4 +730,123 @@ mod test { context.add_person(); context.get_person_id(1); } + + #[test] + fn get_person_property_returns_correct_value() { + let mut context = Context::new(); + let person = context.add_person(); + context.initialize_person_property(person, Age, 10); + assert_eq!( + context.get_person_property(person, AgeGroup), + AgeGroupType::Child + ); + } + + #[test] + fn get_person_property_changes_correctly() { + let mut context = Context::new(); + let person = context.add_person(); + context.initialize_person_property(person, Age, 17); + assert_eq!( + context.get_person_property(person, AgeGroup), + AgeGroupType::Child + ); + context.set_person_property(person, Age, 18); + assert_eq!( + context.get_person_property(person, AgeGroup), + AgeGroupType::Adult + ); + } + #[test] + fn get_person_property_change_event() { + let mut context = Context::new(); + let person = context.add_person(); + context.initialize_person_property(person, Age, 17); + + let flag = Rc::new(RefCell::new(false)); + let flag_clone = flag.clone(); + context.subscribe_to_event( + move |_context, event: PersonPropertyChangeEvent| { + assert_eq!(event.person_id.id, 0); + assert_eq!(event.previous, AgeGroupType::Child); + assert_eq!(event.current, AgeGroupType::Adult); + *flag_clone.borrow_mut() = true; + }, + ); + context.set_person_property(person, Age, 18); + context.execute(); + assert!(*flag.borrow()); + } + + #[test] + fn get_derived_property_multiple_deps() { + let mut context = Context::new(); + let person = context.add_person(); + context.initialize_person_property(person, Age, 29); + context.initialize_person_property(person, IsRunner, true); + + let flag = Rc::new(RefCell::new(false)); + let flag_clone = flag.clone(); + context.subscribe_to_event( + move |_context, event: PersonPropertyChangeEvent| { + assert_eq!(event.person_id.id, 0); + assert!(!event.previous); + assert!(event.current); + *flag_clone.borrow_mut() = true; + }, + ); + context.set_person_property(person, Age, 30); + context.execute(); + assert!(*flag.borrow()); + } + + #[test] + fn register_derived_only_once() { + let mut context = Context::new(); + let person = context.add_person(); + context.initialize_person_property(person, Age, 29); + context.initialize_person_property(person, IsRunner, true); + + let flag = Rc::new(RefCell::new(0)); + let flag_clone = flag.clone(); + context.subscribe_to_event( + move |_context, _event: PersonPropertyChangeEvent| { + *flag_clone.borrow_mut() += 1; + }, + ); + context.subscribe_to_event( + move |_context, _event: PersonPropertyChangeEvent| { + // Make sure that we don't register multiple times + }, + ); + context.set_person_property(person, Age, 30); + context.execute(); + assert_eq!(*flag.borrow(), 1); + } + + // TODO(ryl8@cdc.gov): Nested derived properties don't currently work; we should + // probably either fix this or disallow it. + #[ignore] + #[test] + fn get_derived_property_dependent_on_another_derived() { + let mut context = Context::new(); + let person = context.add_person(); + context.initialize_person_property(person, Age, 40); + context.initialize_person_property(person, IsRunner, false); + + let flag = Rc::new(RefCell::new(false)); + let flag_clone = flag.clone(); + assert!(!context.get_person_property(person, UltraRunner)); + context.subscribe_to_event( + move |_context, event: PersonPropertyChangeEvent| { + assert_eq!(event.person_id.id, 0); + assert!(!event.previous); + assert!(event.current); + *flag_clone.borrow_mut() = true; + }, + ); + context.set_person_property(person, IsRunner, true); + context.execute(); + assert!(*flag.borrow()); + } }