diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index e08f3010ef..f1daf593aa 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -12,8 +12,9 @@ - The `CAR` distribution has been added to allow for use of conditional autoregressions which often are used in spatial and network models. - The dimensionality of model variables can now be parametrized through either of `shape`, `dims` or `size` (see [#4696](https://github.com/pymc-devs/pymc3/pull/4696)): - With `shape` the length of dimensions must be given numerically or as scalar Aesara `Variables`. Numeric entries in `shape` restrict the model variable to the exact length and re-sizing is no longer possible. - - `dims` keeps model variables re-sizeable (for example through `pm.Data`) and leads to well defined coordinates in `InferenceData` objects. An `Ellipsis` (`...`) in the last position of `dims` can be used as short-hand notation for implied dimensions. + - `dims` keeps model variables re-sizeable (for example through `pm.Data`) and leads to well defined coordinates in `InferenceData` objects. - The `size` kwarg behaves like it does in Aesara/NumPy. For univariate RVs it is the same as `shape`, but for multivariate RVs it depends on how the RV implements broadcasting to dimensionality greater than `RVOp.ndim_supp`. + - An `Ellipsis` (`...`) in the last position of `shape` or `dims` can be used as short-hand notation for implied dimensions. - Add `logcdf` method to Kumaraswamy distribution (see [#4706](https://github.com/pymc-devs/pymc3/pull/4706)). - ... diff --git a/pymc3/distributions/distribution.py b/pymc3/distributions/distribution.py index 562332a59d..a334a8627e 100644 --- a/pymc3/distributions/distribution.py +++ b/pymc3/distributions/distribution.py @@ -378,6 +378,9 @@ def dist( The inputs to the `RandomVariable` `Op`. shape : int, tuple, Variable, optional A tuple of sizes for each dimension of the new RV. + + An Ellipsis (...) may be inserted in the last position to short-hand refer to + all the dimensions that the RV would get if no shape/size/dims were passed at all. size : int, tuple, Variable, optional For creating the RV like in Aesara/NumPy. initival : optional @@ -414,9 +417,16 @@ def dist( create_size = None if shape is not None: - ndim_expected = len(tuple(shape)) - ndim_batch = ndim_expected - ndim_supp - create_size = tuple(shape)[:ndim_batch] + if Ellipsis in shape: + # Ellipsis short-hands all implied dimensions. Therefore + # we don't know how many dimensions to expect. + ndim_expected = ndim_batch = None + # Create the RV with its implied shape and resize later + create_size = None + else: + ndim_expected = len(tuple(shape)) + ndim_batch = ndim_expected - ndim_supp + create_size = tuple(shape)[:ndim_batch] elif size is not None: ndim_expected = ndim_supp + len(tuple(size)) ndim_batch = ndim_expected - ndim_supp @@ -429,21 +439,25 @@ def dist( ndims_unexpected = ndim_actual != ndim_expected if shape is not None and ndims_unexpected: - # This is rare, but happens, for example, with MvNormal(np.ones((2, 3)), np.eye(3), shape=(2, 3)). - # Recreate the RV without passing `size` to created it with just the implied dimensions. - rv_out = cls.rv_op(*dist_params, size=None, **kwargs) - - # Now resize by the "extra" dimensions that were not implied from support and parameters - if rv_out.ndim < ndim_expected: - expand_shape = shape[: ndim_expected - rv_out.ndim] - rv_out = change_rv_size(rv_var=rv_out, new_size=expand_shape, expand=True) - if not rv_out.ndim == ndim_expected: - raise ShapeError( - f"Failed to create the RV with the expected dimensionality. " - f"This indicates a severe problem. Please open an issue.", - actual=ndim_actual, - expected=ndim_batch + ndim_supp, - ) + if Ellipsis in shape: + # Resize and we're done! + rv_out = change_rv_size(rv_var=rv_out, new_size=shape[:-1], expand=True) + else: + # This is rare, but happens, for example, with MvNormal(np.ones((2, 3)), np.eye(3), shape=(2, 3)). + # Recreate the RV without passing `size` to created it with just the implied dimensions. + rv_out = cls.rv_op(*dist_params, size=None, **kwargs) + + # Now resize by any remaining "extra" dimensions that were not implied from support and parameters + if rv_out.ndim < ndim_expected: + expand_shape = shape[: ndim_expected - rv_out.ndim] + rv_out = change_rv_size(rv_var=rv_out, new_size=expand_shape, expand=True) + if not rv_out.ndim == ndim_expected: + raise ShapeError( + f"Failed to create the RV with the expected dimensionality. " + f"This indicates a severe problem. Please open an issue.", + actual=ndim_actual, + expected=ndim_batch + ndim_supp, + ) # Warn about the edge cases where the RV Op creates more dimensions than # it should based on `size` and `RVOp.ndim_supp`. diff --git a/pymc3/tests/test_shape_handling.py b/pymc3/tests/test_shape_handling.py index 5bb7aa0288..2b75ba765e 100644 --- a/pymc3/tests/test_shape_handling.py +++ b/pymc3/tests/test_shape_handling.py @@ -236,7 +236,7 @@ class TestShapeDimsSize: [ "implicit", "shape", - # "shape...", + "shape...", "dims", "dims...", "size", @@ -273,9 +273,9 @@ def test_param_and_batch_shape_combos( if parametrization == "shape": rv = pm.Normal("rv", mu=mu, shape=batch_shape + param_shape) assert rv.eval().shape == expected_shape - # elif parametrization == "shape...": - # rv = pm.Normal("rv", mu=mu, shape=(*batch_shape, ...)) - # assert rv.eval().shape == batch_shape + param_shape + elif parametrization == "shape...": + rv = pm.Normal("rv", mu=mu, shape=(*batch_shape, ...)) + assert rv.eval().shape == batch_shape + param_shape elif parametrization == "dims": rv = pm.Normal("rv", mu=mu, dims=batch_dims + param_dims) assert rv.eval().shape == expected_shape @@ -376,7 +376,7 @@ def test_dist_api_works(self): pm.Normal.dist(mu=mu, dims=("town",)) assert pm.Normal.dist(mu=mu, shape=(3,)).eval().shape == (3,) assert pm.Normal.dist(mu=mu, shape=(5, 3)).eval().shape == (5, 3) - # assert pm.Normal.dist(mu=mu, shape=(7, ...)).eval().shape == (7, 3) + assert pm.Normal.dist(mu=mu, shape=(7, ...)).eval().shape == (7, 3) assert pm.Normal.dist(mu=mu, size=(3,)).eval().shape == (3,) assert pm.Normal.dist(mu=mu, size=(4, 3)).eval().shape == (4, 3) @@ -402,9 +402,9 @@ def test_mvnormal_shape_size_difference(self): assert rv.ndim == 3 assert tuple(rv.shape.eval()) == (5, 4, 3) - # rv = pm.MvNormal.dist(mu=np.ones((4, 3, 2)), cov=np.eye(2), shape=(6, 5, ...)) - # assert rv.ndim == 5 - # assert tuple(rv.shape.eval()) == (6, 5, 4, 3, 2) + rv = pm.MvNormal.dist(mu=np.ones((4, 3, 2)), cov=np.eye(2), shape=(6, 5, ...)) + assert rv.ndim == 5 + assert tuple(rv.shape.eval()) == (6, 5, 4, 3, 2) with pytest.warns(None): rv = pm.MvNormal.dist(mu=[1, 2, 3], cov=np.eye(3), size=(5, 4))