diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 597ea03500..c96316a8cd 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -46,6 +46,7 @@ find_size, shape_from_dims, ) +from pymc.exceptions import BlockModelAccessError from pymc.logprob.abstract import ( MeasurableVariable, _get_measurable_outputs, @@ -54,6 +55,7 @@ _logprob, ) from pymc.logprob.rewriting import logprob_rewrites_db +from pymc.model import BlockModelAccess from pymc.printing import str_for_dist from pymc.pytensorf import collect_default_updates, convert_observed_data from pymc.util import UNSET, _add_future_warning_tag @@ -662,7 +664,10 @@ def rv_op( size = normalize_size_param(size) dummy_size_param = size.type() dummy_dist_params = [dist_param.type() for dist_param in dist_params] - dummy_rv = random(*dummy_dist_params, dummy_size_param) + with BlockModelAccess( + error_msg_on_access="Model variables cannot be created in the random function. Use the `.dist` API" + ): + dummy_rv = random(*dummy_dist_params, dummy_size_param) dummy_params = [dummy_size_param] + dummy_dist_params dummy_updates_dict = collect_default_updates(dummy_params, (dummy_rv,)) @@ -1050,7 +1055,12 @@ def is_symbolic_random(self, random, dist_params): # Try calling random with symbolic inputs try: size = normalize_size_param(None) - out = random(*dist_params, size) + with BlockModelAccess( + error_msg_on_access="Model variables cannot be created in the random function. Use the `.dist` API" + ): + out = random(*dist_params, size) + except BlockModelAccessError: + raise except Exception: # If it fails we assume it was not return False diff --git a/pymc/exceptions.py b/pymc/exceptions.py index 5b4141f303..7a18167d5c 100644 --- a/pymc/exceptions.py +++ b/pymc/exceptions.py @@ -82,3 +82,7 @@ class TruncationError(RuntimeError): class NotConstantValueError(ValueError): pass + + +class BlockModelAccessError(RuntimeError): + pass diff --git a/pymc/model.py b/pymc/model.py index 8c9d85af2b..2bb703f144 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -53,7 +53,13 @@ from pymc.data import GenTensorVariable, Minibatch from pymc.distributions.logprob import _joint_logp from pymc.distributions.transforms import _default_transform -from pymc.exceptions import ImputationWarning, SamplingError, ShapeError, ShapeWarning +from pymc.exceptions import ( + BlockModelAccessError, + ImputationWarning, + SamplingError, + ShapeError, + ShapeWarning, +) from pymc.initial_point import make_initial_point_fn from pymc.pytensorf import ( PointFunc, @@ -195,6 +201,8 @@ def get_context(cls, error_if_none=True) -> Optional[T]: if error_if_none: raise TypeError(f"No {cls} on context stack") return None + if isinstance(candidate, BlockModelAccess): + raise BlockModelAccessError(candidate.error_msg_on_access) return candidate def get_contexts(cls) -> List[T]: @@ -1798,6 +1806,13 @@ def point_logps(self, point=None, round_vals=2): Model._context_class = Model +class BlockModelAccess(Model): + """This class can be used to prevent user access to Model contexts""" + + def __init__(self, *args, error_msg_on_access="Model access is blocked", **kwargs): + self.error_msg_on_access = error_msg_on_access + + def set_data(new_data, model=None, *, coords=None): """Sets the value of one or more data container variables. Note that the shape is also dynamic, it is updated when the value is changed. See the examples below for two common diff --git a/pymc/tests/distributions/test_distribution.py b/pymc/tests/distributions/test_distribution.py index 50b263238b..ae7d625ac3 100644 --- a/pymc/tests/distributions/test_distribution.py +++ b/pymc/tests/distributions/test_distribution.py @@ -45,6 +45,7 @@ ) from pymc.distributions.shape_utils import change_dist_size, rv_size_is_none, to_tuple from pymc.distributions.transforms import log +from pymc.exceptions import BlockModelAccessError from pymc.logprob.abstract import get_measurable_outputs, logcdf from pymc.model import Model from pymc.sampling import draw, sample @@ -479,6 +480,17 @@ def custom_random(mu, sigma, size): assert isinstance(new_lognormal.owner.op, CustomSymbolicDistRV) assert tuple(new_lognormal.shape.eval()) == (2, 5, 10) + def test_error_model_access(self): + def random(size): + return pm.Flat("Flat", size=size) + + with pm.Model() as m: + with pytest.raises( + BlockModelAccessError, + match="Model variables cannot be created in the random function", + ): + CustomDist("custom_dist", random=random) + class TestSymbolicRandomVarible: def test_inline(self):