diff --git a/src/codegen.cpp b/src/codegen.cpp index 69cf58a2c814f..770115df8ddf0 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -568,6 +568,10 @@ static Value *global_binding_pointer(jl_module_t *m, jl_sym_t *s, static Value *emit_checked_var(Value *bp, jl_sym_t *name, jl_codectx_t *ctx, bool isvol=false); static bool might_need_root(jl_value_t *ex); static Value *emit_condition(jl_value_t *cond, const std::string &msg, jl_codectx_t *ctx); +static Value *emit_call_function_object(jl_function_t *f, Value *theF, + Value *theFptr, bool specialized, + jl_value_t **args, size_t nargs, + jl_codectx_t *ctx); // NoopType static Type *NoopType; @@ -2346,6 +2350,52 @@ static Value *emit_known_call(jl_value_t *ff, jl_value_t **args, size_t nargs, } } } + else if (f->fptr == &jl_f_invoke && nargs >= 2 && + expr_type(args[1], ctx) == (jl_value_t*)jl_function_type) { + // Get function + rt1 = static_eval(args[1], ctx); + f = (jl_function_t*)rt1; + if (!f || !jl_is_gf(f)) { + JL_GC_POP(); + return NULL; + } + // Get types + jl_value_t *type_types = expr_type(args[2], ctx); + rt2 = type_types; + // Only accept Type{Tuple{...}} + if (!jl_is_type_type(type_types) || + !jl_is_tuple_type(jl_tparam0(type_types))) { + JL_GC_POP(); + return NULL; + } + rt2 = jl_tparam0(type_types); + // Get types of real arguments + jl_svec_t *aty = call_arg_types(&args[3], nargs - 2, ctx); + if (!aty) { + JL_GC_POP(); + return NULL; + } + rt3 = (jl_value_t*)aty; + rt3 = (jl_value_t*)jl_apply_tuple_type(aty); + jl_function_t *mfunc = + jl_gf_invoke_get_specialization(f, (jl_tupletype_t*)rt2, + (jl_tupletype_t*)rt3); + JL_GC_POP(); + if (mfunc == NULL) { + return NULL; + } + if (!is_constant(args[2], ctx)) + emit_expr(args[2], ctx); + assert(mfunc->linfo->functionObject != NULL); + Value *_theF = literal_pointer_val((jl_value_t*)mfunc); + Value *_theFptr = emit_nthptr_recast( + _theF, + (ssize_t)(offsetof(jl_function_t, fptr) / sizeof(void*)), + tbaa_func, + jl_pfptr_llvmt); + return emit_call_function_object(mfunc, _theF, _theFptr, true, + args + 2, nargs - 2, ctx); + } // TODO: other known builtins JL_GC_POP(); return NULL; diff --git a/src/gf.c b/src/gf.c index c8d080889e9cd..8d73c544ab97f 100644 --- a/src/gf.c +++ b/src/gf.c @@ -1782,6 +1782,85 @@ DLLEXPORT jl_value_t *jl_gf_invoke_lookup(jl_function_t *gf, jl_datatype_t *type return (jl_value_t*)m; } +static jl_function_t* +invoke_specialize(jl_methlist_t *m, jl_methtable_t *mt, jl_tupletype_t *tt) +{ + jl_svec_t *tpenv = jl_emptysvec; + jl_tupletype_t *newsig = NULL; + JL_GC_PUSH2(&tpenv, &newsig); + + if (m->invokes == (void*)jl_nothing) { + m->invokes = new_method_table(mt->name); + gc_wb(m, m->invokes); + update_max_args(m->invokes, tt); + // this private method table has just this one definition + jl_method_list_insert(&m->invokes->defs, m->sig, m->func, + m->tvars, 0, 0, (jl_value_t*)m->invokes); + } + + newsig = m->sig; + + if (m->tvars != jl_emptysvec) { + jl_value_t *ti = lookup_match((jl_value_t*)tt, (jl_value_t*)m->sig, + &tpenv, m->tvars); + assert(ti != (jl_value_t*)jl_bottom_type); + (void)ti; + // don't bother computing this if no arguments are tuples + for (size_t i = 0;i < jl_nparams(tt);i++) { + if (jl_is_tuple_type(jl_tparam(tt, i))) { + newsig = (jl_tupletype_t*)jl_instantiate_type_with( + (jl_value_t*)m->sig, + jl_svec_data(tpenv), + jl_svec_len(tpenv) / 2); + break; + } + } + } + jl_function_t *mfunc = + cache_method(m->invokes, tt, m->func, newsig, tpenv, m->isstaged); + JL_GC_POP(); + return mfunc; +} + +// compile-time method lookup +jl_function_t* +jl_gf_invoke_get_specialization(jl_function_t *gf, jl_tupletype_t *types, + jl_tupletype_t *tt) +{ + assert(jl_is_gf(gf)); + jl_methtable_t *mt = jl_gf_mtable(gf); + jl_methlist_t *m = (jl_methlist_t*)jl_gf_invoke_lookup(gf, types); + size_t i; + + if ((jl_value_t*)m == jl_nothing) { + return NULL; + } + + // now we have found the matching definition. + // next look for or create a specialization of this definition. + + jl_function_t *mfunc; + if (m->invokes == (void*)jl_nothing) + mfunc = jl_bottom_func; + else + mfunc = jl_method_table_assoc_exact_by_type(m->invokes, tt); + if (mfunc != jl_bottom_func) { + if (mfunc->linfo == NULL || mfunc->linfo->inInference || + mfunc->linfo->inCompile) { + return NULL; + } + } else { + mfunc = invoke_specialize(m, mt, tt); + } + if (mfunc->linfo->functionObject == NULL) { + if (mfunc->fptr != &jl_trampoline) { + return NULL; + } + jl_compile(mfunc); + } + return mfunc; +} + // invoke() // this does method dispatch with a set of types to match other than the // types of the actual arguments. this means it sometimes does NOT call the @@ -1828,38 +1907,9 @@ jl_value_t *jl_gf_invoke(jl_function_t *gf, jl_tupletype_t *types, } } else { - jl_svec_t *tpenv=jl_emptysvec; - jl_tupletype_t *newsig=NULL; - jl_tupletype_t *tt=NULL; - JL_GC_PUSH3(&tpenv, &newsig, &tt); - tt = arg_type_tuple(args, nargs); - if (m->invokes == (void*)jl_nothing) { - m->invokes = new_method_table(mt->name); - gc_wb(m, m->invokes); - update_max_args(m->invokes, tt); - // this private method table has just this one definition - jl_method_list_insert(&m->invokes->defs,m->sig,m->func,m->tvars,0,0,(jl_value_t*)m->invokes); - } - - newsig = m->sig; - - if (m->tvars != jl_emptysvec) { - jl_value_t *ti = - lookup_match((jl_value_t*)tt, (jl_value_t*)m->sig, &tpenv, m->tvars); - assert(ti != (jl_value_t*)jl_bottom_type); - (void)ti; - // don't bother computing this if no arguments are tuples - for(i=0; i < jl_nparams(tt); i++) { - if (jl_is_tuple_type(jl_tparam(tt,i))) - break; - } - if (i < jl_nparams(tt)) { - newsig = (jl_tupletype_t*)jl_instantiate_type_with((jl_value_t*)m->sig, - jl_svec_data(tpenv), - jl_svec_len(tpenv)/2); - } - } - mfunc = cache_method(m->invokes, tt, m->func, newsig, tpenv, m->isstaged); + jl_tupletype_t *tt = arg_type_tuple(args, nargs); + JL_GC_PUSH1(&tt); + mfunc = invoke_specialize(m, mt, tt); JL_GC_POP(); } diff --git a/src/julia_internal.h b/src/julia_internal.h index 4dc90f14b3619..9b1a63001584b 100644 --- a/src/julia_internal.h +++ b/src/julia_internal.h @@ -127,6 +127,9 @@ DLLEXPORT void jl_read_sonames(void); jl_lambda_info_t *jl_add_static_parameters(jl_lambda_info_t *l, jl_svec_t *sp); jl_function_t *jl_get_specialization(jl_function_t *f, jl_tupletype_t *types); jl_function_t *jl_module_get_initializer(jl_module_t *m); +jl_function_t *jl_gf_invoke_get_specialization(jl_function_t *gf, + jl_tupletype_t *types, + jl_tupletype_t *tt); void jl_generate_fptr(jl_function_t *f); void jl_fptr_to_llvm(void *fptr, jl_lambda_info_t *lam, int specsig);