Skip to content

Commit

Permalink
Add Censored distributions
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Nov 19, 2021
1 parent 50a6117 commit a0ccd82
Show file tree
Hide file tree
Showing 5 changed files with 401 additions and 1 deletion.
2 changes: 2 additions & 0 deletions pymc/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)

from pymc.distributions.bound import Bound
from pymc.distributions.censored import Censored
from pymc.distributions.continuous import (
AsymmetricLaplace,
Beta,
Expand Down Expand Up @@ -194,4 +195,5 @@
"logp_transform",
"logcdf",
"logpt_sum",
"Censored",
]
99 changes: 99 additions & 0 deletions pymc/distributions/censored.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2020 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import aesara.tensor as at
import numpy as np

from aesara.scalar import Clip
from aesara.tensor.elemwise import DimShuffle

from pymc.distributions.distribution import DerivedDistribution, _get_moment


class Censored(DerivedDistribution):
@classmethod
def dist(cls, distribution, lower, upper, **kwargs):
# TODO: Assert distribution is a RandomVariable
if distribution.owner.op.ndim_supp > 0:
raise NotImplemented(
"Censoring of multivariate distributions has not been implemented yet"
)
return super().dist([distribution, lower, upper], **kwargs)

@classmethod
def rv_op(cls, dist, lower=None, upper=None, size=None, rngs=None):
if lower is None:
lower = at.constant(-np.inf)
if upper is None:
upper = at.constant(np.inf)

rv_out = at.clip(dist, lower, upper)
if size is not None:
rv_out = cls.change_size(rv_out, size)
if rngs is not None:
rv_out = cls.change_rngs(rv_out, rngs)
return rv_out

@classmethod
def ndim_supp(cls, *dist_params):
return 0

@classmethod
def change_size(cls, rv, new_size):
dist, lower, upper = rv.owner.inputs
dist_node = dist.owner
# lower / upper may have broadcasted the distribution, in which case we get a Dimshuffle Op
if isinstance(dist.owner.op, DimShuffle):
dist_node = dist_node.inputs[0].owner
rng, old_size, dtype, *dist_params = dist_node.inputs
new_dist = dist_node.op.make_node(rng, new_size, dtype, *dist_params).default_output()
return cls.rv_op(new_dist, lower, upper)

@classmethod
def change_rngs(cls, rv, new_rngs):
(new_rng,) = new_rngs
dist, lower, upper = rv.owner.inputs
dist_node = dist.owner
# lower / upper may have broadcasted the distribution, in which case we get a Dimshuffle Op
if isinstance(dist.owner.op, DimShuffle):
dist_node = dist_node.inputs[0].owner
olg_rng, size, dtype, *dist_params = dist_node.inputs
new_dist = dist_node.op.make_node(new_rng, size, dtype, *dist_params).default_output()
return cls.rv_op(new_dist, lower, upper)

@classmethod
def graph_rvs(cls, dist, *bounds):
return (dist,)


@_get_moment.register(Clip)
def get_moment_censored(op, rv, dist, lower, upper):
moment = at.switch(
at.eq(lower, -np.inf),
at.switch(
at.isinf(upper),
# lower = -inf, upper = inf
0,
# lower = -inf, upper = x
upper - 1,
),
at.switch(
at.eq(upper, np.inf),
# lower = x, upper = inf
lower + 1,
# lower = x, upper = x
(lower + upper) / 2,
),
)
moment = at.full_like(dist, moment)
return moment
234 changes: 233 additions & 1 deletion pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@

from abc import ABCMeta
from functools import singledispatch
from typing import Callable, Optional, Sequence
from typing import Callable, Iterable, Optional, Sequence

import aesara

