Skip to content

Commit

Permalink
Merge branch '20230708-fsaad-quantiles'
Browse files Browse the repository at this point in the history
  • Loading branch information
fsaad committed Aug 7, 2023
2 parents 9f121a1 + 5d3aa16 commit e81ce37
Show file tree
Hide file tree
Showing 7 changed files with 207 additions and 56 deletions.
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ AutoGP.mcmc_parameters!
AutoGP.predict
AutoGP.predict_proba
AutoGP.predict_mvn
AutoGP.predict_quantile
AutoGP.log_marginal_likelihood_estimate
AutoGP.particle_weights
AutoGP.effective_sample_size
Expand Down
103 changes: 74 additions & 29 deletions docs/src/tutorials/iclaims.ipynb

Large diffs are not rendered by default.

75 changes: 52 additions & 23 deletions docs/src/tutorials/iclaims.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,18 +88,20 @@ AutoGP.fit_smc!(model; schedule=schedule, n_mcmc=50, n_hmc=10, shuffle=true, ada
```

Running SMC round 69/343
weights [3.38e-32, 4.11e-20, 7.26e-29, 1.02e-44, 4.74e-57, 9.10e-04, 9.99e-01, 2.71e-17]
Particle Weights: [3.38e-32, 4.11e-20, 7.26e-29, 1.02e-44, 4.74e-57, 9.10e-04, 9.99e-01, 2.71e-17]
Particle ESS: 0.1252276603207894
resampled true
accepted MCMC[4/50] HMC[40/40]
accepted MCMC[6/50] HMC[48/50]
accepted MCMC[7/50] HMC[61/63]
accepted MCMC[9/50] HMC[73/77]
accepted MCMC[8/50] HMC[62/64]
accepted MCMC[7/50] HMC[61/63]
accepted MCMC[7/50] HMC[68/69]
accepted MCMC[9/50] HMC[73/77]
accepted MCMC[12/50] HMC[92/97]
accepted MCMC[14/50] HMC[114/118]
Running SMC round 138/343
weights [1.10e-01, 1.39e-01, 1.98e-01, 1.71e-01, 1.66e-01, 1.78e-01, 1.80e-02, 2.08e-02]
Particle Weights: [1.10e-01, 1.39e-01, 1.98e-01, 1.71e-01, 1.66e-01, 1.78e-01, 1.80e-02, 2.08e-02]
Particle ESS: 0.7836676673226702
resampled true
accepted MCMC[2/50] HMC[3/5]
accepted MCMC[3/50] HMC[15/17]
Expand All @@ -110,7 +112,8 @@ AutoGP.fit_smc!(model; schedule=schedule, n_mcmc=50, n_hmc=10, shuffle=true, ada
accepted MCMC[12/50] HMC[20/32]
accepted MCMC[16/50] HMC[62/73]
Running SMC round 207/343
weights [1.21e-18, 1.72e-18, 1.00e+00, 3.00e-20, 1.48e-12, 2.59e-17, 8.86e-18, 1.49e-17]
Particle Weights: [1.21e-18, 1.72e-18, 1.00e+00, 3.00e-20, 1.48e-12, 2.59e-17, 8.86e-18, 1.49e-17]
Particle ESS: 0.12500000000036948
resampled true
accepted MCMC[8/50] HMC[0/8]
accepted MCMC[10/50] HMC[0/10]
Expand All @@ -121,26 +124,28 @@ AutoGP.fit_smc!(model; schedule=schedule, n_mcmc=50, n_hmc=10, shuffle=true, ada
accepted MCMC[13/50] HMC[0/13]
accepted MCMC[16/50] HMC[2/18]
Running SMC round 276/343
weights [1.52e-01, 1.63e-04, 1.51e-01, 4.79e-01, 1.04e-01, 9.47e-02, 1.76e-02, 1.36e-03]
Particle Weights: [1.52e-01, 1.63e-04, 1.51e-01, 4.79e-01, 1.04e-01, 9.47e-02, 1.76e-02, 1.36e-03]
Particle ESS: 0.42322820857724425
resampled true
accepted MCMC[7/50] HMC[0/7]
accepted MCMC[12/50] HMC[0/12]
accepted MCMC[11/50] HMC[0/11]
accepted MCMC[12/50] HMC[0/12]
accepted MCMC[11/50] HMC[1/12]
accepted MCMC[12/50] HMC[0/12]
accepted MCMC[12/50] HMC[0/12]
accepted MCMC[13/50] HMC[1/14]
accepted MCMC[19/50] HMC[0/19]
accepted MCMC[20/50] HMC[0/20]
Running SMC round 343/343
weights [4.25e-03, 3.87e-04, 5.39e-03, 5.37e-03, 2.16e-04, 5.35e-01, 4.40e-01, 9.31e-03]
Particle Weights: [4.25e-03, 3.87e-04, 5.39e-03, 5.37e-03, 2.16e-04, 5.35e-01, 4.40e-01, 9.31e-03]
Particle ESS: 0.2603461961652077
accepted MCMC[10/50] HMC[0/10]
accepted MCMC[10/50] HMC[0/10]
accepted MCMC[12/50] HMC[0/12]
accepted MCMC[13/50] HMC[1/14]
accepted MCMC[14/50] HMC[0/14]
accepted MCMC[13/50] HMC[1/14]
accepted MCMC[16/50] HMC[0/16]
accepted MCMC[17/50] HMC[0/17]
accepted MCMC[14/50] HMC[1/15]
accepted MCMC[17/50] HMC[0/17]


Plotting the forecasts from each particle reflects the structural uncertainty. 7/8 particles have inferred a periodic component ([`AutoGP.GP.Periodic`](@ref)) with additive linear trend [`AutoGP.GP.Linear`](@ref). 1/8 of the particles has inferred a sum of a periodic kernel and gamma exponential ([`AutoGP.GP.GammaExponential`](@ref)) kernel, which is stationary but not "smooth" (formally, not mean-square differentiable).
Expand Down Expand Up @@ -188,7 +193,7 @@ for (w, k) in zip(AutoGP.particle_weights(model), AutoGP.covariance_kernels(mode
end
```

Particle weight 0.004250523793423254
Particle weight 0.004250523793201452



Expand All @@ -198,7 +203,7 @@ end



Particle weight 0.0003867955857232888
Particle weight 0.00038679558572478394



Expand All @@ -208,7 +213,7 @@ end



Particle weight 0.0053919770385446346
Particle weight 0.0053919770385277765



Expand All @@ -218,7 +223,7 @@ end



Particle weight 0.005373167226504197
Particle weight 0.005373167226484039



Expand All @@ -228,7 +233,7 @@ end



Particle weight 0.0002161576411907359
Particle weight 0.00021615764118894197



Expand All @@ -238,7 +243,7 @@ end



Particle weight 0.5354267607568092
Particle weight 0.5354267607554092



Expand All @@ -248,7 +253,7 @@ end



Particle weight 0.43964210411237775
Particle weight 0.43964210411447696



Expand All @@ -258,7 +263,7 @@ end



Particle weight 0.009312513845423434
Particle weight 0.00931251384499995



Expand All @@ -268,6 +273,30 @@ end



We can also query the overall quantiles of the predictive distribution over new data by using [`AutoGP.predict_quantile`](@ref).


```julia
# Obtain overall quantiles.
quantiles_lo, = AutoGP.predict_quantile.(Ref(model), Ref(ds_query), .025, tol=1e-6)
quantiles_md, = AutoGP.predict_quantile.(Ref(model), Ref(ds_query), .50, tol=1e-6)
quantiles_hi, = AutoGP.predict_quantile.(Ref(model), Ref(ds_query), .975, tol=1e-6)

# Plot the combined predictions.
fig, ax = plt.subplots()
ax.scatter(df_train.ds, df_train.y, marker=".", color="k", label="Observed Data")
ax.scatter(df_test.ds, df_test.y, marker=".", color="r", label="Test Data")
ax.plot(ds_query, quantiles_md, color="k", linewidth=1)
ax.fill_between(ds_query, quantiles_lo, quantiles_hi, color="tab:blue", alpha=.5)
fig.set_size_inches((20, 10))
```



![png](iclaims_files/iclaims_14_0.png)



!!! note

Mean forecasts, quantile forecasts, and probability densities values obtained via [`AutoGP.predict`](@ref) and [`AutoGP.predict_proba`](@ref) are all in the transformed (log space). Only quantile forecasts can be transformed back to direct space via `exp`. Converting mean forecasts and probability densities can be performed by using the [`Distributions.MvLogNormal`](https://juliastats.org/Distributions.jl/stable/multivariate/#Distributions.MvLogNormal) constructor, as demonstrated below.
Expand All @@ -291,18 +320,18 @@ ax.legend()



![png](iclaims_files/iclaims_15_0.png)
![png](iclaims_files/iclaims_17_0.png)






PyObject <matplotlib.legend.Legend object at 0x7f3598672c80>
PyObject <matplotlib.legend.Legend object at 0x7f5b068d02e0>



The difference between the blue and black curves is too small to observe on the scale above; let us plot the bias that arises from doing a naive transformation.
The difference between the blue and black curves is too small to observe on the scale above; let us plot the bias that arises from doing a naive transformation of the predictive mean.


```julia
Expand All @@ -312,6 +341,6 @@ ax.plot(ds_query, Distributions.mean(log_mvn) - exp.(Distributions.mean(mvn)));



![png](iclaims_files/iclaims_17_0.png)
![png](iclaims_files/iclaims_19_0.png)


Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/src/tutorials/iclaims_files/iclaims_17_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/src/tutorials/iclaims_files/iclaims_19_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
84 changes: 80 additions & 4 deletions src/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ import Gen

using Match

using Distributions: MixtureModel
using Distributions: MvNormal
using Distributions: Normal

"""
seed!(seed)
Set the random seed of the global random number generator.
Expand Down Expand Up @@ -450,14 +454,17 @@ with weights [`particle_weights`](@ref)`(model)`. These objects can be retrieved
and
[`Distributions.probs`](https://juliastats.org/Distributions.jl/stable/mixture/#Distributions.probs-Tuple{AbstractMixtureModel}), respectively.
"""
function predict_mvn(model::GPModel, ds::IndexType; noise_pred::Union{Nothing,Float64}=nothing)
function predict_mvn(
model::GPModel,
ds::IndexType;
noise_pred::Union{Nothing,Float64}=nothing)
if !(eltype(ds) <: eltype(model.ds))
error("Invalid time $(ds), expected $(eltype(model.ds))")
end
ds_numeric = Transforms.apply(model.ds_transform, to_numeric(ds))
n_particles = num_particles(model)
weights = particle_weights(model)
distributions = Vector{Distributions.MvNormal}(undef, n_particles)
distributions = Vector{MvNormal}(undef, n_particles)
Threads.@threads for i=1:n_particles
dist = Inference.predict_mvn(
model.pf_state.traces[i],
Expand All @@ -469,9 +476,78 @@ function predict_mvn(model::GPModel, ds::IndexType; noise_pred::Union{Nothing,Fl
# The same is not true for if model.y_transform is log, in which
# case we would need to return Distributions.MvLogNormal.
mu, cov = Transforms.unapply_mean_var(model.y_transform, mu, cov)
distributions[i] = Distributions.MvNormal(mu, cov)
distributions[i] = MvNormal(mu, cov)
end
return MixtureModel(distributions, weights)
end

"""
(x::Vector, success::Bool) = predict_quantile(
model::GPModel, ds::IndexType, q::Real;
noise_pred::Union{Nothing,Float64}=nothing, tol=1e-5, max_iter=1e6)
Evaluates the inverse cumulative distribution function (CDF) of the
multivariate Gaussian mixture model returned by [`predict_mvn`](@ref) at
`q` (between 0 and 1, exclusive) separately for each dimension. The
returned vector `x` has the same length as the index points `ds`.
# Note
The inverse CDF is numerically estimated using a binary search algorithm.
The keyword arguments `tol` and `max_iter` correspond to the desired
absolute tolerance of the estimate and the maximum number of binary search
iterations, respectively. The returned Boolean variable `success` indicates
whether the returned value `x` has been located to within the specified
error tolerance.
# See also
- [`predict_mvn`](@ref)
"""
function predict_quantile(
model::GPModel,
ds::IndexType,
q::Real;
noise_pred::Union{Nothing,Float64}=nothing,
tol=1e-5,
max_iter=1e6)
(0 < q < 1) || error("Quantile must be in (0,1).")
mvn = predict_mvn(model, ds; noise_pred=noise_pred)
components = Distributions.components(mvn)
means = hcat(Distributions.mean.(components)...)
vars = hcat(Distributions.var.(components)...)
mixtures = [Normal.(m, sqrt.(v)) for (m, v) in zip(eachrow(means), eachrow(vars))]
weights = Distributions.probs(mvn)
mixture = MixtureModel.(mixtures, Ref(weights))
@assert length(mixture) == length(ds)
@assert all(length.(mixtures) .== num_particles(model))
x = zeros(length(mixture))
iter = 0
x_max = repeat([Inf], length(mixture))
x_min = repeat([-Inf], length(mixture))
success = false
while iter < max_iter
epsilon = @. Distributions.cdf(mixture, x) - q
if all(abs.(epsilon) .< tol)
success = true
break
end
x_max = ifelse.(epsilon .> 0, x, x_max)
x_min = ifelse.(epsilon .< 0, x, x_min)
x_hi = min.(x_max, (@. 2^sign(x)*x + (x == 0)))
x_lo = max.(x_min, (@. 2^-sign(x)*x - (x == 0)))
x_hi_mid = Distributions.mean([x, x_hi])
x_lo_mid = Distributions.mean([x, x_lo])
x = ifelse.(
abs.(epsilon) .< tol,
x,
ifelse.(
epsilon .< 0,
x_hi_mid,
x_lo_mid))
iter += 1
end
return Distributions.MixtureModel(distributions, weights)
return (x, success)
end

"""
Expand Down

0 comments on commit e81ce37

Please sign in to comment.