Skip to content

Commit

Permalink
Merge pull request #138 from michaelsproul/specify-bounds
Browse files Browse the repository at this point in the history
Allow trait bounds to be manually specified
  • Loading branch information
fitzgen authored Mar 10, 2023
2 parents 061ca86 + a5a9527 commit c397cc2
Show file tree
Hide file tree
Showing 6 changed files with 324 additions and 6 deletions.
2 changes: 1 addition & 1 deletion derive/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ rust-version = "1.63.0"
[dependencies]
proc-macro2 = "1.0"
quote = "1.0"
syn = { version = "1.0.56", features = ['derive'] }
syn = { version = "1.0.56", features = ['derive', 'parsing'] }

[lib]
proc_macro = true
72 changes: 72 additions & 0 deletions derive/src/container_attributes.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
use crate::ARBITRARY_ATTRIBUTE_NAME;
use syn::{
parse::Error, punctuated::Punctuated, DeriveInput, Lit, Meta, MetaNameValue, NestedMeta, Token,
TypeParam,
};

pub struct ContainerAttributes {
/// Specify type bounds to be applied to the derived `Arbitrary` implementation instead of the
/// default inferred bounds.
///
/// ```ignore
/// #[arbitrary(bound = "T: Default, U: Debug")]
/// ```
///
/// Multiple attributes will be combined as long as they don't conflict, e.g.
///
/// ```ignore
/// #[arbitrary(bound = "T: Default")]
/// #[arbitrary(bound = "U: Default")]
/// ```
pub bounds: Option<Vec<Punctuated<TypeParam, Token![,]>>>,
}

impl ContainerAttributes {
pub fn from_derive_input(derive_input: &DeriveInput) -> Result<Self, Error> {
let mut bounds = None;

for attr in &derive_input.attrs {
if !attr.path.is_ident(ARBITRARY_ATTRIBUTE_NAME) {
continue;
}

let meta_list = match attr.parse_meta()? {
Meta::List(l) => l,
_ => {
return Err(Error::new_spanned(
attr,
format!(
"invalid `{}` attribute. expected list",
ARBITRARY_ATTRIBUTE_NAME
),
))
}
};

for nested_meta in meta_list.nested.iter() {
match nested_meta {
NestedMeta::Meta(Meta::NameValue(MetaNameValue {
path,
lit: Lit::Str(bound_str_lit),
..
})) if path.is_ident("bound") => {
bounds
.get_or_insert_with(Vec::new)
.push(bound_str_lit.parse_with(Punctuated::parse_terminated)?);
}
_ => {
return Err(Error::new_spanned(
attr,
format!(
"invalid `{}` attribute. expected `bound = \"..\"`",
ARBITRARY_ATTRIBUTE_NAME,
),
))
}
}
}
}

Ok(Self { bounds })
}
}
4 changes: 1 addition & 3 deletions derive/src/field_attributes.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
use crate::ARBITRARY_ATTRIBUTE_NAME;
use proc_macro2::{Group, Span, TokenStream, TokenTree};
use quote::quote;
use syn::{spanned::Spanned, *};

/// Used to filter out necessary field attribute and within error messages.
static ARBITRARY_ATTRIBUTE_NAME: &str = "arbitrary";

/// Determines how a value for a field should be constructed.
#[cfg_attr(test, derive(Debug))]
pub enum FieldConstructor {
Expand Down
59 changes: 57 additions & 2 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@ use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::*;

mod container_attributes;
mod field_attributes;
use container_attributes::ContainerAttributes;
use field_attributes::{determine_field_constructor, FieldConstructor};

static ARBITRARY_ATTRIBUTE_NAME: &str = "arbitrary";
static ARBITRARY_LIFETIME_NAME: &str = "'arbitrary";

#[proc_macro_derive(Arbitrary, attributes(arbitrary))]
Expand All @@ -18,6 +21,8 @@ pub fn derive_arbitrary(tokens: proc_macro::TokenStream) -> proc_macro::TokenStr
}

