Skip to content

Commit

Permalink
Rename MeasurableVariable to MeasurableOp
Browse files Browse the repository at this point in the history
Also:
* Introduce MeasurableOpMixin for string representation
* Subclass directly instead of registering manually
  • Loading branch information
ricardoV94 committed Aug 1, 2024
1 parent 312b0f8 commit 78f148f
Show file tree
Hide file tree
Showing 19 changed files with 78 additions and 121 deletions.
8 changes: 2 additions & 6 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
rv_size_is_none,
shape_from_dims,
)
from pymc.logprob.abstract import MeasurableVariable, _icdf, _logcdf, _logprob
from pymc.logprob.abstract import MeasurableOp, _icdf, _logcdf, _logprob
from pymc.logprob.basic import logp
from pymc.logprob.rewriting import logprob_rewrites_db
from pymc.printing import str_for_dist
Expand Down Expand Up @@ -228,7 +228,7 @@ def __get__(self, instance, type_):
return descr_get(instance, type_)


class SymbolicRandomVariable(OpFromGraph):
class SymbolicRandomVariable(MeasurableOp, OpFromGraph):
"""Symbolic Random Variable
This is a subclasse of `OpFromGraph` which is used to encapsulate the symbolic
Expand Down Expand Up @@ -624,10 +624,6 @@ def dist(
return rv_out


# Let PyMC know that the SymbolicRandomVariable has a logprob.
MeasurableVariable.register(SymbolicRandomVariable)


@node_rewriter([SymbolicRandomVariable])
def inline_symbolic_random_variable(fgraph, node):
"""
Expand Down
30 changes: 23 additions & 7 deletions pymc/logprob/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
# SOFTWARE.

import abc
import warnings

from collections.abc import Sequence
from functools import singledispatch
Expand All @@ -46,6 +47,17 @@
from pytensor.tensor.random.op import RandomVariable


def __getattr__(name):
if name == "MeasurableVariable":
warnings.warn(

Check warning on line 52 in pymc/logprob/abstract.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/abstract.py#L52

Added line #L52 was not covered by tests
f"{name} has been deprecated in favor of MeasurableOp. Importing will fail in a future release.",
FutureWarning,
)
return MeasurableOpMixin

Check warning on line 56 in pymc/logprob/abstract.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/abstract.py#L56

Added line #L56 was not covered by tests

raise AttributeError(f"module {__name__} has no attribute {name}")


@singledispatch
def _logprob(
op: Op,
Expand Down Expand Up @@ -131,14 +143,21 @@ def _icdf_helper(rv, value, **kwargs):
return rv_icdf


class MeasurableVariable(abc.ABC):
"""A variable that can be assigned a measure/log-probability"""
class MeasurableOp(abc.ABC):
"""An operation whose outputs can be assigned a measure/log-probability"""


MeasurableOp.register(RandomVariable)

MeasurableVariable.register(RandomVariable)

class MeasurableOpMixin(MeasurableOp):
"""MeasurableOp Mixin with a distinctive string representation"""

class MeasurableElemwise(Elemwise):
def __str__(self):
return f"Measurable{super().__str__()}"


class MeasurableElemwise(MeasurableOpMixin, Elemwise):
"""Base class for Measurable Elemwise variables"""

valid_scalar_types: tuple[MetaType, ...] = ()
Expand All @@ -150,6 +169,3 @@ def __init__(self, scalar_op, *args, **kwargs):
f"Acceptable types are {self.valid_scalar_types}"
)
super().__init__(scalar_op, *args, **kwargs)


MeasurableVariable.register(MeasurableElemwise)
4 changes: 2 additions & 2 deletions pymc/logprob/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from pytensor.tensor.variable import TensorVariable

