Skip to content

Commit

Permalink
optimize a bit
Browse files Browse the repository at this point in the history
  • Loading branch information
aviatesk committed May 13, 2022
1 parent 616971c commit 2336683
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 80 deletions.
100 changes: 49 additions & 51 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2253,7 +2253,8 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
frame.currbb = _bits_findnext(W.bits, 1)::Int # next basic block
end

stoverwrite!(frame.pc_vartable, frame.bb_vartables[frame.currbb])
states = frame.bb_vartables
currstate = copy(states[frame.currbb])
while frame.currbb <= nbbs
delete!(W, frame.currbb)
frame.currpc = first(bbs[frame.currbb].stmts)
Expand All @@ -2274,7 +2275,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
@goto branch
elseif isa(stmt, GotoIfNot)
condx = stmt.cond
condt = abstract_eval_value(interp, condx, frame.pc_vartable, frame)
condt = abstract_eval_value(interp, condx, currstate, frame)
if condt === Bottom
empty!(frame.pclimitations)
@goto find_next_bb
Expand Down Expand Up @@ -2313,25 +2314,31 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
# We continue with the true branch, but process the false
# branch here.
if isa(condt, Conditional)
false_vartable = stoverwrite1!(copy(frame.pc_vartable),
conditional_changes(frame.pc_vartable, condt.elsetype, condt.var))
else_change = conditional_change(currstate, condt.elsetype, condt.var)
if else_change !== nothing
false_vartable = stoverwrite1!(copy(currstate), else_change)
else
false_vartable = currstate
end
if falsebb in analyzed_bbs
newstate = stupdate!(frame.bb_vartables[falsebb], false_vartable)
newstate = stupdate!(states[falsebb], false_vartable)
else
newstate = frame.bb_vartables[falsebb] = stupdate!(nothing, false_vartable)
newstate = stoverwrite!(states[falsebb], false_vartable)
push!(analyzed_bbs, falsebb)
end
stoverwrite1!(frame.pc_vartable,
conditional_changes(frame.pc_vartable, condt.vtype, condt.var))
then_change = conditional_change(currstate, condt.vtype, condt.var)
if then_change !== nothing
stoverwrite1!(currstate, then_change)
end
else
if falsebb in analyzed_bbs
newstate = stupdate!(frame.bb_vartables[falsebb], frame.pc_vartable)
newstate = stupdate!(states[falsebb], currstate)
else
newstate = frame.bb_vartables[falsebb] = stupdate!(nothing, frame.pc_vartable)
newstate = stoverwrite!(states[falsebb], currstate)
push!(analyzed_bbs, falsebb)
end
end
if newstate !== nothing || !was_reached(frame, first(bbs[falsebb].stmts))
if newstate !== nothing
handle_control_backedge!(frame, frame.currpc, stmt.dest)
push!(W, falsebb)
end
Expand All @@ -2340,8 +2347,8 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
end
elseif isa(stmt, ReturnNode)
bestguess = frame.bestguess
rt = abstract_eval_value(interp, stmt.val, frame.pc_vartable, frame)
rt = widenreturn(rt, bestguess, nargs, slottypes, frame.pc_vartable)
rt = abstract_eval_value(interp, stmt.val, currstate, frame)
rt = widenreturn(rt, bestguess, nargs, slottypes, currstate)
# narrow representation of bestguess slightly to prepare for tmerge with rt
if rt isa InterConditional && bestguess isa Const
let slot_id = rt.slot
Expand Down Expand Up @@ -2380,9 +2387,9 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
l = stmt.args[1]::Int
catchbb = block_for_inst(frame.cfg, l)
if catchbb in analyzed_bbs
newstate = stupdate!(frame.bb_vartables[catchbb], frame.pc_vartable)
newstate = stupdate!(states[catchbb], currstate)
else
newstate = frame.bb_vartables[catchbb] = stupdate!(nothing, frame.pc_vartable)
newstate = stoverwrite!(states[catchbb], currstate)
push!(analyzed_bbs, catchbb)
end
if newstate !== nothing
Expand All @@ -2395,12 +2402,12 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
end
# Process non control-flow statements
(; changes, type) = abstract_eval_basic_statement(interp,
stmt, frame.pc_vartable, frame)
stmt, currstate, frame)
if type === Union{}
@goto find_next_bb
end
if changes !== nothing
stoverwrite1!(frame.pc_vartable, changes)
stoverwrite1!(currstate, changes)
let cur_hand = frame.handler_at[frame.currpc], l, enter
while cur_hand != 0
enter = frame.src.code[cur_hand]::Expr
Expand All @@ -2409,8 +2416,8 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
# propagate new type info to exception handler
# the handling for Expr(:enter) propagates all changes from before the try/catch
# so this only needs to propagate any changes
if stupdate1!(frame.bb_vartables[exceptbb], changes) || !was_reached(frame, first(bbs[exceptbb].stmts))
push!(frame.ip, exceptbb)
if stupdate1!(states[exceptbb], changes)
push!(W, exceptbb)
end
cur_hand = frame.handler_at[cur_hand]
end
Expand All @@ -2427,56 +2434,47 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
end
end # for frame.currpc in frame.currpc:bbend

