Skip to content

Commit

Permalink
simplify README example
Browse files Browse the repository at this point in the history
  • Loading branch information
Rémi Louf authored and rlouf committed Sep 29, 2020
1 parent 218ed05 commit 9b0c1e7
Showing 1 changed file with 92 additions and 32 deletions.
124 changes: 92 additions & 32 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,20 @@ MCX's philosophy
3. Inference should be performant. Sequential inference should be a first class
citizen.

See the [documentation](https://rlouf.github.io/mcx) for more information.
See the [documentation](https://rlouf.github.io/mcx) for more information. See [this issue](https://github.com/rlouf/mcx/issues/1) for an updated roadmap for v0.1.

## Current API

Note that there are still many moving pieces in `mcx` and the API may change
slightly. In particular, the choice of `<~` for random variable assignement may change. This is valid `mcx` code:
slightly.

```python
from jax import numpy as np
import mcx
import mcx.distributions as dist

x_data = np.array([2.3, 8.2, 1.8])
y_data = np.array([1.7, 7., 3.1])
rng_key = jax.random.PRNGKey(0)
observations = {'x': x_data, 'predictions': y_data, 'lmbda': 3.}

@mcx.model
def linear_regression(x, lmbda=1.):
Expand All @@ -46,46 +46,84 @@ def linear_regression(x, lmbda=1.):
predictions <~ dist.Normal(y, scale)
return predictions

rng_key = jax.random.PRNGKey(0)

# Sample the model forward, conditioning on the value of `x`
mcx.sample_forward(
kernel = mcx.HMC(100)
sampler = mcx.sampler(
rng_key,
linear_regression,
x=x_data,
num_samples=10_000
kernel,
**observations
)
posterior = sampler.run()
```

# Sample from the posterior distribution using HMC
kernel = mcx.HMC(num_integration_steps=100)
## MCX's future

observations = {'x': x_data, 'predictions': y_data, 'lmbda': 3.}
MCX's core is very flexible, so we can start considering the following
applications:

- **Neural network layers:** You can follow discussions about the API in [this Pull Request](https://github.com/rlouf/mcx/pull/16).
- **Programs with stochastic support:** Discussion in this [Issue](https://github.com/rlouf/mcx/issues/37).
- **Tools for causal inference:** Made easier by the internal representation as a
graph.

You are more than welcome to contribute to these discussions, or suggest
potential future directions.


## Batch sampling

Like most PPL, MCX implements a batch sampling runtime:

```python
sampler = mcx.sampler(
rng_key,
linear_regression,
kernel,
**observations
)
trace = sampler.run()

posterior = sampler.run()
```

The warmup trace is discarded by default but you can obtain it by running:

```python
warmup_posterior = sampler.warmup()
posterior = sampler.run()
```

You can extract more samples from the chain after a run and combine the
two traces:

```python
new_posterior = sampler.run()
final_posterior = posterior + new_posterior
```

## Currently implemented
By default MCX will sample using a python `for` loop and display a progress bar.
For faster sampling (but without progress bar) you can use:

```python
posterior = sampler.run(accelerate=True)
```

See [this issue](https://github.com/rlouf/mcx/issues/1) for an updated roadmap for v0.1.
One could use the combination in a notebook to first get a lower bound on the
sampling rate before deciding on a number of samples.

You can follow discussions about the API for neural network layers in [this Pull
Request](https://github.com/rlouf/mcx/pull/16). You are welcome to contribute to
the discussion.

## Iterative sampling

Sampling the posterior is an iterative process. Yet most libraries only provide batch sampling. The generator runtime is already implemented in `mcx`, which opens many possibilities such as:
Sampling the posterior is an iterative process. Yet most libraries only provide
batch sampling. The generator runtime is already implemented in `mcx`, which
opens many possibilities such as:

- Dynamical interruption of inference (say after getting a set number of effective samples);
- Dynamical interruption of inference (say after getting a set number of
effective samples);
- Real-time monitoring of inference with something like tensorboard;
- Easier debugging.

```python
samples = mcx.generate(
samples = mcx.iterative_sampler(
rng_key,
linear_regression,
kernel,
Expand All @@ -96,7 +134,28 @@ for sample in samples:
print(sample)
```

## Sequential Markov Chain Monte Carlo
### Note

`sampler` and `iterative_sampler` share a very similar API and philosophy, they
will likely be merged before the 0.1 release:

```python
sampler = mcx.sampler(
rng_key,
linear_regression,
kernel,
**observations
)

posterior = sampler.run()

for sample in sampler:
print(sample)
```

so it is possible to switch between the two execution modes seemlessly.

## Sequential sampling

One of Bayesian statistics' promises is the ability to update one's knowledge as
more data becomes available. In practice, few libraries allow this in a
Expand All @@ -105,29 +164,30 @@ application:

- Training models with data that does not fit in memory. For deep models,
obviously, but not necessarily;
- Training models where data is not all available at a point in time, but rather
progressively arrives. Think A/B testing for instance, where we need to update
our knowledge as more users arive.
- Training models where data progressively arrives. Think A/B testing for
instance, where we need to update our knowledge as more users arive.

Sequential Markov Chain Monte-Carlo is already implemented in `mcx`. However, more work is needed to diagnose the obtained samples and possibly stop sampling dynamically.
Sequential Markov Chain Monte-Carlo is already implemented in `mcx`. However,
more work is needed to diagnose the obtained samples and possibly stop sampling
dynamically.

```python
sampler = mcx.sequential(
sampler = mcx.sequential_sampler(
rng_key,
linear_regression,
kernel,
**observations
)
posterior = sampler.run()

trace_1 = sampler.update(**observations_1)
trace_2 = sampler.update(**observations_2)
updated_posterior = sampler.update(posterior, **new_observations)
```


## Important note

MCX is a building atop the excellent ideas that have come up in the past 10
years of probablistic programming, whether from Stan (NUTS and the very
knowledgeable community), PyMC3 & PyMC4 (for its simple API), Tensorflow
Probability (for its shape system and inference vectorization), (Num)Pyro (for
the use of JAX in the backend), Anglican, and many that I forget.
the use of JAX in the backend), Gen.jl and Turing.jl (for composable inference),
Soss.jl (generative model API), Anglican, and many that I forget.

0 comments on commit 9b0c1e7

Please sign in to comment.