From 016c4cf213800015f95d9124be0732bf50d8ebbd Mon Sep 17 00:00:00 2001 From: oblique Date: Sat, 12 Nov 2022 23:26:23 +0200 Subject: [PATCH] Fix lifetime issues Fixes #57 --- CHANGELOG.md | 2 + macros/src/codegen.rs | 129 +++++++++++++-------------------- macros/src/parser/data.rs | 42 ++--------- macros/src/parser/event.rs | 1 + macros/src/parser/lifetimes.rs | 110 ++++++++++++++++++++++++++++ macros/src/parser/mod.rs | 6 +- tests/test.rs | 47 ++++++++++++ 7 files changed, 218 insertions(+), 119 deletions(-) create mode 100644 macros/src/parser/lifetimes.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index 7e2469c..fcc22ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ### Fixed +- Fixes multiple issues with lifetimes ([issue-57](https://github.com/korken89/smlang-rs/issues/57), [issue-58](https://github.com/korken89/smlang-rs/pull/58)) + ### Changed - [breaking] Actions now take owned values diff --git a/macros/src/codegen.rs b/macros/src/codegen.rs index bf2d813..853b527 100644 --- a/macros/src/codegen.rs +++ b/macros/src/codegen.rs @@ -1,10 +1,9 @@ // Move guards to return a Result -use crate::parser::data::Lifetimes; +use crate::parser::lifetimes::Lifetimes; use crate::parser::ParsedStateMachine; -use proc_macro2; -use proc_macro2::Span; -use quote::quote; +use proc_macro2::{Span, TokenStream}; +use quote::{quote, ToTokens}; use std::vec::Vec; use syn::{punctuated::Punctuated, token::Paren, Type, TypeTuple}; @@ -154,6 +153,7 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream { .iter() .map(|(name, _)| { let state_data = match sm.state_data.data_types.get(state_name) { + Some(Type::Reference(_)) => quote! { state_data }, Some(_) => quote! { &state_data }, None => quote! {}, }; @@ -217,59 +217,37 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream { for (state, value) in transitions.iter() { // create the state data token stream let state_data = match sm.state_data.data_types.get(state) { + Some(st @ Type::Reference(_)) => quote! { state_data: #st, }, Some(st) => quote! { state_data: &#st, }, None => quote! {}, }; value.iter().for_each(|(event, value)| { + // get input state lifetimes + let in_state_lifetimes = sm.state_data.lifetimes.get(&value.in_state.to_string()).cloned().unwrap_or_default(); // get output state lifetimes - let state_lifetimes = if let Some(lifetimes) = sm.state_data.lifetimes.get(&value.out_state.to_string()) { - lifetimes.clone() - } else { - Lifetimes::new() - }; - - // get the event lifetimes - let mut lifetimes = if let Some(lifetimes) = sm.event_data.lifetimes.get(event) { - lifetimes.clone() - } else { - Lifetimes::new() - }; - - // combine the state data and event data lifetimes - lifetimes.append(&mut state_lifetimes.clone()); + let out_state_lifetimes = sm.state_data.lifetimes.get(&value.out_state.to_string()).cloned().unwrap_or_default(); + + // get event lifetimes + let event_lifetimes = sm.event_data.lifetimes.get(event).cloned().unwrap_or_default(); + + // combine all lifetimes + let mut all_lifetimes = Lifetimes::new(); + all_lifetimes.extend(&in_state_lifetimes); + all_lifetimes.extend(&out_state_lifetimes); + all_lifetimes.extend(&event_lifetimes); // Create the guard traits for user implementation if let Some(guard) = &value.guard { - let guard_with_lifetimes = if let Some(lifetimes) = sm.event_data.lifetimes.get(event) { - let lifetimes = &lifetimes; - quote! { - #guard<#(#lifetimes),*> - } - } else { - quote! { - #guard - } - }; - let event_data = match sm.event_data.data_types.get(event) { - Some(et) => match et { - Type::Reference(_) => { - quote! { event_data: #et } - } - _ => { - quote! { event_data: &#et } - } - }, - None => { - quote! {} - } + Some(et @ Type::Reference(_)) => quote! { event_data: #et }, + Some(et) => quote! { event_data: &#et }, + None => quote! {}, }; let guard_error = if sm.custom_guard_error { quote! { Self::GuardError } - } else { quote! { () } }; @@ -279,7 +257,7 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream { guard_set.push(guard.clone()); guard_list.extend(quote! { #[allow(missing_docs)] - fn #guard_with_lifetimes(&mut self, #temporary_context #state_data #event_data) -> Result<(), #guard_error>; + fn #guard <#all_lifetimes> (&mut self, #temporary_context #state_data #event_data) -> Result<(), #guard_error>; }); } } @@ -300,16 +278,6 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream { }) }; - let action_with_lifetimes = if lifetimes.is_empty() { - quote! { - #action - } - } else { - quote! { - #action<#(#lifetimes),*> - } - }; - let state_data = match sm.state_data.data_types.get(state) { Some(st) => { quote! { state_data: #st, } @@ -332,7 +300,7 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream { action_set.push(action.clone()); action_list.extend(quote! { #[allow(missing_docs)] - fn #action_with_lifetimes(&mut self, #temporary_context #state_data #event_data) -> #return_type; + fn #action <#all_lifetimes> (&mut self, #temporary_context #state_data #event_data) -> #return_type; }); } } @@ -427,21 +395,11 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream { }, }; - // create token-streams for state data lifetimes - let state_lifetimes_code = if sm.state_data.lifetimes.is_empty() { - quote! {} - } else { - let state_lifetimes = &sm.state_data.all_lifetimes; - quote! {#(#state_lifetimes),* ,} - }; + let state_lifetimes = &sm.state_data.all_lifetimes; + let event_lifetimes = &sm.event_data.all_lifetimes; - // create token-streams for event data lifetimes - let event_lifetimes_code = if sm.event_data.lifetimes.is_empty() { - quote! {} - } else { - let event_lifetimes = &sm.event_data.all_lifetimes; - quote! {#(#event_lifetimes),* ,} - }; + // lifetimes that exists in Events but not in States + let event_unique_lifetimes = event_lifetimes - state_lifetimes; let guard_error = if sm.custom_guard_error { quote! { @@ -472,10 +430,10 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream { /// List of auto-generated states. #[allow(missing_docs)] - pub enum States <#state_lifetimes_code> { #(#state_list),* } + pub enum States <#state_lifetimes> { #(#state_list),* } /// Manually define PartialEq for States based on variant only to address issue-#21 - impl<#state_lifetimes_code> PartialEq for States <#state_lifetimes_code> { + impl<#state_lifetimes> PartialEq for States <#state_lifetimes> { fn eq(&self, other: &Self) -> bool { use core::mem::discriminant; discriminant(self) == discriminant(other) @@ -484,10 +442,10 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream { /// List of auto-generated events. #[allow(missing_docs)] - pub enum Events <#event_lifetimes_code> { #(#event_list),* } + pub enum Events <#event_lifetimes> { #(#event_list),* } /// Manually define PartialEq for Events based on variant only to address issue-#21 - impl<#event_lifetimes_code> PartialEq for Events <#event_lifetimes_code> { + impl<#event_lifetimes> PartialEq for Events <#event_lifetimes> { fn eq(&self, other: &Self) -> bool { use core::mem::discriminant; discriminant(self) == discriminant(other) @@ -509,19 +467,19 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream { } /// State machine structure definition. - pub struct StateMachine<#state_lifetimes_code T: StateMachineContext> { - state: Option>, + pub struct StateMachine<#state_lifetimes T: StateMachineContext> { + state: Option>, context: T } - impl<#state_lifetimes_code T: StateMachineContext> StateMachine<#state_lifetimes_code T> { + impl<#state_lifetimes T: StateMachineContext> StateMachine<#state_lifetimes T> { /// Creates a new state machine with the specified starting state. #[inline(always)] #new_sm_code /// Creates a new state machine with an initial state. #[inline(always)] - pub const fn new_with_state(context: T, initial_state: States <#state_lifetimes_code>) -> Self { + pub const fn new_with_state(context: T, initial_state: States <#state_lifetimes>) -> Self { StateMachine { state: Some(initial_state), context @@ -530,7 +488,7 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream { /// Returns the current state. #[inline(always)] - pub fn state(&self) -> Result<&States, #error_type> { + pub fn state(&self) -> Result<&States <#state_lifetimes>, #error_type> { self.state.as_ref().ok_or_else(|| Error::Poisoned) } @@ -550,7 +508,11 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream { /// /// It will return `Ok(&NextState)` if the transition was successful, or `Err(Error)` /// if there was an error in the transition. - pub fn process_event(&mut self, #temporary_context mut event: Events) -> Result<&States, #error_type> { + pub fn process_event <#event_unique_lifetimes> ( + &mut self, + #temporary_context + mut event: Events <#event_lifetimes> + ) -> Result<&States <#state_lifetimes>, #error_type> { match self.state.take().ok_or_else(|| Error::Poisoned)? { #(States::#in_states => match event { #(Events::#events => { @@ -572,3 +534,14 @@ pub fn generate_code(sm: &ParsedStateMachine) -> proc_macro2::TokenStream { } } } + +impl ToTokens for Lifetimes { + fn to_tokens(&self, tokens: &mut TokenStream) { + if self.is_empty() { + return; + } + + let lifetimes = self.as_slice(); + tokens.extend(quote! { #(#lifetimes),* ,}); + } +} diff --git a/macros/src/parser/data.rs b/macros/src/parser/data.rs index 682276e..894139a 100644 --- a/macros/src/parser/data.rs +++ b/macros/src/parser/data.rs @@ -1,40 +1,8 @@ +use crate::parser::lifetimes::Lifetimes; use std::collections::HashMap; -use syn::{parse, spanned::Spanned, GenericArgument, Lifetime, PathArguments, Type}; +use syn::{parse, spanned::Spanned, Type}; pub type DataTypes = HashMap; -pub type Lifetimes = Vec; - -// helper function for extracting a vector of lifetimes from a Type -fn get_lifetimes(data_type: &Type) -> Result { - let mut lifetimes = Lifetimes::new(); - match data_type { - Type::Reference(tr) => { - if let Some(lifetime) = &tr.lifetime { - lifetimes.push(lifetime.clone()); - } else { - return Err(parse::Error::new( - data_type.span(), - "This event's data lifetime is not defined, consider adding a lifetime.", - )); - } - Ok(lifetimes) - } - Type::Path(tp) => { - let punct = &tp.path.segments; - for p in punct.iter() { - if let PathArguments::AngleBracketed(abga) = &p.arguments { - for arg in &abga.args { - if let GenericArgument::Lifetime(lifetime) = &arg { - lifetimes.push(lifetime.clone()); - } - } - } - } - Ok(lifetimes) - } - _ => Ok(lifetimes), - } -} #[derive(Debug)] pub struct DataDefinitions { @@ -55,15 +23,15 @@ impl DataDefinitions { // helper function for adding a new data type to a data descriptions struct fn add(&mut self, key: String, data_type: Type) -> Result<(), parse::Error> { // retrieve any lifetimes used in this data-type - let mut lifetimes = get_lifetimes(&data_type)?; + let lifetimes = Lifetimes::from_type(&data_type)?; // add the data to the collection self.data_types.insert(key.clone(), data_type); // if any new lifetimes were used in the type definition, we add those as well if !lifetimes.is_empty() { - self.lifetimes.insert(key, lifetimes.clone()); - self.all_lifetimes.append(&mut lifetimes); + self.all_lifetimes.extend(&lifetimes); + self.lifetimes.insert(key, lifetimes); } Ok(()) } diff --git a/macros/src/parser/event.rs b/macros/src/parser/event.rs index 33a67bf..55c804a 100644 --- a/macros/src/parser/event.rs +++ b/macros/src/parser/event.rs @@ -8,6 +8,7 @@ pub struct Event { #[derive(Debug)] pub struct EventMapping { + pub in_state: Ident, pub event: Ident, pub guard: Option, pub action: Option, diff --git a/macros/src/parser/lifetimes.rs b/macros/src/parser/lifetimes.rs new file mode 100644 index 0000000..36742de --- /dev/null +++ b/macros/src/parser/lifetimes.rs @@ -0,0 +1,110 @@ +use std::ops::Sub; +use syn::{parse, spanned::Spanned, GenericArgument, Lifetime, PathArguments, Type}; + +#[derive(Default, Debug, Clone)] +pub struct Lifetimes { + lifetimes: Vec, +} + +impl Lifetimes { + pub fn new() -> Lifetimes { + Lifetimes { + lifetimes: Vec::new(), + } + } + + pub fn from_type(data_type: &Type) -> Result { + let mut lifetimes = Lifetimes::new(); + get_lifetimes(data_type, &mut lifetimes)?; + Ok(lifetimes) + } + + pub fn insert(&mut self, lifetime: &Lifetime) { + if !self.lifetimes.contains(lifetime) { + self.lifetimes.push(lifetime.to_owned()); + } + } + + pub fn extend(&mut self, other: &Lifetimes) { + for lifetime in other.lifetimes.iter() { + self.insert(lifetime); + } + } + + pub fn is_empty(&self) -> bool { + self.lifetimes.is_empty() + } + + pub fn as_slice(&self) -> &[Lifetime] { + &self.lifetimes[..] + } +} + +impl Sub<&Lifetimes> for Lifetimes { + type Output = Lifetimes; + + fn sub(mut self, rhs: &Lifetimes) -> Lifetimes { + self.lifetimes.retain(|lt| !rhs.lifetimes.contains(lt)); + self + } +} + +impl Sub for Lifetimes { + type Output = Lifetimes; + + fn sub(self, rhs: Lifetimes) -> Lifetimes { + self.sub(&rhs) + } +} + +impl Sub<&Lifetimes> for &Lifetimes { + type Output = Lifetimes; + + fn sub(self, rhs: &Lifetimes) -> Lifetimes { + self.to_owned().sub(rhs) + } +} + +impl Sub for &Lifetimes { + type Output = Lifetimes; + + fn sub(self, rhs: Lifetimes) -> Lifetimes { + self.to_owned().sub(&rhs) + } +} + +// helper function for extracting lifetimes from a Type +fn get_lifetimes(data_type: &Type, lifetimes: &mut Lifetimes) -> Result<(), parse::Error> { + match data_type { + Type::Reference(tr) => { + if let Some(lifetime) = &tr.lifetime { + lifetimes.insert(lifetime); + } else { + return Err(parse::Error::new( + data_type.span(), + "This event's data lifetime is not defined, consider adding a lifetime.", + )); + } + } + Type::Path(tp) => { + let punct = &tp.path.segments; + for p in punct.iter() { + if let PathArguments::AngleBracketed(abga) = &p.arguments { + for arg in &abga.args { + if let GenericArgument::Lifetime(lifetime) = &arg { + lifetimes.insert(lifetime); + } + } + } + } + } + Type::Tuple(tuple) => { + for elem in tuple.elems.iter() { + get_lifetimes(elem, lifetimes)?; + } + } + _ => {} + } + + Ok(()) +} diff --git a/macros/src/parser/mod.rs b/macros/src/parser/mod.rs index ddfffe5..d2b8f52 100644 --- a/macros/src/parser/mod.rs +++ b/macros/src/parser/mod.rs @@ -1,6 +1,7 @@ pub mod data; pub mod event; pub mod input_state; +pub mod lifetimes; pub mod output_state; pub mod state_machine; pub mod transition; @@ -42,6 +43,7 @@ fn add_transition( if !p.contains_key(&transition.event.ident.to_string()) { let mapping = EventMapping { + in_state: transition.in_state.ident.clone(), event: transition.event.ident.clone(), guard: transition.guard.clone(), action: transition.action.clone(), @@ -137,10 +139,6 @@ impl ParsedStateMachine { states_events_mapping.insert(transition.out_state.ident.to_string(), HashMap::new()); } - // Remove duplicate lifetimes - state_data.all_lifetimes.dedup(); - event_data.all_lifetimes.dedup(); - for transition in sm.transitions.iter() { // if input state is a wildcard, we need to add this transition for all states if transition.in_state.wildcard { diff --git a/tests/test.rs b/tests/test.rs index 88c81e3..055fb71 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -28,3 +28,50 @@ fn wildcard_after_input_state() { sm.process_event(Events::Event1).unwrap(); assert!(matches!(sm.state(), Ok(&States::Fault))); } + +#[test] +fn multiple_lifetimes() { + pub struct X; + pub struct Y; + pub struct Z; + + statemachine! { + transitions: { + *State1 + Event1(&'a X) [guard1] / action1 = State2(&'a X), + State2(&'a X) + Event2(&'b Y) [guard2] / action2 = State3((&'a X, &'b Y)), + State4 + Event(&'c Z) [guard3] / action3 = State5, + } + } + + struct Context; + + impl StateMachineContext for Context { + fn guard1<'a>(&mut self, _event_data: &'a X) -> Result<(), ()> { + Ok(()) + } + + fn guard2<'a, 'b>(&mut self, _state_data: &'a X, _event_data: &'b Y) -> Result<(), ()> { + Ok(()) + } + + fn guard3<'c>(&mut self, _event_data: &'c Z) -> Result<(), ()> { + Ok(()) + } + + fn action1<'a>(&mut self, event_data: &'a X) -> &'a X { + event_data + } + + fn action2<'a, 'b>(&mut self, state_data: &'a X, event_data: &'b Y) -> (&'a X, &'b Y) { + (state_data, event_data) + } + + fn action3<'c>(&mut self, _event_data: &'c Z) {} + } + + #[allow(dead_code)] + struct WrappedStates<'a, 'b>(States<'a, 'b>); + + #[allow(dead_code)] + struct WrappedEvents<'a, 'b, 'c>(Events<'a, 'b, 'c>); +}