Skip to content

Commit

Permalink
Barebones function subtyping (#800)
Browse files Browse the repository at this point in the history
* fndef00 works
* merge done 
* fmt
* inline btys

---------

Co-authored-by: Nico Lehmann <[email protected]>
  • Loading branch information
ranjitjhala and nilehmann authored Sep 17, 2024
1 parent 38ba4e0 commit eaf52f0
Show file tree
Hide file tree
Showing 12 changed files with 223 additions and 24 deletions.
16 changes: 10 additions & 6 deletions crates/flux-infer/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -494,20 +494,24 @@ impl Sub {
debug_assert_eq!(uint_ty_a, uint_ty_b);
Ok(())
}
(BaseTy::Adt(adt_a, args_a), BaseTy::Adt(adt_b, args_b)) => {
debug_assert_eq!(adt_a.did(), adt_b.did());
debug_assert_eq!(args_a.len(), args_b.len());
let variances = infcx.genv.variances_of(adt_a.did());
for (variance, ty_a, ty_b) in izip!(variances, args_a.iter(), args_b.iter()) {
(BaseTy::Adt(a_adt, a_args), BaseTy::Adt(b_adt, b_args)) => {
debug_assert_eq!(a_adt.did(), b_adt.did());
debug_assert_eq!(a_args.len(), b_args.len());
let variances = infcx.genv.variances_of(a_adt.did());
for (variance, ty_a, ty_b) in izip!(variances, a_args.iter(), b_args.iter()) {
self.generic_args(infcx, *variance, ty_a, ty_b)?;
}
Ok(())
}
(BaseTy::FnDef(a_def_id, a_args), BaseTy::FnDef(b_def_id, b_args)) => {
debug_assert_eq!(a_def_id, b_def_id);
assert_eq!(a_args, b_args);
Ok(())
}
(BaseTy::Float(float_ty1), BaseTy::Float(float_ty2)) => {
debug_assert_eq!(float_ty1, float_ty2);
Ok(())
}

(BaseTy::Slice(ty_a), BaseTy::Slice(ty_b)) => self.tys(infcx, ty_a, ty_b),
(BaseTy::Ref(_, ty_a, Mutability::Mut), BaseTy::Ref(_, ty_b, Mutability::Mut)) => {
self.tys(infcx, ty_a, ty_b)?;
Expand Down
2 changes: 2 additions & 0 deletions crates/flux-middle/src/rty/fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,7 @@ impl TypeSuperVisitable for BaseTy {
fn super_visit_with<V: TypeVisitor>(&self, visitor: &mut V) -> ControlFlow<V::BreakTy> {
match self {
BaseTy::Adt(_, args) => args.visit_with(visitor),
BaseTy::FnDef(_, args) => args.visit_with(visitor),
BaseTy::Slice(ty) => ty.visit_with(visitor),
BaseTy::RawPtr(ty, _) => ty.visit_with(visitor),
BaseTy::Ref(_, ty, _) => ty.visit_with(visitor),
Expand Down Expand Up @@ -748,6 +749,7 @@ impl TypeSuperFoldable for BaseTy {
fn try_super_fold_with<F: FallibleTypeFolder>(&self, folder: &mut F) -> Result<Self, F::Error> {
let bty = match self {
BaseTy::Adt(adt_def, args) => BaseTy::adt(adt_def.clone(), args.try_fold_with(folder)?),
BaseTy::FnDef(def_id, args) => BaseTy::fn_def(*def_id, args.try_fold_with(folder)?),
BaseTy::Slice(ty) => BaseTy::Slice(ty.try_fold_with(folder)?),
BaseTy::RawPtr(ty, mu) => BaseTy::RawPtr(ty.try_fold_with(folder)?, *mu),
BaseTy::Ref(re, ty, mutbl) => {
Expand Down
24 changes: 24 additions & 0 deletions crates/flux-middle/src/rty/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,20 @@ pub struct FnTraitPredicate {
}

impl FnTraitPredicate {
pub fn fndef_poly_sig(&self) -> PolyFnSig {
let inputs = self.tupled_args.expect_tuple().iter().cloned().collect();

let fn_sig = FnSig::new(
Safety::Safe,
abi::Abi::Rust,
List::empty(),
inputs,
Binder::new(FnOutput::new(self.output.clone(), vec![]), List::empty()),
);

PolyFnSig::new(fn_sig, List::empty())
}

pub fn to_poly_fn_sig(&self, closure_id: DefId, tys: List<Ty>) -> PolyFnSig {
let mut vars = vec![];

Expand Down Expand Up @@ -1262,6 +1276,7 @@ pub enum BaseTy {
RawPtr(Ty, Mutability),
Ref(Region, Ty, Mutability),
FnPtr(PolyFnSig),
FnDef(DefId, GenericArgs),
Tuple(List<Ty>),
Array(Ty, Const),
Never,
Expand All @@ -1276,6 +1291,10 @@ impl BaseTy {
BaseTy::Adt(adt_def, args.into())
}

pub fn fn_def(def_id: DefId, args: impl Into<GenericArgs>) -> BaseTy {
BaseTy::FnDef(def_id, args.into())
}

pub fn from_primitive_str(s: &str) -> Option<BaseTy> {
match s {
"i8" => Some(BaseTy::Int(IntTy::I8)),
Expand Down Expand Up @@ -1423,6 +1442,7 @@ impl BaseTy {
| BaseTy::RawPtr(..)
| BaseTy::Ref(..)
| BaseTy::FnPtr(..)
| BaseTy::FnDef(..)
| BaseTy::Tuple(_)
| BaseTy::Array(_, _)
| BaseTy::Closure(_, _)
Expand Down Expand Up @@ -1461,6 +1481,10 @@ impl<'tcx> ToRustc<'tcx> for BaseTy {
let args = args.to_rustc(tcx);
ty::Ty::new_adt(tcx, adt_def, args)
}
BaseTy::FnDef(def_id, args) => {
let args = args.to_rustc(tcx);
ty::Ty::new_fn_def(tcx, *def_id, args)
}
BaseTy::Float(f) => ty::Ty::new_float(tcx, *f),
BaseTy::RawPtr(ty, mutbl) => ty::Ty::new_ptr(tcx, ty.to_rustc(tcx), *mutbl),
BaseTy::Ref(re, ty, mutbl) => {
Expand Down
3 changes: 3 additions & 0 deletions crates/flux-middle/src/rty/pretty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,9 @@ impl Pretty for BaseTy {
}
Ok(())
}
BaseTy::FnDef(def_id, args) => {
w!("FnDef({:?}[{:?}])", def_id, join!(", ", args))
}
BaseTy::Param(param) => w!("{}", ^param),
BaseTy::Float(float_ty) => w!("{}", ^float_ty.name_str()),
BaseTy::Slice(ty) => w!("[{:?}]", ty),
Expand Down
4 changes: 4 additions & 0 deletions crates/flux-middle/src/rty/refining.rs
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,10 @@ impl<'genv, 'tcx> Refiner<'genv, 'tcx> {
let args = self.refine_generic_args(adt_def.did(), args)?;
rty::BaseTy::adt(adt_def, args)
}
ty::TyKind::FnDef(def_id, args) => {
let args = self.refine_generic_args(*def_id, args)?;
rty::BaseTy::fn_def(*def_id, args)
}
ty::TyKind::Alias(alias_kind, alias_ty) => {
let kind = Self::refine_alias_kind(alias_kind);
let alias_ty = self.as_default().refine_alias_ty(alias_kind, alias_ty)?;
Expand Down
151 changes: 133 additions & 18 deletions crates/flux-refineck/src/checker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ impl<'ck, 'genv, 'tcx, M: Mode> Checker<'ck, 'genv, 'tcx, M> {
let evars_sol = infcx.pop_scope().with_span(span)?;
infcx.replace_evars(&evars_sol);

self.check_closure_clauses(infcx, infcx.snapshot(), &obligations)
self.check_closure_clauses(infcx, infcx.snapshot(), &obligations, span)
}

#[expect(clippy::too_many_arguments)]
Expand Down Expand Up @@ -616,7 +616,7 @@ impl<'ck, 'genv, 'tcx, M: Mode> Checker<'ck, 'genv, 'tcx, M> {
Ensures::Pred(e) => infcx.assume_pred(e),
}
}
self.check_closure_clauses(infcx, snapshot, &clauses)?;
self.check_closure_clauses(infcx, snapshot, &clauses, span)?;

Ok(output.ret)
}
Expand All @@ -638,27 +638,141 @@ impl<'ck, 'genv, 'tcx, M: Mode> Checker<'ck, 'genv, 'tcx, M> {
)
}

/// The function `check_oblig_fn_def` does a function subtyping check between
/// the sub-type (T_f) corresponding to the type of `def_id` @ `args` and the
/// super-type (T_g) corresponding to the `oblig_sig`. This subtyping is handled
/// as akin to the code
///
/// T_f := (S1,...,Sn) -> S
/// T_g := (T1,...,Tn) -> T
/// T_f <: T_g
///
/// fn g(x1:T1,...,xn:Tn) -> T {
/// f(x1,...,xn)
/// }
/// TODO: copy rules from SLACK.
fn check_oblig_fn_def(
&mut self,
infcx: &mut InferCtxt<'_, 'genv, 'tcx>,
def_id: &DefId,
generic_args: &[GenericArg],
oblig_sig: rty::PolyFnSig,
span: Span,
) -> Result {
let mut infcx = infcx.at(span);
let genv = self.genv;
let tcx = genv.tcx();
let fn_def_sig = self.genv.fn_sig(*def_id).with_span(span)?;

let oblig_sig = oblig_sig
.replace_bound_vars(|_| rty::ReErased, |sort, _| infcx.define_vars(sort))
.normalize_projections(genv, infcx.region_infcx, infcx.def_id)
.with_span(span)?;

// 1. Unpack `T_g` input types
let actuals = oblig_sig
.inputs()
.iter()
.map(|ty| infcx.unpack(ty))
.collect_vec();

// 2. Fresh names for `T_f` refine-params / Instantiate fn_def_sig and normalize it
infcx.push_scope();
let refine_args = infcx.instantiate_refine_args(*def_id).with_span(span)?;
let fn_def_sig = fn_def_sig
.instantiate(tcx, generic_args, &refine_args)
.replace_bound_vars(|_| rty::ReErased, |sort, mode| infcx.fresh_infer_var(sort, mode))
.normalize_projections(genv, infcx.region_infcx, infcx.def_id)
.with_span(span)?;

// 3. INPUT subtyping (g-input <: f-input)
// TODO: Check requires predicates (?)
// for requires in fn_def_sig.requires() {
// at.check_pred(requires, ConstrReason::Call);
// }
assert!(fn_def_sig.requires().is_empty()); // TODO
for (actual, formal) in iter::zip(actuals, fn_def_sig.inputs()) {
let (formal, pred) = formal.unconstr();
infcx.check_pred(&pred, ConstrReason::Call);
// see: TODO(pack-closure)
match (actual.kind(), formal.kind()) {
(TyKind::Ptr(PtrKind::Mut(_), _), _) => {
bug!("Not yet handled: FnDef subtyping with Ptr");
}
_ => {
infcx
.subtyping(&actual, &formal, ConstrReason::Call)
.with_span(span)?;
}
}
}

// // TODO: (RJ) this doesn't feel right to me... can one even *have* clauses here?
// let clauses = self
// .genv
// .predicates_of(*def_id)
// .with_span(span)?
// .predicates()
// .instantiate(self.genv.tcx(), &generic_args, &refine_args);
// at.check_non_closure_clauses(&clauses, ConstrReason::Call)
// .with_span(span)?;

// 4. Plug in the EVAR solution / replace evars
let evars_sol = infcx.pop_scope().with_span(span)?;
infcx.replace_evars(&evars_sol);
let output = fn_def_sig
.output()
.replace_evars(&evars_sol)
.replace_bound_refts_with(|sort, _, _| infcx.define_vars(sort));

// 5. OUTPUT subtyping (f_out <: g_out)
// RJ: new `at` to avoid borrowing errors...!
// let mut at = infcx.at(span);
let oblig_output = oblig_sig
.output()
.replace_bound_refts_with(|sort, mode, _| infcx.fresh_infer_var(sort, mode));
infcx
.subtyping(&output.ret, &oblig_output.ret, ConstrReason::Ret)
.with_span(span)?;
assert!(output.ensures.is_empty()); // TODO
assert!(oblig_output.ensures.is_empty()); // TODO
Ok(())
}

fn check_oblig_fn_trait_pred(
&mut self,
infcx: &mut InferCtxt<'_, 'genv, 'tcx>,
snapshot: &Snapshot,
fn_trait_pred: FnTraitPredicate,
span: Span,
) -> Result {
if let Some(BaseTy::Closure(closure_id, tys)) =
fn_trait_pred.self_ty.as_bty_skipping_existentials()
{
let span = self.genv.tcx().def_span(closure_id);
let body = self.genv.mir(closure_id.expect_local()).with_span(span)?;
Checker::run(
infcx.change_item(closure_id.expect_local(), &body.infcx, snapshot),
closure_id.expect_local(),
self.inherited.reborrow(),
fn_trait_pred.to_poly_fn_sig(*closure_id, tys.clone()),
)?;
} else {
// TODO: When we allow refining closure/fn at the surface level, we would need to do
// actual function subtyping here, but for now, we can skip as all the relevant types
// are unrefined. See issue-767.rs
let self_ty = fn_trait_pred.self_ty.as_bty_skipping_existentials();
match self_ty {
Some(BaseTy::Closure(closure_id, tys)) => {
let span = self.genv.tcx().def_span(closure_id);
let body = self.genv.mir(closure_id.expect_local()).with_span(span)?;
Checker::run(
infcx.change_item(closure_id.expect_local(), &body.infcx, snapshot),
closure_id.expect_local(),
self.inherited.reborrow(),
fn_trait_pred.to_poly_fn_sig(*closure_id, tys.clone()),
)?;
}
Some(BaseTy::FnDef(def_id, args)) => {
// Generates "function subtyping" obligations between the (super-type) `oblig_sig` in the `fn_trait_pred`
// and the (sub-type) corresponding to the signature of `def_id + args`.
// See `tests/neg/surface/fndef00.rs`
let oblig_sig = fn_trait_pred
.fndef_poly_sig()
.normalize_projections(infcx.genv, infcx.region_infcx, infcx.def_id)
.with_span(self.body.span())?;
self.check_oblig_fn_def(infcx, def_id, args, oblig_sig, span)?;
}
_ => {
// TODO: When we allow refining closure/fn at the surface level, we would need to do some function subtyping here,
// but for now, we can skip as all the relevant types are unrefined.
// See issue-767.rs
}
}
Ok(())
}
Expand All @@ -668,11 +782,12 @@ impl<'ck, 'genv, 'tcx, M: Mode> Checker<'ck, 'genv, 'tcx, M> {
infcx: &mut InferCtxt<'_, 'genv, 'tcx>,
snapshot: Snapshot,
clauses: &[Clause],
span: Span,
) -> Result {
for clause in clauses {
match clause.kind_skipping_binder() {
rty::ClauseKind::FnTrait(fn_trait_pred) => {
self.check_oblig_fn_trait_pred(infcx, &snapshot, fn_trait_pred)?;
self.check_oblig_fn_trait_pred(infcx, &snapshot, fn_trait_pred, span)?;
}
rty::ClauseKind::CoroutineOblig(gen_pred) => {
self.check_oblig_generator_pred(infcx, &snapshot, gen_pred)?;
Expand Down
8 changes: 8 additions & 0 deletions crates/flux-refineck/src/type_env.rs
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,14 @@ impl BasicBlockEnvShape {
);
BaseTy::adt(adt_def.clone(), substs)
}
BaseTy::FnDef(def_id, args) => {
let args = List::from_vec(
args.iter()
.map(|arg| Self::pack_generic_arg(scope, arg))
.collect(),
);
BaseTy::fn_def(*def_id, args)
}
BaseTy::Tuple(tys) => {
let tys = tys
.iter()
Expand Down
4 changes: 4 additions & 0 deletions crates/flux-rustc-bridge/src/lowering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,10 @@ impl<'tcx> Lower<'tcx> for rustc_ty::Ty<'tcx> {
let args = args.lower(tcx)?;
Ok(Ty::mk_adt(adt_def.lower(tcx), args))
}
rustc_ty::FnDef(def_id, args) => {
let args = args.lower(tcx)?;
Ok(Ty::mk_fn_def(*def_id, args))
}
rustc_ty::Never => Ok(Ty::mk_never()),
rustc_ty::Str => Ok(Ty::mk_str()),
rustc_ty::Char => Ok(Ty::mk_char()),
Expand Down
12 changes: 12 additions & 0 deletions crates/flux-rustc-bridge/src/ty/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ pub enum TyKind {
Uint(UintTy),
Slice(Ty),
FnPtr(PolyFnSig),
FnDef(DefId, GenericArgs),
Closure(DefId, GenericArgs),
Coroutine(DefId, GenericArgs),
CoroutineWitness(DefId, GenericArgs),
Expand Down Expand Up @@ -663,6 +664,10 @@ impl Ty {
TyKind::Closure(def_id, args.into()).intern()
}

pub fn mk_fn_def(def_id: DefId, args: impl Into<GenericArgs>) -> Ty {
TyKind::FnDef(def_id, args.into()).intern()
}

pub fn mk_coroutine(def_id: DefId, args: impl Into<GenericArgs>) -> Ty {
TyKind::Coroutine(def_id, args.into()).intern()
}
Expand Down Expand Up @@ -788,6 +793,10 @@ impl<'tcx> ToRustc<'tcx> for Ty {
let args = tcx.mk_args_from_iter(args.iter().map(|arg| arg.to_rustc(tcx)));
rustc_ty::TyKind::Adt(adt_def, args)
}
TyKind::FnDef(def_id, args) => {
let args = tcx.mk_args_from_iter(args.iter().map(|arg| arg.to_rustc(tcx)));
rustc_ty::TyKind::FnDef(*def_id, args)
}
TyKind::Array(ty, len) => {
let ty = ty.to_rustc(tcx);
let len = len.to_rustc(tcx);
Expand Down Expand Up @@ -913,6 +922,9 @@ impl fmt::Debug for Ty {
}
Ok(())
}
TyKind::FnDef(def_id, args) => {
write!(f, "FnDef({:?}[{:?}])", def_id, args.iter().format(", "))
}
TyKind::Bool => write!(f, "bool"),
TyKind::Str => write!(f, "str"),
TyKind::Char => write!(f, "char"),
Expand Down
1 change: 1 addition & 0 deletions crates/flux-rustc-bridge/src/ty/subst.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ impl Subst for Ty {
fn subst(&self, args: &[GenericArg]) -> Ty {
match self.kind() {
TyKind::Adt(adt_def, args2) => Ty::mk_adt(adt_def.clone(), args2.subst(args)),
TyKind::FnDef(def_id, args2) => Ty::mk_fn_def(*def_id, args2.subst(args)),
TyKind::Array(ty, len) => Ty::mk_array(ty.subst(args), len.clone()),
TyKind::Ref(re, ty, mutbl) => Ty::mk_ref(*re, ty.subst(args), *mutbl),
TyKind::Tuple(tys) => Ty::mk_tuple(tys.subst(args)),
Expand Down
Loading

0 comments on commit eaf52f0

Please sign in to comment.