from pymc.logprob.abstract import (
MeasurableVariable,
MeasurableOp,
_icdf_helper,
_logcdf_helper,
_logprob,
Expand Down Expand Up @@ -522,7 +522,7 @@ def conditional_logp(
while q:
node = q.popleft()

if not isinstance(node.op, MeasurableVariable):
if not isinstance(node.op, MeasurableOp):
continue

q_values = [replacements[q_rv] for q_rv in node.outputs if q_rv in updated_rv_values]
Expand Down
14 changes: 4 additions & 10 deletions pymc/logprob/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,15 @@
from pytensor.tensor import TensorVariable
from pytensor.tensor.shape import SpecifyShape

from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper
from pymc.logprob.abstract import MeasurableOp, MeasurableOpMixin, _logprob, _logprob_helper
from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db
from pymc.logprob.utils import replace_rvs_by_values


class MeasurableSpecifyShape(SpecifyShape):
class MeasurableSpecifyShape(MeasurableOpMixin, SpecifyShape):
"""A placeholder used to specify a log-likelihood for a specify-shape sub-graph."""


MeasurableVariable.register(MeasurableSpecifyShape)


@_logprob.register(MeasurableSpecifyShape)
def logprob_specify_shape(op, values, inner_rv, *shapes, **kwargs):
(value,) = values
Expand All @@ -80,7 +77,7 @@ def find_measurable_specify_shapes(fgraph, node) -> list[TensorVariable] | None:

if not (
base_rv.owner
and isinstance(base_rv.owner.op, MeasurableVariable)
and isinstance(base_rv.owner.op, MeasurableOp)
and base_rv not in rv_map_feature.rv_values
):
return None # pragma: no cover
Expand All @@ -99,13 +96,10 @@ def find_measurable_specify_shapes(fgraph, node) -> list[TensorVariable] | None:
)


class MeasurableCheckAndRaise(CheckAndRaise):
class MeasurableCheckAndRaise(MeasurableOpMixin, CheckAndRaise):
"""A placeholder used to specify a log-likelihood for an assert sub-graph."""


MeasurableVariable.register(MeasurableCheckAndRaise)


@_logprob.register(MeasurableCheckAndRaise)
def logprob_check_and_raise(op, values, inner_rv, *assertions, **kwargs):
(value,) = values
Expand Down
7 changes: 2 additions & 5 deletions pymc/logprob/cumsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,14 @@
from pytensor.tensor import TensorVariable
from pytensor.tensor.extra_ops import CumOp

from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper
from pymc.logprob.abstract import MeasurableOpMixin, _logprob, _logprob_helper
from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db


class MeasurableCumsum(CumOp):
class MeasurableCumsum(MeasurableOpMixin, CumOp):
"""A placeholder used to specify a log-likelihood for a cumsum sub-graph."""


MeasurableVariable.register(MeasurableCumsum)


@_logprob.register(MeasurableCumsum)
def logprob_cumsum(op, values, base_rv, **kwargs):
"""Compute the log-likelihood graph for a `Cumsum`."""
Expand Down
15 changes: 5 additions & 10 deletions pymc/logprob/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@

from pymc.logprob.abstract import (
MeasurableElemwise,
MeasurableVariable,
MeasurableOp,
MeasurableOpMixin,
_logprob,
_logprob_helper,
)
Expand Down Expand Up @@ -217,7 +218,7 @@ def rv_pull_down(x: TensorVariable) -> TensorVariable:
return fgraph.outputs[0]


class MixtureRV(Op):
class MixtureRV(MeasurableOpMixin, Op):
"""A placeholder used to specify a log-likelihood for a mixture sub-graph."""

__props__ = ("indices_end_idx", "out_dtype", "out_broadcastable")
Expand All @@ -235,9 +236,6 @@ def perform(self, node, inputs, outputs):
raise NotImplementedError("This is a stand-in Op.") # pragma: no cover


MeasurableVariable.register(MixtureRV)


def get_stack_mixture_vars(
node: Apply,
) -> tuple[list[TensorVariable] | None, int | None]:
Expand Down Expand Up @@ -457,13 +455,10 @@ def logprob_switch_mixture(op, values, switch_cond, component_true, component_fa
)


class MeasurableIfElse(IfElse):
class MeasurableIfElse(MeasurableOpMixin, IfElse):
"""Measurable subclass of IfElse operator."""


MeasurableVariable.register(MeasurableIfElse)


@node_rewriter([IfElse])
def useless_ifelse_outputs(fgraph, node):
"""Remove outputs that are shared across the IfElse branches."""
Expand Down Expand Up @@ -512,7 +507,7 @@ def find_measurable_ifelse_mixture(fgraph, node):
base_rvs = assume_measured_ir_outputs(valued_rvs, base_rvs)
if len(base_rvs) != op.n_outs * 2:
return None
if not all(var.owner and isinstance(var.owner.op, MeasurableVariable) for var in base_rvs):
if not all(var.owner and isinstance(var.owner.op, MeasurableOp) for var in base_rvs):
return None

return MeasurableIfElse(n_outs=op.n_outs).make_node(if_var, *base_rvs).outputs
Expand Down
22 changes: 5 additions & 17 deletions pymc/logprob/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from pytensor.tensor.variable import TensorVariable

from pymc.logprob.abstract import (
MeasurableVariable,
MeasurableOpMixin,
_logcdf_helper,
_logprob,
_logprob_helper,
Expand All @@ -59,20 +59,14 @@
from pymc.pytensorf import constant_fold


class MeasurableMax(Max):
class MeasurableMax(MeasurableOpMixin, Max):
"""A placeholder used to specify a log-likelihood for a max sub-graph."""


MeasurableVariable.register(MeasurableMax)


class MeasurableMaxDiscrete(Max):
class MeasurableMaxDiscrete(MeasurableOpMixin, Max):
"""A placeholder used to specify a log-likelihood for sub-graphs of maxima of discrete variables"""


MeasurableVariable.register(MeasurableMaxDiscrete)


@node_rewriter([Max])
def find_measurable_max(fgraph: FunctionGraph, node: Apply) -> list[TensorVariable] | None:
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
Expand Down Expand Up @@ -162,21 +156,15 @@ def max_logprob_discrete(op, values, base_rv, **kwargs):
return logprob


class MeasurableMaxNeg(Max):
class MeasurableMaxNeg(MeasurableOpMixin, Max):
"""A placeholder used to specify a log-likelihood for a max(neg(x)) sub-graph.
This shows up in the graph of min, which is (neg(max(neg(x)))."""


MeasurableVariable.register(MeasurableMaxNeg)


class MeasurableDiscreteMaxNeg(Max):
class MeasurableDiscreteMaxNeg(MeasurableOpMixin, Max):
"""A placeholder used to specify a log-likelihood for sub-graphs of negative maxima of discrete variables"""


MeasurableVariable.register(MeasurableDiscreteMaxNeg)


@node_rewriter(tracks=[Max])
def find_measurable_max_neg(fgraph: FunctionGraph, node: Apply) -> list[TensorVariable] | None:
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
Expand Down
8 changes: 4 additions & 4 deletions pymc/logprob/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
)
from pytensor.tensor.variable import TensorVariable

from pymc.logprob.abstract import MeasurableVariable
from pymc.logprob.abstract import MeasurableOp
from pymc.logprob.utils import DiracDelta

inc_subtensor_ops = (IncSubtensor, AdvancedIncSubtensor, AdvancedIncSubtensor1)
Expand Down Expand Up @@ -138,7 +138,7 @@ def apply(self, fgraph):
continue
# This is where we filter only those nodes we care about:
# Nodes that have variables that we want to measure and are not yet measurable
if isinstance(node.op, MeasurableVariable):
if isinstance(node.op, MeasurableOp):
continue
if not any(out in rv_map_feature.needs_measuring for out in node.outputs):
continue
Expand All @@ -154,7 +154,7 @@ def apply(self, fgraph):
node_rewriter, "__name__", ""
)
# If we converted to a MeasurableVariable we're done here!
if node not in fgraph.apply_nodes or isinstance(node.op, MeasurableVariable):
if node not in fgraph.apply_nodes or isinstance(node.op, MeasurableOp):
# go to next node
break

