Skip to content

Commit

Permalink
Rollup merge of #108203 - compiler-errors:rpitit-fix-defaults-2, r=ja…
Browse files Browse the repository at this point in the history
…ckh726

Fix RPITITs in default trait methods (by assuming projection predicates in param-env)

Instead of having special projection logic that allows us to turn `ProjectionTy(RPITIT, [Self#0, ...])` into `OpaqueTy(RPITIT, [Self#0, ...])`, we can instead augment the param-env of default trait method bodies to assume these as projection predicates. This should allow us to only project where we're allowed to!

In order to make this work without introducing a bunch of cycle errors, we additionally tweak the `OpaqueTypeExpander` used by `ParamEnv::with_reveal_all_normalized` to not normalize the right-hand side of projection predicates. This should be fine, because if we use the projection predicate to normalize some other projection type, we'll continue to normalize the opaque that it gets projected to.

This also makes it possible to support default trait methods with RPITITs in an associated-type based RPITIT lowering strategy without too much extra effort.

Fixes #107002
Alternative to #108142
  • Loading branch information
GuillaumeGomez authored Feb 19, 2023
2 parents 243dcd0 + 3e57b20 commit d2aef58
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 37 deletions.
2 changes: 1 addition & 1 deletion compiler/rustc_hir_analysis/src/check/wfcheck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1599,7 +1599,7 @@ fn check_return_position_impl_trait_in_trait_bounds<'tcx>(
{
for arg in fn_output.walk() {
if let ty::GenericArgKind::Type(ty) = arg.unpack()
&& let ty::Alias(ty::Projection, proj) = ty.kind()
&& let ty::Alias(ty::Opaque, proj) = ty.kind()
&& tcx.def_kind(proj.def_id) == DefKind::ImplTraitPlaceholder
&& tcx.impl_trait_in_trait_parent(proj.def_id) == fn_def_id.to_def_id()
{
Expand Down
22 changes: 21 additions & 1 deletion compiler/rustc_middle/src/ty/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::middle::codegen_fn_attrs::CodegenFnAttrFlags;
use crate::mir;
use crate::ty::layout::IntegerExt;
use crate::ty::{
self, ir::TypeFolder, DefIdTree, FallibleTypeFolder, Ty, TyCtxt, TypeFoldable,
self, ir::TypeFolder, DefIdTree, FallibleTypeFolder, ToPredicate, Ty, TyCtxt, TypeFoldable,
TypeSuperFoldable,
};
use crate::ty::{GenericArgKind, SubstsRef};
Expand Down Expand Up @@ -865,6 +865,26 @@ impl<'tcx> TypeFolder<TyCtxt<'tcx>> for OpaqueTypeExpander<'tcx> {
}
t
}

fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> {
if let ty::PredicateKind::Clause(clause) = p.kind().skip_binder()
&& let ty::Clause::Projection(projection_pred) = clause
{
p.kind()
.rebind(ty::ProjectionPredicate {
projection_ty: projection_pred.projection_ty.fold_with(self),
// Don't fold the term on the RHS of the projection predicate.
// This is because for default trait methods with RPITITs, we
// install a `NormalizesTo(Projection(RPITIT) -> Opaque(RPITIT))`
// predicate, which would trivially cause a cycle when we do
// anything that requires `ParamEnv::with_reveal_all_normalized`.
term: projection_pred.term,
})
.to_predicate(self.tcx)
} else {
p.super_fold_with(self)
}
}
}

