Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add getparameters and setparameters!! #86

Merged
merged 24 commits into from
Oct 12, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
e2bdfb7
added state_from_transition, parameters and setparameters!!
torfjelde Oct 21, 2021
7fa8de0
Update src/AbstractMCMC.jl
torfjelde Oct 23, 2021
0a4fd17
renamed state_from_transition to updatestate!!
torfjelde Nov 17, 2021
28bdf91
adhere to julia convention
torfjelde Nov 17, 2021
86a7826
added docs
torfjelde Nov 17, 2021
e19cea7
fixed docs
torfjelde Nov 17, 2021
d86499f
fixed docs
torfjelde Nov 17, 2021
bce436d
added example for why updatestate!! is useful
torfjelde Nov 17, 2021
21f4d56
improved MixtureState example
torfjelde Nov 17, 2021
de0e5b2
further improvements to docs
torfjelde Nov 17, 2021
23b9119
renamed parameters and setparameters!! to values and setvalues!!
torfjelde Nov 19, 2021
b9f476c
fixed typo in docs
torfjelde Nov 19, 2021
f7b6096
fixed documenting values
torfjelde Nov 19, 2021
4ca57b0
improved and fixed some bugs in docs
torfjelde Nov 19, 2021
abebd59
fixed typo in docs
torfjelde Nov 19, 2021
d1d4642
renamed values and setvalues!! to realize and realize!!
torfjelde Dec 7, 2021
c6c9554
added model to updatestate!!
torfjelde Dec 7, 2021
d9f8585
Merge branch 'master' into tor/state-transition-related
torfjelde Oct 24, 2023
600d36c
Apply suggestions from code review
torfjelde Oct 10, 2024
1bfbef1
Update docs/src/api.md
torfjelde Oct 10, 2024
d9480d1
Apply suggestions from code review
torfjelde Oct 10, 2024
3f861bf
Merge branch 'master' into tor/state-transition-related
torfjelde Oct 10, 2024
ddb588c
Update docs/src/api.md
torfjelde Oct 10, 2024
d6ab10a
version bump
sunxd3 Oct 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 178 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,181 @@ For chains of this type, AbstractMCMC defines the following two methods.
AbstractMCMC.chainscat
AbstractMCMC.chainsstack
```

## Interacting with states of samplers

To make it a bit easier to interact with some arbitrary sampler state, we encourage implementations of `AbstractSampler` to implement the following methods:
```@docs
AbstractMCMC.realize
AbstractMCMC.realize!!
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
```
and optionally
```@docs
AbstractMCMC.updatestate!!(state, transition, state_prev)
```
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
These methods can also be useful for implementing samplers which wraps some inner samplers, e.g. a mixture of samplers.

### Example: `MixtureSampler`

In a `MixtureSampler` we need two things:
- `components`: collection of samplers.
- `weights`: collection of weights representing the probability of chosing the corresponding sampler.
torfjelde marked this conversation as resolved.
Show resolved Hide resolved

```julia
struct MixtureSampler{W,C} <: AbstractMCMC.AbstractSampler
components::C
weights::W
end
```

To implement the state, we need to keep track of a couple of things:
- `index`: the index of the sampler used in this `step`.
- `transition`: the transition resulting from this `step`.
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
- `states`: the current states of _all_ the components.
Two aspects of this might seem a bit strange:
1. We need to keep track of the states of _all_ components rather than just the state for the sampler we used previously.
2. We need to put the `transition` from the `step` into the state.

The reason for (1) is that lots of samplers keep track of more than just the previous realizations of the variables, e.g. in `AdvancedHMC.jl` we keep track of the momentum used, the metric used, etc.
torfjelde marked this conversation as resolved.
Show resolved Hide resolved

For (2) the reason is similar: some samplers might keep track of the variables _in the state_ differently, e.g. you might have a sampler which is _independent_ of the current realizations and the state is simply `nothing`.
torfjelde marked this conversation as resolved.
Show resolved Hide resolved

Hence, we need the `transition`, which should always contain the realizations, to make sure we can resume from the same point in the space in the next `step`.
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
```julia
struct MixtureState{T,S}
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
index::Int
transition::T
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
states::S
end
```
The `step` for a `MixtureSampler` is defined by the following generative process
```math
\begin{aligned}
i &\sim \mathrm{Categorical}(w_1, \dots, w_k) \\
X_t &\sim \mathcal{K}_i(\cdot \mid X_{t - 1})
\end{aligned}
```
where ``\mathcal{K}_i`` denotes the i-th kernel/sampler, and ``w_i`` denotes the weight/probability of choosing the i-th sampler.
[`AbstractMCMC.updatestate!!`](@ref) comes into play in defining/computing ``\mathcal{K}_i(\cdot \mid X_{t - 1})`` since ``X_{t - 1}`` could be coming from a different sampler.
torfjelde marked this conversation as resolved.
Show resolved Hide resolved

If we let `state` be the current `MixtureState`, `i` the current component, and `i_prev` is the previous component we sampled from, then this translates into the following piece of code:

```julia
# Update the corresponding state, i.e. `state.states[i]`, using
# the state and transition from the previous iteration.
state_current = AbstractMCMC.updatestate!!(
state.states[i], state.states[i_prev], state.transition
)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved

