diff --git a/Manifest.toml b/Manifest.toml deleted file mode 100644 index 8a1a33db..00000000 --- a/Manifest.toml +++ /dev/null @@ -1,164 +0,0 @@ -[[Base64]] -uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" - -[[CSTParser]] -deps = ["LibGit2", "Test", "Tokenize"] -git-tree-sha1 = "437c93bc191cd55957b3f8dee7794b6131997c56" -uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f" -version = "0.5.2" - -[[Compat]] -deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] -git-tree-sha1 = "84aa74986c5b9b898b0d1acaf3258741ee64754f" -uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "2.1.0" - -[[DataStructures]] -deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"] -git-tree-sha1 = "ca971f03e146cf144a9e2f2ce59674f5bf0e8038" -uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.15.0" - -[[Dates]] -deps = ["Printf"] -uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" - -[[DelimitedFiles]] -deps = ["Mmap"] -uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" - -[[Distributed]] -deps = ["LinearAlgebra", "Random", "Serialization", "Sockets"] -uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" - -[[FillArrays]] -deps = ["LinearAlgebra", "Random", "SparseArrays", "Test"] -git-tree-sha1 = "2def0123a4f3572234405b0e3d80bfe5d3e1a2a4" -uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "0.5.0" - -[[InplaceOps]] -deps = ["LinearAlgebra", "Test"] -git-tree-sha1 = "50b41d59e7164ab6fda65e71049fee9d890731ff" -uuid = "505f98c9-085e-5b2c-8e89-488be7bf1f34" -version = "0.3.0" - -[[InteractiveUtils]] -deps = ["LinearAlgebra", "Markdown"] -uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" - -[[LazyArrays]] -deps = ["FillArrays", "LinearAlgebra", "MacroTools", "StaticArrays", "Test"] -git-tree-sha1 = "4439e840fe68cbcde806fcc625d05166227a56a5" -uuid = "5078a376-72f3-5289-bfd5-ec5146d43c02" -version = "0.8.0" - -[[LibGit2]] -uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" - -[[Libdl]] -uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" - -[[LinearAlgebra]] -deps = ["Libdl"] -uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" - -[[Logging]] -uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" - -[[MacroTools]] -deps = ["CSTParser", "Compat", "DataStructures", "Test"] -git-tree-sha1 = "daecd9e452f38297c686eba90dba2a6d5da52162" -uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.5.0" - -[[Markdown]] -deps = ["Base64"] -uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" - -[[Missings]] -deps = ["Dates", "InteractiveUtils", "SparseArrays", "Test"] -git-tree-sha1 = "d1d2585677f2bd93a97cfeb8faa7a0de0f982042" -uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" -version = "0.4.0" - -[[Mmap]] -uuid = "a63ad114-7e13-5084-954f-fe012c677804" - -[[OrderedCollections]] -deps = ["Random", "Serialization", "Test"] -git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1" -uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.1.0" - -[[Pkg]] -deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] -uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" - -[[Printf]] -deps = ["Unicode"] -uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" - -[[REPL]] -deps = ["InteractiveUtils", "Markdown", "Sockets"] -uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" - -[[Random]] -deps = ["Serialization"] -uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" - -[[SHA]] -uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" - -[[Serialization]] -uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" - -[[SharedArrays]] -deps = ["Distributed", "Mmap", "Random", "Serialization"] -uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" - -[[Sockets]] -uuid = "6462fe0b-24de-5631-8697-dd941f90decc" - -[[SortingAlgorithms]] -deps = ["DataStructures", "Random", "Test"] -git-tree-sha1 = "03f5898c9959f8115e30bc7226ada7d0df554ddd" -uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" -version = "0.3.1" - -[[SparseArrays]] -deps = ["LinearAlgebra", "Random"] -uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" - -[[StaticArrays]] -deps = ["InteractiveUtils", "LinearAlgebra", "Random", "Statistics", "Test"] -git-tree-sha1 = "3841b39ed5f047db1162627bf5f80a9cd3e39ae2" -uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "0.10.3" - -[[Statistics]] -deps = ["LinearAlgebra", "SparseArrays"] -uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" - -[[StatsBase]] -deps = ["DataStructures", "DelimitedFiles", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "Test"] -git-tree-sha1 = "435707791dc85a67d98d671c1c3fcf1b20b00f94" -uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.29.0" - -[[Test]] -deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] -uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[[Tokenize]] -deps = ["Printf", "Test"] -git-tree-sha1 = "3e83f60b74911d3042d3550884ca2776386a02b8" -uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624" -version = "0.5.3" - -[[UUIDs]] -deps = ["Random"] -uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" - -[[Unicode]] -uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" diff --git a/Project.toml b/Project.toml index 8fff6357..7ec8f73d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,5 @@ name = "AdvancedHMC" uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" -authors = ["The Turing Team"] version = "0.1.5" [deps] @@ -10,6 +9,8 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" +Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" [extras] Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" diff --git a/src/AdvancedHMC.jl b/src/AdvancedHMC.jl index ee141055..462cdfea 100644 --- a/src/AdvancedHMC.jl +++ b/src/AdvancedHMC.jl @@ -6,6 +6,8 @@ using Statistics: mean, var, middle using LinearAlgebra: Symmetric, UpperTriangular, mul!, ldiv!, dot, I, diag, cholesky using LazyArrays: BroadcastArray using Random: GLOBAL_RNG, AbstractRNG +using Parameters: @unpack +using ArgCheck: @argcheck import StatsBase: sample diff --git a/src/adaptation/Adaptation.jl b/src/adaptation/Adaptation.jl index 26575fa0..777ba742 100644 --- a/src/adaptation/Adaptation.jl +++ b/src/adaptation/Adaptation.jl @@ -1,6 +1,7 @@ module Adaptation -import Base: string +import Base: string, rand +using Random: GLOBAL_RNG, AbstractRNG using LinearAlgebra: Symmetric, UpperTriangular, mul!, ldiv!, dot, I, diag, cholesky import LinearAlgebra, Statistics using ..AdvancedHMC: DEBUG diff --git a/src/adaptation/precond.jl b/src/adaptation/precond.jl index f5caf4aa..ee7f9e1a 100644 --- a/src/adaptation/precond.jl +++ b/src/adaptation/precond.jl @@ -282,6 +282,35 @@ function Base.getproperty(dem::DenseEuclideanMetric, d::Symbol) return d === :dim ? size(getfield(dem, :M⁻¹), 1) : getfield(dem, d) end +# `rand` functions for `metric` types. +function Base.rand( + rng::AbstractRNG, + metric::UnitEuclideanMetric +) + r = randn(rng, metric.dim) + return r +end + +function Base.rand( + rng::AbstractRNG, + metric::DiagEuclideanMetric +) + r = randn(rng, metric.dim) + r ./= metric.sqrtM⁻¹ + return r +end + +function Base.rand( + rng::AbstractRNG, + metric::DenseEuclideanMetric +) + r = randn(rng, metric.dim) + ldiv!(metric.cholM⁻¹, r) + return r +end + +Base.rand(metric::AbstractMetric) = rand(GLOBAL_RNG, metric) + #### #### Preconditioner constructors #### diff --git a/src/hamiltonian.jl b/src/hamiltonian.jl index 4c77aa89..ce6c5d9f 100644 --- a/src/hamiltonian.jl +++ b/src/hamiltonian.jl @@ -1,60 +1,100 @@ # TODO: add a type for kinetic energy +# TODO: add cache for gradients by letting ∂logπ∂θ return both log density and gradient struct Hamiltonian{M<:AbstractMetric, Tlogπ, T∂logπ∂θ} metric::M - logπ::Tlogπ - ∂logπ∂θ::T∂logπ∂θ + ℓπ::Tlogπ + ∂ℓπ∂θ::T∂logπ∂θ end # Create a `Hamiltonian` with a new `M⁻¹` -(h::Hamiltonian)(M⁻¹) = Hamiltonian(h.metric(M⁻¹), h.logπ, h.∂logπ∂θ) +(h::Hamiltonian)(M⁻¹) = Hamiltonian(h.metric(M⁻¹), h.ℓπ, h.∂ℓπ∂θ) -∂H∂θ(h::Hamiltonian, θ::AbstractVector) = -h.∂logπ∂θ(θ) +∂H∂θ(h::Hamiltonian, θ::AbstractVector) = -h.∂ℓπ∂θ(θ) ∂H∂r(h::Hamiltonian{<:UnitEuclideanMetric}, r::AbstractVector) = copy(r) ∂H∂r(h::Hamiltonian{<:DiagEuclideanMetric}, r::AbstractVector) = h.metric.M⁻¹ .* r ∂H∂r(h::Hamiltonian{<:DenseEuclideanMetric}, r::AbstractVector) = h.metric.M⁻¹ * r -function hamiltonian_energy(h::Hamiltonian, θ::AbstractVector, r::AbstractVector) - K = kinetic_energy(h, r, θ) - if isnan(K) - K = Inf - @warn "Kinetic energy is `NaN` and is set to `Inf`." - end - V = potential_energy(h, θ) - if isnan(V) - V = Inf - @warn "Potential energy is `NaN` and is set to `Inf`." +struct DualValue{Tv<:AbstractFloat, Tg<:AbstractVector{Tv}} + value::Tv # Cached value, e.g. logπ(θ). + gradient::Tg # Cached gradient, e.g. ∇logπ(θ). +end + +struct PhasePoint{T<:AbstractVector, V<:DualValue} + θ::T # Position variables / model parameters. + r::T # Momentum variables + ℓπ::V # Cached neg potential energy for the current θ. + ℓκ::V # Cached neg kinect energy for the current r. + function PhasePoint(θ::T, r::T, ℓπ::V, ℓκ::V) where {T,V} + @argcheck length(θ) == length(r) == length(ℓπ.gradient) == length(ℓπ.gradient) + # if !(all(isfinite, θ) && all(isfinite, r) && all(isfinite, ℓπ) && all(isfinite, ℓκ)) + if !(isfinite(θ) && isfinite(r) && isfinite(ℓπ) && isfinite(ℓκ)) + @warn "The current proposal will be rejected (due to numerical error(s))..." + ℓκ = DualValue(-Inf, ℓκ.gradient) + ℓπ = DualValue(-Inf, ℓπ.gradient) + end + new{T,V}(θ, r, ℓπ, ℓκ) end - return K + V end -potential_energy(h::Hamiltonian, θ::AbstractVector) = -h.logπ(θ) +phasepoint( + h::Hamiltonian, + θ::T, + r::T; + ℓπ = DualValue(neg_energy(h, r, θ), ∂H∂θ(h, θ)), + ℓκ = DualValue(neg_energy(h, θ), ∂H∂r(h, r)) +) where {T<:AbstractVector} = PhasePoint(θ, r, ℓπ, ℓκ) + + +Base.isfinite(v::DualValue) = all(isfinite, v.value) && all(isfinite, v.gradient) +Base.isfinite(v::AbstractVector) = all(isfinite, v) +Base.isfinite(z::PhasePoint) = isfinite(z.ℓπ) && isfinite(z.ℓκ) + +### +### Negative energy (or log probability) functions. +### NOTE: the general form (i.e. non-Euclidean) of K depends on both θ and r. +### + +neg_energy(z::PhasePoint) = z.ℓπ.value + z.ℓκ.value -# Kinetic energy -# NOTE: the general form of K depends on both θ and r -kinetic_energy(h::Hamiltonian{<:UnitEuclideanMetric}, r, θ) = sum(abs2, r) / 2 -function kinetic_energy(h::Hamiltonian{<:DiagEuclideanMetric}, r, θ) - return sum(abs2(r[i]) * h.metric.M⁻¹[i] for i in 1:length(r)) / 2 +neg_energy(h::Hamiltonian, θ::AbstractVector) = h.ℓπ(θ) + +neg_energy( + h::Hamiltonian{<:UnitEuclideanMetric}, + r::T, + θ::T +) where {T<:AbstractVector} = -sum(abs2, r) / 2 + +function neg_energy( + h::Hamiltonian{<:DiagEuclideanMetric}, + r::T, + θ::T +) where {T<:AbstractVector} + _r = [abs2(r[i]) * h.metric.M⁻¹[i] for i in 1:length(r)] + return -sum(_r) / 2 end -function kinetic_energy(h::Hamiltonian{<:DenseEuclideanMetric}, r, θ) + +function neg_energy( + h::Hamiltonian{<:DenseEuclideanMetric}, + r::T, + θ::T +) where {T<:AbstractVector} mul!(h.metric._temp, h.metric.M⁻¹, r) - return dot(r, h.metric._temp) / 2 + return -dot(r, h.metric._temp) / 2 end -# Momentum sampler -function rand_momentum(rng::AbstractRNG, h::Hamiltonian{<:UnitEuclideanMetric}) - return randn(rng, h.metric.dim) -end -function rand_momentum(rng::AbstractRNG, h::Hamiltonian{<:DiagEuclideanMetric}) - r = randn(rng, h.metric.dim) - r ./= h.metric.sqrtM⁻¹ - return r -end -function rand_momentum(rng::AbstractRNG, h::Hamiltonian{<:DenseEuclideanMetric}) - r = randn(rng, h.metric.dim) - ldiv!(h.metric.cholM⁻¹, r) - return r -end +#### +#### Momentum sampler +#### + +rand_momentum( + rng::AbstractRNG, + z::PhasePoint, + h::Hamiltonian +) = phasepoint(h, z.θ, rand(rng, h.metric)) -rand_momentum(h::Hamiltonian) = rand_momentum(GLOBAL_RNG, h) +rand_momentum( + z::PhasePoint, + h::Hamiltonian +) = phasepoint(h, z.θ, rand(GLOBAL_RNG, h.metric)) diff --git a/src/integrator.jl b/src/integrator.jl index 26f2ee01..d2e4c75d 100644 --- a/src/integrator.jl +++ b/src/integrator.jl @@ -14,64 +14,27 @@ function (::Leapfrog)(ϵ::AbstractFloat) return Leapfrog(ϵ) end -function lf_momentum( - ϵ::T, - h::Hamiltonian, - θ::AbstractVector{T}, - r::AbstractVector{T} -) where {T<:Real} - _∂H∂θ = ∂H∂θ(h, θ) - !is_valid(_∂H∂θ) && return r, false - return r - ϵ * _∂H∂θ, true -end - -function lf_position( - ϵ::T, h::Hamiltonian, - θ::AbstractVector{T}, - r::AbstractVector{T} -) where {T<:Real} - return θ + ϵ * ∂H∂r(h, r) -end - # TODO: double check the function below to see if it is type stable or not function step( lf::Leapfrog{F}, h::Hamiltonian, - θ::AbstractVector{T}, - r::AbstractVector{T}, - n_steps::Int=1 + z::PhasePoint, + n_steps::Int=1; + fwd = n_steps > 0 # Simulate hamiltonian backward when n_steps < 0 ) where {F<:AbstractFloat,T<:Real} - fwd = n_steps > 0 # simulate hamiltonian backward when n_steps < 0 + @unpack θ, r = z ϵ = fwd ? lf.ϵ : -lf.ϵ - n_valid = 0 - r_new, _is_valid_1 = lf_momentum(ϵ/2, h, θ, r) + ∇θ = ∂H∂θ(h, θ) for i = 1:abs(n_steps) - θ_new = lf_position(ϵ, h, θ, r_new) - r_new, _is_valid_2 = lf_momentum(i == n_steps ? ϵ / 2 : ϵ, h, θ_new, r_new) - if (_is_valid_1 && _is_valid_2) - θ, r = θ_new, r_new - n_valid = n_valid + 1 - else - # Reverse half leapfrog step for r when breaking - # the loop immaturely. - if i > 1 && i < abs(n_steps) - r, _ = lf_momentum(-lf.ϵ / 2, h, θ, r) - end - break - end - end - return θ, r, n_valid > 0 -end - -### -### Utility function. -### - -function is_valid(v::AbstractVector{<:Real}) - if any(isnan, v) || any(isinf, v) - return false - else - return true + r = r - ϵ/2 * ∇θ # Take a half leapfrog step for momentum variable + ∇r = ∂H∂r(h, r) + θ = θ + ϵ * ∇r # Take a full leapfrog step for position variable + ∇θ = ∂H∂θ(h, θ) + r = r - ϵ/2 * ∇θ # Take a half leapfrog step for momentum variable + z′ = phasepoint(h, θ, r) + !isfinite(z′) && break + z = z′ end + return z end diff --git a/src/sampler.jl b/src/sampler.jl index b05a8a33..5d351888 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -1,29 +1,51 @@ -sample(h::Hamiltonian, prop::AbstractProposal, θ::AbstractVector{T}, n_samples::Int; verbose::Bool=true) where {T<:Real} = - sample(GLOBAL_RNG, h, prop, θ, n_samples; verbose=verbose) +## +## Sampling functions +## -function sample(rng::AbstractRNG, h::Hamiltonian, prop::AbstractProposal, θ::AbstractVector{T}, n_samples::Int; verbose::Bool=true) where {T<:Real} +sample( + h::Hamiltonian, + τ::AbstractProposal, + θ::AbstractVector{T}, + n_samples::Int; + verbose::Bool=true +) where {T<:Real} = sample(GLOBAL_RNG, h, τ, θ, n_samples; verbose=verbose) + +function sample( + rng::AbstractRNG, + h::Hamiltonian, + τ::AbstractProposal, + θ::AbstractVector{T}, + n_samples::Int; + verbose::Bool=true +) where {T<:Real} θs = Vector{Vector{T}}(undef, n_samples) Hs = Vector{T}(undef, n_samples) αs = Vector{T}(undef, n_samples) + r = rand(rng, h.metric) + z = phasepoint(h, θ, r) time = @elapsed for i = 1:n_samples - θs[i], Hs[i], αs[i] = step(rng, h, prop, i == 1 ? θ : θs[i-1]) + z, αs[i] = transition(rng, τ, h, z) + θs[i], Hs[i] = z.θ, neg_energy(z) + z = rand_momentum(rng, z, h) end - verbose && @info "Finished sampling with $time (s)" typeof(h.metric) typeof(prop) EBFMI(Hs) mean(αs) + verbose && @info "Finished sampling with $time (s)" typeof(h.metric) typeof(τ) EBFMI(Hs) mean(αs) return θs end -sample(h::Hamiltonian, - prop::AbstractProposal, +sample( + h::Hamiltonian, + τ::AbstractProposal, θ::AbstractVector{T}, n_samples::Int, adaptor::Adaptation.AbstractAdaptor, n_adapts::Int=min(div(n_samples, 10), 1_000); verbose::Bool=true -) where {T<:Real} = sample(GLOBAL_RNG, h, prop, θ, n_samples, adaptor, n_adapts; verbose=verbose) +) where {T<:Real} = sample(GLOBAL_RNG, h, τ, θ, n_samples, adaptor, n_adapts; verbose=verbose) -function sample(rng::AbstractRNG, +function sample( + rng::AbstractRNG, h::Hamiltonian, - prop::AbstractProposal, + τ::AbstractProposal, θ::AbstractVector{T}, n_samples::Int, adaptor::Adaptation.AbstractAdaptor, @@ -33,36 +55,24 @@ function sample(rng::AbstractRNG, θs = Vector{Vector{T}}(undef, n_samples) Hs = Vector{T}(undef, n_samples) αs = Vector{T}(undef, n_samples) + r = rand(rng, h.metric) + z = phasepoint(h, θ, r) time = @elapsed for i = 1:n_samples - θs[i], Hs[i], αs[i] = step(rng, h, prop, i == 1 ? θ : θs[i-1]) + z, αs[i] = transition(rng, τ, h, z) + θs[i], Hs[i] = z.θ, neg_energy(z) if i <= n_adapts adapt!(adaptor, θs[i], αs[i]) - h, prop = update(h, prop, adaptor) + h, τ = update(h, τ, adaptor) if verbose if i == n_adapts - @info "Finished $n_adapts adapation steps" typeof(adaptor) prop.integrator.ϵ h.metric + @info "Finished $n_adapts adapation steps" typeof(adaptor) τ.integrator.ϵ h.metric elseif i % Int(n_adapts / 10) == 0 - @info "Adapting $i of $n_adapts steps" typeof(adaptor) prop.integrator.ϵ h.metric + @info "Adapting $i of $n_adapts steps" typeof(adaptor) τ.integrator.ϵ h.metric end end end + z = rand_momentum(rng, z, h) end - verbose && @info "Finished $n_samples sampling steps in $time (s)" typeof(h.metric) typeof(prop) EBFMI(Hs) mean(αs) + verbose && @info "Finished $n_samples sampling steps in $time (s)" typeof(h.metric) typeof(τ) EBFMI(Hs) mean(αs) return θs end - -function step(rng::AbstractRNG, - h::Hamiltonian, - prop::AbstractTrajectory{I}, - θ::AbstractVector{T} -) where {T<:Real,I<:AbstractIntegrator} - h = update(h, θ) # Ensure h.metric has the same dim as θ. - r = rand_momentum(rng, h) - θ_new, r_new, α, H_new = transition(rng, prop, h, θ, r) - return θ_new, H_new, α -end - -step(h::Hamiltonian, - p::AbstractTrajectory, - θ::AbstractVector{T} -) where {T<:Real} = step(GLOBAL_RNG, h, p, θ) diff --git a/src/trajectory.jl b/src/trajectory.jl index f2a18aa3..e40cd2e3 100644 --- a/src/trajectory.jl +++ b/src/trajectory.jl @@ -1,16 +1,18 @@ -#### -#### Hamiltonian dynamics numerical simulation trajectories -#### +## +## Hamiltonian dynamics numerical simulation trajectories +## abstract type AbstractProposal end abstract type AbstractTrajectory{I<:AbstractIntegrator} <: AbstractProposal end -# Create a callback function for all `AbstractTrajectory` without passing random number generator -transition(at::AbstractTrajectory{I}, +# Create a callback function for all `AbstractTrajectory` +# without passing random number generator +transition( + τ::AbstractTrajectory{I}, h::Hamiltonian, θ::AbstractVector{T}, r::AbstractVector{T} -) where {I<:AbstractIntegrator,T<:Real} = transition(GLOBAL_RNG, at, h, θ, r) +) where {I<:AbstractIntegrator,T<:Real} = transition(GLOBAL_RNG, τ, h, θ, r) ### ### Standard HMC implementation with fixed leapfrog step numbers. @@ -20,29 +22,31 @@ struct StaticTrajectory{I<:AbstractIntegrator} <: AbstractTrajectory{I} n_steps :: Int end +""" +Termination (i.e. no-U-turn). +""" +struct Termination end + """ Create a `StaticTrajectory` with a new integrator """ -function (tlp::StaticTrajectory)(integrator::AbstractIntegrator) - return StaticTrajectory(integrator, tlp.n_steps) +function (τ::StaticTrajectory)(integrator::AbstractIntegrator) + return StaticTrajectory(integrator, τ.n_steps) end function transition( rng::AbstractRNG, - prop::StaticTrajectory, + τ::StaticTrajectory, h::Hamiltonian, - θ::AbstractVector{T}, - r::AbstractVector{T} + z::PhasePoint ) where {T<:Real} - H = hamiltonian_energy(h, θ, r) - θ_new, r_new, _ = step(prop.integrator, h, θ, r, prop.n_steps) - H_new = hamiltonian_energy(h, θ_new, r_new) + z′ = step(τ.integrator, h, z, τ.n_steps) # Accept via MH criteria - is_accept, α = mh_accept(rng, H, H_new) + is_accept, α = mh_accept(rng, -neg_energy(z), -neg_energy(z′)) if is_accept - θ, r, H = θ_new, -r_new, H_new + z = PhasePoint(z′.θ, -z′.r, z′.ℓπ, z′.ℓκ) end - return θ, r, α, H + return z, α end abstract type DynamicTrajectory{I<:AbstractIntegrator} <: AbstractTrajectory{I} end @@ -58,21 +62,21 @@ end """ Create a `HMCDA` with a new integrator """ -function (tlp::HMCDA)(integrator::AbstractIntegrator) - return HMCDA(integrator, tlp.λ) +function (τ::HMCDA)(integrator::AbstractIntegrator) + return HMCDA(integrator, τ.λ) end function transition( rng::AbstractRNG, - prop::HMCDA, + τ::HMCDA, h::Hamiltonian, θ::AbstractVector{T}, r::AbstractVector{T} ) where {T<:Real} - # Create the corresponding static prop - n_steps = max(1, round(Int, prop.λ / prop.integrator.ϵ)) - static_prop = StaticTrajectory(prop.integrator, n_steps) - return transition(rng, static_prop, h, θ, r) + # Create the corresponding static τ + n_steps = max(1, round(Int, τ.λ / τ.integrator.ϵ)) + static_τ = StaticTrajectory(τ.integrator, n_steps) + return transition(rng, static_τ, h, θ, r) end @@ -89,9 +93,8 @@ struct NUTS{I<:AbstractIntegrator} <: DynamicTrajectory{I} Δ_max :: AbstractFloat end -""" -Helper function to use default values -""" + +# Helper function to use default values NUTS(integrator::AbstractIntegrator) = NUTS(integrator, 10, 1000.0) """ @@ -111,8 +114,7 @@ function build_tree( rng::AbstractRNG, nt::DynamicTrajectory{I}, h::Hamiltonian, - θ::AbstractVector{T}, - r::AbstractVector{T}, + z::PhasePoint, logu::AbstractFloat, v::Int, j::Int, @@ -120,85 +122,106 @@ function build_tree( ) where {I<:AbstractIntegrator,T<:Real} if j == 0 # Base case - take one leapfrog step in the direction v. - θ′, r′, _is_valid = step(nt.integrator, h, θ, r, v) - H′ = _is_valid ? hamiltonian_energy(h, θ′, r′) : Inf + z′ = step(nt.integrator, h, z, v) + H′ = -neg_energy(z′) n′ = (logu <= -H′) ? 1 : 0 s′ = (logu < nt.Δ_max + -H′) ? 1 : 0 α′ = exp(min(0, H - H′)) - return θ′, r′, θ′, r′, θ′, r′, n′, s′, α′, 1 + return z′, z′, z′, n′, s′, α′, 1 else # Recursion - build the left and right subtrees. - θm, rm, θp, rp, θ′, r′, n′, s′, α′, n′α = build_tree(rng, nt, h, θ, r, logu, v, j - 1, H) + zm, zp, z′, n′, s′, α′, n′α = build_tree(rng, nt, h, z, logu, v, j - 1, H) if s′ == 1 if v == -1 - θm, rm, _, _, θ′′, r′′, n′′, s′′, α′′, n′′α = build_tree(rng, nt, h, θm, rm, logu, v, j - 1, H) + zm, _, z′′, n′′, s′′, α′′, n′′α = build_tree(rng, nt, h, zm, logu, v, j - 1, H) else - _, _, θp, rp, θ′′, r′′, n′′, s′′, α′′, n′′α = build_tree(rng, nt, h, θp, rp, logu, v, j - 1, H) + _, zp, z′′, n′′, s′′, α′′, n′′α = build_tree(rng, nt, h, zp, logu, v, j - 1, H) end if rand(rng) < n′′ / (n′ + n′′) - θ′ = θ′′ - r′ = r′′ + z′ = z′′ end α′ = α′ + α′′ n′α = n′α + n′′α - s′ = s′′ * (dot(θp - θm, ∂H∂r(h, rm)) >= 0 ? 1 : 0) * (dot(θp - θm, ∂H∂r(h, rp)) >= 0 ? 1 : 0) + s′ = s′′ * (dot(zp.θ - zm.θ, ∂H∂r(h, zm.r)) >= 0 ? 1 : 0) * (dot(zp.θ - zm.θ, ∂H∂r(h, zp.r)) >= 0 ? 1 : 0) n′ = n′ + n′′ end - return θm, rm, θp, rp, θ′, r′, n′, s′, α′, n′α + # s: termination stats + # α: MH stats, i.e. sum of MH accept prob for all leapfrog steps + # nα: total # of leap frog steps, i.e. phase points in a trajectory + # n: # of acceptable candicates, i.e. prob is larger than slice variable u + return zm, zp, z′, n′, s′, α′, n′α end end build_tree( nt::DynamicTrajectory{I}, h::Hamiltonian, - θ::AbstractVector{T}, - r::AbstractVector{T}, + z::PhasePoint, logu::AbstractFloat, v::Int, j::Int, H::AbstractFloat -) where {I<:AbstractIntegrator,T<:Real} = build_tree(GLOBAL_RNG, nt, h, θ, r, logu, v, j, H) +) where {I<:AbstractIntegrator,T<:Real} = build_tree(GLOBAL_RNG, nt, h, z, logu, v, j, H) function transition( rng::AbstractRNG, nt::DynamicTrajectory{I}, h::Hamiltonian, - θ::AbstractVector{T}, - r::AbstractVector{T} + z::PhasePoint ) where {I<:AbstractIntegrator,T<:Real} - H = hamiltonian_energy(h, θ, r) + θ, r = z.θ, z.r + H = -neg_energy(z) logu = log(rand(rng)) - H - θm = θ; θp = θ; rm = r; rp = r; j = 0; θ_new = θ; r_new = r; n = 1; s = 1 + zm = z; zp = z; z_new = z; j = 0; n = 1; s = 1 local α, nα while s == 1 && j <= nt.max_depth v = rand(rng, [-1, 1]) if v == -1 - θm, rm, _, _, θ′, r′,n′, s′, α, nα = build_tree(rng, nt, h, θm, rm, logu, v, j, H) + zm, _, z′, n′, s′, α, nα = build_tree(rng, nt, h, zm, logu, v, j, H) else - _, _, θp, rp, θ′, r′,n′, s′, α, nα = build_tree(rng, nt, h, θp, rp, logu, v, j, H) + zm, _, z′, n′, s′, α, nα = build_tree(rng, nt, h, zm, logu, v, j, H) end if s′ == 1 if rand(rng) < min(1, n′ / n) - θ_new = θ′ - r_new = r′ + z_new = z′ end end n = n + n′ - s = s′ * (dot(θp - θm, ∂H∂r(h, rm)) >= 0 ? 1 : 0) * (dot(θp - θm, ∂H∂r(h, rp)) >= 0 ? 1 : 0) + s = s′ * (dot(zp.θ - zm.θ, ∂H∂r(h, zm.r)) >= 0 ? 1 : 0) * (dot(zp.θ - zm.θ, ∂H∂r(h, zp.r)) >= 0 ? 1 : 0) j = j + 1 end - H_new = 0 # Warning: NUTS always return H_new = 0; - return θ_new, r_new, α / nα, H_new + return z_new, α / nα end +transition(nt::DynamicTrajectory{I}, + h::Hamiltonian, + θ::AbstractVector{T}, + r::AbstractVector{T} +) where {I<:AbstractIntegrator,T<:Real} = transition(GLOBAL_RNG, nt, h, θ, r) + + +## +## API: required by Turing.Gibbs +## + +# TODO: rename all `Turing.step` to `transition`? + +# function step(rng::AbstractRNG, h::Hamiltonian, τ::AbstractTrajectory{I}, θ::AbstractVector{T}) where {T<:Real,I<:AbstractIntegrator} +# r = rand(rng, h.metric) +# θ_new, r_new, α, H_new = transition(rng, τ, h, θ, r) +# return θ_new, H_new, α +# end +# +# step(h::Hamiltonian, p::AbstractTrajectory, θ::AbstractVector{T}) where {T<:Real} = step(GLOBAL_RNG, h, p, θ) + ### ### Find for an initial leap-frog step-size via heuristic search. ### @@ -213,11 +236,12 @@ function find_good_eps( a_min, a_cross, a_max = 0.25, 0.5, 0.75 # minimal, crossing, maximal accept ratio d = 2.0 - r = rand_momentum(rng, h) - H = hamiltonian_energy(h, θ, r) + r = rand(rng, h.metric) + z = phasepoint(h, θ, r) + H = -neg_energy(z) - θ′, r′, _is_valid = step(Leapfrog(ϵ), h, θ, r) - H_new = _is_valid ? hamiltonian_energy(h, θ′, r′) : Inf + z′ = step(Leapfrog(ϵ), h, z) + H_new = -neg_energy(z′) ΔH = H - H_new direction = ΔH > log(a_cross) ? 1 : -1 @@ -225,8 +249,8 @@ function find_good_eps( # Crossing step: increase/decrease ϵ until accept ratio cross a_cross. for _ = 1:max_n_iters ϵ′ = direction == 1 ? d * ϵ : 1 / d * ϵ - θ′, r′, _is_valid = step(Leapfrog(ϵ′), h, θ, r) - H_new = _is_valid ? hamiltonian_energy(h, θ′, r′) : Inf + z′ = step(Leapfrog(ϵ′), h, z) + H_new = -neg_energy(z′) ΔH = H - H_new DEBUG && @debug "Crossing step" direction H_new ϵ "α = $(min(1, exp(ΔH)))" @@ -244,8 +268,8 @@ function find_good_eps( ϵ, ϵ′ = ϵ < ϵ′ ? (ϵ, ϵ′) : (ϵ′, ϵ) # ensure ϵ < ϵ′ for _ = 1:max_n_iters ϵ_mid = middle(ϵ, ϵ′) - θ′, r′, _is_valid = step(Leapfrog(ϵ_mid), h, θ, r) - H_new = _is_valid ? hamiltonian_energy(h, θ′, r′) : Inf + z′ = step(Leapfrog(ϵ_mid), h, z) + H_new = -neg_energy(z′) ΔH = H - H_new DEBUG && @debug "Bisection step" H_new ϵ_mid "α = $(min(1, exp(ΔH)))" @@ -269,27 +293,43 @@ find_good_eps( ) where {T<:Real} = find_good_eps(GLOBAL_RNG, h, θ; max_n_iters=max_n_iters) -function mh_accept(rng::AbstractRNG, H::AbstractFloat, H_new::AbstractFloat) - logα = min(0, H - H_new) - return log(rand(rng)) < logα, exp(logα) +function mh_accept( + rng::AbstractRNG, + H::T, + H_new::T +) where {T<:AbstractFloat} + α = min(1.0, exp(H - H_new)) + accept = rand(rng) < α + return accept, α end -mh_accept(H::AbstractFloat, H_new::AbstractFloat) = mh_accept(GLOBAL_RNG, H, H_new) + +mh_accept( + H::T, + H_new::T +) where {T<:AbstractFloat} = mh_accept(GLOBAL_RNG, H, H_new) #### #### Adaption #### -function update(h::Hamiltonian, prop::AbstractProposal, dpc::Adaptation.AbstractPreconditioner) - return h(getM⁻¹(dpc)), prop -end +update( + h::Hamiltonian, + τ::AbstractProposal, + dpc::Adaptation.AbstractPreconditioner +) = h(getM⁻¹(dpc)), τ -function update(h::Hamiltonian, prop::AbstractProposal, da::NesterovDualAveraging) - return h, prop(prop.integrator(getϵ(da))) -end +update( + h::Hamiltonian, + τ::AbstractProposal, + da::NesterovDualAveraging +) = h, τ(τ.integrator(getϵ(da))) -function update(h::Hamiltonian, prop::AbstractProposal, ca::Adaptation.AbstractCompositeAdaptor) - return h(getM⁻¹(ca.pc)), prop(prop.integrator(getϵ(ca.ssa))) -end + +update( + h::Hamiltonian, + τ::AbstractProposal, + ca::Adaptation.AbstractCompositeAdaptor +) = h(getM⁻¹(ca.pc)), τ(τ.integrator(getϵ(ca.ssa))) function update(h::Hamiltonian, θ::AbstractVector{<:Real}) metric = h.metric diff --git a/test/hamiltonian.jl b/test/hamiltonian.jl index aaf16936..8a835ce0 100644 --- a/test/hamiltonian.jl +++ b/test/hamiltonian.jl @@ -1,4 +1,5 @@ using Test, AdvancedHMC +import AdvancedHMC: DualValue, PhasePoint include("common.jl") @@ -6,4 +7,17 @@ include("common.jl") h = Hamiltonian(UnitEuclideanMetric(D), logπ, ∂logπ∂θ) r_init = ones(D) -@test AdvancedHMC.kinetic_energy(h, r_init, θ_init) == D / 2 +@test -AdvancedHMC.neg_energy(h, r_init, θ_init) == D / 2 + + +init_z1() = PhasePoint([NaN], [NaN], DualValue(0.,[0.]), DualValue(0.,[0.])) +init_z2() = PhasePoint([Inf], [Inf], DualValue(0.,[0.]), DualValue(0.,[0.])) + +@test_logs (:warn, "The current proposal will be rejected (due to numerical error(s))...") init_z1() +@test_logs (:warn, "The current proposal will be rejected (due to numerical error(s))...") init_z2() + +z1 = init_z1() +z2 = init_z2() + +@test z1.ℓπ.value == z1.ℓπ.value +@test z1.ℓκ.value == z1.ℓκ.value diff --git a/test/hmc.jl b/test/hmc.jl index b9ae9846..7042e900 100644 --- a/test/hmc.jl +++ b/test/hmc.jl @@ -15,28 +15,28 @@ n_adapts = 2_000 DenseEuclideanMetric(D), ] h = Hamiltonian(metric, logπ, ∂logπ∂θ) - @testset "$(typeof(prop))" for prop in [ + @testset "$(typeof(τ))" for τ in [ StaticTrajectory(Leapfrog(ϵ), n_steps), NUTS(Leapfrog(find_good_eps(h, θ_init))), ] - @info "HMC and NUTS numerical test" typeof(prop) n_samples - samples = sample(h, prop, θ_init, n_samples; verbose=false) + @info "HMC and NUTS numerical test" typeof(τ) n_samples + samples = sample(h, τ, θ_init, n_samples; verbose=false) @test mean(samples[n_adapts+1:end]) ≈ zeros(D) atol=RNDATOL @testset "$(typeof(adaptor))" for adaptor in [ Preconditioner(metric), - NesterovDualAveraging(0.8, prop.integrator.ϵ), + NesterovDualAveraging(0.8, τ.integrator.ϵ), NaiveCompAdaptor( Preconditioner(metric), - NesterovDualAveraging(0.8, prop.integrator.ϵ), + NesterovDualAveraging(0.8, τ.integrator.ϵ), ), StanNUTSAdaptor( n_adapts, Preconditioner(metric), - NesterovDualAveraging(0.8, prop.integrator.ϵ), + NesterovDualAveraging(0.8, τ.integrator.ϵ), ), ] - @info "HMC and NUTS numerical test" typeof(prop) n_samples typeof(adaptor) typeof(metric) n_adapts - samples = sample(h, prop, θ_init, n_samples, adaptor, n_adapts; verbose=false) + @info "HMC and NUTS numerical test" typeof(τ) n_samples typeof(adaptor) typeof(metric) n_adapts + samples = sample(h, τ, θ_init, n_samples, adaptor, n_adapts; verbose=false) @test mean(samples[n_adapts+1:end]) ≈ zeros(D) atol=RNDATOL end end diff --git a/test/integrator.jl b/test/integrator.jl index 176787e6..b54ea5d5 100644 --- a/test/integrator.jl +++ b/test/integrator.jl @@ -6,33 +6,35 @@ lf = Leapfrog(ϵ) θ_init = randn(D) h = Hamiltonian(UnitEuclideanMetric(D), logπ, ∂logπ∂θ) -r_init = AdvancedHMC.rand_momentum(h) +r_init = AdvancedHMC.rand(h.metric) n_steps = 10 @testset "step(::Leapfrog) against steps(::Leapfrog)" begin - θ_step, r_step = copy(θ_init), copy(r_init) + z = AdvancedHMC.phasepoint(h, copy(θ_init), copy(r_init)) + z_step = z t_step = @elapsed for i = 1:n_steps - θ_step, r_step, _ = AdvancedHMC.step(lf, h, θ_step, r_step) + z_step = AdvancedHMC.step(lf, h, z_step) end - t_steps = @elapsed θ_steps, r_steps, _ = AdvancedHMC.step(lf, h, θ_init, r_init, n_steps) + t_steps = @elapsed z_steps = AdvancedHMC.step(lf, h, z, n_steps) @info "Performance of step() v.s. steps()" n_steps t_step t_steps t_step / t_steps - @test θ_step ≈ θ_steps atol=DETATOL - @test r_step ≈ r_steps atol=DETATOL + @test z_step.θ ≈ z_steps.θ atol=DETATOL + @test z_step.r ≈ z_steps.r atol=DETATOL end # using Turing: Inference # @testset "steps(::Leapfrog) against Turing.Inference._leapfrog()" begin +# z = AdvancedHMC.phasepoint(h, θ_init, r_init) # t_Turing = @elapsed θ_Turing, r_Turing, _ = Inference._leapfrog(θ_init, r_init, n_steps, ϵ, x -> (nothing, ∂logπ∂θ(x))) -# t_AHMC = @elapsed θ_AHMC, r_AHMC, _ = AdvancedHMC.step(lf, h, θ_init, r_init, n_steps) +# t_AHMC = @elapsed z_AHMC = AdvancedHMC.step(lf, h, z, n_steps) # @info "Performance of leapfrog of AdvancedHMC v.s. Turing" n_steps t_Turing t_AHMC t_Turing / t_AHMC # -# @test θ_Turing ≈ θ_AHMC atol=DETATOL -# @test r_Turing ≈ r_AHMC atol=DETATOL +# @test θ_Turing ≈ z_AHMC.θ atol=DETATOL +# @test r_Turing ≈ z_AHMC.r atol=DETATOL # end using LinearAlgebra: dot @@ -47,19 +49,20 @@ using Statistics: mean q_init = randn(D) h = Hamiltonian(UnitEuclideanMetric(D), negU, ∂negU∂q) - p_init = AdvancedHMC.rand_momentum(h) + p_init = AdvancedHMC.rand(h.metric) q, p = copy(q_init), copy(p_init) + z = AdvancedHMC.phasepoint(h, q, p) n_steps = 10_000 qs = zeros(n_steps) ps = zeros(n_steps) Hs = zeros(n_steps) for i = 1:n_steps - q, p, _ = AdvancedHMC.step(lf, h, q, p) - qs[i] = q[1] - ps[i] = p[1] - Hs[i] = AdvancedHMC.hamiltonian_energy(h, q, p) + z = AdvancedHMC.step(lf, h, z) + qs[i] = z.θ[1] + ps[i] = z.r[1] + Hs[i] = -AdvancedHMC.neg_energy(z) end # Throw first 1_000 steps diff --git a/test/proposal.jl b/test/proposal.jl index 73882401..f9a9fcc7 100644 --- a/test/proposal.jl +++ b/test/proposal.jl @@ -7,18 +7,20 @@ lf = Leapfrog(ϵ) θ_init = randn(D) h = Hamiltonian(UnitEuclideanMetric(D), logπ, ∂logπ∂θ) -prop = NUTS(Leapfrog(find_good_eps(h, θ_init))) -r_init = AdvancedHMC.rand_momentum(h) +τ = NUTS(Leapfrog(find_good_eps(h, θ_init))) +r_init = AdvancedHMC.rand(h.metric) @testset "Passing random number generator" begin for seed in [1234, 5678, 90] rng = MersenneTwister(seed) - θ1, r1 = AdvancedHMC.transition(rng, prop, h, θ_init, r_init) + z = AdvancedHMC.phasepoint(h, θ_init, r_init) + z1′, _ = AdvancedHMC.transition(rng, τ, h, z) rng = MersenneTwister(seed) - θ2, r2 = AdvancedHMC.transition(rng, prop, h, θ_init, r_init) + z = AdvancedHMC.phasepoint(h, θ_init, r_init) + z2′, _ = AdvancedHMC.transition(rng, τ, h, z) - @test θ1 == θ2 - @test r1 == r2 + @test z1′.θ == z2′.θ + @test z1′.r == z2′.r end end