# Case 1: Fallthrough termination
@label fallthrough
nextbb = frame.currbb + 1

# Case 2: Directly branch to a different BB
@label branch
if nextbb in analyzed_bbs
newstate = stupdate!(frame.bb_vartables[nextbb], frame.pc_vartable)
else
newstate = frame.bb_vartables[nextbb] = stupdate!(nothing, frame.pc_vartable)
push!(analyzed_bbs, nextbb)
end
if newstate !== nothing || !was_reached(frame, first(bbs[nextbb].stmts))
push!(W, nextbb)
# Case 1: Fallthrough termination
begin @label fallthrough
nextbb = frame.currbb + 1
end
@goto find_next_bb

# TODO: Restore optimization
if nextbb <= nbbs
newstate = stupdate!(frame.bb_vartables[nextbb], frame.pc_vartable)
# Case 2: Directly branch to a different BB
begin @label branch
if nextbb in analyzed_bbs
newstate = stupdate!(states[nextbb], currstate)
else
newstate = stoverwrite!(states[nextbb], currstate)
push!(analyzed_bbs, nextbb)
end
if newstate !== nothing
frame.currbb = nextbb
frame.currpc = first(bbs[nextbb].stmts)
stoverwrite!(frame.pc_vartable, newstate)
continue
push!(W, nextbb)
end
end

# Case 3: Control flow ended along the current path (converged, return or throw)
@label find_next_bb
frame.currbb = _bits_findnext(W.bits, 1)::Int # next basic block
frame.currbb == -1 && break # the working set is empty
frame.currbb > nbbs && break
# Case 3: Control flow ended along the current path (converged, return or throw)
begin @label find_next_bb
frame.currbb = _bits_findnext(W.bits, 1)::Int # next basic block
frame.currbb == -1 && break # the working set is empty
frame.currbb > nbbs && break

frame.currpc = first(bbs[frame.currbb].stmts)
stoverwrite!(frame.pc_vartable, frame.bb_vartables[frame.currbb])
frame.currpc = first(bbs[frame.currbb].stmts)
stoverwrite!(currstate, states[frame.currbb])
end
end # while frame.currbb <= nbbs

frame.dont_work_on_me = false
nothing
end

function conditional_changes(changes::VarTable, @nospecialize(typ), var::SlotNumber)
oldtyp = changes[slot_id(var)].typ
function conditional_change(state::VarTable, @nospecialize(typ), var::SlotNumber)
oldtyp = state[slot_id(var)].typ
# approximate test for `typ ∩ oldtyp` being better than `oldtyp`
# since we probably formed these types with `typesubstract`, the comparison is likely simple
if ignorelimited(typ) ignorelimited(oldtyp)
# typ is better unlimited, but we may still need to compute the tmeet with the limit "causes" since we ignored those in the comparison
oldtyp isa LimitedAccuracy && (typ = tmerge(typ, LimitedAccuracy(Bottom, oldtyp.causes)))
return StateUpdate(var, VarState(typ, false), changes, true)
return StateUpdate(var, VarState(typ, false), state, true)
end
return nothing
end
Expand Down
9 changes: 4 additions & 5 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ mutable struct InferenceState
ssavalue_uses::Vector{BitSet} # ssavalue sparsity and restart info
# TODO: Could keep this sparsely by doing structural liveness analysis ahead of time.
bb_vartables::Vector{VarTable}
pc_vartable::VarTable
stmt_edges::Vector{Union{Nothing, Vector{Any}}}
stmt_info::Vector{Any}

