Skip to content

Commit

Permalink
Add new top-level attribute `scale_info(bounds(T: SomeTrait + OtherTr…
Browse files Browse the repository at this point in the history
…ait))` (#88)

* Add new top-level attribute `scale_info(bounds(T: SomeTrait + OtherTrait))`

* Fmt

* cleanup

* Run pass tests as wildcard

* Add ui pass test for fixing custom bounds overflow

* Fmt

* Move custom bounds test to ui test

* Satisfy clippy

* Fix UI test

Co-authored-by: Andrew Jones <[email protected]>
  • Loading branch information
dvdplm and ascjones committed Jun 14, 2021
1 parent ce199ab commit db886fd
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 46 deletions.
28 changes: 19 additions & 9 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ use syn::{
Variant,
};

#[proc_macro_derive(TypeInfo)]
#[proc_macro_derive(TypeInfo, attributes(scale_info))]
pub fn type_info(input: TokenStream) -> TokenStream {
match generate(input.into()) {
Ok(output) => output.into(),
Expand All @@ -66,21 +66,31 @@ fn generate(input: TokenStream2) -> Result<TokenStream2> {
}

fn generate_type(input: TokenStream2) -> Result<TokenStream2> {
let ast: DeriveInput = syn::parse2(input.clone())?;
let mut ast: DeriveInput = syn::parse2(input.clone())?;

utils::check_attributes(&ast)?;

let scale_info = crate_name_ident("scale-info")?;
let parity_scale_codec = crate_name_ident("parity-scale-codec")?;

let ident = &ast.ident;

let where_clause = if let Some(custom_bounds) = utils::custom_trait_bounds(&ast.attrs)
{
let where_clause = ast.generics.make_where_clause();
where_clause.predicates.extend(custom_bounds);
where_clause.clone()
} else {
trait_bounds::make_where_clause(
ident,
&ast.generics,
&ast.data,
&scale_info,
&parity_scale_codec,
)?
};

let (impl_generics, ty_generics, _) = ast.generics.split_for_impl();
let where_clause = trait_bounds::make_where_clause(
ident,
&ast.generics,
&ast.data,
&scale_info,
&parity_scale_codec,
)?;

let generic_type_ids = ast.generics.type_params().map(|ty| {
let ty_ident = &ty.ident;
Expand Down
106 changes: 90 additions & 16 deletions derive/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,14 @@ use alloc::{
use proc_macro2::TokenStream;
use quote::quote;
use syn::{
parse::Parse,
parse_quote,
punctuated::Punctuated,
spanned::Spanned,
token,
AttrStyle,
Attribute,
DeriveInput,
Lit,
Meta,
NestedMeta,
Expand Down Expand Up @@ -55,6 +59,41 @@ pub fn get_doc_literals(attrs: &[syn::Attribute]) -> Vec<syn::Lit> {
.collect()
}

/// Trait bounds.
pub type TraitBounds = Punctuated<syn::WherePredicate, token::Comma>;

/// Parse `name(T: Bound, N: Bound)` as a custom trait bound.
struct CustomTraitBound<N> {
_name: N,
_paren_token: token::Paren,
bounds: TraitBounds,
}

impl<N: Parse> Parse for CustomTraitBound<N> {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let content;
let _name = input.parse()?;
let _paren_token = syn::parenthesized!(content in input);
let bounds = content.parse_terminated(syn::WherePredicate::parse)?;
Ok(Self {
_name,
_paren_token,
bounds,
})
}
}

syn::custom_keyword!(bounds);

/// Look for a `#[scale_info(bounds(…))]`in the given attributes.
///
/// If found, use the given trait bounds when deriving the `TypeInfo` trait.
pub fn custom_trait_bounds(attrs: &[Attribute]) -> Option<TraitBounds> {
scale_info_meta_item(attrs.iter(), |meta: CustomTraitBound<bounds>| {
Some(meta.bounds)
})
}

/// Look for a `#[codec(index = $int)]` attribute on a variant. If no attribute
/// is found, fall back to the discriminant or just the variant index.
pub fn variant_index(v: &Variant, i: usize) -> TokenStream {
Expand All @@ -77,7 +116,7 @@ pub fn maybe_index(variant: &Variant) -> Option<u8> {
.iter()
.filter(|attr| attr.style == AttrStyle::Outer);

find_meta_item(outer_attrs, |meta| {
codec_meta_item(outer_attrs, |meta| {
if let NestedMeta::Meta(Meta::NameValue(ref nv)) = meta {
if nv.path.is_ident("index") {
if let Lit::Int(ref v) = nv.lit {
Expand All @@ -99,7 +138,7 @@ pub fn is_compact(field: &syn::Field) -> bool {
.attrs
.iter()
.filter(|attr| attr.style == AttrStyle::Outer);
find_meta_item(outer_attrs, |meta| {
codec_meta_item(outer_attrs, |meta| {
if let NestedMeta::Meta(Meta::Path(ref path)) = meta {
if path.is_ident("compact") {
return Some(())
Expand All @@ -113,7 +152,7 @@ pub fn is_compact(field: &syn::Field) -> bool {

/// Look for a `#[codec(skip)]` in the given attributes.
pub fn should_skip(attrs: &[Attribute]) -> bool {
find_meta_item(attrs.iter(), |meta| {
codec_meta_item(attrs.iter(), |meta| {
if let NestedMeta::Meta(Meta::Path(ref path)) = meta {
if path.is_ident("skip") {
return Some(path.span())
Expand All @@ -125,22 +164,57 @@ pub fn should_skip(attrs: &[Attribute]) -> bool {
.is_some()
}

fn find_meta_item<'a, F, R, I>(itr: I, pred: F) -> Option<R>
fn codec_meta_item<'a, F, R, I, M>(itr: I, pred: F) -> Option<R>
where
F: Fn(&NestedMeta) -> Option<R> + Clone,
F: FnMut(M) -> Option<R> + Clone,
I: Iterator<Item = &'a Attribute>,
M: Parse,
{
itr.filter_map(|attr| {
if attr.path.is_ident("codec") {
if let Meta::List(ref meta_list) = attr
.parse_meta()
.expect("scale-info: Bad index in `#[codec(index = …)]`, see `parity-scale-codec` error")
{
return meta_list.nested.iter().filter_map(pred.clone()).next()
}
}
find_meta_item("codec", itr, pred)
}

None
fn scale_info_meta_item<'a, F, R, I, M>(itr: I, pred: F) -> Option<R>
where
F: FnMut(M) -> Option<R> + Clone,
I: Iterator<Item = &'a Attribute>,
M: Parse,
{
find_meta_item("scale_info", itr, pred)
}

fn find_meta_item<'a, F, R, I, M>(kind: &str, mut itr: I, mut pred: F) -> Option<R>
where
F: FnMut(M) -> Option<R> + Clone,
I: Iterator<Item = &'a Attribute>,
M: Parse,
{
itr.find_map(|attr| {
attr.path
.is_ident(kind)
.then(|| pred(attr.parse_args().ok()?))
.flatten()
})
.next()
}

/// Ensure attributes are correctly applied. This *must* be called before using
/// any of the attribute finder methods or the macro may panic if it encounters
/// misapplied attributes.
/// `#[scale_info(bounds())]` is the only accepted attribute.
pub fn check_attributes(input: &DeriveInput) -> syn::Result<()> {
for attr in &input.attrs {
check_top_attribute(attr)?;
}
Ok(())
}

// Only `#[scale_info(bounds())]` is a valid top attribute.
fn check_top_attribute(attr: &Attribute) -> syn::Result<()> {
if attr.path.is_ident("scale_info") {
match attr.parse_args::<CustomTraitBound<bounds>>() {
Ok(_) => Ok(()),
Err(e) => Err(syn::Error::new(attr.span(), format!("Invalid attribute: {:?}. Only `#[scale_info(bounds(…))]` is a valid top attribute", e)))
}
} else {
Ok(())
}
}
6 changes: 1 addition & 5 deletions test_suite/tests/derive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -617,9 +617,5 @@ fn ui_tests() {
t.compile_fail("tests/ui/fail_unions.rs");
t.compile_fail("tests/ui/fail_use_codec_attrs_without_deriving_encode.rs");
t.compile_fail("tests/ui/fail_with_invalid_codec_attrs.rs");
t.pass("tests/ui/pass_with_valid_codec_attrs.rs");
t.pass("tests/ui/pass_non_static_lifetime.rs");
t.pass("tests/ui/pass_self_referential.rs");
t.pass("tests/ui/pass_basic_generic_type.rs");
t.pass("tests/ui/pass_complex_generic_self_referential_type.rs");
t.pass("tests/ui/pass_*");
}
16 changes: 0 additions & 16 deletions test_suite/tests/ui/fail_with_invalid_codec_attrs.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,8 @@ error: expected literal
18 | #[codec(index = a)]
| ^

error: proc-macro derive panicked
--> $DIR/fail_with_invalid_codec_attrs.rs:16:18
|
16 | #[derive(Encode, TypeInfo)]
| ^^^^^^^^
|
= help: message: scale-info: Bad index in `#[codec(index = …)]`, see `parity-scale-codec` error: Error("expected literal")

error: expected literal
--> $DIR/fail_with_invalid_codec_attrs.rs:24:25
|
24 | #[codec(encode_as = u8, compact)]
| ^^

error: proc-macro derive panicked
--> $DIR/fail_with_invalid_codec_attrs.rs:22:18
|
22 | #[derive(Encode, TypeInfo)]
| ^^^^^^^^
|
= help: message: scale-info: Bad index in `#[codec(index = …)]`, see `parity-scale-codec` error: Error("expected literal")
25 changes: 25 additions & 0 deletions test_suite/tests/ui/pass_custom_bounds.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use scale_info::TypeInfo;
use core::marker::PhantomData;

#[allow(unused)]
#[derive(TypeInfo)]
#[scale_info(bounds(T: Default + TypeInfo + 'static, N: TypeInfo + 'static))]
struct Hey<T, N> {
ciao: Greet<T>,
ho: N,
}

#[derive(TypeInfo)]
#[scale_info(bounds(T: TypeInfo + 'static))]
struct Greet<T> {
marker: PhantomData<T>,
}

#[derive(TypeInfo, Default)]
struct SomeType;

fn assert_type_info<T: TypeInfo + 'static>() {}

fn main() {
assert_type_info::<Hey<SomeType, u16>>();
}
22 changes: 22 additions & 0 deletions test_suite/tests/ui/pass_custom_bounds_fix_overflow.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use scale_info::TypeInfo;

#[allow(unused)]
#[derive(TypeInfo)]
// Without this we get `overflow evaluating the requirement `Vec<B<()>>: TypeInfo``.
// The custom bounds replace the auto generated bounds.
#[scale_info(bounds(T: TypeInfo + 'static))]
struct A<T> {
a: Vec<B<T>>,
b: Vec<B<()>>,
marker: core::marker::PhantomData<T>,
}

#[allow(unused)]
#[derive(TypeInfo)]
struct B<T>(A<T>);

fn assert_type_info<T: TypeInfo + 'static>() {}

fn main() {
assert_type_info::<A<bool>>();
}

0 comments on commit db886fd

Please sign in to comment.