Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add collection of mean, var and std for distributions #250

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MeasureTheory"
uuid = "eadaa1a4-d27c-401d-8699-e962e1bbc33b"
authors = ["Chad Scherrer <[email protected]> and contributors"]
version = "0.19.0"
version = "0.19.1"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down Expand Up @@ -53,7 +53,9 @@ ForwardDiff = "0.10"
IfElse = "0.1"
Infinities = "0.1"
InverseFunctions = "0.1"
InteractiveUtils = "<0.0.1, 1"
KeywordCalls = "0.2"
LinearAlgebra = "<0.0.1, 1"
LogExpFunctions = "0.3.3"
MLStyle = "0.4"
MacroTools = "0.5"
Expand All @@ -62,19 +64,15 @@ MeasureBase = "0.14"
NamedTupleTools = "0.13, 0.14"
PositiveFactorizations = "0.2"
PrettyPrinting = "0.3, 0.4"
Random = "<0.0.1, 1"
Reexport = "1"
SpecialFunctions = "1, 2"
Static = "0.8"
StaticArraysCore = "1"
Statistics = "<0.0.1, 1"
StatsBase = "0.34"
StatsFuns = "0.9, 1"
TransformVariables = "0.8"
Tricks = "0.1"
julia = "1.6"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "Aqua"]
2 changes: 1 addition & 1 deletion src/MeasureTheory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ using MeasureBase: BoundedInts, BoundedReals, CountingBase, IntegerDomain, Integ
using MeasureBase: weightedmeasure, restrict
using MeasureBase: AbstractTransitionKernel

import Statistics: mean, var, std
import Statistics: mean, cov, var, std

import MeasureBase: likelihoodof
export likelihoodof
Expand Down
5 changes: 5 additions & 0 deletions src/combinators/product.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@
x
end

for f in (:mean, :var, :std)
@eval $f(d::ProductMeasure) = map($f, marginals(d))

Check warning on line 27 in src/combinators/product.jl

View check run for this annotation

Codecov / codecov/patch

src/combinators/product.jl#L27

Added line #L27 was not covered by tests
@eval $f(d::For) = map($f, marginals(d))
end

# # e.g. set(Normal(μ=2)^5, params, randn(5))
# function Accessors.set(
# d::ProductMeasure{A},
Expand Down
25 changes: 25 additions & 0 deletions src/parameterized/mvnormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,28 @@ end
@inline function proxy(d::MvNormal{(:μ, :Λ),Tuple{T,C}}) where {T,C<:Cholesky}
affine((μ = d.μ, λ = d.Λ.L), Normal()^supportdim(d))
end

