From 9de150c3b3289d5608614f68fd743cb409a24824 Mon Sep 17 00:00:00 2001 From: Rory Finnegan Date: Mon, 15 Apr 2024 06:22:04 -0700 Subject: [PATCH] fix cumsum ignoring first element returned in call to promote_op (#53461) Include an option to have `promote_op` throw an error rather than return `Union{}`. We're not changing the default behaviour, but in some cases returning `Union{}` will result in less clear error messages. Closes #53438 @LilithHafner may have comments --- base/accumulate.jl | 54 +++++++++++++++++++++++++++++++++++++++------- test/arrayops.jl | 27 +++++++++++++++++++++++ 2 files changed, 73 insertions(+), 8 deletions(-) diff --git a/base/accumulate.jl b/base/accumulate.jl index a2d8a1d368d86..2748a4da481fa 100644 --- a/base/accumulate.jl +++ b/base/accumulate.jl @@ -33,7 +33,7 @@ function accumulate_pairwise!(op::Op, result::AbstractVector, v::AbstractVector) end function accumulate_pairwise(op, v::AbstractVector{T}) where T - out = similar(v, promote_op(op, T, T)) + out = similar(v, _accumulate_promote_op(op, v)) return accumulate_pairwise!(op, out, v) end @@ -111,8 +111,8 @@ julia> cumsum(a, dims=2) widening happens and integer overflow results in `Int8[100, -128]`. """ function cumsum(A::AbstractArray{T}; dims::Integer) where T - out = similar(A, promote_op(add_sum, T, T)) - cumsum!(out, A, dims=dims) + out = similar(A, _accumulate_promote_op(add_sum, A)) + return cumsum!(out, A, dims=dims) end """ @@ -280,14 +280,13 @@ function accumulate(op, A; dims::Union{Nothing,Integer}=nothing, kw...) # This branch takes care of the cases not handled by `_accumulate!`. return collect(Iterators.accumulate(op, A; kw...)) end + nt = values(kw) - if isempty(kw) - out = similar(A, promote_op(op, eltype(A), eltype(A))) - elseif keys(nt) === (:init,) - out = similar(A, promote_op(op, typeof(nt.init), eltype(A))) - else + if !(isempty(kw) || keys(nt) === (:init,)) throw(ArgumentError("accumulate does not support the keyword arguments $(setdiff(keys(nt), (:init,)))")) end + + out = similar(A, _accumulate_promote_op(op, A; kw...)) accumulate!(op, out, A; dims=dims, kw...) end @@ -442,3 +441,42 @@ function _accumulate1!(op, B, v1, A::AbstractVector, dim::Integer) end return B end + +# Internal function used to identify the widest possible eltype required for accumulate results +function _accumulate_promote_op(op, v; init=nothing) + # Nested mock functions used to infer the widest necessary eltype + # NOTE: We are just passing this to promote_op for inference and should never be run. + + # Initialization function used to identify initial type of `r` + # NOTE: reduce_first may have a different return type than calling `op` + function f(op, v, init) + val = first(something(iterate(v))) + return isnothing(init) ? Base.reduce_first(op, val) : op(init, val) + end + + # Infer iteration type independent of the initialization type + # If `op` fails then this will return `Union{}` as `k` will be undefined. + # Returning `Union{}` is desirable as it won't break the `promote_type` call in the + # outer scope below + function g(op, v, r) + local k + for val in v + k = op(r, val) + end + return k + end + + # Finally loop again with the two types promoted together + # If the `op` fails and reduce_first was used then then this will still just + # return the initial type, allowing the `op` to error during execution. + function h(op, v, r) + for val in v + r = op(r, val) + end + return r + end + + R = Base.promote_op(f, typeof(op), typeof(v), typeof(init)) + K = Base.promote_op(g, typeof(op), typeof(v), R) + return Base.promote_op(h, typeof(op), typeof(v), Base.promote_type(R, K)) +end diff --git a/test/arrayops.jl b/test/arrayops.jl index 1c36453a6adae..b64d08264e2d1 100644 --- a/test/arrayops.jl +++ b/test/arrayops.jl @@ -2859,6 +2859,33 @@ end @inferred accumulate(*, String[]) @test accumulate(*, ['a' 'b'; 'c' 'd'], dims=1) == ["a" "b"; "ac" "bd"] @test accumulate(*, ['a' 'b'; 'c' 'd'], dims=2) == ["a" "ab"; "c" "cd"] + + # #53438 + v = [(1, 2), (3, 4)] + @test_throws MethodError accumulate(+, v) + @test_throws MethodError cumsum(v) + @test_throws MethodError cumprod(v) + @test_throws MethodError accumulate(+, v; init=(0, 0)) + @test_throws MethodError accumulate(+, v; dims=1, init=(0, 0)) + + # Some checks to ensure we're identifying the widest needed eltype + # as identified in PR 53461 + @testset "Base._accumulate_promote_op" begin + # A somewhat contrived example where each call to `foo` + # will return a different type + foo(x::Bool, y::Int)::Int = x + y + foo(x::Int, y::Int)::Float64 = x + y + foo(x::Float64, y::Int)::ComplexF64 = x + y * im + foo(x::ComplexF64, y::Int)::String = string(x, "+", y) + + v = collect(1:5) + @test Base._accumulate_promote_op(foo, v; init=true) === Base._accumulate_promote_op(foo, v) == Union{Float64, String, ComplexF64} + @test Base._accumulate_promote_op(/, v) === Base._accumulate_promote_op(/, v; init=0) == Float64 + @test Base._accumulate_promote_op(+, v) === Base._accumulate_promote_op(+, v; init=0) === Int + @test Base._accumulate_promote_op(+, v; init=0.0) === Float64 + @test Base._accumulate_promote_op(+, Union{Int, Missing}[v...]) === Union{Int, Missing} + @test Base._accumulate_promote_op(+, Union{Int, Nothing}[v...]) === Union{Int, Nothing} + end end struct F21666{T <: Base.ArithmeticStyle}