Skip to content

Commit

Permalink
Propagate Expectation around binop typeck code to construct more prec…
Browse files Browse the repository at this point in the history
…ise trait obligations for binops.
  • Loading branch information
willcrichton committed Jul 16, 2022
1 parent 8c1cc82 commit e5bb7d8
Show file tree
Hide file tree
Showing 12 changed files with 194 additions and 42 deletions.
3 changes: 2 additions & 1 deletion compiler/rustc_middle/src/traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub mod util;
use crate::infer::canonical::Canonical;
use crate::ty::abstract_const::NotConstEvaluatable;
use crate::ty::subst::SubstsRef;
use crate::ty::{self, AdtKind, Ty, TyCtxt};
use crate::ty::{self, AdtKind, Predicate, Ty, TyCtxt};

use rustc_data_structures::sync::Lrc;
use rustc_errors::{Applicability, Diagnostic};
Expand Down Expand Up @@ -414,6 +414,7 @@ pub enum ObligationCauseCode<'tcx> {
BinOp {
rhs_span: Option<Span>,
is_lit: bool,
output_pred: Option<Predicate<'tcx>>,
},
}

Expand Down
18 changes: 18 additions & 0 deletions compiler/rustc_middle/src/ty/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1033,6 +1033,24 @@ impl<'tcx> Predicate<'tcx> {
}
}

pub fn to_opt_poly_projection_pred(self) -> Option<PolyProjectionPredicate<'tcx>> {
let predicate = self.kind();
match predicate.skip_binder() {
PredicateKind::Projection(t) => Some(predicate.rebind(t)),
PredicateKind::Trait(..)
| PredicateKind::Subtype(..)
| PredicateKind::Coerce(..)
| PredicateKind::RegionOutlives(..)
| PredicateKind::WellFormed(..)
| PredicateKind::ObjectSafe(..)
| PredicateKind::ClosureKind(..)
| PredicateKind::TypeOutlives(..)
| PredicateKind::ConstEvaluatable(..)
| PredicateKind::ConstEquate(..)
| PredicateKind::TypeWellFormedFromEnv(..) => None,
}
}

pub fn to_opt_type_outlives(self) -> Option<PolyTypeOutlivesPredicate<'tcx>> {
let predicate = self.kind();
match predicate.skip_binder() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
self.suggest_restricting_param_bound(
&mut err,
trait_predicate,
None,
obligation.cause.body_id,
);
} else if !suggested {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ use rustc_middle::hir::map;
use rustc_middle::ty::{
self, suggest_arbitrary_trait_bound, suggest_constraining_type_param, AdtKind, DefIdTree,
GeneratorDiagnosticData, GeneratorInteriorTypeCause, Infer, InferTy, IsSuggestable,
ToPredicate, Ty, TyCtxt, TypeFoldable, TypeFolder, TypeSuperFoldable, TypeVisitable,
ProjectionPredicate, ToPredicate, Ty, TyCtxt, TypeFoldable, TypeFolder, TypeSuperFoldable,
TypeVisitable,
};
use rustc_middle::ty::{TypeAndMut, TypeckResults};
use rustc_session::Limit;
Expand Down Expand Up @@ -172,6 +173,7 @@ pub trait InferCtxtExt<'tcx> {
&self,
err: &mut Diagnostic,
trait_pred: ty::PolyTraitPredicate<'tcx>,
proj_pred: Option<ty::PolyProjectionPredicate<'tcx>>,
body_id: hir::HirId,
);

