Skip to content

Commit

Permalink
Prevent Model access in random function of CustomDist
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Dec 8, 2022
1 parent 439a973 commit bcfff6d
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 3 deletions.
14 changes: 12 additions & 2 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
find_size,
shape_from_dims,
)
from pymc.exceptions import BlockModelAccessError
from pymc.logprob.abstract import (
MeasurableVariable,
_get_measurable_outputs,
Expand All @@ -54,6 +55,7 @@
_logprob,
)
from pymc.logprob.rewriting import logprob_rewrites_db
from pymc.model import BlockModelAccess
from pymc.printing import str_for_dist
from pymc.pytensorf import collect_default_updates, convert_observed_data
from pymc.util import UNSET, _add_future_warning_tag
Expand Down Expand Up @@ -662,7 +664,10 @@ def rv_op(
size = normalize_size_param(size)
dummy_size_param = size.type()
dummy_dist_params = [dist_param.type() for dist_param in dist_params]
dummy_rv = random(*dummy_dist_params, dummy_size_param)
with BlockModelAccess(
error_msg_on_access="Model variables cannot be created in the random function. Use the `.dist` API"
):
dummy_rv = random(*dummy_dist_params, dummy_size_param)
dummy_params = [dummy_size_param] + dummy_dist_params
dummy_updates_dict = collect_default_updates(dummy_params, (dummy_rv,))

Expand Down Expand Up @@ -1050,7 +1055,12 @@ def is_symbolic_random(self, random, dist_params):
# Try calling random with symbolic inputs
try:
size = normalize_size_param(None)
out = random(*dist_params, size)
with BlockModelAccess(
error_msg_on_access="Model variables cannot be created in the random function. Use the `.dist` API"
):
out = random(*dist_params, size)
except BlockModelAccessError:
raise
except Exception:
# If it fails we assume it was not
return False
Expand Down
4 changes: 4 additions & 0 deletions pymc/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,7 @@ class TruncationError(RuntimeError):

class NotConstantValueError(ValueError):
pass


class BlockModelAccessError(RuntimeError):
pass
17 changes: 16 additions & 1 deletion pymc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,13 @@
from pymc.data import GenTensorVariable, Minibatch
from pymc.distributions.logprob import _joint_logp
from pymc.distributions.transforms import _default_transform
from pymc.exceptions import ImputationWarning, SamplingError, ShapeError, ShapeWarning
from pymc.exceptions import (
BlockModelAccessError,
ImputationWarning,
SamplingError,
ShapeError,
ShapeWarning,
)
from pymc.initial_point import make_initial_point_fn
from pymc.pytensorf import (
PointFunc,
Expand Down Expand Up @@ -195,6 +201,8 @@ def get_context(cls, error_if_none=True) -> Optional[T]:
if error_if_none:
raise TypeError(f"No {cls} on context stack")
return None
if isinstance(candidate, BlockModelAccess):
raise BlockModelAccessError(candidate.error_msg_on_access)
return candidate

def get_contexts(cls) -> List[T]:
Expand Down Expand Up @@ -1798,6 +1806,13 @@ def point_logps(self, point=None, round_vals=2):
Model._context_class = Model


class BlockModelAccess(Model):
"""This class can be used to prevent user access to Model contexts"""

def __init__(self, *args, error_msg_on_access="Model access is blocked", **kwargs):
self.error_msg_on_access = error_msg_on_access


def set_data(new_data, model=None, *, coords=None):
"""Sets the value of one or more data container variables. Note that the shape is also
dynamic, it is updated when the value is changed. See the examples below for two common
Expand Down
12 changes: 12 additions & 0 deletions pymc/tests/distributions/test_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
)
from pymc.distributions.shape_utils import change_dist_size, rv_size_is_none, to_tuple
from pymc.distributions.transforms import log
from pymc.exceptions import BlockModelAccessError
from pymc.logprob.abstract import get_measurable_outputs, logcdf
from pymc.model import Model
from pymc.sampling import draw, sample
Expand Down Expand Up @@ -479,6 +480,17 @@ def custom_random(mu, sigma, size):
assert isinstance(new_lognormal.owner.op, CustomSymbolicDistRV)
assert tuple(new_lognormal.shape.eval()) == (2, 5, 10)

def test_error_model_access(self):
def random(size):
return pm.Flat("Flat", size=size)

with pm.Model() as m:
with pytest.raises(
BlockModelAccessError,
match="Model variables cannot be created in the random function",
):
CustomDist("custom_dist", random=random)


class TestSymbolicRandomVarible:
def test_inline(self):
Expand Down

0 comments on commit bcfff6d

Please sign in to comment.