Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix #46778, precompile() for abstract but compileable signatures #47259

Merged
merged 1 commit into from
Nov 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions contrib/generate_precompile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -401,10 +401,6 @@ function generate_precompile_statements()
end
# println(ps)
ps = Core.eval(PrecompileStagingArea, ps)
# XXX: precompile doesn't currently handle overloaded nospecialize arguments very well.
# Skipping them avoids the warning.
ms = length(ps) == 1 ? Base._methods_by_ftype(ps[1], 1, Base.get_world_counter()) : Base.methods(ps...)
ms isa Vector || continue
precompile(ps...)
n_succeeded += 1
print("\rExecuting precompile statements... $n_succeeded/$(length(statements))")
Expand Down
125 changes: 100 additions & 25 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -2255,6 +2255,39 @@ JL_DLLEXPORT jl_value_t *jl_normalize_to_compilable_sig(jl_methtable_t *mt, jl_t
return is_compileable ? (jl_value_t*)tt : jl_nothing;
}

// return a MethodInstance for a compileable method_match
jl_method_instance_t *jl_method_match_to_mi(jl_method_match_t *match, size_t world, size_t min_valid, size_t max_valid, int mt_cache)
{
jl_method_t *m = match->method;
jl_svec_t *env = match->sparams;
jl_tupletype_t *ti = match->spec_types;
jl_method_instance_t *mi = NULL;
if (jl_is_datatype(ti)) {
jl_methtable_t *mt = jl_method_get_table(m);
if ((jl_value_t*)mt != jl_nothing) {
// get the specialization without caching it
if (mt_cache && ((jl_datatype_t*)ti)->isdispatchtuple) {
// Since we also use this presence in the cache
// to trigger compilation when producing `.ji` files,
// inject it there now if we think it will be
// used via dispatch later (e.g. because it was hinted via a call to `precompile`)
JL_LOCK(&mt->writelock);
mi = cache_method(mt, &mt->cache, (jl_value_t*)mt, ti, m, world, min_valid, max_valid, env);
JL_UNLOCK(&mt->writelock);
}
else {
jl_value_t *tt = jl_normalize_to_compilable_sig(mt, ti, env, m);
JL_GC_PUSH1(&tt);
if (tt != jl_nothing) {
mi = jl_specializations_get_linfo(m, (jl_value_t*)tt, env);
}
JL_GC_POP();
}
}
}
return mi;
}

