Skip to content

Commit

Permalink
start porting tests in ArviZ
Browse files Browse the repository at this point in the history
  • Loading branch information
OriolAbril committed Mar 10, 2021
1 parent ed76e65 commit a7d52c7
Show file tree
Hide file tree
Showing 5 changed files with 706 additions and 61 deletions.
7 changes: 6 additions & 1 deletion pymc3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,12 @@ def __set_compiler_flags():

from pymc3 import gp, ode, sampling
from pymc3.aesaraf import *
from pymc3.backends import load_trace, save_trace
from pymc3.backends import (
load_trace,
predictions_to_inference_data,
save_trace,
to_inference_data,
)
from pymc3.backends.tracetab import *
from pymc3.blocking import *
from pymc3.data import *
Expand Down
1 change: 1 addition & 0 deletions pymc3/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
Saved backends can be loaded using `arviz.from_netcdf`
"""
from pymc3.backends.arviz import predictions_to_inference_data, to_inference_data
from pymc3.backends.ndarray import (
NDArray,
load_trace,
Expand Down
58 changes: 56 additions & 2 deletions pymc3/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,14 @@
import numpy as np
import xarray as xr

from aesara.gof.graph import ancestors
from aesara.graph.basic import ancestors
from aesara.tensor.var import TensorVariable
from arviz import InferenceData, concat, rcParams
from arviz.data.base import CoordSpec, DimSpec, dict_to_dataset, requires

import pymc3

from pymc3.model import modelcontext
from pymc3.sampling import _DefaultTrace
from pymc3.util import get_default_varnames

if TYPE_CHECKING:
Expand All @@ -42,6 +41,61 @@
Var = Any # pylint: disable=invalid-name


class _DefaultTrace:
"""
Utility for collecting samples into a dictionary.
Name comes from its similarity to ``defaultdict``:
entries are lazily created.
Parameters
----------
samples : int
The number of samples that will be collected, per variable,
into the trace.
Attributes
----------
trace_dict : Dict[str, np.ndarray]
A dictionary constituting a trace. Should be extracted
after a procedure has filled the `_DefaultTrace` using the
`insert()` method
"""

trace_dict: Dict[str, np.ndarray] = {}
_len: Optional[int] = None

def __init__(self, samples: int):
self._len = samples
self.trace_dict = {}

def insert(self, k: str, v, idx: int):
"""
Insert `v` as the value of the `idx`th sample for the variable `k`.
Parameters
----------
k: str
Name of the variable.
v: anything that can go into a numpy array (including a numpy array)
The value of the `idx`th sample from variable `k`
ids: int
The index of the sample we are inserting into the trace.
"""
value_shape = np.shape(v)

# initialize if necessary
if k not in self.trace_dict:
array_shape = (self._len,) + value_shape
self.trace_dict[k] = np.empty(array_shape, dtype=np.array(v).dtype)

# do the actual insertion
if value_shape == ():
self.trace_dict[k][idx] = v
else:
self.trace_dict[k][idx, :] = v


class InferenceDataConverter: # pylint: disable=too-many-instance-attributes
"""Encapsulate InferenceData specific logic."""

Expand Down
61 changes: 3 additions & 58 deletions pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

import aesara
import aesara.gradient as tg
import arviz
import numpy as np
import packaging
import xarray
Expand All @@ -38,6 +37,7 @@
import pymc3 as pm

from pymc3.aesaraf import inputvars
from pymc3.backends.arviz import _DefaultTrace
from pymc3.backends.base import BaseTrace, MultiTrace
from pymc3.backends.ndarray import NDArray
from pymc3.blocking import DictToArrayBijection
Expand Down Expand Up @@ -344,7 +344,7 @@ def sample(
Whether to return the trace as an :class:`arviz:arviz.InferenceData` (True) object or a `MultiTrace` (False)
Defaults to `False`, but we'll switch to `True` in an upcoming release.
idata_kwargs : dict, optional
Keyword arguments for :func:`arviz:arviz.from_pymc3`
Keyword arguments for :func:`pymc3.to_inference_data`
mp_ctx : multiprocessing.context.BaseContent
A multiprocessing context for parallel sampling. See multiprocessing
documentation for details.
Expand Down Expand Up @@ -639,7 +639,7 @@ def sample(
ikwargs = dict(model=model, save_warmup=not discard_tuned_samples, log_likelihood=False)
if idata_kwargs:
ikwargs.update(idata_kwargs)
idata = arviz.from_pymc3(trace, **ikwargs)
idata = pm.to_inference_data(trace, **ikwargs)

if compute_convergence_checks:
if draws - tune < 100:
Expand Down Expand Up @@ -1546,61 +1546,6 @@ def stop_tuning(step):
return step


class _DefaultTrace:
"""
Utility for collecting samples into a dictionary.
Name comes from its similarity to ``defaultdict``:
entries are lazily created.
Parameters
----------
samples : int
The number of samples that will be collected, per variable,
into the trace.
Attributes
----------
trace_dict : Dict[str, np.ndarray]
A dictionary constituting a trace. Should be extracted
after a procedure has filled the `_DefaultTrace` using the
`insert()` method
"""

trace_dict: Dict[str, np.ndarray] = {}
_len: Optional[int] = None

def __init__(self, samples: int):
self._len = samples
self.trace_dict = {}

def insert(self, k: str, v, idx: int):
"""
Insert `v` as the value of the `idx`th sample for the variable `k`.
Parameters
----------
k: str
Name of the variable.
v: anything that can go into a numpy array (including a numpy array)
The value of the `idx`th sample from variable `k`
ids: int
The index of the sample we are inserting into the trace.
"""
value_shape = np.shape(v)

# initialize if necessary
if k not in self.trace_dict:
array_shape = (self._len,) + value_shape
self.trace_dict[k] = np.empty(array_shape, dtype=np.array(v).dtype)

# do the actual insertion
if value_shape == ():
self.trace_dict[k][idx] = v
else:
self.trace_dict[k][idx, :] = v


def sample_posterior_predictive(
trace,
samples: Optional[int] = None,
Expand Down
Loading

0 comments on commit a7d52c7

Please sign in to comment.