From 4b7dc0c9d2705fc37f500177d5c1c06196b3d962 Mon Sep 17 00:00:00 2001 From: Takafumi Arakaki Date: Mon, 3 Feb 2020 22:28:53 -0800 Subject: [PATCH 1/3] Implement `accumulate` and friends for Tuple --- base/accumulate.jl | 31 +++++++++++++++++++++++++++---- test/tuple.jl | 11 +++++++++++ 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/base/accumulate.jl b/base/accumulate.jl index c5c4e83b3b0e3..8d5a5898d7939 100644 --- a/base/accumulate.jl +++ b/base/accumulate.jl @@ -92,12 +92,15 @@ function cumsum(A::AbstractArray{T}; dims::Integer) where T end """ - cumsum(x::AbstractVector) + cumsum(itr::Union{AbstractVector,Tuple}) -Cumulative sum a vector. See also [`cumsum!`](@ref) +Cumulative sum an iterator. See also [`cumsum!`](@ref) to use a preallocated output array, both for performance and to control the precision of the output (e.g. to avoid overflow). +!!! compat "Julia 1.5" + `cumsum` on a tuple requires at least Julia 1.5. + # Examples ```jldoctest julia> cumsum([1, 1, 1]) @@ -111,9 +114,13 @@ julia> cumsum([fill(1, 2) for i in 1:3]) [1, 1] [2, 2] [3, 3] + +julia> cumsum((1, 1, 1)) +(1, 2, 3) ``` """ cumsum(x::AbstractVector) = cumsum(x, dims=1) +cumsum(itr) = accumulate(add_sum, itr) """ @@ -163,12 +170,15 @@ function cumprod(A::AbstractArray; dims::Integer) end """ - cumprod(x::AbstractVector) + cumprod(itr::Union{AbstractVector,Tuple}) -Cumulative product of a vector. See also +Cumulative product of an iterator. See also [`cumprod!`](@ref) to use a preallocated output array, both for performance and to control the precision of the output (e.g. to avoid overflow). +!!! compat "Julia 1.5" + `cumprod` on a tuple requires at least Julia 1.5. + # Examples ```jldoctest julia> cumprod(fill(1//2, 3)) @@ -182,9 +192,13 @@ julia> cumprod([fill(1//3, 2, 2) for i in 1:3]) [1//3 1//3; 1//3 1//3] [2//9 2//9; 2//9 2//9] [4//27 4//27; 4//27 4//27] + +julia> cumprod((1, 2, 1)) +(1, 2, 2) ``` """ cumprod(x::AbstractVector) = cumprod(x, dims=1) +cumprod(itr) = accumulate(mul_prod, itr) """ @@ -247,6 +261,15 @@ function accumulate(op, A; dims::Union{Nothing,Integer}=nothing, kw...) accumulate!(op, out, A; dims=dims, kw...) end +function accumulate(op, xs::Tuple; init = _InitialValue()) + rf = BottomRF(op) + ys, = foldl(xs; init = ((), init)) do (ys, acc), x + acc = rf(acc, x) + (ys..., acc), acc + end + return ys +end + """ accumulate!(op, B, A; [dims], [init]) diff --git a/test/tuple.jl b/test/tuple.jl index 06658819cb9c2..e7e6e3a7d343e 100644 --- a/test/tuple.jl +++ b/test/tuple.jl @@ -368,6 +368,17 @@ end end end +@testset "accumulate" begin + @test @inferred(cumsum(())) == () + @test @inferred(cumsum((1, 2, 3))) == (1, 3, 6) + @test @inferred(cumprod((1, 2, 3))) == (1, 2, 6) + @test @inferred(accumulate(+, (1, 2, 3); init=10)) == (11, 13, 16) + op(::Nothing, ::Any) = missing + op(::Missing, ::Any) = nothing + @test @inferred(accumulate(op, (1, 2, 3, 4); init = nothing)) === + (missing, nothing, missing, nothing) +end + @testset "ntuple" begin nttest1(x::NTuple{n, Int}) where {n} = n @test nttest1(()) == 0 From 8da5715d62f08cf816bea7211b1c87176df8e58f Mon Sep 17 00:00:00 2001 From: Takafumi Arakaki Date: Tue, 4 Feb 2020 00:23:32 -0800 Subject: [PATCH 2/3] Add NEWS --- NEWS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/NEWS.md b/NEWS.md index 9280a30ea42ac..e18c7fcd15540 100644 --- a/NEWS.md +++ b/NEWS.md @@ -57,6 +57,7 @@ New library features * `isapprox` (or `≈`) now has a one-argument "curried" method `isapprox(x)` which returns a function, like `isequal` (or `==`)` ([#32305]). * `Ref{NTuple{N,T}}` can be passed to `Ptr{T}`/`Ref{T}` `ccall` signatures ([#34199]) +* `accumulate`, `cumsum`, and `cumprod` now support `Tuple` ([#34654]). Standard library changes From 039196f0ebe466421f09ec7b1603f5ee2488a9d0 Mon Sep 17 00:00:00 2001 From: Takafumi Arakaki Date: Tue, 25 Feb 2020 21:41:55 -0800 Subject: [PATCH 3/3] Directly call Base.afoldl to be slightly more compiler-friendly --- base/accumulate.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/base/accumulate.jl b/base/accumulate.jl index 8d5a5898d7939..8e959df71b28a 100644 --- a/base/accumulate.jl +++ b/base/accumulate.jl @@ -263,7 +263,7 @@ end function accumulate(op, xs::Tuple; init = _InitialValue()) rf = BottomRF(op) - ys, = foldl(xs; init = ((), init)) do (ys, acc), x + ys, = afoldl(((), init), xs...) do (ys, acc), x acc = rf(acc, x) (ys..., acc), acc end