From 986e20d5bb5df4274b390b4148aab0058081a241 Mon Sep 17 00:00:00 2001 From: Virginia Senioria <91khr@users.noreply.github.com> Date: Wed, 25 Sep 2024 07:34:53 +0000 Subject: [PATCH] Fixed diagnostics for coroutines with () as input. --- .../traits/fulfillment_errors.rs | 72 +++++++++---------- .../arg-count-mismatch-on-unit-input.rs | 11 +++ .../arg-count-mismatch-on-unit-input.stderr | 15 ++++ 3 files changed, 61 insertions(+), 37 deletions(-) create mode 100644 tests/ui/coroutine/arg-count-mismatch-on-unit-input.rs create mode 100644 tests/ui/coroutine/arg-count-mismatch-on-unit-input.stderr diff --git a/compiler/rustc_trait_selection/src/error_reporting/traits/fulfillment_errors.rs b/compiler/rustc_trait_selection/src/error_reporting/traits/fulfillment_errors.rs index 19e2679ae4da7..969f2528836d4 100644 --- a/compiler/rustc_trait_selection/src/error_reporting/traits/fulfillment_errors.rs +++ b/compiler/rustc_trait_selection/src/error_reporting/traits/fulfillment_errors.rs @@ -2635,49 +2635,47 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> { // This shouldn't be common unless manually implementing one of the // traits manually, but don't make it more confusing when it does // happen. - Ok( - if Some(expected_trait_ref.def_id) != self.tcx.lang_items().coroutine_trait() - && not_tupled - { - self.report_and_explain_type_error( - TypeTrace::trait_refs( - &obligation.cause, - true, - expected_trait_ref, - found_trait_ref, - ), - ty::error::TypeError::Mismatch, - ) - } else if found.len() == expected.len() { - self.report_closure_arg_mismatch( - span, - found_span, - found_trait_ref, - expected_trait_ref, - obligation.cause.code(), - found_node, - obligation.param_env, - ) - } else { - let (closure_span, closure_arg_span, found) = found_did - .and_then(|did| { - let node = self.tcx.hir().get_if_local(did)?; - let (found_span, closure_arg_span, found) = - self.get_fn_like_arguments(node)?; - Some((Some(found_span), closure_arg_span, found)) - }) - .unwrap_or((found_span, None, found)); - - self.report_arg_count_mismatch( + if Some(expected_trait_ref.def_id) != self.tcx.lang_items().coroutine_trait() && not_tupled + { + return Ok(self.report_and_explain_type_error( + TypeTrace::trait_refs(&obligation.cause, true, expected_trait_ref, found_trait_ref), + ty::error::TypeError::Mismatch, + )); + } + if found.len() != expected.len() { + let (closure_span, closure_arg_span, found) = found_did + .and_then(|did| { + let node = self.tcx.hir().get_if_local(did)?; + let (found_span, closure_arg_span, found) = self.get_fn_like_arguments(node)?; + Some((Some(found_span), closure_arg_span, found)) + }) + .unwrap_or((found_span, None, found)); + + // If the coroutine take a single () as its argument, + // the trait argument would found the coroutine take 0 arguments, + // but get_fn_like_arguments would give 1 argument. + // This would result in "Expected to take 1 argument, but it takes 1 argument". + // Check again to avoid this. + if found.len() != expected.len() { + return Ok(self.report_arg_count_mismatch( span, closure_span, expected, found, found_trait_ty.is_closure(), closure_arg_span, - ) - }, - ) + )); + } + } + Ok(self.report_closure_arg_mismatch( + span, + found_span, + found_trait_ref, + expected_trait_ref, + obligation.cause.code(), + found_node, + obligation.param_env, + )) } /// Given some node representing a fn-like thing in the HIR map, diff --git a/tests/ui/coroutine/arg-count-mismatch-on-unit-input.rs b/tests/ui/coroutine/arg-count-mismatch-on-unit-input.rs new file mode 100644 index 0000000000000..448c7100df657 --- /dev/null +++ b/tests/ui/coroutine/arg-count-mismatch-on-unit-input.rs @@ -0,0 +1,11 @@ +#![feature(coroutines, coroutine_trait, stmt_expr_attributes)] + +use std::ops::Coroutine; + +fn foo() -> impl Coroutine { + //~^ ERROR type mismatch in coroutine arguments + #[coroutine] + |_: ()| {} +} + +fn main() { } diff --git a/tests/ui/coroutine/arg-count-mismatch-on-unit-input.stderr b/tests/ui/coroutine/arg-count-mismatch-on-unit-input.stderr new file mode 100644 index 0000000000000..c7d6507fd7940 --- /dev/null +++ b/tests/ui/coroutine/arg-count-mismatch-on-unit-input.stderr @@ -0,0 +1,15 @@ +error[E0631]: type mismatch in coroutine arguments + --> $DIR/arg-count-mismatch-on-unit-input.rs:5:13 + | +LL | fn foo() -> impl Coroutine { + | ^^^^^^^^^^^^^^^^^^ expected due to this +... +LL | |_: ()| {} + | ------- found signature defined here + | + = note: expected coroutine signature `fn(u8) -> _` + found coroutine signature `fn(()) -> _` + +error: aborting due to 1 previous error + +For more information about this error, try `rustc --explain E0631`.