diff --git a/src/statistics.jl b/src/statistics.jl index c49566f900..933666efa3 100644 --- a/src/statistics.jl +++ b/src/statistics.jl @@ -1,7 +1,10 @@ using Statistics Statistics._var(A::CuArray, corrected::Bool, mean, dims) = - sum((A .- something(mean, Statistics.mean(A, dims=dims))).^2, dims=dims)/(prod(size(A)[[dims...]])-corrected) + sum((A .- something(mean, Statistics.mean(A, dims=dims))).^2, dims=dims)/(prod(size(A)[dims])-corrected) + +Statistics._var(A::CuArray, corrected::Bool, mean, dim::Integer) = + sum((A .- something(mean, Statistics.mean(A, dims=dim))).^2, dims=dim)/(size(A)[dim]-corrected) Statistics._var(A::CuArray, corrected::Bool, mean, ::Colon) = sum((A .- something(mean, Statistics.mean(A))).^2)/(length(A)-corrected) diff --git a/test/statistics.jl b/test/statistics.jl index ca10310b2a..1091e305ff 100644 --- a/test/statistics.jl +++ b/test/statistics.jl @@ -17,6 +17,8 @@ end @testset "mean" begin @test testf(mean, rand(2,2)) @test testf(mean, rand(2,2); dims=2) + @test testf(mean, rand(2,2,2); dims=[1,3]) @test testf(x->mean(sin, x), rand(2,2)) @test testf(x->mean(sin, x; dims=2), rand(2,2)) + @test testf(x->mean(sin, x; dims=[1,3]), rand(2,2,2)) end