Skip to content

Commit

Permalink
fix #46778, precompile() for abstract but compileable signatures (#47259
Browse files Browse the repository at this point in the history
)
  • Loading branch information
JeffBezanson authored Nov 16, 2022
1 parent 9b3f5c3 commit fe81138
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 29 deletions.
4 changes: 0 additions & 4 deletions contrib/generate_precompile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -401,10 +401,6 @@ function generate_precompile_statements()
end
# println(ps)
ps = Core.eval(PrecompileStagingArea, ps)
# XXX: precompile doesn't currently handle overloaded nospecialize arguments very well.
# Skipping them avoids the warning.
ms = length(ps) == 1 ? Base._methods_by_ftype(ps[1], 1, Base.get_world_counter()) : Base.methods(ps...)
ms isa Vector || continue
precompile(ps...)
n_succeeded += 1
print("\rExecuting precompile statements... $n_succeeded/$(length(statements))")
Expand Down
125 changes: 100 additions & 25 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -2255,6 +2255,39 @@ JL_DLLEXPORT jl_value_t *jl_normalize_to_compilable_sig(jl_methtable_t *mt, jl_t
return is_compileable ? (jl_value_t*)tt : jl_nothing;
}

// return a MethodInstance for a compileable method_match
jl_method_instance_t *jl_method_match_to_mi(jl_method_match_t *match, size_t world, size_t min_valid, size_t max_valid, int mt_cache)
{
jl_method_t *m = match->method;
jl_svec_t *env = match->sparams;
jl_tupletype_t *ti = match->spec_types;
jl_method_instance_t *mi = NULL;
if (jl_is_datatype(ti)) {
jl_methtable_t *mt = jl_method_get_table(m);
if ((jl_value_t*)mt != jl_nothing) {
// get the specialization without caching it
if (mt_cache && ((jl_datatype_t*)ti)->isdispatchtuple) {
// Since we also use this presence in the cache
// to trigger compilation when producing `.ji` files,
// inject it there now if we think it will be
// used via dispatch later (e.g. because it was hinted via a call to `precompile`)
JL_LOCK(&mt->writelock);
mi = cache_method(mt, &mt->cache, (jl_value_t*)mt, ti, m, world, min_valid, max_valid, env);
JL_UNLOCK(&mt->writelock);
}
else {
jl_value_t *tt = jl_normalize_to_compilable_sig(mt, ti, env, m);
JL_GC_PUSH1(&tt);
if (tt != jl_nothing) {
mi = jl_specializations_get_linfo(m, (jl_value_t*)tt, env);
}
JL_GC_POP();
}
}
}
return mi;
}

