Skip to content

Commit

Permalink
Added entry and exit functions
Browse files Browse the repository at this point in the history
By using `>` and `<` we can now define entry functions and exit
functions on states. These should be defined once and will apply to all
transitions interacting with either entering or exiting the state.

Signed-off-by: Martin Broers <[email protected]>
  • Loading branch information
Martin Broers committed Jun 27, 2024
1 parent 5a751a0 commit 2785f10
Show file tree
Hide file tree
Showing 7 changed files with 187 additions and 23 deletions.
21 changes: 21 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,27 @@ statemachine!{

This example is available in `ex3.rs`.

### Using entry and exit functions in transitions

DSL implementation:

```rust
statemachine!{
transitions: {
*State1 + Event1 = State2,
State2 < exit_state_2 + Event2 = State1,
State1 > enter_state_3 + Event3 = State3,
State2 + Event3 = State3,
}
}
```
For all transitions entering State3, the function `enter_state_3` will be
called. For all transitions exiting State2, the function `exit_state_2` will be
called, in the right order, so first the `exit` function prior to the `entry`
function.

An example is available in `on_entry_on_exit`.

## Helpers

### Auto-derive certain traits for states and events
Expand Down
49 changes: 38 additions & 11 deletions examples/on_entry_on_exit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,38 +6,65 @@ use smlang::statemachine;

statemachine! {
name: OnEntryExample,
generate_entry_exit_states: true,
transitions: {
*D0 + ToD1 = D1,
D1 + ToD2 = D2,
*D0 > exit_d0 + ToD1 = D1,
D0 + ToD3 = D3,
D1 < enter_d1 + ToD2 = D2,
D2 + ToD1 = D1,
D1 + ToD0 = D0,
},
generate_entry_exit_states: true,
}

/// Context
pub struct Context {
exited_d0: bool,
entered_d1: bool,
exited_d0: i32,
entered_d1: i32,
}

impl OnEntryExampleStateMachineContext for Context {
fn on_exit_d0(&mut self) {
self.exited_d0 = true;
self.exited_d0 += 1;
}
fn on_entry_d1(&mut self) {
self.entered_d1 = true;
self.entered_d1 += 1;
}
}

fn main() {
let mut sm = OnEntryExampleStateMachine::new(Context {
exited_d0: false,
entered_d1: false,
exited_d0: 0,
entered_d1: 0,
});

// first event starts the dominos
let _ = sm.process_event(OnEntryExampleEvents::ToD1).unwrap();

assert!(matches!(sm.state(), Ok(&OnEntryExampleStates::D1)));
assert!(sm.context().exited_d0);
assert!(sm.context().entered_d1);
assert_eq!(sm.context().exited_d0, 1);
assert_eq!(sm.context().entered_d1, 1);

let _ = sm.process_event(OnEntryExampleEvents::ToD2).unwrap();

assert!(matches!(sm.state(), Ok(&OnEntryExampleStates::D2)));
assert_eq!(sm.context().exited_d0, 1);
assert_eq!(sm.context().entered_d1, 1);

let _ = sm.process_event(OnEntryExampleEvents::ToD1).unwrap();

assert!(matches!(sm.state(), Ok(&OnEntryExampleStates::D1)));
assert_eq!(sm.context().exited_d0, 1);
assert_eq!(sm.context().entered_d1, 2);

let _ = sm.process_event(OnEntryExampleEvents::ToD0).unwrap();

assert!(matches!(sm.state(), Ok(&OnEntryExampleStates::D0)));
assert_eq!(sm.context().exited_d0, 1);
assert_eq!(sm.context().entered_d1, 2);

let _ = sm.process_event(OnEntryExampleEvents::ToD3).unwrap();

assert!(matches!(sm.state(), Ok(&OnEntryExampleStates::D3)));
assert_eq!(sm.context().exited_d0, 2);
assert_eq!(sm.context().entered_d1, 2);
}
32 changes: 23 additions & 9 deletions macros/src/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream {
let generate_entry_exit_states = sm.generate_entry_exit_states;
let generate_transition_callback = sm.generate_transition_callback;

let entry_fns = &sm.entry_functions;
let exit_fns = &sm.exit_functions;

// Get only the unique states
let mut state_list: Vec<_> = sm.states.values().collect();
state_list.sort_by_key(|state| state.to_string());
Expand Down Expand Up @@ -114,11 +117,6 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream {
})
.collect();

// println!("sm: {:#?}", sm);
// println!("in_states: {:#?}", in_states);
// println!("events: {:#?}", events);
// println!("transitions: {:#?}", transitions);

// Map guards, actions and output states into code blocks
let guards: Vec<Vec<_>> = transitions
.values()
Expand Down Expand Up @@ -257,6 +255,21 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream {
let mut action_list = proc_macro2::TokenStream::new();

let mut entry_list = proc_macro2::TokenStream::new();

let mut entries_exits = proc_macro2::TokenStream::new();
for ident in entry_fns.values() {
entries_exits.extend(quote! {
#[allow(missing_docs)]
fn #ident(&mut self){}
});
}
for ident in exit_fns.values() {
entries_exits.extend(quote! {
#[allow(missing_docs)]
fn #ident(&mut self){}
});
}

for (state, event_mappings) in transitions.iter() {
// create the state data token stream
let state_data = match sm.state_data.data_types.get(state) {
Expand Down Expand Up @@ -497,20 +510,20 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream {
#guard_result
self.context.log_guard(stringify!(#guard_expression), &guard_result);
if guard_result.map_err(#error_type_name::GuardFailed)? {
#entry_exit_states
#action_code
let out_state = #states_type_name::#out_state;
self.context.log_state_change(&out_state);
#entry_exit_states
self.state = Some(out_state);
return self.state()
}
}
} else { // Unguarded transition
quote!{
#action_code
#entry_exit_states
let out_state = #states_type_name::#out_state;
self.context.log_state_change(&out_state);
#entry_exit_states
self.state = Some(out_state);
return self.state();
}
Expand Down Expand Up @@ -603,6 +616,9 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream {
#guard_error
#guard_list
#action_list
#entry_list
#entries_exits


/// Called at the beginning of a state machine's `process_event()`. No-op by
/// default but can be overridden in implementations of a state machine's
Expand All @@ -624,8 +640,6 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream {
/// of a state machine's `StateMachineContext` trait.
fn log_state_change(&self, new_state: & #states_type_name) {}

#entry_list

#[allow(missing_docs)]
fn transition_callback(&self, new_state: &Option<#states_type_name>) {}
}
Expand Down
3 changes: 3 additions & 0 deletions macros/src/parser/event.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ pub struct Transition {
pub guard: Option<GuardExpression>,
pub action: Option<AsyncIdent>,
pub out_state: Ident,

pub entry_fn: Option<Ident>,
pub exit_fn: Option<Ident>,
}

impl parse::Parse for Event {
Expand Down
58 changes: 57 additions & 1 deletion macros/src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ use syn::{parse, Ident, Type};
use transition::StateTransition;
pub type TransitionMap = HashMap<String, HashMap<String, EventMapping>>;

#[derive(Debug, Clone)]
pub struct EntryIdent {
pub ident: Ident,
pub state: Vec<InputState>,
pub is_async: bool,
}

#[derive(Debug, Clone)]
pub struct AsyncIdent {
pub ident: Ident,
Expand Down Expand Up @@ -59,18 +66,25 @@ pub struct ParsedStateMachine {

pub generate_entry_exit_states: bool,
pub generate_transition_callback: bool,
pub entry_functions: HashMap<Ident, Ident>,
pub exit_functions: HashMap<Ident, Ident>,
}

// helper function for adding a transition to a transition event map
fn add_transition(
transition: &StateTransition,
transition_map: &mut TransitionMap,
state_data: &DataDefinitions,
entry_fns: &HashMap<Ident, Ident>,
exit_fns: &HashMap<Ident, Ident>,
) -> Result<(), parse::Error> {
let p = transition_map
.get_mut(&transition.in_state.ident.to_string())
.unwrap();

let entry_fn = entry_fns.get(&transition.in_state.ident.clone());
let exit_fn = exit_fns.get(&transition.in_state.ident.clone());

match p.entry(transition.event.ident.to_string()) {
hash_map::Entry::Vacant(entry) => {
let mapping = EventMapping {
Expand All @@ -80,6 +94,8 @@ fn add_transition(
guard: transition.guard.clone(),
action: transition.action.clone(),
out_state: transition.out_state.ident.clone(),
entry_fn: entry_fn.cloned(),
exit_fn: exit_fn.cloned(),
}],
};
entry.insert(mapping);
Expand All @@ -90,6 +106,8 @@ fn add_transition(
guard: transition.guard.clone(),
action: transition.action.clone(),
out_state: transition.out_state.ident.clone(),
entry_fn: entry_fn.cloned(),
exit_fn: exit_fn.cloned(),
});
}
}
Expand Down Expand Up @@ -138,6 +156,33 @@ impl ParsedStateMachine {
let mut event_data = DataDefinitions::new();
let mut states_events_mapping = TransitionMap::new();

let mut states_with_exit_function = HashMap::new();
let mut states_with_entry_function = HashMap::new();

fn add_entry(map: &mut HashMap<Ident, Ident>, vec: &Vec<EntryIdent>) -> parse::Result<()> {
for identifier in vec {
for input_state in &identifier.state {
if let Some(existing_identifier) =
map.insert(input_state.ident.clone(), identifier.ident.clone())
{
if identifier.ident != existing_identifier {
println!(
"entry_state: {:?}, state.ident: {:?}",
identifier.ident, input_state.ident
);
return Err(parse::Error::new(
Span::call_site(),
"Different entry or exit functions defined for state",
));
}
}
}
}
Ok(())
}
add_entry(&mut states_with_entry_function, &sm.entries)?;
add_entry(&mut states_with_exit_function, &sm.exits)?;

for transition in sm.transitions.iter() {
// Collect states
let in_state_name = transition.in_state.ident.to_string();
Expand Down Expand Up @@ -203,6 +248,8 @@ impl ParsedStateMachine {
&wildcard_transition,
&mut states_events_mapping,
&state_data,
&states_with_entry_function,
&states_with_exit_function,
)?;

transition_added = true;
Expand All @@ -217,7 +264,13 @@ impl ParsedStateMachine {
));
}
} else {
add_transition(transition, &mut states_events_mapping, &state_data)?;
add_transition(
transition,
&mut states_events_mapping,
&state_data,
&states_with_entry_function,
&states_with_exit_function,
)?;
}
}

Expand All @@ -235,6 +288,9 @@ impl ParsedStateMachine {
states_events_mapping,
generate_entry_exit_states: sm.generate_entry_exit_states,
generate_transition_callback: sm.generate_transition_callback,

entry_functions: states_with_entry_function,
exit_functions: states_with_exit_function,
})
}
}
15 changes: 14 additions & 1 deletion macros/src/parser/state_machine.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use super::transition::{StateTransition, StateTransitions};
use super::{
transition::{StateTransition, StateTransitions},
EntryIdent,
};
use syn::{braced, bracketed, parse, spanned::Spanned, token, Ident, Token, Type};

#[derive(Debug)]
Expand All @@ -11,6 +14,8 @@ pub struct StateMachine {
pub derive_events: Vec<Ident>,
pub generate_entry_exit_states: bool,
pub generate_transition_callback: bool,
pub entries: Vec<EntryIdent>,
pub exits: Vec<EntryIdent>,
}

impl StateMachine {
Expand All @@ -24,6 +29,8 @@ impl StateMachine {
derive_events: Vec::new(),
generate_entry_exit_states: false,
generate_transition_callback: false,
entries: Vec::new(),
exits: Vec::new(),
}
}

Expand All @@ -38,6 +45,12 @@ impl StateMachine {
};
self.transitions.push(transition);
}
if let Some(entry) = transitions.entry {
self.entries.push(entry);
}
if let Some(exit) = transitions.exit {
self.exits.push(exit);
}
}
}

Expand Down
Loading

0 comments on commit 2785f10

Please sign in to comment.