Skip to content

Commit

Permalink
implemented fix for escaping underscores in latex repr and added a un… (
Browse files Browse the repository at this point in the history
pymc-devs#7501)

* implemented fix for escaping underscores in latex repr and added a unit test

* updated unit test staticmethod to include underscore in var name

* add underscore escape fix to distribution repr as well as model repr, fixed testing to expect underscores in LaTeX representation to be escaped

* added cleaner method using re to escape underscores, added cleaner test to assert underscores are escaped
  • Loading branch information
Dekermanjian authored and aphc14 committed Sep 17, 2024
1 parent 7bd0f44 commit 33f701a
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 7 deletions.
11 changes: 11 additions & 0 deletions pymc/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.


import re

from functools import partial

from pytensor.compile import SharedVariable
Expand Down Expand Up @@ -58,6 +60,7 @@ def str_for_dist(
if "latex" in formatting:
if print_name is not None:
print_name = r"\text{" + _latex_escape(print_name.strip("$")) + "}"
print_name = _format_underscore(print_name)

op_name = (
dist.owner.op._print_name[1]
Expand Down Expand Up @@ -114,6 +117,7 @@ def str_for_model(model: Model, formatting: str = "plain", include_params: bool
if not var_reprs:
return ""
if "latex" in formatting:
var_reprs = [_format_underscore(x) for x in var_reprs]
var_reprs = [
var_repr.replace(r"\sim", r"&\sim &").strip("$")
for var_repr in var_reprs
Expand Down Expand Up @@ -295,3 +299,10 @@ def _default_repr_pretty(obj: TensorVariable | Model, p, cycle):
except (ModuleNotFoundError, AttributeError):
# no ipython shell
pass


def _format_underscore(variable: str) -> str:
"""
Escapes all unescaped underscores in the variable name for LaTeX representation.
"""
return re.sub(r"(?<!\\)_", r"\\_", variable)
34 changes: 27 additions & 7 deletions tests/test_printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np

from pytensor.tensor.random import normal
Expand Down Expand Up @@ -171,15 +172,15 @@ def setup_class(self):
r"$\text{mu} \sim \operatorname{Deterministic}(f(\text{beta},~\text{alpha}))$",
r"$\text{beta} \sim \operatorname{Normal}(0,~10)$",
r"$\text{Z} \sim \operatorname{MultivariateNormal}(f(),~f())$",
r"$\text{nb_with_p_n} \sim \operatorname{NegativeBinomial}(10,~\text{nbp})$",
r"$\text{nb\_with\_p\_n} \sim \operatorname{NegativeBinomial}(10,~\text{nbp})$",
r"$\text{zip} \sim \operatorname{MarginalMixture}(f(),~\operatorname{DiracDelta}(0),~\operatorname{Poisson}(5))$",
r"$\text{w} \sim \operatorname{Dirichlet}(\text{<constant>})$",
(
r"$\text{nested_mix} \sim \operatorname{MarginalMixture}(\text{w},"
r"$\text{nested\_mix} \sim \operatorname{MarginalMixture}(\text{w},"
r"~\operatorname{MarginalMixture}(f(),~\operatorname{DiracDelta}(0),~\operatorname{Poisson}(5)),"
r"~\operatorname{Censored}(\operatorname{Bernoulli}(0.5),~-1,~1))$"
),
r"$\text{Y_obs} \sim \operatorname{Normal}(\text{mu},~\text{sigma})$",
r"$\text{Y\_obs} \sim \operatorname{Normal}(\text{mu},~\text{sigma})$",
r"$\text{pot} \sim \operatorname{Potential}(f(\text{beta},~\text{alpha}))$",
r"$\text{pred} \sim \operatorname{Deterministic}(f(\text{<normal>}))",
],
Expand All @@ -189,11 +190,11 @@ def setup_class(self):
r"$\text{mu} \sim \operatorname{Deterministic}$",
r"$\text{beta} \sim \operatorname{Normal}$",
r"$\text{Z} \sim \operatorname{MultivariateNormal}$",
r"$\text{nb_with_p_n} \sim \operatorname{NegativeBinomial}$",
r"$\text{nb\_with\_p\_n} \sim \operatorname{NegativeBinomial}$",
r"$\text{zip} \sim \operatorname{MarginalMixture}$",
r"$\text{w} \sim \operatorname{Dirichlet}$",
r"$\text{nested_mix} \sim \operatorname{MarginalMixture}$",
r"$\text{Y_obs} \sim \operatorname{Normal}$",
r"$\text{nested\_mix} \sim \operatorname{MarginalMixture}$",
r"$\text{Y\_obs} \sim \operatorname{Normal}$",
r"$\text{pot} \sim \operatorname{Potential}$",
r"$\text{pred} \sim \operatorname{Deterministic}",
],
Expand Down Expand Up @@ -256,7 +257,7 @@ def test_model_latex_repr_three_levels_model():
"$$",
"\\begin{array}{rcl}",
"\\text{mu} &\\sim & \\operatorname{Normal}(0,~5)\\\\\\text{sigma} &\\sim & "
"\\operatorname{HalfCauchy}(0,~2.5)\\\\\\text{censored_normal} &\\sim & "
"\\operatorname{HalfCauchy}(0,~2.5)\\\\\\text{censored\\_normal} &\\sim & "
"\\operatorname{Censored}(\\operatorname{Normal}(\\text{mu},~\\text{sigma}),~-2,~2)",
"\\end{array}",
"$$",
Expand Down Expand Up @@ -316,3 +317,22 @@ def random(rng, mu, size):

str_repr = model.str_repr(include_params=False)
assert str_repr == "\n".join(["x ~ CustomDistNormal", "y ~ CustomRandomNormal"])


class TestLatexRepr:
@staticmethod
def simple_model() -> Model:
with Model() as simple_model:
error = HalfNormal("error", 0.5)
alpha_a = Normal("alpha_a", 0, 1)
Normal("y", alpha_a, error)
return simple_model

def test_latex_escaped_underscore(self):
"""
Ensures that all underscores in model variable names are properly escaped for LaTeX representation
"""
model = self.simple_model()
model_str = model.str_repr(formatting="latex")
assert "\\_" in model_str
assert "_" not in model_str.replace("\\_", "")

0 comments on commit 33f701a

Please sign in to comment.