Skip to content

Commit

Permalink
Refactor pretty-printing and fix latex underscores, addresses 6508
Browse files Browse the repository at this point in the history
  • Loading branch information
covertg committed Feb 18, 2023
1 parent c2ce47f commit 8370d97
Show file tree
Hide file tree
Showing 4 changed files with 309 additions and 318 deletions.
8 changes: 5 additions & 3 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
)
from pymc.logprob.rewriting import logprob_rewrites_db
from pymc.model import BlockModelAccess
from pymc.printing import str_for_dist
from pymc.pytensorf import collect_default_updates, convert_observed_data
from pymc.util import UNSET, _add_future_warning_tag
from pymc.vartypes import string_types
Expand Down Expand Up @@ -320,9 +319,12 @@ def __new__(
)

# add in pretty-printing support
rv_out.str_repr = types.MethodType(str_for_dist, rv_out)
from pymc.printing import str_for_model_var

rv_out.str_repr = types.MethodType(str_for_model_var, rv_out)
rv_out._repr_latex_ = types.MethodType(
functools.partial(str_for_dist, formatting="latex"), rv_out
functools.partial(str_for_model_var, formatting="latex"),
rv_out,
)

rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
Expand Down
16 changes: 6 additions & 10 deletions pymc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2015,15 +2015,13 @@ def Deterministic(name, var, model=None, dims=None):
model.deterministics.append(var)
model.add_named_variable(var, dims)

from pymc.printing import str_for_potential_or_deterministic
from pymc.printing import str_for_model_var

var.str_repr = types.MethodType(
functools.partial(str_for_potential_or_deterministic, dist_name="Deterministic"), var
functools.partial(str_for_model_var, dist_name="Deterministic"), var
)
var._repr_latex_ = types.MethodType(
functools.partial(
str_for_potential_or_deterministic, dist_name="Deterministic", formatting="latex"
),
functools.partial(str_for_model_var, dist_name="Deterministic", formatting="latex"),
var,
)

Expand All @@ -2047,15 +2045,13 @@ def Potential(name, var, model=None):
model.potentials.append(var)
model.add_named_variable(var)

from pymc.printing import str_for_potential_or_deterministic
from pymc.printing import str_for_model_var

var.str_repr = types.MethodType(
functools.partial(str_for_potential_or_deterministic, dist_name="Potential"), var
functools.partial(str_for_model_var, dist_name="Potential"), var
)
var._repr_latex_ = types.MethodType(
functools.partial(
str_for_potential_or_deterministic, dist_name="Potential", formatting="latex"
),
functools.partial(str_for_model_var, dist_name="Potential", formatting="latex"),
var,
)

Expand Down
Loading

0 comments on commit 8370d97

Please sign in to comment.