From 2785f10fb724cca4835ddd8c7a0186751967b149 Mon Sep 17 00:00:00 2001 From: Martin Broers Date: Thu, 7 Mar 2024 12:09:03 +0100 Subject: [PATCH] Added entry and exit functions By using `>` and `<` we can now define entry functions and exit functions on states. These should be defined once and will apply to all transitions interacting with either entering or exiting the state. Signed-off-by: Martin Broers --- README.md | 21 +++++++++++ examples/on_entry_on_exit.rs | 49 +++++++++++++++++++------ macros/src/codegen.rs | 32 ++++++++++++----- macros/src/parser/event.rs | 3 ++ macros/src/parser/mod.rs | 58 +++++++++++++++++++++++++++++- macros/src/parser/state_machine.rs | 15 +++++++- macros/src/parser/transition.rs | 32 ++++++++++++++++- 7 files changed, 187 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index 55f72e9..84073df 100644 --- a/README.md +++ b/README.md @@ -260,6 +260,27 @@ statemachine!{ This example is available in `ex3.rs`. +### Using entry and exit functions in transitions + +DSL implementation: + +```rust +statemachine!{ + transitions: { + *State1 + Event1 = State2, + State2 < exit_state_2 + Event2 = State1, + State1 > enter_state_3 + Event3 = State3, + State2 + Event3 = State3, + } +} +``` +For all transitions entering State3, the function `enter_state_3` will be +called. For all transitions exiting State2, the function `exit_state_2` will be +called, in the right order, so first the `exit` function prior to the `entry` +function. + +An example is available in `on_entry_on_exit`. + ## Helpers ### Auto-derive certain traits for states and events diff --git a/examples/on_entry_on_exit.rs b/examples/on_entry_on_exit.rs index 745aa9f..f01a564 100644 --- a/examples/on_entry_on_exit.rs +++ b/examples/on_entry_on_exit.rs @@ -6,38 +6,65 @@ use smlang::statemachine; statemachine! { name: OnEntryExample, + generate_entry_exit_states: true, transitions: { - *D0 + ToD1 = D1, - D1 + ToD2 = D2, + *D0 > exit_d0 + ToD1 = D1, + D0 + ToD3 = D3, + D1 < enter_d1 + ToD2 = D2, + D2 + ToD1 = D1, + D1 + ToD0 = D0, }, - generate_entry_exit_states: true, } /// Context pub struct Context { - exited_d0: bool, - entered_d1: bool, + exited_d0: i32, + entered_d1: i32, } impl OnEntryExampleStateMachineContext for Context { fn on_exit_d0(&mut self) { - self.exited_d0 = true; + self.exited_d0 += 1; } fn on_entry_d1(&mut self) { - self.entered_d1 = true; + self.entered_d1 += 1; } } fn main() { let mut sm = OnEntryExampleStateMachine::new(Context { - exited_d0: false, - entered_d1: false, + exited_d0: 0, + entered_d1: 0, }); // first event starts the dominos let _ = sm.process_event(OnEntryExampleEvents::ToD1).unwrap(); assert!(matches!(sm.state(), Ok(&OnEntryExampleStates::D1))); - assert!(sm.context().exited_d0); - assert!(sm.context().entered_d1); + assert_eq!(sm.context().exited_d0, 1); + assert_eq!(sm.context().entered_d1, 1); + + let _ = sm.process_event(OnEntryExampleEvents::ToD2).unwrap(); + + assert!(matches!(sm.state(), Ok(&OnEntryExampleStates::D2))); + assert_eq!(sm.context().exited_d0, 1); + assert_eq!(sm.context().entered_d1, 1); + + let _ = sm.process_event(OnEntryExampleEvents::ToD1).unwrap(); + + assert!(matches!(sm.state(), Ok(&OnEntryExampleStates::D1))); + assert_eq!(sm.context().exited_d0, 1); + assert_eq!(sm.context().entered_d1, 2); + + let _ = sm.process_event(OnEntryExampleEvents::ToD0).unwrap(); + + assert!(matches!(sm.state(), Ok(&OnEntryExampleStates::D0))); + assert_eq!(sm.context().exited_d0, 1); + assert_eq!(sm.context().entered_d1, 2); + + let _ = sm.process_event(OnEntryExampleEvents::ToD3).unwrap(); + + assert!(matches!(sm.state(), Ok(&OnEntryExampleStates::D3))); + assert_eq!(sm.context().exited_d0, 2); + assert_eq!(sm.context().entered_d1, 2); } diff --git a/macros/src/codegen.rs b/macros/src/codegen.rs index d1e3883..895fba4 100644 --- a/macros/src/codegen.rs +++ b/macros/src/codegen.rs @@ -22,6 +22,9 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream { let generate_entry_exit_states = sm.generate_entry_exit_states; let generate_transition_callback = sm.generate_transition_callback; + let entry_fns = &sm.entry_functions; + let exit_fns = &sm.exit_functions; + // Get only the unique states let mut state_list: Vec<_> = sm.states.values().collect(); state_list.sort_by_key(|state| state.to_string()); @@ -114,11 +117,6 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream { }) .collect(); - // println!("sm: {:#?}", sm); - // println!("in_states: {:#?}", in_states); - // println!("events: {:#?}", events); - // println!("transitions: {:#?}", transitions); - // Map guards, actions and output states into code blocks let guards: Vec> = transitions .values() @@ -257,6 +255,21 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream { let mut action_list = proc_macro2::TokenStream::new(); let mut entry_list = proc_macro2::TokenStream::new(); + + let mut entries_exits = proc_macro2::TokenStream::new(); + for ident in entry_fns.values() { + entries_exits.extend(quote! { + #[allow(missing_docs)] + fn #ident(&mut self){} + }); + } + for ident in exit_fns.values() { + entries_exits.extend(quote! { + #[allow(missing_docs)] + fn #ident(&mut self){} + }); + } + for (state, event_mappings) in transitions.iter() { // create the state data token stream let state_data = match sm.state_data.data_types.get(state) { @@ -497,10 +510,10 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream { #guard_result self.context.log_guard(stringify!(#guard_expression), &guard_result); if guard_result.map_err(#error_type_name::GuardFailed)? { - #entry_exit_states #action_code let out_state = #states_type_name::#out_state; self.context.log_state_change(&out_state); + #entry_exit_states self.state = Some(out_state); return self.state() } @@ -508,9 +521,9 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream { } else { // Unguarded transition quote!{ #action_code - #entry_exit_states let out_state = #states_type_name::#out_state; self.context.log_state_change(&out_state); + #entry_exit_states self.state = Some(out_state); return self.state(); } @@ -603,6 +616,9 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream { #guard_error #guard_list #action_list + #entry_list + #entries_exits + /// Called at the beginning of a state machine's `process_event()`. No-op by /// default but can be overridden in implementations of a state machine's @@ -624,8 +640,6 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream { /// of a state machine's `StateMachineContext` trait. fn log_state_change(&self, new_state: & #states_type_name) {} - #entry_list - #[allow(missing_docs)] fn transition_callback(&self, new_state: &Option<#states_type_name>) {} } diff --git a/macros/src/parser/event.rs b/macros/src/parser/event.rs index 3f06e77..d90d898 100644 --- a/macros/src/parser/event.rs +++ b/macros/src/parser/event.rs @@ -20,6 +20,9 @@ pub struct Transition { pub guard: Option, pub action: Option, pub out_state: Ident, + + pub entry_fn: Option, + pub exit_fn: Option, } impl parse::Parse for Event { diff --git a/macros/src/parser/mod.rs b/macros/src/parser/mod.rs index a4c2ae2..b244038 100644 --- a/macros/src/parser/mod.rs +++ b/macros/src/parser/mod.rs @@ -20,6 +20,13 @@ use syn::{parse, Ident, Type}; use transition::StateTransition; pub type TransitionMap = HashMap>; +#[derive(Debug, Clone)] +pub struct EntryIdent { + pub ident: Ident, + pub state: Vec, + pub is_async: bool, +} + #[derive(Debug, Clone)] pub struct AsyncIdent { pub ident: Ident, @@ -59,6 +66,8 @@ pub struct ParsedStateMachine { pub generate_entry_exit_states: bool, pub generate_transition_callback: bool, + pub entry_functions: HashMap, + pub exit_functions: HashMap, } // helper function for adding a transition to a transition event map @@ -66,11 +75,16 @@ fn add_transition( transition: &StateTransition, transition_map: &mut TransitionMap, state_data: &DataDefinitions, + entry_fns: &HashMap, + exit_fns: &HashMap, ) -> Result<(), parse::Error> { let p = transition_map .get_mut(&transition.in_state.ident.to_string()) .unwrap(); + let entry_fn = entry_fns.get(&transition.in_state.ident.clone()); + let exit_fn = exit_fns.get(&transition.in_state.ident.clone()); + match p.entry(transition.event.ident.to_string()) { hash_map::Entry::Vacant(entry) => { let mapping = EventMapping { @@ -80,6 +94,8 @@ fn add_transition( guard: transition.guard.clone(), action: transition.action.clone(), out_state: transition.out_state.ident.clone(), + entry_fn: entry_fn.cloned(), + exit_fn: exit_fn.cloned(), }], }; entry.insert(mapping); @@ -90,6 +106,8 @@ fn add_transition( guard: transition.guard.clone(), action: transition.action.clone(), out_state: transition.out_state.ident.clone(), + entry_fn: entry_fn.cloned(), + exit_fn: exit_fn.cloned(), }); } } @@ -138,6 +156,33 @@ impl ParsedStateMachine { let mut event_data = DataDefinitions::new(); let mut states_events_mapping = TransitionMap::new(); + let mut states_with_exit_function = HashMap::new(); + let mut states_with_entry_function = HashMap::new(); + + fn add_entry(map: &mut HashMap, vec: &Vec) -> parse::Result<()> { + for identifier in vec { + for input_state in &identifier.state { + if let Some(existing_identifier) = + map.insert(input_state.ident.clone(), identifier.ident.clone()) + { + if identifier.ident != existing_identifier { + println!( + "entry_state: {:?}, state.ident: {:?}", + identifier.ident, input_state.ident + ); + return Err(parse::Error::new( + Span::call_site(), + "Different entry or exit functions defined for state", + )); + } + } + } + } + Ok(()) + } + add_entry(&mut states_with_entry_function, &sm.entries)?; + add_entry(&mut states_with_exit_function, &sm.exits)?; + for transition in sm.transitions.iter() { // Collect states let in_state_name = transition.in_state.ident.to_string(); @@ -203,6 +248,8 @@ impl ParsedStateMachine { &wildcard_transition, &mut states_events_mapping, &state_data, + &states_with_entry_function, + &states_with_exit_function, )?; transition_added = true; @@ -217,7 +264,13 @@ impl ParsedStateMachine { )); } } else { - add_transition(transition, &mut states_events_mapping, &state_data)?; + add_transition( + transition, + &mut states_events_mapping, + &state_data, + &states_with_entry_function, + &states_with_exit_function, + )?; } } @@ -235,6 +288,9 @@ impl ParsedStateMachine { states_events_mapping, generate_entry_exit_states: sm.generate_entry_exit_states, generate_transition_callback: sm.generate_transition_callback, + + entry_functions: states_with_entry_function, + exit_functions: states_with_exit_function, }) } } diff --git a/macros/src/parser/state_machine.rs b/macros/src/parser/state_machine.rs index 8dc5674..48f9e9c 100644 --- a/macros/src/parser/state_machine.rs +++ b/macros/src/parser/state_machine.rs @@ -1,4 +1,7 @@ -use super::transition::{StateTransition, StateTransitions}; +use super::{ + transition::{StateTransition, StateTransitions}, + EntryIdent, +}; use syn::{braced, bracketed, parse, spanned::Spanned, token, Ident, Token, Type}; #[derive(Debug)] @@ -11,6 +14,8 @@ pub struct StateMachine { pub derive_events: Vec, pub generate_entry_exit_states: bool, pub generate_transition_callback: bool, + pub entries: Vec, + pub exits: Vec, } impl StateMachine { @@ -24,6 +29,8 @@ impl StateMachine { derive_events: Vec::new(), generate_entry_exit_states: false, generate_transition_callback: false, + entries: Vec::new(), + exits: Vec::new(), } } @@ -38,6 +45,12 @@ impl StateMachine { }; self.transitions.push(transition); } + if let Some(entry) = transitions.entry { + self.entries.push(entry); + } + if let Some(exit) = transitions.exit { + self.exits.push(exit); + } } } diff --git a/macros/src/parser/transition.rs b/macros/src/parser/transition.rs index fe5a248..efe2ef1 100644 --- a/macros/src/parser/transition.rs +++ b/macros/src/parser/transition.rs @@ -1,7 +1,7 @@ -use super::event::Event; use super::input_state::InputState; use super::output_state::OutputState; use super::AsyncIdent; +use super::{event::Event, EntryIdent}; use proc_macro2::TokenStream; use quote::quote; use std::fmt; @@ -22,6 +22,8 @@ pub struct StateTransitions { pub event: Event, pub guard: Option, pub action: Option, + pub entry: Option, + pub exit: Option, pub out_state: OutputState, } @@ -49,6 +51,32 @@ impl parse::Parse for StateTransitions { } } + // Possible extry function + let entry = if input.parse::().is_ok() { + let is_async = input.parse::().is_ok(); + let entry_function: Ident = input.parse()?; + Some(EntryIdent { + ident: entry_function, + state: in_states.clone(), + is_async, + }) + } else { + None + }; + let exit = if input.parse::]>().is_ok() { + let is_async = input.parse::().is_ok(); + let exit_function: Ident = match input.parse() { + Ok(v) => v, + Err(e) => panic!("Could not parse exit token: {:?}", e), + }; + Some(EntryIdent { + ident: exit_function, + state: in_states.clone(), + is_async, + }) + } else { + None + }; // Event let event: Event = input.parse()?; @@ -81,6 +109,8 @@ impl parse::Parse for StateTransitions { guard, action, out_state, + entry, + exit, }) } }