From efec32867226f007c7a47b87ff9f5f8da595f621 Mon Sep 17 00:00:00 2001 From: Jeff Bezanson Date: Wed, 31 Aug 2011 00:33:05 -0400 Subject: [PATCH] performing type inference inside closures this closes issue #175 --- j/inference.j | 64 ++++++++++++++++++++++++++++++++++------------ j/serialize.j | 3 +++ src/alloc.c | 2 +- src/ast.c | 4 +++ src/codegen.cpp | 2 +- src/dump.c | 4 +-- src/gf.c | 2 +- src/intrinsics.cpp | 3 +++ src/jltypes.c | 10 +++++--- src/julia.h | 2 +- 10 files changed, 69 insertions(+), 27 deletions(-) diff --git a/j/inference.j b/j/inference.j index 6222966aaaf46..7f073fc20caaf 100644 --- a/j/inference.j +++ b/j/inference.j @@ -347,7 +347,7 @@ t_func[apply_type] = (1, Inf, apply_type_tfunc) # other: apply -function builtin_tfunction(f, args::Tuple, argtypes::Tuple) +function builtin_tfunction(f::ANY, args::ANY, argtypes::ANY) tf = get(t_func::IdTable, f, false) if is(tf,false) # struct constructor @@ -357,6 +357,7 @@ function builtin_tfunction(f, args::Tuple, argtypes::Tuple) # unknown/unhandled builtin return Any end + tf = tf::(Real, Real, Function) if !(tf[1] <= length(argtypes) <= tf[2]) # wrong # of args return None @@ -812,13 +813,13 @@ f_argnames(ast) = is_rest_arg(arg) = (isa(arg,Expr) && is(arg.head,symbol("::")) && ccall(:jl_is_rest_arg,Int32,(Any,), arg) != 0) -function typeinf_task(caller) - result = () - while true - (caller, args) = yieldto(caller, result) - result = typeinf_ext_(args...) - end -end +# function typeinf_task(caller) +# result = () +# while true +# (caller, args) = yieldto(caller, result) +# result = typeinf_ext_(args...) +# end +# end #Inference_Task = Task(typeinf_task, 2097152) #yieldto(Inference_Task, current_task()) @@ -845,6 +846,8 @@ typeinf(linfo,atypes,sparams,copy) = typeinf(linfo,atypes,sparams,copy,linfo) abstract RecPending{T} +isRecPending(t) = isa(t, AbstractKind) && is(t.name, RecPending.name) + # def is the original unspecialized version of a method. we aggregate all # saved type inference data there. function typeinf(linfo::LambdaStaticData,atypes::Tuple,sparams::Tuple, cop, def) @@ -861,7 +864,7 @@ function typeinf(linfo::LambdaStaticData,atypes::Tuple,sparams::Tuple, cop, def) # here instead of returning, and update the cache, until the new # inferred type equals the cached type (fixed point) rt = ast_rettype(tf[2]) - if isa(rt,AbstractKind) && is(rt.name,RecPending.name) + if isRecPending(rt) curtype = rt.parameters[1] redo = true ast = tf[2] @@ -1049,7 +1052,6 @@ function typeinf(linfo::LambdaStaticData,atypes::Tuple,sparams::Tuple, cop, def) end end end - inference_stack = inference_stack.prev #print("\n",ast,"\n") #print("==> ", frame.result,"\n") if redo && typeseq(curtype, frame.result) @@ -1064,7 +1066,9 @@ function typeinf(linfo::LambdaStaticData,atypes::Tuple,sparams::Tuple, cop, def) if !rec fulltree.args[3] = inlining_pass(fulltree.args[3], s[1]) tuple_elim_pass(fulltree) + linfo.inferred = true end + inference_stack = inference_stack.prev return (fulltree, frame.result) end @@ -1080,7 +1084,7 @@ function record_var_type(e::Symbol, t, decls) end end -function eval_annotate(e::Expr, vtypes, sv, decls) +function eval_annotate(e::Expr, vtypes, sv, decls, clo) head = e.head if is(head,:quote) || is(head,:top) || is(head,:goto) || is(head,:static_typeof) || is(head,:line) @@ -1098,38 +1102,44 @@ function eval_annotate(e::Expr, vtypes, sv, decls) else e.args[1] = SymbolNode(s, abstract_eval(s, vtypes, sv)) end - e.args[2] = eval_annotate(e.args[2], vtypes, sv, decls) + e.args[2] = eval_annotate(e.args[2], vtypes, sv, decls, clo) # TODO: if this def does not reach any uses, maybe don't do this record_var_type(s, exprtype(e.args[2]), decls) return e end for i=1:length(e.args) - e.args[i] = eval_annotate(e.args[i], vtypes, sv, decls) + e.args[i] = eval_annotate(e.args[i], vtypes, sv, decls, clo) end e end -function eval_annotate(e::Symbol, vtypes, sv, decls) +function eval_annotate(e::Symbol, vtypes, sv, decls, clo) t = abstract_eval(e, vtypes, sv) record_var_type(e, t, decls) SymbolNode(e, t) end -function eval_annotate(e::SymbolNode, vtypes, sv, decls) +function eval_annotate(e::SymbolNode, vtypes, sv, decls, clo) t = abstract_eval(e.name, vtypes, sv) record_var_type(e.name, t, decls) e.typ = t e end -eval_annotate(s, vtypes, sv, decls) = s +eval_annotate(s, vtypes, sv, decls, clo) = s + +function eval_annotate(l::LambdaStaticData, vtypes, sv, decls, clo) + push(clo, l) + l +end # annotate types of all symbols in AST function type_annotate(ast::Expr, states::Array, sv, rettype, vnames) decls = idtable() + closures = {} body = ast.args[3].args for i=1:length(body) - body[i] = eval_annotate(body[i], states[i], sv, decls) + body[i] = eval_annotate(body[i], states[i], sv, decls, closures) end ast.args[3].typ = rettype @@ -1140,6 +1150,26 @@ function type_annotate(ast::Expr, states::Array, sv, rettype, vnames) vi[2] = decls[vi[1]] end end + + # do inference on inner functions + if isRecPending(rettype) + return ast + end + + for li = closures + if !li.inferred + a = li.ast + # pass on declarations of captured vars + vinf = a.args[2].args[3] + for vi = vinf + if has(decls,vi[1]) + vi[2] = decls[vi[1]] + end + end + typeinf(li, NTuple{length(a.args[1]), Any}, li.sparams, false, li) + end + end + ast end diff --git a/j/serialize.j b/j/serialize.j index bfdb416b80c92..a077f934b3502 100644 --- a/j/serialize.j +++ b/j/serialize.j @@ -146,6 +146,7 @@ function serialize(s, f::Function) serialize(s, lambda_number(linfo)) serialize(s, linfo.ast) serialize(s, linfo.sparams) + serialize(s, linfo.inferred) serialize(s, env) end end @@ -238,6 +239,7 @@ function deserialize_function(s) lnumber = force(deserialize(s)) ast = deserialize(s) sparams = deserialize(s) + infr = force(deserialize(s)) env = deserialize(s) if has(known_lambda_data, lnumber) linfo = known_lambda_data[lnumber] @@ -249,6 +251,7 @@ function deserialize_function(s) function () linfo = ccall(:jl_new_lambda_info, Any, (Any, Any), force(ast), force(sparams)) + linfo.inferred = infr known_lambda_data[lnumber] = linfo ccall(:jl_new_closure_internal, Any, (Any, Any), linfo, force(env))::Function diff --git a/src/alloc.c b/src/alloc.c index 427545b14028f..c81f9c099d661 100644 --- a/src/alloc.c +++ b/src/alloc.c @@ -290,7 +290,7 @@ jl_lambda_info_t *jl_new_lambda_info(jl_value_t *ast, jl_tuple_t *sparams) li->roots = jl_null; li->functionObject = NULL; li->specTypes = NULL; - li->inferred = 0; + li->inferred = jl_false; li->inInference = 0; li->inCompile = 0; li->unspecialized = NULL; diff --git a/src/ast.c b/src/ast.c index 9b364fee6ad22..0da375a421f12 100644 --- a/src/ast.c +++ b/src/ast.c @@ -554,9 +554,13 @@ static jl_value_t *copy_ast(jl_value_t *expr, jl_tuple_t *sp) { if (jl_is_lambda_info(expr)) { jl_lambda_info_t *li = (jl_lambda_info_t*)expr; + /* if (sp == jl_null && li->ast && jl_lam_capt((jl_expr_t*)li->ast)->length == 0) return expr; + */ + // TODO: avoid if above condition is true and decls have already + // been evaluated. li = jl_add_static_parameters(li, sp); jl_specialize_ast(li); return (jl_value_t*)li; diff --git a/src/codegen.cpp b/src/codegen.cpp index 85c470a4bf6be..6f538b8ac4503 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -1352,7 +1352,7 @@ static bool store_unboxed_p(char *name, jl_codectx_t *ctx) jl_value_t *jt = (*ctx->declTypes)[name]; // only store a variable unboxed if type inference has run, which // checks that the variable is not referenced undefined. - return (ctx->linfo->inferred && jl_is_bits_type(jt) && + return (ctx->linfo->inferred==jl_true && jl_is_bits_type(jt) && jl_is_leaf_type(jt) && // don't unbox intrinsics, since inference depends on their having // stable addresses for table lookup. diff --git a/src/dump.c b/src/dump.c index f2410b62ed0ec..6278e6356f72a 100644 --- a/src/dump.c +++ b/src/dump.c @@ -312,7 +312,7 @@ void jl_serialize_value_(ios_t *s, jl_value_t *v) jl_serialize_value(s, (jl_value_t*)li->name); jl_serialize_value(s, (jl_value_t*)li->specTypes); jl_serialize_value(s, (jl_value_t*)li->specializations); - write_int8(s, li->inferred); + jl_serialize_value(s, (jl_value_t*)li->inferred); } else if (jl_typeis(v, jl_methtable_type)) { writetag(s, jl_methtable_type); @@ -645,7 +645,7 @@ jl_value_t *jl_deserialize_value(ios_t *s) li->name = (jl_sym_t*)jl_deserialize_value(s); li->specTypes = jl_deserialize_value(s); li->specializations = (jl_tuple_t*)jl_deserialize_value(s); - li->inferred = read_int8(s); + li->inferred = jl_deserialize_value(s); li->fptr = NULL; li->roots = jl_null; diff --git a/src/gf.c b/src/gf.c index f8704093bb6c3..3a01291874625 100644 --- a/src/gf.c +++ b/src/gf.c @@ -594,7 +594,7 @@ static jl_function_t *cache_method(jl_methtable_t *mt, jl_tuple_t *type, #ifdef ENABLE_INFERENCE jl_value_t *newast = jl_apply(jl_typeinf_func, fargs, 5); newmeth->linfo->ast = jl_tupleref(newast, 0); - newmeth->linfo->inferred = 1; + newmeth->linfo->inferred = jl_true; #endif newmeth->linfo->inInference = 0; } diff --git a/src/intrinsics.cpp b/src/intrinsics.cpp index 617df07046681..a3fa898174edd 100644 --- a/src/intrinsics.cpp +++ b/src/intrinsics.cpp @@ -269,6 +269,9 @@ static Value *boxed(Value *v) if (jb == jl_uint32_type) return builder.CreateCall(box_uint32_func, v); if (jb == jl_uint64_type) return builder.CreateCall(box_uint64_func, v); if (jl_is_bits_type(jt)) { + if (v->getType()->isPointerTy()) { + v = builder.CreatePtrToInt(v, T_size); + } int nb = jl_bitstype_nbits(jt); if (nb == 8) return builder.CreateCall2(box8_func, literal_pointer_val(jt), v); diff --git a/src/jltypes.c b/src/jltypes.c index a28fda4096557..1337933b13f82 100644 --- a/src/jltypes.c +++ b/src/jltypes.c @@ -2470,18 +2470,20 @@ void jl_init_types() jl_lambda_info_type = jl_new_struct_type(jl_symbol("LambdaStaticData"), jl_any_type, jl_null, - jl_tuple(8, jl_symbol("ast"), jl_symbol("sparams"), + jl_tuple(9, jl_symbol("ast"), jl_symbol("sparams"), jl_symbol("tfunc"), jl_symbol("name"), /* jl_symbol("roots"), jl_symbol("specTypes"), jl_symbol("unspecialized"), jl_symbol("specializations")*/ jl_symbol(""), jl_symbol(""), - jl_symbol(""), jl_symbol("")), - jl_tuple(8, jl_expr_type, jl_tuple_type, + jl_symbol(""), jl_symbol(""), + jl_symbol("inferred")), + jl_tuple(9, jl_expr_type, jl_tuple_type, jl_any_type, jl_sym_type, jl_any_type, jl_tuple_type, - jl_function_type, jl_tuple_type)); + jl_function_type, jl_tuple_type, + jl_bool_type)); jl_lambda_info_type->fptr = jl_f_no_function; jl_box_type = diff --git a/src/julia.h b/src/julia.h index b8b072facbdd5..804d0a4751186 100644 --- a/src/julia.h +++ b/src/julia.h @@ -86,11 +86,11 @@ typedef struct _jl_lambda_info_t { struct _jl_function_t *unspecialized; // pairlist of all lambda infos with code generated from this one jl_tuple_t *specializations; + jl_value_t *inferred; // hidden fields: jl_fptr_t fptr; void *functionObject; - uptrint_t inferred; // flag telling if inference is running on this function // used to avoid infinite recursion uptrint_t inInference;