# Take a `step` for this sampler using the updated state.
transition, state_current = AbstractMCMC.step(
rng, model, sampler_current, sampler_state;
kwargs...
)
```

The full [`AbstractMCMC.step`](@ref) implementation would then be something like:

```julia
function AbstractMCMC.step(rng, model::AbstractMCMC.AbstractModel, sampler::MixtureSampler, state; kwargs...)
# Sample the component to use in this `step`.
i = rand(Categorical(sampler.weights))
sampler_current = sampler.components[i]

# Update the corresponding state, i.e. `state.states[i]`, using
# the state and transition from the previous iteration.
i_prev = state.index
state_current = AbstractMCMC.updatestate!!(
model, state.states[i], state.states[i_prev], state.transition
)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved

# Take a `step` for this sampler using the updated state.
transition, state_current = AbstractMCMC.step(
rng, model, sampler_current, state_current;
kwargs...
)

# Create the new states.
# NOTE: Code below will result in `states_new` being a `Vector`.
# If we wanted to allow usage of alternative containers, e.g. `Tuple`,
# it would be better to use something like `@set states[i] = state_current`
# where `@set` is from Setfield.jl.
states_new = map(1:length(state.states)) do j
if j == i
# Replace the i-th state with the new one.
state_current
else
# Otherwise we just carry over the previous ones.
state.states[j]
end
end

# Create the new `MixtureState`.
state_new = MixtureState(i, transition, states_new)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved

return transition, state_new
end
```

And for the initial [`AbstractMCMC.step`](@ref) we have:

```julia
function AbstractMCMC.step(rng, model::AbstractMCMC.AbstractModel, sampler::MixtureSampler; kwargs...)
# Initialize every state.
transitions_and_states = map(sampler.components) do spl
AbstractMCMC.step(rng, model, spl; kwargs...)
end

# Sample the component to use this `step`.
i = rand(Categorical(sampler.weights))
# Extract the corresponding transition.
transition = first(transitions_and_states[i])
# Extract states.
states = map(last, transitions_and_states)
# Create new `MixtureState`.
state = MixtureState(i, transition, states)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved

return transition, state
end
```

Suppose we then wanted to use this with some of the packages which implements AbstractMCMC.jl's interface, e.g. [`AdvancedMH.jl`](https://github.com/TuringLang/AdvancedMH.jl), then we'd simply have to implement `realize` and `realize!!`:
torfjelde marked this conversation as resolved.
Show resolved Hide resolved

```julia
function AbstractMCMC.updatestate!!(model, ::AdvancedMH.Transition, state_prev::AdvancedMH.Transition)
# Let's `deepcopy` just to be certain.
return deepcopy(state_prev)
end
```
torfjelde marked this conversation as resolved.
Show resolved Hide resolved

To use `MixtureSampler` with two samplers `sampler1` and `sampler2` from `AdvancedMH.jl` as components, we'd simply do

```julia
sampler = MixtureSampler([sampler1, sampler2], [0.1, 0.9])
transition, state = AbstractMCMC.step(rng, model, sampler)
while ...
transition, state = AbstractMCMC.step(rng, model, sampler, state)
end
```

As a final note, there is one potential issue we haven't really addressed in the above implementation: a lot of samplers have their own implementations of `AbstractMCMC.AbstractModel` which means that we would also have to ensure that all the different samplers we are using would be compatible with the same model. A very easy way to fix this would be to just add a struct called `ManyModels` supporting `getindex`, e.g. `models[i]` would return the i-th `model`:

```julia
struct ManyModels{M} <: AbstractMCMC.AbstractModel
models::M
end

Base.getindex(model::ManyModels, I...) = model.models[I...]
```

Then the above `step` would just extract the `model` corresponding to the current sampler:

```julia
# Take a `step` for this sampler using the updated state.
transition, state_current = AbstractMCMC.step(
rng, model[i], sampler_current, state_current;
kwargs...
)
```

This issue should eventually disappear as the community moves towards a unified approach to implement `AbstractMCMC.AbstractModel`.
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
30 changes: 30 additions & 0 deletions src/AbstractMCMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,36 @@
"""
struct MCMCSerial <: AbstractMCMCEnsemble end

"""
updatestate!!(model, state, transition_prev[, state_prev])

Return new instance of `state` using information from `model`, `transition_prev` and, optionally, `state_prev`.

Defaults to `realize!!(state, realize(transition_prev))`.
"""
function updatestate!!(model, state, transition_prev, state_prev)
return updatestate!!(state, transition_prev)

Check warning on line 90 in src/AbstractMCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/AbstractMCMC.jl#L89-L90

Added lines #L89 - L90 were not covered by tests
end
updatestate!!(model, state, transition) = realize!!(state, realize(transition))

Check warning on line 92 in src/AbstractMCMC.jl

View check run for this annotation

Codecov / codecov/patch

src/AbstractMCMC.jl#L92

Added line #L92 was not covered by tests

"""
realize!!(state, realization)

Update the realization of the `state` with `realization` and return it.

If `state` can be updated in-place, it is expected that this function returns `state` with updated
realize. Otherwise a new `state` object with the new `realization` is returned.
"""
function realize!! end

"""
realize(transition)

Return the realization of the random variables present in `transition`.
"""
function realize end

torfjelde marked this conversation as resolved.
Show resolved Hide resolved

torfjelde marked this conversation as resolved.
Show resolved Hide resolved
include("samplingstats.jl")
include("logging.jl")
include("interface.jl")
Expand Down
Loading