diff --git a/src/moments.jl b/src/moments.jl index c5d0ae5c0..c973d8b8d 100644 --- a/src/moments.jl +++ b/src/moments.jl @@ -42,7 +42,7 @@ function var(v::RealArray, w::AbstractWeights; mean=nothing, corrected::DepBool=nothing) corrected = depcheck(:var, corrected) - if mean == nothing + if mean === nothing varm(v, w, Statistics.mean(v, w); corrected=corrected) else varm(v, w, mean; corrected=corrected) @@ -65,7 +65,7 @@ function var!(R::AbstractArray, A::RealArray, w::AbstractWeights, dims::Int; if mean == 0 varm!(R, A, w, Base.reducedim_initarray(A, dims, 0, eltype(R)), dims; corrected=corrected) - elseif mean == nothing + elseif mean === nothing varm!(R, A, w, Statistics.mean(A, w, dims=dims), dims; corrected=corrected) else # check size of mean @@ -85,15 +85,22 @@ end function varm(A::RealArray, w::AbstractWeights, M::RealArray, dim::Int; corrected::DepBool=nothing) corrected = depcheck(:varm, corrected) - varm!(similar(A, Float64, Base.reduced_indices(axes(A), dim)), A, w, M, + TA, TM, Tw = eltype(A), eltype(M), eltype(w) + T = typeof( (zero(Tw) * (zero(TA) - zero(TM)))^2 ) + varm!(similar(A, T, Base.reduced_indices(axes(A), dim)), A, w, M, dim; corrected=corrected) end function var(A::RealArray, w::AbstractWeights, dim::Int; mean=nothing, corrected::DepBool=nothing) corrected = depcheck(:var, corrected) - var!(similar(A, Float64, Base.reduced_indices(axes(A), dim)), A, w, dim; - mean=mean, corrected=corrected) + # doing this here instead of in var! + # allows better type stability for the returned array + M = Statistics.mean(A, w, dims=dim) + TA, TM, Tw = eltype(A), eltype(M), eltype(w) + T = typeof( (zero(Tw) * (zero(TA) - zero(TM)))^2 ) + var!(similar(A, T, Base.reduced_indices(axes(A), dim)), A, w, dim; + mean=M, corrected=corrected) end ## std @@ -223,7 +230,7 @@ end ##### General central moment function _moment2(v::RealArray, m::Real; corrected=false) n = length(v) - s = 0.0 + s = (zero(eltype(v)) - zero(m))^2 for i = 1:n @inbounds z = v[i] - m s += z * z @@ -233,7 +240,7 @@ end function _moment2(v::RealArray, wv::AbstractWeights, m::Real; corrected=false) n = length(v) - s = 0.0 + s = zero(eltype(wv)) * (zero(eltype(v)) - zero(m))^2 for i = 1:n @inbounds z = v[i] - m @inbounds s += (z * z) * wv[i] @@ -244,7 +251,7 @@ end function _moment3(v::RealArray, m::Real) n = length(v) - s = 0.0 + s = (zero(eltype(v)) - zero(m))^3 for i = 1:n @inbounds z = v[i] - m s += z * z * z @@ -254,7 +261,7 @@ end function _moment3(v::RealArray, wv::AbstractWeights, m::Real) n = length(v) - s = 0.0 + s = zero(eltype(wv)) * (zero(eltype(v)) - zero(m))^3 for i = 1:n @inbounds z = v[i] - m @inbounds s += (z * z * z) * wv[i] @@ -264,7 +271,7 @@ end function _moment4(v::RealArray, m::Real) n = length(v) - s = 0.0 + s = (zero(eltype(v)) - zero(m))^4 for i = 1:n @inbounds z = v[i] - m s += abs2(z * z) @@ -274,7 +281,7 @@ end function _moment4(v::RealArray, wv::AbstractWeights, m::Real) n = length(v) - s = 0.0 + s = zero(eltype(wv)) * (zero(eltype(v)) - zero(m))^4 for i = 1:n @inbounds z = v[i] - m @inbounds s += abs2(z * z) * wv[i] @@ -284,7 +291,7 @@ end function _momentk(v::RealArray, k::Int, m::Real) n = length(v) - s = 0.0 + s = zero(eltype(v)) - zero(m) for i = 1:n @inbounds z = v[i] - m s += (z ^ k) @@ -294,7 +301,7 @@ end function _momentk(v::RealArray, k::Int, wv::AbstractWeights, m::Real) n = length(v) - s = 0.0 + s = zero(eltype(wv)) * (zero(eltype(v)) - zero(m))^k for i = 1:n @inbounds z = v[i] - m @inbounds s += (z ^ k) * wv[i] @@ -341,8 +348,9 @@ specifying a weighting vector `wv` and a center `m`. """ function skewness(v::RealArray, m::Real) n = length(v) - cm2 = 0.0 # empirical 2nd centered moment (variance) - cm3 = 0.0 # empirical 3rd centered moment + T = typeof( zero(eltype(v)) - zero(m) ) + cm2 = zero(T)^2 # empirical 2nd centered moment (variance) + cm3 = zero(T)^3 # empirical 3rd centered moment for i = 1:n @inbounds z = v[i] - m z2 = z * z @@ -358,8 +366,9 @@ end function skewness(v::RealArray, wv::AbstractWeights, m::Real) n = length(v) length(wv) == n || throw(DimensionMismatch("Inconsistent array lengths.")) - cm2 = 0.0 # empirical 2nd centered moment (variance) - cm3 = 0.0 # empirical 3rd centered moment + T = typeof(zero(eltype(v)) - zero(m)) + cm2 = zero(eltype(wv)) * zero(T)^2 # empirical 2nd centered moment (variance) + cm3 = zero(eltype(wv)) * zero(T)^3 # empirical 3rd centered moment @inbounds for i = 1:n x_i = v[i] @@ -388,8 +397,9 @@ specifying a weighting vector `wv` and a center `m`. """ function kurtosis(v::RealArray, m::Real) n = length(v) - cm2 = 0.0 # empirical 2nd centered moment (variance) - cm4 = 0.0 # empirical 4th centered moment + T = typeof( zero(eltype(v)) - zero(m) ) + cm2 = zero(T)^2 # empirical 2nd centered moment (variance) + cm4 = zero(T)^4 # empirical 4th centered moment for i = 1:n @inbounds z = v[i] - m z2 = z * z @@ -398,14 +408,15 @@ function kurtosis(v::RealArray, m::Real) end cm4 /= n cm2 /= n - return (cm4 / (cm2 * cm2)) - 3.0 + return (cm4 / (cm2 * cm2)) - 3 end function kurtosis(v::RealArray, wv::AbstractWeights, m::Real) n = length(v) length(wv) == n || throw(DimensionMismatch("Inconsistent array lengths.")) - cm2 = 0.0 # empirical 2nd centered moment (variance) - cm4 = 0.0 # empirical 4th centered moment + T = typeof(zero(eltype(v)) - zero(m)) + cm2 = zero(eltype(wv)) * zero(T)^2 # empirical 2nd centered moment (variance) + cm4 = zero(eltype(wv)) * zero(T)^4 # empirical 4th centered moment @inbounds for i = 1 : n x_i = v[i] @@ -419,7 +430,7 @@ function kurtosis(v::RealArray, wv::AbstractWeights, m::Real) sw = sum(wv) cm4 /= sw cm2 /= sw - return (cm4 / (cm2 * cm2)) - 3.0 + return (cm4 / (cm2 * cm2)) - 3 end kurtosis(v::RealArray) = kurtosis(v, mean(v)) diff --git a/test/moments.jl b/test/moments.jl index 97fda44ac..09fc53bc2 100644 --- a/test/moments.jl +++ b/test/moments.jl @@ -1,4 +1,5 @@ using StatsBase +using Statistics using Test @testset "StatsBase.Moments" begin @@ -278,4 +279,19 @@ end @test moment(x, 5, w) ≈ sum((x2 .- 4).^5) / 5 end +@testset "Preservation of eltypes in moments" begin + xs = Float16[1, 2, 3, 4, 5]; + ws = AnalyticWeights(Float16[1, 1, 1, 1, 1]); + @test typeof(std(xs)) === Float16 + @test typeof(var(xs)) === Float16 + @test typeof(mean(xs, ws)) === Float16 + @test typeof(std(xs, ws)) === Float16 + @test typeof(var(xs, ws)) === Float16 + @test typeof(skewness(xs, ws)) === Float16 + @test typeof(kurtosis(xs, ws)) === Float16 + for i in 1:5 + @test typeof(moment(xs, i, ws)) === Float16 + end +end + end # @testset "StatsBase.Moments"