Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Propagate iteration info to optimizer #36684

Merged
merged 1 commit into from
Jul 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 49 additions & 29 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
push!(fullmatch, thisfullmatch)
end
end
info = UnionSplitInfo(splitsigs, infos)
info = UnionSplitInfo(infos)
else
mt = ccall(:jl_method_table_for, Any, (Any,), atype)
if mt === nothing
Expand Down Expand Up @@ -505,13 +505,13 @@ end
# returns an array of types
function precise_container_type(interp::AbstractInterpreter, @nospecialize(itft), @nospecialize(typ), vtypes::VarTable, sv::InferenceState)
if isa(typ, PartialStruct) && typ.typ.name === Tuple.name
return typ.fields
return typ.fields, nothing
end

if isa(typ, Const)
val = typ.val
if isa(val, SimpleVector) || isa(val, Tuple)
return Any[ Const(val[i]) for i in 1:length(val) ] # avoid making a tuple Generator here!
return Any[ Const(val[i]) for i in 1:length(val) ], nothing # avoid making a tuple Generator here!
end
end

Expand All @@ -529,27 +529,27 @@ function precise_container_type(interp::AbstractInterpreter, @nospecialize(itft)
if isa(tti, Union)
utis = uniontypes(tti)
if _any(t -> !isa(t, DataType) || !(t <: Tuple) || !isknownlength(t), utis)
return Any[Vararg{Any}]
return Any[Vararg{Any}], nothing
end
result = Any[rewrap_unionall(p, tti0) for p in utis[1].parameters]
for t in utis[2:end]
if length(t.parameters) != length(result)
return Any[Vararg{Any}]
return Any[Vararg{Any}], nothing
end
for j in 1:length(t.parameters)
result[j] = tmerge(result[j], rewrap_unionall(t.parameters[j], tti0))
end
end
return result
return result, nothing
elseif tti0 <: Tuple
if isa(tti0, DataType)
if isvatuple(tti0) && length(tti0.parameters) == 1
return Any[Vararg{unwrapva(tti0.parameters[1])}]
return Any[Vararg{unwrapva(tti0.parameters[1])}], nothing
else
return Any[ p for p in tti0.parameters ]
return Any[ p for p in tti0.parameters ], nothing
end
elseif !isa(tti, DataType)
return Any[Vararg{Any}]
return Any[Vararg{Any}], nothing
else
len = length(tti.parameters)
last = tti.parameters[len]
Expand All @@ -558,12 +558,12 @@ function precise_container_type(interp::AbstractInterpreter, @nospecialize(itft)
if va
elts[len] = Vararg{elts[len]}
end
return elts
return elts, nothing
end
elseif tti0 === SimpleVector || tti0 === Any
return Any[Vararg{Any}]
return Any[Vararg{Any}], nothing
elseif tti0 <: Array
return Any[Vararg{eltype(tti0)}]
return Any[Vararg{eltype(tti0)}], nothing
else
return abstract_iteration(interp, itft, typ, vtypes, sv)
end
Expand All @@ -572,30 +572,34 @@ end
# simulate iteration protocol on container type up to fixpoint
function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @nospecialize(itertype), vtypes::VarTable, sv::InferenceState)
if !isdefined(Main, :Base) || !isdefined(Main.Base, :iterate) || !isconst(Main.Base, :iterate)
return Any[Vararg{Any}]
return Any[Vararg{Any}], nothing
end
if itft === nothing
iteratef = getfield(Main.Base, :iterate)
itft = Const(iteratef)
elseif isa(itft, Const)
iteratef = itft.val
else
return Any[Vararg{Any}]
return Any[Vararg{Any}], nothing
end
@assert !isvarargtype(itertype)
stateordonet = abstract_call_known(interp, iteratef, nothing, Any[itft, itertype], vtypes, sv).rt
call = abstract_call_known(interp, iteratef, nothing, Any[itft, itertype], vtypes, sv)
stateordonet = call.rt
info = call.info
# Return Bottom if this is not an iterator.
# WARNING: Changes to the iteration protocol must be reflected here,
# this is not just an optimization.
stateordonet === Bottom && return Any[Bottom]
stateordonet === Bottom && return Any[Bottom], AbstractIterationInfo(CallMeta[CallMeta(Bottom, info)])
valtype = statetype = Bottom
ret = Any[]
calls = CallMeta[call]

# Try to unroll the iteration up to MAX_TUPLE_SPLAT, which covers any finite
# length iterators, or interesting prefix
while true
stateordonet_widened = widenconst(stateordonet)
if stateordonet_widened === Nothing
return ret
return ret, AbstractIterationInfo(calls)
end
if Nothing <: stateordonet_widened || length(ret) >= InferenceParams(interp).MAX_TUPLE_SPLAT
break
Expand All @@ -607,12 +611,14 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
# If there's no new information in this statetype, don't bother continuing,
# the iterator won't be finite.
if nstatetype statetype
return Any[Bottom]
return Any[Bottom], nothing
end
valtype = getfield_tfunc(stateordonet, Const(1))
push!(ret, valtype)
statetype = nstatetype
stateordonet = abstract_call_known(interp, iteratef, nothing, Any[Const(iteratef), itertype, statetype], vtypes, sv).rt
call = abstract_call_known(interp, iteratef, nothing, Any[Const(iteratef), itertype, statetype], vtypes, sv)
stateordonet = call.rt
push!(calls, call)
end
# From here on, we start asking for results on the widened types, rather than
# the precise (potentially const) state type
Expand All @@ -629,15 +635,15 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
if nounion.parameters[1] <: valtype && nounion.parameters[2] <: statetype
if typeintersect(stateordonet, Nothing) === Union{}
# Reached a fixpoint, but Nothing is not possible => iterator is infinite or failing
return Any[Bottom]
return Any[Bottom], nothing
end
break
end
valtype = tmerge(valtype, nounion.parameters[1])
statetype = tmerge(statetype, nounion.parameters[2])
end
push!(ret, Vararg{valtype})
return ret
return ret, nothing
end

# do apply(af, fargs...), where af is a function value
Expand All @@ -656,13 +662,15 @@ function abstract_apply(interp::AbstractInterpreter, @nospecialize(itft), @nospe
nargs = length(aargtypes)
splitunions = 1 < countunionsplit(aargtypes) <= InferenceParams(interp).MAX_APPLY_UNION_ENUM
ctypes = Any[Any[aft]]
infos = [Union{Nothing, AbstractIterationInfo}[]]
for i = 1:nargs
ctypes´ = []
infos′ = []
for ti in (splitunions ? uniontypes(aargtypes[i]) : Any[aargtypes[i]])
if !isvarargtype(ti)
cti = precise_container_type(interp, itft, ti, vtypes, sv)
cti, info = precise_container_type(interp, itft, ti, vtypes, sv)
else
cti = precise_container_type(interp, itft, unwrapva(ti), vtypes, sv)
cti, info = precise_container_type(interp, itft, unwrapva(ti), vtypes, sv)
# We can't represent a repeating sequence of the same types,
# so tmerge everything together to get one type that represents
# everything.
Expand All @@ -678,19 +686,29 @@ function abstract_apply(interp::AbstractInterpreter, @nospecialize(itft), @nospe
if _any(t -> t === Bottom, cti)
continue
end
for ct in ctypes
for j = 1:length(ctypes)
ct = ctypes[j]
if isvarargtype(ct[end])
# This is vararg, we're not gonna be able to do any inling,
# drop the info
info = nothing

tail = tuple_tail_elem(unwrapva(ct[end]), cti)
push!(ctypes´, push!(ct[1:(end - 1)], tail))
else
push!(ctypes´, append!(ct[:], cti))
end
push!(infos′, push!(copy(infos[j]), info))
end
end
ctypes = ctypes´
infos = infos′
end
local info = nothing
for ct in ctypes
retinfos = ApplyCallInfo[]
retinfo = UnionSplitApplyCallInfo(retinfos)
for i = 1:length(ctypes)
ct = ctypes[i]
arginfo = infos[i]
lct = length(ct)
# truncate argument list at the first Vararg
for i = 1:lct-1
Expand All @@ -701,15 +719,17 @@ function abstract_apply(interp::AbstractInterpreter, @nospecialize(itft), @nospe
end
end
call = abstract_call(interp, nothing, ct, vtypes, sv, max_methods)
info = call.info
push!(retinfos, ApplyCallInfo(call.info, arginfo))
res = tmerge(res, call.rt)
if res === Any
# No point carrying forward the info, we're not gonna inline it anyway
retinfo = nothing
break
end
end
# TODO: Add a special info type to capture all the iteration info.
# For now, only propagate info if we don't also union-split the iteration
return CallMeta(res, length(ctypes) == 1 ? info : false)
return CallMeta(res, retinfo)
end

function is_method_pure(method::Method, @nospecialize(sig), sparams::SimpleVector)
Expand Down Expand Up @@ -779,7 +799,7 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, fargs::U
end
rt = builtin_tfunction(interp, f, argtypes[2:end], sv)
if f === getfield && isa(fargs, Vector{Any}) && la == 3 && isa(argtypes[3], Const) && isa(argtypes[3].val, Int) && argtypes[2] Tuple
cti = precise_container_type(interp, nothing, argtypes[2], vtypes, sv)
cti, _ = precise_container_type(interp, nothing, argtypes[2], vtypes, sv)
idx = argtypes[3].val
if 1 <= idx <= length(cti)
rt = unwrapva(cti[idx])
Expand Down
Loading