Skip to content

Commit

Permalink
performing type inference inside closures
Browse files Browse the repository at this point in the history
this closes issue #175
  • Loading branch information
JeffBezanson committed Aug 31, 2011
1 parent c0755e1 commit efec328
Show file tree
Hide file tree
Showing 10 changed files with 69 additions and 27 deletions.
64 changes: 47 additions & 17 deletions j/inference.j
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ t_func[apply_type] = (1, Inf, apply_type_tfunc)
# other: apply
function builtin_tfunction(f, args::Tuple, argtypes::Tuple)
function builtin_tfunction(f::ANY, args::ANY, argtypes::ANY)
tf = get(t_func::IdTable, f, false)
if is(tf,false)
# struct constructor
Expand All @@ -357,6 +357,7 @@ function builtin_tfunction(f, args::Tuple, argtypes::Tuple)
# unknown/unhandled builtin
return Any
end
tf = tf::(Real, Real, Function)
if !(tf[1] <= length(argtypes) <= tf[2])
# wrong # of args
return None
Expand Down Expand Up @@ -812,13 +813,13 @@ f_argnames(ast) =
is_rest_arg(arg) = (isa(arg,Expr) && is(arg.head,symbol("::")) &&
ccall(:jl_is_rest_arg,Int32,(Any,), arg) != 0)

function typeinf_task(caller)
result = ()
while true
(caller, args) = yieldto(caller, result)
result = typeinf_ext_(args...)
end
end
# function typeinf_task(caller)
# result = ()
# while true
# (caller, args) = yieldto(caller, result)
# result = typeinf_ext_(args...)
# end
# end

#Inference_Task = Task(typeinf_task, 2097152)
#yieldto(Inference_Task, current_task())
Expand All @@ -845,6 +846,8 @@ typeinf(linfo,atypes,sparams,copy) = typeinf(linfo,atypes,sparams,copy,linfo)

abstract RecPending{T}

isRecPending(t) = isa(t, AbstractKind) && is(t.name, RecPending.name)

# def is the original unspecialized version of a method. we aggregate all
# saved type inference data there.
function typeinf(linfo::LambdaStaticData,atypes::Tuple,sparams::Tuple, cop, def)
Expand All @@ -861,7 +864,7 @@ function typeinf(linfo::LambdaStaticData,atypes::Tuple,sparams::Tuple, cop, def)
# here instead of returning, and update the cache, until the new
# inferred type equals the cached type (fixed point)
rt = ast_rettype(tf[2])
if isa(rt,AbstractKind) && is(rt.name,RecPending.name)
if isRecPending(rt)
curtype = rt.parameters[1]
redo = true
ast = tf[2]
Expand Down Expand Up @@ -1049,7 +1052,6 @@ function typeinf(linfo::LambdaStaticData,atypes::Tuple,sparams::Tuple, cop, def)
end
end
end
inference_stack = inference_stack.prev
#print("\n",ast,"\n")
#print("==> ", frame.result,"\n")
if redo && typeseq(curtype, frame.result)
Expand All @@ -1064,7 +1066,9 @@ function typeinf(linfo::LambdaStaticData,atypes::Tuple,sparams::Tuple, cop, def)
if !rec
fulltree.args[3] = inlining_pass(fulltree.args[3], s[1])
tuple_elim_pass(fulltree)
linfo.inferred = true
end
inference_stack = inference_stack.prev
return (fulltree, frame.result)
end

Expand All @@ -1080,7 +1084,7 @@ function record_var_type(e::Symbol, t, decls)
end
end

function eval_annotate(e::Expr, vtypes, sv, decls)
function eval_annotate(e::Expr, vtypes, sv, decls, clo)
head = e.head
if is(head,:quote) || is(head,:top) || is(head,:goto) ||
is(head,:static_typeof) || is(head,:line)
Expand All @@ -1098,38 +1102,44 @@ function eval_annotate(e::Expr, vtypes, sv, decls)
else
e.args[1] = SymbolNode(s, abstract_eval(s, vtypes, sv))
end
e.args[2] = eval_annotate(e.args[2], vtypes, sv, decls)
e.args[2] = eval_annotate(e.args[2], vtypes, sv, decls, clo)
# TODO: if this def does not reach any uses, maybe don't do this
record_var_type(s, exprtype(e.args[2]), decls)
return e
end
for i=1:length(e.args)
e.args[i] = eval_annotate(e.args[i], vtypes, sv, decls)
e.args[i] = eval_annotate(e.args[i], vtypes, sv, decls, clo)
end
e
end

