Skip to content

Commit

Permalink
RFC: Add inference->optimize analysis forwarding mechanism
Browse files Browse the repository at this point in the history
This change attempts to be a solution to the generalized problem
encountered in #36169. In short, we do a whole bunch of analysis
during inference to figure out the final type of an expression,
but sometimes, we may need intermediate results that were
computed along the way. So far, we don't really have a great
place to put those results, so we end up having to re-compute
them during the optimization phase. That's what #36169 did,
but is clearly not a scalable solution.

I encountered the exact same issue while working on a new AD
compiler plugin, that needs to do a whole bunch of work during
inference to determine what to do (e.g. call a primitive, recurse,
or increase the derivative level), and optimizations need to have
access to this information.

This PR adds an additional `info` field to CodeInfo and IRCode
that can be used to forward this kind of information. As a proof
of concept, it forwards method match info from inference to
inlining (we do already cache these, so there's little performance
gain from this per se - it's more to exercise the infrastructure).

The plan is to do an alternative fix to #36169 on top of this
as the next step, but I figured I'd open it up for discussion first.
  • Loading branch information
Keno committed Jul 2, 2020
1 parent f192ead commit 06f5158
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 20 deletions.
2 changes: 2 additions & 0 deletions base/compiler/ssair/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ function convert_to_ircode(ci::CodeInfo, code::Vector{Any}, coverage::Bool, narg
insert!(code, idx, Expr(:code_coverage_effect))
insert!(ci.codelocs, idx, codeloc)
insert!(ci.ssavaluetypes, idx, Nothing)
insert!(ci.stmtinfo, idx, nothing)
changemap[oldidx] += 1
if oldidx < length(labelmap)
labelmap[oldidx + 1] += 1
Expand All @@ -61,6 +62,7 @@ function convert_to_ircode(ci::CodeInfo, code::Vector{Any}, coverage::Bool, narg
insert!(code, idx + 1, ReturnNode())
insert!(ci.codelocs, idx + 1, ci.codelocs[idx])
insert!(ci.ssavaluetypes, idx + 1, Union{})
insert!(ci.stmtinfo, idx, nothing)
if oldidx < length(changemap)
changemap[oldidx + 1] += 1
coverage && (labelmap[oldidx + 1] += 1)
Expand Down
45 changes: 34 additions & 11 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1006,6 +1006,22 @@ function process_simple!(ir::IRCode, idx::Int, params::OptimizationParams, world
return (sig, invoke_data)
end

# This is not currently called in the regular course, but may be needed
# if we ever want to re-run inlining again later in the pass pipeline after
# additional type information was discovered.
function recompute_method_matches(atype, sv)
# Regular case: Retrieve matching methods from cache (or compute them)
# World age does not need to be taken into account in the cache
# because it is forwarded from type inference through `sv.params`
# in the case that the cache is nonempty, so it should be unchanged
# The max number of methods should be the same as in inference most
# of the time, and should not affect correctness otherwise.
(meth, min_valid, max_valid) =
matching_methods(atype, sv.matching_methods_cache, sv.params.MAX_METHODS, sv.world)
update_valid_age!(min_valid, max_valid, sv)
MethodMatchInfo(meth)
end

function assemble_inline_todo!(ir::IRCode, sv::OptimizationState)
# todo = (inline_idx, (isva, isinvoke, na), method, spvals, inline_linetable, inline_ir, lie)
todo = Any[]
Expand All @@ -1015,6 +1031,7 @@ function assemble_inline_todo!(ir::IRCode, sv::OptimizationState)

stmt = ir.stmts[idx][:inst]
calltype = ir.stmts[idx][:type]
info = ir.stmts[idx][:info]
(sig, invoke_data) = r

# Ok, now figure out what method to call
Expand All @@ -1025,12 +1042,22 @@ function assemble_inline_todo!(ir::IRCode, sv::OptimizationState)

nu = countunionsplit(sig.atypes)
if nu == 1 || nu > sv.params.MAX_UNION_SPLITTING
if !isa(info, MethodMatchInfo)
info = nothing
end
infos = [info]
splits = Any[sig.atype]
else
splits = Any[]
for union_sig in UnionSplitSignature(sig.atypes)
push!(splits, argtypes_to_type(union_sig))
end
if !isa(info, UnionSplitInfo)
infos = fill(nothing, length(splits))
else
@assert length(info.matches) == length(splits)
infos = info.matches
end
end

cases = Pair{Any, Any}[]
Expand All @@ -1039,15 +1066,13 @@ function assemble_inline_todo!(ir::IRCode, sv::OptimizationState)
too_many = false
local meth
local fully_covered = true
for atype in splits
# Regular case: Retrieve matching methods from cache (or compute them)
# World age does not need to be taken into account in the cache
# because it is forwarded from type inference through `sv.params`
# in the case that the cache is nonempty, so it should be unchanged
# The max number of methods should be the same as in inference most
# of the time, and should not affect correctness otherwise.
(meth, min_valid, max_valid) =
matching_methods(atype, sv.matching_methods_cache, sv.params.MAX_METHODS, sv.world)
for i in 1:length(splits)
atype = splits[i]
info = infos[i]
if info === nothing
info = recompute_method_matches(atype, sv)
end
meth = info.applicable
if meth === false
# Too many applicable methods
too_many = true
Expand All @@ -1064,8 +1089,6 @@ function assemble_inline_todo!(ir::IRCode, sv::OptimizationState)
else
only_method = false
end
update_valid_age!(min_valid, max_valid, sv)

for match in meth::Vector{Any}
(metharg, methsp, method) = (match[1]::Type, match[2]::SimpleVector, match[3]::Method)
# TODO: This could be better
Expand Down
8 changes: 2 additions & 6 deletions base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,7 @@ length(is::InstructionStream) = length(is.inst)
isempty(is::InstructionStream) = isempty(is.inst)
function add!(is::InstructionStream)
ninst = length(is) + 1
resize!(is.inst, ninst)
resize!(is.type, ninst)
resize!(is.info, ninst)
resize!(is.line, ninst)
resize!(is.flag, ninst)
resize!(is, ninst)
return ninst
end
#function copy(is::InstructionStream) # unused
Expand Down Expand Up @@ -226,7 +222,7 @@ end
function setindex!(is::InstructionStream, newval::Instruction, idx::Int)
is.inst[idx] = newval[:inst]
is.type[idx] = newval[:type]
#is.info[idx] = newval[:info]
is.info[idx] = newval[:info]
is.line[idx] = newval[:line]
is.flag[idx] = newval[:flag]
return is
Expand Down
5 changes: 3 additions & 2 deletions base/compiler/ssair/legacy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,13 @@ function replace_code_newstyle!(ci::CodeInfo, ir::IRCode, nargs::Int)
# All but the first `nargs` slots will now be unused
resize!(ci.slotflags, nargs + 1)
stmts = ir.stmts
ci.code, ci.ssavaluetypes, ci.codelocs, ci.ssaflags, ci.linetable =
stmts.inst, stmts.type, stmts.line, stmts.flag, ir.linetable
ci.code, ci.ssavaluetypes, ci.stmtinfo, ci.codelocs, ci.ssaflags, ci.linetable =
stmts.inst, stmts.type, stmts.info, stmts.line, stmts.flag, ir.linetable
for metanode in ir.meta
push!(ci.code, metanode)
push!(ci.codelocs, 1)
push!(ci.ssavaluetypes, Any)
push!(ci.stmtinfo, nothing)
push!(ci.ssaflags, 0x00)
end
# Translate BB Edges to statement edges
Expand Down
2 changes: 2 additions & 0 deletions base/compiler/ssair/slot2ssa.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ function strip_trailing_junk!(ci::CodeInfo, code::Vector{Any}, flags::Vector{UIn
resize!(code, i)
resize!(ci.ssavaluetypes, i)
resize!(ci.codelocs, i)
resize!(ci.stmtinfo, i)
resize!(flags, i)
break
end
Expand All @@ -202,6 +203,7 @@ function strip_trailing_junk!(ci::CodeInfo, code::Vector{Any}, flags::Vector{UIn
push!(code, ReturnNode())
push!(ci.ssavaluetypes, Union{})
push!(ci.codelocs, 0)
push!(ci.stmtinfo, nothing)
push!(flags, 0x00)
end
nothing
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/stmtinfo.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
struct MethodMatchInfo
applicable::Vector{Any}
applicable::Any
end

struct UnionSplitInfo
Expand Down
1 change: 1 addition & 0 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ function type_annotate!(sv::InferenceState)
deleteat!(states, i)
deleteat!(src.ssavaluetypes, i)
deleteat!(src.codelocs, i)
deleteat!(src.stmtinfo, i)
nexpr -= 1
if oldidx < length(changemap)
changemap[oldidx + 1] = -1
Expand Down
1 change: 1 addition & 0 deletions test/worlds.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ function equal(ci1::Core.CodeInfo, ci2::Core.CodeInfo)
return ci1.code == ci2.code &&
ci1.codelocs == ci2.codelocs &&
ci1.ssavaluetypes == ci2.ssavaluetypes &&
ci1.stmtinfo == ci2.stmtinfo &&
ci1.ssaflags == ci2.ssaflags &&
ci1.method_for_inference_limit_heuristics == ci2.method_for_inference_limit_heuristics &&
ci1.linetable == ci2.linetable &&
Expand Down

0 comments on commit 06f5158

Please sign in to comment.