fn expand_derive_arbitrary(input: syn::DeriveInput) -> Result<TokenStream> {
let container_attrs = ContainerAttributes::from_derive_input(&input)?;

let (lifetime_without_bounds, lifetime_with_bounds) =
build_arbitrary_lifetime(input.generics.clone());

Expand All @@ -30,8 +35,13 @@ fn expand_derive_arbitrary(input: syn::DeriveInput) -> Result<TokenStream> {
gen_arbitrary_method(&input, lifetime_without_bounds.clone(), &recursive_count)?;
let size_hint_method = gen_size_hint_method(&input)?;
let name = input.ident;
// Add a bound `T: Arbitrary` to every type parameter T.
let generics = add_trait_bounds(input.generics, lifetime_without_bounds.clone());

// Apply user-supplied bounds or automatic `T: ArbitraryBounds`.
let generics = apply_trait_bounds(
input.generics,
lifetime_without_bounds.clone(),
&container_attrs,
)?;

// Build ImplGeneric with a lifetime (https://github.com/dtolnay/syn/issues/90)
let mut generics_with_lifetime = generics.clone();
Expand Down Expand Up @@ -77,6 +87,51 @@ fn build_arbitrary_lifetime(generics: Generics) -> (LifetimeDef, LifetimeDef) {
(lifetime_without_bounds, lifetime_with_bounds)
}

fn apply_trait_bounds(
mut generics: Generics,
lifetime: LifetimeDef,
container_attrs: &ContainerAttributes,
) -> Result<Generics> {
// If user-supplied bounds exist, apply them to their matching type parameters.
if let Some(config_bounds) = &container_attrs.bounds {
let mut config_bounds_applied = 0;
for param in generics.params.iter_mut() {
if let GenericParam::Type(type_param) = param {
if let Some(replacement) = config_bounds
.iter()
.flatten()
.find(|p| p.ident == type_param.ident)
{
*type_param = replacement.clone();
config_bounds_applied += 1;
} else {
// If no user-supplied bounds exist for this type, delete the original bounds.
// This mimics serde.
type_param.bounds = Default::default();
type_param.default = None;
}
}
}
let config_bounds_supplied = config_bounds
.iter()
.map(|bounds| bounds.len())
.sum::<usize>();
if config_bounds_applied != config_bounds_supplied {
return Err(Error::new(
Span::call_site(),
format!(
"invalid `{}` attribute. too many bounds, only {} out of {} are applicable",
ARBITRARY_ATTRIBUTE_NAME, config_bounds_applied, config_bounds_supplied,
),
));
}
Ok(generics)
} else {
// Otherwise, inject a `T: Arbitrary` bound for every parameter.
Ok(add_trait_bounds(generics, lifetime))
}
}

// Add a bound `T: Arbitrary` to every type parameter T.
fn add_trait_bounds(mut generics: Generics, lifetime: LifetimeDef) -> Generics {
for param in generics.params.iter_mut() {
Expand Down
51 changes: 51 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1338,5 +1338,56 @@ mod test {
/// x: i32,
/// }
/// ```
///
/// Multiple conflicting bounds at the container-level:
/// ```compile_fail
/// #[derive(::arbitrary::Arbitrary)]
/// #[arbitrary(bound = "T: Default")]
/// #[arbitrary(bound = "T: Default")]
/// struct Point<T: Default> {
/// #[arbitrary(default)]
/// x: T,
/// }
/// ```
///
/// Multiple conflicting bounds in a single bound attribute:
/// ```compile_fail
/// #[derive(::arbitrary::Arbitrary)]
/// #[arbitrary(bound = "T: Default, T: Default")]
/// struct Point<T: Default> {
/// #[arbitrary(default)]
/// x: T,
/// }
/// ```
///
/// Multiple conflicting bounds in multiple bound attributes:
/// ```compile_fail
/// #[derive(::arbitrary::Arbitrary)]
/// #[arbitrary(bound = "T: Default", bound = "T: Default")]
/// struct Point<T: Default> {
/// #[arbitrary(default)]
/// x: T,
/// }
/// ```
///
/// Too many bounds supplied:
/// ```compile_fail
/// #[derive(::arbitrary::Arbitrary)]
/// #[arbitrary(bound = "T: Default")]
/// struct Point {
/// x: i32,
/// }
/// ```
///
/// Too many bounds supplied across multiple attributes:
/// ```compile_fail
/// #[derive(::arbitrary::Arbitrary)]
/// #[arbitrary(bound = "T: Default")]
/// #[arbitrary(bound = "U: Default")]
/// struct Point<T: Default> {
/// #[arbitrary(default)]
/// x: T,
/// }
/// ```
#[cfg(all(doctest, feature = "derive"))]
pub struct CompileFailTests;
142 changes: 142 additions & 0 deletions tests/bound.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
#![cfg(feature = "derive")]

use arbitrary::{Arbitrary, Unstructured};

fn arbitrary_from<'a, T: Arbitrary<'a>>(input: &'a [u8]) -> T {
let mut buf = Unstructured::new(input);
T::arbitrary(&mut buf).expect("can create arbitrary instance OK")
}

/// This wrapper trait *implies* `Arbitrary`, but the compiler isn't smart enough to work that out
/// so when using this wrapper we *must* opt-out of the auto-generated `T: Arbitrary` bounds.
pub trait WrapperTrait: for<'a> Arbitrary<'a> {}

impl WrapperTrait for u32 {}

#[derive(Arbitrary)]
#[arbitrary(bound = "T: WrapperTrait")]
struct GenericSingleBound<T: WrapperTrait> {
t: T,
}

#[test]
fn single_bound() {
let v: GenericSingleBound<u32> = arbitrary_from(&[0, 0, 0, 0]);
assert_eq!(v.t, 0);
}

#[derive(Arbitrary)]
#[arbitrary(bound = "T: WrapperTrait, U: WrapperTrait")]
struct GenericMultipleBoundsSingleAttribute<T: WrapperTrait, U: WrapperTrait> {
t: T,
u: U,
}

#[test]
fn multiple_bounds_single_attribute() {
let v: GenericMultipleBoundsSingleAttribute<u32, u32> =
arbitrary_from(&[1, 0, 0, 0, 2, 0, 0, 0]);
assert_eq!(v.t, 1);
assert_eq!(v.u, 2);
}

#[derive(Arbitrary)]
#[arbitrary(bound = "T: WrapperTrait")]
#[arbitrary(bound = "U: Default")]
struct GenericMultipleArbitraryAttributes<T: WrapperTrait, U: Default> {
t: T,
#[arbitrary(default)]
u: U,
}

#[test]
fn multiple_arbitrary_attributes() {
let v: GenericMultipleArbitraryAttributes<u32, u32> = arbitrary_from(&[1, 0, 0, 0]);
assert_eq!(v.t, 1);
assert_eq!(v.u, 0);
}

#[derive(Arbitrary)]
#[arbitrary(bound = "T: WrapperTrait", bound = "U: Default")]
struct GenericMultipleBoundAttributes<T: WrapperTrait, U: Default> {
t: T,
#[arbitrary(default)]
u: U,
}

#[test]
fn multiple_bound_attributes() {
let v: GenericMultipleBoundAttributes<u32, u32> = arbitrary_from(&[1, 0, 0, 0]);
assert_eq!(v.t, 1);
assert_eq!(v.u, 0);
}

#[derive(Arbitrary)]
#[arbitrary(bound = "T: WrapperTrait", bound = "U: Default")]
#[arbitrary(bound = "V: WrapperTrait, W: Default")]
struct GenericMultipleArbitraryAndBoundAttributes<
T: WrapperTrait,
U: Default,
V: WrapperTrait,
W: Default,
> {
t: T,
#[arbitrary(default)]
u: U,
v: V,
#[arbitrary(default)]
w: W,
}

#[test]
fn multiple_arbitrary_and_bound_attributes() {
let v: GenericMultipleArbitraryAndBoundAttributes<u32, u32, u32, u32> =
arbitrary_from(&[1, 0, 0, 0, 2, 0, 0, 0]);
assert_eq!(v.t, 1);
assert_eq!(v.u, 0);
assert_eq!(v.v, 2);
assert_eq!(v.w, 0);
}

#[derive(Arbitrary)]
#[arbitrary(bound = "T: Default")]
struct GenericDefault<T: Default> {
#[arbitrary(default)]
x: T,
}

#[test]
fn default_bound() {
// We can write a generic func without any `Arbitrary` bound.
fn generic_default<T: Default>() -> GenericDefault<T> {
arbitrary_from(&[])
}

assert_eq!(generic_default::<u64>().x, 0);
assert_eq!(generic_default::<String>().x, String::new());
assert_eq!(generic_default::<Vec<u8>>().x, Vec::new());
}

#[derive(Arbitrary)]
#[arbitrary()]
struct EmptyArbitraryAttribute {
t: u32,
}

#[test]
fn empty_arbitrary_attribute() {
let v: EmptyArbitraryAttribute = arbitrary_from(&[1, 0, 0, 0]);
assert_eq!(v.t, 1);
}

#[derive(Arbitrary)]
#[arbitrary(bound = "")]
struct EmptyBoundAttribute {
t: u32,
}

#[test]
fn empty_bound_attribute() {
let v: EmptyBoundAttribute = arbitrary_from(&[1, 0, 0, 0]);
assert_eq!(v.t, 1);
}

0 comments on commit c397cc2

Please sign in to comment.