diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 99af3f5b..083b9f50 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -51,7 +51,7 @@ use syn::{ Variant, }; -#[proc_macro_derive(TypeInfo)] +#[proc_macro_derive(TypeInfo, attributes(scale_info))] pub fn type_info(input: TokenStream) -> TokenStream { match generate(input.into()) { Ok(output) => output.into(), @@ -66,21 +66,31 @@ fn generate(input: TokenStream2) -> Result { } fn generate_type(input: TokenStream2) -> Result { - let ast: DeriveInput = syn::parse2(input.clone())?; + let mut ast: DeriveInput = syn::parse2(input.clone())?; + + utils::check_attributes(&ast)?; let scale_info = crate_name_ident("scale-info")?; let parity_scale_codec = crate_name_ident("parity-scale-codec")?; let ident = &ast.ident; + let where_clause = if let Some(custom_bounds) = utils::custom_trait_bounds(&ast.attrs) + { + let where_clause = ast.generics.make_where_clause(); + where_clause.predicates.extend(custom_bounds); + where_clause.clone() + } else { + trait_bounds::make_where_clause( + ident, + &ast.generics, + &ast.data, + &scale_info, + &parity_scale_codec, + )? + }; + let (impl_generics, ty_generics, _) = ast.generics.split_for_impl(); - let where_clause = trait_bounds::make_where_clause( - ident, - &ast.generics, - &ast.data, - &scale_info, - &parity_scale_codec, - )?; let generic_type_ids = ast.generics.type_params().map(|ty| { let ty_ident = &ty.ident; diff --git a/derive/src/utils.rs b/derive/src/utils.rs index 32c0c2f4..5b54447c 100644 --- a/derive/src/utils.rs +++ b/derive/src/utils.rs @@ -23,10 +23,14 @@ use alloc::{ use proc_macro2::TokenStream; use quote::quote; use syn::{ + parse::Parse, parse_quote, + punctuated::Punctuated, spanned::Spanned, + token, AttrStyle, Attribute, + DeriveInput, Lit, Meta, NestedMeta, @@ -55,6 +59,41 @@ pub fn get_doc_literals(attrs: &[syn::Attribute]) -> Vec { .collect() } +/// Trait bounds. +pub type TraitBounds = Punctuated; + +/// Parse `name(T: Bound, N: Bound)` as a custom trait bound. +struct CustomTraitBound { + _name: N, + _paren_token: token::Paren, + bounds: TraitBounds, +} + +impl Parse for CustomTraitBound { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let content; + let _name = input.parse()?; + let _paren_token = syn::parenthesized!(content in input); + let bounds = content.parse_terminated(syn::WherePredicate::parse)?; + Ok(Self { + _name, + _paren_token, + bounds, + }) + } +} + +syn::custom_keyword!(bounds); + +/// Look for a `#[scale_info(bounds(…))]`in the given attributes. +/// +/// If found, use the given trait bounds when deriving the `TypeInfo` trait. +pub fn custom_trait_bounds(attrs: &[Attribute]) -> Option { + scale_info_meta_item(attrs.iter(), |meta: CustomTraitBound| { + Some(meta.bounds) + }) +} + /// Look for a `#[codec(index = $int)]` attribute on a variant. If no attribute /// is found, fall back to the discriminant or just the variant index. pub fn variant_index(v: &Variant, i: usize) -> TokenStream { @@ -77,7 +116,7 @@ pub fn maybe_index(variant: &Variant) -> Option { .iter() .filter(|attr| attr.style == AttrStyle::Outer); - find_meta_item(outer_attrs, |meta| { + codec_meta_item(outer_attrs, |meta| { if let NestedMeta::Meta(Meta::NameValue(ref nv)) = meta { if nv.path.is_ident("index") { if let Lit::Int(ref v) = nv.lit { @@ -99,7 +138,7 @@ pub fn is_compact(field: &syn::Field) -> bool { .attrs .iter() .filter(|attr| attr.style == AttrStyle::Outer); - find_meta_item(outer_attrs, |meta| { + codec_meta_item(outer_attrs, |meta| { if let NestedMeta::Meta(Meta::Path(ref path)) = meta { if path.is_ident("compact") { return Some(()) @@ -113,7 +152,7 @@ pub fn is_compact(field: &syn::Field) -> bool { /// Look for a `#[codec(skip)]` in the given attributes. pub fn should_skip(attrs: &[Attribute]) -> bool { - find_meta_item(attrs.iter(), |meta| { + codec_meta_item(attrs.iter(), |meta| { if let NestedMeta::Meta(Meta::Path(ref path)) = meta { if path.is_ident("skip") { return Some(path.span()) @@ -125,22 +164,57 @@ pub fn should_skip(attrs: &[Attribute]) -> bool { .is_some() } -fn find_meta_item<'a, F, R, I>(itr: I, pred: F) -> Option +fn codec_meta_item<'a, F, R, I, M>(itr: I, pred: F) -> Option where - F: Fn(&NestedMeta) -> Option + Clone, + F: FnMut(M) -> Option + Clone, I: Iterator, + M: Parse, { - itr.filter_map(|attr| { - if attr.path.is_ident("codec") { - if let Meta::List(ref meta_list) = attr - .parse_meta() - .expect("scale-info: Bad index in `#[codec(index = …)]`, see `parity-scale-codec` error") - { - return meta_list.nested.iter().filter_map(pred.clone()).next() - } - } + find_meta_item("codec", itr, pred) +} - None +fn scale_info_meta_item<'a, F, R, I, M>(itr: I, pred: F) -> Option +where + F: FnMut(M) -> Option + Clone, + I: Iterator, + M: Parse, +{ + find_meta_item("scale_info", itr, pred) +} + +fn find_meta_item<'a, F, R, I, M>(kind: &str, mut itr: I, mut pred: F) -> Option +where + F: FnMut(M) -> Option + Clone, + I: Iterator, + M: Parse, +{ + itr.find_map(|attr| { + attr.path + .is_ident(kind) + .then(|| pred(attr.parse_args().ok()?)) + .flatten() }) - .next() +} + +/// Ensure attributes are correctly applied. This *must* be called before using +/// any of the attribute finder methods or the macro may panic if it encounters +/// misapplied attributes. +/// `#[scale_info(bounds())]` is the only accepted attribute. +pub fn check_attributes(input: &DeriveInput) -> syn::Result<()> { + for attr in &input.attrs { + check_top_attribute(attr)?; + } + Ok(()) +} + +// Only `#[scale_info(bounds())]` is a valid top attribute. +fn check_top_attribute(attr: &Attribute) -> syn::Result<()> { + if attr.path.is_ident("scale_info") { + match attr.parse_args::>() { + Ok(_) => Ok(()), + Err(e) => Err(syn::Error::new(attr.span(), format!("Invalid attribute: {:?}. Only `#[scale_info(bounds(…))]` is a valid top attribute", e))) + } + } else { + Ok(()) + } } diff --git a/test_suite/tests/derive.rs b/test_suite/tests/derive.rs index 81836ba0..cb99fa55 100644 --- a/test_suite/tests/derive.rs +++ b/test_suite/tests/derive.rs @@ -617,9 +617,5 @@ fn ui_tests() { t.compile_fail("tests/ui/fail_unions.rs"); t.compile_fail("tests/ui/fail_use_codec_attrs_without_deriving_encode.rs"); t.compile_fail("tests/ui/fail_with_invalid_codec_attrs.rs"); - t.pass("tests/ui/pass_with_valid_codec_attrs.rs"); - t.pass("tests/ui/pass_non_static_lifetime.rs"); - t.pass("tests/ui/pass_self_referential.rs"); - t.pass("tests/ui/pass_basic_generic_type.rs"); - t.pass("tests/ui/pass_complex_generic_self_referential_type.rs"); + t.pass("tests/ui/pass_*"); } diff --git a/test_suite/tests/ui/fail_with_invalid_codec_attrs.stderr b/test_suite/tests/ui/fail_with_invalid_codec_attrs.stderr index 96a1a344..de1bd402 100644 --- a/test_suite/tests/ui/fail_with_invalid_codec_attrs.stderr +++ b/test_suite/tests/ui/fail_with_invalid_codec_attrs.stderr @@ -16,24 +16,8 @@ error: expected literal 18 | #[codec(index = a)] | ^ -error: proc-macro derive panicked - --> $DIR/fail_with_invalid_codec_attrs.rs:16:18 - | -16 | #[derive(Encode, TypeInfo)] - | ^^^^^^^^ - | - = help: message: scale-info: Bad index in `#[codec(index = …)]`, see `parity-scale-codec` error: Error("expected literal") - error: expected literal --> $DIR/fail_with_invalid_codec_attrs.rs:24:25 | 24 | #[codec(encode_as = u8, compact)] | ^^ - -error: proc-macro derive panicked - --> $DIR/fail_with_invalid_codec_attrs.rs:22:18 - | -22 | #[derive(Encode, TypeInfo)] - | ^^^^^^^^ - | - = help: message: scale-info: Bad index in `#[codec(index = …)]`, see `parity-scale-codec` error: Error("expected literal") diff --git a/test_suite/tests/ui/pass_custom_bounds.rs b/test_suite/tests/ui/pass_custom_bounds.rs new file mode 100644 index 00000000..8c053420 --- /dev/null +++ b/test_suite/tests/ui/pass_custom_bounds.rs @@ -0,0 +1,25 @@ +use scale_info::TypeInfo; +use core::marker::PhantomData; + +#[allow(unused)] +#[derive(TypeInfo)] +#[scale_info(bounds(T: Default + TypeInfo + 'static, N: TypeInfo + 'static))] +struct Hey { + ciao: Greet, + ho: N, +} + +#[derive(TypeInfo)] +#[scale_info(bounds(T: TypeInfo + 'static))] +struct Greet { + marker: PhantomData, +} + +#[derive(TypeInfo, Default)] +struct SomeType; + +fn assert_type_info() {} + +fn main() { + assert_type_info::>(); +} \ No newline at end of file diff --git a/test_suite/tests/ui/pass_custom_bounds_fix_overflow.rs b/test_suite/tests/ui/pass_custom_bounds_fix_overflow.rs new file mode 100644 index 00000000..7a088ccf --- /dev/null +++ b/test_suite/tests/ui/pass_custom_bounds_fix_overflow.rs @@ -0,0 +1,22 @@ +use scale_info::TypeInfo; + +#[allow(unused)] +#[derive(TypeInfo)] +// Without this we get `overflow evaluating the requirement `Vec>: TypeInfo``. +// The custom bounds replace the auto generated bounds. +#[scale_info(bounds(T: TypeInfo + 'static))] +struct A { + a: Vec>, + b: Vec>, + marker: core::marker::PhantomData, +} + +#[allow(unused)] +#[derive(TypeInfo)] +struct B(A); + +fn assert_type_info() {} + +fn main() { + assert_type_info::>(); +} \ No newline at end of file