Expand Down Expand Up @@ -273,7 +273,7 @@ def request_measurable(self, vars: Sequence[Variable]) -> list[Variable]:
# Input vars or valued vars can't be measured for derived expressions
if not var.owner or var in self.rv_values:
continue
if isinstance(var.owner.op, MeasurableVariable):
if isinstance(var.owner.op, MeasurableOp):
measurable.append(var)
else:
self.needs_measuring.add(var)
Expand Down
9 changes: 3 additions & 6 deletions pymc/logprob/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from pytensor.tensor.variable import TensorVariable
from pytensor.updates import OrderedUpdates

from pymc.logprob.abstract import MeasurableVariable, _logprob
from pymc.logprob.abstract import MeasurableOp, MeasurableOpMixin, _logprob
from pymc.logprob.basic import conditional_logp
from pymc.logprob.rewriting import (
PreserveRVMappings,
Expand All @@ -66,16 +66,13 @@
from pymc.logprob.utils import replace_rvs_by_values


class MeasurableScan(Scan):
class MeasurableScan(MeasurableOpMixin, Scan):
"""A placeholder used to specify a log-likelihood for a scan sub-graph."""

def __str__(self):
return f"Measurable({super().__str__()})"


MeasurableVariable.register(MeasurableScan)


def convert_outer_out_to_in(
input_scan_args: ScanArgs,
outer_out_vars: Iterable[TensorVariable],
Expand Down Expand Up @@ -288,7 +285,7 @@ def get_random_outer_outputs(
io_type = oo_info.name[(oo_info.name.index("_", 6) + 1) :]
inner_out_type = f"inner_out_{io_type}"
io_var = getattr(scan_args, inner_out_type)[oo_info.index]
if io_var.owner and isinstance(io_var.owner.op, MeasurableVariable):
if io_var.owner and isinstance(io_var.owner.op, MeasurableOp):
rv_vars.append((n, oo_var, io_var))
return rv_vars

Expand Down
Loading

0 comments on commit 78f148f

Please sign in to comment.