Skip to content

Commit

Permalink
Merge pull request #63 from yaak-ai/async
Browse files Browse the repository at this point in the history
Add support for async
  • Loading branch information
ryan-summers authored Jun 5, 2023
2 parents 983f57a + 2b86978 commit 92c200c
Show file tree
Hide file tree
Showing 11 changed files with 221 additions and 30 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
### Added

- Add `impl_display_{states,events}` for auto-impementation of `core::fmt::Display` on `States` and/or `Events`
- Add support for async guards and actions

### Fixed

Expand Down
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ readme = "README.md"

[dependencies]
smlang-macros = { path = "macros", version = "0.6.0" }
async-trait = "0.1"

[dev-dependencies]
smol = "1"

[target.'cfg(not(target_os = "none"))'.dev-dependencies]
trybuild = "1.0"
Expand Down
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,10 @@ See example `examples/event_with_data.rs` for a usage example.

See example `examples/guard_action_syntax.rs` for a usage-example.

### Async Guard and Action

See example `examples/async.rs` for a usage-example.

## State Machine Examples

Here are some examples of state machines converted from UML to the State Machine Language DSL. Runnable versions of each example is available in the `examples` folder.
Expand Down
85 changes: 85 additions & 0 deletions examples/async.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
//! Async guards and actions example
//!
//! An example of using async guards and actions mixed with standard ones.

#![deny(missing_docs)]

use smlang::{async_trait, statemachine};
use smol;

statemachine! {
transitions: {
*State1 + Event1 [guard1] / async action1 = State2,
State2 + Event2 [async guard2] / async action2 = State3,
State3 + Event3 / action3 = State4(bool),
}
}

/// Context with member
pub struct Context {
lock: smol::lock::RwLock<bool>,
done: bool,
}

#[async_trait]
impl StateMachineContext for Context {
fn guard1(&mut self) -> Result<(), ()> {
println!("`guard1` called from sync context");
Ok(())
}

async fn action1(&mut self) -> () {
println!("`action1` called from async context");
let mut lock = self.lock.write().await;
*lock = true;
}

async fn guard2(&mut self) -> Result<(), ()> {
println!("`guard2` called from async context");
let mut lock = self.lock.write().await;
*lock = false;
Ok(())
}

async fn action2(&mut self) -> () {
println!("`action2` called from async context");
if !*self.lock.read().await {
self.done = true;
}
}

fn action3(&mut self) -> bool {
println!("`action3` called from sync context, done = `{}`", self.done);
self.done
}
}

fn main() {
smol::block_on(async {
let mut sm = StateMachine::new(Context {
lock: smol::lock::RwLock::new(false),
done: false,
});
assert!(matches!(sm.state(), Ok(&States::State1)));

let r = sm.process_event(Events::Event1).await;
assert!(matches!(r, Ok(&States::State2)));

let r = sm.process_event(Events::Event2).await;
assert!(matches!(r, Ok(&States::State3)));

let r = sm.process_event(Events::Event3).await;
assert!(matches!(r, Ok(&States::State4(true))));

// Now all events will not give any change of state
let r = sm.process_event(Events::Event1).await;
assert!(matches!(r, Err(Error::InvalidEvent)));
assert!(matches!(sm.state(), Ok(&States::State4(_))));

let r = sm.process_event(Events::Event2).await;
assert!(matches!(r, Err(Error::InvalidEvent)));
assert!(matches!(sm.state(), Ok(&States::State4(_))));
});

// ...
}
57 changes: 43 additions & 14 deletions macros/src/codegen.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Move guards to return a Result

