Skip to content

Commit

Permalink
Solve ProjectionPredicate during normalization (#830)
Browse files Browse the repository at this point in the history
* add test

* fix #829 hopefully

* Update crates/flux-middle/src/rty/projections.rs

Co-authored-by: Nico Lehmann <[email protected]>

* change todo to tracked-span-bug
  • Loading branch information
ranjitjhala authored Oct 1, 2024
1 parent 9695996 commit b7f0858
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 12 deletions.
2 changes: 1 addition & 1 deletion crates/flux-driver/src/collector/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use rustc_errors::ErrorGuaranteed;
use rustc_hir::{
self as hir,
def::DefKind,
def_id::{DefId, LocalDefId, CRATE_DEF_ID},
def_id::{LocalDefId, CRATE_DEF_ID},
EnumDef, ImplItemKind, Item, ItemKind, OwnerId, VariantData, CRATE_OWNER_ID,
};
use rustc_middle::ty::TyCtxt;
Expand Down
115 changes: 109 additions & 6 deletions crates/flux-middle/src/rty/projections.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ use rustc_trait_selection::traits::SelectionContext;

use super::{
fold::{FallibleTypeFolder, TypeFoldable, TypeSuperFoldable},
AliasKind, AliasReft, AliasTy, BaseTy, Binder, Clause, ClauseKind, Const, Expr, ExprKind,
GenericArg, ProjectionPredicate, RefineArgs, Region, SubsetTy, Ty, TyKind,
subst::{GenericsSubstDelegate, GenericsSubstFolder},
AliasKind, AliasReft, AliasTy, BaseTy, Binder, Clause, ClauseKind, Const, EarlyBinder, Expr,
ExprKind, GenericArg, ProjectionPredicate, RefineArgs, Region, SubsetTy, Ty, TyKind,
};
use crate::{
global_env::GlobalEnv,
Expand Down Expand Up @@ -136,7 +137,65 @@ impl<'genv, 'tcx, 'cx> Normalizer<'genv, 'tcx, 'cx> {
Ok((true, ty))
}

fn confirm_candidate(&self, candidate: Candidate, obligation: &AliasTy) -> QueryResult<Ty> {
fn find_resolved_predicates(
&self,
subst: &mut TVarSubst,
preds: Vec<EarlyBinder<ProjectionPredicate>>,
) -> (Vec<ProjectionPredicate>, Vec<EarlyBinder<ProjectionPredicate>>) {
let mut resolved = vec![];
let mut unresolved = vec![];
for pred in preds {
let term = pred.clone().skip_binder().term;
let alias_ty = pred.clone().map(|p| p.projection_ty);
match subst.instantiate_partial(alias_ty) {
Some(projection_ty) => {
let pred = ProjectionPredicate { projection_ty, term };
resolved.push(pred);
}
None => unresolved.push(pred.clone()),
}
}
(resolved, unresolved)
}

// See issue-829*.rs for an example of what this function is for.
fn resolve_projection_predicates(
&mut self,
subst: &mut TVarSubst,
impl_def_id: DefId,
) -> QueryResult {
let mut projection_preds: Vec<_> = self
.genv
.predicates_of(impl_def_id)?
.skip_binder()
.predicates
.iter()
.filter_map(|pred| {
if let ClauseKind::Projection(pred) = pred.kind_skipping_binder() {
Some(EarlyBinder(pred.clone()))
} else {
None
}
})
.collect();

while !projection_preds.is_empty() {
let (resolved, unresolved) = self.find_resolved_predicates(subst, projection_preds);

if resolved.is_empty() {
break; // failed: there is some unresolved projection pred!
}
for p in resolved {
let obligation = &p.projection_ty;
let (_, ty) = self.normalize_projection_ty(obligation)?;
subst.tys(&p.term, &ty);
}
projection_preds = unresolved;
}
Ok(())
}

fn confirm_candidate(&mut self, candidate: Candidate, obligation: &AliasTy) -> QueryResult<Ty> {
match candidate {
Candidate::ParamEnv(pred) | Candidate::TraitDef(pred) => Ok(pred.term),
Candidate::UserDefinedImpl(impl_def_id) => {
Expand All @@ -145,9 +204,9 @@ impl<'genv, 'tcx, 'cx> Normalizer<'genv, 'tcx, 'cx> {
// and the id of a rust impl block
// impl<T, A: Allocator> Iterator for IntoIter<T, A>

// 1. Match the self type of the rust impl block and the flux self type of the obligation
// 1. MATCH the self type of the rust impl block and the flux self type of the obligation
// to infer a substitution
// IntoIter<{v. i32[v] | v > 0}, Global> against IntoIter<T, A>
// IntoIter<{v. i32[v] | v > 0}, Global> MATCH IntoIter<T, A>
// => {T -> {v. i32[v] | v > 0}, A -> Global}

let impl_trait_ref = self
Expand All @@ -162,9 +221,13 @@ impl<'genv, 'tcx, 'cx> Normalizer<'genv, 'tcx, 'cx> {
for (a, b) in iter::zip(&impl_trait_ref.args, &obligation.args) {
subst.generic_args(a, b);
}

// 2. Gather the ProjectionPredicates and solve them see issue-808.rs
self.resolve_projection_predicates(&mut subst, impl_def_id)?;

let args = subst.finish(self.tcx(), generics);

// 2. Get the associated type in the impl block and apply the substitution to it
// 3. Get the associated type in the impl block and apply the substitution to it
let assoc_type_id = self
.tcx()
.associated_items(impl_def_id)
Expand Down Expand Up @@ -316,11 +379,51 @@ struct TVarSubst {
args: Vec<Option<GenericArg>>,
}

impl GenericsSubstDelegate for &TVarSubst {
type Error = ();

fn ty_for_param(&mut self, param_ty: rustc_middle::ty::ParamTy) -> Result<Ty, Self::Error> {
match self.args.get(param_ty.index as usize) {
Some(Some(GenericArg::Ty(ty))) => Ok(ty.clone()),
Some(None) => Err(()),
arg => tracked_span_bug!("expected type for generic parameter, found `{arg:?}`"),
}
}

fn sort_for_param(
&mut self,
_param_ty: rustc_middle::ty::ParamTy,
) -> Result<super::Sort, Self::Error> {
tracked_span_bug!()
}

fn ctor_for_param(&mut self, _param_ty: rustc_middle::ty::ParamTy) -> super::SubsetTyCtor {
tracked_span_bug!()
}

fn region_for_param(&mut self, _ebr: rustc_middle::ty::EarlyParamRegion) -> Region {
tracked_span_bug!()
}

fn expr_for_param_const(&self, _param_const: rustc_middle::ty::ParamConst) -> Expr {
tracked_span_bug!()
}

fn const_for_param(&mut self, _param: &Const) -> Const {
tracked_span_bug!()
}
}

impl TVarSubst {
fn new(generics: &rustc_middle::ty::Generics) -> Self {
Self { args: vec![None; generics.count()] }
}

fn instantiate_partial<T: TypeFoldable>(&mut self, pred: EarlyBinder<T>) -> Option<T> {
let mut folder = GenericsSubstFolder::new(&*self, &[]);
pred.skip_binder().try_fold_with(&mut folder).ok()
}

fn finish<'tcx>(
self,
tcx: TyCtxt<'tcx>,
Expand Down
10 changes: 5 additions & 5 deletions crates/flux-middle/src/rty/subst.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ pub trait GenericsSubstDelegate {
type Error = !;

fn sort_for_param(&mut self, param_ty: ParamTy) -> Result<Sort, Self::Error>;
fn ty_for_param(&mut self, param_ty: ParamTy) -> Ty;
fn ty_for_param(&mut self, param_ty: ParamTy) -> Result<Ty, Self::Error>;
fn ctor_for_param(&mut self, param_ty: ParamTy) -> SubsetTyCtor;
fn region_for_param(&mut self, ebr: EarlyParamRegion) -> Region;
fn expr_for_param_const(&self, param_const: ParamConst) -> Expr;
Expand All @@ -358,9 +358,9 @@ impl<'a, 'tcx> GenericsSubstDelegate for GenericArgsDelegate<'a, 'tcx> {
}
}

fn ty_for_param(&mut self, param_ty: ParamTy) -> Ty {
fn ty_for_param(&mut self, param_ty: ParamTy) -> Result<Ty, !> {
match self.0.get(param_ty.index as usize) {
Some(GenericArg::Ty(ty)) => ty.clone(),
Some(GenericArg::Ty(ty)) => Ok(ty.clone()),
Some(arg) => tracked_span_bug!("expected type for generic parameter, found `{arg:?}`"),
None => tracked_span_bug!("type parameter out of range {param_ty:?}"),
}
Expand Down Expand Up @@ -433,7 +433,7 @@ where
(self.sort_for_param)(param_ty)
}

fn ty_for_param(&mut self, param_ty: ParamTy) -> Ty {
fn ty_for_param(&mut self, param_ty: ParamTy) -> Result<Ty, E> {
bug!("unexpected type param {param_ty:?}");
}

Expand Down Expand Up @@ -497,7 +497,7 @@ impl<D: GenericsSubstDelegate> FallibleTypeFolder for GenericsSubstFolder<'_, D>

fn try_fold_ty(&mut self, ty: &Ty) -> Result<Ty, D::Error> {
match ty.kind() {
TyKind::Param(param_ty) => Ok(self.delegate.ty_for_param(*param_ty)),
TyKind::Param(param_ty) => self.delegate.ty_for_param(*param_ty),
TyKind::Indexed(BaseTy::Param(param_ty), idx) => {
let idx = idx.try_fold_with(self)?;
Ok(self
Expand Down
24 changes: 24 additions & 0 deletions tests/tests/pos/surface/issue-829.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
trait Trait1 {
type Assoc1;
}

impl Trait1 for i32 {
type Assoc1 = bool;
}

trait Trait2 {
type Assoc2;
}

struct S<T> {
fld: T,
}

impl<T1, T2> Trait2 for S<T2>
where
T2: Trait1<Assoc1 = T1>,
{
type Assoc2 = T1;
}

fn test(x: <S<i32> as Trait2>::Assoc2) {}
25 changes: 25 additions & 0 deletions tests/tests/pos/surface/issue-829b.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
trait Trait1 {
type Assoc1;
}

impl Trait1 for i32 {
type Assoc1 = i32;
}

trait Trait2 {
type Assoc2;
}

struct S<T> {
fld: T,
}

impl<T1, T2, T3> Trait2 for S<T1>
where
T2: Trait1<Assoc1 = T3>,
T1: Trait1<Assoc1 = T2>,
{
type Assoc2 = T1;
}

fn test(x: <S<i32> as Trait2>::Assoc2) {}

0 comments on commit b7f0858

Please sign in to comment.