function eval_annotate(e::Symbol, vtypes, sv, decls)
function eval_annotate(e::Symbol, vtypes, sv, decls, clo)
t = abstract_eval(e, vtypes, sv)
record_var_type(e, t, decls)
SymbolNode(e, t)
end

function eval_annotate(e::SymbolNode, vtypes, sv, decls)
function eval_annotate(e::SymbolNode, vtypes, sv, decls, clo)
t = abstract_eval(e.name, vtypes, sv)
record_var_type(e.name, t, decls)
e.typ = t
e
end

eval_annotate(s, vtypes, sv, decls) = s
eval_annotate(s, vtypes, sv, decls, clo) = s

function eval_annotate(l::LambdaStaticData, vtypes, sv, decls, clo)
push(clo, l)
l
end

# annotate types of all symbols in AST
function type_annotate(ast::Expr, states::Array, sv, rettype, vnames)
decls = idtable()
closures = {}
body = ast.args[3].args
for i=1:length(body)
body[i] = eval_annotate(body[i], states[i], sv, decls)
body[i] = eval_annotate(body[i], states[i], sv, decls, closures)
end
ast.args[3].typ = rettype

Expand All @@ -1140,6 +1150,26 @@ function type_annotate(ast::Expr, states::Array, sv, rettype, vnames)
vi[2] = decls[vi[1]]
end
end

# do inference on inner functions
if isRecPending(rettype)
return ast
end

for li = closures
if !li.inferred
a = li.ast
# pass on declarations of captured vars
vinf = a.args[2].args[3]
for vi = vinf
if has(decls,vi[1])
vi[2] = decls[vi[1]]
end
end
typeinf(li, NTuple{length(a.args[1]), Any}, li.sparams, false, li)
end
end

ast
end