// compile-time method lookup
jl_method_instance_t *jl_get_specialization1(jl_tupletype_t *types JL_PROPAGATES_ROOT, size_t world, size_t *min_valid, size_t *max_valid, int mt_cache)
{
Expand All @@ -2274,36 +2307,78 @@ jl_method_instance_t *jl_get_specialization1(jl_tupletype_t *types JL_PROPAGATES
*max_valid = max_valid2;
if (matches == jl_false || jl_array_len(matches) != 1 || ambig)
return NULL;
jl_value_t *tt = NULL;
JL_GC_PUSH2(&matches, &tt);
JL_GC_PUSH1(&matches);
jl_method_match_t *match = (jl_method_match_t*)jl_array_ptr_ref(matches, 0);
jl_method_t *m = match->method;
jl_svec_t *env = match->sparams;
jl_tupletype_t *ti = match->spec_types;
jl_method_instance_t *nf = NULL;
if (jl_is_datatype(ti)) {
jl_methtable_t *mt = jl_method_get_table(m);
if ((jl_value_t*)mt != jl_nothing) {
// get the specialization without caching it
if (mt_cache && ((jl_datatype_t*)ti)->isdispatchtuple) {
// Since we also use this presence in the cache
// to trigger compilation when producing `.ji` files,
// inject it there now if we think it will be
// used via dispatch later (e.g. because it was hinted via a call to `precompile`)
JL_LOCK(&mt->writelock);
nf = cache_method(mt, &mt->cache, (jl_value_t*)mt, ti, m, world, min_valid2, max_valid2, env);
JL_UNLOCK(&mt->writelock);
}
else {
tt = jl_normalize_to_compilable_sig(mt, ti, env, m);
if (tt != jl_nothing) {
nf = jl_specializations_get_linfo(m, (jl_value_t*)tt, env);
jl_method_instance_t *mi = jl_method_match_to_mi(match, world, min_valid2, max_valid2, mt_cache);
JL_GC_POP();
return mi;
}

// Get a MethodInstance for a precompile() call. This uses a special kind of lookup that
// tries to find a method for which the requested signature is compileable.
jl_method_instance_t *jl_get_compile_hint_specialization(jl_tupletype_t *types JL_PROPAGATES_ROOT, size_t world, size_t *min_valid, size_t *max_valid, int mt_cache)
{
if (jl_has_free_typevars((jl_value_t*)types))
return NULL; // don't poison the cache due to a malformed query
if (!jl_has_concrete_subtype((jl_value_t*)types))
return NULL;

size_t min_valid2 = 1;
size_t max_valid2 = ~(size_t)0;
int ambig = 0;
jl_value_t *matches = jl_matching_methods(types, jl_nothing, -1, 0, world, &min_valid2, &max_valid2, &ambig);
if (*min_valid < min_valid2)
*min_valid = min_valid2;
if (*max_valid > max_valid2)
*max_valid = max_valid2;
size_t i, n = jl_array_len(matches);
if (n == 0)
return NULL;
JL_GC_PUSH1(&matches);
jl_method_match_t *match = NULL;
if (n == 1) {
match = (jl_method_match_t*)jl_array_ptr_ref(matches, 0);
}
else {
// first, select methods for which `types` is compileable
size_t count = 0;
for (i = 0; i < n; i++) {
jl_method_match_t *match1 = (jl_method_match_t*)jl_array_ptr_ref(matches, i);
if (jl_isa_compileable_sig(types, match1->method))
jl_array_ptr_set(matches, count++, (jl_value_t*)match1);
}
jl_array_del_end((jl_array_t*)matches, n - count);
n = count;
// now remove methods that are more specific than others in the list.
// this is because the intent of precompiling e.g. f(::DataType) is to
// compile that exact method if it exists, and not lots of f(::Type{X}) methods
int exclude;
count = 0;
for (i = 0; i < n; i++) {
jl_method_match_t *match1 = (jl_method_match_t*)jl_array_ptr_ref(matches, i);
exclude = 0;
for (size_t j = n-1; j > i; j--) { // more general methods maybe more likely to be at end
jl_method_match_t *match2 = (jl_method_match_t*)jl_array_ptr_ref(matches, j);
if (jl_type_morespecific(match1->method->sig, match2->method->sig)) {
exclude = 1;
break;
}
}
if (!exclude)
jl_array_ptr_set(matches, count++, (jl_value_t*)match1);
if (count > 1)
break;
}
// at this point if there are 0 matches left we found nothing, or if there are
// more than one the request is ambiguous and we ignore it.
if (count == 1)
match = (jl_method_match_t*)jl_array_ptr_ref(matches, 0);
}
jl_method_instance_t *mi = NULL;
if (match != NULL)
mi = jl_method_match_to_mi(match, world, min_valid2, max_valid2, mt_cache);
JL_GC_POP();
return nf;
return mi;
}

static void _generate_from_hint(jl_method_instance_t *mi, size_t world)
Expand Down Expand Up @@ -2370,7 +2445,7 @@ JL_DLLEXPORT int jl_compile_hint(jl_tupletype_t *types)
size_t world = jl_atomic_load_acquire(&jl_world_counter);
size_t min_valid = 0;
size_t max_valid = ~(size_t)0;
jl_method_instance_t *mi = jl_get_specialization1(types, world, &min_valid, &max_valid, 1);
jl_method_instance_t *mi = jl_get_compile_hint_specialization(types, world, &min_valid, &max_valid, 1);
if (mi == NULL)
return 0;
JL_GC_PROMISE_ROOTED(mi);
Expand Down
7 changes: 7 additions & 0 deletions test/precompile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1556,3 +1556,10 @@ end

empty!(Base.DEPOT_PATH)
append!(Base.DEPOT_PATH, original_depot_path)

@testset "issue 46778" begin
f46778(::Any, ::Type{Int}) = 1
f46778(::Any, ::DataType) = 2
@test precompile(Tuple{typeof(f46778), Int, DataType})
@test which(f46778, Tuple{Any,DataType}).specializations[1].cache.invoke != C_NULL
end