Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue with pickling Deterministic #4120

Merged
merged 6 commits into from
Sep 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 17 additions & 18 deletions pymc3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import collections
import functools
import itertools
import threading
import warnings
Expand Down Expand Up @@ -1903,14 +1902,22 @@ def _walk_up_rv(rv, formatting='plain'):
return all_rvs


def _repr_deterministic_rv(rv, formatting='plain'):
"""Make latex string for a Deterministic variable"""
if formatting == 'latex':
return r"$\text{{{name}}} \sim \text{{Deterministic}}({args})$".format(
name=rv.name, args=r",~".join(_walk_up_rv(rv, formatting=formatting)))
else:
return "{name} ~ Deterministic({args})".format(
name=rv.name, args=", ".join(_walk_up_rv(rv, formatting=formatting)))
class DeterministicWrapper(tt.TensorVariable):
def _str_repr(self, formatting='plain'):
if formatting == 'latex':
return r"$\text{{{name}}} \sim \text{{Deterministic}}({args})$".format(
name=self.name, args=r",~".join(_walk_up_rv(self, formatting=formatting)))
else:
return "{name} ~ Deterministic({args})".format(
name=self.name, args=", ".join(_walk_up_rv(self, formatting=formatting)))

def _repr_latex_(self):
return self._str_repr(formatting='latex')

__latex__ = _repr_latex_

def __str__(self):
return self._str_repr(formatting='plain')


def Deterministic(name, var, model=None, dims=None):
Expand All @@ -1929,15 +1936,7 @@ def Deterministic(name, var, model=None, dims=None):
var = var.copy(model.name_for(name))
model.deterministics.append(var)
model.add_random_variable(var, dims)
var._repr_latex_ = functools.partial(_repr_deterministic_rv, var, formatting='latex')
var.__latex__ = var._repr_latex_

# simply assigning var.__str__ is not enough, since str() will default to the class-
# defined __str__ anyway; see https://stackoverflow.com/a/5918210/1692028
old_type = type(var)
new_type = type(old_type.__name__ + '_pymc3_Deterministic', (old_type,),
{'__str__': functools.partial(_repr_deterministic_rv, var, formatting='plain')})
var.__class__ = new_type
var.__class__ = DeterministicWrapper # adds str and latex functionality
Copy link
Contributor

@MarcoGorelli MarcoGorelli Sep 20, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible for var to be a numpy array here*, in which case, would assigning its .__class__ attribute to a subclass of tt.TensorVariable cause problems? (I'm new to PyMC3 so sorry for the basic question :) )

*the definition of floatX would suggest it

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will indeed cause errors if var is not a TensorVariable. I think a deterministic will always be a TensorVariable, at least it is in all the test scripts and demos. It's a good question, though; I'm not sure there is any formal specification requiring that it has to be. (But I cannot really see how a numpy-typed deterministic makes sense, given that all model factors are TensorVariables anyway (and all results of e.g. numpy * tensorvariable etc. are also of type TensorVariable))


return var

Expand Down
8 changes: 8 additions & 0 deletions pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1774,6 +1774,14 @@ def test___str__(self):
for str_repr in self.expected_str:
assert str_repr in model_str

def test_str(self):
for distribution, str_repr in zip(self.distributions, self.expected_str):
assert str(distribution) == str_repr

model_str = str(self.model)
for str_repr in self.expected_str:
assert str_repr in model_str


def test_discrete_trafo():
with pytest.raises(ValueError) as err:
Expand Down
25 changes: 25 additions & 0 deletions pymc3/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import theano
import theano.tensor as tt
import numpy as np
import pickle
import pandas as pd
import numpy.testing as npt
import unittest
Expand Down Expand Up @@ -421,3 +422,27 @@ def test_tempered_logp_dlogp():

npt.assert_allclose(func_nograd(x), func(x)[0])
npt.assert_allclose(func_temp_nograd(x), func_temp(x)[0])


def test_model_pickle(tmpdir):
"""Tests that PyMC3 models are pickleable"""
with pm.Model() as model:
x = pm.Normal('x')
pm.Normal('y', observed=1)

file_path = tmpdir.join("model.p")
with open(file_path, 'wb') as buff:
pickle.dump(model, buff)


def test_model_pickle_deterministic(tmpdir):
"""Tests that PyMC3 models are pickleable"""
with pm.Model() as model:
x = pm.Normal('x')
z = pm.Normal("z")
pm.Deterministic("w", x/z)
pm.Normal('y', observed=1)

file_path = tmpdir.join("model.p")
with open(file_path, 'wb') as buff:
pickle.dump(model, buff)