Skip to content

Commit

Permalink
Rollup merge of #130866 - compiler-errors:dyn-instantiate-binder, r=lcnr
Browse files Browse the repository at this point in the history
Allow instantiating object trait binder when upcasting

This PR fixes two bugs (that probably need an FCP).

### We use equality rather than subtyping for upcasting dyn conversions

This code should be valid:

```rust
#![feature(trait_upcasting)]

trait Foo: for<'h> Bar<'h> {}
trait Bar<'a> {}

fn foo(x: &dyn Foo) {
    let y: &dyn Bar<'static> = x;
}
```
But instead:

```
error[E0308]: mismatched types
 --> src/lib.rs:7:32
  |
7 |     let y: &dyn Bar<'static> = x;
  |                                ^ one type is more general than the other
  |
  = note: expected existential trait ref `for<'h> Bar<'h>`
             found existential trait ref `Bar<'_>`
```

And so should this:

```rust
#![feature(trait_upcasting)]

fn foo(x: &dyn for<'h> Fn(&'h ())) {
    let y: &dyn FnOnce(&'static ()) = x;
}
```

But instead:

```
error[E0308]: mismatched types
 --> src/lib.rs:4:39
  |
4 |     let y: &dyn FnOnce(&'static ()) = x;
  |                                       ^ one type is more general than the other
  |
  = note: expected existential trait ref `for<'h> FnOnce<(&'h (),)>`
             found existential trait ref `FnOnce<(&(),)>`
```

Specifically, both of these fail because we use *equality* when comparing the supertrait to the *target* of the unsize goal. For the first example, since our supertrait is `for<'h> Bar<'h>` but our target is `Bar<'static>`, there's a higher-ranked type mismatch even though we *should* be able to instantiate that supertrait binder when upcasting. Similarly for the second example.

### New solver uses equality rather than subtyping for no-op (i.e. non-upcasting) dyn conversions

This code should be valid in the new solver, like it is with the old solver:

```rust
// -Znext-solver

fn foo<'a>(x: &mut for<'h> dyn Fn(&'h ())) {
   let _: &mut dyn Fn(&'a ()) = x;
}
```

But instead:

```
error: lifetime may not live long enough
 --> <source>:2:11
  |
1 | fn foo<'a>(x: &mut dyn for<'h> Fn(&'h ())) {
  |        -- lifetime `'a` defined here
2 |    let _: &mut dyn Fn(&'a ()) = x;
  |           ^^^^^^^^^^^^^^^^^^^ type annotation requires that `'a` must outlive `'static`
  |
  = note: requirement occurs because of a mutable reference to `dyn Fn(&())`
```

Specifically, this fails because we try to coerce `&mut dyn for<'h> Fn(&'h ())` to `&mut dyn Fn(&'a ())`, which registers an `dyn for<'h> Fn(&'h ()): dyn Fn(&'a ())` goal. This fails because the new solver uses *equating* rather than *subtyping* in `Unsize` goals.

This is *mostly* not a problem... You may wonder why the same code passes on the new solver for immutable references:

```
// -Znext-solver

fn foo<'a>(x: &dyn Fn(&())) {
   let _: &dyn Fn(&'a ()) = x; // works
}
```

That's because in this case, we first try to coerce via `Unsize`, but due to the leak check the goal fails. Then, later in coercion, we fall back to a simple subtyping operation, which *does* work.

Since `&T` is covariant over `T`, but `&mut T` is invariant, that's where the discrepancy between these two examples crops up.

---

r? lcnr or reassign :D
  • Loading branch information
matthiaskrgr authored Sep 28, 2024
2 parents 6e9db86 + d753aba commit 4e510da
Show file tree
Hide file tree
Showing 10 changed files with 228 additions and 201 deletions.
164 changes: 54 additions & 110 deletions compiler/rustc_infer/src/infer/at.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,7 @@ impl<'tcx> InferCtxt<'tcx> {
}

pub trait ToTrace<'tcx>: Relate<TyCtxt<'tcx>> + Copy {
fn to_trace(
cause: &ObligationCause<'tcx>,
a_is_expected: bool,
a: Self,
b: Self,
) -> TypeTrace<'tcx>;
fn to_trace(cause: &ObligationCause<'tcx>, a: Self, b: Self) -> TypeTrace<'tcx>;
}

impl<'a, 'tcx> At<'a, 'tcx> {
Expand All @@ -116,7 +111,7 @@ impl<'a, 'tcx> At<'a, 'tcx> {
{
let mut fields = CombineFields::new(
self.infcx,
ToTrace::to_trace(self.cause, true, expected, actual),
ToTrace::to_trace(self.cause, expected, actual),
self.param_env,
define_opaque_types,
);
Expand All @@ -136,7 +131,7 @@ impl<'a, 'tcx> At<'a, 'tcx> {
{
let mut fields = CombineFields::new(
self.infcx,
ToTrace::to_trace(self.cause, true, expected, actual),
ToTrace::to_trace(self.cause, expected, actual),
self.param_env,
define_opaque_types,
);
Expand All @@ -154,12 +149,26 @@ impl<'a, 'tcx> At<'a, 'tcx> {
where
T: ToTrace<'tcx>,
{
let mut fields = CombineFields::new(
self.infcx,
ToTrace::to_trace(self.cause, true, expected, actual),
self.param_env,
self.eq_trace(
define_opaque_types,
);
ToTrace::to_trace(self.cause, expected, actual),
expected,
actual,
)
}

/// Makes `expected == actual`.
pub fn eq_trace<T>(
self,
define_opaque_types: DefineOpaqueTypes,
trace: TypeTrace<'tcx>,
expected: T,
actual: T,
) -> InferResult<'tcx, ()>
where
T: Relate<TyCtxt<'tcx>>,
{
let mut fields = CombineFields::new(self.infcx, trace, self.param_env, define_opaque_types);
fields.equate(StructurallyRelateAliases::No).relate(expected, actual)?;
Ok(InferOk {
value: (),
Expand Down Expand Up @@ -192,7 +201,7 @@ impl<'a, 'tcx> At<'a, 'tcx> {
assert!(self.infcx.next_trait_solver());
let mut fields = CombineFields::new(
self.infcx,
ToTrace::to_trace(self.cause, true, expected, actual),
ToTrace::to_trace(self.cause, expected, actual),
self.param_env,
DefineOpaqueTypes::Yes,
);
Expand Down Expand Up @@ -284,7 +293,7 @@ impl<'a, 'tcx> At<'a, 'tcx> {
{
let mut fields = CombineFields::new(
self.infcx,
ToTrace::to_trace(self.cause, true, expected, actual),
ToTrace::to_trace(self.cause, expected, actual),
self.param_env,
define_opaque_types,
);
Expand All @@ -306,7 +315,7 @@ impl<'a, 'tcx> At<'a, 'tcx> {
{
let mut fields = CombineFields::new(
self.infcx,
ToTrace::to_trace(self.cause, true, expected, actual),
ToTrace::to_trace(self.cause, expected, actual),
self.param_env,
define_opaque_types,
);
Expand All @@ -316,18 +325,13 @@ impl<'a, 'tcx> At<'a, 'tcx> {
}

impl<'tcx> ToTrace<'tcx> for ImplSubject<'tcx> {
fn to_trace(
cause: &ObligationCause<'tcx>,
a_is_expected: bool,
a: Self,
b: Self,
) -> TypeTrace<'tcx> {
fn to_trace(cause: &ObligationCause<'tcx>, a: Self, b: Self) -> TypeTrace<'tcx> {
match (a, b) {
(ImplSubject::Trait(trait_ref_a), ImplSubject::Trait(trait_ref_b)) => {
ToTrace::to_trace(cause, a_is_expected, trait_ref_a, trait_ref_b)
ToTrace::to_trace(cause, trait_ref_a, trait_ref_b)
}
(ImplSubject::Inherent(ty_a), ImplSubject::Inherent(ty_b)) => {
ToTrace::to_trace(cause, a_is_expected, ty_a, ty_b)
ToTrace::to_trace(cause, ty_a, ty_b)
}
(ImplSubject::Trait(_), ImplSubject::Inherent(_))
| (ImplSubject::Inherent(_), ImplSubject::Trait(_)) => {
Expand All @@ -338,65 +342,45 @@ impl<'tcx> ToTrace<'tcx> for ImplSubject<'tcx> {
}

impl<'tcx> ToTrace<'tcx> for Ty<'tcx> {
fn to_trace(
cause: &ObligationCause<'tcx>,
a_is_expected: bool,
a: Self,
b: Self,
) -> TypeTrace<'tcx> {
fn to_trace(cause: &ObligationCause<'tcx>, a: Self, b: Self) -> TypeTrace<'tcx> {
TypeTrace {
cause: cause.clone(),
values: ValuePairs::Terms(ExpectedFound::new(a_is_expected, a.into(), b.into())),
values: ValuePairs::Terms(ExpectedFound::new(true, a.into(), b.into())),
}
}
}

impl<'tcx> ToTrace<'tcx> for ty::Region<'tcx> {
fn to_trace(
cause: &ObligationCause<'tcx>,
a_is_expected: bool,
a: Self,
b: Self,
) -> TypeTrace<'tcx> {
fn to_trace(cause: &ObligationCause<'tcx>, a: Self, b: Self) -> TypeTrace<'tcx> {
TypeTrace {
cause: cause.clone(),
values: ValuePairs::Regions(ExpectedFound::new(a_is_expected, a, b)),
values: ValuePairs::Regions(ExpectedFound::new(true, a, b)),
}
}
}

impl<'tcx> ToTrace<'tcx> for Const<'tcx> {
fn to_trace(
cause: &ObligationCause<'tcx>,
a_is_expected: bool,
a: Self,
b: Self,
) -> TypeTrace<'tcx> {
fn to_trace(cause: &ObligationCause<'tcx>, a: Self, b: Self) -> TypeTrace<'tcx> {
TypeTrace {
cause: cause.clone(),
values: ValuePairs::Terms(ExpectedFound::new(a_is_expected, a.into(), b.into())),
values: ValuePairs::Terms(ExpectedFound::new(true, a.into(), b.into())),
}
}
}

impl<'tcx> ToTrace<'tcx> for ty::GenericArg<'tcx> {
fn to_trace(
cause: &ObligationCause<'tcx>,
a_is_expected: bool,
a: Self,
b: Self,
) -> TypeTrace<'tcx> {
fn to_trace(cause: &ObligationCause<'tcx>, a: Self, b: Self) -> TypeTrace<'tcx> {
TypeTrace {
cause: cause.clone(),
values: match (a.unpack(), b.unpack()) {
(GenericArgKind::Lifetime(a), GenericArgKind::Lifetime(b)) => {
ValuePairs::Regions(ExpectedFound::new(a_is_expected, a, b))
ValuePairs::Regions(ExpectedFound::new(true, a, b))
}
(GenericArgKind::Type(a), GenericArgKind::Type(b)) => {
ValuePairs::Terms(ExpectedFound::new(a_is_expected, a.into(), b.into()))
ValuePairs::Terms(ExpectedFound::new(true, a.into(), b.into()))
}
(GenericArgKind::Const(a), GenericArgKind::Const(b)) => {
ValuePairs::Terms(ExpectedFound::new(a_is_expected, a.into(), b.into()))
ValuePairs::Terms(ExpectedFound::new(true, a.into(), b.into()))
}

(
Expand All @@ -419,72 +403,47 @@ impl<'tcx> ToTrace<'tcx> for ty::GenericArg<'tcx> {
}

impl<'tcx> ToTrace<'tcx> for ty::Term<'tcx> {
fn to_trace(
cause: &ObligationCause<'tcx>,
a_is_expected: bool,
a: Self,
b: Self,
) -> TypeTrace<'tcx> {
fn to_trace(cause: &ObligationCause<'tcx>, a: Self, b: Self) -> TypeTrace<'tcx> {
TypeTrace {
cause: cause.clone(),
values: ValuePairs::Terms(ExpectedFound::new(a_is_expected, a, b)),
values: ValuePairs::Terms(ExpectedFound::new(true, a, b)),
}
}
}

impl<'tcx> ToTrace<'tcx> for ty::TraitRef<'tcx> {
fn to_trace(
cause: &ObligationCause<'tcx>,
a_is_expected: bool,
a: Self,
b: Self,
) -> TypeTrace<'tcx> {
fn to_trace(cause: &ObligationCause<'tcx>, a: Self, b: Self) -> TypeTrace<'tcx> {
TypeTrace {
cause: cause.clone(),
values: ValuePairs::TraitRefs(ExpectedFound::new(a_is_expected, a, b)),
values: ValuePairs::TraitRefs(ExpectedFound::new(true, a, b)),
}
}
}

impl<'tcx> ToTrace<'tcx> for ty::AliasTy<'tcx> {
fn to_trace(
cause: &ObligationCause<'tcx>,
a_is_expected: bool,
a: Self,
b: Self,
) -> TypeTrace<'tcx> {
fn to_trace(cause: &ObligationCause<'tcx>, a: Self, b: Self) -> TypeTrace<'tcx> {
TypeTrace {
cause: cause.clone(),
values: ValuePairs::Aliases(ExpectedFound::new(a_is_expected, a.into(), b.into())),
values: ValuePairs::Aliases(ExpectedFound::new(true, a.into(), b.into())),
}
}
}

impl<'tcx> ToTrace<'tcx> for ty::AliasTerm<'tcx> {
fn to_trace(
cause: &ObligationCause<'tcx>,
a_is_expected: bool,
a: Self,
b: Self,
) -> TypeTrace<'tcx> {
fn to_trace(cause: &ObligationCause<'tcx>, a: Self, b: Self) -> TypeTrace<'tcx> {
TypeTrace {
cause: cause.clone(),
values: ValuePairs::Aliases(ExpectedFound::new(a_is_expected, a, b)),
values: ValuePairs::Aliases(ExpectedFound::new(true, a, b)),
}
}
}

impl<'tcx> ToTrace<'tcx> for ty::FnSig<'tcx> {
fn to_trace(
cause: &ObligationCause<'tcx>,
a_is_expected: bool,
a: Self,
b: Self,
) -> TypeTrace<'tcx> {
fn to_trace(cause: &ObligationCause<'tcx>, a: Self, b: Self) -> TypeTrace<'tcx> {
TypeTrace {
cause: cause.clone(),
values: ValuePairs::PolySigs(ExpectedFound::new(
a_is_expected,
true,
ty::Binder::dummy(a),
ty::Binder::dummy(b),
)),
Expand All @@ -493,43 +452,28 @@ impl<'tcx> ToTrace<'tcx> for ty::FnSig<'tcx> {
}

impl<'tcx> ToTrace<'tcx> for ty::PolyFnSig<'tcx> {
fn to_trace(
cause: &ObligationCause<'tcx>,
a_is_expected: bool,
a: Self,
b: Self,
) -> TypeTrace<'tcx> {
fn to_trace(cause: &ObligationCause<'tcx>, a: Self, b: Self) -> TypeTrace<'tcx> {
TypeTrace {
cause: cause.clone(),
values: ValuePairs::PolySigs(ExpectedFound::new(a_is_expected, a, b)),
values: ValuePairs::PolySigs(ExpectedFound::new(true, a, b)),
}
}
}

impl<'tcx> ToTrace<'tcx> for ty::PolyExistentialTraitRef<'tcx> {
fn to_trace(
cause: &ObligationCause<'tcx>,
a_is_expected: bool,
a: Self,
b: Self,
) -> TypeTrace<'tcx> {
fn to_trace(cause: &ObligationCause<'tcx>, a: Self, b: Self) -> TypeTrace<'tcx> {
TypeTrace {
cause: cause.clone(),
values: ValuePairs::ExistentialTraitRef(ExpectedFound::new(a_is_expected, a, b)),
values: ValuePairs::ExistentialTraitRef(ExpectedFound::new(true, a, b)),
}
}
}

impl<'tcx> ToTrace<'tcx> for ty::PolyExistentialProjection<'tcx> {
fn to_trace(
cause: &ObligationCause<'tcx>,
a_is_expected: bool,
a: Self,
b: Self,
) -> TypeTrace<'tcx> {
fn to_trace(cause: &ObligationCause<'tcx>, a: Self, b: Self) -> TypeTrace<'tcx> {
TypeTrace {
cause: cause.clone(),
values: ValuePairs::ExistentialProjection(ExpectedFound::new(a_is_expected, a, b)),
values: ValuePairs::ExistentialProjection(ExpectedFound::new(true, a, b)),
}
}
}
16 changes: 9 additions & 7 deletions compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -448,10 +448,10 @@ where
}
}
} else {
self.delegate.enter_forall(kind, |kind| {
let goal = goal.with(self.cx(), ty::Binder::dummy(kind));
self.add_goal(GoalSource::InstantiateHigherRanked, goal);
self.evaluate_added_goals_and_make_canonical_response(Certainty::Yes)
self.enter_forall(kind, |ecx, kind| {
let goal = goal.with(ecx.cx(), ty::Binder::dummy(kind));
ecx.add_goal(GoalSource::InstantiateHigherRanked, goal);
ecx.evaluate_added_goals_and_make_canonical_response(Certainty::Yes)
})
}
}
Expand Down Expand Up @@ -840,12 +840,14 @@ where
self.delegate.instantiate_binder_with_infer(value)
}

/// `enter_forall`, but takes `&mut self` and passes it back through the
/// callback since it can't be aliased during the call.
pub(super) fn enter_forall<T: TypeFoldable<I> + Copy, U>(
&self,
&mut self,
value: ty::Binder<I, T>,
f: impl FnOnce(T) -> U,
f: impl FnOnce(&mut Self, T) -> U,
) -> U {
self.delegate.enter_forall(value, f)
self.delegate.enter_forall(value, |value| f(self, value))
}

pub(super) fn resolve_vars_if_possible<T>(&self, value: T) -> T
Expand Down
Loading

0 comments on commit 4e510da

Please sign in to comment.