Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only put Display-like bounds on type variables #387

Merged
merged 10 commits into from
Jul 25, 2024
114 changes: 8 additions & 106 deletions impl/src/fmt/debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ use crate::utils::{
Either, Spanning,
};

use super::{trait_name_to_attribute_name, ContainerAttributes, FmtAttribute};
use super::{
trait_name_to_attribute_name, ContainerAttributes, ContainsGenericsExt as _,
FmtAttribute,
};

/// Expands a [`fmt::Debug`] derive macro.
///
Expand All @@ -24,15 +27,15 @@ pub fn expand(input: &syn::DeriveInput, _: &str) -> syn::Result<TokenStream> {
.unwrap_or_default();
let ident = &input.ident;

let type_params: Vec<_> = input
let type_params = input
.generics
.params
.iter()
.filter_map(|p| match p {
syn::GenericParam::Type(t) => Some(&t.ident),
syn::GenericParam::Const(..) | syn::GenericParam::Lifetime(..) => None,
})
.collect();
.collect::<Vec<_>>();

let (bounds, body) = match &input.data {
syn::Data::Struct(s) => {
Expand Down Expand Up @@ -355,7 +358,7 @@ impl<'a> Expansion<'a> {
if let Some(fmt) = self.attr.fmt.as_ref() {
out.extend(fmt.bounded_types(self.fields).filter_map(
|(ty, trait_name)| {
if !self.contains_generic_param(ty) {
if !ty.contains_generics(self.type_params) {
return None;
}

Expand All @@ -369,7 +372,7 @@ impl<'a> Expansion<'a> {
self.fields.iter().try_fold(out, |mut out, field| {
let ty = &field.ty;

if !self.contains_generic_param(ty) {
if !ty.contains_generics(self.type_params) {
return Ok(out);
}

Expand All @@ -392,105 +395,4 @@ impl<'a> Expansion<'a> {
})
}
}

/// Checks whether the provided [`syn::Path`] contains any of these [`Expansion::type_params`].
fn path_contains_generic_param(&self, path: &syn::Path) -> bool {
path.segments
.iter()
.any(|segment| match &segment.arguments {
syn::PathArguments::None => false,
syn::PathArguments::AngleBracketed(
syn::AngleBracketedGenericArguments { args, .. },
) => args.iter().any(|generic| match generic {
syn::GenericArgument::Type(ty)
| syn::GenericArgument::AssocType(syn::AssocType { ty, .. }) => {
self.contains_generic_param(ty)
}

syn::GenericArgument::Lifetime(_)
| syn::GenericArgument::Const(_)
| syn::GenericArgument::AssocConst(_)
| syn::GenericArgument::Constraint(_) => false,
_ => unimplemented!(
"syntax is not supported by `derive_more`, please report a bug",
),
}),
syn::PathArguments::Parenthesized(
syn::ParenthesizedGenericArguments { inputs, output, .. },
) => {
inputs.iter().any(|ty| self.contains_generic_param(ty))
|| match output {
syn::ReturnType::Default => false,
syn::ReturnType::Type(_, ty) => {
self.contains_generic_param(ty)
}
}
}
})
}

/// Checks whether the provided [`syn::Type`] contains any of these [`Expansion::type_params`].
fn contains_generic_param(&self, ty: &syn::Type) -> bool {
if self.type_params.is_empty() {
return false;
}
match ty {
syn::Type::Path(syn::TypePath { qself, path }) => {
if let Some(qself) = qself {
if self.contains_generic_param(&qself.ty) {
return true;
}
}

if let Some(ident) = path.get_ident() {
self.type_params.iter().any(|param| *param == ident)
} else {
self.path_contains_generic_param(path)
}
}

syn::Type::Array(syn::TypeArray { elem, .. })
| syn::Type::Group(syn::TypeGroup { elem, .. })
| syn::Type::Paren(syn::TypeParen { elem, .. })
| syn::Type::Ptr(syn::TypePtr { elem, .. })
| syn::Type::Reference(syn::TypeReference { elem, .. })
| syn::Type::Slice(syn::TypeSlice { elem, .. }) => {
self.contains_generic_param(elem)
}

syn::Type::BareFn(syn::TypeBareFn { inputs, output, .. }) => {
inputs
.iter()
.any(|arg| self.contains_generic_param(&arg.ty))
|| match output {
syn::ReturnType::Default => false,
syn::ReturnType::Type(_, ty) => self.contains_generic_param(ty),
}
}
syn::Type::Tuple(syn::TypeTuple { elems, .. }) => {
elems.iter().any(|ty| self.contains_generic_param(ty))
}

syn::Type::ImplTrait(_) => false,
syn::Type::Infer(_) => false,
syn::Type::Macro(_) => false,
syn::Type::Never(_) => false,
syn::Type::TraitObject(syn::TypeTraitObject { bounds, .. }) => {
bounds.iter().any(|bound| match bound {
syn::TypeParamBound::Trait(syn::TraitBound { path, .. }) => {
self.path_contains_generic_param(path)
}
syn::TypeParamBound::Lifetime(_) => false,
syn::TypeParamBound::Verbatim(_) => false,
_ => unimplemented!(
"syntax is not supported by `derive_more`, please report a bug",
),
})
}
syn::Type::Verbatim(_) => false,
_ => unimplemented!(
"syntax is not supported by `derive_more`, please report a bug",
),
}
}
}
53 changes: 41 additions & 12 deletions impl/src/fmt/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ use syn::{ext::IdentExt as _, parse_quote, spanned::Spanned as _};

use crate::utils::{attr::ParseMultiple as _, Spanning};

use super::{trait_name_to_attribute_name, ContainerAttributes, FmtAttribute};
use super::{
trait_name_to_attribute_name, ContainerAttributes, ContainsGenericsExt as _,
FmtAttribute,
};

/// Expands a [`fmt::Display`]-like derive macro.
///
Expand All @@ -32,7 +35,17 @@ pub fn expand(input: &syn::DeriveInput, trait_name: &str) -> syn::Result<TokenSt
let trait_ident = format_ident!("{trait_name}");
let ident = &input.ident;

let ctx = (&attrs, ident, &trait_ident, &attr_name);
let type_params = input
.generics
.params
.iter()
.filter_map(|p| match p {
syn::GenericParam::Type(t) => Some(&t.ident),
syn::GenericParam::Const(..) | syn::GenericParam::Lifetime(..) => None,
})
.collect::<Vec<_>>();

let ctx: ExpansionCtx = (&attrs, &type_params, ident, &trait_ident, &attr_name);
let (bounds, body) = match &input.data {
syn::Data::Struct(s) => expand_struct(s, ctx),
syn::Data::Enum(e) => expand_enum(e, ctx),
Expand Down Expand Up @@ -62,13 +75,15 @@ pub fn expand(input: &syn::DeriveInput, trait_name: &str) -> syn::Result<TokenSt

/// Type alias for an expansion context:
/// - [`ContainerAttributes`].
/// - Type parameters. Slice of [`syn::Ident`].
/// - Struct/enum/union [`syn::Ident`].
/// - Derived trait [`syn::Ident`].
/// - Attribute name [`syn::Ident`].
///
/// [`syn::Ident`]: struct@syn::Ident
type ExpansionCtx<'a> = (
&'a ContainerAttributes,
&'a [&'a syn::Ident],
&'a syn::Ident,
&'a syn::Ident,
&'a syn::Ident,
Expand All @@ -77,12 +92,13 @@ type ExpansionCtx<'a> = (
/// Expands a [`fmt::Display`]-like derive macro for the provided struct.
fn expand_struct(
s: &syn::DataStruct,
(attrs, ident, trait_ident, _): ExpansionCtx<'_>,
(attrs, type_params, ident, trait_ident, _): ExpansionCtx<'_>,
) -> syn::Result<(Vec<syn::WherePredicate>, TokenStream)> {
let s = Expansion {
shared_attr: None,
attrs,
fields: &s.fields,
type_params,
trait_ident,
ident,
};
Expand Down Expand Up @@ -111,7 +127,7 @@ fn expand_struct(
/// Expands a [`fmt`]-like derive macro for the provided enum.
fn expand_enum(
e: &syn::DataEnum,
(container_attrs, _, trait_ident, attr_name): ExpansionCtx<'_>,
(container_attrs, type_params, _, trait_ident, attr_name): ExpansionCtx<'_>,
) -> syn::Result<(Vec<syn::WherePredicate>, TokenStream)> {
if let Some(shared_fmt) = &container_attrs.fmt {
if shared_fmt
Expand Down Expand Up @@ -153,6 +169,7 @@ fn expand_enum(
shared_attr: container_attrs.fmt.as_ref(),
attrs: &attrs,
fields: &variant.fields,
type_params,
trait_ident,
ident,
};
Expand Down Expand Up @@ -190,7 +207,7 @@ fn expand_enum(
/// Expands a [`fmt::Display`]-like derive macro for the provided union.
fn expand_union(
u: &syn::DataUnion,
(attrs, _, _, attr_name): ExpansionCtx<'_>,
(attrs, _, _, _, attr_name): ExpansionCtx<'_>,
) -> syn::Result<(Vec<syn::WherePredicate>, TokenStream)> {
let fmt = &attrs.fmt.as_ref().ok_or_else(|| {
syn::Error::new(
Expand Down Expand Up @@ -227,6 +244,9 @@ struct Expansion<'a> {
/// Struct or enum [`syn::Fields`].
fields: &'a syn::Fields,

/// Type parameters in this struct or enum.
type_params: &'a [&'a syn::Ident],

/// [`fmt`] trait [`syn::Ident`].
///
/// [`syn::Ident`]: struct@syn::Ident
Expand Down Expand Up @@ -338,28 +358,37 @@ impl<'a> Expansion<'a> {
if let Some(fmt) = &self.attrs.fmt {
bounds.extend(
fmt.bounded_types(self.fields)
.map(|(ty, trait_name)| {
.filter_map(|(ty, trait_name)| {
if !ty.contains_generics(self.type_params) {
return None;
}
let trait_ident = format_ident!("{trait_name}");

parse_quote! { #ty: derive_more::core::fmt::#trait_ident }
Some(parse_quote! { #ty: derive_more::core::fmt::#trait_ident })
})
.chain(self.attrs.bounds.0.clone()),
);
} else {
bounds.extend(self.fields.iter().next().map(|f| {
bounds.extend(self.fields.iter().next().and_then(|f| {
let ty = &f.ty;
if !ty.contains_generics(self.type_params) {
return None;
}
let trait_ident = &self.trait_ident;
parse_quote! { #ty: derive_more::core::fmt::#trait_ident }
}))
Some(parse_quote! { #ty: derive_more::core::fmt::#trait_ident })
}));
};
}

if let Some(shared_fmt) = &self.shared_attr {
bounds.extend(shared_fmt.bounded_types(self.fields).map(
bounds.extend(shared_fmt.bounded_types(self.fields).filter_map(
|(ty, trait_name)| {
if !ty.contains_generics(self.type_params) {
return None;
}
let trait_ident = format_ident!("{trait_name}");

parse_quote! { #ty: derive_more::core::fmt::#trait_ident }
Some(parse_quote! { #ty: derive_more::core::fmt::#trait_ident })
},
));
}
Expand Down
Loading