Skip to content

Commit

Permalink
Rollup merge of rust-lang#40025 - est31:master, r=eddyb
Browse files Browse the repository at this point in the history
Implement non-capturing closure to fn coercion

Implements non capturing closure coercion ([RFC 1558](https://github.com/rust-lang/rfcs/blob/master/text/1558-closure-to-fn-coercion.md)).

cc tracking issue rust-lang#39817
  • Loading branch information
eddyb authored Feb 25, 2017
2 parents 32af4ce + 77f131d commit c3075f3
Show file tree
Hide file tree
Showing 22 changed files with 275 additions and 3 deletions.
1 change: 1 addition & 0 deletions src/librustc/middle/expr_use_visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,7 @@ impl<'a, 'gcx, 'tcx> ExprUseVisitor<'a, 'gcx, 'tcx> {
adjustment::Adjust::NeverToAny |
adjustment::Adjust::ReifyFnPointer |
adjustment::Adjust::UnsafeFnPointer |
adjustment::Adjust::ClosureFnPointer |
adjustment::Adjust::MutToConstPointer => {
// Creating a closure/fn-pointer or unsizing consumes
// the input and stores it into the resulting rvalue.
Expand Down
1 change: 1 addition & 0 deletions src/librustc/middle/mem_categorization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,7 @@ impl<'a, 'gcx, 'tcx> MemCategorizationContext<'a, 'gcx, 'tcx> {
adjustment::Adjust::NeverToAny |
adjustment::Adjust::ReifyFnPointer |
adjustment::Adjust::UnsafeFnPointer |
adjustment::Adjust::ClosureFnPointer |
adjustment::Adjust::MutToConstPointer |
adjustment::Adjust::DerefRef {..} => {
debug!("cat_expr({:?}): {:?}",
Expand Down
3 changes: 3 additions & 0 deletions src/librustc/mir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1022,6 +1022,9 @@ pub enum CastKind {
/// Convert unique, zero-sized type for a fn to fn()
ReifyFnPointer,

/// Convert non capturing closure to fn()
ClosureFnPointer,

/// Convert safe fn() to unsafe fn()
UnsafeFnPointer,

Expand Down
4 changes: 4 additions & 0 deletions src/librustc/ty/adjustment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ pub enum Adjust<'tcx> {
/// Go from a safe fn pointer to an unsafe fn pointer.
UnsafeFnPointer,

// Go from a non-capturing closure to an fn pointer.
ClosureFnPointer,

/// Go from a mut raw pointer to a const raw pointer.
MutToConstPointer,

Expand Down Expand Up @@ -120,6 +123,7 @@ impl<'tcx> Adjustment<'tcx> {

Adjust::ReifyFnPointer |
Adjust::UnsafeFnPointer |
Adjust::ClosureFnPointer |
Adjust::MutToConstPointer |
Adjust::DerefRef {..} => false,
}
Expand Down
1 change: 1 addition & 0 deletions src/librustc_mir/build/expr/as_lvalue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ impl<'a, 'gcx, 'tcx> Builder<'a, 'gcx, 'tcx> {
ExprKind::Use { .. } |
ExprKind::NeverToAny { .. } |
ExprKind::ReifyFnPointer { .. } |
ExprKind::ClosureFnPointer { .. } |
ExprKind::UnsafeFnPointer { .. } |
ExprKind::Unsize { .. } |
ExprKind::Repeat { .. } |
Expand Down
4 changes: 4 additions & 0 deletions src/librustc_mir/build/expr/as_rvalue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ impl<'a, 'gcx, 'tcx> Builder<'a, 'gcx, 'tcx> {
let source = unpack!(block = this.as_operand(block, source));
block.and(Rvalue::Cast(CastKind::UnsafeFnPointer, source, expr.ty))
}
ExprKind::ClosureFnPointer { source } => {
let source = unpack!(block = this.as_operand(block, source));
block.and(Rvalue::Cast(CastKind::ClosureFnPointer, source, expr.ty))
}
ExprKind::Unsize { source } => {
let source = unpack!(block = this.as_operand(block, source));
block.and(Rvalue::Cast(CastKind::Unsize, source, expr.ty))
Expand Down
1 change: 1 addition & 0 deletions src/librustc_mir/build/expr/category.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ impl Category {
ExprKind::Cast { .. } |
ExprKind::Use { .. } |
ExprKind::ReifyFnPointer { .. } |
ExprKind::ClosureFnPointer { .. } |
ExprKind::UnsafeFnPointer { .. } |
ExprKind::Unsize { .. } |
ExprKind::Repeat { .. } |
Expand Down
1 change: 1 addition & 0 deletions src/librustc_mir/build/expr/into.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ impl<'a, 'gcx, 'tcx> Builder<'a, 'gcx, 'tcx> {
ExprKind::Cast { .. } |
ExprKind::Use { .. } |
ExprKind::ReifyFnPointer { .. } |
ExprKind::ClosureFnPointer { .. } |
ExprKind::UnsafeFnPointer { .. } |
ExprKind::Unsize { .. } |
ExprKind::Repeat { .. } |
Expand Down
9 changes: 9 additions & 0 deletions src/librustc_mir/hair/cx/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,15 @@ impl<'tcx> Mirror<'tcx> for &'tcx hir::Expr {
kind: ExprKind::UnsafeFnPointer { source: expr.to_ref() },
};
}
Some((ty::adjustment::Adjust::ClosureFnPointer, adjusted_ty)) => {
expr = Expr {
temp_lifetime: temp_lifetime,
temp_lifetime_was_shrunk: was_shrunk,
ty: adjusted_ty,
span: self.span,
kind: ExprKind::ClosureFnPointer { source: expr.to_ref() },
};
}
Some((ty::adjustment::Adjust::NeverToAny, adjusted_ty)) => {
expr = Expr {
temp_lifetime: temp_lifetime,
Expand Down
3 changes: 3 additions & 0 deletions src/librustc_mir/hair/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ pub enum ExprKind<'tcx> {
ReifyFnPointer {
source: ExprRef<'tcx>,
},
ClosureFnPointer {
source: ExprRef<'tcx>,
},
UnsafeFnPointer {
source: ExprRef<'tcx>,
},
Expand Down
1 change: 1 addition & 0 deletions src/librustc_mir/transform/qualify_consts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,7 @@ impl<'a, 'tcx> Visitor<'tcx> for Qualifier<'a, 'tcx, 'tcx> {
Rvalue::CheckedBinaryOp(..) |
Rvalue::Cast(CastKind::ReifyFnPointer, ..) |
Rvalue::Cast(CastKind::UnsafeFnPointer, ..) |
Rvalue::Cast(CastKind::ClosureFnPointer, ..) |
Rvalue::Cast(CastKind::Unsize, ..) => {}

Rvalue::Len(_) => {
Expand Down
1 change: 1 addition & 0 deletions src/librustc_passes/consts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,7 @@ fn check_adjustments<'a, 'tcx>(v: &mut CheckCrateVisitor<'a, 'tcx>, e: &hir::Exp
Some(Adjust::NeverToAny) |
Some(Adjust::ReifyFnPointer) |
Some(Adjust::UnsafeFnPointer) |
Some(Adjust::ClosureFnPointer) |
Some(Adjust::MutToConstPointer) => {}

Some(Adjust::DerefRef { autoderefs, .. }) => {
Expand Down
2 changes: 1 addition & 1 deletion src/librustc_trans/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1188,7 +1188,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
}

assert!(fn_ty.kind() == llvm::TypeKind::Function,
"builder::{} not passed a function", typ);
"builder::{} not passed a function, but {:?}", typ, fn_ty);

let param_tys = fn_ty.func_params();

Expand Down
14 changes: 14 additions & 0 deletions src/librustc_trans/collector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,20 @@ impl<'a, 'tcx> MirVisitor<'tcx> for MirNeighborCollector<'a, 'tcx> {
self.output);
}
}
mir::Rvalue::Cast(mir::CastKind::ClosureFnPointer, ref operand, _) => {
let source_ty = operand.ty(self.mir, self.scx.tcx());
match source_ty.sty {
ty::TyClosure(def_id, substs) => {
let closure_trans_item =
create_fn_trans_item(self.scx,
def_id,
substs.substs,
self.param_substs);
self.output.push(closure_trans_item);
}
_ => bug!(),
}
}
mir::Rvalue::Box(..) => {
let exchange_malloc_fn_def_id =
self.scx
Expand Down
23 changes: 22 additions & 1 deletion src/librustc_trans/mir/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use rustc::mir;
use rustc::mir::tcx::LvalueTy;
use rustc::ty::{self, layout, Ty, TyCtxt, TypeFoldable};
use rustc::ty::cast::{CastTy, IntTy};
use rustc::ty::subst::Substs;
use rustc::ty::subst::{Kind, Substs};
use rustc_data_structures::indexed_vec::{Idx, IndexVec};
use {abi, adt, base, Disr, machine};
use callee::Callee;
Expand Down Expand Up @@ -578,6 +578,27 @@ impl<'a, 'tcx> MirConstContext<'a, 'tcx> {
}
}
}
mir::CastKind::ClosureFnPointer => {
match operand.ty.sty {
ty::TyClosure(def_id, substs) => {
// Get the def_id for FnOnce::call_once
let fn_once = tcx.lang_items.fn_once_trait().unwrap();
let call_once = tcx
.global_tcx().associated_items(fn_once)
.find(|it| it.kind == ty::AssociatedKind::Method)
.unwrap().def_id;
// Now create its substs [Closure, Tuple]
let input = tcx.closure_type(def_id, substs).sig.input(0);
let substs = tcx.mk_substs([operand.ty, input.skip_binder()]
.iter().cloned().map(Kind::from));
Callee::def(self.ccx, call_once, substs)
.reify(self.ccx)
}
_ => {
bug!("{} cannot be cast to a fn ptr", operand.ty)
}
}
}
mir::CastKind::UnsafeFnPointer => {
// this is a no-op at the LLVM level
operand.llval
Expand Down
23 changes: 23 additions & 0 deletions src/librustc_trans/mir/rvalue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use llvm::{self, ValueRef};
use rustc::ty::{self, Ty};
use rustc::ty::cast::{CastTy, IntTy};
use rustc::ty::layout::Layout;
use rustc::ty::subst::Kind;
use rustc::mir::tcx::LvalueTy;
use rustc::mir;
use middle::lang_items::ExchangeMallocFnLangItem;
Expand Down Expand Up @@ -190,6 +191,28 @@ impl<'a, 'tcx> MirContext<'a, 'tcx> {
}
}
}
mir::CastKind::ClosureFnPointer => {
match operand.ty.sty {
ty::TyClosure(def_id, substs) => {
// Get the def_id for FnOnce::call_once
let fn_once = bcx.tcx().lang_items.fn_once_trait().unwrap();
let call_once = bcx.tcx()
.global_tcx().associated_items(fn_once)
.find(|it| it.kind == ty::AssociatedKind::Method)
.unwrap().def_id;
// Now create its substs [Closure, Tuple]
let input = bcx.tcx().closure_type(def_id, substs).sig.input(0);
let substs = bcx.tcx().mk_substs([operand.ty, input.skip_binder()]
.iter().cloned().map(Kind::from));
OperandValue::Immediate(
Callee::def(bcx.ccx, call_once, substs)
.reify(bcx.ccx))
}
_ => {
bug!("{} cannot be cast to a fn ptr", operand.ty)
}
}
}
mir::CastKind::UnsafeFnPointer => {
// this is a no-op at the LLVM level
operand.val
Expand Down
65 changes: 64 additions & 1 deletion src/librustc_typeck/check/coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,17 @@
use check::FnCtxt;

use rustc::hir;
use rustc::hir::def_id::DefId;
use rustc::infer::{Coercion, InferOk, TypeTrace};
use rustc::traits::{self, ObligationCause, ObligationCauseCode};
use rustc::ty::adjustment::{Adjustment, Adjust, AutoBorrow};
use rustc::ty::{self, LvaluePreference, TypeAndMut, Ty};
use rustc::ty::{self, LvaluePreference, TypeAndMut,
Ty, ClosureSubsts};
use rustc::ty::fold::TypeFoldable;
use rustc::ty::error::TypeError;
use rustc::ty::relate::RelateResult;
use syntax::abi;
use syntax::feature_gate;
use util::common::indent;

use std::cell::RefCell;
Expand Down Expand Up @@ -196,6 +200,11 @@ impl<'f, 'gcx, 'tcx> Coerce<'f, 'gcx, 'tcx> {
// unsafe qualifier.
self.coerce_from_fn_pointer(a, a_f, b)
}
ty::TyClosure(def_id_a, substs_a) => {
// Non-capturing closures are coercible to
// function pointers
self.coerce_closure_to_fn(a, def_id_a, substs_a, b)
}
_ => {
// Otherwise, just use unification rules.
self.unify_and_identity(a, b)
Expand Down Expand Up @@ -551,6 +560,60 @@ impl<'f, 'gcx, 'tcx> Coerce<'f, 'gcx, 'tcx> {
}
}

fn coerce_closure_to_fn(&self,
a: Ty<'tcx>,
def_id_a: DefId,
substs_a: ClosureSubsts<'tcx>,
b: Ty<'tcx>)
-> CoerceResult<'tcx> {
//! Attempts to coerce from the type of a non-capturing closure
//! into a function pointer.
//!

let b = self.shallow_resolve(b);

let node_id_a = self.tcx.hir.as_local_node_id(def_id_a).unwrap();
match b.sty {
ty::TyFnPtr(_) if self.tcx.with_freevars(node_id_a, |v| v.is_empty()) => {
if !self.tcx.sess.features.borrow().closure_to_fn_coercion {
feature_gate::emit_feature_err(&self.tcx.sess.parse_sess,
"closure_to_fn_coercion",
self.cause.span,
feature_gate::GateIssue::Language,
feature_gate::CLOSURE_TO_FN_COERCION);
return self.unify_and_identity(a, b);
}
// We coerce the closure, which has fn type
// `extern "rust-call" fn((arg0,arg1,...)) -> _`
// to
// `fn(arg0,arg1,...) -> _`
let sig = self.closure_type(def_id_a, substs_a).sig;
let converted_sig = sig.map_bound(|s| {
let params_iter = match s.inputs()[0].sty {
ty::TyTuple(params, _) => {
params.into_iter().cloned()
}
_ => bug!(),
};
self.tcx.mk_fn_sig(params_iter,
s.output(),
s.variadic)
});
let fn_ty = self.tcx.mk_bare_fn(ty::BareFnTy {
unsafety: hir::Unsafety::Normal,
abi: abi::Abi::Rust,
sig: converted_sig,
});
let pointer_ty = self.tcx.mk_fn_ptr(&fn_ty);
debug!("coerce_closure_to_fn(a={:?}, b={:?}, pty={:?})",
a, b, pointer_ty);
self.unify_and_identity(pointer_ty, b)
.map(|(ty, _)| (ty, Adjust::ClosureFnPointer))
}
_ => self.unify_and_identity(a, b),
}
}

fn coerce_unsafe_ptr(&self,
a: Ty<'tcx>,
b: Ty<'tcx>,
Expand Down
4 changes: 4 additions & 0 deletions src/librustc_typeck/check/writeback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,10 @@ impl<'cx, 'gcx, 'tcx> WritebackCx<'cx, 'gcx, 'tcx> {
adjustment::Adjust::MutToConstPointer
}

adjustment::Adjust::ClosureFnPointer => {
adjustment::Adjust::ClosureFnPointer
}

adjustment::Adjust::UnsafeFnPointer => {
adjustment::Adjust::UnsafeFnPointer
}
Expand Down
7 changes: 7 additions & 0 deletions src/libsyntax/feature_gate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,10 @@ declare_features! (
// `extern "msp430-interrupt" fn()`
(active, abi_msp430_interrupt, "1.16.0", Some(38487)),

// Used to identify crates that contain sanitizer runtimes
// rustc internal
(active, closure_to_fn_coercion, "1.17.0", Some(39817)),

// Used to identify crates that contain sanitizer runtimes
// rustc internal
(active, sanitizer_runtime, "1.17.0", None),
Expand Down Expand Up @@ -982,6 +986,9 @@ pub const EXPLAIN_DERIVE_UNDERSCORE: &'static str =
pub const EXPLAIN_PLACEMENT_IN: &'static str =
"placement-in expression syntax is experimental and subject to change.";

pub const CLOSURE_TO_FN_COERCION: &'static str =
"non-capturing closure to fn coercion is experimental";

struct PostExpansionVisitor<'a> {
context: &'a Context<'a>,
}
Expand Down
24 changes: 24 additions & 0 deletions src/test/compile-fail/closure-no-fn.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright 2017 The Rust Project Developers. See the COPYRIGHT
// file at the top-level directory of this distribution and at
// http://rust-lang.org/COPYRIGHT.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

// Ensure that capturing closures are never coerced to fns
// Especially interesting as non-capturing closures can be.

fn main() {
let mut a = 0u8;
let foo: fn(u8) -> u8 = |v: u8| { a += v; a };
//~^ ERROR mismatched types
let b = 0u8;
let bar: fn() -> u8 = || { b };
//~^ ERROR mismatched types
let baz: fn() -> u8 = || { b } as fn() -> u8;
//~^ ERROR mismatched types
//~^^ ERROR non-scalar cast
}
Loading

0 comments on commit c3075f3

Please sign in to comment.