from aeppl.logprob import _logcdf, _logprob
from aesara import tensor as at
from aesara.tensor.basic import as_tensor_variable
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.random.var import RandomStateSharedVariable
from aesara.tensor.var import TensorVariable
Expand Down Expand Up @@ -354,11 +355,242 @@ def dist(
return rv_out


class DerivedDistribution:
def __new__(
cls,
name: str,
*args,
rngs: Optional[Iterable] = None,
dims: Optional[Dims] = None,
initval=None,
observed=None,
total_size=None,
transform=UNSET,
**kwargs,
) -> TensorVariable:
"""Adds a TensorVariable corresponding to a PyMC derived distribution to the current model.
Note that all remaining kwargs must be compatible with ``.dist()``
Parameters
----------
cls : type
A PyMC distribution.
name : str
Name for the new model variable.
rngs : optional
Random number generator to use with the RandomVariable.
dims : tuple, optional
A tuple of dimension names known to the model.
initval : optional
Numeric or symbolic untransformed initial value of matching shape,
or one of the following initial value strategies: "moment", "prior".
Depending on the sampler's settings, a random jitter may be added to numeric, symbolic
or moment-based initial values in the transformed space.
observed : optional
Observed data to be passed when registering the random variable in the model.
See ``Model.register_rv``.
total_size : float, optional
See ``Model.register_rv``.
transform : optional
See ``Model.register_rv``.
**kwargs
Keyword arguments that will be forwarded to ``.dist()``.
Most prominently: ``shape`` and ``size``
Returns
-------
var : TensorVariable
The created variable, registered in the Model.
"""

try:
from pymc.model import Model

model = Model.get_context()
except TypeError:
raise TypeError(
"No model on context stack, which is needed to "
"instantiate distributions. Add variable inside "
"a 'with model:' block, or use the '.dist' syntax "
"for a standalone distribution."
)

if "testval" in kwargs:
initval = kwargs.pop("testval")
warnings.warn(
"The `testval` argument is deprecated; use `initval`.",
FutureWarning,
stacklevel=2,
)

if not isinstance(name, string_types):
raise TypeError(f"Name needs to be a string but got: {name}")

if dims is not None and "shape" in kwargs:
raise ValueError(
f"Passing both `dims` ({dims}) and `shape` ({kwargs['shape']}) is not supported!"
)
if dims is not None and "size" in kwargs:
raise ValueError(
f"Passing both `dims` ({dims}) and `size` ({kwargs['size']}) is not supported!"
)
dims = convert_dims(dims)

if rngs is None:
rngs = [model.next_rng() for _ in cls.graph_rvs(args)]

# Create the RV without dims information, because that's not something tracked at the Aesara level.
# If necessary we'll later replicate to a different size implied by already known dims.
rv_out = cls.dist(*args, rngs=rngs, **kwargs)
ndim_actual = rv_out.ndim
resize_shape = None

# # `dims` are only available with this API, because `.dist()` can be used
# # without a modelcontext and dims are not tracked at the Aesara level.
if dims is not None:
ndim_resize, resize_shape, dims = resize_from_dims(dims, ndim_actual, model)
elif observed is not None:
ndim_resize, resize_shape, observed = resize_from_observed(observed, ndim_actual)

if resize_shape:
# A batch size was specified through `dims`, or implied by `observed`.
rv_out = cls.change_size(
rv=rv_out,
new_size=resize_shape,
)

rv_out = model.register_rv(
rv_out,
name,
observed,
total_size,
dims=dims,
transform=transform,
initval=initval,
)

# TODO: Refactor this
# add in pretty-printing support
rv_out.str_repr = lambda *args, **kwargs: name
rv_out._repr_latex_ = f"\\text{name}"
# rv_out.str_repr = types.MethodType(str_for_dist, rv_out)
# rv_out._repr_latex_ = types.MethodType(
# functools.partial(str_for_dist, formatting="latex"), rv_out
# )

return rv_out

@classmethod
def dist(
cls,
dist_params,
*,
shape: Optional[Shape] = None,
size: Optional[Size] = None,
**kwargs,
) -> TensorVariable:
"""Creates a TensorVariable corresponding to the `cls` derived distribution.
Parameters
----------
dist_params : array-like
The inputs to the `RandomVariable` `Op`.
shape : int, tuple, Variable, optional
A tuple of sizes for each dimension of the new RV.
An Ellipsis (...) may be inserted in the last position to short-hand refer to
all the dimensions that the RV would get if no shape/size/dims were passed at all.
size : int, tuple, Variable, optional
For creating the RV like in Aesara/NumPy.
Returns
-------
var : TensorVariable
"""

if "testval" in kwargs:
kwargs.pop("testval")
warnings.warn(
"The `.dist(testval=...)` argument is deprecated and has no effect. "
"Initial values for sampling/optimization can be specified with `initval` in a modelcontext. "
"For using Aesara's test value features, you must assign the `.tag.test_value` yourself.",
FutureWarning,
stacklevel=2,
)
if "initval" in kwargs:
raise TypeError(
"Unexpected keyword argument `initval`. "
"This argument is not available for the `.dist()` API."
)

if "dims" in kwargs:
raise NotImplementedError("The use of a `.dist(dims=...)` API is not supported.")
if shape is not None and size is not None:
raise ValueError(
f"Passing both `shape` ({shape}) and `size` ({size}) is not supported!"
)

shape = convert_shape(shape)
size = convert_size(size)

create_size, ndim_expected, ndim_batch, ndim_supp = find_size(
shape=shape, size=size, ndim_supp=cls.ndim_supp(*dist_params)
)
# Create the RV with a `size` right away.
# This is not necessarily the final result.
graph = cls.rv_op(*dist_params, size=create_size, **kwargs)
# TODO: Refactor this branch
# graph = maybe_resize(
# graph,
# cls.rv_op,
# dist_params,
# ndim_expected,
# ndim_batch,
# ndim_supp,
# shape,
# size,
# **kwargs,
# )

rngs = kwargs.pop("rngs", None)
if rngs is not None:
graph_rvs = cls.graph_rvs(*graph.owner.inputs)
assert len(rngs) == len(graph_rvs)
for rng, rv_out in zip(rngs, graph_rvs):
if (
rv_out.owner
and isinstance(rv_out.owner.op, RandomVariable)
and isinstance(rng, RandomStateSharedVariable)
and not getattr(rng, "default_update", None)
):
# This tells `aesara.function` that the shared RNG variable
# is mutable, which--in turn--tells the `FunctionGraph`
# `Supervisor` feature to allow in-place updates on the variable.
# Without it, the `RandomVariable`s could not be optimized to allow
# in-place RNG updates, forcing all sample results from compiled
# functions to be the same on repeated evaluations.
new_rng = rv_out.owner.outputs[0]
rv_out.update = (rng, new_rng)
rng.default_update = new_rng

# TODO: Create new attr error stating that these are not available for DerivedDistribution
# rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
# rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)")
# rv_out.random = _make_nice_attr_error("rv.random()", "rv.eval()")
return graph


@singledispatch
def _get_moment(op, rv, *rv_inputs) -> TensorVariable:
raise NotImplementedError(f"Variable {rv} of type {op} has no get_moment implementation.")


@_get_moment.register(Elemwise)
def _get_moment_elemwise(op, rv, *dist_params):
"""For Elemwise Ops, dispatch on respective scalar_op"""
return _get_moment(op.scalar_op, rv, *dist_params)


def get_moment(rv: TensorVariable) -> TensorVariable:
"""Method for choosing a representative point/value
Expand Down
Loading

0 comments on commit a0ccd82

Please sign in to comment.