Skip to content

Commit

Permalink
Refactoring printing module with aeppl/printing.py with additions
Browse files Browse the repository at this point in the history
  • Loading branch information
larryshamalama committed Jan 9, 2023
1 parent 76003e8 commit e9d7210
Show file tree
Hide file tree
Showing 5 changed files with 565 additions and 229 deletions.
8 changes: 6 additions & 2 deletions pymc/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -1249,9 +1249,13 @@ def _zero_inflated_mixture(*, name, nonzero_p, nonzero_dist, **kwargs):
nonzero_dist,
]
if name is not None:
return Mixture(name, weights, comp_dists, **kwargs)
out_rv = Mixture(name, weights, comp_dists, **kwargs)
else:
return Mixture.dist(weights, comp_dists, **kwargs)
out_rv = Mixture.dist(weights, comp_dists, **kwargs)

out_rv.is_zero_inflated = True

return out_rv


class ZeroInflatedPoisson:
Expand Down
8 changes: 0 additions & 8 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import contextvars
import functools
import sys
import types
import warnings

from abc import ABCMeta
Expand Down Expand Up @@ -57,7 +56,6 @@
)
from pymc.logprob.rewriting import logprob_rewrites_db
from pymc.model import BlockModelAccess
from pymc.printing import str_for_dist
from pymc.pytensorf import collect_default_updates, convert_observed_data
from pymc.util import UNSET, _add_future_warning_tag
from pymc.vartypes import string_types
Expand Down Expand Up @@ -317,12 +315,6 @@ def __new__(
initval=initval,
)

# add in pretty-printing support
rv_out.str_repr = types.MethodType(str_for_dist, rv_out)
rv_out._repr_latex_ = types.MethodType(
functools.partial(str_for_dist, formatting="latex"), rv_out
)

rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)")
rv_out.random = _make_nice_attr_error("rv.random()", "pm.draw(rv)")
Expand Down
1 change: 1 addition & 0 deletions pymc/distributions/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ class Mixture(Distribution):
"""

rv_type = MarginalMixtureRV
is_zero_inflated = False

@classmethod
def dist(cls, w, comp_dists, **kwargs):
Expand Down
49 changes: 23 additions & 26 deletions pymc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import itertools
import threading
import types
import warnings
Expand Down Expand Up @@ -589,10 +589,7 @@ def __init__(

from pymc.printing import str_for_model

self.str_repr = types.MethodType(str_for_model, self)
self._repr_latex_ = types.MethodType(
functools.partial(str_for_model, formatting="latex"), self
)
self._repr_latex_ = types.MethodType(str_for_model, self)

@property
def model(self):
Expand Down Expand Up @@ -2015,17 +2012,17 @@ def Deterministic(name, var, model=None, dims=None):
model.deterministics.append(var)
model.add_named_variable(var, dims)

from pymc.printing import str_for_potential_or_deterministic
# from pymc.printing import str_for_potential_or_deterministic

var.str_repr = types.MethodType(
functools.partial(str_for_potential_or_deterministic, dist_name="Deterministic"), var
)
var._repr_latex_ = types.MethodType(
functools.partial(
str_for_potential_or_deterministic, dist_name="Deterministic", formatting="latex"
),
var,
)
# var.str_repr = types.MethodType(
# functools.partial(str_for_potential_or_deterministic, dist_name="Deterministic"), var
# )
# var._repr_latex_ = types.MethodType(
# functools.partial(
# str_for_potential_or_deterministic, dist_name="Deterministic", formatting="latex"
# ),
# var,
# )

return var

Expand All @@ -2047,16 +2044,16 @@ def Potential(name, var, model=None):
model.potentials.append(var)
model.add_named_variable(var)

from pymc.printing import str_for_potential_or_deterministic

var.str_repr = types.MethodType(
functools.partial(str_for_potential_or_deterministic, dist_name="Potential"), var
)
var._repr_latex_ = types.MethodType(
functools.partial(
str_for_potential_or_deterministic, dist_name="Potential", formatting="latex"
),
var,
)
# from pymc.printing import str_for_potential_or_deterministic

# var.str_repr = types.MethodType(
# functools.partial(str_for_potential_or_deterministic, dist_name="Potential"), var
# )
# var._repr_latex_ = types.MethodType(
# functools.partial(
# str_for_potential_or_deterministic, dist_name="Potential", formatting="latex"
# ),
# var,
# )

return var
Loading

0 comments on commit e9d7210

Please sign in to comment.