Skip to content

Commit

Permalink
inference: concretize invoke callsite correctly (#46743)
Browse files Browse the repository at this point in the history
It turns out that previously we didn't concretize `invoke` callsite
correctly, as we didn't take into account whether the call being
concretized is `invoke`-d or not, e.g.:
```
julia> invoke_concretized2(a::Int) = a > 0 ? :int : nothing
invoke_concretized2 (generic function with 1 method)

julia> invoke_concretized2(a::Integer) = a > 0 ? :integer : nothing
invoke_concretized2 (generic function with 2 methods)

julia> let
           Base.Experimental.@force_compile
           Base.@invoke invoke_concretized2(42::Integer)
       end
:int # this should return `:integer` instead
```

This commit fixes that up by propagating information `invoke`-d callsite
to `concrete_eval_call`. Now we should be able to pass the following test cases:
```julia
invoke_concretized1(a::Int) = a > 0 ? :int : nothing
invoke_concretized1(a::Integer) = a > 0 ? "integer" : nothing
@test Base.infer_effects((Int,)) do a
    @invoke invoke_concretized1(a::Integer)
end |> Core.Compiler.is_foldable
@test Base.return_types() do
    @invoke invoke_concretized1(42::Integer)
end |> only === String

invoke_concretized2(a::Int) = a > 0 ? :int : nothing
invoke_concretized2(a::Integer) = a > 0 ? :integer : nothing
@test Base.infer_effects((Int,)) do a
    @invoke invoke_concretized2(a::Integer)
end |> Core.Compiler.is_foldable
@test let
    Base.Experimental.@force_compile
    @invoke invoke_concretized2(42::Integer)
end === :integer
```
  • Loading branch information
aviatesk authored Sep 14, 2022
1 parent 7cbea73 commit 8455c8f
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 14 deletions.
37 changes: 23 additions & 14 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -819,17 +819,24 @@ function collect_const_args(argtypes::Vector{Any}, start::Int)
end for i = start:length(argtypes) ]
end

function invoke_signature(invokesig::Vector{Any})
ft, argtyps = widenconst(invokesig[2]), instanceof_tfunc(widenconst(invokesig[3]))[1]
return rewrap_unionall(Tuple{ft, unwrap_unionall(argtyps).parameters...}, argtyps)
struct InvokeCall
types # ::Type
lookupsig # ::Type
InvokeCall(@nospecialize(types), @nospecialize(lookupsig)) = new(types, lookupsig)
end

