diff --git a/base/compiler/types.jl b/base/compiler/types.jl index 642a7ac551662..53a363a2280e4 100644 --- a/base/compiler/types.jl +++ b/base/compiler/types.jl @@ -373,20 +373,17 @@ end function NativeInterpreter(world::UInt = get_world_counter(); inf_params::InferenceParams = InferenceParams(), opt_params::OptimizationParams = OptimizationParams()) + curr_max_world = get_world_counter() # Sometimes the caller is lazy and passes typemax(UInt). # we cap it to the current world age for correctness if world == typemax(UInt) - world = get_world_counter() + world = curr_max_world end - # If they didn't pass typemax(UInt) but passed something more subtly # incorrect, fail out loudly. - @assert world <= get_world_counter() - + @assert world <= curr_max_world method_table = CachedMethodTable(InternalMethodTable(world)) - inf_cache = Vector{InferenceResult}() # Initially empty cache - return NativeInterpreter(world, method_table, inf_cache, inf_params, opt_params) end diff --git a/src/gf.c b/src/gf.c index 20179157ff836..fdf61f440908b 100644 --- a/src/gf.c +++ b/src/gf.c @@ -444,8 +444,11 @@ STATIC_INLINE jl_value_t *_jl_rettype_inferred(jl_value_t *owner, jl_method_inst if (jl_atomic_load_relaxed(&codeinst->min_world) <= min_world && max_world <= jl_atomic_load_relaxed(&codeinst->max_world) && jl_egal(codeinst->owner, owner)) { - jl_value_t *code = jl_atomic_load_relaxed(&codeinst->inferred); - if (code && (code == jl_nothing || jl_ir_flag_inferred(code))) + jl_value_t *inferred = jl_atomic_load_relaxed(&codeinst->inferred); + if (inferred && ((inferred == jl_nothing) || ( + // allow whatever code instance external abstract interpreter produced + // since `jl_ir_flag_inferred` is specific to the native interpreter + codeinst->owner != jl_nothing || jl_ir_flag_inferred(inferred)))) return (jl_value_t*)codeinst; } codeinst = jl_atomic_load_relaxed(&codeinst->next); diff --git a/test/compiler/AbstractInterpreter.jl b/test/compiler/AbstractInterpreter.jl index c0b320009b8ec..2068997c77c82 100644 --- a/test/compiler/AbstractInterpreter.jl +++ b/test/compiler/AbstractInterpreter.jl @@ -399,7 +399,6 @@ end Core.eval(Core.Compiler, quote f(;a=1) = a end) @test_throws MethodError Core.Compiler.f(;b=2) - # Custom lookup function # ====================== @@ -469,3 +468,35 @@ let # generate cache @test occursin("j_sin_", s) @test !occursin("j_cos_", s) end + +# custom inferred data +# ==================== + +@newinterp CustomDataInterp +struct CustomDataInterpToken end +CC.cache_owner(::CustomDataInterp) = CustomDataInterpToken() +struct CustomData + inferred + CustomData(@nospecialize inferred) = new(inferred) +end +function CC.transform_result_for_cache(interp::CustomDataInterp, + mi::Core.MethodInstance, valid_worlds::CC.WorldRange, result::CC.InferenceResult) + inferred_result = @invoke CC.transform_result_for_cache(interp::CC.AbstractInterpreter, + mi::Core.MethodInstance, valid_worlds::CC.WorldRange, result::CC.InferenceResult) + return CustomData(inferred_result) +end +function CC.inlining_policy(interp::CustomDataInterp, @nospecialize(src), + @nospecialize(info::CC.CallInfo), stmt_flag::UInt32) + if src isa CustomData + src = src.inferred + end + return @invoke CC.inlining_policy(interp::CC.AbstractInterpreter, src::Any, + info::CC.CallInfo, stmt_flag::UInt32) +end +let src = code_typed((Int,); interp=CustomDataInterp()) do x + return sin(x) + cos(x) + end |> only |> first + @test count(isinvoke(:sin), src.code) == 1 + @test count(isinvoke(:cos), src.code) == 1 + @test count(isinvoke(:+), src.code) == 0 +end