From ba5b4eb01292be56c57aefa18cdd302dfb0b65d9 Mon Sep 17 00:00:00 2001 From: Ding Xiang Fei Date: Sat, 13 Jul 2024 20:14:22 +0800 Subject: [PATCH] derive(SmartPointer): rewrite bounds in where and generic bounds --- .../src/deriving/smart_ptr.rs | 219 +++++++++++++++++- .../smart-pointer-bounds-issue-127647.rs | 78 +++++++ 2 files changed, 286 insertions(+), 11 deletions(-) create mode 100644 tests/ui/deriving/smart-pointer-bounds-issue-127647.rs diff --git a/compiler/rustc_builtin_macros/src/deriving/smart_ptr.rs b/compiler/rustc_builtin_macros/src/deriving/smart_ptr.rs index bbc7cd3962720..0c2777a105db4 100644 --- a/compiler/rustc_builtin_macros/src/deriving/smart_ptr.rs +++ b/compiler/rustc_builtin_macros/src/deriving/smart_ptr.rs @@ -1,21 +1,29 @@ use std::mem::swap; use ast::HasAttrs; +use rustc_ast::mut_visit::MutVisitor; use rustc_ast::{ self as ast, GenericArg, GenericBound, GenericParamKind, ItemKind, MetaItem, TraitBoundModifiers, VariantData, }; use rustc_attr as attr; +use rustc_data_structures::flat_map_in_place::FlatMapInPlace; use rustc_expand::base::{Annotatable, ExtCtxt}; use rustc_span::symbol::{sym, Ident}; -use rustc_span::Span; +use rustc_span::{Span, Symbol}; use smallvec::{smallvec, SmallVec}; use thin_vec::{thin_vec, ThinVec}; +type AstTy = ast::ptr::P; + macro_rules! path { ($span:expr, $($part:ident)::*) => { vec![$(Ident::new(sym::$part, $span),)*] } } +macro_rules! symbols { + ($($part:ident)::*) => { [$(sym::$part),*] } +} + pub fn expand_deriving_smart_ptr( cx: &ExtCtxt<'_>, span: Span, @@ -143,8 +151,11 @@ pub fn expand_deriving_smart_ptr( // Find the `#[pointee]` parameter and add an `Unsize<__S>` bound to it. let mut impl_generics = generics.clone(); + let pointee_ty_ident = generics.params[pointee_param_idx].ident; + let mut self_bounds; { let p = &mut impl_generics.params[pointee_param_idx]; + self_bounds = p.bounds.clone(); let arg = GenericArg::Type(s_ty.clone()); let unsize = cx.path_all(span, true, path!(span, core::marker::Unsize), vec![arg]); p.bounds.push(cx.trait_bound(unsize, false)); @@ -152,22 +163,208 @@ pub fn expand_deriving_smart_ptr( swap(&mut p.attrs, &mut attrs); p.attrs = attrs.into_iter().filter(|attr| !attr.has_name(sym::pointee)).collect(); } + // We should not set default values to constant generic parameters + // and commute bounds that indirectly involves `#[pointee]`. + for (params, orig_params) in impl_generics.params[pointee_param_idx + 1..] + .iter_mut() + .zip(&generics.params[pointee_param_idx + 1..]) + { + if let ast::GenericParamKind::Const { default, .. } = &mut params.kind { + *default = None; + } + for bound in &orig_params.bounds { + let mut bound = bound.clone(); + let mut substitution = TypeSubstitution { + from_name: pointee_ty_ident.name, + to_ty: &s_ty, + rewritten: false, + }; + substitution.visit_param_bound(&mut bound); + if substitution.rewritten { + params.bounds.push(bound); + } + } + } // Add the `__S: ?Sized` extra parameter to the impl block. + // We should also commute the bounds from `#[pointee]` to `__S` as required by `Unsize<__S>`. let sized = cx.path_global(span, path!(span, core::marker::Sized)); - let bound = GenericBound::Trait( - cx.poly_trait_ref(span, sized), - TraitBoundModifiers { - polarity: ast::BoundPolarity::Maybe(span), - constness: ast::BoundConstness::Never, - asyncness: ast::BoundAsyncness::Normal, - }, - ); - let extra_param = cx.typaram(span, Ident::new(sym::__S, span), vec![bound], None); - impl_generics.params.push(extra_param); + if self_bounds.iter().all(|bound| { + if let GenericBound::Trait( + trait_ref, + TraitBoundModifiers { polarity: ast::BoundPolarity::Maybe(_), .. }, + ) = bound + { + !is_sized_marker(&trait_ref.trait_ref.path) + } else { + false + } + }) { + self_bounds.push(GenericBound::Trait( + cx.poly_trait_ref(span, sized), + TraitBoundModifiers { + polarity: ast::BoundPolarity::Maybe(span), + constness: ast::BoundConstness::Never, + asyncness: ast::BoundAsyncness::Normal, + }, + )); + } + { + let mut substitution = + TypeSubstitution { from_name: pointee_ty_ident.name, to_ty: &s_ty, rewritten: false }; + for bound in &mut self_bounds { + substitution.visit_param_bound(bound); + } + } + + // We should also commute the where bounds from `#[pointee]` to `__S` + // as well as any bound that indirectly involves the `#[pointee]` type. + for bound in &generics.where_clause.predicates { + if let ast::WherePredicate::BoundPredicate(bound) = bound { + let bound_on_pointee = bound + .bounded_ty + .kind + .is_simple_path() + .map_or(false, |name| name == pointee_ty_ident.name); + + let bounds: Vec<_> = bound + .bounds + .iter() + .filter(|bound| { + if let GenericBound::Trait( + trait_ref, + TraitBoundModifiers { polarity: ast::BoundPolarity::Maybe(_), .. }, + ) = bound + { + !bound_on_pointee || !is_sized_marker(&trait_ref.trait_ref.path) + } else { + true + } + }) + .cloned() + .collect(); + let mut substitution = TypeSubstitution { + from_name: pointee_ty_ident.name, + to_ty: &s_ty, + rewritten: bounds.len() != bound.bounds.len(), + }; + let mut predicate = ast::WherePredicate::BoundPredicate(ast::WhereBoundPredicate { + span: bound.span, + bound_generic_params: bound.bound_generic_params.clone(), + bounded_ty: bound.bounded_ty.clone(), + bounds, + }); + substitution.visit_where_predicate(&mut predicate); + if substitution.rewritten { + impl_generics.where_clause.predicates.push(predicate); + } + } + } + + let extra_param = cx.typaram(span, Ident::new(sym::__S, span), self_bounds, None); + impl_generics.params.insert(pointee_param_idx + 1, extra_param); // Add the impl blocks for `DispatchFromDyn` and `CoerceUnsized`. let gen_args = vec![GenericArg::Type(alt_self_type.clone())]; add_impl_block(impl_generics.clone(), sym::DispatchFromDyn, gen_args.clone()); add_impl_block(impl_generics.clone(), sym::CoerceUnsized, gen_args.clone()); } + +fn is_sized_marker(path: &ast::Path) -> bool { + const CORE_UNSIZE: [Symbol; 3] = symbols!(core::marker::Sized); + const STD_UNSIZE: [Symbol; 3] = symbols!(std::marker::Sized); + if path.segments.len() == 3 { + path.segments.iter().zip(CORE_UNSIZE).all(|(segment, symbol)| segment.ident.name == symbol) + || path + .segments + .iter() + .zip(STD_UNSIZE) + .all(|(segment, symbol)| segment.ident.name == symbol) + } else { + *path == sym::Sized + } +} + +struct TypeSubstitution<'a> { + from_name: Symbol, + to_ty: &'a AstTy, + rewritten: bool, +} + +impl<'a> ast::mut_visit::MutVisitor for TypeSubstitution<'a> { + fn visit_ty(&mut self, ty: &mut AstTy) { + if let Some(name) = ty.kind.is_simple_path() + && name == self.from_name + { + *ty = self.to_ty.clone(); + self.rewritten = true; + return; + } + match &mut ty.kind { + ast::TyKind::Slice(_) + | ast::TyKind::Array(_, _) + | ast::TyKind::Ptr(_) + | ast::TyKind::Ref(_, _) + | ast::TyKind::BareFn(_) + | ast::TyKind::Never + | ast::TyKind::Tup(_) + | ast::TyKind::AnonStruct(_, _) + | ast::TyKind::AnonUnion(_, _) + | ast::TyKind::Path(_, _) + | ast::TyKind::TraitObject(_, _) + | ast::TyKind::ImplTrait(_, _) + | ast::TyKind::Paren(_) + | ast::TyKind::Typeof(_) + | ast::TyKind::Infer + | ast::TyKind::MacCall(_) + | ast::TyKind::Pat(_, _) => ast::mut_visit::noop_visit_ty(ty, self), + ast::TyKind::ImplicitSelf + | ast::TyKind::CVarArgs + | ast::TyKind::Dummy + | ast::TyKind::Err(_) => {} + } + } + + fn visit_param_bound(&mut self, bound: &mut GenericBound) { + match bound { + GenericBound::Trait(trait_ref, _) => { + if trait_ref + .bound_generic_params + .iter() + .any(|param| param.ident.name == self.from_name) + { + return; + } + self.visit_poly_trait_ref(trait_ref); + } + + GenericBound::Use(args, _span) => { + for arg in args { + self.visit_precise_capturing_arg(arg); + } + } + GenericBound::Outlives(_) => {} + } + } + + fn visit_where_predicate(&mut self, where_predicate: &mut ast::WherePredicate) { + match where_predicate { + rustc_ast::WherePredicate::BoundPredicate(bound) => { + if bound.bound_generic_params.iter().any(|param| param.ident.name == self.from_name) + { + // Name is shadowed so we must skip the rest + return; + } + bound + .bound_generic_params + .flat_map_in_place(|param| self.flat_map_generic_param(param)); + self.visit_ty(&mut bound.bounded_ty); + for bound in &mut bound.bounds { + self.visit_param_bound(bound) + } + } + rustc_ast::WherePredicate::RegionPredicate(_) + | rustc_ast::WherePredicate::EqPredicate(_) => {} + } + } +} diff --git a/tests/ui/deriving/smart-pointer-bounds-issue-127647.rs b/tests/ui/deriving/smart-pointer-bounds-issue-127647.rs new file mode 100644 index 0000000000000..0d4f765aade2c --- /dev/null +++ b/tests/ui/deriving/smart-pointer-bounds-issue-127647.rs @@ -0,0 +1,78 @@ +//@ check-pass + +#![feature(derive_smart_pointer)] + +#[derive(core::marker::SmartPointer)] +#[repr(transparent)] +pub struct Ptr<'a, #[pointee] T: OnDrop + ?Sized, X> { + data: &'a mut T, + x: core::marker::PhantomData, +} + +pub trait OnDrop { + fn on_drop(&mut self); +} + +#[derive(core::marker::SmartPointer)] +#[repr(transparent)] +pub struct Ptr2<'a, #[pointee] T: ?Sized, X> +where + T: OnDrop, +{ + data: &'a mut T, + x: core::marker::PhantomData, +} + +pub trait MyTrait {} + +#[derive(core::marker::SmartPointer)] +#[repr(transparent)] +pub struct Ptr3<'a, #[pointee] T: ?Sized, X> +where + T: MyTrait, +{ + data: &'a mut T, + x: core::marker::PhantomData, +} + +#[derive(core::marker::SmartPointer)] +#[repr(transparent)] +pub struct Ptr4<'a, #[pointee] T: MyTrait + ?Sized, X> { + data: &'a mut T, + x: core::marker::PhantomData, +} + +#[derive(core::marker::SmartPointer)] +#[repr(transparent)] +pub struct Ptr5<'a, #[pointee] T: ?Sized, X> +where + Ptr5Companion: MyTrait, + Ptr5Companion2: MyTrait, +{ + data: &'a mut T, + x: core::marker::PhantomData, +} + +pub struct Ptr5Companion(core::marker::PhantomData); +pub struct Ptr5Companion2; + +#[derive(core::marker::SmartPointer)] +#[repr(transparent)] +pub struct Ptr6<'a, #[pointee] T: ?Sized, X: MyTrait> { + data: &'a mut T, + x: core::marker::PhantomData, +} + +// a reduced example from https://lore.kernel.org/all/20240402-linked-list-v1-1-b1c59ba7ae3b@google.com/ +#[repr(transparent)] +#[derive(core::marker::SmartPointer)] +pub struct ListArc<#[pointee] T, const ID: u64 = 0> +where + T: ListArcSafe + ?Sized, +{ + arc: *const T, +} + +pub trait ListArcSafe {} + +fn main() {}