Expand Down Expand Up @@ -152,17 +151,17 @@ mutable struct InferenceState

nslots = length(src.slotflags)
slottypes = Vector{Any}(undef, nslots)
pc_vartable = VarTable(undef, nslots)
bb_vartable1 = VarTable(undef, nslots)
bb_vartable_proto = VarTable(undef, nslots)
argtypes = result.argtypes
nargtypes = length(argtypes)
for i in 1:nslots
argtyp = (i > nargtypes) ? Bottom : argtypes[i]
pc_vartable[i] = VarState(argtyp, i > nargtypes)
bb_vartable1[i] = VarState(argtyp, i > nargtypes)
bb_vartable_proto[i] = VarState(Bottom, i > nargtypes)
slottypes[i] = argtyp
end
bb_vartables = VarTable[i == 1 ? copy(pc_vartable) : copy(bb_vartable_proto)
bb_vartables = VarTable[i == 1 ? bb_vartable1 : copy(bb_vartable_proto)
for i = 1:length(cfg.blocks)]

pclimitations = IdSet{InferenceState}()
Expand Down Expand Up @@ -193,7 +192,7 @@ mutable struct InferenceState

frame = new(
linfo, world, mod, sptypes, slottypes, src, cfg,
currbb, currpc, ip, was_reached, handler_at, ssavalue_uses, bb_vartables, pc_vartable, stmt_edges, stmt_info,
currbb, currpc, ip, was_reached, handler_at, ssavalue_uses, bb_vartables, stmt_edges, stmt_info,
pclimitations, limitations, cycle_backedges, callers_in_cycle, dont_work_on_me, parent, inferred,
result, valid_worlds, bestguess, ipo_effects,
params, restrict_abstract_call_sites, cached,
Expand Down
24 changes: 0 additions & 24 deletions base/compiler/typelattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -376,27 +376,6 @@ widenwrappedconditional(typ::LimitedAccuracy) = LimitedAccuracy(widenconditional
ignorelimited(@nospecialize typ) = typ
ignorelimited(typ::LimitedAccuracy) = typ.typ

function stupdate!(state::Nothing, changes::StateUpdate)
newst = copy(changes.state)
changeid = slot_id(changes.var)
newst[changeid] = changes.vtype
# remove any Conditional for this slot from the vtable
# (unless this change is came from the conditional)
if !changes.conditional
for i = 1:length(newst)
newtype = newst[i]
if isa(newtype, VarState)
newtypetyp = ignorelimited(newtype.typ)
if isa(newtypetyp, Conditional) && slot_id(newtypetyp.var) == changeid
newtypetyp = widenwrappedconditional(newtype.typ)
newst[i] = VarState(newtypetyp, newtype.undef)
end
end
end
end
return newst
end

function stupdate!(state::VarTable, changes::StateUpdate)
newstate = nothing
changeid = slot_id(changes.var)
Expand Down Expand Up @@ -437,8 +416,6 @@ function stupdate!(state::VarTable, changes::VarTable)
return newstate
end

stupdate!(::Nothing, changes::VarTable) = copy(changes)

function stupdate1!(state::VarTable, change::StateUpdate)
changeid = slot_id(change.var)
# remove any Conditional for this slot from the catch block vtable
Expand Down Expand Up @@ -499,4 +476,3 @@ function stoverwrite1!(state::VarTable, change::StateUpdate)
state[changeid] = newtype
return state
end
stoverwrite1!(state::VarTable, ::Nothing) = state

0 comments on commit 2336683

Please sign in to comment.