Expand Down
3 changes: 3 additions & 0 deletions j/serialize.j
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ function serialize(s, f::Function)
serialize(s, lambda_number(linfo))
serialize(s, linfo.ast)
serialize(s, linfo.sparams)
serialize(s, linfo.inferred)
serialize(s, env)
end
end
Expand Down Expand Up @@ -238,6 +239,7 @@ function deserialize_function(s)
lnumber = force(deserialize(s))
ast = deserialize(s)
sparams = deserialize(s)
infr = force(deserialize(s))
env = deserialize(s)
if has(known_lambda_data, lnumber)
linfo = known_lambda_data[lnumber]
Expand All @@ -249,6 +251,7 @@ function deserialize_function(s)
function ()
linfo = ccall(:jl_new_lambda_info, Any, (Any, Any),
force(ast), force(sparams))
linfo.inferred = infr
known_lambda_data[lnumber] = linfo
ccall(:jl_new_closure_internal, Any, (Any, Any),
linfo, force(env))::Function
Expand Down
2 changes: 1 addition & 1 deletion src/alloc.c
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ jl_lambda_info_t *jl_new_lambda_info(jl_value_t *ast, jl_tuple_t *sparams)
li->roots = jl_null;
li->functionObject = NULL;
li->specTypes = NULL;
li->inferred = 0;
li->inferred = jl_false;
li->inInference = 0;
li->inCompile = 0;
li->unspecialized = NULL;
Expand Down
4 changes: 4 additions & 0 deletions src/ast.c
Original file line number Diff line number Diff line change
Expand Up @@ -554,9 +554,13 @@ static jl_value_t *copy_ast(jl_value_t *expr, jl_tuple_t *sp)
{
if (jl_is_lambda_info(expr)) {
jl_lambda_info_t *li = (jl_lambda_info_t*)expr;
/*
if (sp == jl_null && li->ast &&
jl_lam_capt((jl_expr_t*)li->ast)->length == 0)
return expr;
*/
// TODO: avoid if above condition is true and decls have already
// been evaluated.
li = jl_add_static_parameters(li, sp);
jl_specialize_ast(li);
return (jl_value_t*)li;
Expand Down
2 changes: 1 addition & 1 deletion src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1352,7 +1352,7 @@ static bool store_unboxed_p(char *name, jl_codectx_t *ctx)
jl_value_t *jt = (*ctx->declTypes)[name];
// only store a variable unboxed if type inference has run, which
// checks that the variable is not referenced undefined.
return (ctx->linfo->inferred && jl_is_bits_type(jt) &&
return (ctx->linfo->inferred==jl_true && jl_is_bits_type(jt) &&
jl_is_leaf_type(jt) &&
// don't unbox intrinsics, since inference depends on their having
// stable addresses for table lookup.
Expand Down
4 changes: 2 additions & 2 deletions src/dump.c
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ void jl_serialize_value_(ios_t *s, jl_value_t *v)
jl_serialize_value(s, (jl_value_t*)li->name);
jl_serialize_value(s, (jl_value_t*)li->specTypes);
jl_serialize_value(s, (jl_value_t*)li->specializations);
write_int8(s, li->inferred);
jl_serialize_value(s, (jl_value_t*)li->inferred);
}
else if (jl_typeis(v, jl_methtable_type)) {
writetag(s, jl_methtable_type);
Expand Down Expand Up @@ -645,7 +645,7 @@ jl_value_t *jl_deserialize_value(ios_t *s)
li->name = (jl_sym_t*)jl_deserialize_value(s);
li->specTypes = jl_deserialize_value(s);
li->specializations = (jl_tuple_t*)jl_deserialize_value(s);
li->inferred = read_int8(s);
li->inferred = jl_deserialize_value(s);

li->fptr = NULL;
li->roots = jl_null;
Expand Down
2 changes: 1 addition & 1 deletion src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ static jl_function_t *cache_method(jl_methtable_t *mt, jl_tuple_t *type,
#ifdef ENABLE_INFERENCE
jl_value_t *newast = jl_apply(jl_typeinf_func, fargs, 5);
newmeth->linfo->ast = jl_tupleref(newast, 0);
newmeth->linfo->inferred = 1;
newmeth->linfo->inferred = jl_true;
#endif
newmeth->linfo->inInference = 0;
}
Expand Down
3 changes: 3 additions & 0 deletions src/intrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,9 @@ static Value *boxed(Value *v)
if (jb == jl_uint32_type) return builder.CreateCall(box_uint32_func, v);
if (jb == jl_uint64_type) return builder.CreateCall(box_uint64_func, v);
if (jl_is_bits_type(jt)) {
if (v->getType()->isPointerTy()) {
v = builder.CreatePtrToInt(v, T_size);
}
int nb = jl_bitstype_nbits(jt);
if (nb == 8)
return builder.CreateCall2(box8_func, literal_pointer_val(jt), v);
Expand Down
10 changes: 6 additions & 4 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -2470,18 +2470,20 @@ void jl_init_types()
jl_lambda_info_type =
jl_new_struct_type(jl_symbol("LambdaStaticData"),
jl_any_type, jl_null,
jl_tuple(8, jl_symbol("ast"), jl_symbol("sparams"),
jl_tuple(9, jl_symbol("ast"), jl_symbol("sparams"),
jl_symbol("tfunc"), jl_symbol("name"),
/*
jl_symbol("roots"), jl_symbol("specTypes"),
jl_symbol("unspecialized"),
jl_symbol("specializations")*/
jl_symbol(""), jl_symbol(""),
jl_symbol(""), jl_symbol("")),
jl_tuple(8, jl_expr_type, jl_tuple_type,
jl_symbol(""), jl_symbol(""),
jl_symbol("inferred")),
jl_tuple(9, jl_expr_type, jl_tuple_type,
jl_any_type, jl_sym_type,
jl_any_type, jl_tuple_type,
jl_function_type, jl_tuple_type));
jl_function_type, jl_tuple_type,
jl_bool_type));
jl_lambda_info_type->fptr = jl_f_no_function;

jl_box_type =
Expand Down
2 changes: 1 addition & 1 deletion src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,11 @@ typedef struct _jl_lambda_info_t {
struct _jl_function_t *unspecialized;
// pairlist of all lambda infos with code generated from this one
jl_tuple_t *specializations;
jl_value_t *inferred;

// hidden fields:
jl_fptr_t fptr;
void *functionObject;
uptrint_t inferred;
// flag telling if inference is running on this function
// used to avoid infinite recursion
uptrint_t inInference;
Expand Down

0 comments on commit efec328

Please sign in to comment.