Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type design for PhasePoint and DualValue. #51

Merged
merged 56 commits into from
May 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
a8166ea
add Parameters to REQUIRE
yebai Apr 16, 2019
e596c03
Draft type design for PhasePoint and LogDensity.
yebai Apr 16, 2019
07ad3bd
Merge branch 'master' into hg/some-type-refactoring
yebai Apr 17, 2019
9c3be85
Renamed some types.
yebai Apr 17, 2019
8134f50
Added cache for kinect energy.
yebai Apr 17, 2019
01adeb4
Added some comments.
yebai Apr 17, 2019
d464a74
Merge `Sampler.step` into `Sampler.sample`.
yebai Apr 17, 2019
90804f2
Typo fix.
yebai Apr 17, 2019
7d09b1d
Moved `step` function into trajectory.
yebai Apr 17, 2019
a121843
Fix leftover bug after merging `sampler.step`.
yebai Apr 17, 2019
fdfb80b
Added types for MHStats, and NUTS.
yebai Apr 18, 2019
cb277af
Relaxed type constraint for PhasePoint.
yebai Apr 18, 2019
e794add
Add dependency on ArgCheck.
yebai Apr 18, 2019
ffb3327
Switch to `PhasePoint` .
yebai Apr 18, 2019
6216579
Use `PhasePoint` in `transition`.
yebai Apr 18, 2019
1a8bf7e
Merge branch 'master' into hg/some-type-refactoring
yebai Apr 19, 2019
c0c1130
Merge branch 'master' into hg/some-type-refactoring
yebai Apr 20, 2019
0ce89ea
Create REQUIRE
yebai Apr 20, 2019
b5829b3
Update Project.toml
yebai Apr 20, 2019
e330bbe
Update REQUIRE
yebai Apr 20, 2019
512c2ec
Update Project.toml
yebai Apr 20, 2019
76d14e5
Merge branch 'master' into hg/some-type-refactoring
yebai Apr 24, 2019
06b7e81
Fix dependency.
yebai Apr 24, 2019
652854c
RM Manifest.toml file.
yebai Apr 24, 2019
52384a4
Added numerical checking to `PhasePoint`.
yebai Apr 24, 2019
3ff7076
Typo fix.
yebai Apr 24, 2019
5d6856c
Merge branch 'master' into hg/some-type-refactoring
yebai May 13, 2019
875404b
Re-formatting - no functionality change.
yebai May 13, 2019
11e9cdb
Re-formating - no functionality change.
yebai May 13, 2019
6725144
Removed some obsolete types.
yebai May 13, 2019
d879b95
Re-formating - no functionality change.
yebai May 13, 2019
2edfd4c
Remove `n_valid`.
yebai May 13, 2019
486dc38
Refactor leapfrog step function.
yebai May 16, 2019
7f7bed2
Cleanup.
yebai May 16, 2019
60dea85
More refactoring in rand momentum.
yebai May 16, 2019
fb935d6
Cleanup.
yebai May 16, 2019
596420b
Removed some comments.
yebai May 17, 2019
9ceef64
More cleanup.
yebai May 17, 2019
56766d6
Refactor: use `phasepoint` in `build_tree`.
yebai May 17, 2019
b56a3ca
Cleanup.
yebai May 17, 2019
e5375fe
Cleanup.
yebai May 17, 2019
ae3dfbd
Rename `prop` to `τ`.
yebai May 17, 2019
ef162d1
Cleanup.
yebai May 17, 2019
68f733a
Bugfix.
yebai May 17, 2019
be734f0
Fixed some type issues.
yebai May 18, 2019
0f04825
Re-formatting - no functionality change.
yebai May 18, 2019
c7c0442
Minor tweak of numerical error message.
yebai May 18, 2019
067b88d
`kinetic_energy` ==> `neg_energy`
yebai May 18, 2019
aa484a0
`potential_energy` ==> `neg_energy` and bugfixes.
yebai May 18, 2019
16599d4
Re-formatting - no functionality change.
yebai May 18, 2019
6f01571
Rename `logπ` and `logκ` with `ℓπ`, `ℓκ`.
yebai May 18, 2019
435ac08
Cleanup + bugfix.
yebai May 18, 2019
c6616d7
`∂logπ∂θ` ==> `∂ℓπ∂θ`.
yebai May 18, 2019
32115a5
Cleanup - no functionality change.
yebai May 18, 2019
502c9ef
Removed an obsolete comment.
yebai May 18, 2019
0d4aed3
test @warn in the constructor of PhasePoint
xukai92 May 20, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 0 additions & 164 deletions Manifest.toml