Expand Down Expand Up @@ -457,6 +459,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
&self,
mut err: &mut Diagnostic,
trait_pred: ty::PolyTraitPredicate<'tcx>,
proj_pred: Option<ty::PolyProjectionPredicate<'tcx>>,
body_id: hir::HirId,
) {
let trait_pred = self.resolve_numeric_literals_with_default(trait_pred);
Expand Down Expand Up @@ -589,9 +592,16 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
}
// Missing generic type parameter bound.
let param_name = self_ty.to_string();
let constraint = with_no_trimmed_paths!(
let mut constraint = with_no_trimmed_paths!(
trait_pred.print_modifiers_and_trait_path().to_string()
);

if let Some(proj_pred) = proj_pred {
let ProjectionPredicate { projection_ty, term } = proj_pred.skip_binder();
let item = self.tcx.associated_item(projection_ty.item_def_id);
constraint.push_str(&format!("<{}={}>", item.name, term.to_string()));
}

if suggest_constraining_type_param(
self.tcx,
generics,
Expand Down Expand Up @@ -2825,7 +2835,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
trait_ref: &ty::PolyTraitRef<'tcx>,
) {
let rhs_span = match obligation.cause.code() {
ObligationCauseCode::BinOp { rhs_span: Some(span), is_lit } if *is_lit => span,
ObligationCauseCode::BinOp { rhs_span: Some(span), is_lit, .. } if *is_lit => span,
_ => return,
};
match (
Expand Down
10 changes: 6 additions & 4 deletions compiler/rustc_typeck/src/check/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,11 +282,13 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
match expr.kind {
ExprKind::Box(subexpr) => self.check_expr_box(subexpr, expected),
ExprKind::Lit(ref lit) => self.check_lit(&lit, expected),
ExprKind::Binary(op, lhs, rhs) => self.check_binop(expr, op, lhs, rhs),
ExprKind::Binary(op, lhs, rhs) => self.check_binop(expr, op, lhs, rhs, expected),
ExprKind::Assign(lhs, rhs, span) => {
self.check_expr_assign(expr, expected, lhs, rhs, span)
}
ExprKind::AssignOp(op, lhs, rhs) => self.check_binop_assign(expr, op, lhs, rhs),
ExprKind::AssignOp(op, lhs, rhs) => {
self.check_binop_assign(expr, op, lhs, rhs, expected)
}
ExprKind::Unary(unop, oprnd) => self.check_expr_unary(unop, oprnd, expected, expr),
ExprKind::AddrOf(kind, mutbl, oprnd) => {
self.check_expr_addr_of(kind, mutbl, oprnd, expected, expr)
Expand Down Expand Up @@ -404,14 +406,14 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
}
}
hir::UnOp::Not => {
let result = self.check_user_unop(expr, oprnd_t, unop);
let result = self.check_user_unop(expr, oprnd_t, unop, expected_inner);
// If it's builtin, we can reuse the type, this helps inference.
if !(oprnd_t.is_integral() || *oprnd_t.kind() == ty::Bool) {
oprnd_t = result;
}
}
hir::UnOp::Neg => {
let result = self.check_user_unop(expr, oprnd_t, unop);
let result = self.check_user_unop(expr, oprnd_t, unop, expected_inner);
// If it's builtin, we can reuse the type, this helps inference.
if !oprnd_t.is_numeric() {
oprnd_t = result;
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_typeck/src/check/fn_ctxt/_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
rhs_span: opt_input_expr.map(|expr| expr.span),
is_lit: opt_input_expr
.map_or(false, |expr| matches!(expr.kind, ExprKind::Lit(_))),
output_pred: None,
},
),
self.param_env,
Expand Down
30 changes: 27 additions & 3 deletions compiler/rustc_typeck/src/check/method/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ mod suggest;
pub use self::suggest::SelfSource;
pub use self::MethodError::*;

use crate::check::FnCtxt;
use crate::check::{Expectation, FnCtxt};
use crate::ObligationCause;
use rustc_data_structures::sync::Lrc;
use rustc_errors::{Applicability, Diagnostic};
Expand All @@ -20,8 +20,10 @@ use rustc_hir::def_id::DefId;
use rustc_infer::infer::{self, InferOk};
use rustc_middle::ty::subst::Subst;
use rustc_middle::ty::subst::{InternalSubsts, SubstsRef};
use rustc_middle::ty::{self, ToPredicate, Ty, TypeVisitable};
use rustc_middle::ty::{DefIdTree, GenericParamDefKind};
use rustc_middle::ty::{
self, AssocKind, DefIdTree, GenericParamDefKind, ProjectionPredicate, ProjectionTy, Term,
ToPredicate, Ty, TypeVisitable,
};
use rustc_span::symbol::Ident;
use rustc_span::Span;
use rustc_trait_selection::traits;
Expand Down Expand Up @@ -318,6 +320,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
self_ty: Ty<'tcx>,
opt_input_type: Option<Ty<'tcx>>,
opt_input_expr: Option<&'tcx hir::Expr<'tcx>>,
expected: Expectation<'tcx>,
) -> (traits::Obligation<'tcx, ty::Predicate<'tcx>>, &'tcx ty::List<ty::subst::GenericArg<'tcx>>)
{
// Construct a trait-reference `self_ty : Trait<input_tys>`
Expand All @@ -339,6 +342,23 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {

// Construct an obligation
let poly_trait_ref = ty::Binder::dummy(trait_ref);
let opt_output_ty =
expected.only_has_type(self).and_then(|ty| (!ty.needs_infer()).then(|| ty));
let opt_output_assoc_item = self.tcx.associated_items(trait_def_id).find_by_name_and_kind(
self.tcx,
Ident::from_str("Output"),
AssocKind::Type,
trait_def_id,
);
let output_pred =
opt_output_ty.zip(opt_output_assoc_item).map(|(output_ty, output_assoc_item)| {
ty::Binder::dummy(ty::PredicateKind::Projection(ProjectionPredicate {
projection_ty: ProjectionTy { substs, item_def_id: output_assoc_item.def_id },
term: Term::Ty(output_ty),
}))
.to_predicate(self.tcx)
});

(
traits::Obligation::new(
traits::ObligationCause::new(
Expand All @@ -348,6 +368,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
rhs_span: opt_input_expr.map(|expr| expr.span),
is_lit: opt_input_expr
.map_or(false, |expr| matches!(expr.kind, hir::ExprKind::Lit(_))),
output_pred,
},
),
self.param_env,
Expand Down Expand Up @@ -397,13 +418,15 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
self_ty: Ty<'tcx>,
opt_input_type: Option<Ty<'tcx>>,
opt_input_expr: Option<&'tcx hir::Expr<'tcx>>,
expected: Expectation<'tcx>,
) -> Option<InferOk<'tcx, MethodCallee<'tcx>>> {
let (obligation, substs) = self.obligation_for_op_method(
span,
trait_def_id,
self_ty,
opt_input_type,
opt_input_expr,
expected,
);
self.construct_obligation_for_trait(
span,
Expand Down Expand Up @@ -505,6 +528,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
rhs_span: opt_input_expr.map(|expr| expr.span),
is_lit: opt_input_expr
.map_or(false, |expr| matches!(expr.kind, hir::ExprKind::Lit(_))),
output_pred: None,
},
)
} else {
Expand Down
Loading

0 comments on commit e5bb7d8

Please sign in to comment.