Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inline invoke #9642

Closed
wants to merge 14 commits into from
Closed
189 changes: 168 additions & 21 deletions base/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -787,30 +787,29 @@ function abstract_call_gf(f, fargs, argtypes, e)
end

function invoke_tfunc(f, types, argtypes)
argtypes = typeintersect(types,limit_tuple_type(argtypes))
if is(argtypes,Bottom)
argtypes = typeintersect(types, limit_tuple_type(argtypes))
if is(argtypes, Bottom)
return Bottom
end
applicable = _methods(f, types, -1)
if isempty(applicable)
meth = try
ccall(:jl_gf_invoke_lookup, Any, (Any, Any), f, types)
end
if is(meth, nothing)
return Any
end
for (m::Tuple) in applicable
local linfo
try
linfo = func_for_method(m[3],types,m[2])
catch
return Any
end
if typeseq(m[1],types)
tvars = m[2][1:2:end]
(ti, env) = ccall(:jl_match_method, Any, (Any,Any,Any),
argtypes, m[1], tvars)::(Any,Any)
(_tree,rt) = typeinf(linfo, ti, env, linfo)
return rt
end
(ti, env) = ccall(:jl_match_method, Any, (Any, Any, Any),
argtypes, meth.sig, meth.tvars)::(Any, Any)
if !isa(ti, Tuple)
return Any
end
return Any
local linfo
try
linfo = func_for_method(meth, types, env)
catch
return Any
end
(_tree, rt) = typeinf(linfo, ti, env, linfo)
return rt
end

