Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow trait bounds to be manually specified #138

Merged
merged 4 commits into from
Mar 10, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion derive/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ rust-version = "1.63.0"
[dependencies]
proc-macro2 = "1.0"
quote = "1.0"
syn = { version = "1.0.56", features = ['derive'] }
syn = { version = "1.0.56", features = ['derive', 'parsing'] }

[lib]
proc_macro = true
72 changes: 72 additions & 0 deletions derive/src/container_attributes.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
use crate::ARBITRARY_ATTRIBUTE_NAME;
use syn::{
parse::Error, punctuated::Punctuated, DeriveInput, Lit, Meta, MetaNameValue, NestedMeta, Token,
TypeParam,
};

pub struct ContainerAttributes {
/// Specify type bounds to be applied to the derived `Arbitrary` implementation instead of the
/// default inferred bounds.
///
/// ```
/// #[arbitrary(bound = "T: Default, U: Debug")]
/// ```
///
/// Multiple attributes will be combined as long as they don't conflict, e.g.
///
/// ```
/// #[arbitrary(bound = "T: Default")]
/// #[arbitrary(bound = "U: Default")]
/// ```
pub bounds: Option<Vec<Punctuated<TypeParam, Token![,]>>>,
michaelsproul marked this conversation as resolved.
Show resolved Hide resolved
}

impl ContainerAttributes {
pub fn from_derive_input(derive_input: &DeriveInput) -> Result<Self, Error> {
let mut bounds = None;

for attr in &derive_input.attrs {
if !attr.path.is_ident(ARBITRARY_ATTRIBUTE_NAME) {
continue;
}

let meta_list = match attr.parse_meta()? {
Meta::List(l) => l,
_ => {
return Err(Error::new_spanned(
attr,
format!(
"invalid `{}` attribute. expected list",
ARBITRARY_ATTRIBUTE_NAME
),
))
}
};

for nested_meta in meta_list.nested.iter() {
match nested_meta {
NestedMeta::Meta(Meta::NameValue(MetaNameValue {
path,
lit: Lit::Str(bound_str_lit),
..
})) if path.is_ident("bound") => {
bounds
.get_or_insert_with(Vec::new)
.push(bound_str_lit.parse_with(Punctuated::parse_terminated)?);
}
_ => {
return Err(Error::new_spanned(
attr,
format!(
"invalid `{}` attribute. expected `bound = \"..\"`",
ARBITRARY_ATTRIBUTE_NAME,
),
))
}
}
}
}

Ok(Self { bounds })
}
}
4 changes: 1 addition & 3 deletions derive/src/field_attributes.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
use crate::ARBITRARY_ATTRIBUTE_NAME;
use proc_macro2::{Group, Span, TokenStream, TokenTree};
use quote::quote;
use syn::{spanned::Spanned, *};

/// Used to filter out necessary field attribute and within error messages.
static ARBITRARY_ATTRIBUTE_NAME: &str = "arbitrary";

/// Determines how a value for a field should be constructed.
#[cfg_attr(test, derive(Debug))]
pub enum FieldConstructor {
Expand Down
59 changes: 57 additions & 2 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@ use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::*;

mod container_attributes;
mod field_attributes;
use container_attributes::ContainerAttributes;
use field_attributes::{determine_field_constructor, FieldConstructor};

static ARBITRARY_ATTRIBUTE_NAME: &str = "arbitrary";
static ARBITRARY_LIFETIME_NAME: &str = "'arbitrary";

#[proc_macro_derive(Arbitrary, attributes(arbitrary))]
Expand All @@ -18,6 +21,8 @@ pub fn derive_arbitrary(tokens: proc_macro::TokenStream) -> proc_macro::TokenStr
}

