Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reflection: move signature union-splitting logic under the control of inference #22144

Merged
merged 2 commits into from
May 31, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 87 additions & 27 deletions base/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1253,8 +1253,34 @@ end

#### recursing into expression ####

# take a Tuple where one or more parameters are Unions
# and return an array such that those Unions are removed
# and `Union{return...} == ty`
function switchtupleunion(ty::ANY)
tparams = (unwrap_unionall(ty)::DataType).parameters
return _switchtupleunion(Any[tparams...], length(tparams), [], ty)
end

function _switchtupleunion(t::Vector{Any}, i::Int, tunion::Vector{Any}, origt::ANY)
if i == 0
tpl = rewrap_unionall(Tuple{t...}, origt)
push!(tunion, tpl)
else
ti = t[i]
if isa(ti, Union)
for ty in uniontypes(ti::Union)
t[i] = ty
_switchtupleunion(t, i - 1, tunion, origt)
end
t[i] = ti
else
_switchtupleunion(t, i - 1, tunion, origt)
end
end
return tunion
end

function abstract_call_gf_by_type(f::ANY, atype::ANY, sv::InferenceState)
tm = _topmod(sv)
# don't consider more than N methods. this trades off between
# compiler performance and generated code performance.
# typically, considering many methods means spending lots of time
Expand Down Expand Up @@ -1282,23 +1308,64 @@ function abstract_call_gf_by_type(f::ANY, atype::ANY, sv::InferenceState)
end
min_valid = UInt[typemin(UInt)]
max_valid = UInt[typemax(UInt)]
splitunions = 1 < countunionsplit(argtypes) <= sv.params.MAX_UNION_SPLITTING
if splitunions
splitsigs = switchtupleunion(argtype)
applicable = Any[]
for sig_n in splitsigs
xapplicable = _methods_by_ftype(sig_n, sv.params.MAX_METHODS, sv.params.world, min_valid, max_valid)
xapplicable === false && return Any
append!(applicable, xapplicable)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth to keep track of how many different methods the entries in applicable actually refer to and bail out if that is more than sv.params.MAX_METHODS, too?

Copy link
Sponsor Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably, for now I'm just moving the existing logic rather than trying to improve on it much

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense.

else
applicable = _methods_by_ftype(argtype, sv.params.MAX_METHODS, sv.params.world, min_valid, max_valid)
rettype = Bottom
if applicable === false
# this means too many methods matched
return Any
end
end
applicable = applicable::Array{Any,1}
napplicable = length(applicable)
fullmatch = false
for (m::SimpleVector) in applicable
sig = m[1]
sigtuple = unwrap_unionall(sig)::DataType
method = m[3]::Method
sparams = m[2]::SimpleVector
recomputesvec = false
rettype = Bottom
for i in 1:napplicable
match = applicable[i]::SimpleVector
method = match[3]::Method
if !fullmatch && (argtype <: method.sig)
fullmatch = true
end
sig = match[1]
sigtuple = unwrap_unionall(sig)::DataType
splitunions = false
# TODO: splitunions = 1 < countunionsplit(sigtuple.parameters) * napplicable <= sv.params.MAX_UNION_SPLITTING
# currently this triggers a bug in inference recursion detection
if splitunions
splitsigs = switchtupleunion(sig)
for sig_n in splitsigs
rt = abstract_call_method(method, f, sig_n, svec(), sv)
rettype = tmerge(rettype, rt)
rettype === Any && break
end
rettype === Any && break
else
rt = abstract_call_method(method, f, sig, match[2]::SimpleVector, sv)
rettype = tmerge(rettype, rt)
rettype === Any && break
end
end
if !(fullmatch || rettype === Any)
# also need an edge to the method table in case something gets
# added that did not intersect with any existing method
add_mt_backedge(ftname.mt, argtype, sv)
update_valid_age!(min_valid[1], max_valid[1], sv)
end
#print("=> ", rettype, "\n")
return rettype
end

function abstract_call_method(method::Method, f::ANY, sig::ANY, sparams::SimpleVector, sv::InferenceState)
sigtuple = unwrap_unionall(sig)::DataType
recomputesvec = false

# limit argument type tuple growth
msig = unwrap_unionall(method.sig)
Expand Down Expand Up @@ -1367,6 +1434,7 @@ function abstract_call_gf_by_type(f::ANY, atype::ANY, sv::InferenceState)
# for example, given function f(T, Any...), limit to 3 arguments
# instead of the default (MAX_TUPLETYPE_LEN)
if limitlength
tm = _topmod(sv)
if !istopfunction(tm, f, :promote_typeof)
fst = sigtuple.parameters[lsig + 1]
allsame = true
Expand All @@ -1388,30 +1456,17 @@ function abstract_call_gf_by_type(f::ANY, atype::ANY, sv::InferenceState)
end

# if sig changed, may need to recompute the sparams environment
if recomputesvec && !isempty(sparams)
if isa(method.sig, UnionAll) && (recomputesvec || isempty(sparams))
recomputed = ccall(:jl_env_from_type_intersection, Ref{SimpleVector}, (Any, Any), sig, method.sig)
sig = recomputed[1]
if !isa(unwrap_unionall(sig), DataType) # probably Union{}
rettype = Any
break
return Any
end
sparams = recomputed[2]::SimpleVector
end
rt, edge = typeinf_edge(method, sig, sparams, sv)
edge !== nothing && add_backedge!(edge::MethodInstance, sv)
rettype = tmerge(rettype, rt)
if rettype === Any
break
end
end
if !(fullmatch || rettype === Any)
# also need an edge to the method table in case something gets
# added that did not intersect with any existing method
add_mt_backedge(ftname.mt, argtype, sv)
update_valid_age!(min_valid[1], max_valid[1], sv)
end
#print("=> ", rettype, "\n")
return rettype
return rt
end