// compile-time method lookup
jl_method_instance_t *jl_get_specialization1(jl_tupletype_t *types JL_PROPAGATES_ROOT, size_t world, size_t *min_valid, size_t *max_valid, int mt_cache)
{
Expand All @@ -2274,36 +2307,78 @@ jl_method_instance_t *jl_get_specialization1(jl_tupletype_t *types JL_PROPAGATES
*max_valid = max_valid2;
if (matches == jl_false || jl_array_len(matches) != 1 || ambig)
return NULL;
jl_value_t *tt = NULL;
JL_GC_PUSH2(&matches, &tt);
JL_GC_PUSH1(&matches);
jl_method_match_t *match = (jl_method_match_t*)jl_array_ptr_ref(matches, 0);
jl_method_t *m = match->method;
jl_svec_t *env = match->sparams;
jl_tupletype_t *ti = match->spec_types;
jl_method_instance_t *nf = NULL;
if (jl_is_datatype(ti)) {
jl_methtable_t *mt = jl_method_get_table(m);
if ((jl_value_t*)mt != jl_nothing) {
// get the specialization without caching it
if (mt_cache && ((jl_datatype_t*)ti)->isdispatchtuple) {
// Since we also use this presence in the cache
// to trigger compilation when producing `.ji` files,
// inject it there now if we think it will be
// used via dispatch later (e.g. because it was hinted via a call to `precompile`)
JL_LOCK(&mt->writelock);
nf = cache_method(mt, &mt->cache, (jl_value_t*)mt, ti, m, world, min_valid2, max_valid2, env);
JL_UNLOCK(&mt->writelock);
}
else {
tt = jl_normalize_to_compilable_sig(mt, ti, env, m);
if (tt != jl_nothing) {
nf = jl_specializations_get_linfo(m, (jl_value_t*)tt, env);
jl_method_instance_t *mi = jl_method_match_to_mi(match, world, min_valid2, max_valid2, mt_cache);
JL_GC_POP();
return mi;
}

// Get a MethodInstance for a precompile() call. This uses a special kind of lookup that
// tries to find a method for which the requested signature is compileable.
jl_method_instance_t *jl_get_compile_hint_specialization(jl_tupletype_t *types JL_PROPAGATES_ROOT, size_t world, size_t *min_valid, size_t *max_valid, int mt_cache)
{
if (jl_has_free_typevars((jl_value_t*)types))
return NULL; // don't poison the cache due to a malformed query
if (!jl_has_concrete_subtype((jl_value_t*)types))
return NULL;

size_t min_valid2 = 1;
size_t max_valid2 = ~(size_t)0;
int ambig = 0;
jl_value_t *matches = jl_matching_methods(types, jl_nothing, -1, 0, world, &min_valid2, &max_valid2, &ambig);
if (*min_valid < min_valid2)
*min_valid = min_valid2;
if (*max_valid > max_valid2)
*max_valid = max_valid2;
size_t i, n = jl_array_len(matches);
if (n == 0)
return NULL;
JL_GC_PUSH1(&matches);
jl_method_match_t *match = NULL;
if (n == 1) {
match = (jl_method_match_t*)jl_array_ptr_ref(matches, 0);
}
else {
// first, select methods for which `types` is compileable
size_t count = 0;
for (i = 0; i < n; i++) {
jl_method_match_t *match1 = (jl_method_match_t*)jl_array_ptr_ref(matches, i);
if (jl_isa_compileable_sig(types, match1->method))
jl_array_ptr_set(matches, count++, (jl_value_t*)match1);
}
jl_array_del_end((jl_array_t*)matches, n - count);
n = count;
// now remove methods that are more specific than others in the list.
// this is because the intent of precompiling e.g. f(::DataType) is to
// compile that exact method if it exists, and not lots of f(::Type{X}) methods
int exclude;
count = 0;
for (i = 0; i < n; i++) {
jl_method_match_t *match1 = (jl_method_match_t*)jl_array_ptr_ref(matches, i);
exclude = 0;
for (size_t j = n-1; j > i; j--) { // more general methods maybe more likely to be at end
jl_method_match_t *match2 = (jl_method_match_t*)jl_array_ptr_ref(matches, j);
if (jl_type_morespecific(match1->method->sig, match2->method->sig)) {
exclude = 1;
break;
}
}
if (!exclude)
jl_array_ptr_set(matches, count++, (jl_value_t*)match1);
if (count > 1)
break;
}
// at this point if there are 0 matches left we found nothing, or if there are
// more than one the request is ambiguous and we ignore it.
if (count == 1)
match = (jl_method_match_t*)jl_array_ptr_ref(matches, 0);
}
jl_method_instance_t *mi = NULL;
if (match != NULL)
mi = jl_method_match_to_mi(match, world, min_valid2, max_valid2, mt_cache);
JL_GC_POP();
return nf;
return mi;
}

static void _generate_from_hint(jl_method_instance_t *mi, size_t world)
Expand Down Expand Up @@ -2370,7 +2445,7 @@ JL_DLLEXPORT int jl_compile_hint(jl_tupletype_t *types)
size_t world = jl_atomic_load_acquire(&jl_world_counter);
size_t min_valid = 0;
size_t max_valid = ~(size_t)0;
jl_method_instance_t *mi = jl_get_specialization1(types, world, &min_valid, &max_valid, 1);
jl_method_instance_t *mi = jl_get_compile_hint_specialization(types, world, &min_valid, &max_valid, 1);
if (mi == NULL)
return 0;
JL_GC_PROMISE_ROOTED(mi);
Expand Down
7 changes: 7 additions & 0 deletions test/precompile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1556,3 +1556,10 @@ end

empty!(Base.DEPOT_PATH)
append!(Base.DEPOT_PATH, original_depot_path)

@testset "issue 46778" begin
f46778(::Any, ::Type{Int}) = 1
f46778(::Any, ::DataType) = 2
@test precompile(Tuple{typeof(f46778), Int, DataType})
@test which(f46778, Tuple{Any,DataType}).specializations[1].cache.invoke != C_NULL
end

0 comments on commit fe81138

Please sign in to comment.