use crate::parser::lifetimes::Lifetimes;
use crate::parser::ParsedStateMachine;
use crate::parser::{lifetimes::Lifetimes, AsyncIdent, ParsedStateMachine};
use proc_macro2::{Literal, Span};
use quote::quote;
use std::vec::Vec;
Expand Down Expand Up @@ -239,7 +238,7 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream {
all_lifetimes.extend(&event_lifetimes);

// Create the guard traits for user implementation
if let Some(guard) = &value.guard {
if let Some(AsyncIdent {ident: guard, is_async}) = &value.guard {
let event_data = match sm.event_data.data_types.get(event) {
Some(et @ Type::Reference(_)) => quote! { event_data: #et },
Some(et) => quote! { event_data: &#et },
Expand All @@ -255,15 +254,24 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream {
// Only add the guard if it hasn't been added before
if !guard_set.iter().any(|g| g == guard) {
guard_set.push(guard.clone());
let is_async = match is_async {
true => quote!{ async },
false => quote!{ },
};
guard_list.extend(quote! {
#[allow(missing_docs)]
fn #guard <#all_lifetimes> (&mut self, #temporary_context #state_data #event_data) -> Result<(), #guard_error>;
#is_async fn #guard <#all_lifetimes> (&mut self, #temporary_context #state_data #event_data) -> Result<(), #guard_error>;
});
}
}

// Create the action traits for user implementation
if let Some(action) = &value.action {
if let Some(AsyncIdent {ident: action, is_async}) = &value.action {
let is_async = match is_async {
true => quote!{ async },
false => quote!{ },
};

let return_type = if let Some(output_data) =
sm.state_data.data_types.get(&value.out_state.to_string())
{
Expand Down Expand Up @@ -300,7 +308,7 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream {
action_set.push(action.clone());
action_list.extend(quote! {
#[allow(missing_docs)]
fn #action <#all_lifetimes> (&mut self, #temporary_context #state_data #event_data) -> #return_type;
#is_async fn #action <#all_lifetimes> (&mut self, #temporary_context #state_data #event_data) -> #return_type;
});
}
}
Expand All @@ -316,6 +324,8 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream {
}
};

let mut sm_is_async = false;

// Create the code blocks inside the switch cases
let code_blocks: Vec<Vec<_>> = guards
.iter()
Expand All @@ -334,28 +344,40 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream {
.zip(out_states.iter().zip(guard_action_parameters.iter().zip(guard_action_ref_parameters.iter()))),
)
.map(|(guard, (action, (out_state, (g_a_param, g_a_ref_param))))| {
if let Some(g) = guard {
if let Some(a) = action {
if let Some(AsyncIdent {ident: g, is_async: is_g_async}) = guard {
let guard_await = match is_g_async {
true => { sm_is_async = true; quote! { .await } },
false => quote! { },
};
if let Some(AsyncIdent {ident: a, is_async: is_a_async}) = action {
let action_await = match is_a_async {
true => { sm_is_async = true; quote! { .await } },
false => quote! { },
};
quote! {
if let Err(e) = self.context.#g(#temporary_context_call #g_a_ref_param) {
if let Err(e) = self.context.#g(#temporary_context_call #g_a_ref_param) #guard_await {
self.state = Some(States::#in_state);
return Err(Error::GuardFailed(e));
}
let _data = self.context.#a(#temporary_context_call #g_a_param);
let _data = self.context.#a(#temporary_context_call #g_a_param) #action_await;
self.state = Some(States::#out_state);
}
} else {
quote! {
if let Err(e) = self.context.#g(#temporary_context_call #g_a_ref_param) {
if let Err(e) = self.context.#g(#temporary_context_call #g_a_ref_param) #guard_await {
self.state = Some(States::#in_state);
return Err(Error::GuardFailed(e));
}
self.state = Some(States::#out_state);
}
}
} else if let Some(a) = action {
} else if let Some(AsyncIdent {ident: a, is_async: is_a_async}) = action {
let action_await = match is_a_async {
true => { sm_is_async = true; quote! { .await } },
false => quote! { },
};
quote! {
let _data = self.context.#a(#temporary_context_call #g_a_param);
let _data = self.context.#a(#temporary_context_call #g_a_param) #action_await;
self.state = Some(States::#out_state);
}
} else {
Expand Down Expand Up @@ -454,6 +476,12 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream {
quote! {}
};

let (is_async, is_async_trait) = if sm_is_async {
(quote! { async }, quote! { #[smlang::async_trait] })
} else {
(quote! {}, quote! {})
};

let error_type = if sm.custom_guard_error {
quote! {
Error<<T as StateMachineContext>::GuardError>
Expand All @@ -466,6 +494,7 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream {
quote! {
/// This trait outlines the guards and actions that need to be implemented for the state
/// machine.
#is_async_trait
pub trait StateMachineContext {
#guard_error
#guard_list
Expand Down Expand Up @@ -556,7 +585,7 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream {
///
/// It will return `Ok(&NextState)` if the transition was successful, or `Err(Error)`
/// if there was an error in the transition.
pub fn process_event <#event_unique_lifetimes> (
pub #is_async fn process_event <#event_unique_lifetimes> (
&mut self,
#temporary_context
mut event: Events <#event_lifetimes>
Expand Down
5 changes: 3 additions & 2 deletions macros/src/parser/event.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::parser::AsyncIdent;
use syn::{parenthesized, parse, spanned::Spanned, token, Ident, Token, Type};

#[derive(Debug, Clone)]
Expand All @@ -10,8 +11,8 @@ pub struct Event {
pub struct EventMapping {
pub in_state: Ident,
pub event: Ident,
pub guard: Option<Ident>,
pub action: Option<Ident>,
pub guard: Option<AsyncIdent>,
pub action: Option<AsyncIdent>,
pub out_state: Ident,
}

Expand Down
6 changes: 6 additions & 0 deletions macros/src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ use transition::StateTransition;

pub type TransitionMap = HashMap<String, HashMap<String, EventMapping>>;

#[derive(Debug, Clone)]
pub struct AsyncIdent {
pub ident: Ident,
pub is_async: bool,
}

#[derive(Debug)]
pub struct ParsedStateMachine {
pub temporary_context_type: Option<Type>,
Expand Down
21 changes: 15 additions & 6 deletions macros/src/parser/transition.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
use super::event::Event;
use super::input_state::InputState;
use super::output_state::OutputState;
use super::AsyncIdent;
use syn::{bracketed, parse, token, Ident, Token};

#[derive(Debug)]
pub struct StateTransition {
pub in_state: InputState,
pub event: Event,
pub guard: Option<Ident>,
pub action: Option<Ident>,
pub guard: Option<AsyncIdent>,
pub action: Option<AsyncIdent>,
pub out_state: OutputState,
}

#[derive(Debug)]
pub struct StateTransitions {
pub in_states: Vec<InputState>,
pub event: Event,
pub guard: Option<Ident>,
pub action: Option<Ident>,
pub guard: Option<AsyncIdent>,
pub action: Option<AsyncIdent>,
pub out_state: OutputState,
}

Expand Down Expand Up @@ -52,16 +53,24 @@ impl parse::Parse for StateTransitions {
let guard = if input.peek(token::Bracket) {
let content;
bracketed!(content in input);
let is_async = content.parse::<token::Async>().is_ok();
let guard: Ident = content.parse()?;
Some(guard)
Some(AsyncIdent {
ident: guard,
is_async,
})
} else {
None
};

// Possible action
let action = if input.parse::<Token![/]>().is_ok() {
let is_async = input.parse::<token::Async>().is_ok();
let action: Ident = input.parse()?;
Some(action)
Some(AsyncIdent {
ident: action,
is_async,
})
} else {
None
};
Expand Down
Loading

0 comments on commit 92c200c

Please sign in to comment.