From 8370d976f215f72cec9b5ad8652a3d1c6f7bbaef Mon Sep 17 00:00:00 2001 From: Cove Geary Date: Sat, 18 Feb 2023 14:13:33 -0500 Subject: [PATCH] Refactor pretty-printing and fix latex underscores, addresses 6508 --- pymc/distributions/distribution.py | 8 +- pymc/model.py | 16 +- pymc/printing.py | 367 ++++++++++++++--------------- pymc/tests/test_printing.py | 236 ++++++++++--------- 4 files changed, 309 insertions(+), 318 deletions(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 5ece0b1bb6..2228dc1885 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -57,7 +57,6 @@ ) from pymc.logprob.rewriting import logprob_rewrites_db from pymc.model import BlockModelAccess -from pymc.printing import str_for_dist from pymc.pytensorf import collect_default_updates, convert_observed_data from pymc.util import UNSET, _add_future_warning_tag from pymc.vartypes import string_types @@ -320,9 +319,12 @@ def __new__( ) # add in pretty-printing support - rv_out.str_repr = types.MethodType(str_for_dist, rv_out) + from pymc.printing import str_for_model_var + + rv_out.str_repr = types.MethodType(str_for_model_var, rv_out) rv_out._repr_latex_ = types.MethodType( - functools.partial(str_for_dist, formatting="latex"), rv_out + functools.partial(str_for_model_var, formatting="latex"), + rv_out, ) rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)") diff --git a/pymc/model.py b/pymc/model.py index 62bbf66605..ba45a50873 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -2015,15 +2015,13 @@ def Deterministic(name, var, model=None, dims=None): model.deterministics.append(var) model.add_named_variable(var, dims) - from pymc.printing import str_for_potential_or_deterministic + from pymc.printing import str_for_model_var var.str_repr = types.MethodType( - functools.partial(str_for_potential_or_deterministic, dist_name="Deterministic"), var + functools.partial(str_for_model_var, dist_name="Deterministic"), var ) var._repr_latex_ = types.MethodType( - functools.partial( - str_for_potential_or_deterministic, dist_name="Deterministic", formatting="latex" - ), + functools.partial(str_for_model_var, dist_name="Deterministic", formatting="latex"), var, ) @@ -2047,15 +2045,13 @@ def Potential(name, var, model=None): model.potentials.append(var) model.add_named_variable(var) - from pymc.printing import str_for_potential_or_deterministic + from pymc.printing import str_for_model_var var.str_repr = types.MethodType( - functools.partial(str_for_potential_or_deterministic, dist_name="Potential"), var + functools.partial(str_for_model_var, dist_name="Potential"), var ) var._repr_latex_ = types.MethodType( - functools.partial( - str_for_potential_or_deterministic, dist_name="Potential", formatting="latex" - ), + functools.partial(str_for_model_var, dist_name="Potential", formatting="latex"), var, ) diff --git a/pymc/printing.py b/pymc/printing.py index f3ce014384..7094a08640 100644 --- a/pymc/printing.py +++ b/pymc/printing.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import functools import itertools +import warnings -from typing import Union +from typing import List, Optional, Tuple, Union from pytensor.compile import SharedVariable from pytensor.graph.basic import Constant, walk @@ -26,91 +28,52 @@ RandomStateSharedVariable, ) +from pymc.distributions.distribution import SymbolicRandomVariable from pymc.model import Model -__all__ = [ - "str_for_dist", - "str_for_model", - "str_for_potential_or_deterministic", -] +__all__ = ["str_for_model", "str_for_model_var"] -def str_for_dist( - dist: TensorVariable, formatting: str = "plain", include_params: bool = True +def str_for_model_var( + var: TensorVariable, formatting: str = "plain", dist_name: Optional[str] = None, **kwargs ) -> str: - """Make a human-readable string representation of a Distribution in a model, either - LaTeX or plain, optionally with distribution parameter values included.""" + """Make a human-readable string representation of a Model variable. - if include_params: - # first 3 args are always (rng, size, dtype), rest is relevant for distribution - if isinstance(dist.owner.op, RandomVariable): - dist_args = [ - _str_for_input_var(x, formatting=formatting) for x in dist.owner.inputs[3:] - ] - else: - dist_args = [ - _str_for_input_var(x, formatting=formatting) - for x in dist.owner.inputs - if not isinstance(x, (RandomStateSharedVariable, RandomGeneratorSharedVariable)) - ] - - print_name = dist.name - - if "latex" in formatting: - if print_name is not None: - print_name = r"\text{" + _latex_escape(dist.name.strip("$")) + "}" - - op_name = ( - dist.owner.op._print_name[1] - if hasattr(dist.owner.op, "_print_name") - else r"\\operatorname{Unknown}" + Intended for Distribution, Deterministic, and Potential. + """ + var_name, dist_name, args_str = _get_varname_distname_args( + var, formatting=formatting, dist_name=dist_name + ) + if "include_params" in kwargs: + warnings.warn( + "The `include_params` argument has been deprecated. Future versions will always include parameter values.", + FutureWarning, + stacklevel=2, ) - if include_params: - if print_name: - return r"${} \sim {}({})$".format( - print_name, op_name, ",~".join([d.strip("$") for d in dist_args]) - ) - else: - return r"${}({})$".format(op_name, ",~".join([d.strip("$") for d in dist_args])) - - else: - if print_name: - return rf"${print_name} \sim {op_name}$" - else: - return rf"${op_name}$" - - else: # plain - dist_name = ( - dist.owner.op._print_name[0] if hasattr(dist.owner.op, "_print_name") else "Unknown" + # When removing, remove **kwargs from here and str_for_model(). + if not kwargs["include_params"]: + args_str = "" + if formatting == "latex": + out = rf"${var_name} \sim {dist_name}({args_str})$" + elif formatting == "plain": + out = f"{var_name} ~ {dist_name}({args_str})" + else: + raise ValueError( + f"Formatting method not recognized: {formatting}. Supported formatting: [latex, plain]." ) - if include_params: - if print_name: - return r"{} ~ {}({})".format(print_name, dist_name, ", ".join(dist_args)) - else: - return r"{}({})".format(dist_name, ", ".join(dist_args)) - else: - if print_name: - return rf"{print_name} ~ {dist_name}" - else: - return dist_name + return out -def str_for_model(model: Model, formatting: str = "plain", include_params: bool = True) -> str: - """Make a human-readable string representation of Model, listing all random variables - and their distributions, optionally including parameter values.""" +def str_for_model(model: Model, formatting: str = "plain", **kwargs) -> str: + """Make a human-readable string representation of Model including all random + variables and their distributions. + """ all_rv = itertools.chain(model.unobserved_RVs, model.observed_RVs, model.potentials) - - rv_reprs = [rv.str_repr(formatting=formatting, include_params=include_params) for rv in all_rv] - rv_reprs = [rv_repr for rv_repr in rv_reprs if "TransformedDistribution()" not in rv_repr] - + rv_reprs = [rv.str_repr(formatting=formatting, **kwargs) for rv in all_rv] if not rv_reprs: return "" if "latex" in formatting: - rv_reprs = [ - rv_repr.replace(r"\sim", r"&\sim &").strip("$") - for rv_repr in rv_reprs - if rv_repr is not None - ] + rv_reprs = [rv_repr.replace(r"\sim", r"&\sim &").strip("$") for rv_repr in rv_reprs] return r"""$$ \begin{{array}}{{rcl}} {} @@ -130,148 +93,176 @@ def str_for_model(model: Model, formatting: str = "plain", include_params: bool return "\n".join(rv_reprs) -def str_for_potential_or_deterministic( - var: TensorVariable, - formatting: str = "plain", - include_params: bool = True, - dist_name: str = "Deterministic", -) -> str: - """Make a human-readable string representation of a Deterministic or Potential in a model, either - LaTeX or plain, optionally with distribution parameter values included.""" - print_name = var.name if var.name is not None else "" - if "latex" in formatting: - print_name = r"\text{" + _latex_escape(print_name.strip("$")) + "}" - if include_params: - return rf"${print_name} \sim \operatorname{{{dist_name}}}({_str_for_expression(var, formatting=formatting)})$" - else: - return rf"${print_name} \sim \operatorname{{{dist_name}}}$" - else: # plain - if include_params: - return rf"{print_name} ~ {dist_name}({_str_for_expression(var, formatting=formatting)})" - else: - return rf"{print_name} ~ {dist_name}" - +def _get_varname_distname_args( + var: TensorVariable, formatting: str, dist_name: Optional[str] = None +) -> Tuple[str, str, str]: + """Generate formatted strings for the name, distribution name, and + arguments list of a Model variable. + """ + # Name and distribution name + name = var.name if var.name is not None else "" + if not dist_name and hasattr(var.owner.op, "_print_name"): + # The _print_name tuple is necessary for maximum prettiness because a few RVs + # use special formatting (e.g. superscripts) for their latex print name + dist_name = ( + var.owner.op._print_name[1] if formatting == "latex" else var.owner.op._print_name[0] + ) + elif not dist_name: + dist_name = "Unknown" + if formatting == "latex": + name = _latex_clean_command(name, command="text") + dist_name = _latex_clean_command(dist_name, command="operatorname") + # Arguments passed to the distribution or expression + if isinstance(var.owner.op, RandomVariable): + # var is the RV from a Distribution. + dist_args = var.owner.inputs[3:] # First 3 inputs are always rng, size, dtype + elif isinstance(var.owner.op, SymbolicRandomVariable): + # var is a symbolic RV from a Distribution. + dist_args = [ + x + for x in var.owner.inputs + if not isinstance(x, (RandomStateSharedVariable, RandomGeneratorSharedVariable)) + ] + elif _is_potential_or_deterministic(var): + # var is a Deterministic or a Potential. + dist_args = _walk_expression_args(var) + else: + raise ValueError(f"Unable to parse arguments for variable") + args_str = _str_for_args_list(dist_args, formatting=formatting) + if _is_potential_or_deterministic(var): + args_str = f"f({args_str})" # TODO do we still want to do this? -def _str_for_input_var(var: Variable, formatting: str) -> str: - # Avoid circular import - from pymc.distributions.distribution import SymbolicRandomVariable + # These three strings are now formatted according to `formatting`. If latex, they + # just ultimately need to be wrapped in $. + return name, dist_name, args_str - def _is_potential_or_deterministic(var: Variable) -> bool: - try: - return var.str_repr.__func__.func is str_for_potential_or_deterministic - except AttributeError: - # in case other code overrides str_repr, fallback - return False +def _str_for_input_var(var: Variable, formatting: str) -> str: + """Make a human-readable string representation for a variable that + serves as input to another variable.""" + # Check for constants first, because they won't have var.owner if isinstance(var, (Constant, SharedVariable)): - return _str_for_constant(var, formatting) - elif isinstance( - var.owner.op, (RandomVariable, SymbolicRandomVariable) - ) or _is_potential_or_deterministic(var): - # show the names for RandomVariables, Deterministics, and Potentials, rather - # than the full expression - return _str_for_input_rv(var, formatting) + # Give the constant value if it's small enough, else basic type info. + # Previously _str_for_constant() + if isinstance(var, Constant): + var_data = var.data + var_type = "constant" + else: + var_data = var.get_value() + var_type = "shared" + if var_data.size == 1: + return f"{var_data.flatten()[0]:.3g}" + else: + return f"<{var_type} {var_data.shape}>" # TODO shape info or nah? elif isinstance(var.owner.op, DimShuffle): - return _str_for_input_var(var.owner.inputs[0], formatting) - else: - return _str_for_expression(var, formatting) - - -def _str_for_input_rv(var: Variable, formatting: str) -> str: - _str = ( - var.name - if var.name is not None - else str_for_dist(var, formatting=formatting, include_params=True) - ) - if "latex" in formatting: - return _latex_text_format(_latex_escape(_str.strip("$"))) - else: - return _str - - -def _str_for_constant(var: Union[Constant, SharedVariable], formatting: str) -> str: - if isinstance(var, Constant): - var_data = var.data - var_type = "constant" - else: - var_data = var.get_value() - var_type = "shared" - - if len(var_data.shape) == 0: - return f"{var_data:.3g}" - elif len(var_data.shape) == 1 and var_data.shape[0] == 1: - return f"{var_data[0]:.3g}" - elif "latex" in formatting: - return rf"\text{{<{var_type}>}}" + # Recurse + return _str_for_input_var(var.owner.inputs[0], formatting=formatting) + elif _is_potential_or_deterministic(var) or isinstance( + var.owner.op, (RandomVariable, SymbolicRandomVariable) + ): + if var.name: + # Give the name of the RV/Potential/Deterministic if available + return var.name + else: + # But if rv comes from .dist() we print the distribution with its args + _, dist_name, args_str = _get_varname_distname_args(var, formatting=formatting) + return f"{dist_name}({args_str})" + elif hasattr(var, "owner") and var.owner: + # Return an "expression" i.e. indicate that this variable is a function of other + # variables. Looks like f(arg1, ..., argN). Previously _str_for_expression() + args = _walk_expression_args(var) + args_str = _str_for_args_list(args, formatting=formatting) + return f"f({args_str})" else: - return rf"<{var_type}>" + raise ValueError("Unidentified variable in dist or expression args") -def _str_for_expression(var: Variable, formatting: str) -> str: - # Avoid circular import - from pymc.distributions.distribution import SymbolicRandomVariable +def _walk_expression_args(var: Variable) -> List[Variable]: + """Find all arguments of an expression""" - # construct a string like f(a1, ..., aN) listing all random variables a as arguments def _expand(x): if x.owner and (not isinstance(x.owner.op, (RandomVariable, SymbolicRandomVariable))): return reversed(x.owner.inputs) - parents = [ + return [ x for x in walk(nodes=var.owner.inputs, expand=_expand) if x.owner and isinstance(x.owner.op, (RandomVariable, SymbolicRandomVariable)) ] - names = [x.name for x in parents] - - if "latex" in formatting: - return ( - r"f(" - + ",~".join([_latex_text_format(_latex_escape(n.strip("$"))) for n in names]) - + ")" - ) - else: - return r"f(" + ", ".join([n.strip("$") for n in names]) + ")" -def _latex_text_format(text: str) -> str: - if r"\operatorname{" in text: - return text +def _str_for_args_list(args: List[Variable], formatting: str) -> str: + """Create a human-readable string representation for the list of inputs + to a distribution or expression.""" + strs = [_str_for_input_var(x, formatting=formatting) for x in args] + if formatting == "latex": + # Format the str as \text{} only if it hasn't been formatted yet and it isn't numeric + strs_formatted = [ + s + if r"\text" in s or r"\operatorname" in s or s == "f()" or s.replace(".", "").isdigit() + else _latex_clean_command(s, command="text") + for s in strs + ] + return ",~".join(strs_formatted) else: - return r"\text{" + text + "}" - - -def _latex_escape(text: str) -> str: - # Note that this is *NOT* a proper LaTeX escaper, on purpose. _repr_latex_ is - # primarily used in the context of Jupyter notebooks, which render using MathJax. - # MathJax is a subset of LaTeX proper, which expects only $ to be escaped. If we were - # to also escape e.g. _ (replace with \_), then "\_" will show up in the output, etc. - return text.replace("$", r"\$") - - -def _default_repr_pretty(obj: Union[TensorVariable, Model], p, cycle): - """Handy plug-in method to instruct IPython-like REPLs to use our str_repr above.""" - # we know that our str_repr does not recurse, so we can ignore cycle - try: - output = obj.str_repr() - # Find newlines and replace them with p.break_() - # (see IPython.lib.pretty._repr_pprint) - lines = output.splitlines() - with p.group(): - for idx, output_line in enumerate(lines): - if idx: - p.break_() - p.text(output_line) - except AttributeError: - # the default fallback option (no str_repr method) - IPython.lib.pretty._repr_pprint(obj, p, cycle) + return ", ".join(strs) + + +def _latex_clean_command(text: str, command: str) -> str: + r"""Prepare text for LaTeX and maybe wrap it in a \command{}.""" + text = text.replace("$", r"\$") # TODO do we want to keep dollar signs or strip them? + if not text.startswith(rf"\{command}"): + # The printing module is designed such that text never passes through this + # function more than once. However, in some cases the text may have already + # been formatted for LaTeX -- such as when accessing a RandomVariable's + # pre-specified _print_name. In these cases we avoid the double wrap. + text = rf"\{command}{{{text}}}" + # This escape is a workaround for pymc#6508. MathJax is the latex engine in Jupyter + # notebooks, but its behavior in \text commands deviates from canonical LaTeX: stuff + # in the \text block will be rendered more or less verbatim. Meanwhile, other + # engines such as KaTeX behave differently and expect certain characters to be + # escaped to be typeset as desired. We work around this by escaping out of the + # command itself, writing the character, then continuing on with the same command. + if command == "text": + text = text.replace("_", rf"}}\_\{command}{{") + text = text.replace("~", rf"}}~\{command}{{") + return text + + +def _is_potential_or_deterministic(var: Variable) -> bool: + # This is a bit hacky but seems like the best we got. We should write + # a test to make sure that Deterministic and Potential don't get updated + # without also modifying this function. + if ( + hasattr(var, "str_repr") + and callable(var.str_repr) + and isinstance(var.str_repr.__func__, functools.partial) + ): + args = [*var.str_repr.__func__.args, *var.str_repr.__func__.keywords.values()] + return "Deterministic" in args or "Potential" in args + return False + + +def _pymc_pprint(obj: Union[TensorVariable, Model], *args, **kwargs): + """Pretty-print method that instructs IPython to use our `str_repr()`. + + Note that `str_repr()` is assigned in the initialization functions for + the objects of interest, i.e. Distribution and Model. + """ + if hasattr(obj, "str_repr") and callable(obj.str_repr): + s = obj.str_repr() + else: + s = repr(obj) + # Allow IPython to deal with newlines and the actual printing to shell + IPython.lib.pretty._repr_pprint(s, *args, **kwargs) try: # register our custom pretty printer in ipython shells import IPython - IPython.lib.pretty.for_type(TensorVariable, _default_repr_pretty) - IPython.lib.pretty.for_type(Model, _default_repr_pretty) + IPython.lib.pretty.for_type(TensorVariable, _pymc_pprint) + IPython.lib.pretty.for_type(Model, _pymc_pprint) except (ModuleNotFoundError, AttributeError): # no ipython shell pass diff --git a/pymc/tests/test_printing.py b/pymc/tests/test_printing.py index 62c360ec7f..c503ffb5fe 100644 --- a/pymc/tests/test_printing.py +++ b/pymc/tests/test_printing.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import numpy as np +import pytest from pymc import Bernoulli, Censored, HalfCauchy, Mixture, StudentT from pymc.distributions import ( @@ -31,52 +32,41 @@ class BaseTestStrAndLatexRepr: - def test__repr_latex_(self): - for distribution, tex in zip(self.distributions, self.expected[("latex", True)]): - assert distribution._repr_latex_() == tex - - model_tex = self.model._repr_latex_() - - # make sure each variable is in the model - for tex in self.expected[("latex", True)]: - for segment in tex.strip("$").split(r"\sim"): - assert segment in model_tex - - def test_str_repr(self): - for str_format in self.formats: - for dist, text in zip(self.distributions, self.expected[str_format]): - assert dist.str_repr(*str_format) == text - - model_text = self.model.str_repr(*str_format) - for text in self.expected[str_format]: - if str_format[0] == "latex": - for segment in text.strip("$").split(r"\sim"): - assert segment in model_text - else: - assert text in model_text + @pytest.mark.parametrize("formatting", ["plain", "latex"]) + def test_var_str_repr(self, formatting): + for dist, expected in zip(self.distributions, self.expected[formatting]): + assert dist.str_repr(formatting) == expected + + @pytest.mark.parametrize("formatting", ["plain", "latex"]) + def test_model_str_repr(self, formatting): + # Tests that each printed variable is in the model string, but + # does not test for ordering + model_str = self.model.str_repr(formatting) + for expected in self.expected[formatting]: + if formatting == "latex": + expected = expected.replace(r"\sim", r"&\sim &").strip("$") + assert expected in model_str class TestMonolith(BaseTestStrAndLatexRepr): def setup_class(self): # True parameter values - alpha, sigma = 1, 1 - beta = [1, 2.5] - + alpha0, sigma0, beta0 = 1, 1, [1, 2.5] # Size of dataset size = 100 - # Predictor variable X = np.random.normal(size=(size, 2)).dot(np.array([[1, 0], [0, 0.2]])) - # Simulate outcome variable - Y = alpha + X.dot(beta) + np.random.randn(size) * sigma + Y = alpha0 + X.dot(beta0) + np.random.randn(size) * sigma0 with Model() as self.model: # TODO: some variables commented out here as they're not working properly # in v4 yet (9-jul-2021), so doesn't make sense to test str/latex for them # Priors for unknown model parameters alpha = Normal("alpha", mu=0, sigma=10) - b = Normal("beta", mu=0, sigma=10, size=(2,), observed=beta) + beta = Normal( + "beta", mu=0, sigma=10, size=(2,), observed=beta0 + ) # TODO why is this observed? sigma = HalfNormal("sigma", sigma=1) # Test Cholesky parameterization @@ -86,7 +76,8 @@ def setup_class(self): # nb1 = pm.NegativeBinomial( # "nb_with_mu_alpha", mu=pm.Normal("nbmu"), alpha=pm.Gamma("nbalpha", mu=6, sigma=1) # ) - nb2 = NegativeBinomial("nb_with_p_n", p=Uniform("nbp"), n=10) + nbp = Uniform("nbp") + nb_with_p_n = NegativeBinomial("nb_with_p_n", p=nbp, n=10) # SymbolicRV zip = ZeroInflatedPoisson("zip", 0.5, 5) @@ -98,7 +89,7 @@ def setup_class(self): nested_mix = Mixture("nested_mix", w, [comp_1, comp_2]) # Expected value of outcome - mu = Deterministic("mu", floatX(alpha + dot(X, b))) + mu = Deterministic("mu", floatX(alpha + dot(X, beta))) # add a bounded variable as well # bound_var = Bound(Normal, lower=1.0)("bound_var", mu=0, sigma=10) @@ -126,71 +117,83 @@ def setup_class(self): # add a potential as well pot = Potential("pot", mu**2) - self.distributions = [alpha, sigma, mu, b, Z, nb2, zip, w, nested_mix, Y_obs, pot] - self.deterministics_or_potentials = [mu, pot] + self.distributions = [ + alpha, + sigma, + Z, + nbp, + nb_with_p_n, + zip, + w, + nested_mix, + kron_normal, + dm, + mu, + beta, + Y_obs, + pot, + ] # tuples of (formatting, include_params) - self.formats = [("plain", True), ("plain", False), ("latex", True), ("latex", False)] + # self.formats = [("plain", True), ("plain", False), ("latex", True), ("latex", False)] self.expected = { - ("plain", True): [ - r"alpha ~ N(0, 10)", - r"sigma ~ N**+(0, 1)", - r"mu ~ Deterministic(f(beta, alpha))", - r"beta ~ N(0, 10)", - r"Z ~ N(f(), f())", - r"nb_with_p_n ~ NB(10, nbp)", - r"zip ~ MarginalMixture(f(), DiracDelta(0), Pois(5))", - r"w ~ Dir()", - ( - r"nested_mix ~ MarginalMixture(w, " - r"MarginalMixture(f(), DiracDelta(0), Pois(5)), " - r"Censored(Bern(0.5), -1, 1))" - ), - r"Y_obs ~ N(mu, sigma)", - r"pot ~ Potential(f(beta, alpha))", - ], - ("plain", False): [ - r"alpha ~ N", - r"sigma ~ N**+", - r"mu ~ Deterministic", - r"beta ~ N", - r"Z ~ N", - r"nb_with_p_n ~ NB", - r"zip ~ MarginalMixture", - r"w ~ Dir", - r"nested_mix ~ MarginalMixture", - r"Y_obs ~ N", - r"pot ~ Potential", + "plain": [ + "alpha ~ N(0, 10)", + "sigma ~ N**+(0, 1)", + "Z ~ N(f(), f())", + "nbp ~ U(0, 1)", + "nb_with_p_n ~ NB(10, nbp)", + "zip ~ MarginalMixture(f(), DiracDelta(0), Pois(5))", + "w ~ Dir()", + "nested_mix ~ MarginalMixture(w, MarginalMixture(f(), DiracDelta(0), Pois(5)), Censored(Bern(0.5), -1, 1))", + "kron_normal ~ KroneckerNormal(, 0, , )", + "dm ~ DirichletMN(5, )", + "mu ~ Deterministic(f(beta, alpha))", + "beta ~ N(0, 10)", + "Y_obs ~ N(mu, sigma)", + "pot ~ Potential(f(beta, alpha))", ], - ("latex", True): [ + # ("plain", False): [ + # r"alpha ~ N", + # r"sigma ~ N**+", + # r"mu ~ Deterministic", + # r"beta ~ N", + # r"Z ~ N", + # r"nb_with_p_n ~ NB", + # r"zip ~ MarginalMixture", + # r"w ~ Dir", + # r"nested_mix ~ MarginalMixture", + # r"Y_obs ~ N", + # r"pot ~ Potential", + # ], + "latex": [ r"$\text{alpha} \sim \operatorname{N}(0,~10)$", r"$\text{sigma} \sim \operatorname{N^{+}}(0,~1)$", - r"$\text{mu} \sim \operatorname{Deterministic}(f(\text{beta},~\text{alpha}))$", - r"$\text{beta} \sim \operatorname{N}(0,~10)$", r"$\text{Z} \sim \operatorname{N}(f(),~f())$", - r"$\text{nb_with_p_n} \sim \operatorname{NB}(10,~\text{nbp})$", + r"$\text{nbp} \sim \operatorname{U}(0,~1)$", + r"$\text{nb}\_\text{with}\_\text{p}\_\text{n} \sim \operatorname{NB}(10,~\text{nbp})$", r"$\text{zip} \sim \operatorname{MarginalMixture}(f(),~\operatorname{DiracDelta}(0),~\operatorname{Pois}(5))$", - r"$\text{w} \sim \operatorname{Dir}(\text{})$", - ( - r"$\text{nested_mix} \sim \operatorname{MarginalMixture}(\text{w}," - r"~\operatorname{MarginalMixture}(f(),~\operatorname{DiracDelta}(0),~\operatorname{Pois}(5))," - r"~\operatorname{Censored}(\operatorname{Bern}(0.5),~-1,~1))$" - ), - r"$\text{Y_obs} \sim \operatorname{N}(\text{mu},~\text{sigma})$", + r"$\text{w} \sim \operatorname{Dir}(\text{})$", + r"$\text{nested}\_\text{mix} \sim \operatorname{MarginalMixture}(\text{w},~\operatorname{MarginalMixture}(f(),~\operatorname{DiracDelta}(0),~\operatorname{Pois}(5)),~\operatorname{Censored}(\operatorname{Bern}(0.5),~\text{-1},~1))$", + r"$\text{kron}\_\text{normal} \sim \operatorname{KroneckerNormal}(\text{},~0,~\text{},~\text{})$", + r"$\text{dm} \sim \operatorname{DirichletMN}(5,~\text{})$", + r"$\text{mu} \sim \operatorname{Deterministic}(f(\text{beta},~\text{alpha}))$", + r"$\text{beta} \sim \operatorname{N}(0,~10)$", + r"$\text{Y}\_\text{obs} \sim \operatorname{N}(\text{mu},~\text{sigma})$", r"$\text{pot} \sim \operatorname{Potential}(f(\text{beta},~\text{alpha}))$", ], - ("latex", False): [ - r"$\text{alpha} \sim \operatorname{N}$", - r"$\text{sigma} \sim \operatorname{N^{+}}$", - r"$\text{mu} \sim \operatorname{Deterministic}$", - r"$\text{beta} \sim \operatorname{N}$", - r"$\text{Z} \sim \operatorname{N}$", - r"$\text{nb_with_p_n} \sim \operatorname{NB}$", - r"$\text{zip} \sim \operatorname{MarginalMixture}$", - r"$\text{w} \sim \operatorname{Dir}$", - r"$\text{nested_mix} \sim \operatorname{MarginalMixture}$", - r"$\text{Y_obs} \sim \operatorname{N}$", - r"$\text{pot} \sim \operatorname{Potential}$", - ], + # ("latex", False): [ + # r"$\text{alpha} \sim \operatorname{N}$", + # r"$\text{sigma} \sim \operatorname{N^{+}}$", + # r"$\text{mu} \sim \operatorname{Deterministic}$", + # r"$\text{beta} \sim \operatorname{N}$", + # r"$\text{Z} \sim \operatorname{N}$", + # r"$\text{nb_with_p_n} \sim \operatorname{NB}$", + # r"$\text{zip} \sim \operatorname{MarginalMixture}$", + # r"$\text{w} \sim \operatorname{Dir}$", + # r"$\text{nested_mix} \sim \operatorname{MarginalMixture}$", + # r"$\text{Y_obs} \sim \operatorname{N}$", + # r"$\text{pot} \sim \operatorname{Potential}$", + # ], } @@ -207,32 +210,32 @@ def setup_class(self): self.distributions = [a, b, c, d] # tuples of (formatting, include_params) - self.formats = [("plain", True), ("plain", False), ("latex", True), ("latex", False)] + # self.formats = [("plain", True), ("plain", False), ("latex", True), ("latex", False)] self.expected = { - ("plain", True): [ + "plain": [ r"a ~ N(2, 1)", - r"b ~ N(, 1)", + r"b ~ N(, 1)", r"c ~ N(2, 1)", - r"d ~ N(, 1)", - ], - ("plain", False): [ - r"a ~ N", - r"b ~ N", - r"c ~ N", - r"d ~ N", + r"d ~ N(, 1)", ], - ("latex", True): [ + # ("plain", False): [ + # r"a ~ N", + # r"b ~ N", + # r"c ~ N", + # r"d ~ N", + # ], + "latex": [ r"$\text{a} \sim \operatorname{N}(2,~1)$", - r"$\text{b} \sim \operatorname{N}(\text{},~1)$", + r"$\text{b} \sim \operatorname{N}(\text{},~1)$", r"$\text{c} \sim \operatorname{N}(2,~1)$", - r"$\text{d} \sim \operatorname{N}(\text{},~1)$", - ], - ("latex", False): [ - r"$\text{a} \sim \operatorname{N}$", - r"$\text{b} \sim \operatorname{N}$", - r"$\text{c} \sim \operatorname{N}$", - r"$\text{d} \sim \operatorname{N}$", + r"$\text{d} \sim \operatorname{N}(\text{},~1)$", ], + # ("latex", False): [ + # r"$\text{a} \sim \operatorname{N}$", + # r"$\text{b} \sim \operatorname{N}$", + # r"$\text{c} \sim \operatorname{N}$", + # r"$\text{d} \sim \operatorname{N}$", + # ], } @@ -248,11 +251,11 @@ def test_model_latex_repr_three_levels_model(): latex_repr = censored_model.str_repr(formatting="latex") expected = [ "$$", - "\\begin{array}{rcl}", - "\\text{mu} &\\sim & \\operatorname{N}(0,~5)\\\\\\text{sigma} &\\sim & " - "\\operatorname{C^{+}}(0,~2.5)\\\\\\text{censored_normal} &\\sim & " - "\\operatorname{Censored}(\\operatorname{N}(\\text{mu},~\\text{sigma}),~-2,~2)", - "\\end{array}", + r"\begin{array}{rcl}", + r"\text{mu} &\sim & \operatorname{N}(0,~5)\\\text{sigma} &\sim & " + r"\operatorname{C^{+}}(0,~2.5)\\\text{censored}\_\text{normal} &\sim & " + r"\operatorname{Censored}(\operatorname{N}(\text{mu},~\text{sigma}),~\text{-2},~2)", + r"\end{array}", "$$", ] assert [line.strip() for line in latex_repr.split("\n")] == expected @@ -266,11 +269,10 @@ def test_model_latex_repr_mixture_model(): latex_repr = mix_model.str_repr(formatting="latex") expected = [ "$$", - "\\begin{array}{rcl}", - "\\text{w} &\\sim & " - "\\operatorname{Dir}(\\text{})\\\\\\text{mix} &\\sim & " - "\\operatorname{MarginalMixture}(\\text{w},~\\operatorname{N}(0,~5),~\\operatorname{StudentT}(7,~0,~1))", - "\\end{array}", + r"\begin{array}{rcl}", + r"\text{w} &\sim & \operatorname{Dir}(\text{})\\\text{mix} &\sim & " + r"\operatorname{MarginalMixture}(\text{w},~\operatorname{N}(0,~5),~\operatorname{StudentT}(7,~0,~1))", + r"\end{array}", "$$", ] assert [line.strip() for line in latex_repr.split("\n")] == expected