Skip to content

Commit

Permalink
optimize: revise inlining costs (#51599)
Browse files Browse the repository at this point in the history
Add a bonus for Intrinsics called with mostly constant arguments. We
know that simple expressions like `x*1 + 0` will get optimized later by
LLVM, and also likely fold into other expressions, so try to reflect
that in the cost estimated earlier. Additionally rebalance some of the
other costs to more accurately reflect what they take in assembly.
  • Loading branch information
vtjnash authored Oct 6, 2023
1 parent f919e8f commit 0ab032a
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 34 deletions.
40 changes: 34 additions & 6 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ function stmt_effect_flags(𝕃ₒ::AbstractLattice, @nospecialize(stmt), @nospe
nothrow = _builtin_nothrow(𝕃ₒ, f, argtypes, rt)
return (true, nothrow, nothrow)
end
if f === Intrinsics.cglobal
if f === Intrinsics.cglobal || f === Intrinsics.llvmcall
# TODO: these are not yet linearized
return (false, false, false)
end
Expand Down Expand Up @@ -1031,11 +1031,36 @@ function statement_cost(ex::Expr, line::Int, src::Union{CodeInfo, IRCode}, sptyp
f = singleton_type(ftyp)
if isa(f, IntrinsicFunction)
iidx = Int(reinterpret(Int32, f::IntrinsicFunction)) + 1
if !isassigned(T_IFUNC_COST, iidx)
# unknown/unhandled intrinsic
return params.inline_nonleaf_penalty
if isassigned(T_IFUNC, iidx)
minarg, maxarg, = T_IFUNC[iidx]
nargs = length(ex.args)
if minarg + 1 <= nargs <= maxarg + 1
# With mostly constant arguments, all Intrinsics tend to become very cheap
# and are likely to combine with the operations around them,
# so reduce their cost by half.
cost = T_IFUNC_COST[iidx]
if cost == 0 || nargs < 3 ||
(f === Intrinsics.cglobal || f === Intrinsics.llvmcall) # these hold malformed IR, so argextype will crash on them
return cost
end
aty2 = widenconditional(argextype(ex.args[2], src, sptypes))
nconst = Int(aty2 isa Const)
for i = 3:nargs
aty = widenconditional(argextype(ex.args[i], src, sptypes))
if widenconst(aty) != widenconst(aty2)
nconst = 0
break
end
nconst += aty isa Const
end
if nconst + 2 >= nargs
cost = (cost - 1) Γ· 2
end
return cost
end
end
return T_IFUNC_COST[iidx]
# unknown/unhandled intrinsic
return params.inline_nonleaf_penalty
end
if isa(f, Builtin) && f !== invoke
# The efficiency of operations like a[i] and s.b
Expand All @@ -1046,9 +1071,12 @@ function statement_cost(ex::Expr, line::Int, src::Union{CodeInfo, IRCode}, sptyp
# tuple iteration/destructuring makes that impossible
# return plus_saturate(argcost, isknowntype(extyp) ? 1 : params.inline_nonleaf_penalty)
return 0
elseif (f === Core.arrayref || f === Core.const_arrayref || f === Core.arrayset) && length(ex.args) >= 3
elseif (f === Core.arrayref || f === Core.const_arrayref) && length(ex.args) >= 3
atyp = argextype(ex.args[3], src, sptypes)
return isknowntype(atyp) ? 4 : error_path ? params.inline_error_path_cost : params.inline_nonleaf_penalty
elseif f === Core.arrayset && length(ex.args) >= 3
atyp = argextype(ex.args[2], src, sptypes)
return isknowntype(atyp) ? 8 : error_path ? params.inline_error_path_cost : params.inline_nonleaf_penalty
elseif f === typeassert && isconstType(widenconst(argextype(ex.args[3], src, sptypes)))
return 1
end
Expand Down
52 changes: 26 additions & 26 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,10 @@ end
@nospecs conversion_tfunc(𝕃::AbstractLattice, t, x) = conversion_tfunc(widenlattice(𝕃), t, x)
@nospecs conversion_tfunc(::JLTypeLattice, t, x) = instanceof_tfunc(t, true)[1]

add_tfunc(bitcast, 2, 2, bitcast_tfunc, 1)
add_tfunc(sext_int, 2, 2, conversion_tfunc, 1)
add_tfunc(zext_int, 2, 2, conversion_tfunc, 1)
add_tfunc(trunc_int, 2, 2, conversion_tfunc, 1)
add_tfunc(bitcast, 2, 2, bitcast_tfunc, 0)
add_tfunc(sext_int, 2, 2, conversion_tfunc, 0)
add_tfunc(zext_int, 2, 2, conversion_tfunc, 0)
add_tfunc(trunc_int, 2, 2, conversion_tfunc, 0)
add_tfunc(fptoui, 2, 2, conversion_tfunc, 1)
add_tfunc(fptosi, 2, 2, conversion_tfunc, 1)
add_tfunc(uitofp, 2, 2, conversion_tfunc, 1)
Expand All @@ -170,30 +170,30 @@ add_tfunc(fpext, 2, 2, conversion_tfunc, 1)
@nospecs math_tfunc(𝕃::AbstractLattice, args...) = math_tfunc(widenlattice(𝕃), args...)
@nospecs math_tfunc(::JLTypeLattice, x, xs...) = widenconst(x)

add_tfunc(neg_int, 1, 1, math_tfunc, 1)
add_tfunc(neg_int, 1, 1, math_tfunc, 0)
add_tfunc(add_int, 2, 2, math_tfunc, 1)
add_tfunc(sub_int, 2, 2, math_tfunc, 1)
add_tfunc(mul_int, 2, 2, math_tfunc, 4)
add_tfunc(sdiv_int, 2, 2, math_tfunc, 30)
add_tfunc(udiv_int, 2, 2, math_tfunc, 30)
add_tfunc(srem_int, 2, 2, math_tfunc, 30)
add_tfunc(urem_int, 2, 2, math_tfunc, 30)
add_tfunc(mul_int, 2, 2, math_tfunc, 3)
add_tfunc(sdiv_int, 2, 2, math_tfunc, 20)
add_tfunc(udiv_int, 2, 2, math_tfunc, 20)
add_tfunc(srem_int, 2, 2, math_tfunc, 20)
add_tfunc(urem_int, 2, 2, math_tfunc, 20)
add_tfunc(add_ptr, 2, 2, math_tfunc, 1)
add_tfunc(sub_ptr, 2, 2, math_tfunc, 1)
add_tfunc(neg_float, 1, 1, math_tfunc, 1)
add_tfunc(add_float, 2, 2, math_tfunc, 1)
add_tfunc(sub_float, 2, 2, math_tfunc, 1)
add_tfunc(mul_float, 2, 2, math_tfunc, 4)
add_tfunc(div_float, 2, 2, math_tfunc, 4)
add_tfunc(fma_float, 3, 3, math_tfunc, 5)
add_tfunc(muladd_float, 3, 3, math_tfunc, 5)
add_tfunc(add_float, 2, 2, math_tfunc, 2)
add_tfunc(sub_float, 2, 2, math_tfunc, 2)
add_tfunc(mul_float, 2, 2, math_tfunc, 8)
add_tfunc(div_float, 2, 2, math_tfunc, 10)
add_tfunc(fma_float, 3, 3, math_tfunc, 8)
add_tfunc(muladd_float, 3, 3, math_tfunc, 8)

# fast arithmetic
add_tfunc(neg_float_fast, 1, 1, math_tfunc, 1)
add_tfunc(add_float_fast, 2, 2, math_tfunc, 1)
add_tfunc(sub_float_fast, 2, 2, math_tfunc, 1)
add_tfunc(mul_float_fast, 2, 2, math_tfunc, 2)
add_tfunc(div_float_fast, 2, 2, math_tfunc, 2)
add_tfunc(add_float_fast, 2, 2, math_tfunc, 2)
add_tfunc(sub_float_fast, 2, 2, math_tfunc, 2)
add_tfunc(mul_float_fast, 2, 2, math_tfunc, 8)
add_tfunc(div_float_fast, 2, 2, math_tfunc, 10)

# bitwise operators
# -----------------
Expand Down Expand Up @@ -280,12 +280,12 @@ add_tfunc(le_float_fast, 2, 2, cmp_tfunc, 1)
@nospecs chk_tfunc(𝕃::AbstractLattice, x, y) = chk_tfunc(widenlattice(𝕃), x, y)
@nospecs chk_tfunc(::JLTypeLattice, x, y) = Tuple{widenconst(x), Bool}

add_tfunc(checked_sadd_int, 2, 2, chk_tfunc, 10)
add_tfunc(checked_uadd_int, 2, 2, chk_tfunc, 10)
add_tfunc(checked_ssub_int, 2, 2, chk_tfunc, 10)
add_tfunc(checked_usub_int, 2, 2, chk_tfunc, 10)
add_tfunc(checked_smul_int, 2, 2, chk_tfunc, 10)
add_tfunc(checked_umul_int, 2, 2, chk_tfunc, 10)
add_tfunc(checked_sadd_int, 2, 2, chk_tfunc, 2)
add_tfunc(checked_uadd_int, 2, 2, chk_tfunc, 2)
add_tfunc(checked_ssub_int, 2, 2, chk_tfunc, 2)
add_tfunc(checked_usub_int, 2, 2, chk_tfunc, 2)
add_tfunc(checked_smul_int, 2, 2, chk_tfunc, 5)
add_tfunc(checked_umul_int, 2, 2, chk_tfunc, 5)

# other, misc
# -----------
Expand Down
4 changes: 2 additions & 2 deletions doc/src/devdocs/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,11 @@ Each statement gets analyzed for its total cost in a function called
as follows:
```jldoctest; filter=r"tuple.jl:\d+"
julia> Base.print_statement_costs(stdout, map, (typeof(sqrt), Tuple{Int},)) # map(sqrt, (2,))
map(f, t::Tuple{Any}) @ Base tuple.jl:291
map(f, t::Tuple{Any}) @ Base tuple.jl:281
0 1 ─ %1 = $(Expr(:boundscheck, true))::Bool
0 β”‚ %2 = Base.getfield(_3, 1, %1)::Int64
1 β”‚ %3 = Base.sitofp(Float64, %2)::Float64
2 β”‚ %4 = Base.lt_float(%3, 0.0)::Bool
0 β”‚ %4 = Base.lt_float(%3, 0.0)::Bool
0 └── goto #3 if not %4
0 2 ─ invoke Base.Math.throw_complex_domainerror(:sqrt::Symbol, %3::Float64)::Union{}
0 └── unreachable
Expand Down

0 comments on commit 0ab032a

Please sign in to comment.