From 23366835112610f2bfb22afb558a48b50942f759 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Fri, 13 May 2022 09:09:47 +0900 Subject: [PATCH] optimize a bit --- base/compiler/abstractinterpretation.jl | 100 ++++++++++++------------ base/compiler/inferencestate.jl | 9 +-- base/compiler/typelattice.jl | 24 ------ 3 files changed, 53 insertions(+), 80 deletions(-) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index e95b39c4d6b15..4979870a3a264 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -2253,7 +2253,8 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) frame.currbb = _bits_findnext(W.bits, 1)::Int # next basic block end - stoverwrite!(frame.pc_vartable, frame.bb_vartables[frame.currbb]) + states = frame.bb_vartables + currstate = copy(states[frame.currbb]) while frame.currbb <= nbbs delete!(W, frame.currbb) frame.currpc = first(bbs[frame.currbb].stmts) @@ -2274,7 +2275,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) @goto branch elseif isa(stmt, GotoIfNot) condx = stmt.cond - condt = abstract_eval_value(interp, condx, frame.pc_vartable, frame) + condt = abstract_eval_value(interp, condx, currstate, frame) if condt === Bottom empty!(frame.pclimitations) @goto find_next_bb @@ -2313,25 +2314,31 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) # We continue with the true branch, but process the false # branch here. if isa(condt, Conditional) - false_vartable = stoverwrite1!(copy(frame.pc_vartable), - conditional_changes(frame.pc_vartable, condt.elsetype, condt.var)) + else_change = conditional_change(currstate, condt.elsetype, condt.var) + if else_change !== nothing + false_vartable = stoverwrite1!(copy(currstate), else_change) + else + false_vartable = currstate + end if falsebb in analyzed_bbs - newstate = stupdate!(frame.bb_vartables[falsebb], false_vartable) + newstate = stupdate!(states[falsebb], false_vartable) else - newstate = frame.bb_vartables[falsebb] = stupdate!(nothing, false_vartable) + newstate = stoverwrite!(states[falsebb], false_vartable) push!(analyzed_bbs, falsebb) end - stoverwrite1!(frame.pc_vartable, - conditional_changes(frame.pc_vartable, condt.vtype, condt.var)) + then_change = conditional_change(currstate, condt.vtype, condt.var) + if then_change !== nothing + stoverwrite1!(currstate, then_change) + end else if falsebb in analyzed_bbs - newstate = stupdate!(frame.bb_vartables[falsebb], frame.pc_vartable) + newstate = stupdate!(states[falsebb], currstate) else - newstate = frame.bb_vartables[falsebb] = stupdate!(nothing, frame.pc_vartable) + newstate = stoverwrite!(states[falsebb], currstate) push!(analyzed_bbs, falsebb) end end - if newstate !== nothing || !was_reached(frame, first(bbs[falsebb].stmts)) + if newstate !== nothing handle_control_backedge!(frame, frame.currpc, stmt.dest) push!(W, falsebb) end @@ -2340,8 +2347,8 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) end elseif isa(stmt, ReturnNode) bestguess = frame.bestguess - rt = abstract_eval_value(interp, stmt.val, frame.pc_vartable, frame) - rt = widenreturn(rt, bestguess, nargs, slottypes, frame.pc_vartable) + rt = abstract_eval_value(interp, stmt.val, currstate, frame) + rt = widenreturn(rt, bestguess, nargs, slottypes, currstate) # narrow representation of bestguess slightly to prepare for tmerge with rt if rt isa InterConditional && bestguess isa Const let slot_id = rt.slot @@ -2380,9 +2387,9 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) l = stmt.args[1]::Int catchbb = block_for_inst(frame.cfg, l) if catchbb in analyzed_bbs - newstate = stupdate!(frame.bb_vartables[catchbb], frame.pc_vartable) + newstate = stupdate!(states[catchbb], currstate) else - newstate = frame.bb_vartables[catchbb] = stupdate!(nothing, frame.pc_vartable) + newstate = stoverwrite!(states[catchbb], currstate) push!(analyzed_bbs, catchbb) end if newstate !== nothing @@ -2395,12 +2402,12 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) end # Process non control-flow statements (; changes, type) = abstract_eval_basic_statement(interp, - stmt, frame.pc_vartable, frame) + stmt, currstate, frame) if type === Union{} @goto find_next_bb end if changes !== nothing - stoverwrite1!(frame.pc_vartable, changes) + stoverwrite1!(currstate, changes) let cur_hand = frame.handler_at[frame.currpc], l, enter while cur_hand != 0 enter = frame.src.code[cur_hand]::Expr @@ -2409,8 +2416,8 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) # propagate new type info to exception handler # the handling for Expr(:enter) propagates all changes from before the try/catch # so this only needs to propagate any changes - if stupdate1!(frame.bb_vartables[exceptbb], changes) || !was_reached(frame, first(bbs[exceptbb].stmts)) - push!(frame.ip, exceptbb) + if stupdate1!(states[exceptbb], changes) + push!(W, exceptbb) end cur_hand = frame.handler_at[cur_hand] end @@ -2427,56 +2434,47 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) end end # for frame.currpc in frame.currpc:bbend - # Case 1: Fallthrough termination - @label fallthrough - nextbb = frame.currbb + 1 - - # Case 2: Directly branch to a different BB - @label branch - if nextbb in analyzed_bbs - newstate = stupdate!(frame.bb_vartables[nextbb], frame.pc_vartable) - else - newstate = frame.bb_vartables[nextbb] = stupdate!(nothing, frame.pc_vartable) - push!(analyzed_bbs, nextbb) - end - if newstate !== nothing || !was_reached(frame, first(bbs[nextbb].stmts)) - push!(W, nextbb) + # Case 1: Fallthrough termination + begin @label fallthrough + nextbb = frame.currbb + 1 end - @goto find_next_bb - # TODO: Restore optimization - if nextbb <= nbbs - newstate = stupdate!(frame.bb_vartables[nextbb], frame.pc_vartable) + # Case 2: Directly branch to a different BB + begin @label branch + if nextbb in analyzed_bbs + newstate = stupdate!(states[nextbb], currstate) + else + newstate = stoverwrite!(states[nextbb], currstate) + push!(analyzed_bbs, nextbb) + end if newstate !== nothing - frame.currbb = nextbb - frame.currpc = first(bbs[nextbb].stmts) - stoverwrite!(frame.pc_vartable, newstate) - continue + push!(W, nextbb) end end - # Case 3: Control flow ended along the current path (converged, return or throw) - @label find_next_bb - frame.currbb = _bits_findnext(W.bits, 1)::Int # next basic block - frame.currbb == -1 && break # the working set is empty - frame.currbb > nbbs && break + # Case 3: Control flow ended along the current path (converged, return or throw) + begin @label find_next_bb + frame.currbb = _bits_findnext(W.bits, 1)::Int # next basic block + frame.currbb == -1 && break # the working set is empty + frame.currbb > nbbs && break - frame.currpc = first(bbs[frame.currbb].stmts) - stoverwrite!(frame.pc_vartable, frame.bb_vartables[frame.currbb]) + frame.currpc = first(bbs[frame.currbb].stmts) + stoverwrite!(currstate, states[frame.currbb]) + end end # while frame.currbb <= nbbs frame.dont_work_on_me = false nothing end -function conditional_changes(changes::VarTable, @nospecialize(typ), var::SlotNumber) - oldtyp = changes[slot_id(var)].typ +function conditional_change(state::VarTable, @nospecialize(typ), var::SlotNumber) + oldtyp = state[slot_id(var)].typ # approximate test for `typ ∩ oldtyp` being better than `oldtyp` # since we probably formed these types with `typesubstract`, the comparison is likely simple if ignorelimited(typ) ⊑ ignorelimited(oldtyp) # typ is better unlimited, but we may still need to compute the tmeet with the limit "causes" since we ignored those in the comparison oldtyp isa LimitedAccuracy && (typ = tmerge(typ, LimitedAccuracy(Bottom, oldtyp.causes))) - return StateUpdate(var, VarState(typ, false), changes, true) + return StateUpdate(var, VarState(typ, false), state, true) end return nothing end diff --git a/base/compiler/inferencestate.jl b/base/compiler/inferencestate.jl index b713e86f543ac..1ce3ff10f2a30 100644 --- a/base/compiler/inferencestate.jl +++ b/base/compiler/inferencestate.jl @@ -99,7 +99,6 @@ mutable struct InferenceState ssavalue_uses::Vector{BitSet} # ssavalue sparsity and restart info # TODO: Could keep this sparsely by doing structural liveness analysis ahead of time. bb_vartables::Vector{VarTable} - pc_vartable::VarTable stmt_edges::Vector{Union{Nothing, Vector{Any}}} stmt_info::Vector{Any} @@ -152,17 +151,17 @@ mutable struct InferenceState nslots = length(src.slotflags) slottypes = Vector{Any}(undef, nslots) - pc_vartable = VarTable(undef, nslots) + bb_vartable1 = VarTable(undef, nslots) bb_vartable_proto = VarTable(undef, nslots) argtypes = result.argtypes nargtypes = length(argtypes) for i in 1:nslots argtyp = (i > nargtypes) ? Bottom : argtypes[i] - pc_vartable[i] = VarState(argtyp, i > nargtypes) + bb_vartable1[i] = VarState(argtyp, i > nargtypes) bb_vartable_proto[i] = VarState(Bottom, i > nargtypes) slottypes[i] = argtyp end - bb_vartables = VarTable[i == 1 ? copy(pc_vartable) : copy(bb_vartable_proto) + bb_vartables = VarTable[i == 1 ? bb_vartable1 : copy(bb_vartable_proto) for i = 1:length(cfg.blocks)] pclimitations = IdSet{InferenceState}() @@ -193,7 +192,7 @@ mutable struct InferenceState frame = new( linfo, world, mod, sptypes, slottypes, src, cfg, - currbb, currpc, ip, was_reached, handler_at, ssavalue_uses, bb_vartables, pc_vartable, stmt_edges, stmt_info, + currbb, currpc, ip, was_reached, handler_at, ssavalue_uses, bb_vartables, stmt_edges, stmt_info, pclimitations, limitations, cycle_backedges, callers_in_cycle, dont_work_on_me, parent, inferred, result, valid_worlds, bestguess, ipo_effects, params, restrict_abstract_call_sites, cached, diff --git a/base/compiler/typelattice.jl b/base/compiler/typelattice.jl index d2226a9a1e23f..0d24ae493b72d 100644 --- a/base/compiler/typelattice.jl +++ b/base/compiler/typelattice.jl @@ -376,27 +376,6 @@ widenwrappedconditional(typ::LimitedAccuracy) = LimitedAccuracy(widenconditional ignorelimited(@nospecialize typ) = typ ignorelimited(typ::LimitedAccuracy) = typ.typ -function stupdate!(state::Nothing, changes::StateUpdate) - newst = copy(changes.state) - changeid = slot_id(changes.var) - newst[changeid] = changes.vtype - # remove any Conditional for this slot from the vtable - # (unless this change is came from the conditional) - if !changes.conditional - for i = 1:length(newst) - newtype = newst[i] - if isa(newtype, VarState) - newtypetyp = ignorelimited(newtype.typ) - if isa(newtypetyp, Conditional) && slot_id(newtypetyp.var) == changeid - newtypetyp = widenwrappedconditional(newtype.typ) - newst[i] = VarState(newtypetyp, newtype.undef) - end - end - end - end - return newst -end - function stupdate!(state::VarTable, changes::StateUpdate) newstate = nothing changeid = slot_id(changes.var) @@ -437,8 +416,6 @@ function stupdate!(state::VarTable, changes::VarTable) return newstate end -stupdate!(::Nothing, changes::VarTable) = copy(changes) - function stupdate1!(state::VarTable, change::StateUpdate) changeid = slot_id(change.var) # remove any Conditional for this slot from the catch block vtable @@ -499,4 +476,3 @@ function stoverwrite1!(state::VarTable, change::StateUpdate) state[changeid] = newtype return state end -stoverwrite1!(state::VarTable, ::Nothing) = state