This file was deleted.

3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
name = "AdvancedHMC"
uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
authors = ["The Turing Team"]
version = "0.1.5"

[deps]
Expand All @@ -10,6 +9,8 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"

[extras]
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Expand Down
2 changes: 2 additions & 0 deletions src/AdvancedHMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ using Statistics: mean, var, middle
using LinearAlgebra: Symmetric, UpperTriangular, mul!, ldiv!, dot, I, diag, cholesky
using LazyArrays: BroadcastArray
using Random: GLOBAL_RNG, AbstractRNG
using Parameters: @unpack
using ArgCheck: @argcheck

import StatsBase: sample

Expand Down
3 changes: 2 additions & 1 deletion src/adaptation/Adaptation.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module Adaptation

import Base: string
import Base: string, rand
using Random: GLOBAL_RNG, AbstractRNG
using LinearAlgebra: Symmetric, UpperTriangular, mul!, ldiv!, dot, I, diag, cholesky
import LinearAlgebra, Statistics
using ..AdvancedHMC: DEBUG
Expand Down
29 changes: 29 additions & 0 deletions src/adaptation/precond.jl
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,35 @@ function Base.getproperty(dem::DenseEuclideanMetric, d::Symbol)
return d === :dim ? size(getfield(dem, :M⁻¹), 1) : getfield(dem, d)
end

# `rand` functions for `metric` types.
function Base.rand(
rng::AbstractRNG,
metric::UnitEuclideanMetric
)
r = randn(rng, metric.dim)
return r
end

function Base.rand(
rng::AbstractRNG,
metric::DiagEuclideanMetric
)
r = randn(rng, metric.dim)
r ./= metric.sqrtM⁻¹
return r
end

function Base.rand(
rng::AbstractRNG,
metric::DenseEuclideanMetric
)
r = randn(rng, metric.dim)
ldiv!(metric.cholM⁻¹, r)
return r
end

Base.rand(metric::AbstractMetric) = rand(GLOBAL_RNG, metric)

####
#### Preconditioner constructors
####
Expand Down
116 changes: 78 additions & 38 deletions src/hamiltonian.jl
Original file line number Diff line number Diff line change
@@ -1,60 +1,100 @@
# TODO: add a type for kinetic energy
# TODO: add cache for gradients by letting ∂logπ∂θ return both log density and gradient

struct Hamiltonian{M<:AbstractMetric, Tlogπ, T∂logπ∂θ}
metric::M
logπ::Tlogπ
logπ∂θ::T∂logπ∂θ
ℓπ::Tlogπ
ℓπ∂θ::T∂logπ∂θ
end

# Create a `Hamiltonian` with a new `M⁻¹`
(h::Hamiltonian)(M⁻¹) = Hamiltonian(h.metric(M⁻¹), h.logπ, h.∂logπ∂θ)
(h::Hamiltonian)(M⁻¹) = Hamiltonian(h.metric(M⁻¹), h.ℓπ, h.∂ℓπ∂θ)

∂H∂θ(h::Hamiltonian, θ::AbstractVector) = -h.∂logπ∂θ(θ)
∂H∂θ(h::Hamiltonian, θ::AbstractVector) = -h.∂ℓπ∂θ(θ)

∂H∂r(h::Hamiltonian{<:UnitEuclideanMetric}, r::AbstractVector) = copy(r)
∂H∂r(h::Hamiltonian{<:DiagEuclideanMetric}, r::AbstractVector) = h.metric.M⁻¹ .* r
∂H∂r(h::Hamiltonian{<:DenseEuclideanMetric}, r::AbstractVector) = h.metric.M⁻¹ * r

function hamiltonian_energy(h::Hamiltonian, θ::AbstractVector, r::AbstractVector)
K = kinetic_energy(h, r, θ)
if isnan(K)
K = Inf
@warn "Kinetic energy is `NaN` and is set to `Inf`."
end
V = potential_energy(h, θ)
if isnan(V)
V = Inf
@warn "Potential energy is `NaN` and is set to `Inf`."
struct DualValue{Tv<:AbstractFloat, Tg<:AbstractVector{Tv}}
value::Tv # Cached value, e.g. logπ(θ).
gradient::Tg # Cached gradient, e.g. ∇logπ(θ).
end