fn expand_derive_arbitrary(input: syn::DeriveInput) -> Result<TokenStream> {
let container_attrs = ContainerAttributes::from_derive_input(&input)?;

let (lifetime_without_bounds, lifetime_with_bounds) =
build_arbitrary_lifetime(input.generics.clone());

Expand All @@ -30,8 +35,13 @@ fn expand_derive_arbitrary(input: syn::DeriveInput) -> Result<TokenStream> {
gen_arbitrary_method(&input, lifetime_without_bounds.clone(), &recursive_count)?;
let size_hint_method = gen_size_hint_method(&input)?;
let name = input.ident;
// Add a bound `T: Arbitrary` to every type parameter T.
let generics = add_trait_bounds(input.generics, lifetime_without_bounds.clone());

// Apply user-supplied bounds or automatic `T: ArbitraryBounds`.
let generics = apply_trait_bounds(
input.generics,
lifetime_without_bounds.clone(),
&container_attrs,
)?;

// Build ImplGeneric with a lifetime (https://github.com/dtolnay/syn/issues/90)
let mut generics_with_lifetime = generics.clone();
Expand Down Expand Up @@ -77,6 +87,51 @@ fn build_arbitrary_lifetime(generics: Generics) -> (LifetimeDef, LifetimeDef) {
(lifetime_without_bounds, lifetime_with_bounds)
}

fn apply_trait_bounds(
mut generics: Generics,
lifetime: LifetimeDef,
container_attrs: &ContainerAttributes,
) -> Result<Generics> {
// If user-supplied bounds exist, apply them to their matching type parameters.
if let Some(config_bounds) = &container_attrs.bounds {
let mut config_bounds_applied = 0;
for param in generics.params.iter_mut() {
if let GenericParam::Type(type_param) = param {
if let Some(replacement) = config_bounds
.iter()
.flatten()
.find(|p| p.ident == type_param.ident)
{
*type_param = replacement.clone();
config_bounds_applied += 1;
} else {
// If no user-supplied bounds exist for this type, delete the original bounds.
// This mimics serde.
type_param.bounds = Default::default();
type_param.default = None;
}
}
}
let config_bounds_supplied = config_bounds
.iter()
.map(|bounds| bounds.len())
.sum::<usize>();
if config_bounds_applied != config_bounds_supplied {
return Err(Error::new(
Span::call_site(),
format!(
"invalid `{}` attribute. too many bounds, only {} out of {} are applicable",
ARBITRARY_ATTRIBUTE_NAME, config_bounds_applied, config_bounds_supplied,
),
));
}
Ok(generics)
} else {
// Otherwise, inject a `T: Arbitrary` bound for every parameter.
Ok(add_trait_bounds(generics, lifetime))
}
}

// Add a bound `T: Arbitrary` to every type parameter T.
fn add_trait_bounds(mut generics: Generics, lifetime: LifetimeDef) -> Generics {
for param in generics.params.iter_mut() {
Expand Down
51 changes: 51 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1338,5 +1338,56 @@ mod test {
/// x: i32,
/// }
/// ```
///
/// Multiple conflicting bounds at the container-level:
/// ```compile_fail
/// #[derive(::arbitrary::Arbitrary)]
/// #[arbitrary(bound = "T: Default")]
/// #[arbitrary(bound = "T: Default")]
/// struct Point<T: Default> {
/// #[arbitrary(default)]
/// x: T,
/// }
/// ```
///
/// Multiple conflicting bounds in a single bound attribute:
/// ```compile_fail
/// #[derive(::arbitrary::Arbitrary)]
/// #[arbitrary(bound = "T: Default, T: Default")]
/// struct Point<T: Default> {
/// #[arbitrary(default)]
/// x: T,
/// }
/// ```
///
/// Multiple conflicting bounds in multiple bound attributes:
/// ```compile_fail
/// #[derive(::arbitrary::Arbitrary)]
/// #[arbitrary(bound = "T: Default", bound = "T: Default")]
/// struct Point<T: Default> {
/// #[arbitrary(default)]
/// x: T,
/// }
/// ```
///
/// Too many bounds supplied:
/// ```compile_fail
/// #[derive(::arbitrary::Arbitrary)]
/// #[arbitrary(bound = "T: Default")]
/// struct Point {
/// x: i32,
/// }
/// ```
///
/// Too many bounds supplied across multiple attributes:
/// ```compile_fail
/// #[derive(::arbitrary::Arbitrary)]
/// #[arbitrary(bound = "T: Default")]
/// #[arbitrary(bound = "U: Default")]
/// struct Point<T: Default> {
/// #[arbitrary(default)]
/// x: T,
/// }
/// ```
#[cfg(all(doctest, feature = "derive"))]
pub struct CompileFailTests;
142 changes: 142 additions & 0 deletions tests/bound.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
#![cfg(feature = "derive")]

use arbitrary::{Arbitrary, Unstructured};

fn arbitrary_from<'a, T: Arbitrary<'a>>(input: &'a [u8]) -> T {
let mut buf = Unstructured::new(input);
T::arbitrary(&mut buf).expect("can create arbitrary instance OK")
}

/// This wrapper trait *implies* `Arbitrary`, but the compiler isn't smart enough to work that out
/// so when using this wrapper we *must* opt-out of the auto-generated `T: Arbitrary` bounds.
pub trait WrapperTrait: for<'a> Arbitrary<'a> {}

impl WrapperTrait for u32 {}

#[derive(Arbitrary)]
#[arbitrary(bound = "T: WrapperTrait")]
struct GenericSingleBound<T: WrapperTrait> {
t: T,
}

#[test]
fn single_bound() {
let v: GenericSingleBound<u32> = arbitrary_from(&[0, 0, 0, 0]);
assert_eq!(v.t, 0);
}

#[derive(Arbitrary)]
#[arbitrary(bound = "T: WrapperTrait, U: WrapperTrait")]
struct GenericMultipleBoundsSingleAttribute<T: WrapperTrait, U: WrapperTrait> {
t: T,
u: U,
}

#[test]
fn multiple_bounds_single_attribute() {
let v: GenericMultipleBoundsSingleAttribute<u32, u32> =
arbitrary_from(&[1, 0, 0, 0, 2, 0, 0, 0]);
assert_eq!(v.t, 1);
assert_eq!(v.u, 2);
}

#[derive(Arbitrary)]
#[arbitrary(bound = "T: WrapperTrait")]
#[arbitrary(bound = "U: Default")]
struct GenericMultipleArbitraryAttributes<T: WrapperTrait, U: Default> {
t: T,
#[arbitrary(default)]
u: U,
}

#[test]
fn multiple_arbitrary_attributes() {
let v: GenericMultipleArbitraryAttributes<u32, u32> = arbitrary_from(&[1, 0, 0, 0]);
assert_eq!(v.t, 1);
assert_eq!(v.u, 0);
}

#[derive(Arbitrary)]
#[arbitrary(bound = "T: WrapperTrait", bound = "U: Default")]
struct GenericMultipleBoundAttributes<T: WrapperTrait, U: Default> {
t: T,
#[arbitrary(default)]
u: U,
}

#[test]
fn multiple_bound_attributes() {
let v: GenericMultipleBoundAttributes<u32, u32> = arbitrary_from(&[1, 0, 0, 0]);
assert_eq!(v.t, 1);
assert_eq!(v.u, 0);
}

#[derive(Arbitrary)]
#[arbitrary(bound = "T: WrapperTrait", bound = "U: Default")]
#[arbitrary(bound = "V: WrapperTrait, W: Default")]
struct GenericMultipleArbitraryAndBoundAttributes<
T: WrapperTrait,
U: Default,
V: WrapperTrait,
W: Default,
> {
t: T,
#[arbitrary(default)]
u: U,
v: V,
#[arbitrary(default)]
w: W,
}

#[test]
fn multiple_arbitrary_and_bound_attributes() {
let v: GenericMultipleArbitraryAndBoundAttributes<u32, u32, u32, u32> =
arbitrary_from(&[1, 0, 0, 0, 2, 0, 0, 0]);
assert_eq!(v.t, 1);
assert_eq!(v.u, 0);
assert_eq!(v.v, 2);
assert_eq!(v.w, 0);
}

#[derive(Arbitrary)]
#[arbitrary(bound = "T: Default")]
struct GenericDefault<T: Default> {
#[arbitrary(default)]
x: T,
}

#[test]
fn default_bound() {
// We can write a generic func without any `Arbitrary` bound.
fn generic_default<T: Default>() -> GenericDefault<T> {
arbitrary_from(&[])
}

assert_eq!(generic_default::<u64>().x, 0);
assert_eq!(generic_default::<String>().x, String::new());
assert_eq!(generic_default::<Vec<u8>>().x, Vec::new());
}

#[derive(Arbitrary)]
#[arbitrary()]
struct EmptyArbitraryAttribute {
t: u32,
}

#[test]
fn empty_arbitrary_attribute() {
let v: EmptyArbitraryAttribute = arbitrary_from(&[1, 0, 0, 0]);
assert_eq!(v.t, 1);
}

#[derive(Arbitrary)]
#[arbitrary(bound = "")]
struct EmptyBoundAttribute {
t: u32,
}

#[test]
fn empty_bound_attribute() {
let v: EmptyBoundAttribute = arbitrary_from(&[1, 0, 0, 0]);
assert_eq!(v.t, 1);
}