diff --git a/EpiAware/src/EpiAwareUtils/EpiAwareUtils.jl b/EpiAware/src/EpiAwareUtils/EpiAwareUtils.jl index 14442e9ab..532747869 100644 --- a/EpiAware/src/EpiAwareUtils/EpiAwareUtils.jl +++ b/EpiAware/src/EpiAwareUtils/EpiAwareUtils.jl @@ -8,13 +8,13 @@ using ..EpiAwareBase using DataFramesMeta: DataFrame, @rename! using DynamicPPL: Model, fix, condition, @submodel, @model using MCMCChains: Chains -using Random: AbstractRNG +using Random: AbstractRNG, randexp using Tables: rowtable using Distributions, DocStringExtensions, QuadGK, Statistics, Turing #Export Structures -export HalfNormal, DirectSample +export HalfNormal, DirectSample, SafePoisson, SafeNegativeBinomial #Export functions export scan, spread_draws, censored_pmf, get_param_array, prefix_submodel @@ -32,5 +32,7 @@ include("turing-methods.jl") include("DirectSample.jl") include("post-inference.jl") include("get_param_array.jl") +include("SafePoisson.jl") +include("SafeNegativeBinomial.jl") end diff --git a/EpiAware/src/EpiAwareUtils/SafeNegativeBinomial.jl b/EpiAware/src/EpiAwareUtils/SafeNegativeBinomial.jl new file mode 100644 index 000000000..3a32de966 --- /dev/null +++ b/EpiAware/src/EpiAwareUtils/SafeNegativeBinomial.jl @@ -0,0 +1,137 @@ +@doc raw" +Create a Negative binomial distribution with the specified mean that avoids `InExactError` +when the mean is too large. + + +# Parameterisation: +We are using a mean and cluster factorization of the negative binomial distribution such +that the variance to mean relationship is: + +```math +\sigma^2 = \mu + \alpha^2 \mu^2 +``` + +The reason for this parameterisation is that at sufficiently large mean values (i.e. `r > 1 / p`) `p` is approximately equal to the +standard fluctuation of the distribution, e.g. if `p = 0.05` we expect typical fluctuations of samples from the negative binomial to be +about 5% of the mean when the mean is notably larger than 20. Otherwise, we expect approximately Poisson noise. In our opinion, this +parameterisation is useful for specifying the distribution in a way that is easier to reason on priors for `p`. + +# Arguments: + +- `r`: The number of successes, although this can be extended to a continous number. +- `p`: Success rate. + +# Returns: + +- A `SafeNegativeBinomial` distribution with the specified mean. + +# Examples: + +```jldoctest SafeNegativeBinomial +using EpiAware, Distributions + +bigμ = exp(48.0) #Large value of μ +σ² = bigμ + 0.05 * bigμ^2 #Large variance + +# We can calculate the success rate from the mean to variance relationship +p = bigμ / σ² +r = bigμ * p / (1 - p) +d = SafeNegativeBinomial(r, p) +# output +EpiAware.EpiAwareUtils.SafeNegativeBinomial{Float64}(r=20.0, p=2.85032816548187e-20) +``` + +```jldoctest SafeNegativeBinomial +cdf(d, 100) +# output +0.0 +``` + +```jldoctest SafeNegativeBinomial +logpdf(d, 100) +# output +-850.1397180331871 +``` + +```jldoctest SafeNegativeBinomial +mean(d) +# output +7.016735912097631e20 +``` + +```jldoctest SafeNegativeBinomial +var(d) +# output +2.4617291430060293e40 +``` +" +struct SafeNegativeBinomial{T <: Real} <: DiscreteUnivariateDistribution + r::T + p::T + + function SafeNegativeBinomial{T}(r::T, p::T) where {T <: Real} + return new{T}(r, p) + end +end + +#Outer constructors make AD work +function SafeNegativeBinomial(r::T, p::T) where {T <: Real} + return SafeNegativeBinomial{T}(r, p) +end + +SafeNegativeBinomial(r::Real, p::Real) = SafeNegativeBinomial(promote(r, p)...) + +# helper function +_negbin(d::SafeNegativeBinomial) = NegativeBinomial(d.r, d.p; check_args = false) + +### Support + +Base.minimum(d::SafeNegativeBinomial) = 0 +Base.maximum(d::SafeNegativeBinomial) = Inf +Distributions.insupport(d::SafeNegativeBinomial, x::Integer) = x >= 0 + +#### Parameters + +Distributions.params(d::SafeNegativeBinomial) = _negbin(d) |> params +Distributions.partype(::SafeNegativeBinomial{T}) where {T} = T + +Distributions.succprob(d::SafeNegativeBinomial) = _negbin(d).p +Distributions.failprob(d::SafeNegativeBinomial{T}) where {T} = one(T) - _negbin(d).p + +#### Statistics + +Distributions.mean(d::SafeNegativeBinomial) = _negbin(d) |> mean +Distributions.var(d::SafeNegativeBinomial) = _negbin(d) |> var +Distributions.std(d::SafeNegativeBinomial) = _negbin(d) |> std +Distributions.skewness(d::SafeNegativeBinomial) = _negbin(d) |> skewness +Distributions.kurtosis(d::SafeNegativeBinomial) = _negbin(d) |> kurtosis +Distributions.mode(d::SafeNegativeBinomial) = _negbin(d) |> mode +function Distributions.kldivergence(p::SafeNegativeBinomial, q::SafeNegativeBinomial) + kldivergence(_negbin(p), _negbin(q)) +end + +#### Evaluation & Sampling + +Distributions.logpdf(d::SafeNegativeBinomial, k::Real) = logpdf(_negbin(d), k) + +Distributions.cdf(d::SafeNegativeBinomial, x::Real) = cdf(_negbin(d), x) +Distributions.ccdf(d::SafeNegativeBinomial, x::Real) = ccdf(_negbin(d), x) +Distributions.logcdf(d::SafeNegativeBinomial, x::Real) = logcdf(_negbin(d), x) +Distributions.logccdf(d::SafeNegativeBinomial, x::Real) = logccdf(_negbin(d), x) +Distributions.quantile(d::SafeNegativeBinomial, q::Real) = quantile(_negbin(d), q) +Distributions.cquantile(d::SafeNegativeBinomial, q::Real) = cquantile(_negbin(d), q) +Distributions.invlogcdf(d::SafeNegativeBinomial, lq::Real) = invlogcdf(_negbin(d), lq) +Distributions.invlogccdf(d::SafeNegativeBinomial, lq::Real) = invlogccdf(_negbin(d), lq) + +## sampling +function Base.rand(rng::AbstractRNG, d::SafeNegativeBinomial) + if isone(d.p) + return 0 + else + return rand(rng, SafePoisson(rand(rng, Gamma(d.r, (1 - d.p) / d.p)))) + end +end + +Distributions.mgf(d::SafeNegativeBinomial, t::Real) = mgf(_negbin(d), t) +Distributions.cgf(d::SafeNegativeBinomial, t) = cgf(_negbin(d), t) +Distributions.cf(d::SafeNegativeBinomial, t::Real) = cf(_negbin(d), t) diff --git a/EpiAware/src/EpiAwareUtils/SafePoisson.jl b/EpiAware/src/EpiAwareUtils/SafePoisson.jl new file mode 100644 index 000000000..57a541325 --- /dev/null +++ b/EpiAware/src/EpiAwareUtils/SafePoisson.jl @@ -0,0 +1,264 @@ +@doc raw" +Create a Poisson distribution with the specified mean that avoids `InExactError` +when the mean is too large. + +# Arguments: + +- `λ`: The mean of the Poisson distribution. + +# Returns: + +- A `SafePoisson` distribution with the specified mean. + +# Examples: + +```jldoctest SafePoisson +using EpiAware, Distributions + +bigλ = exp(48.0) #Large value of λ +d = SafePoisson(bigλ) +# output +EpiAware.EpiAwareUtils.SafePoisson{Float64}(λ=7.016735912097631e20) +``` + +```jldoctest SafePoisson +cdf(d, 2) +# output +0.0 +``` + +```jldoctest SafePoisson +logpdf(d, 100) +# output +-7.016735912097631e20 +``` + +```jldoctest SafePoisson +mean(d) +# output +7.016735912097631e20 +``` + +```jldoctest SafePoisson +var(d) +# output +7.016735912097631e20 +``` +" +struct SafePoisson{T <: Real} <: ContinuousUnivariateDistribution + λ::T + + SafePoisson{T}(λ::Real) where {T <: Real} = new{T}(λ) + SafePoisson(λ::Real) = SafePoisson{eltype(λ)}(λ) +end + +# Default outer constructor +SafePoisson() = SafePoisson{Float64}(1.0) + +# helper functions +_poisson(d::SafePoisson) = Poisson(d.λ; check_args = false) + +# ineffiecient but safe floor function to integer, which can handle large values of x +function _safe_int_floor(x::Real) + Tf = typeof(x) + if (Tf(typemin(Int)) - one(Tf)) < x < (Tf(typemax(Int)) + one(Tf)) + return floor(Int, x) + else + return floor(BigInt, x) + end +end + +function _safe_int_round(x::Real) + Tf = typeof(x) + if (Tf(typemin(Int)) - one(Tf)) < x < (Tf(typemax(Int)) + one(Tf)) + return round(Int, x) + else + return round(BigInt, x) + end +end + +### Parameters + +Distributions.params(d::SafePoisson) = _poisson(d) |> params +Distributions.partype(::SafePoisson{T}) where {T} = T +Distributions.rate(d::SafePoisson) = d.λ + +### Statistics + +Distributions.mean(d::SafePoisson) = d.λ +Distributions.mode(d::SafePoisson) = _safe_int_floor(d.λ) +Distributions.var(d::SafePoisson) = d.λ +Distributions.skewness(d::SafePoisson) = one(typeof(d.λ)) / sqrt(d.λ) +Distributions.kurtosis(d::SafePoisson) = one(typeof(d.λ)) / d.λ + +function Distributions.entropy(d::SafePoisson{T}) where {T <: Real} + entropy(_poisson(d)) +end + +function Distributions.kldivergence(p::SafePoisson, q::SafePoisson) + kldivergence(_poisson(p), _poisson(q)) +end + +### Evaluation + +Distributions.mgf(d::SafePoisson, t::Real) = mgf(_poisson(d), t) +Distributions.cgf(d::SafePoisson, t) = cgf(_poisson(d), t) +Distributions.cf(d::SafePoisson, t::Real) = cf(_poisson(d), t) +Distributions.logpdf(d::SafePoisson, x::Integer) = logpdf(_poisson(d), x) +Distributions.pdf(d::SafePoisson, x::Integer) = pdf(_poisson(d), x) +Distributions.cdf(d::SafePoisson, x::Integer) = cdf(_poisson(d), x) +Distributions.ccdf(d::SafePoisson, x::Integer) = ccdf(_poisson(d), x) +Distributions.quantile(d::SafePoisson, q::Real) = quantile(_poisson(d), q) + +### Support + +Base.minimum(d::SafePoisson) = 0 +Base.maximum(d::SafePoisson) = Inf +Distributions.insupport(d::SafePoisson, x::Integer) = x >= 0 + +### Sampling +### Taken from PoissonRandom.jl https://github.com/SciML/PoissonRandom.jl/blob/master/src/PoissonRandom.jl + +count_rand(λ) = count_rand(Random.GLOBAL_RNG, λ) +function count_rand(rng::AbstractRNG, λ) + n = 0 + c = randexp(rng) + while c < λ + n += 1 + c += randexp(rng) + end + return n +end + +# Algorithm from: +# +# J.H. Ahrens, U. Dieter (1982) +# "Computer Generation of Poisson Deviates from Modified Normal Distributions" +# ACM Transactions on Mathematical Software, 8(2):163-179 +# +# For μ sufficiently large, (i.e. >= 10.0) +# +ad_rand(λ) = ad_rand(Random.GLOBAL_RNG, λ) +function ad_rand(rng::AbstractRNG, λ) + s = sqrt(λ) + d = 6.0 * λ^2 + L = _safe_int_floor(λ - 1.1484) + # Step N + G = λ + s * randn(rng) + + if G >= 0.0 + K = _safe_int_floor(G) + # Step I + if K >= L + return K + end + + # Step S + U = rand(rng) + if d * U >= (λ - K)^3 + return K + end + + # Step P + px, py, fx, fy = procf(λ, K, s) + + # Step Q + if fy * (1 - U) <= py * exp(px - fx) + return K + end + end + + while true + # Step E + E = randexp(rng) + U = 2.0 * rand(rng) - 1.0 + T = 1.8 + copysign(E, U) + if T <= -0.6744 + continue + end + + K = _safe_int_floor(λ + s * T) + px, py, fx, fy = procf(λ, K, s) + c = 0.1069 / λ + + # Step H + @fastmath if c * abs(U) <= py * exp(px + E) - fy * exp(fx + E) + return K + end + end +end + +# log(1+x)-x +# accurate ~2ulps for -0.227 < x < 0.315 +function log1pmx_kernel(x::Float64) + r = x / (x + 2.0) + t = r * r + w = @evalpoly(t, + 6.66666666666666667e-1, # 2/3 + 4.00000000000000000e-1, # 2/5 + 2.85714285714285714e-1, # 2/7 + 2.22222222222222222e-1, # 2/9 + 1.81818181818181818e-1, # 2/11 + 1.53846153846153846e-1, # 2/13 + 1.33333333333333333e-1, # 2/15 + 1.17647058823529412e-1) # 2/17 + hxsq = 0.5 * x * x + r * (hxsq + w * t) - hxsq +end + +# use naive calculation or range reduction outside kernel range. +# accurate ~2ulps for all x +function log1pmx(x::Float64) + if !(-0.7 < x < 0.9) + return log1p(x) - x + elseif x > 0.315 + u = (x - 0.5) / 1.5 + return log1pmx_kernel(u) - 9.45348918918356180e-2 - 0.5 * u + elseif x > -0.227 + return log1pmx_kernel(x) + elseif x > -0.4 + u = (x + 0.25) / 0.75 + return log1pmx_kernel(u) - 3.76820724517809274e-2 + 0.25 * u + elseif x > -0.6 + u = (x + 0.5) * 2.0 + return log1pmx_kernel(u) - 1.93147180559945309e-1 + 0.5 * u + else + u = (x + 0.625) / 0.375 + return log1pmx_kernel(u) - 3.55829253011726237e-1 + 0.625 * u + end +end + +# Procedure F +function procf(λ, K::Int, s::Float64) + # can be pre-computed, but does not seem to affect performance + ω = 0.3989422804014327 / s + b1 = 0.041666666666666664 / λ + b2 = 0.3 * b1 * b1 + c3 = 0.14285714285714285 * b1 * b2 + c2 = b2 - 15.0 * c3 + c1 = b1 - 6.0 * b2 + 45.0 * c3 + c0 = 1.0 - b1 + 3.0 * b2 - 15.0 * c3 + + if K < 10 + px = -float(λ) + py = λ^K / factorial(K) + else + δ = 0.08333333333333333 / K + δ -= 4.8 * δ^3 + V = (λ - K) / K + px = K * log1pmx(V) - δ # avoids need for table + py = 0.3989422804014327 / sqrt(K) + end + X = (K - λ + 0.5) / s + X2 = X^2 + fx = -0.5 * X2 # missing negation in pseudo-algorithm, but appears in fortran code. + fy = ω * (((c3 * X2 + c2) * X2 + c1) * X2 + c0) + return px, py, fx, fy +end + +pois_rand(λ) = pois_rand(Random.GLOBAL_RNG, λ) +pois_rand(rng::AbstractRNG, λ) = λ < 6 ? count_rand(rng, λ) : ad_rand(rng, λ) + +function Base.rand(rng::AbstractRNG, d::SafePoisson) + pois_rand(rng, d.λ) +end diff --git a/EpiAware/src/EpiInference/NUTSampler.jl b/EpiAware/src/EpiInference/NUTSampler.jl index c3c69e926..fcf99f24a 100644 --- a/EpiAware/src/EpiInference/NUTSampler.jl +++ b/EpiAware/src/EpiInference/NUTSampler.jl @@ -24,6 +24,8 @@ The `NUTSampler` struct represents using the No-U-Turn Sampler (NUTS) to sample ndraws::Int "The metric type to use for the HMC sampler." metricT::M = DiagEuclideanMetric + "number of adaptation steps" + nadapts::Int = -1 #This uses the Turing NUTS number of adaptation steps default which is half of the number of draws end @doc raw" @@ -51,6 +53,7 @@ function _apply_nuts(model, method, prev_result; kwargs...) method.mcmc_parallel, method.ndraws ÷ method.nchains, method.nchains; + nadapts = method.nadapts, kwargs...) end @@ -69,5 +72,6 @@ function _apply_nuts(model, method, prev_result::PathfinderResult; kwargs...) method.ndraws ÷ method.nchains, method.nchains; init_params = init_params, + nadapts = method.nadapts, kwargs...) end diff --git a/EpiAware/src/EpiObsModels/utils.jl b/EpiAware/src/EpiObsModels/utils.jl index ccb0227f6..33a46949f 100644 --- a/EpiAware/src/EpiObsModels/utils.jl +++ b/EpiAware/src/EpiObsModels/utils.jl @@ -40,8 +40,8 @@ A `NegativeBinomial` distribution object. """ function NegativeBinomialMeanClust(μ, α) μ² = μ^2 - ex_σ² = α * μ² - p = μ / (μ + ex_σ²) - r = μ² / ex_σ² - return NegativeBinomial(r, p) + σ² = μ + α * μ² + p = μ / σ² + r = 1 / α + return SafeNegativeBinomial(r, p) end diff --git a/EpiAware/test/EpiAwareUtils/SafeNegativeBinomial.jl b/EpiAware/test/EpiAwareUtils/SafeNegativeBinomial.jl new file mode 100644 index 000000000..8c828e49d --- /dev/null +++ b/EpiAware/test/EpiAwareUtils/SafeNegativeBinomial.jl @@ -0,0 +1,106 @@ +@testitem "Testing SafeNegativeBinomial Constructor " begin + μ = 10.0 + α = 0.05 + # calculate the r, p parameters + σ² = μ + α * μ^2 + p = μ / σ² + r = μ * p / (1 - p) + + dist = SafeNegativeBinomial(r, p) + @test typeof(dist) <: SafeNegativeBinomial +end + +@testitem "Check distribution properties of SafeNegativeBinomial" begin + using Distributions, HypothesisTests, StatsBase + μ = 10.0 + α = 0.05 + # calculate the r, p parameters + σ² = μ + α * μ^2 + p = μ / σ² + r = μ * p / (1 - p) + + dist = SafeNegativeBinomial(r, p) + #Check Distributions.jl mean function + @test mean(dist) ≈ μ + @test var(dist) ≈ σ² + n = 100_000 + samples = [rand(dist) for _ in 1:n] + #Check mean from direct sampling of Distributions version and ANOVA and Variance F test comparisons + _dist = NegativeBinomial(r, p) + direct_samples = rand(_dist, n) + mean_pval = OneWayANOVATest(samples, direct_samples) |> pvalue + @test mean_pval > 1e-6 #Very unlikely to fail if the model is correctly implemented + var_pval = VarianceFTest(samples, direct_samples) |> pvalue + @test var_pval > 1e-6 #Very unlikely to fail if the model is correctly implemented + # Check that the variance is closer than 6 std of estimator to the direct samples + # very unlikely failure if the model is correctly implemented + @test abs(var(dist) - var(direct_samples)) < 6 * var(_dist)^2 * sqrt(2 / n) + @testset "Check quantiles" begin + for q in [0.1, 0.25, 0.5, 0.75, 0.9] + @test isapprox(quantile(dist, q), quantile(direct_samples, q), atol = 0.1) + end + end + + @testset "Check support boundaries" begin + @test minimum(dist) == 0 + @test maximum(dist) == Inf + end + + @testset "Check logpdf against Distributions" begin + for x in 0:10:100 + @test isapprox(logpdf(dist, x), + logpdf(_dist, x), atol = 0.1) + end + end + + @testset "Check CDF" begin + x = 0:10:100 + @test isapprox(cdf(dist, x), ecdf(direct_samples)(x), atol = 0.05) + end +end + +@testitem "Testing safety of rand call for SafeNegativeBinomial at large values" begin + using Distributions + μ = exp(48.0) #Large value of λ + α = 0.05 + # calculate the r, p parameters + σ² = μ + α * μ^2 + p = μ / σ² + r = μ * p / (1 - p) + + dist = SafeNegativeBinomial(r, p) + @testset "Large value of mean samples a BigInt with SafePoisson" begin + @test rand(dist) isa BigInt + end + @testset "Large value of mean sample failure with Poisson" begin + _dist = EpiAware.EpiAwareUtils._negbin(dist) + @test_throws InexactError rand(_dist) + end +end + +@testitem "Check gradients can be evaluated for logpdf of SafeNegativeBinomial" begin + using Distributions, ReverseDiff, FiniteDifferences, ForwardDiff + μ = exp(48.0) #Large value of λ + α = 0.05 + # calculate the r, p parameters + σ² = μ + α * μ^2 + p = μ / σ² + r = μ * p / (1 - p) + + # Make a helper function for grad calls + f(x) = SafeNegativeBinomial(exp(x[1]), p) |> nb -> logpdf(nb, 100) + g_fin_diff = grad(central_fdm(5, 1), f, [log(r)])[1] + + # Compiled ReverseDiff version + input = randn(1) + const f_tape = ReverseDiff.GradientTape(f, input) + const compiled_f_tape = ReverseDiff.compile(f_tape) + cfg = ReverseDiff.GradientConfig(input) + g_rvd = ReverseDiff.gradient(f, [log(r)], cfg) + + # ForwardDiff version + g_fd = ForwardDiff.gradient(f, [log(r)]) + + @test g_fin_diff ≈ g_rvd + @test g_fin_diff ≈ g_fd +end diff --git a/EpiAware/test/EpiAwareUtils/SafePoisson.jl b/EpiAware/test/EpiAwareUtils/SafePoisson.jl new file mode 100644 index 000000000..f8fa89cad --- /dev/null +++ b/EpiAware/test/EpiAwareUtils/SafePoisson.jl @@ -0,0 +1,83 @@ +@testitem "Testing SafePoisson Constructor " begin + λ = 10.0 + dist = SafePoisson(λ) + @test typeof(dist) <: SafePoisson +end + +@testitem "Check distribution properties of SafePoisson" begin + using Distributions, HypothesisTests, StatsBase + λ = 10.0 + dist = SafePoisson(λ) + #Check Distributions.jl mean function + @test mean(dist) ≈ λ + n = 100_000 + samples = [rand(dist) for _ in 1:n] + #Check mean from direct sampling of Distributions version and ANOVA and Variance F test comparisons + direct_samples = rand(Poisson(λ), n) + mean_pval = OneWayANOVATest(samples, direct_samples) |> pvalue + @test mean_pval > 1e-6 #Very unlikely to fail if the model is correctly implemented + var_pval = VarianceFTest(samples, direct_samples) |> pvalue + @test var_pval > 1e-6 #Very unlikely to fail if the model is correctly implemented + # Check that the variance is closer than 6 std of estimator to the direct samples + # very unlikely failure if the model is correctly implemented + @test abs(var(dist) - var(direct_samples)) < 6 * var(Poisson(λ))^2 * sqrt(2 / n) + + @testset "Check quantiles" begin + for q in [0.1, 0.25, 0.5, 0.75, 0.9] + @test isapprox(quantile(dist, q), quantile(direct_samples, q), atol = 0.1) + end + end + + @testset "Check support boundaries" begin + @test minimum(dist) == 0 + @test maximum(dist) == Inf + end + + @testset "Check logpdf against Distributions" begin + for x in 0:10:100 + @test isapprox(logpdf(dist, x), + logpdf(Poisson(λ), x), atol = 0.1) + end + end + + @testset "Check CDF" begin + x = 0:10:100 + @test isapprox(cdf(dist, x), ecdf(direct_samples)(x), atol = 0.05) + end +end + +@testitem "Testing safety of rand call for SafePoisson at large values" begin + using Distributions + bigλ = exp(48.0) #Large value of λ + dist = SafePoisson(bigλ) + @testset "Large value of mean samples a BigInt with SafePoisson" begin + @test rand(dist) isa BigInt + end + @testset "Large value of mean sample failure with Poisson" begin + _dist = Poisson(dist.λ) + @test_throws InexactError rand(_dist) + end +end + +@testitem "Check gradients can be evaluated for logpdf of SafePoisson" begin + using Distributions, ReverseDiff, FiniteDifferences, ForwardDiff + log_μ = 48.0 #Plausible large value to hit with a log scale random walk over a number of time steps + α = 0.05 + + # Make a helper function for grad calls + f(x) = SafePoisson(exp(x[1])) |> poi -> logpdf(poi, 100) + g_fin_diff = grad(central_fdm(5, 1), f, [log_μ])[1] + + # Compiled ReverseDiff version + input = randn(1) + const f_tape = ReverseDiff.GradientTape(f, input) + const compiled_f_tape = ReverseDiff.compile(f_tape) + cfg = ReverseDiff.GradientConfig(input) + g_rvd = ReverseDiff.gradient(f, [log_μ], cfg) + + # ForwardDiff version + g_fd = ForwardDiff.gradient(f, [log_μ]) + + @test g_fin_diff ≈ g_rvd + @test g_fin_diff ≈ g_fd +end diff --git a/EpiAware/test/Project.toml b/EpiAware/test/Project.toml index f944a2595..95298f7be 100644 --- a/EpiAware/test/Project.toml +++ b/EpiAware/test/Project.toml @@ -5,12 +5,15 @@ DataFramesMeta = "1313f7d8-7da2-5740-9ea0-a2ca25f37964" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" Pathfinder = "b1d3bc72-d0e7-4279-b92f-7fa5d6d2d454" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"