# Statistics dispatch
for N in ((:μ,), (:σ,), (:λ,), (:Σ,), (:Λ,), (:μ, :σ), (:μ, :λ), (:μ, :Σ), (:μ, :Λ))
expr = Expr(:block)
if first(N) == :μ
push!(expr.args, :(mean(d::MvNormal{$N}) = d.μ))
else
push!(expr.args, :(mean(d::MvNormal{$N,Tuple{T}}) where {T} = zeros(eltype(T), supportdim(d))))
end
cov_var = last(N)
push!(expr.args, :(var(d::MvNormal{$N}) = diag(cov(d))))
push!(expr.args, :(std(d::MvNormal{$N}) = sqrt.(diag(cov(d)))))
if cov_var == :μ
push!(expr.args, :(cov(d::MvNormal{$N, Tuple{T}}) where {T} = I(supportdim(d)...)))
elseif cov_var == :σ
push!(expr.args, :(cov(d::MvNormal{$N}) = d.σ * d.σ'))
elseif cov_var == :λ
push!(expr.args, :(cov(d::MvNormal{$N}) = inv(d.λ' * d.λ)))
elseif cov_var == :Σ
push!(expr.args, :(cov(d::MvNormal{$N}) = Matrix(d.Σ)))
elseif cov_var == :Λ
push!(expr.args, :(cov(d::MvNormal{$N}) = inv(d.Λ')))
end
eval(expr)
end
19 changes: 19 additions & 0 deletions src/parameterized/normal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,25 @@
end
end

for N in ((:μ,), (:σ,), (:λ,), (:μ, :σ), (:μ, :λ))
expr = Expr(:block)
if first(N) == :μ
push!(expr.args, :(mean(d::Normal{$N}) = d.μ))
else
push!(expr.args, :(mean(d::Normal{$N,Tuple{T}}) where {T} = zero(T)))

Check warning on line 207 in src/parameterized/normal.jl

View check run for this annotation

Codecov / codecov/patch

src/parameterized/normal.jl#L207

Added line #L207 was not covered by tests
end
cov_var = last(N)
push!(expr.args, :(var(d::Normal{$N}) = abs2(std(d))))
if cov_var == :μ
push!(expr.args, :(std(d::Normal{$N, Tuple{T}}) where {T} = one(T)))
elseif cov_var == :σ
push!(expr.args, :(std(d::Normal{$N}) = d.σ))
elseif cov_var == :λ
push!(expr.args, :(std(d::Normal{$N}) = inv(d.λ)))
end
eval(expr)
end

MeasureBase.transport_origin(::Normal) = StdNormal()

MeasureBase.to_origin(::Normal{()}, y) = y
Expand Down
4 changes: 4 additions & 0 deletions src/parameterized/poisson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ function Base.rand(rng::AbstractRNG, T::Type, d::Poisson{(:logλ,)})
rand(rng, Dists.Poisson(exp(d.logλ)))
end

mean(d::Poisson{(:λ,)}) = d.λ
std(d::Poisson{(:λ,)}) = sqrt(d.λ)
var(d::Poisson{(:λ,)}) = d.λ

@inline function insupport(::Poisson, x)
isinteger(x) && x ≥ 0
end
12 changes: 12 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[deps]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
DynamicIterators = "6c76993d-992e-5bf1-9e63-34920a5a5a38"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MeasureBase = "fa1605e6-acd5-459c-a1e6-7e635759db14"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TransformVariables = "84d833dd-6860-57f9-a1a7-6da5db126cff"
57 changes: 56 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Test
using StatsFuns
using Base.Iterators: take
using Statistics
using Random
using LinearAlgebra
# using DynamicIterators: trace, TimeLift
Expand All @@ -23,7 +24,7 @@ using IfElse
# # detect_ambiguities_options...,
# )

Aqua.test_all(MeasureBase; ambiguities = false)
# Aqua.test_all(MeasureBase; ambiguities = false)

function draw2(μ)
x = rand(μ)
Expand Down Expand Up @@ -411,6 +412,33 @@ end
x = rand(d)
@test logdensityof(d, x) ≈ logdensityof(Dists.MvNormal(Σ), x)
@test logdensityof(MvNormal(zeros(3), σ), x) ≈ logdensityof(d, x)

D = 3
μ = randn(D)
zero_μ = zeros(D)
σ = LowerTriangular(randn(D, D))
Σ = σ * σ'
λ = inv(σ)
Λ = inv(Σ)

d = MvNormal(;μ)
@test mean(d) == μ
@test cov(d) == I(D)
@testset "Cov param: $(string(cov_param))" for (cov_param, val) in [(:σ, σ), (:Σ, Σ), (:λ, λ), (:Λ, Λ)]
@eval begin
# Mean is not given
d = MvNormal(;$cov_param=$val)
@test mean(d) == $zero_μ
@test cov(d) ≈ $Σ
@test std(d) ≈ sqrt.(diag($Σ))
@test var(d) ≈ diag($Σ)
d = MvNormal(;μ=$μ, $cov_param=$val)
@test mean(d) == $μ
@test cov(d) ≈ $Σ
@test std(d) ≈ sqrt.(diag($Σ))
@test var(d) ≈ diag($Σ)
end
end
end

@testset "NegativeBinomial" begin
Expand All @@ -419,10 +447,34 @@ end

@testset "Normal" begin
@test repro(Normal, (:μ, :σ))
μ = randn()
σ = rand()
λ = inv(σ)
d = Normal(;μ)
@test mean(d) == μ
@test var(d) == one(μ)
@test std(d) == one(μ)
@testset "std param : $(string(std_param))" for (std_param, val) in [(:σ, σ), (:λ, λ)]
@eval begin
d = Normal(;μ=$μ)
@test mean(d) == $μ
@test var(d) == one($μ)
@test std(d) == one($μ)
d = Normal(;μ=$μ, $std_param=$val)
@test mean(d) == $μ
@test var(d) ≈ abs2($σ)
@test std(d) ≈ $σ
end
end
end

@testset "Poisson" begin
@test repro(Poisson, (:λ,))
λ = rand()
d = Poisson(;λ)
@test mean(d) == λ
@test var(d) == λ
@test std(d) == sqrt(λ)
end

@testset "StudentT" begin
Expand All @@ -441,6 +493,9 @@ end
x = Vector{Int16}(undef, 10)
@test rand!(d, x) isa Vector
@test rand(d) isa Vector
@test mean(d) == mean.(collect(marginals(d)))
@test std(d) == std.(collect(marginals(d)))
@test var(d) == var.(collect(marginals(d)))

@testset "Indexed by Generator" begin
d = For((j^2 for j in 1:10)) do i
Expand Down
Loading