function to_tuple_of_Types(t::ANY)
Expand Down Expand Up @@ -895,6 +894,7 @@ function abstract_call(f, fargs, argtypes, vtypes, sv::StaticVarInfo, e)
if !is(af,false) && (af=_ieval(af);isgeneric(af))
sig = argtypes[2]
if isa(sig,Tuple) && all(isType, sig)
e.head = :call1
sig = map(t->t.parameters[1], sig)
return invoke_tfunc(af, sig, argtypes[3:end])
end
Expand Down Expand Up @@ -2249,13 +2249,155 @@ function inlineable(f::ANY, e::Expr, atypes::Tuple, sv::StaticVarInfo, enclosing
if isa(f,IntrinsicFunction)
return NF
end
if is(f, invoke)
return inlineable_invoke(f, e, atypes, sv, enclosing_ast, argexprs)
end

return inlineable_gf(f, e, atypes, sv, enclosing_ast, argexprs)
end

function get_invoke_types{T<:Top}(ts::Type{Type{T}})
return T
end

function get_invoke_types(ts::Tuple)
return tuple(Any[get_invoke_types(t) for t in ts]...)::Type
end

function inlineable_invoke(f::Function, e::Expr, atypes, sv,
enclosing_ast, argexprs)
f = isconstantfunc(argexprs[1], sv)
if is(f, false)
return NF
end
f = _ieval(f)
local invoke_types
try
invoke_types = get_invoke_types(atypes[2])
catch
return NF
end
if length(atypes) != (length(invoke_types) + 2)
return NF
end

fexpr = argexprs[1]
# Always evaluate the types expression and letting codegen to eliminate
# unnecessary code for now.
stmts = Any[argexprs[2]]
argexprs = argexprs[3:end]
atypes = atypes[3:end]

# Special case when invoke is equivalent with direct call.
if isleaftype(invoke_types) && invoke_types == atypes
new_e = Expr(:call, fexpr, argexprs...)
new_e.typ = e.typ
res = inlineable_gf(f, new_e, invoke_types, sv,
enclosing_ast, argexprs)
if isa(res, Tuple)
if isa(res[2], Array)
append!(stmts, res[2])
end
return (res[1], stmts)
else
return (new_e, stmts)
end
end

meth = try
ccall(:jl_gf_invoke_lookup, Any, (Any, Any), f, invoke_types)
end
if is(meth, nothing)
return NF
end

# TODO: pre-evaluation is only necessary when type check is needed and
# when the arguments have side-effect
check_stmts = []
atypes_l = Type[atypes...]
err_label = genlabel(sv)
after_err_label = genlabel(sv)
for i in 1:length(atypes_l)
Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when inlining arguments, you need to go in reverse order to preserve the order-of-execution (see https://github.com/yuyichao/julia/blob/inline-invoke/base/inference.jl#L2770-L2771 for an example)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is what I'm trying to do here. I've seen that part but I didn't really understand why. Isn't the arguments evaluated in the order they appears in the code?

Also the code generated seems to be currect here and that's why I didn't bother too much before sending the PR

julia> @noinline function get_next()
           global counter
           counter = counter + 1
           return counter
       end
get_next (generic function with 1 method)

julia> f(a, b) = (a, b)
f (generic function with 1 method)

julia> g() = invoke(f, (Integer, Integer), get_next(), get_next())
g (generic function with 1 method)

julia> @code_typed g()
1-element Array{Any,1}:
 :($(Expr(:lambda, Any[], Any[Any[:_var0,:_var1],Any[Any[:_var0,Any,18],Any[:_var1,Any,18]],Any[]], :(begin  # none, line 1:
        _var0 = get_next()::Any
        _var1 = get_next()::Any
        unless (isa)(_var0,Integer)::Bool goto 1
        unless (isa)(_var1,Integer)::Bool goto 1
        goto 2
        1: 
        (error)("invoke: argument type error")::Union()
        2: 
        return (top(tuple))(_var0::Integer,_var1::Integer)::(Integer,Integer)
    end::(Integer,Integer)))))

julia> g()
ERROR: counter not defined
 in get_next at ./no file:3
 in g at ./none:1

julia> counter = 0
0

julia> g()
(1,2)

Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm waiting for a build now, but what if you change the second argument to invoke to (Any, Integer)?

Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, i wasn't paying attention to the fact that you are always copying the argument to a temporary variable. there's no need to do that if the argument is effect_free / affect_free (the difference linguistically is subtle: effect-free means that it does not cause an effect on surrounding code (e.g. pure), whereas affect-free means it is not affected by surrounding code (e.g. immutable))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type check is done after all arguments are evaluated so changing to (Any, Integer) does not affect the evaluation of the arguments at all. (which is the same schematics with calling invoke function)

(Actually I think I might miss the case where evaluating the second argument to invoke has side effect)....

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ooops, forgot to paste the output................

julia> @noinline function get_next()
           global counter
           counter = counter + 1
           return counter
       end
get_next (generic function with 1 method)

julia> counter = 0
0

julia> f(a, b) = (a, b)
f (generic function with 1 method)

julia> g() = invoke(f, (Any, Integer), get_next(), get_next())
g (generic function with 1 method)

julia> @code_typed g()
1-element Array{Any,1}:
g() :($(Expr(:lambda, Any[], Any[Any[:_var0,:_var1],Any[Any[:_var0,Any,18],Any[:_var1,Any,18]],Any[]], :(begin  # none, line 1:
        _var0 = get_next()::Any
        _var1 = get_next()::Any
        unless (isa)(_var1,Integer)::Bool goto 1
        goto 2
        1: 
        (error)("invoke: argument type error")::Union()
        2: 
        return (top(tuple))(_var0::Any,_var1::Integer)::(Any,Integer)
    end::(Any,Integer)))))

julia> g()
(1,2)

arg_name = unique_name(enclosing_ast)
tmp_ex = :($arg_name = $(argexprs[i]))
tmp_ex.typ = atypes_l[i]
add_variable(enclosing_ast, arg_name, atypes_l[i], true)
push!(stmts, tmp_ex)
if !issubtype(atypes_l[i], invoke_types[i])
atypes_l[i] = typeintersect(atypes_l[i], invoke_types[i])
check_type = :($(TopNode(:isa))($arg_name, $(invoke_types[i])))
check_type.typ = Bool
push!(check_stmts, Expr(:gotoifnot, check_type, err_label.label))
end
argexprs[i] = SymbolNode(arg_name, atypes_l[i])
end
if !isempty(check_stmts)
append!(stmts, check_stmts)
push!(stmts, gn(after_err_label))
push!(stmts, err_label)
check_error = :($(TopNode(:error))("invoke: argument type error"))
check_error.typ = None
push!(stmts, check_error)
push!(stmts, after_err_label)
end
atypes = tuple(atypes_l...)

match_meth = _match_method(meth, atypes)
new_e = Expr(:call, fexpr, argexprs...)
new_e.typ = e.typ
if length(match_meth) == 1
match_meth = match_meth[1]::Tuple
# Try inlining
res = inlineable_meth(f, match_meth, new_e, atypes, sv, enclosing_ast,
argexprs, true)
if isa(res, Tuple)
if isa(res[2], Array)
append!(stmts, res[2])
end
return (res[1], stmts)
end
end
return NF
end

_match_method(m::ANY, t::ANY) = _match_method(m, Any[(t::Tuple)...],
length(t::Tuple), [])
function _match_method(m::ANY, t::Array, i, matching::Array{Any, 1})
if i == 0
(ti, env) = ccall(:jl_match_method, Any, (Any, Any, Any),
tuple(t...), m.sig, m.tvars)::(Any, Any)
if isa(ti, Tuple)
push!(matching, tuple(ti, env, m))
end
else
ti = t[i]
if isa(ti, UnionType)
for ty in (ti::UnionType).types
t[i] = ty
_match_method(m, t, i - 1, matching)
end
t[i] = ti
else
return _match_method(m, t, i - 1, matching)
end
end
matching
end

function inlineable_gf(f::Function, e::Expr, atypes, sv, enclosing_ast,
argexprs)
meth = _methods(f, atypes, 1)
if meth === false || length(meth) != 1
return NF
end
meth = meth[1]::Tuple
return inlineable_meth(f, meth[1], e, atypes, sv, enclosing_ast, argexprs)
end

function inlineable_meth(f::Function, meth::Tuple, e::Expr, atypes, sv,
enclosing_ast, argexprs, is_invoke::Bool=false)
# NOTE: when is_invoke is true, this function shouldn't do or generate any
# code that does method lookup based on argument types. More specifically,
# arguments f and e (e.args[1]) must be used with care.
local linfo
try
linfo = func_for_method(meth[3],atypes,meth[2])
Expand All @@ -2267,6 +2409,7 @@ function inlineable(f::ANY, e::Expr, atypes::Tuple, sv::StaticVarInfo, enclosing
## growing due to recursion.
## It might be helpful for some things, but turns out not to be
## necessary to get max performance from recursive varargs functions.
## NOTE: Need to adapt with is_invoke if the following code is re-enabled.
# if length(atypes) > MAX_TUPLETYPE_LEN
# # check call stack to see if this argument list is growing
# st = inference_stack
Expand Down Expand Up @@ -2346,7 +2489,10 @@ function inlineable(f::ANY, e::Expr, atypes::Tuple, sv::StaticVarInfo, enclosing
cost /= 4
end
if !inline_worthy(body, cost)
if incompletematch
# incompletematch shouldn't happen for invoke because of the type check
# generated before entering inlineable_meth but just in case something
# changes in the future
if incompletematch && !is_invoke
# inline a typeassert-based call-site, rather than a
# full generic lookup, using the inliner to handle
# all the fiddly details
Expand Down Expand Up @@ -2704,6 +2850,7 @@ function inlineable(f::ANY, e::Expr, atypes::Tuple, sv::StaticVarInfo, enclosing
end
return (expr, stmts)
end

# The inlining incomplete matches optimization currently
# doesn't work on Tuples of TypeVars
const inline_incompletematch_allowed = false
Expand Down
2 changes: 1 addition & 1 deletion deps/openspecfun
Submodule openspecfun updated from 381db9 to f3036c
2 changes: 1 addition & 1 deletion src/ast.c
Original file line number Diff line number Diff line change
Expand Up @@ -910,7 +910,7 @@ static void eval_decl_types(jl_array_t *vi, jl_value_t *ast, jl_tuple_t *spenv)
assert(jl_array_len(v) > 1);
jl_value_t *ty = jl_static_eval(jl_cellref(v,1), NULL, jl_current_module,
(jl_value_t*)spenv, (jl_expr_t*)ast, 1, 1);
if (ty != NULL && (jl_is_type(ty) || jl_is_typevar(ty))) {
if (ty != NULL && jl_is_type(ty)) {
jl_cellset(v, 1, ty);
}
else {
Expand Down
2 changes: 1 addition & 1 deletion src/builtins.c
Original file line number Diff line number Diff line change
Expand Up @@ -956,7 +956,7 @@ JL_CALLABLE(jl_f_union)
size_t i;
jl_tuple_t *argt = jl_alloc_tuple_uninit(nargs);
for(i=0; i < nargs; i++) {
if (!jl_is_type(args[i]) && !jl_is_typevar(args[i])) {
if (!jl_is_type(args[i])) {
jl_error("invalid union type");
}
else {
Expand Down
3 changes: 2 additions & 1 deletion src/cgutils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1108,8 +1108,9 @@ static jl_value_t *expr_type(jl_value_t *e, jl_codectx_t *ctx)
return (jl_value_t*)jl_any_type;
}
type_of_constant:
if (jl_is_datatype(e) || jl_is_uniontype(e) || jl_is_typector(e))
if (jl_is_datatype(e) || jl_is_uniontype(e) || jl_is_typector(e)) {
return (jl_value_t*)jl_wrap_Type(e);
}
return (jl_value_t*)jl_typeof(e);
}

Expand Down
86 changes: 85 additions & 1 deletion src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,13 @@ static Value *global_binding_pointer(jl_module_t *m, jl_sym_t *s,
jl_binding_t **pbnd, bool assign);
static Value *emit_checked_var(Value *bp, jl_sym_t *name, jl_codectx_t *ctx, bool isvol=false);
static bool might_need_root(jl_value_t *ex);
static Value *emit_condition(jl_value_t *cond, const std::string &msg, jl_codectx_t *ctx);
static Value *emit_condition(jl_value_t *cond, const std::string &msg,
jl_codectx_t *ctx);
static Value *emit_call_function_object(jl_function_t *f, Value *theF,
Value *theFptr,
bool specialized,
jl_value_t **args, size_t nargs,
jl_codectx_t *ctx);

// NoopType
static Type *NoopType;
Expand Down Expand Up @@ -1878,6 +1884,39 @@ static Value *emit_f_is(jl_value_t *rt1, jl_value_t *rt2,
return answer;
}

static int
get_invoke_types(jl_value_t *a, jl_value_t **ptemp)
{
if (jl_is_type_type(a)) {
jl_value_t *types = jl_tparam0(a);
if (jl_is_tuple(types)) {
*ptemp = types;
return 1;
}
return 0;
} else if (jl_is_tuple(a)) {
jl_tuple_t *tt = (jl_tuple_t*)a;
int tlen = jl_tuple_len(tt);
jl_tuple_t *types = jl_alloc_tuple(tlen);
*ptemp = (jl_value_t*)types;
for(int i = 0;i < tlen;i++) {
jl_value_t *el = jl_tupleref(tt, i);
jl_value_t *var_type;
if (jl_is_type_type(el)) {
jl_tupleset(types, i, jl_tparam0(el));
} else if (i == tlen - 1 && jl_is_vararg_type(el) &&
jl_is_type_type((var_type = jl_tparam0(el)))) {
jl_tupleset(types, i, jl_wrap_vararg(jl_tparam0(var_type)));
} else {
*ptemp = NULL;
return 0;
}
}
return 1;
}
return 0;
}

static Value *emit_known_call(jl_value_t *ff, jl_value_t **args, size_t nargs,
jl_codectx_t *ctx,
Value **theFptr, jl_function_t **theF,
Expand Down Expand Up @@ -2496,6 +2535,51 @@ static Value *emit_known_call(jl_value_t *ff, jl_value_t **args, size_t nargs,
}
}
}
else if (f->fptr == &jl_f_invoke && nargs >= 2 &&
expr_type(args[1], ctx) == (jl_value_t*)jl_function_type) {
f = (jl_function_t*)static_eval(args[1], ctx, true);
rt1 = (jl_value_t*)f;
if (!f || !jl_is_gf(f)) {
JL_GC_POP();
return NULL;
}
// Hack since expr_type doesn't support tuple types.
rt3 = static_eval(args[2], ctx, true);
if (rt3) {
if (!jl_is_tuple(rt3)) {
JL_GC_POP();
return NULL;
}
} else {
jl_value_t *type_types = expr_type(args[2], ctx);
rt2 = type_types;
if (!get_invoke_types(type_types, &rt3)) {
JL_GC_POP();
return NULL;
}
rt2 = NULL;
}
jl_tuple_t *tt = call_arg_types(&args[3], nargs - 2, ctx);
if (tt == NULL) {
JL_GC_POP();
return NULL;
}
rt2 = (jl_value_t*)tt;
jl_function_t *mfunc =
jl_gf_invoke_get_specialization(f, (jl_tuple_t*)rt3, tt);
JL_GC_POP();
if (mfunc == NULL) {
return NULL;
}
assert(mfunc->linfo->functionObject != NULL);
jl_cstyle_compile(mfunc);
emit_expr(args[2], ctx);
Value *_theF = literal_pointer_val((jl_value_t*)mfunc);
Value *_theFptr = (Value*)mfunc->linfo->functionObject;
Value *result = emit_call_function_object(mfunc, _theF, _theFptr, true,
args + 2, nargs - 2, ctx);
return result;
}
// TODO: other known builtins
JL_GC_POP();
return NULL;
Expand Down
Loading