# determine whether `ex` abstractly evals to constant `c`
Expand Down Expand Up @@ -1562,6 +1617,9 @@ function abstract_apply(aft::ANY, fargs::Vector{Any}, aargtypes::Vector{Any}, vt
return res
end

# TODO: this function is a very buggy and poor model of the return_type function
# since abstract_call_gf_by_type is a very inaccurate model of _method and of typeinf_type,
# while this assumes that it is a precisely accurate and exact model of both
function return_type_tfunc(argtypes::ANY, vtypes::VarTable, sv::InferenceState)
if length(argtypes) == 3
tt = argtypes[3]
Expand Down Expand Up @@ -2112,8 +2170,10 @@ function issubconditional(a::Conditional, b::Conditional)
end

function ⊑(a::ANY, b::ANY)
a === NF && return true
b === NF && return false
(a === NF || b === Any) && return true
(a === Any || b === NF) && return false
a === Union{} && return true
b === Union{} && return false
if isa(a, Conditional)
if isa(b, Conditional)
return issubconditional(a, b)
Expand Down Expand Up @@ -3483,7 +3543,7 @@ function is_self_quoting(x::ANY)
return isa(x,Number) || isa(x,AbstractString) || isa(x,Tuple) || isa(x,Type)
end

function countunionsplit(atypes::Vector{Any})
function countunionsplit(atypes)
nu = 1
for ti in atypes
if isa(ti, Union)
Expand Down
37 changes: 0 additions & 37 deletions base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -507,46 +507,9 @@ function _methods_by_ftype(t::ANY, lim::Int, world::UInt)
return _methods_by_ftype(t, lim, world, UInt[typemin(UInt)], UInt[typemax(UInt)])
end
function _methods_by_ftype(t::ANY, lim::Int, world::UInt, min::Array{UInt,1}, max::Array{UInt,1})
tp = unwrap_unionall(t).parameters::SimpleVector
nu = 1
for ti in tp
if isa(ti, Union)
nu *= unionlen(ti::Union)
end
end
if 1 < nu <= 64
return _methods_by_ftype(Any[tp...], t, length(tp), lim, [], world, min, max)
end
# XXX: the following can return incorrect answers that the above branch would have corrected
return ccall(:jl_matching_methods, Any, (Any, Cint, Cint, UInt, Ptr{UInt}, Ptr{UInt}), t, lim, 0, world, min, max)
end

function _methods_by_ftype(t::Array, origt::ANY, i, lim::Integer, matching::Array{Any,1},
world::UInt, min::Array{UInt,1}, max::Array{UInt,1})
if i == 0
world = typemax(UInt)
new = ccall(:jl_matching_methods, Any, (Any, Cint, Cint, UInt, Ptr{UInt}, Ptr{UInt}),
rewrap_unionall(Tuple{t...}, origt), lim, 0, world, min, max)
new === false && return false
append!(matching, new::Array{Any,1})
else
ti = t[i]
if isa(ti, Union)
for ty in uniontypes(ti::Union)
t[i] = ty
if _methods_by_ftype(t, origt, i - 1, lim, matching, world, min, max) === false
t[i] = ti
return false
end
end
t[i] = ti
else
return _methods_by_ftype(t, origt, i - 1, lim, matching, world, min, max)
end
end
return matching
end

# high-level, more convenient method lookup functions

# type for reflecting and pretty-printing a subset of methods
Expand Down
11 changes: 7 additions & 4 deletions base/reshapedarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,17 +96,20 @@ reshape(parent::AbstractArray, dims::Int...) = reshape(parent, dims)
reshape(parent::AbstractArray, dims::Union{Int,Colon}...) = reshape(parent, dims)
reshape(parent::AbstractArray, dims::Tuple{Vararg{Union{Int,Colon}}}) = _reshape(parent, _reshape_uncolon(parent, dims))
@inline function _reshape_uncolon(A, dims)
pre, post = _split_at_colon((), dims...)
pre = _before_colon(dims...)
post = _after_colon(dims...)
if any(d -> d isa Colon, post)
throw(DimensionMismatch("new dimensions $(dims) may have at most one omitted dimension specified by Colon()"))
end
sz, remainder = divrem(length(A), prod(pre)*prod(post))
remainder == 0 || _throw_reshape_colon_dimmismatch(A, dims)
(pre..., sz, post...)
end
@inline _split_at_colon(pre, dim::Any, tail...) = _split_at_colon((pre..., dim), tail...)
@inline _split_at_colon(pre, ::Colon, tail...) = (pre, tail)
_throw_reshape_colon_dimmismatch(A, dims) =
@inline _before_colon(dim::Any, tail...) = (dim, _before_colon(tail...)...)
@inline _before_colon(dim::Colon, tail...) = ()
@inline _after_colon(dim::Any, tail...) = _after_colon(tail...)
@inline _after_colon(dim::Colon, tail...) = tail
@noinline _throw_reshape_colon_dimmismatch(A, dims) =
throw(DimensionMismatch("array size $(length(A)) must be divisible by the product of the new dimensions $dims"))

reshape(parent::AbstractArray{T,N}, ndims::Type{Val{N}}) where {T,N} = parent
Expand Down