function concrete_eval_call(interp::AbstractInterpreter,
@nospecialize(f), result::MethodCallResult, arginfo::ArgInfo, sv::InferenceState)
@nospecialize(f), result::MethodCallResult, arginfo::ArgInfo, sv::InferenceState,
invokecall::Union{Nothing,InvokeCall}=nothing)
eligible = concrete_eval_eligible(interp, f, result, arginfo, sv)
eligible === nothing && return false
if eligible
args = collect_const_args(arginfo, #=start=#2)
if invokecall !== nothing
# this call should be `invoke`d, rewrite `args` back now
pushfirst!(args, f, invokecall.types)
f = invoke
end
world = get_world_counter(interp)
value = try
Core._call_in_world_total(world, f, args...)
Expand Down Expand Up @@ -874,15 +881,15 @@ struct ConstCallResults
new(rt, const_result, effects)
end

function abstract_call_method_with_const_args(interp::AbstractInterpreter, result::MethodCallResult,
@nospecialize(f), arginfo::ArgInfo, match::MethodMatch,
sv::InferenceState, @nospecialize(invoketypes=nothing))
function abstract_call_method_with_const_args(interp::AbstractInterpreter,
result::MethodCallResult, @nospecialize(f), arginfo::ArgInfo, match::MethodMatch,
sv::InferenceState, invokecall::Union{Nothing,InvokeCall}=nothing)
if !const_prop_enabled(interp, sv, match)
return nothing
end
res = concrete_eval_call(interp, f, result, arginfo, sv)
res = concrete_eval_call(interp, f, result, arginfo, sv, invokecall)
if isa(res, ConstCallResults)
add_backedge!(res.const_result.mi, sv, invoketypes)
add_backedge!(res.const_result.mi, sv, invokecall === nothing ? nothing : invokecall.lookupsig)
return res
end
mi = maybe_get_const_prop_profitable(interp, result, f, arginfo, match, sv)
Expand Down Expand Up @@ -1674,18 +1681,18 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn
nargtype isa DataType || return CallMeta(Any, Effects(), false) # other cases are not implemented below
isdispatchelem(ft) || return CallMeta(Any, Effects(), false) # check that we might not have a subtype of `ft` at runtime, before doing supertype lookup below
ft = ft::DataType
types = rewrap_unionall(Tuple{ft, unwrap_unionall(types).parameters...}, types)::Type
lookupsig = rewrap_unionall(Tuple{ft, unwrap_unionall(types).parameters...}, types)::Type
nargtype = Tuple{ft, nargtype.parameters...}
argtype = Tuple{ft, argtype.parameters...}
match, valid_worlds, overlayed = findsup(types, method_table(interp))
match, valid_worlds, overlayed = findsup(lookupsig, method_table(interp))
match === nothing && return CallMeta(Any, Effects(), false)
update_valid_age!(sv, valid_worlds)
method = match.method
tienv = ccall(:jl_type_intersection_with_env, Any, (Any, Any), nargtype, method.sig)::SimpleVector
ti = tienv[1]; env = tienv[2]::SimpleVector
result = abstract_call_method(interp, method, ti, env, false, sv)
(; rt, edge, effects) = result
edge !== nothing && add_backedge!(edge::MethodInstance, sv, types)
edge !== nothing && add_backedge!(edge::MethodInstance, sv, lookupsig)
match = MethodMatch(ti, env, method, argtype <: method.sig)
res = nothing
sig = match.spec_types
Expand All @@ -1697,8 +1704,10 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn
# t, a = ti.parameters[i], argtypes′[i]
# argtypes′[i] = t ⊑ a ? t : a
# end
const_call_result = abstract_call_method_with_const_args(interp, result,
overlayed ? nothing : singleton_type(ft′), arginfo, match, sv, types)
f = overlayed ? nothing : singleton_type(ft′)
invokecall = InvokeCall(types, lookupsig)
const_call_result = abstract_call_method_with_const_args(interp,
result, f, arginfo, match, sv, invokecall)
const_result = nothing
if const_call_result !== nothing
if (typeinf_lattice(interp), const_call_result.rt, rt)
Expand Down
5 changes: 5 additions & 0 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1162,6 +1162,11 @@ function inline_invoke!(
return nothing
end

function invoke_signature(invokesig::Vector{Any})
ft, argtyps = widenconst(invokesig[2]), instanceof_tfunc(widenconst(invokesig[3]))[1]
return rewrap_unionall(Tuple{ft, unwrap_unionall(argtyps).parameters...}, argtyps)
end

function narrow_opaque_closure!(ir::IRCode, stmt::Expr, @nospecialize(info), state::InliningState)
if isa(info, OpaqueClosureCreateInfo)
lbt = argextype(stmt.args[2], ir)
Expand Down
21 changes: 21 additions & 0 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4211,3 +4211,24 @@ function isa_kindtype(T::Type{<:AbstractVector})
return nothing
end
@test only(Base.return_types(isa_kindtype)) === Union{Nothing,Symbol}

invoke_concretized1(a::Int) = a > 0 ? :int : nothing
invoke_concretized1(a::Integer) = a > 0 ? "integer" : nothing
# check if `invoke(invoke_concretized1, Tuple{Integer}, ::Int)` is foldable
@test Base.infer_effects((Int,)) do a
@invoke invoke_concretized1(a::Integer)
end |> Core.Compiler.is_foldable
@test Base.return_types() do
@invoke invoke_concretized1(42::Integer)
end |> only === String

invoke_concretized2(a::Int) = a > 0 ? :int : nothing
invoke_concretized2(a::Integer) = a > 0 ? :integer : nothing
# check if `invoke(invoke_concretized2, Tuple{Integer}, ::Int)` is foldable
@test Base.infer_effects((Int,)) do a
@invoke invoke_concretized2(a::Integer)
end |> Core.Compiler.is_foldable
@test let
Base.Experimental.@force_compile
@invoke invoke_concretized2(42::Integer)
end === :integer

0 comments on commit 8455c8f

Please sign in to comment.