Skip to content

Commit

Permalink
Use argextype for builtin / intrinsic check
Browse files Browse the repository at this point in the history
  • Loading branch information
topolarity committed Oct 9, 2024
1 parent 99a22ec commit 739ac62
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 38 deletions.
1 change: 1 addition & 0 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3113,6 +3113,7 @@ end
abstract_eval_ssavalue(s::SSAValue, sv::InferenceState) = abstract_eval_ssavalue(s, sv.ssavaluetypes)

function abstract_eval_ssavalue(s::SSAValue, ssavaluetypes::Vector{Any})
(1 s.id length(ssavaluetypes)) || throw(InvalidIRError())
typ = ssavaluetypes[s.id]
if typ === NOT_FOUND
return Bottom
Expand Down
25 changes: 21 additions & 4 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -411,26 +411,35 @@ function argextype(@nospecialize(x), compact::IncrementalCompact, sptypes::Vecto
isa(x, AnySSAValue) && return types(compact)[x]
return argextype(x, compact, sptypes, compact.ir.argtypes)
end
argextype(@nospecialize(x), src::CodeInfo, sptypes::Vector{VarState}) = argextype(x, src, sptypes, src.slottypes::Vector{Any})
function argextype(@nospecialize(x), src::CodeInfo, sptypes::Vector{VarState})
return argextype(x, src, sptypes, src.slottypes::Union{Vector{Any},Nothing})
end
function argextype(
@nospecialize(x), src::Union{IRCode,IncrementalCompact,CodeInfo},
sptypes::Vector{VarState}, slottypes::Vector{Any})
sptypes::Vector{VarState}, slottypes::Union{Vector{Any},Nothing})
if isa(x, Expr)
if x.head === :static_parameter
return sptypes[x.args[1]::Int].typ
idx = x.args[1]::Int
(1 idx length(sptypes)) || throw(InvalidIRError())
return sptypes[idx].typ
elseif x.head === :boundscheck
return Bool
elseif x.head === :copyast
length(x.args) == 0 && throw(InvalidIRError())
return argextype(x.args[1], src, sptypes, slottypes)
end
Core.println("argextype called on Expr with head ", x.head,
" which is not valid for IR in argument-position.")
@assert false
elseif isa(x, SlotNumber)
slottypes === nothing && return Any
(1 x.id length(slottypes)) || throw(InvalidIRError())
return slottypes[x.id]
elseif isa(x, SSAValue)
return abstract_eval_ssavalue(x, src)
elseif isa(x, Argument)
slottypes === nothing && return Any
(1 x.n length(slottypes)) || throw(InvalidIRError())
return slottypes[x.n]
elseif isa(x, QuoteNode)
return Const(x.value)
Expand All @@ -444,7 +453,15 @@ function argextype(
return Const(x)
end
end
abstract_eval_ssavalue(s::SSAValue, src::CodeInfo) = abstract_eval_ssavalue(s, src.ssavaluetypes::Vector{Any})
function abstract_eval_ssavalue(s::SSAValue, src::CodeInfo)
ssavaluetypes = src.ssavaluetypes
if ssavaluetypes isa Int
(1 s.id ssavaluetypes) || throw(InvalidIRError())
return Any
else
return abstract_eval_ssavalue(s, ssavaluetypes::Vector{Any})
end
end
abstract_eval_ssavalue(s::SSAValue, src::Union{IRCode,IncrementalCompact}) = types(src)[s]

"""
Expand Down
36 changes: 28 additions & 8 deletions base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ Instruction(is::InstructionStream) = Instruction(is, add_new_idx!(is))
fldarray = getfield(getfield(node, :data), fld)
fldidx = getfield(node, :idx)
(fld === :line) && return (fldarray[3fldidx-2], fldarray[3fldidx-1], fldarray[3fldidx-0])
(1 fldidx length(fldarray)) || throw(InvalidIRError())
return fldarray[fldidx]
end
@inline function setindex!(node::Instruction, @nospecialize(val), fld::Symbol)
Expand Down Expand Up @@ -481,11 +482,16 @@ function block_for_inst(ir::IRCode, inst::Int)
end

function getindex(ir::IRCode, s::SSAValue)
id = s.id
(id 1) || throw(InvalidIRError())
nstmts = length(ir.stmts)
if s.id <= nstmts
return ir.stmts[s.id]
if id <= nstmts
return ir.stmts[id]
else
return ir.new_nodes.stmts[s.id - nstmts]
id -= nstmts
stmts = ir.new_nodes.stmts
(id length(stmts)) || throw(InvalidIRError())
return stmts[id]
end
end

Expand Down Expand Up @@ -801,12 +807,13 @@ end
types(ir::Union{IRCode, IncrementalCompact}) = TypesView(ir)

function getindex(compact::IncrementalCompact, ssa::SSAValue)
@assert ssa.id < compact.result_idx
(1 ssa.id compact.result_idx) || throw(InvalidIRError())
return compact.result[ssa.id]
end

function getindex(compact::IncrementalCompact, ssa::OldSSAValue)
id = ssa.id
(id 1) || throw(InvalidIRError())
if id < compact.idx
new_idx = compact.ssa_rename[id]::Int
return compact.result[new_idx]
Expand All @@ -818,12 +825,15 @@ function getindex(compact::IncrementalCompact, ssa::OldSSAValue)
return compact.ir.new_nodes.stmts[id]
end
id -= length(compact.ir.new_nodes)
(id length(compact.pending_nodes.stmts)) || throw(InvalidIRError())
return compact.pending_nodes.stmts[id]
end

function getindex(compact::IncrementalCompact, ssa::NewSSAValue)
if ssa.id < 0
return compact.new_new_nodes.stmts[-ssa.id]
stmts = compact.new_new_nodes.stmts
(-ssa.id length(stmts)) || throw(InvalidIRError())
return stmts[-ssa.id]
else
return compact[SSAValue(ssa.id)]
end
Expand Down Expand Up @@ -1069,6 +1079,7 @@ function getindex(view::TypesView, v::OldSSAValue)
id = v.id
ir = view.ir.ir
stmts = ir.stmts
(id 1) || throw(InvalidIRError())
if id <= length(stmts)
return stmts[id][:type]
end
Expand All @@ -1077,7 +1088,9 @@ function getindex(view::TypesView, v::OldSSAValue)
return ir.new_nodes.stmts[id][:type]
end
id -= length(ir.new_nodes)
return view.ir.pending_nodes.stmts[id][:type]
stmts = view.ir.pending_nodes.stmts
(id length(stmts)) || throw(InvalidIRError())
return stmts[id][:type]
end

function kill_current_use!(compact::IncrementalCompact, @nospecialize(val))
Expand Down Expand Up @@ -1204,20 +1217,27 @@ end

getindex(view::TypesView, idx::SSAValue) = getindex(view, idx.id)
function getindex(view::TypesView, idx::Int)
(idx 1) || throw(InvalidIRError())
if isa(view.ir, IncrementalCompact) && idx < view.ir.result_idx
return view.ir.result[idx][:type]
elseif isa(view.ir, IncrementalCompact) && view.ir.renamed_new_nodes
if idx <= length(view.ir.result)
return view.ir.result[idx][:type]
else
return view.ir.new_new_nodes.stmts[idx - length(view.ir.result)][:type]
idx -= length(view.ir.result)
stmts = view.ir.new_new_nodes.stmts
(idx length(stmts)) || throw(InvalidIRError())
return stmts[idx][:type]
end
else
ir = isa(view.ir, IncrementalCompact) ? view.ir.ir : view.ir
if idx <= length(ir.stmts)
return ir.stmts[idx][:type]
else
return ir.new_nodes.stmts[idx - length(ir.stmts)][:type]
idx -= length(ir.stmts)
stmts = ir.new_nodes.stmts
(idx length(stmts)) || throw(InvalidIRError())
return stmts[idx][:type]
end
end
end
Expand Down
62 changes: 38 additions & 24 deletions base/compiler/ssair/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ end

import Base: show_unquoted
using Base: printstyled, with_output_color, prec_decl, @invoke
using Core.Compiler: VarState, InvalidIRError

function Base.show(io::IO, cfg::CFG)
print(io, "CFG with $(length(cfg.blocks)) blocks:")
Expand All @@ -31,7 +32,8 @@ function Base.show(io::IO, cfg::CFG)
end
end

function print_stmt(io::IO, idx::Int, @nospecialize(stmt), used::BitSet, maxlength_idx::Int, color::Bool, show_type::Bool)
function print_stmt(io::IO, idx::Int, @nospecialize(stmt), code::Union{IRCode,CodeInfo,IncrementalCompact},
sptypes::Vector{VarState}, used::BitSet, maxlength_idx::Int, color::Bool, show_type::Bool)
if idx in used
idx_s = string(idx)
pad = " "^(maxlength_idx - length(idx_s) + 1)
Expand Down Expand Up @@ -67,21 +69,29 @@ function print_stmt(io::IO, idx::Int, @nospecialize(stmt), used::BitSet, maxleng
join(io, (print_arg(i) for i = 3:length(stmt.args)), ", ")
print(io, ")")
elseif isexpr(stmt, :call) && length(stmt.args) >= 1
arg1 = stmt.args[1]
if arg1 isa GlobalRef && isdefined(arg1.mod, arg1.name)
arg1 = getfield(arg1.mod, arg1.name)
ft = try
Core.Compiler.argextype(stmt.args[1], code, sptypes)
catch err
!(err isa InvalidIRError) && rethrow()
nothing
end
if isa(arg1, Core.IntrinsicFunction)
f = Core.Compiler.singleton_type(ft)
if isa(f, Core.IntrinsicFunction)
printstyled(io, "intrinsic "; color = :light_black)
elseif isa(arg1, Core.Builtin)
if (arg1 === Core._apply_iterate || arg1 === Core._apply_pure ||
arg1 === Core._call_in_world || arg1 === Core._call_in_world_total ||
arg1 === Core._call_latest)
elseif isa(f, Core.Builtin)
if (f === Core._apply_iterate || f === Core._apply_pure ||
f === Core._call_in_world || f === Core._call_in_world_total ||
f === Core._call_latest)
# These apply-like builtins are effectively dynamic calls
printstyled(io, "dynamic builtin "; color = :yellow)
else
printstyled(io, "builtin "; color = :light_black)
end
elseif f === nothing
# This should only happen when, e.g., printing a call that targets
# an out-of-bounds SSAValue or similar
# (i.e. under normal circumstances, dead code)
printstyled(io, "unknown "; color = :light_black)
else
printstyled(io, "dynamic "; color = :yellow)
end
Expand Down Expand Up @@ -648,13 +658,13 @@ end
# at index `idx`. This function is repeatedly called until it returns `nothing`.
# to iterate nodes that are to be inserted after the statement, set `attach_after=true`.
function show_ir_stmt(io::IO, code::Union{IRCode, CodeInfo, IncrementalCompact}, idx::Int, config::IRShowConfig,
used::BitSet, cfg::CFG, bb_idx::Int; pop_new_node! = Returns(nothing), only_after::Bool=false)
sptypes::Vector{VarState}, used::BitSet, cfg::CFG, bb_idx::Int; pop_new_node! = Returns(nothing), only_after::Bool=false)
return show_ir_stmt(io, code, idx, config.line_info_preprinter, config.line_info_postprinter,
used, cfg, bb_idx; pop_new_node!, only_after, config.bb_color)
sptypes, used, cfg, bb_idx; pop_new_node!, only_after, config.bb_color)
end

function show_ir_stmt(io::IO, code::Union{IRCode, CodeInfo, IncrementalCompact}, idx::Int, line_info_preprinter, line_info_postprinter,
used::BitSet, cfg::CFG, bb_idx::Int; pop_new_node! = Returns(nothing), only_after::Bool=false, bb_color=:light_black)
sptypes::Vector{VarState}, used::BitSet, cfg::CFG, bb_idx::Int; pop_new_node! = Returns(nothing), only_after::Bool=false, bb_color=:light_black)
stmt = _stmt(code, idx)
type = _type(code, idx)
max_bb_idx_size = length(string(length(cfg.blocks)))
Expand Down Expand Up @@ -713,7 +723,7 @@ function show_ir_stmt(io::IO, code::Union{IRCode, CodeInfo, IncrementalCompact},
show_type = should_print_ssa_type(new_node_inst)
let maxlength_idx=maxlength_idx, show_type=show_type
with_output_color(:green, io) do io′
print_stmt(io′, node_idx, new_node_inst, used, maxlength_idx, false, show_type)
print_stmt(io′, node_idx, new_node_inst, code, sptypes, used, maxlength_idx, false, show_type)
end
end

Expand Down Expand Up @@ -742,7 +752,7 @@ function show_ir_stmt(io::IO, code::Union{IRCode, CodeInfo, IncrementalCompact},
stmt = statement_indices_to_labels(stmt, cfg)
end
show_type = type !== nothing && should_print_ssa_type(stmt)
print_stmt(io, idx, stmt, used, maxlength_idx, true, show_type)
print_stmt(io, idx, stmt, code, sptypes, used, maxlength_idx, true, show_type)
if type !== nothing # ignore types for pre-inference code
if type === UNDEF
# This is an error, but can happen if passes don't update their type information
Expand Down Expand Up @@ -901,10 +911,10 @@ end
default_config(code::CodeInfo) = IRShowConfig(statementidx_lineinfo_printer(code))

function show_ir_stmts(io::IO, ir::Union{IRCode, CodeInfo, IncrementalCompact}, inds, config::IRShowConfig,
used::BitSet, cfg::CFG, bb_idx::Int; pop_new_node! = Returns(nothing))
sptypes::Vector{VarState}, used::BitSet, cfg::CFG, bb_idx::Int; pop_new_node! = Returns(nothing))
for idx in inds
if config.should_print_stmt(ir, idx, used)
bb_idx = show_ir_stmt(io, ir, idx, config, used, cfg, bb_idx; pop_new_node!)
bb_idx = show_ir_stmt(io, ir, idx, config, sptypes, used, cfg, bb_idx; pop_new_node!)
elseif bb_idx <= length(cfg.blocks) && idx == cfg.blocks[bb_idx].stmts.stop
bb_idx += 1
end
Expand All @@ -924,7 +934,7 @@ function show_ir(io::IO, ir::IRCode, config::IRShowConfig=default_config(ir);
cfg = ir.cfg
maxssaid = length(ir.stmts) + Core.Compiler.length(ir.new_nodes)
let io = IOContext(io, :maxssaid=>maxssaid)
show_ir_stmts(io, ir, 1:length(ir.stmts), config, used, cfg, 1; pop_new_node!)
show_ir_stmts(io, ir, 1:length(ir.stmts), config, ir.sptypes, used, cfg, 1; pop_new_node!)
end
finish_show_ir(io, cfg, config)
end
Expand All @@ -933,8 +943,12 @@ function show_ir(io::IO, ci::CodeInfo, config::IRShowConfig=default_config(ci);
pop_new_node! = Returns(nothing))
used = stmts_used(io, ci)
cfg = compute_basic_blocks(ci.code)
parent = ci.parent
sptypes = if parent isa MethodInstance
Core.Compiler.sptypes_from_meth_instance(parent)
else Core.Compiler.EMPTY_SPTYPES end
let io = IOContext(io, :maxssaid=>length(ci.code))
show_ir_stmts(io, ci, 1:length(ci.code), config, used, cfg, 1; pop_new_node!)
show_ir_stmts(io, ci, 1:length(ci.code), config, sptypes, used, cfg, 1; pop_new_node!)
end
finish_show_ir(io, cfg, config)
end
Expand Down Expand Up @@ -983,8 +997,8 @@ function show_ir(io::IO, compact::IncrementalCompact, config::IRShowConfig=defau
pop_new_node! = new_nodes_iter(compact)
maxssaid = length(compact.result) + Core.Compiler.length(compact.new_new_nodes)
bb_idx = let io = IOContext(io, :maxssaid=>maxssaid)
show_ir_stmts(io, compact, 1:compact.result_idx-1, config, used_compacted,
compact_cfg, 1; pop_new_node!)
show_ir_stmts(io, compact, 1:compact.result_idx-1, config, compact.ir.sptypes,
used_compacted, compact_cfg, 1; pop_new_node!)
end


Expand Down Expand Up @@ -1015,13 +1029,13 @@ function show_ir(io::IO, compact::IncrementalCompact, config::IRShowConfig=defau
let io = IOContext(io, :maxssaid=>maxssaid)
# first show any new nodes to be attached after the last compacted statement
if compact.idx > 1
show_ir_stmt(io, compact.ir, compact.idx-1, config, used_uncompacted,
uncompacted_cfg, bb_idx; pop_new_node!, only_after=true)
show_ir_stmt(io, compact.ir, compact.idx-1, config, compact.ir.sptypes,
used_uncompacted, uncompacted_cfg, bb_idx; pop_new_node!, only_after=true)
end

# then show the actual uncompacted IR
show_ir_stmts(io, compact.ir, compact.idx:length(stmts), config, used_uncompacted,
uncompacted_cfg, bb_idx; pop_new_node!)
show_ir_stmts(io, compact.ir, compact.idx:length(stmts), config, compact.ir.sptypes,
used_uncompacted, uncompacted_cfg, bb_idx; pop_new_node!)
end

finish_show_ir(io, uncompacted_cfg, config)
Expand Down
2 changes: 2 additions & 0 deletions base/compiler/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ the following methods to satisfy the `AbstractInterpreter` API requirement:

abstract type AbstractLattice end

struct InvalidIRError <: Exception end

struct ArgInfo
fargs::Union{Nothing,Vector{Any}}
argtypes::Vector{Any}
Expand Down
5 changes: 4 additions & 1 deletion stdlib/InteractiveUtils/test/highlighting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ end
c = Base.text_colors[Base.warn_color()]
InteractiveUtils.highlighting[:warntype] = false
code_warntype(IOContext(io, :color => true), f, Tuple{Int64})
@test !occursin(c, String(take!(io)))
@test !any([
occursin("Body", line) && occursin(c, line)
for line in split(String(take!(io)), "\n")
])
InteractiveUtils.highlighting[:warntype] = true
code_warntype(IOContext(io, :color => true), f, Tuple{Int64})
@test occursin(c, String(take!(io)))
Expand Down
2 changes: 1 addition & 1 deletion test/compiler/EscapeAnalysis/EAUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ function print_with_info(preprint, postprint, io::IO, ir::IRCode, source::Bool)
bb_idx_prev = bb_idx = 1
for idx = 1:length(ir.stmts)
preprint(io, idx)
bb_idx = Base.IRShow.show_ir_stmt(io, ir, idx, line_info_preprinter, line_info_postprinter, used, ir.cfg, bb_idx)
bb_idx = Base.IRShow.show_ir_stmt(io, ir, idx, line_info_preprinter, line_info_postprinter, ir.sptypes, used, ir.cfg, bb_idx)
postprint(io, idx, bb_idx != bb_idx_prev)
bb_idx_prev = bb_idx
end
Expand Down

0 comments on commit 739ac62

Please sign in to comment.