impl<'tcx> Ty<'tcx> {
Expand Down
36 changes: 3 additions & 33 deletions compiler/rustc_trait_selection/src/traits/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,7 @@ enum ProjectionCandidate<'tcx> {
/// From an "impl" (or a "pseudo-impl" returned by select)
Select(Selection<'tcx>),

ImplTraitInTrait(ImplTraitInTraitCandidate<'tcx>),
}

#[derive(PartialEq, Eq, Debug)]
enum ImplTraitInTraitCandidate<'tcx> {
// The `impl Trait` from a trait function's default body
Trait,
// A concrete type provided from a trait's `impl Trait` from an impl
Impl(ImplSourceUserDefinedData<'tcx, PredicateObligation<'tcx>>),
ImplTraitInTrait(ImplSourceUserDefinedData<'tcx, PredicateObligation<'tcx>>),
}

enum ProjectionCandidateSet<'tcx> {
Expand Down Expand Up @@ -1292,17 +1284,6 @@ fn assemble_candidate_for_impl_trait_in_trait<'cx, 'tcx>(
let tcx = selcx.tcx();
if tcx.def_kind(obligation.predicate.def_id) == DefKind::ImplTraitPlaceholder {
let trait_fn_def_id = tcx.impl_trait_in_trait_parent(obligation.predicate.def_id);
// If we are trying to project an RPITIT with trait's default `Self` parameter,
// then we must be within a default trait body.
if obligation.predicate.self_ty()
== ty::InternalSubsts::identity_for_item(tcx, obligation.predicate.def_id).type_at(0)
&& tcx.associated_item(trait_fn_def_id).defaultness(tcx).has_value()
{
candidate_set.push_candidate(ProjectionCandidate::ImplTraitInTrait(
ImplTraitInTraitCandidate::Trait,
));
return;
}

let trait_def_id = tcx.parent(trait_fn_def_id);
let trait_substs =
Expand All @@ -1313,9 +1294,7 @@ fn assemble_candidate_for_impl_trait_in_trait<'cx, 'tcx>(
let _ = selcx.infcx.commit_if_ok(|_| {
match selcx.select(&obligation.with(tcx, trait_predicate)) {
Ok(Some(super::ImplSource::UserDefined(data))) => {
candidate_set.push_candidate(ProjectionCandidate::ImplTraitInTrait(
ImplTraitInTraitCandidate::Impl(data),
));
candidate_set.push_candidate(ProjectionCandidate::ImplTraitInTrait(data));
Ok(())
}
Ok(None) => {
Expand Down Expand Up @@ -1777,18 +1756,9 @@ fn confirm_candidate<'cx, 'tcx>(
ProjectionCandidate::Select(impl_source) => {
confirm_select_candidate(selcx, obligation, impl_source)
}
ProjectionCandidate::ImplTraitInTrait(ImplTraitInTraitCandidate::Impl(data)) => {
ProjectionCandidate::ImplTraitInTrait(data) => {
confirm_impl_trait_in_trait_candidate(selcx, obligation, data)
}
// If we're projecting an RPITIT for a default trait body, that's just
// the same def-id, but as an opaque type (with regular RPIT semantics).
ProjectionCandidate::ImplTraitInTrait(ImplTraitInTraitCandidate::Trait) => Progress {
term: selcx
.tcx()
.mk_opaque(obligation.predicate.def_id, obligation.predicate.substs)
.into(),
obligations: vec![],
},
};

// When checking for cycle during evaluation, we compare predicates with
Expand Down
61 changes: 59 additions & 2 deletions compiler/rustc_ty_utils/src/ty.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
use rustc_data_structures::fx::FxIndexSet;
use rustc_data_structures::fx::{FxHashSet, FxIndexSet};
use rustc_hir as hir;
use rustc_hir::def::DefKind;
use rustc_index::bit_set::BitSet;
#[cfg(not(bootstrap))]
use rustc_middle::ty::ir::TypeVisitable;
use rustc_middle::ty::{
self, Binder, EarlyBinder, Predicate, PredicateKind, ToPredicate, Ty, TyCtxt,
self, ir::TypeVisitor, Binder, EarlyBinder, Predicate, PredicateKind, ToPredicate, Ty, TyCtxt,
TypeSuperVisitable,
};
use rustc_session::config::TraitSolver;
use rustc_span::def_id::{DefId, CRATE_DEF_ID};
Expand Down Expand Up @@ -136,6 +140,19 @@ fn param_env(tcx: TyCtxt<'_>, def_id: DefId) -> ty::ParamEnv<'_> {
predicates.extend(environment);
}

if tcx.def_kind(def_id) == DefKind::AssocFn
&& tcx.associated_item(def_id).container == ty::AssocItemContainer::TraitContainer
{
let sig = tcx.fn_sig(def_id).subst_identity();
sig.visit_with(&mut ImplTraitInTraitFinder {
tcx,
fn_def_id: def_id,
bound_vars: sig.bound_vars(),
predicates: &mut predicates,
seen: FxHashSet::default(),
});
}

let local_did = def_id.as_local();
let hir_id = local_did.map(|def_id| tcx.hir().local_def_id_to_hir_id(def_id));

Expand Down Expand Up @@ -222,6 +239,46 @@ fn param_env(tcx: TyCtxt<'_>, def_id: DefId) -> ty::ParamEnv<'_> {
traits::normalize_param_env_or_error(tcx, unnormalized_env, cause)
}

/// Walk through a function type, gathering all RPITITs and installing a
/// `NormalizesTo(Projection(RPITIT) -> Opaque(RPITIT))` predicate into the
/// predicates list. This allows us to observe that an RPITIT projects to
/// its corresponding opaque within the body of a default-body trait method.
struct ImplTraitInTraitFinder<'a, 'tcx> {
tcx: TyCtxt<'tcx>,
predicates: &'a mut Vec<Predicate<'tcx>>,
fn_def_id: DefId,
bound_vars: &'tcx ty::List<ty::BoundVariableKind>,
seen: FxHashSet<DefId>,
}

impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for ImplTraitInTraitFinder<'_, 'tcx> {
fn visit_ty(&mut self, ty: Ty<'tcx>) -> std::ops::ControlFlow<Self::BreakTy> {
if let ty::Alias(ty::Projection, alias_ty) = *ty.kind()
&& self.tcx.def_kind(alias_ty.def_id) == DefKind::ImplTraitPlaceholder
&& self.tcx.impl_trait_in_trait_parent(alias_ty.def_id) == self.fn_def_id
&& self.seen.insert(alias_ty.def_id)
{
self.predicates.push(
ty::Binder::bind_with_vars(
ty::ProjectionPredicate {
projection_ty: alias_ty,
term: self.tcx.mk_alias(ty::Opaque, alias_ty).into(),
},
self.bound_vars,
)
.to_predicate(self.tcx),
);

for bound in self.tcx.item_bounds(alias_ty.def_id).subst_iter(self.tcx, alias_ty.substs)
{
bound.visit_with(self);
}
}

ty.super_visit_with(self)
}
}

/// Elaborate the environment.
///
/// Collect a list of `Predicate`'s used for building the `ParamEnv`. Adds `TypeWellFormedFromEnv`'s
Expand Down
66 changes: 66 additions & 0 deletions tests/ui/async-await/in-trait/async-default-fn-overridden.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// run-pass
// edition:2021

#![feature(async_fn_in_trait)]
//~^ WARN the feature `async_fn_in_trait` is incomplete and may not be safe to use

use std::future::Future;

trait AsyncTrait {
async fn default_impl() {
assert!(false);
}

async fn call_default_impl() {
Self::default_impl().await
}
}

struct AsyncType;

impl AsyncTrait for AsyncType {
async fn default_impl() {
// :)
}
}

async fn async_main() {
// Should not assert false
AsyncType::call_default_impl().await;
}

// ------------------------------------------------------------------------- //
// Implementation Details Below...

use std::pin::Pin;
use std::task::*;

pub fn noop_waker() -> Waker {
let raw = RawWaker::new(std::ptr::null(), &NOOP_WAKER_VTABLE);

// SAFETY: the contracts for RawWaker and RawWakerVTable are upheld
unsafe { Waker::from_raw(raw) }
}

const NOOP_WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(noop_clone, noop, noop, noop);

unsafe fn noop_clone(_p: *const ()) -> RawWaker {
RawWaker::new(std::ptr::null(), &NOOP_WAKER_VTABLE)
}

unsafe fn noop(_p: *const ()) {}

fn main() {
let mut fut = async_main();

// Poll loop, just to test the future...
let waker = noop_waker();
let ctx = &mut Context::from_waker(&waker);

loop {
match unsafe { Pin::new_unchecked(&mut fut).poll(ctx) } {
Poll::Pending => {}
Poll::Ready(()) => break,
}
}
}
11 changes: 11 additions & 0 deletions tests/ui/async-await/in-trait/async-default-fn-overridden.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
warning: the feature `async_fn_in_trait` is incomplete and may not be safe to use and/or cause compiler crashes
--> $DIR/async-default-fn-overridden.rs:4:12
|
LL | #![feature(async_fn_in_trait)]
| ^^^^^^^^^^^^^^^^^
|
= note: see issue #91611 <https://github.com/rust-lang/rust/issues/91611> for more information
= note: `#[warn(incomplete_features)]` on by default

warning: 1 warning emitted

0 comments on commit d2aef58

Please sign in to comment.