Skip to content

Commit

Permalink
remove geweke diagnostic (#1545)
Browse files Browse the repository at this point in the history
* remove geweke diagnostic

* remove unused import, update changelog
  • Loading branch information
aloctavodia authored Feb 8, 2021
1 parent 2dff682 commit 6fa1ce8
Show file tree
Hide file tree
Showing 7 changed files with 7 additions and 274 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
### New features
* Added `to_zarr` method to InferenceData
* Added `from_zarr` method to InferenceData
* Added confidence interval band to auto-correlation plot ([1535](https://github.com/arviz-devs/arviz/pull/1535))
* Added confidence interval band to auto-correlation plot ([1535](https://github.com/arviz-devs/arviz/pull/1535))

### Maintenance and fixes

### Deprecation
* Removed Geweke diagnostic ([1545](https://github.com/arviz-devs/arviz/pull/1545))

### Documentation

Expand Down
1 change: 0 additions & 1 deletion arviz/stats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
"ess",
"rhat",
"mcse",
"geweke",
"autocorr",
"autocov",
"make_ufunc",
Expand Down
85 changes: 2 additions & 83 deletions arviz/stats/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from scipy import stats

from ..data import convert_to_dataset
from ..utils import Numba, _numba_var, _stack, _var_names, conditional_jit
from ..utils import Numba, _numba_var, _stack, _var_names
from .density_utils import histogram as _histogram
from .stats_utils import _circular_standard_deviation, _sqrt
from .stats_utils import autocov as _autocov
Expand All @@ -17,7 +17,7 @@
from .stats_utils import stats_variance_2d as svar
from .stats_utils import wrap_xarray_ufunc as _wrap_xarray_ufunc

__all__ = ["bfmi", "ess", "rhat", "mcse", "geweke"]
__all__ = ["bfmi", "ess", "rhat", "mcse"]


def bfmi(data):
Expand Down Expand Up @@ -415,87 +415,6 @@ def mcse(data, *, var_names=None, method="mean", prob=None, dask_kwargs=None):
)


@conditional_jit(forceobj=True)
def geweke(ary, first=0.1, last=0.5, intervals=20):
r"""Compute z-scores for convergence diagnostics.
Compare the mean of the first % of series with the mean of the last % of series. x is divided
into a number of segments for which this difference is computed. If the series is converged,
this score should oscillate between -1 and 1.
Parameters
----------
ary : 1D array-like
The trace of some stochastic parameter.
first : float
The fraction of series at the beginning of the trace.
last : float
The fraction of series at the end to be compared with the section
at the beginning.
intervals : int
The number of segments.
Returns
-------
scores : list [[]]
Return a list of [i, score], where i is the starting index for each interval and score the
Geweke score on the interval.
Notes
-----
The Geweke score on some series x is computed by:
.. math:: \frac{E[x_s] - E[x_e]}{\sqrt{V[x_s] + V[x_e]}}
where :math:`E` stands for the mean, :math:`V` the variance,
:math:`x_s` a section at the start of the series and
:math:`x_e` a section at the end of the series.
References
----------
* Geweke (1992)
"""
# Filter out invalid intervals
return _geweke(ary, first, last, intervals)


def _geweke(ary, first, last, intervals):
_numba_flag = Numba.numba_flag
for interval in (first, last):
if interval <= 0 or interval >= 1:
raise ValueError("Invalid intervals for Geweke convergence analysis", (first, last))
if first + last >= 1:
raise ValueError("Invalid intervals for Geweke convergence analysis", (first, last))

# Initialize list of z-scores
zscores = []

# Last index value
end = len(ary) - 1

# Start intervals going up to the <last>% of the chain
last_start_idx = (1 - last) * end

# Calculate starting indices
start_indices = np.linspace(0, last_start_idx, num=intervals, endpoint=True, dtype=int)

# Loop over start indices
for start in start_indices:
# Calculate slices
first_slice = ary[start : start + int(first * (end - start))]
last_slice = ary[int(end - last * (end - start)) :]

z_score = first_slice.mean() - last_slice.mean()
if _numba_flag:
z_score /= _sqrt(svar(first_slice), svar(last_slice))
else:
z_score /= np.sqrt(first_slice.var() + last_slice.var())

zscores.append([start, z_score])

return np.array(zscores)


def ks_summary(pareto_tail_indices):
"""Display a summary of Pareto tail indices.
Expand Down
27 changes: 1 addition & 26 deletions arviz/tests/base_tests/test_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ...data import from_cmdstan, load_arviz_data
from ...plots.plot_utils import xarray_var_iter
from ...rcparams import rc_context, rcParams
from ...stats import bfmi, ess, geweke, mcse, rhat
from ...stats import bfmi, ess, mcse, rhat
from ...stats.diagnostics import (
_ess,
_ess_quantile,
Expand Down Expand Up @@ -466,31 +466,6 @@ def test_multichain_summary_array(self, draws, chains):
else:
assert round(rhat_hat, 3) == round(rhat_hat_, 3)

def test_geweke(self):
first = 0.1
last = 0.5
intervals = 100
data = np.random.randn(100000)
gw_stat = geweke(data, first, last, intervals)

# all geweke values should be between -1 and 1 for this many draws from a
# normal distribution
assert ((gw_stat[:, 1] > -1) | (gw_stat[:, 1] < 1)).all()

assert gw_stat.shape[0] == intervals
assert 100000 * last - gw_stat[:, 0].max() == 1

def test_geweke_bad_interval(self):
# lower bound
with pytest.raises(ValueError):
geweke(np.random.randn(10), first=0)
# upper bound
with pytest.raises(ValueError):
geweke(np.random.randn(10), last=1)
# sum larger than 1
with pytest.raises(ValueError):
geweke(np.random.randn(10), first=0.9, last=0.9)

def test_ks_summary(self):
"""Instead of psislw data, this test uses fake data."""
pareto_tail_indices = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.2])
Expand Down
14 changes: 1 addition & 13 deletions arviz/tests/base_tests/test_diagnostics_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from ...data import load_arviz_data
from ...rcparams import rcParams
from ...stats import bfmi, geweke, mcse, rhat
from ...stats import bfmi, mcse, rhat
from ...stats.diagnostics import _mc_error, ks_summary
from ...utils import Numba
from ..helpers import running_on_ci
Expand Down Expand Up @@ -77,18 +77,6 @@ def test_ks_summary_numba():
assert Numba.numba_flag == state


def test_geweke_numba():
"""Numba test for geweke."""
state = Numba.numba_flag
data = np.random.randn(100)
Numba.disable_numba()
non_numba = geweke(data)
Numba.enable_numba()
with_numba = geweke(data)
assert np.allclose(non_numba, with_numba)
assert Numba.numba_flag == state


@pytest.mark.parametrize("batches", (1, 20))
@pytest.mark.parametrize("circular", (True, False))
def test_mcse_error_numba(batches, circular):
Expand Down
1 change: 0 additions & 1 deletion doc/source/api/diagnostics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ Diagnostics
:toctree: generated/

bfmi
geweke
ess
rhat
mcse
150 changes: 1 addition & 149 deletions doc/source/user_guide/Numba.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -144,154 +144,6 @@
"**In certain scenarios, Numba outperforms numpy!** **Let's see Numba's effect on a few of ArviZ functions**"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"False"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Numba.disable_numba() # This disables numba\n",
"Numba.numba_flag"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"data = np.random.randn(1000000)\n",
"smaller_data = np.random.randn(1000)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"74.7 ms ± 3.53 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
}
],
"source": [
"%timeit geweke(data)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3.85 ms ± 158 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"%timeit geweke(smaller_data)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Numba.enable_numba() # This will re-enable numba\n",
"Numba.numba_flag # This indicates the status of Numba"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"43.8 ms ± 22.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"%timeit geweke(data)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.51 ms ± 584 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
]
}
],
"source": [
"%timeit geweke(smaller_data)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Numba.enable_numba()\n",
"Numba.numba_flag"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Numba speeds up the code by a factor of two approximately. Let's check some other method**"
]
},
{
"cell_type": "code",
"execution_count": 16,
Expand Down Expand Up @@ -441,4 +293,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}

0 comments on commit 6fa1ce8

Please sign in to comment.