Skip to content

Commit

Permalink
fix cumsum ignoring first element returned in call to promote_op (#53461
Browse files Browse the repository at this point in the history
)

Include an option to have `promote_op` throw an error rather than return
`Union{}`. We're not changing the default behaviour, but in some cases
returning `Union{}` will result in less clear error messages.

Closes #53438

@LilithHafner may have comments
  • Loading branch information
rofinn authored Apr 15, 2024
1 parent 3dd07fd commit 9de150c
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 8 deletions.
54 changes: 46 additions & 8 deletions base/accumulate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ function accumulate_pairwise!(op::Op, result::AbstractVector, v::AbstractVector)
end

function accumulate_pairwise(op, v::AbstractVector{T}) where T
out = similar(v, promote_op(op, T, T))
out = similar(v, _accumulate_promote_op(op, v))
return accumulate_pairwise!(op, out, v)
end

Expand Down Expand Up @@ -111,8 +111,8 @@ julia> cumsum(a, dims=2)
widening happens and integer overflow results in `Int8[100, -128]`.
"""
function cumsum(A::AbstractArray{T}; dims::Integer) where T
out = similar(A, promote_op(add_sum, T, T))
cumsum!(out, A, dims=dims)
out = similar(A, _accumulate_promote_op(add_sum, A))
return cumsum!(out, A, dims=dims)
end

"""
Expand Down Expand Up @@ -280,14 +280,13 @@ function accumulate(op, A; dims::Union{Nothing,Integer}=nothing, kw...)
# This branch takes care of the cases not handled by `_accumulate!`.
return collect(Iterators.accumulate(op, A; kw...))
end

nt = values(kw)
if isempty(kw)
out = similar(A, promote_op(op, eltype(A), eltype(A)))
elseif keys(nt) === (:init,)
out = similar(A, promote_op(op, typeof(nt.init), eltype(A)))
else
if !(isempty(kw) || keys(nt) === (:init,))
throw(ArgumentError("accumulate does not support the keyword arguments $(setdiff(keys(nt), (:init,)))"))
end

out = similar(A, _accumulate_promote_op(op, A; kw...))
accumulate!(op, out, A; dims=dims, kw...)
end

Expand Down Expand Up @@ -442,3 +441,42 @@ function _accumulate1!(op, B, v1, A::AbstractVector, dim::Integer)
end
return B
end

# Internal function used to identify the widest possible eltype required for accumulate results
function _accumulate_promote_op(op, v; init=nothing)
# Nested mock functions used to infer the widest necessary eltype
# NOTE: We are just passing this to promote_op for inference and should never be run.

# Initialization function used to identify initial type of `r`
# NOTE: reduce_first may have a different return type than calling `op`
function f(op, v, init)
val = first(something(iterate(v)))
return isnothing(init) ? Base.reduce_first(op, val) : op(init, val)
end

# Infer iteration type independent of the initialization type
# If `op` fails then this will return `Union{}` as `k` will be undefined.
# Returning `Union{}` is desirable as it won't break the `promote_type` call in the
# outer scope below
function g(op, v, r)
local k
for val in v
k = op(r, val)
end
return k
end

# Finally loop again with the two types promoted together
# If the `op` fails and reduce_first was used then then this will still just
# return the initial type, allowing the `op` to error during execution.
function h(op, v, r)
for val in v
r = op(r, val)
end
return r
end

R = Base.promote_op(f, typeof(op), typeof(v), typeof(init))
K = Base.promote_op(g, typeof(op), typeof(v), R)
return Base.promote_op(h, typeof(op), typeof(v), Base.promote_type(R, K))
end
27 changes: 27 additions & 0 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2859,6 +2859,33 @@ end
@inferred accumulate(*, String[])
@test accumulate(*, ['a' 'b'; 'c' 'd'], dims=1) == ["a" "b"; "ac" "bd"]
@test accumulate(*, ['a' 'b'; 'c' 'd'], dims=2) == ["a" "ab"; "c" "cd"]

# #53438
v = [(1, 2), (3, 4)]
@test_throws MethodError accumulate(+, v)
@test_throws MethodError cumsum(v)
@test_throws MethodError cumprod(v)
@test_throws MethodError accumulate(+, v; init=(0, 0))
@test_throws MethodError accumulate(+, v; dims=1, init=(0, 0))

# Some checks to ensure we're identifying the widest needed eltype
# as identified in PR 53461
@testset "Base._accumulate_promote_op" begin
# A somewhat contrived example where each call to `foo`
# will return a different type
foo(x::Bool, y::Int)::Int = x + y
foo(x::Int, y::Int)::Float64 = x + y
foo(x::Float64, y::Int)::ComplexF64 = x + y * im
foo(x::ComplexF64, y::Int)::String = string(x, "+", y)

v = collect(1:5)
@test Base._accumulate_promote_op(foo, v; init=true) === Base._accumulate_promote_op(foo, v) == Union{Float64, String, ComplexF64}
@test Base._accumulate_promote_op(/, v) === Base._accumulate_promote_op(/, v; init=0) == Float64
@test Base._accumulate_promote_op(+, v) === Base._accumulate_promote_op(+, v; init=0) === Int
@test Base._accumulate_promote_op(+, v; init=0.0) === Float64
@test Base._accumulate_promote_op(+, Union{Int, Missing}[v...]) === Union{Int, Missing}
@test Base._accumulate_promote_op(+, Union{Int, Nothing}[v...]) === Union{Int, Nothing}
end
end

struct F21666{T <: Base.ArithmeticStyle}
Expand Down

0 comments on commit 9de150c

Please sign in to comment.