Skip to content

Commit

Permalink
Add Ellipsis-support for the shape kwarg
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelosthege committed Jun 4, 2021
1 parent c7d477c commit 46c225e
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 27 deletions.
3 changes: 2 additions & 1 deletion RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
- The `CAR` distribution has been added to allow for use of conditional autoregressions which often are used in spatial and network models.
- The dimensionality of model variables can now be parametrized through either of `shape`, `dims` or `size` (see [#4696](https://github.com/pymc-devs/pymc3/pull/4696)):
- With `shape` the length of dimensions must be given numerically or as scalar Aesara `Variables`. Numeric entries in `shape` restrict the model variable to the exact length and re-sizing is no longer possible.
- `dims` keeps model variables re-sizeable (for example through `pm.Data`) and leads to well defined coordinates in `InferenceData` objects. An `Ellipsis` (`...`) in the last position of `dims` can be used as short-hand notation for implied dimensions.
- `dims` keeps model variables re-sizeable (for example through `pm.Data`) and leads to well defined coordinates in `InferenceData` objects.
- The `size` kwarg behaves like it does in Aesara/NumPy. For univariate RVs it is the same as `shape`, but for multivariate RVs it depends on how the RV implements broadcasting to dimensionality greater than `RVOp.ndim_supp`.
- An `Ellipsis` (`...`) in the last position of `shape` or `dims` can be used as short-hand notation for implied dimensions.
- Add `logcdf` method to Kumaraswamy distribution (see [#4706](https://github.com/pymc-devs/pymc3/pull/4706)).
- ...

Expand Down
50 changes: 32 additions & 18 deletions pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,9 @@ def dist(
The inputs to the `RandomVariable` `Op`.
shape : int, tuple, Variable, optional
A tuple of sizes for each dimension of the new RV.
An Ellipsis (...) may be inserted in the last position to short-hand refer to
all the dimensions that the RV would get if no shape/size/dims were passed at all.
size : int, tuple, Variable, optional
For creating the RV like in Aesara/NumPy.
initival : optional
Expand Down Expand Up @@ -414,9 +417,16 @@ def dist(
create_size = None

if shape is not None:
ndim_expected = len(tuple(shape))
ndim_batch = ndim_expected - ndim_supp
create_size = tuple(shape)[:ndim_batch]
if Ellipsis in shape:
# Ellipsis short-hands all implied dimensions. Therefore
# we don't know how many dimensions to expect.
ndim_expected = ndim_batch = None
# Create the RV with its implied shape and resize later
create_size = None
else:
ndim_expected = len(tuple(shape))
ndim_batch = ndim_expected - ndim_supp
create_size = tuple(shape)[:ndim_batch]
elif size is not None:
ndim_expected = ndim_supp + len(tuple(size))
ndim_batch = ndim_expected - ndim_supp
Expand All @@ -429,21 +439,25 @@ def dist(
ndims_unexpected = ndim_actual != ndim_expected

if shape is not None and ndims_unexpected:
# This is rare, but happens, for example, with MvNormal(np.ones((2, 3)), np.eye(3), shape=(2, 3)).
# Recreate the RV without passing `size` to created it with just the implied dimensions.
rv_out = cls.rv_op(*dist_params, size=None, **kwargs)

# Now resize by the "extra" dimensions that were not implied from support and parameters
if rv_out.ndim < ndim_expected:
expand_shape = shape[: ndim_expected - rv_out.ndim]
rv_out = change_rv_size(rv_var=rv_out, new_size=expand_shape, expand=True)
if not rv_out.ndim == ndim_expected:
raise ShapeError(
f"Failed to create the RV with the expected dimensionality. "
f"This indicates a severe problem. Please open an issue.",
actual=ndim_actual,
expected=ndim_batch + ndim_supp,
)
if Ellipsis in shape:
# Resize and we're done!
rv_out = change_rv_size(rv_var=rv_out, new_size=shape[:-1], expand=True)
else:
# This is rare, but happens, for example, with MvNormal(np.ones((2, 3)), np.eye(3), shape=(2, 3)).
# Recreate the RV without passing `size` to created it with just the implied dimensions.
rv_out = cls.rv_op(*dist_params, size=None, **kwargs)

# Now resize by any remaining "extra" dimensions that were not implied from support and parameters
if rv_out.ndim < ndim_expected:
expand_shape = shape[: ndim_expected - rv_out.ndim]
rv_out = change_rv_size(rv_var=rv_out, new_size=expand_shape, expand=True)
if not rv_out.ndim == ndim_expected:
raise ShapeError(
f"Failed to create the RV with the expected dimensionality. "
f"This indicates a severe problem. Please open an issue.",
actual=ndim_actual,
expected=ndim_batch + ndim_supp,
)

# Warn about the edge cases where the RV Op creates more dimensions than
# it should based on `size` and `RVOp.ndim_supp`.
Expand Down
16 changes: 8 additions & 8 deletions pymc3/tests/test_shape_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ class TestShapeDimsSize:
[
"implicit",
"shape",
# "shape...",
"shape...",
"dims",
"dims...",
"size",
Expand Down Expand Up @@ -273,9 +273,9 @@ def test_param_and_batch_shape_combos(
if parametrization == "shape":
rv = pm.Normal("rv", mu=mu, shape=batch_shape + param_shape)
assert rv.eval().shape == expected_shape
# elif parametrization == "shape...":
# rv = pm.Normal("rv", mu=mu, shape=(*batch_shape, ...))
# assert rv.eval().shape == batch_shape + param_shape
elif parametrization == "shape...":
rv = pm.Normal("rv", mu=mu, shape=(*batch_shape, ...))
assert rv.eval().shape == batch_shape + param_shape
elif parametrization == "dims":
rv = pm.Normal("rv", mu=mu, dims=batch_dims + param_dims)
assert rv.eval().shape == expected_shape
Expand Down Expand Up @@ -376,7 +376,7 @@ def test_dist_api_works(self):
pm.Normal.dist(mu=mu, dims=("town",))
assert pm.Normal.dist(mu=mu, shape=(3,)).eval().shape == (3,)
assert pm.Normal.dist(mu=mu, shape=(5, 3)).eval().shape == (5, 3)
# assert pm.Normal.dist(mu=mu, shape=(7, ...)).eval().shape == (7, 3)
assert pm.Normal.dist(mu=mu, shape=(7, ...)).eval().shape == (7, 3)
assert pm.Normal.dist(mu=mu, size=(3,)).eval().shape == (3,)
assert pm.Normal.dist(mu=mu, size=(4, 3)).eval().shape == (4, 3)

Expand All @@ -402,9 +402,9 @@ def test_mvnormal_shape_size_difference(self):
assert rv.ndim == 3
assert tuple(rv.shape.eval()) == (5, 4, 3)

# rv = pm.MvNormal.dist(mu=np.ones((4, 3, 2)), cov=np.eye(2), shape=(6, 5, ...))
# assert rv.ndim == 5
# assert tuple(rv.shape.eval()) == (6, 5, 4, 3, 2)
rv = pm.MvNormal.dist(mu=np.ones((4, 3, 2)), cov=np.eye(2), shape=(6, 5, ...))
assert rv.ndim == 5
assert tuple(rv.shape.eval()) == (6, 5, 4, 3, 2)

with pytest.warns(None):
rv = pm.MvNormal.dist(mu=[1, 2, 3], cov=np.eye(3), size=(5, 4))
Expand Down

0 comments on commit 46c225e

Please sign in to comment.