diff --git a/compiler/rustc_hir_analysis/src/check/wfcheck.rs b/compiler/rustc_hir_analysis/src/check/wfcheck.rs index 66c3904af963b..5743f086f89b4 100644 --- a/compiler/rustc_hir_analysis/src/check/wfcheck.rs +++ b/compiler/rustc_hir_analysis/src/check/wfcheck.rs @@ -1599,7 +1599,7 @@ fn check_return_position_impl_trait_in_trait_bounds<'tcx>( { for arg in fn_output.walk() { if let ty::GenericArgKind::Type(ty) = arg.unpack() - && let ty::Alias(ty::Projection, proj) = ty.kind() + && let ty::Alias(ty::Opaque, proj) = ty.kind() && tcx.def_kind(proj.def_id) == DefKind::ImplTraitPlaceholder && tcx.impl_trait_in_trait_parent(proj.def_id) == fn_def_id.to_def_id() { diff --git a/compiler/rustc_middle/src/ty/util.rs b/compiler/rustc_middle/src/ty/util.rs index a34ee1a99a178..ca46cf29919f8 100644 --- a/compiler/rustc_middle/src/ty/util.rs +++ b/compiler/rustc_middle/src/ty/util.rs @@ -4,7 +4,7 @@ use crate::middle::codegen_fn_attrs::CodegenFnAttrFlags; use crate::mir; use crate::ty::layout::IntegerExt; use crate::ty::{ - self, ir::TypeFolder, DefIdTree, FallibleTypeFolder, Ty, TyCtxt, TypeFoldable, + self, ir::TypeFolder, DefIdTree, FallibleTypeFolder, ToPredicate, Ty, TyCtxt, TypeFoldable, TypeSuperFoldable, }; use crate::ty::{GenericArgKind, SubstsRef}; @@ -865,6 +865,26 @@ impl<'tcx> TypeFolder> for OpaqueTypeExpander<'tcx> { } t } + + fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> { + if let ty::PredicateKind::Clause(clause) = p.kind().skip_binder() + && let ty::Clause::Projection(projection_pred) = clause + { + p.kind() + .rebind(ty::ProjectionPredicate { + projection_ty: projection_pred.projection_ty.fold_with(self), + // Don't fold the term on the RHS of the projection predicate. + // This is because for default trait methods with RPITITs, we + // install a `NormalizesTo(Projection(RPITIT) -> Opaque(RPITIT))` + // predicate, which would trivially cause a cycle when we do + // anything that requires `ParamEnv::with_reveal_all_normalized`. + term: projection_pred.term, + }) + .to_predicate(self.tcx) + } else { + p.super_fold_with(self) + } + } } impl<'tcx> Ty<'tcx> { diff --git a/compiler/rustc_trait_selection/src/traits/project.rs b/compiler/rustc_trait_selection/src/traits/project.rs index 9b3249e58e8db..1c66fb257ebb5 100644 --- a/compiler/rustc_trait_selection/src/traits/project.rs +++ b/compiler/rustc_trait_selection/src/traits/project.rs @@ -90,15 +90,7 @@ enum ProjectionCandidate<'tcx> { /// From an "impl" (or a "pseudo-impl" returned by select) Select(Selection<'tcx>), - ImplTraitInTrait(ImplTraitInTraitCandidate<'tcx>), -} - -#[derive(PartialEq, Eq, Debug)] -enum ImplTraitInTraitCandidate<'tcx> { - // The `impl Trait` from a trait function's default body - Trait, - // A concrete type provided from a trait's `impl Trait` from an impl - Impl(ImplSourceUserDefinedData<'tcx, PredicateObligation<'tcx>>), + ImplTraitInTrait(ImplSourceUserDefinedData<'tcx, PredicateObligation<'tcx>>), } enum ProjectionCandidateSet<'tcx> { @@ -1292,17 +1284,6 @@ fn assemble_candidate_for_impl_trait_in_trait<'cx, 'tcx>( let tcx = selcx.tcx(); if tcx.def_kind(obligation.predicate.def_id) == DefKind::ImplTraitPlaceholder { let trait_fn_def_id = tcx.impl_trait_in_trait_parent(obligation.predicate.def_id); - // If we are trying to project an RPITIT with trait's default `Self` parameter, - // then we must be within a default trait body. - if obligation.predicate.self_ty() - == ty::InternalSubsts::identity_for_item(tcx, obligation.predicate.def_id).type_at(0) - && tcx.associated_item(trait_fn_def_id).defaultness(tcx).has_value() - { - candidate_set.push_candidate(ProjectionCandidate::ImplTraitInTrait( - ImplTraitInTraitCandidate::Trait, - )); - return; - } let trait_def_id = tcx.parent(trait_fn_def_id); let trait_substs = @@ -1313,9 +1294,7 @@ fn assemble_candidate_for_impl_trait_in_trait<'cx, 'tcx>( let _ = selcx.infcx.commit_if_ok(|_| { match selcx.select(&obligation.with(tcx, trait_predicate)) { Ok(Some(super::ImplSource::UserDefined(data))) => { - candidate_set.push_candidate(ProjectionCandidate::ImplTraitInTrait( - ImplTraitInTraitCandidate::Impl(data), - )); + candidate_set.push_candidate(ProjectionCandidate::ImplTraitInTrait(data)); Ok(()) } Ok(None) => { @@ -1777,18 +1756,9 @@ fn confirm_candidate<'cx, 'tcx>( ProjectionCandidate::Select(impl_source) => { confirm_select_candidate(selcx, obligation, impl_source) } - ProjectionCandidate::ImplTraitInTrait(ImplTraitInTraitCandidate::Impl(data)) => { + ProjectionCandidate::ImplTraitInTrait(data) => { confirm_impl_trait_in_trait_candidate(selcx, obligation, data) } - // If we're projecting an RPITIT for a default trait body, that's just - // the same def-id, but as an opaque type (with regular RPIT semantics). - ProjectionCandidate::ImplTraitInTrait(ImplTraitInTraitCandidate::Trait) => Progress { - term: selcx - .tcx() - .mk_opaque(obligation.predicate.def_id, obligation.predicate.substs) - .into(), - obligations: vec![], - }, }; // When checking for cycle during evaluation, we compare predicates with diff --git a/compiler/rustc_ty_utils/src/ty.rs b/compiler/rustc_ty_utils/src/ty.rs index 2c50b766d21f6..f1af0073e4da9 100644 --- a/compiler/rustc_ty_utils/src/ty.rs +++ b/compiler/rustc_ty_utils/src/ty.rs @@ -1,8 +1,12 @@ -use rustc_data_structures::fx::FxIndexSet; +use rustc_data_structures::fx::{FxHashSet, FxIndexSet}; use rustc_hir as hir; +use rustc_hir::def::DefKind; use rustc_index::bit_set::BitSet; +#[cfg(not(bootstrap))] +use rustc_middle::ty::ir::TypeVisitable; use rustc_middle::ty::{ - self, Binder, EarlyBinder, Predicate, PredicateKind, ToPredicate, Ty, TyCtxt, + self, ir::TypeVisitor, Binder, EarlyBinder, Predicate, PredicateKind, ToPredicate, Ty, TyCtxt, + TypeSuperVisitable, }; use rustc_session::config::TraitSolver; use rustc_span::def_id::{DefId, CRATE_DEF_ID}; @@ -136,6 +140,19 @@ fn param_env(tcx: TyCtxt<'_>, def_id: DefId) -> ty::ParamEnv<'_> { predicates.extend(environment); } + if tcx.def_kind(def_id) == DefKind::AssocFn + && tcx.associated_item(def_id).container == ty::AssocItemContainer::TraitContainer + { + let sig = tcx.fn_sig(def_id).subst_identity(); + sig.visit_with(&mut ImplTraitInTraitFinder { + tcx, + fn_def_id: def_id, + bound_vars: sig.bound_vars(), + predicates: &mut predicates, + seen: FxHashSet::default(), + }); + } + let local_did = def_id.as_local(); let hir_id = local_did.map(|def_id| tcx.hir().local_def_id_to_hir_id(def_id)); @@ -222,6 +239,46 @@ fn param_env(tcx: TyCtxt<'_>, def_id: DefId) -> ty::ParamEnv<'_> { traits::normalize_param_env_or_error(tcx, unnormalized_env, cause) } +/// Walk through a function type, gathering all RPITITs and installing a +/// `NormalizesTo(Projection(RPITIT) -> Opaque(RPITIT))` predicate into the +/// predicates list. This allows us to observe that an RPITIT projects to +/// its corresponding opaque within the body of a default-body trait method. +struct ImplTraitInTraitFinder<'a, 'tcx> { + tcx: TyCtxt<'tcx>, + predicates: &'a mut Vec>, + fn_def_id: DefId, + bound_vars: &'tcx ty::List, + seen: FxHashSet, +} + +impl<'tcx> TypeVisitor> for ImplTraitInTraitFinder<'_, 'tcx> { + fn visit_ty(&mut self, ty: Ty<'tcx>) -> std::ops::ControlFlow { + if let ty::Alias(ty::Projection, alias_ty) = *ty.kind() + && self.tcx.def_kind(alias_ty.def_id) == DefKind::ImplTraitPlaceholder + && self.tcx.impl_trait_in_trait_parent(alias_ty.def_id) == self.fn_def_id + && self.seen.insert(alias_ty.def_id) + { + self.predicates.push( + ty::Binder::bind_with_vars( + ty::ProjectionPredicate { + projection_ty: alias_ty, + term: self.tcx.mk_alias(ty::Opaque, alias_ty).into(), + }, + self.bound_vars, + ) + .to_predicate(self.tcx), + ); + + for bound in self.tcx.item_bounds(alias_ty.def_id).subst_iter(self.tcx, alias_ty.substs) + { + bound.visit_with(self); + } + } + + ty.super_visit_with(self) + } +} + /// Elaborate the environment. /// /// Collect a list of `Predicate`'s used for building the `ParamEnv`. Adds `TypeWellFormedFromEnv`'s diff --git a/tests/ui/async-await/in-trait/async-default-fn-overridden.rs b/tests/ui/async-await/in-trait/async-default-fn-overridden.rs new file mode 100644 index 0000000000000..0fd1a2703db99 --- /dev/null +++ b/tests/ui/async-await/in-trait/async-default-fn-overridden.rs @@ -0,0 +1,66 @@ +// run-pass +// edition:2021 + +#![feature(async_fn_in_trait)] +//~^ WARN the feature `async_fn_in_trait` is incomplete and may not be safe to use + +use std::future::Future; + +trait AsyncTrait { + async fn default_impl() { + assert!(false); + } + + async fn call_default_impl() { + Self::default_impl().await + } +} + +struct AsyncType; + +impl AsyncTrait for AsyncType { + async fn default_impl() { + // :) + } +} + +async fn async_main() { + // Should not assert false + AsyncType::call_default_impl().await; +} + +// ------------------------------------------------------------------------- // +// Implementation Details Below... + +use std::pin::Pin; +use std::task::*; + +pub fn noop_waker() -> Waker { + let raw = RawWaker::new(std::ptr::null(), &NOOP_WAKER_VTABLE); + + // SAFETY: the contracts for RawWaker and RawWakerVTable are upheld + unsafe { Waker::from_raw(raw) } +} + +const NOOP_WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(noop_clone, noop, noop, noop); + +unsafe fn noop_clone(_p: *const ()) -> RawWaker { + RawWaker::new(std::ptr::null(), &NOOP_WAKER_VTABLE) +} + +unsafe fn noop(_p: *const ()) {} + +fn main() { + let mut fut = async_main(); + + // Poll loop, just to test the future... + let waker = noop_waker(); + let ctx = &mut Context::from_waker(&waker); + + loop { + match unsafe { Pin::new_unchecked(&mut fut).poll(ctx) } { + Poll::Pending => {} + Poll::Ready(()) => break, + } + } +} diff --git a/tests/ui/async-await/in-trait/async-default-fn-overridden.stderr b/tests/ui/async-await/in-trait/async-default-fn-overridden.stderr new file mode 100644 index 0000000000000..61a826258d09f --- /dev/null +++ b/tests/ui/async-await/in-trait/async-default-fn-overridden.stderr @@ -0,0 +1,11 @@ +warning: the feature `async_fn_in_trait` is incomplete and may not be safe to use and/or cause compiler crashes + --> $DIR/async-default-fn-overridden.rs:4:12 + | +LL | #![feature(async_fn_in_trait)] + | ^^^^^^^^^^^^^^^^^ + | + = note: see issue #91611 for more information + = note: `#[warn(incomplete_features)]` on by default + +warning: 1 warning emitted +