struct PhasePoint{T<:AbstractVector, V<:DualValue}
θ::T # Position variables / model parameters.
r::T # Momentum variables
ℓπ::V # Cached neg potential energy for the current θ.
ℓκ::V # Cached neg kinect energy for the current r.
function PhasePoint(θ::T, r::T, ℓπ::V, ℓκ::V) where {T,V}
yebai marked this conversation as resolved.
Show resolved Hide resolved
@argcheck length(θ) == length(r) == length(ℓπ.gradient) == length(ℓπ.gradient)
# if !(all(isfinite, θ) && all(isfinite, r) && all(isfinite, ℓπ) && all(isfinite, ℓκ))
if !(isfinite(θ) && isfinite(r) && isfinite(ℓπ) && isfinite(ℓκ))
@warn "The current proposal will be rejected (due to numerical error(s))..."
ℓκ = DualValue(-Inf, ℓκ.gradient)
ℓπ = DualValue(-Inf, ℓπ.gradient)
end
new{T,V}(θ, r, ℓπ, ℓκ)
end
return K + V
end

potential_energy(h::Hamiltonian, θ::AbstractVector) = -h.logπ(θ)
phasepoint(
h::Hamiltonian,
θ::T,
r::T;
ℓπ = DualValue(neg_energy(h, r, θ), ∂H∂θ(h, θ)),
ℓκ = DualValue(neg_energy(h, θ), ∂H∂r(h, r))
) where {T<:AbstractVector} = PhasePoint(θ, r, ℓπ, ℓκ)


Base.isfinite(v::DualValue) = all(isfinite, v.value) && all(isfinite, v.gradient)
Base.isfinite(v::AbstractVector) = all(isfinite, v)
Base.isfinite(z::PhasePoint) = isfinite(z.ℓπ) && isfinite(z.ℓκ)

###
### Negative energy (or log probability) functions.
### NOTE: the general form (i.e. non-Euclidean) of K depends on both θ and r.
###

neg_energy(z::PhasePoint) = z.ℓπ.value + z.ℓκ.value

# Kinetic energy
# NOTE: the general form of K depends on both θ and r
kinetic_energy(h::Hamiltonian{<:UnitEuclideanMetric}, r, θ) = sum(abs2, r) / 2
function kinetic_energy(h::Hamiltonian{<:DiagEuclideanMetric}, r, θ)
return sum(abs2(r[i]) * h.metric.M⁻¹[i] for i in 1:length(r)) / 2
neg_energy(h::Hamiltonian, θ::AbstractVector) = h.ℓπ(θ)

neg_energy(
h::Hamiltonian{<:UnitEuclideanMetric},
r::T,
θ::T
) where {T<:AbstractVector} = -sum(abs2, r) / 2

function neg_energy(
h::Hamiltonian{<:DiagEuclideanMetric},
r::T,
θ::T
) where {T<:AbstractVector}
_r = [abs2(r[i]) * h.metric.M⁻¹[i] for i in 1:length(r)]
return -sum(_r) / 2
end
function kinetic_energy(h::Hamiltonian{<:DenseEuclideanMetric}, r, θ)

function neg_energy(
h::Hamiltonian{<:DenseEuclideanMetric},
r::T,
θ::T
) where {T<:AbstractVector}
mul!(h.metric._temp, h.metric.M⁻¹, r)
return dot(r, h.metric._temp) / 2
return -dot(r, h.metric._temp) / 2
end

# Momentum sampler
function rand_momentum(rng::AbstractRNG, h::Hamiltonian{<:UnitEuclideanMetric})
return randn(rng, h.metric.dim)
end
function rand_momentum(rng::AbstractRNG, h::Hamiltonian{<:DiagEuclideanMetric})
r = randn(rng, h.metric.dim)
r ./= h.metric.sqrtM⁻¹
return r
end
function rand_momentum(rng::AbstractRNG, h::Hamiltonian{<:DenseEuclideanMetric})
r = randn(rng, h.metric.dim)
ldiv!(h.metric.cholM⁻¹, r)
return r
end
####
#### Momentum sampler
####

rand_momentum(
rng::AbstractRNG,
z::PhasePoint,
h::Hamiltonian
) = phasepoint(h, z.θ, rand(rng, h.metric))

rand_momentum(h::Hamiltonian) = rand_momentum(GLOBAL_RNG, h)
rand_momentum(
z::PhasePoint,
h::Hamiltonian
) = phasepoint(h, z.θ, rand(GLOBAL_RNG, h.metric))
Loading