From f3d32f2f0cd03885686470c48250bd6773c1b9aa Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Mon, 5 Feb 2024 18:33:41 +0000 Subject: [PATCH 1/7] Flatten confirmation logic --- .../src/solve/assembly/structural_traits.rs | 105 +++++++------ .../src/traits/project.rs | 138 +++++++----------- 2 files changed, 107 insertions(+), 136 deletions(-) diff --git a/compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs b/compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs index d02578c484649..4fd9a29c0b2e3 100644 --- a/compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs +++ b/compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs @@ -318,34 +318,27 @@ pub(in crate::solve) fn extract_tupled_inputs_and_output_from_async_callable<'tc self_ty: Ty<'tcx>, goal_kind: ty::ClosureKind, env_region: ty::Region<'tcx>, -) -> Result< - (ty::Binder<'tcx, (Ty<'tcx>, Ty<'tcx>, Ty<'tcx>)>, Option>), - NoSolution, -> { +) -> Result<(ty::Binder<'tcx, (Ty<'tcx>, Ty<'tcx>, Ty<'tcx>)>, Vec>), NoSolution> +{ match *self_ty.kind() { ty::CoroutineClosure(def_id, args) => { let args = args.as_coroutine_closure(); let kind_ty = args.kind_ty(); - - if let Some(closure_kind) = kind_ty.to_opt_closure_kind() { + let sig = args.coroutine_closure_sig().skip_binder(); + let mut nested = vec![]; + let coroutine_ty = if let Some(closure_kind) = kind_ty.to_opt_closure_kind() { if !closure_kind.extends(goal_kind) { return Err(NoSolution); } - Ok(( - args.coroutine_closure_sig().map_bound(|sig| { - let coroutine_ty = sig.to_coroutine_given_kind_and_upvars( - tcx, - args.parent_args(), - tcx.coroutine_for_closure(def_id), - goal_kind, - env_region, - args.tupled_upvars_ty(), - args.coroutine_captures_by_ref_ty(), - ); - (sig.tupled_inputs_ty, sig.return_ty, coroutine_ty) - }), - None, - )) + sig.to_coroutine_given_kind_and_upvars( + tcx, + args.parent_args(), + tcx.coroutine_for_closure(def_id), + goal_kind, + env_region, + args.tupled_upvars_ty(), + args.coroutine_captures_by_ref_ty(), + ) } else { let async_fn_kind_trait_def_id = tcx.require_lang_item(LangItem::AsyncFnKindHelper, None); @@ -362,39 +355,43 @@ pub(in crate::solve) fn extract_tupled_inputs_and_output_from_async_callable<'tc // the goal kind <= the closure kind. As a projection `AsyncFnKindHelper::Upvars` // will project to the right upvars for the generator, appending the inputs and // coroutine upvars respecting the closure kind. - Ok(( - args.coroutine_closure_sig().map_bound(|sig| { - let tupled_upvars_ty = Ty::new_projection( - tcx, - upvars_projection_def_id, - [ - ty::GenericArg::from(kind_ty), - Ty::from_closure_kind(tcx, goal_kind).into(), - env_region.into(), - sig.tupled_inputs_ty.into(), - args.tupled_upvars_ty().into(), - args.coroutine_captures_by_ref_ty().into(), - ], - ); - let coroutine_ty = sig.to_coroutine( - tcx, - args.parent_args(), - Ty::from_closure_kind(tcx, goal_kind), - tcx.coroutine_for_closure(def_id), - tupled_upvars_ty, - ); - (sig.tupled_inputs_ty, sig.return_ty, coroutine_ty) - }), - Some( - ty::TraitRef::new( - tcx, - async_fn_kind_trait_def_id, - [kind_ty, Ty::from_closure_kind(tcx, goal_kind)], - ) - .to_predicate(tcx), - ), - )) - } + nested.push( + ty::TraitRef::new( + tcx, + async_fn_kind_trait_def_id, + [kind_ty, Ty::from_closure_kind(tcx, goal_kind)], + ) + .to_predicate(tcx), + ); + let tupled_upvars_ty = Ty::new_projection( + tcx, + upvars_projection_def_id, + [ + ty::GenericArg::from(kind_ty), + Ty::from_closure_kind(tcx, goal_kind).into(), + env_region.into(), + sig.tupled_inputs_ty.into(), + args.tupled_upvars_ty().into(), + args.coroutine_captures_by_ref_ty().into(), + ], + ); + sig.to_coroutine( + tcx, + args.parent_args(), + Ty::from_closure_kind(tcx, goal_kind), + tcx.coroutine_for_closure(def_id), + tupled_upvars_ty, + ) + }; + + Ok(( + args.coroutine_closure_sig().rebind(( + sig.tupled_inputs_ty, + sig.return_ty, + coroutine_ty, + )), + nested, + )) } ty::FnDef(..) | ty::FnPtr(..) | ty::Closure(..) => Err(NoSolution), diff --git a/compiler/rustc_trait_selection/src/traits/project.rs b/compiler/rustc_trait_selection/src/traits/project.rs index 955c81eee6be3..88c28761d25f4 100644 --- a/compiler/rustc_trait_selection/src/traits/project.rs +++ b/compiler/rustc_trait_selection/src/traits/project.rs @@ -2446,8 +2446,9 @@ fn confirm_callable_candidate<'cx, 'tcx>( fn confirm_async_closure_candidate<'cx, 'tcx>( selcx: &mut SelectionContext<'cx, 'tcx>, obligation: &ProjectionTyObligation<'tcx>, - mut nested: Vec>, + nested: Vec>, ) -> Progress<'tcx> { + let tcx = selcx.tcx(); let self_ty = selcx.infcx.shallow_resolve(obligation.predicate.self_ty()); let ty::CoroutineClosure(def_id, args) = *self_ty.kind() else { unreachable!( @@ -2456,76 +2457,48 @@ fn confirm_async_closure_candidate<'cx, 'tcx>( }; let args = args.as_coroutine_closure(); let kind_ty = args.kind_ty(); + let sig = args.coroutine_closure_sig().skip_binder(); - let tcx = selcx.tcx(); let goal_kind = tcx.async_fn_trait_kind_from_def_id(obligation.predicate.trait_def_id(tcx)).unwrap(); - - let async_fn_kind_helper_trait_def_id = - tcx.require_lang_item(LangItem::AsyncFnKindHelper, None); - nested.push(obligation.with( - tcx, - ty::TraitRef::new( - tcx, - async_fn_kind_helper_trait_def_id, - [kind_ty, Ty::from_closure_kind(tcx, goal_kind)], - ), - )); - let env_region = match goal_kind { ty::ClosureKind::Fn | ty::ClosureKind::FnMut => obligation.predicate.args.region_at(2), ty::ClosureKind::FnOnce => tcx.lifetimes.re_static, }; - let upvars_projection_def_id = tcx - .associated_items(async_fn_kind_helper_trait_def_id) - .filter_by_name_unhygienic(sym::Upvars) - .next() - .unwrap() - .def_id; - - // FIXME(async_closures): Confirmation is kind of a mess here. Ideally, - // we'd short-circuit when we know that the goal_kind >= closure_kind, and not - // register a nested predicate or create a new projection ty here. But I'm too - // lazy to make this more efficient atm, and we can always tweak it later, - // since all this does is make the solver do more work. - // - // The code duplication due to the different length args is kind of weird, too. - // - // See the logic in `structural_traits` in the new solver to understand a bit - // more clearly how this *should* look. - let poly_cache_entry = args.coroutine_closure_sig().map_bound(|sig| { - let (projection_ty, term) = match tcx.item_name(obligation.predicate.def_id) { - sym::CallOnceFuture => { - let tupled_upvars_ty = Ty::new_projection( - tcx, - upvars_projection_def_id, - [ - ty::GenericArg::from(kind_ty), - Ty::from_closure_kind(tcx, goal_kind).into(), - env_region.into(), - sig.tupled_inputs_ty.into(), - args.tupled_upvars_ty().into(), - args.coroutine_captures_by_ref_ty().into(), - ], - ); - let coroutine_ty = sig.to_coroutine( + let item_name = tcx.item_name(obligation.predicate.def_id); + let term = match item_name { + sym::CallOnceFuture | sym::CallMutFuture | sym::CallFuture => { + if let Some(closure_kind) = kind_ty.to_opt_closure_kind() { + if !closure_kind.extends(goal_kind) { + bug!("we should not be confirming if the closure kind is not met"); + } + sig.to_coroutine_given_kind_and_upvars( tcx, args.parent_args(), - Ty::from_closure_kind(tcx, goal_kind), tcx.coroutine_for_closure(def_id), - tupled_upvars_ty, - ); - ( - ty::AliasTy::new( - tcx, - obligation.predicate.def_id, - [self_ty, sig.tupled_inputs_ty], - ), - coroutine_ty.into(), + goal_kind, + env_region, + args.tupled_upvars_ty(), + args.coroutine_captures_by_ref_ty(), ) - } - sym::CallMutFuture | sym::CallFuture => { + } else { + let async_fn_kind_trait_def_id = + tcx.require_lang_item(LangItem::AsyncFnKindHelper, None); + let upvars_projection_def_id = tcx + .associated_items(async_fn_kind_trait_def_id) + .filter_by_name_unhygienic(sym::Upvars) + .next() + .unwrap() + .def_id; + // When we don't know the closure kind (and therefore also the closure's upvars, + // which are computed at the same time), we must delay the computation of the + // generator's upvars. We do this using the `AsyncFnKindHelper`, which as a trait + // goal functions similarly to the old `ClosureKind` predicate, and ensures that + // the goal kind <= the closure kind. As a projection `AsyncFnKindHelper::Upvars` + // will project to the right upvars for the generator, appending the inputs and + // coroutine upvars respecting the closure kind. + // N.B. No need to register a `AsyncFnKindHelper` goal here, it's already in `nested`. let tupled_upvars_ty = Ty::new_projection( tcx, upvars_projection_def_id, @@ -2538,37 +2511,38 @@ fn confirm_async_closure_candidate<'cx, 'tcx>( args.coroutine_captures_by_ref_ty().into(), ], ); - let coroutine_ty = sig.to_coroutine( + sig.to_coroutine( tcx, args.parent_args(), Ty::from_closure_kind(tcx, goal_kind), tcx.coroutine_for_closure(def_id), tupled_upvars_ty, - ); - ( - ty::AliasTy::new( - tcx, - obligation.predicate.def_id, - [ - ty::GenericArg::from(self_ty), - sig.tupled_inputs_ty.into(), - env_region.into(), - ], - ), - coroutine_ty.into(), ) } - sym::Output => ( - ty::AliasTy::new(tcx, obligation.predicate.def_id, [self_ty, sig.tupled_inputs_ty]), - sig.return_ty.into(), - ), - name => bug!("no such associated type: {name}"), - }; - ty::ProjectionPredicate { projection_ty, term } - }); + } + sym::Output => sig.return_ty, + name => bug!("no such associated type: {name}"), + }; + let projection_ty = match item_name { + sym::CallOnceFuture | sym::Output => { + ty::AliasTy::new(tcx, obligation.predicate.def_id, [self_ty, sig.tupled_inputs_ty]) + } + sym::CallMutFuture | sym::CallFuture => ty::AliasTy::new( + tcx, + obligation.predicate.def_id, + [ty::GenericArg::from(self_ty), sig.tupled_inputs_ty.into(), env_region.into()], + ), + name => bug!("no such associated type: {name}"), + }; - confirm_param_env_candidate(selcx, obligation, poly_cache_entry, true) - .with_addl_obligations(nested) + confirm_param_env_candidate( + selcx, + obligation, + args.coroutine_closure_sig() + .rebind(ty::ProjectionPredicate { projection_ty, term: term.into() }), + true, + ) + .with_addl_obligations(nested) } fn confirm_async_fn_kind_helper_candidate<'cx, 'tcx>( From 0dd40786b555c04afa52b9d0c789a29dbd4e3dd2 Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Wed, 31 Jan 2024 18:22:16 +0000 Subject: [PATCH 2/7] Harmonize blanket implementations for AsyncFn* traits --- library/alloc/src/boxed.rs | 29 +++++++++++++++ library/alloc/src/lib.rs | 1 + library/core/src/ops/async_function.rs | 51 +++++++++++++++++++------- 3 files changed, 67 insertions(+), 14 deletions(-) diff --git a/library/alloc/src/boxed.rs b/library/alloc/src/boxed.rs index 953041b8c202b..92600b8e5bddb 100644 --- a/library/alloc/src/boxed.rs +++ b/library/alloc/src/boxed.rs @@ -159,6 +159,7 @@ use core::iter::FusedIterator; use core::marker::Tuple; use core::marker::Unsize; use core::mem::{self, SizedTypeProperties}; +use core::ops::{AsyncFn, AsyncFnMut, AsyncFnOnce}; use core::ops::{ CoerceUnsized, Coroutine, CoroutineState, Deref, DerefMut, DispatchFromDyn, Receiver, }; @@ -2030,6 +2031,34 @@ impl + ?Sized, A: Allocator> Fn for Box { } } +#[unstable(feature = "async_fn_traits", issue = "none")] +impl + ?Sized, A: Allocator> AsyncFnOnce for Box { + type Output = F::Output; + type CallOnceFuture = F::CallOnceFuture; + + extern "rust-call" fn async_call_once(self, args: Args) -> Self::CallOnceFuture { + F::async_call_once(*self, args) + } +} + +#[unstable(feature = "async_fn_traits", issue = "none")] +impl + ?Sized, A: Allocator> AsyncFnMut for Box { + type CallMutFuture<'a> = F::CallMutFuture<'a> where Self: 'a; + + extern "rust-call" fn async_call_mut(&mut self, args: Args) -> Self::CallMutFuture<'_> { + F::async_call_mut(self, args) + } +} + +#[unstable(feature = "async_fn_traits", issue = "none")] +impl + ?Sized, A: Allocator> AsyncFn for Box { + type CallFuture<'a> = F::CallFuture<'a> where Self: 'a; + + extern "rust-call" fn async_call(&self, args: Args) -> Self::CallFuture<'_> { + F::async_call(self, args) + } +} + #[unstable(feature = "coerce_unsized", issue = "18598")] impl, U: ?Sized, A: Allocator> CoerceUnsized> for Box {} diff --git a/library/alloc/src/lib.rs b/library/alloc/src/lib.rs index 96d43e11dc6b1..3341b564d1f65 100644 --- a/library/alloc/src/lib.rs +++ b/library/alloc/src/lib.rs @@ -106,6 +106,7 @@ #![feature(array_windows)] #![feature(ascii_char)] #![feature(assert_matches)] +#![feature(async_fn_traits)] #![feature(async_iterator)] #![feature(coerce_unsized)] #![feature(const_align_of_val)] diff --git a/library/core/src/ops/async_function.rs b/library/core/src/ops/async_function.rs index efbe9d164c3a7..19b1220f05ec2 100644 --- a/library/core/src/ops/async_function.rs +++ b/library/core/src/ops/async_function.rs @@ -65,44 +65,67 @@ pub trait AsyncFnOnce { mod impls { use super::{AsyncFn, AsyncFnMut, AsyncFnOnce}; - use crate::future::Future; use crate::marker::Tuple; #[unstable(feature = "async_fn_traits", issue = "none")] - impl, A: Tuple> AsyncFn for F + impl AsyncFn for &F where - >::Output: Future, + F: AsyncFn, { - type CallFuture<'a> = >::Output where Self: 'a; + type CallFuture<'a> = F::CallFuture<'a> where Self: 'a; extern "rust-call" fn async_call(&self, args: A) -> Self::CallFuture<'_> { - self.call(args) + F::async_call(*self, args) } } #[unstable(feature = "async_fn_traits", issue = "none")] - impl, A: Tuple> AsyncFnMut for F + impl AsyncFnMut for &F where - >::Output: Future, + F: AsyncFn, { - type CallMutFuture<'a> = >::Output where Self: 'a; + type CallMutFuture<'a> = F::CallFuture<'a> where Self: 'a; extern "rust-call" fn async_call_mut(&mut self, args: A) -> Self::CallMutFuture<'_> { - self.call_mut(args) + F::async_call(*self, args) } } #[unstable(feature = "async_fn_traits", issue = "none")] - impl, A: Tuple> AsyncFnOnce for F + impl<'a, A: Tuple, F: ?Sized> AsyncFnOnce for &'a F where - >::Output: Future, + F: AsyncFn, { - type CallOnceFuture = >::Output; + type Output = F::Output; + type CallOnceFuture = F::CallFuture<'a>; - type Output = <>::Output as Future>::Output; + extern "rust-call" fn async_call_once(self, args: A) -> Self::CallOnceFuture { + F::async_call(self, args) + } + } + + #[unstable(feature = "async_fn_traits", issue = "none")] + impl AsyncFnMut for &mut F + where + F: AsyncFnMut, + { + type CallMutFuture<'a> = F::CallMutFuture<'a> where Self: 'a; + + extern "rust-call" fn async_call_mut(&mut self, args: A) -> Self::CallMutFuture<'_> { + F::async_call_mut(*self, args) + } + } + + #[unstable(feature = "async_fn_traits", issue = "none")] + impl<'a, A: Tuple, F: ?Sized> AsyncFnOnce for &'a mut F + where + F: AsyncFnMut, + { + type Output = F::Output; + type CallOnceFuture = F::CallMutFuture<'a>; extern "rust-call" fn async_call_once(self, args: A) -> Self::CallOnceFuture { - self.call_once(args) + F::async_call_mut(self, args) } } } From 08af64e96be28c3680d6e8c96d437a560d3a9ae3 Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Mon, 5 Feb 2024 19:17:18 +0000 Subject: [PATCH 3/7] Regular closures now built-in impls for AsyncFn* --- .../src/solve/assembly/structural_traits.rs | 73 +++++- .../src/traits/project.rs | 221 ++++++++++++------ .../src/traits/select/candidate_assembly.rs | 14 +- .../src/traits/select/confirmation.rs | 100 +++++--- compiler/rustc_ty_utils/src/instance.rs | 13 ++ tests/ui/async-await/async-fn/simple.rs | 2 +- tests/ui/did_you_mean/bad-assoc-ty.stderr | 9 +- 7 files changed, 318 insertions(+), 114 deletions(-) diff --git a/compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs b/compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs index 4fd9a29c0b2e3..4b95d26f9f8b7 100644 --- a/compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs +++ b/compiler/rustc_trait_selection/src/solve/assembly/structural_traits.rs @@ -394,7 +394,78 @@ pub(in crate::solve) fn extract_tupled_inputs_and_output_from_async_callable<'tc )) } - ty::FnDef(..) | ty::FnPtr(..) | ty::Closure(..) => Err(NoSolution), + ty::FnDef(..) | ty::FnPtr(..) => { + let bound_sig = self_ty.fn_sig(tcx); + let sig = bound_sig.skip_binder(); + let future_trait_def_id = tcx.require_lang_item(LangItem::Future, None); + // `FnDef` and `FnPtr` only implement `AsyncFn*` when their + // return type implements `Future`. + let nested = vec![ + bound_sig + .rebind(ty::TraitRef::new(tcx, future_trait_def_id, [sig.output()])) + .to_predicate(tcx), + ]; + let future_output_def_id = tcx + .associated_items(future_trait_def_id) + .filter_by_name_unhygienic(sym::Output) + .next() + .unwrap() + .def_id; + let future_output_ty = Ty::new_projection(tcx, future_output_def_id, [sig.output()]); + Ok(( + bound_sig.rebind((Ty::new_tup(tcx, sig.inputs()), sig.output(), future_output_ty)), + nested, + )) + } + ty::Closure(_, args) => { + let args = args.as_closure(); + let bound_sig = args.sig(); + let sig = bound_sig.skip_binder(); + let future_trait_def_id = tcx.require_lang_item(LangItem::Future, None); + // `Closure`s only implement `AsyncFn*` when their return type + // implements `Future`. + let mut nested = vec![ + bound_sig + .rebind(ty::TraitRef::new(tcx, future_trait_def_id, [sig.output()])) + .to_predicate(tcx), + ]; + + // Additionally, we need to check that the closure kind + // is still compatible. + let kind_ty = args.kind_ty(); + if let Some(closure_kind) = kind_ty.to_opt_closure_kind() { + if !closure_kind.extends(goal_kind) { + return Err(NoSolution); + } + } else { + let async_fn_kind_trait_def_id = + tcx.require_lang_item(LangItem::AsyncFnKindHelper, None); + // When we don't know the closure kind (and therefore also the closure's upvars, + // which are computed at the same time), we must delay the computation of the + // generator's upvars. We do this using the `AsyncFnKindHelper`, which as a trait + // goal functions similarly to the old `ClosureKind` predicate, and ensures that + // the goal kind <= the closure kind. As a projection `AsyncFnKindHelper::Upvars` + // will project to the right upvars for the generator, appending the inputs and + // coroutine upvars respecting the closure kind. + nested.push( + ty::TraitRef::new( + tcx, + async_fn_kind_trait_def_id, + [kind_ty, Ty::from_closure_kind(tcx, goal_kind)], + ) + .to_predicate(tcx), + ); + } + + let future_output_def_id = tcx + .associated_items(future_trait_def_id) + .filter_by_name_unhygienic(sym::Output) + .next() + .unwrap() + .def_id; + let future_output_ty = Ty::new_projection(tcx, future_output_def_id, [sig.output()]); + Ok((bound_sig.rebind((sig.inputs()[0], sig.output(), future_output_ty)), nested)) + } ty::Bool | ty::Char diff --git a/compiler/rustc_trait_selection/src/traits/project.rs b/compiler/rustc_trait_selection/src/traits/project.rs index 88c28761d25f4..f45a20ccd325a 100644 --- a/compiler/rustc_trait_selection/src/traits/project.rs +++ b/compiler/rustc_trait_selection/src/traits/project.rs @@ -2450,14 +2450,6 @@ fn confirm_async_closure_candidate<'cx, 'tcx>( ) -> Progress<'tcx> { let tcx = selcx.tcx(); let self_ty = selcx.infcx.shallow_resolve(obligation.predicate.self_ty()); - let ty::CoroutineClosure(def_id, args) = *self_ty.kind() else { - unreachable!( - "expected coroutine-closure self type for coroutine-closure candidate, found {self_ty}" - ) - }; - let args = args.as_coroutine_closure(); - let kind_ty = args.kind_ty(); - let sig = args.coroutine_closure_sig().skip_binder(); let goal_kind = tcx.async_fn_trait_kind_from_def_id(obligation.predicate.trait_def_id(tcx)).unwrap(); @@ -2465,84 +2457,163 @@ fn confirm_async_closure_candidate<'cx, 'tcx>( ty::ClosureKind::Fn | ty::ClosureKind::FnMut => obligation.predicate.args.region_at(2), ty::ClosureKind::FnOnce => tcx.lifetimes.re_static, }; - let item_name = tcx.item_name(obligation.predicate.def_id); - let term = match item_name { - sym::CallOnceFuture | sym::CallMutFuture | sym::CallFuture => { - if let Some(closure_kind) = kind_ty.to_opt_closure_kind() { - if !closure_kind.extends(goal_kind) { - bug!("we should not be confirming if the closure kind is not met"); + + let poly_cache_entry = match *self_ty.kind() { + ty::CoroutineClosure(def_id, args) => { + let args = args.as_coroutine_closure(); + let kind_ty = args.kind_ty(); + let sig = args.coroutine_closure_sig().skip_binder(); + + let term = match item_name { + sym::CallOnceFuture | sym::CallMutFuture | sym::CallFuture => { + if let Some(closure_kind) = kind_ty.to_opt_closure_kind() { + if !closure_kind.extends(goal_kind) { + bug!("we should not be confirming if the closure kind is not met"); + } + sig.to_coroutine_given_kind_and_upvars( + tcx, + args.parent_args(), + tcx.coroutine_for_closure(def_id), + goal_kind, + env_region, + args.tupled_upvars_ty(), + args.coroutine_captures_by_ref_ty(), + ) + } else { + let async_fn_kind_trait_def_id = + tcx.require_lang_item(LangItem::AsyncFnKindHelper, None); + let upvars_projection_def_id = tcx + .associated_items(async_fn_kind_trait_def_id) + .filter_by_name_unhygienic(sym::Upvars) + .next() + .unwrap() + .def_id; + // When we don't know the closure kind (and therefore also the closure's upvars, + // which are computed at the same time), we must delay the computation of the + // generator's upvars. We do this using the `AsyncFnKindHelper`, which as a trait + // goal functions similarly to the old `ClosureKind` predicate, and ensures that + // the goal kind <= the closure kind. As a projection `AsyncFnKindHelper::Upvars` + // will project to the right upvars for the generator, appending the inputs and + // coroutine upvars respecting the closure kind. + // N.B. No need to register a `AsyncFnKindHelper` goal here, it's already in `nested`. + let tupled_upvars_ty = Ty::new_projection( + tcx, + upvars_projection_def_id, + [ + ty::GenericArg::from(kind_ty), + Ty::from_closure_kind(tcx, goal_kind).into(), + env_region.into(), + sig.tupled_inputs_ty.into(), + args.tupled_upvars_ty().into(), + args.coroutine_captures_by_ref_ty().into(), + ], + ); + sig.to_coroutine( + tcx, + args.parent_args(), + Ty::from_closure_kind(tcx, goal_kind), + tcx.coroutine_for_closure(def_id), + tupled_upvars_ty, + ) + } } - sig.to_coroutine_given_kind_and_upvars( + sym::Output => sig.return_ty, + name => bug!("no such associated type: {name}"), + }; + let projection_ty = match item_name { + sym::CallOnceFuture | sym::Output => ty::AliasTy::new( tcx, - args.parent_args(), - tcx.coroutine_for_closure(def_id), - goal_kind, - env_region, - args.tupled_upvars_ty(), - args.coroutine_captures_by_ref_ty(), - ) - } else { - let async_fn_kind_trait_def_id = - tcx.require_lang_item(LangItem::AsyncFnKindHelper, None); - let upvars_projection_def_id = tcx - .associated_items(async_fn_kind_trait_def_id) - .filter_by_name_unhygienic(sym::Upvars) - .next() - .unwrap() - .def_id; - // When we don't know the closure kind (and therefore also the closure's upvars, - // which are computed at the same time), we must delay the computation of the - // generator's upvars. We do this using the `AsyncFnKindHelper`, which as a trait - // goal functions similarly to the old `ClosureKind` predicate, and ensures that - // the goal kind <= the closure kind. As a projection `AsyncFnKindHelper::Upvars` - // will project to the right upvars for the generator, appending the inputs and - // coroutine upvars respecting the closure kind. - // N.B. No need to register a `AsyncFnKindHelper` goal here, it's already in `nested`. - let tupled_upvars_ty = Ty::new_projection( + obligation.predicate.def_id, + [self_ty, sig.tupled_inputs_ty], + ), + sym::CallMutFuture | sym::CallFuture => ty::AliasTy::new( tcx, - upvars_projection_def_id, + obligation.predicate.def_id, + [ty::GenericArg::from(self_ty), sig.tupled_inputs_ty.into(), env_region.into()], + ), + name => bug!("no such associated type: {name}"), + }; + + args.coroutine_closure_sig() + .rebind(ty::ProjectionPredicate { projection_ty, term: term.into() }) + } + ty::FnDef(..) | ty::FnPtr(..) => { + let bound_sig = self_ty.fn_sig(tcx); + let sig = bound_sig.skip_binder(); + + let term = match item_name { + sym::CallOnceFuture | sym::CallMutFuture | sym::CallFuture => sig.output(), + sym::Output => { + let future_trait_def_id = tcx.require_lang_item(LangItem::Future, None); + let future_output_def_id = tcx + .associated_items(future_trait_def_id) + .filter_by_name_unhygienic(sym::Output) + .next() + .unwrap() + .def_id; + Ty::new_projection(tcx, future_output_def_id, [sig.output()]) + } + name => bug!("no such associated type: {name}"), + }; + let projection_ty = match item_name { + sym::CallOnceFuture | sym::Output => ty::AliasTy::new( + tcx, + obligation.predicate.def_id, + [self_ty, Ty::new_tup(tcx, sig.inputs())], + ), + sym::CallMutFuture | sym::CallFuture => ty::AliasTy::new( + tcx, + obligation.predicate.def_id, [ - ty::GenericArg::from(kind_ty), - Ty::from_closure_kind(tcx, goal_kind).into(), + ty::GenericArg::from(self_ty), + Ty::new_tup(tcx, sig.inputs()).into(), env_region.into(), - sig.tupled_inputs_ty.into(), - args.tupled_upvars_ty().into(), - args.coroutine_captures_by_ref_ty().into(), ], - ); - sig.to_coroutine( - tcx, - args.parent_args(), - Ty::from_closure_kind(tcx, goal_kind), - tcx.coroutine_for_closure(def_id), - tupled_upvars_ty, - ) - } + ), + name => bug!("no such associated type: {name}"), + }; + + bound_sig.rebind(ty::ProjectionPredicate { projection_ty, term: term.into() }) } - sym::Output => sig.return_ty, - name => bug!("no such associated type: {name}"), - }; - let projection_ty = match item_name { - sym::CallOnceFuture | sym::Output => { - ty::AliasTy::new(tcx, obligation.predicate.def_id, [self_ty, sig.tupled_inputs_ty]) + ty::Closure(_, args) => { + let args = args.as_closure(); + let bound_sig = args.sig(); + let sig = bound_sig.skip_binder(); + + let term = match item_name { + sym::CallOnceFuture | sym::CallMutFuture | sym::CallFuture => sig.output(), + sym::Output => { + let future_trait_def_id = tcx.require_lang_item(LangItem::Future, None); + let future_output_def_id = tcx + .associated_items(future_trait_def_id) + .filter_by_name_unhygienic(sym::Output) + .next() + .unwrap() + .def_id; + Ty::new_projection(tcx, future_output_def_id, [sig.output()]) + } + name => bug!("no such associated type: {name}"), + }; + let projection_ty = match item_name { + sym::CallOnceFuture | sym::Output => { + ty::AliasTy::new(tcx, obligation.predicate.def_id, [self_ty, sig.inputs()[0]]) + } + sym::CallMutFuture | sym::CallFuture => ty::AliasTy::new( + tcx, + obligation.predicate.def_id, + [ty::GenericArg::from(self_ty), sig.inputs()[0].into(), env_region.into()], + ), + name => bug!("no such associated type: {name}"), + }; + + bound_sig.rebind(ty::ProjectionPredicate { projection_ty, term: term.into() }) } - sym::CallMutFuture | sym::CallFuture => ty::AliasTy::new( - tcx, - obligation.predicate.def_id, - [ty::GenericArg::from(self_ty), sig.tupled_inputs_ty.into(), env_region.into()], - ), - name => bug!("no such associated type: {name}"), + _ => bug!("expected callable type for AsyncFn candidate"), }; - confirm_param_env_candidate( - selcx, - obligation, - args.coroutine_closure_sig() - .rebind(ty::ProjectionPredicate { projection_ty, term: term.into() }), - true, - ) - .with_addl_obligations(nested) + confirm_param_env_candidate(selcx, obligation, poly_cache_entry, true) + .with_addl_obligations(nested) } fn confirm_async_fn_kind_helper_candidate<'cx, 'tcx>( diff --git a/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs b/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs index 2258e7961038b..34dc85537140a 100644 --- a/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs +++ b/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs @@ -361,8 +361,18 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> { } candidates.vec.push(AsyncClosureCandidate); } - ty::Infer(ty::TyVar(_)) => { - candidates.ambiguous = true; + // Closures and fn pointers implement `AsyncFn*` if their return types + // implement `Future`, which is checked later. + ty::Closure(_, args) => { + if let Some(closure_kind) = args.as_closure().kind_ty().to_opt_closure_kind() + && !closure_kind.extends(goal_kind) + { + return; + } + candidates.vec.push(AsyncClosureCandidate); + } + ty::FnDef(..) | ty::FnPtr(..) => { + candidates.vec.push(AsyncClosureCandidate); } _ => {} } diff --git a/compiler/rustc_trait_selection/src/traits/select/confirmation.rs b/compiler/rustc_trait_selection/src/traits/select/confirmation.rs index c9d06b0f67521..4284516954922 100644 --- a/compiler/rustc_trait_selection/src/traits/select/confirmation.rs +++ b/compiler/rustc_trait_selection/src/traits/select/confirmation.rs @@ -883,40 +883,86 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> { &mut self, obligation: &PolyTraitObligation<'tcx>, ) -> Result>, SelectionError<'tcx>> { - // Okay to skip binder because the args on closure types never - // touch bound regions, they just capture the in-scope - // type/region parameters. + let tcx = self.tcx(); let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder()); - let ty::CoroutineClosure(closure_def_id, args) = *self_ty.kind() else { - bug!("async closure candidate for non-coroutine-closure {:?}", obligation); - }; - let trait_ref = args.as_coroutine_closure().coroutine_closure_sig().map_bound(|sig| { - ty::TraitRef::new( - self.tcx(), - obligation.predicate.def_id(), - [self_ty, sig.tupled_inputs_ty], - ) - }); + let mut nested = vec![]; + let (trait_ref, kind_ty) = match *self_ty.kind() { + ty::CoroutineClosure(_, args) => { + let args = args.as_coroutine_closure(); + let trait_ref = args.coroutine_closure_sig().map_bound(|sig| { + ty::TraitRef::new( + self.tcx(), + obligation.predicate.def_id(), + [self_ty, sig.tupled_inputs_ty], + ) + }); + (trait_ref, args.kind_ty()) + } + ty::FnDef(..) | ty::FnPtr(..) => { + let sig = self_ty.fn_sig(tcx); + let trait_ref = sig.map_bound(|sig| { + ty::TraitRef::new( + self.tcx(), + obligation.predicate.def_id(), + [self_ty, Ty::new_tup(tcx, sig.inputs())], + ) + }); + // We must additionally check that the return type impls `Future`. + let future_trait_def_id = tcx.require_lang_item(LangItem::Future, None); + nested.push(obligation.with( + tcx, + sig.map_bound(|sig| { + ty::TraitRef::new(tcx, future_trait_def_id, [sig.output()]) + }), + )); + (trait_ref, Ty::from_closure_kind(tcx, ty::ClosureKind::Fn)) + } + ty::Closure(_, args) => { + let sig = args.as_closure().sig(); + let trait_ref = sig.map_bound(|sig| { + ty::TraitRef::new( + self.tcx(), + obligation.predicate.def_id(), + [self_ty, sig.inputs()[0]], + ) + }); + // We must additionally check that the return type impls `Future`. + let future_trait_def_id = tcx.require_lang_item(LangItem::Future, None); + nested.push(obligation.with( + tcx, + sig.map_bound(|sig| { + ty::TraitRef::new(tcx, future_trait_def_id, [sig.output()]) + }), + )); + (trait_ref, Ty::from_closure_kind(tcx, ty::ClosureKind::Fn)) + } + _ => bug!("expected callable type for AsyncFn candidate"), + }; - let mut nested = self.confirm_poly_trait_refs(obligation, trait_ref)?; + nested.extend(self.confirm_poly_trait_refs(obligation, trait_ref)?); let goal_kind = self.tcx().async_fn_trait_kind_from_def_id(obligation.predicate.def_id()).unwrap(); - nested.push(obligation.with( - self.tcx(), - ty::TraitRef::from_lang_item( - self.tcx(), - LangItem::AsyncFnKindHelper, - obligation.cause.span, - [ - args.as_coroutine_closure().kind_ty(), - Ty::from_closure_kind(self.tcx(), goal_kind), - ], - ), - )); - debug!(?closure_def_id, ?trait_ref, ?nested, "confirm closure candidate obligations"); + // If we have not yet determiend the `ClosureKind` of the closure or coroutine-closure, + // then additionally register an `AsyncFnKindHelper` goal which will fail if the kind + // is constrained to an insufficient type later on. + if let Some(closure_kind) = self.infcx.shallow_resolve(kind_ty).to_opt_closure_kind() { + if !closure_kind.extends(goal_kind) { + return Err(SelectionError::Unimplemented); + } + } else { + nested.push(obligation.with( + self.tcx(), + ty::TraitRef::from_lang_item( + self.tcx(), + LangItem::AsyncFnKindHelper, + obligation.cause.span, + [kind_ty, Ty::from_closure_kind(self.tcx(), goal_kind)], + ), + )); + } Ok(nested) } diff --git a/compiler/rustc_ty_utils/src/instance.rs b/compiler/rustc_ty_utils/src/instance.rs index 9faad10dd14df..bcc7c98ed6995 100644 --- a/compiler/rustc_ty_utils/src/instance.rs +++ b/compiler/rustc_ty_utils/src/instance.rs @@ -306,6 +306,19 @@ fn resolve_associated_item<'tcx>( Some(Instance::new(coroutine_closure_def_id, args)) } } + ty::Closure(closure_def_id, args) => { + let trait_closure_kind = tcx.fn_trait_kind_from_def_id(trait_id).unwrap(); + Some(Instance::resolve_closure( + tcx, + closure_def_id, + args, + trait_closure_kind, + )) + } + ty::FnDef(..) | ty::FnPtr(..) => Some(Instance { + def: ty::InstanceDef::FnPtrShim(trait_item_id, rcvr_args.type_at(0)), + args: rcvr_args, + }), _ => bug!( "no built-in definition for `{trait_ref}::{}` for non-lending-closure type", tcx.item_name(trait_item_id) diff --git a/tests/ui/async-await/async-fn/simple.rs b/tests/ui/async-await/async-fn/simple.rs index 99a5d56a3093b..172ede7098a2b 100644 --- a/tests/ui/async-await/async-fn/simple.rs +++ b/tests/ui/async-await/async-fn/simple.rs @@ -1,5 +1,5 @@ // edition: 2021 -// check-pass +// build-pass #![feature(async_fn_traits)] diff --git a/tests/ui/did_you_mean/bad-assoc-ty.stderr b/tests/ui/did_you_mean/bad-assoc-ty.stderr index eed01267224d3..3c474d19d1d05 100644 --- a/tests/ui/did_you_mean/bad-assoc-ty.stderr +++ b/tests/ui/did_you_mean/bad-assoc-ty.stderr @@ -191,14 +191,7 @@ error[E0223]: ambiguous associated type --> $DIR/bad-assoc-ty.rs:33:10 | LL | type H = Fn(u8) -> (u8)::Output; - | ^^^^^^^^^^^^^^^^^^^^^^ - | -help: use fully-qualified syntax - | -LL | type H = <(dyn Fn(u8) -> u8 + 'static) as AsyncFnOnce>::Output; - | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -LL | type H = <(dyn Fn(u8) -> u8 + 'static) as IntoFuture>::Output; - | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + | ^^^^^^^^^^^^^^^^^^^^^^ help: use fully-qualified syntax: `<(dyn Fn(u8) -> u8 + 'static) as IntoFuture>::Output` error[E0223]: ambiguous associated type --> $DIR/bad-assoc-ty.rs:39:19 From b8c93f1223695217cbabc1f3f1e428c358bb4e7a Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Mon, 5 Feb 2024 19:59:05 +0000 Subject: [PATCH 4/7] Coroutine closures implement regular Fn traits, when possible --- compiler/rustc_hir_typeck/src/closure.rs | 17 +++-- .../src/traits/project.rs | 74 ++++++++++++++++++- .../src/traits/select/candidate_assembly.rs | 25 +++++++ .../src/traits/select/confirmation.rs | 26 ++++--- compiler/rustc_ty_utils/src/instance.rs | 18 +++++ 5 files changed, 142 insertions(+), 18 deletions(-) diff --git a/compiler/rustc_hir_typeck/src/closure.rs b/compiler/rustc_hir_typeck/src/closure.rs index a985fa201d071..5bdd9412d0e51 100644 --- a/compiler/rustc_hir_typeck/src/closure.rs +++ b/compiler/rustc_hir_typeck/src/closure.rs @@ -56,11 +56,18 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { // It's always helpful for inference if we know the kind of // closure sooner rather than later, so first examine the expected // type, and see if can glean a closure kind from there. - let (expected_sig, expected_kind) = match expected.to_option(self) { - Some(ty) => { - self.deduce_closure_signature(self.try_structurally_resolve_type(expr_span, ty)) - } - None => (None, None), + let (expected_sig, expected_kind) = match closure.kind { + hir::ClosureKind::Closure => match expected.to_option(self) { + Some(ty) => { + self.deduce_closure_signature(self.try_structurally_resolve_type(expr_span, ty)) + } + None => (None, None), + }, + // We don't want to deduce a signature from `Fn` bounds for coroutines + // or coroutine-closures, because the former does not implement `Fn` + // ever, and the latter's signature doesn't correspond to the coroutine + // type that it returns. + hir::ClosureKind::Coroutine(_) | hir::ClosureKind::CoroutineClosure(_) => (None, None), }; let ClosureSignatures { bound_sig, mut liberated_sig } = diff --git a/compiler/rustc_trait_selection/src/traits/project.rs b/compiler/rustc_trait_selection/src/traits/project.rs index f45a20ccd325a..0dc11d785c460 100644 --- a/compiler/rustc_trait_selection/src/traits/project.rs +++ b/compiler/rustc_trait_selection/src/traits/project.rs @@ -2074,7 +2074,9 @@ fn confirm_select_candidate<'cx, 'tcx>( } else if lang_items.async_iterator_trait() == Some(trait_def_id) { confirm_async_iterator_candidate(selcx, obligation, data) } else if selcx.tcx().fn_trait_kind_from_def_id(trait_def_id).is_some() { - if obligation.predicate.self_ty().is_closure() { + if obligation.predicate.self_ty().is_closure() + || obligation.predicate.self_ty().is_coroutine_closure() + { confirm_closure_candidate(selcx, obligation, data) } else { confirm_fn_pointer_candidate(selcx, obligation, data) @@ -2386,11 +2388,75 @@ fn confirm_closure_candidate<'cx, 'tcx>( obligation: &ProjectionTyObligation<'tcx>, nested: Vec>, ) -> Progress<'tcx> { + let tcx = selcx.tcx(); let self_ty = selcx.infcx.shallow_resolve(obligation.predicate.self_ty()); - let ty::Closure(_, args) = self_ty.kind() else { - unreachable!("expected closure self type for closure candidate, found {self_ty}") + let closure_sig = match *self_ty.kind() { + ty::Closure(_, args) => args.as_closure().sig(), + + // Construct a "normal" `FnOnce` signature for coroutine-closure. This is + // basically duplicated with the `AsyncFnOnce::CallOnce` confirmation, but + // I didn't see a good way to unify those. + ty::CoroutineClosure(def_id, args) => { + let args = args.as_coroutine_closure(); + let kind_ty = args.kind_ty(); + args.coroutine_closure_sig().map_bound(|sig| { + // If we know the kind and upvars, use that directly. + // Otherwise, defer to `AsyncFnKindHelper::Upvars` to delay + // the projection, like the `AsyncFn*` traits do. + let output_ty = if let Some(_) = kind_ty.to_opt_closure_kind() { + sig.to_coroutine_given_kind_and_upvars( + tcx, + args.parent_args(), + tcx.coroutine_for_closure(def_id), + ty::ClosureKind::FnOnce, + tcx.lifetimes.re_static, + args.tupled_upvars_ty(), + args.coroutine_captures_by_ref_ty(), + ) + } else { + let async_fn_kind_trait_def_id = + tcx.require_lang_item(LangItem::AsyncFnKindHelper, None); + let upvars_projection_def_id = tcx + .associated_items(async_fn_kind_trait_def_id) + .filter_by_name_unhygienic(sym::Upvars) + .next() + .unwrap() + .def_id; + let tupled_upvars_ty = Ty::new_projection( + tcx, + upvars_projection_def_id, + [ + ty::GenericArg::from(kind_ty), + Ty::from_closure_kind(tcx, ty::ClosureKind::FnOnce).into(), + tcx.lifetimes.re_static.into(), + sig.tupled_inputs_ty.into(), + args.tupled_upvars_ty().into(), + args.coroutine_captures_by_ref_ty().into(), + ], + ); + sig.to_coroutine( + tcx, + args.parent_args(), + Ty::from_closure_kind(tcx, ty::ClosureKind::FnOnce), + tcx.coroutine_for_closure(def_id), + tupled_upvars_ty, + ) + }; + tcx.mk_fn_sig( + [sig.tupled_inputs_ty], + output_ty, + sig.c_variadic, + sig.unsafety, + sig.abi, + ) + }) + } + + _ => { + unreachable!("expected closure self type for closure candidate, found {self_ty}"); + } }; - let closure_sig = args.as_closure().sig(); + let Normalized { value: closure_sig, obligations } = normalize_with_depth( selcx, obligation.param_env, diff --git a/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs b/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs index 34dc85537140a..a82acc3ba0549 100644 --- a/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs +++ b/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs @@ -332,6 +332,31 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> { } } } + ty::CoroutineClosure(def_id, args) => { + let is_const = self.tcx().is_const_fn_raw(def_id); + match self.infcx.closure_kind(self_ty) { + Some(closure_kind) => { + let no_borrows = self + .infcx + .shallow_resolve(args.as_coroutine_closure().tupled_upvars_ty()) + .tuple_fields() + .is_empty(); + if no_borrows && closure_kind.extends(kind) { + candidates.vec.push(ClosureCandidate { is_const }); + } else if kind == ty::ClosureKind::FnOnce { + candidates.vec.push(ClosureCandidate { is_const }); + } + } + None => { + if kind == ty::ClosureKind::FnOnce { + candidates.vec.push(ClosureCandidate { is_const }); + } else { + // This stays ambiguous until kind+upvars are determined. + candidates.ambiguous = true; + } + } + } + } ty::Infer(ty::TyVar(_)) => { debug!("assemble_unboxed_closure_candidates: ambiguous self-type"); candidates.ambiguous = true; diff --git a/compiler/rustc_trait_selection/src/traits/select/confirmation.rs b/compiler/rustc_trait_selection/src/traits/select/confirmation.rs index 4284516954922..f2dc4b1be739a 100644 --- a/compiler/rustc_trait_selection/src/traits/select/confirmation.rs +++ b/compiler/rustc_trait_selection/src/traits/select/confirmation.rs @@ -865,17 +865,25 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> { // touch bound regions, they just capture the in-scope // type/region parameters. let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder()); - let ty::Closure(closure_def_id, args) = *self_ty.kind() else { - bug!("closure candidate for non-closure {:?}", obligation); + let trait_ref = match *self_ty.kind() { + ty::Closure(_, args) => { + self.closure_trait_ref_unnormalized(obligation, args, self.tcx().consts.true_) + } + ty::CoroutineClosure(_, args) => { + args.as_coroutine_closure().coroutine_closure_sig().map_bound(|sig| { + ty::TraitRef::new( + self.tcx(), + obligation.predicate.def_id(), + [self_ty, sig.tupled_inputs_ty], + ) + }) + } + _ => { + bug!("closure candidate for non-closure {:?}", obligation); + } }; - let trait_ref = - self.closure_trait_ref_unnormalized(obligation, args, self.tcx().consts.true_); - let nested = self.confirm_poly_trait_refs(obligation, trait_ref)?; - - debug!(?closure_def_id, ?trait_ref, ?nested, "confirm closure candidate obligations"); - - Ok(nested) + self.confirm_poly_trait_refs(obligation, trait_ref) } #[instrument(skip(self), level = "debug")] diff --git a/compiler/rustc_ty_utils/src/instance.rs b/compiler/rustc_ty_utils/src/instance.rs index bcc7c98ed6995..eae80199ce568 100644 --- a/compiler/rustc_ty_utils/src/instance.rs +++ b/compiler/rustc_ty_utils/src/instance.rs @@ -278,6 +278,24 @@ fn resolve_associated_item<'tcx>( def: ty::InstanceDef::FnPtrShim(trait_item_id, rcvr_args.type_at(0)), args: rcvr_args, }), + ty::CoroutineClosure(coroutine_closure_def_id, args) => { + // When a coroutine-closure implements the `Fn` traits, then it + // always dispatches to the `FnOnce` implementation. This is to + // ensure that the `closure_kind` of the resulting closure is in + // sync with the built-in trait implementations (since all of the + // implementations return `FnOnce::Output`). + if ty::ClosureKind::FnOnce == args.as_coroutine_closure().kind() { + Some(Instance::new(coroutine_closure_def_id, args)) + } else { + Some(Instance { + def: ty::InstanceDef::ConstructCoroutineInClosureShim { + coroutine_closure_def_id, + target_kind: ty::ClosureKind::FnOnce, + }, + args, + }) + } + } _ => bug!( "no built-in definition for `{trait_ref}::{}` for non-fn type", tcx.item_name(trait_item_id) From 3bb384aad6e7f61a0b4b8c604206e78ffa418df4 Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Thu, 8 Feb 2024 15:46:00 +0000 Subject: [PATCH 5/7] Prefer AsyncFn* over Fn* for coroutine-closures --- compiler/rustc_hir_typeck/src/callee.rs | 51 ++++++++++++------- .../src/traits/select/candidate_assembly.rs | 18 +++++-- .../async-await/async-closures/is-not-fn.rs | 5 +- .../async-closures/is-not-fn.stderr | 10 ++-- tests/ui/async-await/async-fn/dyn-pos.stderr | 24 +++++++++ 5 files changed, 79 insertions(+), 29 deletions(-) diff --git a/compiler/rustc_hir_typeck/src/callee.rs b/compiler/rustc_hir_typeck/src/callee.rs index fbe6f454dbc92..bed0d80787fbc 100644 --- a/compiler/rustc_hir_typeck/src/callee.rs +++ b/compiler/rustc_hir_typeck/src/callee.rs @@ -260,23 +260,40 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { adjusted_ty: Ty<'tcx>, opt_arg_exprs: Option<&'tcx [hir::Expr<'tcx>]>, ) -> Option<(Option>, MethodCallee<'tcx>)> { + // HACK(async_closures): For async closures, prefer `AsyncFn*` + // over `Fn*`, since all async closures implement `FnOnce`, but + // choosing that over `AsyncFn`/`AsyncFnMut` would be more restrictive. + // For other callables, just prefer `Fn*` for perf reasons. + // + // The order of trait choices here is not that big of a deal, + // since it just guides inference (and our choice of autoref). + // Though in the future, I'd like typeck to choose: + // `Fn > AsyncFn > FnMut > AsyncFnMut > FnOnce > AsyncFnOnce` + // ...or *ideally*, we just have `LendingFn`/`LendingFnMut`, which + // would naturally unify these two trait hierarchies in the most + // general way. + let call_trait_choices = if self.shallow_resolve(adjusted_ty).is_coroutine_closure() { + [ + (self.tcx.lang_items().async_fn_trait(), sym::async_call, true), + (self.tcx.lang_items().async_fn_mut_trait(), sym::async_call_mut, true), + (self.tcx.lang_items().async_fn_once_trait(), sym::async_call_once, false), + (self.tcx.lang_items().fn_trait(), sym::call, true), + (self.tcx.lang_items().fn_mut_trait(), sym::call_mut, true), + (self.tcx.lang_items().fn_once_trait(), sym::call_once, false), + ] + } else { + [ + (self.tcx.lang_items().fn_trait(), sym::call, true), + (self.tcx.lang_items().fn_mut_trait(), sym::call_mut, true), + (self.tcx.lang_items().fn_once_trait(), sym::call_once, false), + (self.tcx.lang_items().async_fn_trait(), sym::async_call, true), + (self.tcx.lang_items().async_fn_mut_trait(), sym::async_call_mut, true), + (self.tcx.lang_items().async_fn_once_trait(), sym::async_call_once, false), + ] + }; + // Try the options that are least restrictive on the caller first. - for (opt_trait_def_id, method_name, borrow) in [ - (self.tcx.lang_items().fn_trait(), Ident::with_dummy_span(sym::call), true), - (self.tcx.lang_items().fn_mut_trait(), Ident::with_dummy_span(sym::call_mut), true), - (self.tcx.lang_items().fn_once_trait(), Ident::with_dummy_span(sym::call_once), false), - (self.tcx.lang_items().async_fn_trait(), Ident::with_dummy_span(sym::async_call), true), - ( - self.tcx.lang_items().async_fn_mut_trait(), - Ident::with_dummy_span(sym::async_call_mut), - true, - ), - ( - self.tcx.lang_items().async_fn_once_trait(), - Ident::with_dummy_span(sym::async_call_once), - false, - ), - ] { + for (opt_trait_def_id, method_name, borrow) in call_trait_choices { let Some(trait_def_id) = opt_trait_def_id else { continue }; let opt_input_type = opt_arg_exprs.map(|arg_exprs| { @@ -293,7 +310,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { if let Some(ok) = self.lookup_method_in_trait( self.misc(call_expr.span), - method_name, + Ident::with_dummy_span(method_name), trait_def_id, adjusted_ty, opt_input_type.as_ref().map(slice::from_ref), diff --git a/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs b/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs index a82acc3ba0549..eb4b3b7a62ea4 100644 --- a/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs +++ b/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs @@ -336,11 +336,23 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> { let is_const = self.tcx().is_const_fn_raw(def_id); match self.infcx.closure_kind(self_ty) { Some(closure_kind) => { - let no_borrows = self + let no_borrows = match self .infcx .shallow_resolve(args.as_coroutine_closure().tupled_upvars_ty()) - .tuple_fields() - .is_empty(); + .kind() + { + ty::Tuple(tys) => tys.is_empty(), + ty::Error(_) => false, + _ => bug!("tuple_fields called on non-tuple"), + }; + // A coroutine-closure implements `FnOnce` *always*, since it may + // always be called once. It additionally implements `Fn`/`FnMut` + // only if it has no upvars (therefore no borrows from the closure + // that would need to be represented with a lifetime) and if the + // closure kind permits it. + // FIXME(async_closures): Actually, it could also implement `Fn`/`FnMut` + // if it takes all of its upvars by copy, and none by ref. This would + // require us to record a bit more information during upvar analysis. if no_borrows && closure_kind.extends(kind) { candidates.vec.push(ClosureCandidate { is_const }); } else if kind == ty::ClosureKind::FnOnce { diff --git a/tests/ui/async-await/async-closures/is-not-fn.rs b/tests/ui/async-await/async-closures/is-not-fn.rs index 94c8e8563bd9e..40b0febbf0692 100644 --- a/tests/ui/async-await/async-closures/is-not-fn.rs +++ b/tests/ui/async-await/async-closures/is-not-fn.rs @@ -5,8 +5,5 @@ fn main() { fn needs_fn(x: impl FnOnce()) {} needs_fn(async || {}); - //~^ ERROR expected a `FnOnce()` closure, found `{coroutine-closure@ - // FIXME(async_closures): This should explain in more detail how async fns don't - // implement the regular `Fn` traits. Or maybe we should just fix it and make them - // when there are no upvars or whatever. + //~^ ERROR expected `{coroutine-closure@is-not-fn.rs:7:14}` to be a closure that returns `()` } diff --git a/tests/ui/async-await/async-closures/is-not-fn.stderr b/tests/ui/async-await/async-closures/is-not-fn.stderr index 12da4b1fc6fb7..6169cee85fd32 100644 --- a/tests/ui/async-await/async-closures/is-not-fn.stderr +++ b/tests/ui/async-await/async-closures/is-not-fn.stderr @@ -1,13 +1,13 @@ -error[E0277]: expected a `FnOnce()` closure, found `{coroutine-closure@$DIR/is-not-fn.rs:7:14: 7:22}` +error[E0271]: expected `{coroutine-closure@is-not-fn.rs:7:14}` to be a closure that returns `()`, but it returns `{async closure body@$DIR/is-not-fn.rs:7:23: 7:25}` --> $DIR/is-not-fn.rs:7:14 | LL | needs_fn(async || {}); - | -------- ^^^^^^^^^^^ expected an `FnOnce()` closure, found `{coroutine-closure@$DIR/is-not-fn.rs:7:14: 7:22}` + | -------- ^^^^^^^^^^^ expected `()`, found `async` closure body | | | required by a bound introduced by this call | - = help: the trait `FnOnce<()>` is not implemented for `{coroutine-closure@$DIR/is-not-fn.rs:7:14: 7:22}` - = note: wrap the `{coroutine-closure@$DIR/is-not-fn.rs:7:14: 7:22}` in a closure with no arguments: `|| { /* code */ }` + = note: expected unit type `()` + found `async` closure body `{async closure body@$DIR/is-not-fn.rs:7:23: 7:25}` note: required by a bound in `needs_fn` --> $DIR/is-not-fn.rs:6:25 | @@ -16,4 +16,4 @@ LL | fn needs_fn(x: impl FnOnce()) {} error: aborting due to 1 previous error -For more information about this error, try `rustc --explain E0277`. +For more information about this error, try `rustc --explain E0271`. diff --git a/tests/ui/async-await/async-fn/dyn-pos.stderr b/tests/ui/async-await/async-fn/dyn-pos.stderr index c93235265160b..488c5d06938f0 100644 --- a/tests/ui/async-await/async-fn/dyn-pos.stderr +++ b/tests/ui/async-await/async-fn/dyn-pos.stderr @@ -8,6 +8,9 @@ note: for a trait to be "object safe" it needs to allow building a vtable to all --> $SRC_DIR/core/src/ops/async_function.rs:LL:COL | = note: the trait cannot be made into an object because it contains the generic associated type `CallFuture` + = help: the following types implement the trait, consider defining an enum where each variant holds one of these types, implementing `AsyncFn` for this new enum and using it instead: + &F + std::boxed::Box error[E0038]: the trait `AsyncFnMut` cannot be made into an object --> $DIR/dyn-pos.rs:5:16 @@ -19,6 +22,10 @@ note: for a trait to be "object safe" it needs to allow building a vtable to all --> $SRC_DIR/core/src/ops/async_function.rs:LL:COL | = note: the trait cannot be made into an object because it contains the generic associated type `CallMutFuture` + = help: the following types implement the trait, consider defining an enum where each variant holds one of these types, implementing `AsyncFnMut` for this new enum and using it instead: + &F + &mut F + std::boxed::Box error[E0038]: the trait `AsyncFn` cannot be made into an object --> $DIR/dyn-pos.rs:5:16 @@ -30,6 +37,9 @@ note: for a trait to be "object safe" it needs to allow building a vtable to all --> $SRC_DIR/core/src/ops/async_function.rs:LL:COL | = note: the trait cannot be made into an object because it contains the generic associated type `CallFuture` + = help: the following types implement the trait, consider defining an enum where each variant holds one of these types, implementing `AsyncFn` for this new enum and using it instead: + &F + std::boxed::Box = note: duplicate diagnostic emitted due to `-Z deduplicate-diagnostics=no` error[E0038]: the trait `AsyncFnMut` cannot be made into an object @@ -42,6 +52,10 @@ note: for a trait to be "object safe" it needs to allow building a vtable to all --> $SRC_DIR/core/src/ops/async_function.rs:LL:COL | = note: the trait cannot be made into an object because it contains the generic associated type `CallMutFuture` + = help: the following types implement the trait, consider defining an enum where each variant holds one of these types, implementing `AsyncFnMut` for this new enum and using it instead: + &F + &mut F + std::boxed::Box = note: duplicate diagnostic emitted due to `-Z deduplicate-diagnostics=no` error[E0038]: the trait `AsyncFn` cannot be made into an object @@ -54,6 +68,9 @@ note: for a trait to be "object safe" it needs to allow building a vtable to all --> $SRC_DIR/core/src/ops/async_function.rs:LL:COL | = note: the trait cannot be made into an object because it contains the generic associated type `CallFuture` + = help: the following types implement the trait, consider defining an enum where each variant holds one of these types, implementing `AsyncFn` for this new enum and using it instead: + &F + std::boxed::Box = note: duplicate diagnostic emitted due to `-Z deduplicate-diagnostics=no` error[E0038]: the trait `AsyncFnMut` cannot be made into an object @@ -66,6 +83,10 @@ note: for a trait to be "object safe" it needs to allow building a vtable to all --> $SRC_DIR/core/src/ops/async_function.rs:LL:COL | = note: the trait cannot be made into an object because it contains the generic associated type `CallMutFuture` + = help: the following types implement the trait, consider defining an enum where each variant holds one of these types, implementing `AsyncFnMut` for this new enum and using it instead: + &F + &mut F + std::boxed::Box = note: duplicate diagnostic emitted due to `-Z deduplicate-diagnostics=no` error[E0038]: the trait `AsyncFn` cannot be made into an object @@ -81,6 +102,9 @@ note: for a trait to be "object safe" it needs to allow building a vtable to all ::: $SRC_DIR/core/src/ops/async_function.rs:LL:COL | = note: the trait cannot be made into an object because it contains the generic associated type `CallMutFuture` + = help: the following types implement the trait, consider defining an enum where each variant holds one of these types, implementing `AsyncFn` for this new enum and using it instead: + &F + std::boxed::Box error: aborting due to 7 previous errors From 9322882adeb232af46ecd1400ce2af4a96347425 Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Tue, 6 Feb 2024 20:51:56 +0000 Subject: [PATCH 6/7] Add a couple more tests --- tests/ui/async-await/async-closures/once.rs | 22 +++++++++++++++++++++ tests/ui/async-await/async-closures/refd.rs | 18 +++++++++++++++++ 2 files changed, 40 insertions(+) create mode 100644 tests/ui/async-await/async-closures/once.rs create mode 100644 tests/ui/async-await/async-closures/refd.rs diff --git a/tests/ui/async-await/async-closures/once.rs b/tests/ui/async-await/async-closures/once.rs new file mode 100644 index 0000000000000..a1c56c5de6afd --- /dev/null +++ b/tests/ui/async-await/async-closures/once.rs @@ -0,0 +1,22 @@ +// aux-build:block-on.rs +// edition:2021 +// build-pass + +#![feature(async_closure)] + +use std::future::Future; + +extern crate block_on; + +struct NoCopy; + +fn main() { + block_on::block_on(async { + async fn call_once(x: impl Fn(&'static str) -> F) -> F::Output { + x("hello, world").await + } + call_once(async |x: &'static str| { + println!("hello, {x}"); + }).await + }); +} diff --git a/tests/ui/async-await/async-closures/refd.rs b/tests/ui/async-await/async-closures/refd.rs new file mode 100644 index 0000000000000..7c61ff2d9bd87 --- /dev/null +++ b/tests/ui/async-await/async-closures/refd.rs @@ -0,0 +1,18 @@ +// aux-build:block-on.rs +// edition:2021 +// build-pass + +// check that `&{async-closure}` implements `AsyncFn`. + +#![feature(async_closure)] + +extern crate block_on; + +struct NoCopy; + +fn main() { + block_on::block_on(async { + async fn call_once(x: impl async Fn()) { x().await } + call_once(&async || {}).await + }); +} From 540be28f6c2571e7be3ab3936b62635fa0d3caf3 Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Thu, 8 Feb 2024 18:56:52 +0000 Subject: [PATCH 7/7] sort suggestions for object diagnostic --- compiler/rustc_infer/src/traits/error_reporting/mod.rs | 3 ++- .../ui/generic-associated-types/gat-in-trait-path.base.stderr | 2 +- tests/ui/generic-associated-types/issue-79422.base.stderr | 4 ++-- tests/ui/wf/wf-unsafe-trait-obj-match.stderr | 4 ++-- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/compiler/rustc_infer/src/traits/error_reporting/mod.rs b/compiler/rustc_infer/src/traits/error_reporting/mod.rs index eabc1b953af1c..1ceb245dcc7a6 100644 --- a/compiler/rustc_infer/src/traits/error_reporting/mod.rs +++ b/compiler/rustc_infer/src/traits/error_reporting/mod.rs @@ -178,12 +178,13 @@ pub fn report_object_safety_error<'tcx>( ))); } impls => { - let types = impls + let mut types = impls .iter() .map(|t| { with_no_trimmed_paths!(format!(" {}", tcx.type_of(*t).instantiate_identity(),)) }) .collect::>(); + types.sort(); err.help(format!( "the following types implement the trait, consider defining an enum where each \ variant holds one of these types, implementing `{}` for this new enum and using \ diff --git a/tests/ui/generic-associated-types/gat-in-trait-path.base.stderr b/tests/ui/generic-associated-types/gat-in-trait-path.base.stderr index bd3728cec8c82..e05c83ebc762d 100644 --- a/tests/ui/generic-associated-types/gat-in-trait-path.base.stderr +++ b/tests/ui/generic-associated-types/gat-in-trait-path.base.stderr @@ -13,8 +13,8 @@ LL | type A<'a> where Self: 'a; | ^ ...because it contains the generic associated type `A` = help: consider moving `A` to another trait = help: the following types implement the trait, consider defining an enum where each variant holds one of these types, implementing `Foo` for this new enum and using it instead: - Fooy Fooer + Fooy error: aborting due to 1 previous error diff --git a/tests/ui/generic-associated-types/issue-79422.base.stderr b/tests/ui/generic-associated-types/issue-79422.base.stderr index bcc6382cf7cd2..7f58f82570220 100644 --- a/tests/ui/generic-associated-types/issue-79422.base.stderr +++ b/tests/ui/generic-associated-types/issue-79422.base.stderr @@ -29,8 +29,8 @@ LL | type VRefCont<'a>: RefCont<'a, V> where Self: 'a; | ^^^^^^^^ ...because it contains the generic associated type `VRefCont` = help: consider moving `VRefCont` to another trait = help: the following types implement the trait, consider defining an enum where each variant holds one of these types, implementing `MapLike` for this new enum and using it instead: - std::collections::BTreeMap Source + std::collections::BTreeMap error[E0038]: the trait `MapLike` cannot be made into an object --> $DIR/issue-79422.rs:44:13 @@ -47,8 +47,8 @@ LL | type VRefCont<'a>: RefCont<'a, V> where Self: 'a; | ^^^^^^^^ ...because it contains the generic associated type `VRefCont` = help: consider moving `VRefCont` to another trait = help: the following types implement the trait, consider defining an enum where each variant holds one of these types, implementing `MapLike` for this new enum and using it instead: - std::collections::BTreeMap Source + std::collections::BTreeMap = note: required for the cast from `Box>` to `Box + 'static)>>` error: aborting due to 3 previous errors diff --git a/tests/ui/wf/wf-unsafe-trait-obj-match.stderr b/tests/ui/wf/wf-unsafe-trait-obj-match.stderr index a0279774abebf..3b53f55ffdc1d 100644 --- a/tests/ui/wf/wf-unsafe-trait-obj-match.stderr +++ b/tests/ui/wf/wf-unsafe-trait-obj-match.stderr @@ -30,8 +30,8 @@ LL | trait Trait: Sized {} | | | this trait cannot be made into an object... = help: the following types implement the trait, consider defining an enum where each variant holds one of these types, implementing `Trait` for this new enum and using it instead: - S R + S = note: required for the cast from `&S` to `&dyn Trait` error[E0038]: the trait `Trait` cannot be made into an object @@ -52,8 +52,8 @@ LL | trait Trait: Sized {} | | | this trait cannot be made into an object... = help: the following types implement the trait, consider defining an enum where each variant holds one of these types, implementing `Trait` for this new enum and using it instead: - S R + S = note: required for the cast from `&R` to `&dyn Trait` error: aborting due to 3 previous errors