From 005032f10099188fea86f63b6baa46a27867983f Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 8 Jun 2021 09:56:32 -0600 Subject: [PATCH] Fix LocScaleReparam(1.0) (#2863) --- pyro/infer/reparam/loc_scale.py | 2 +- tests/infer/reparam/test_loc_scale.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyro/infer/reparam/loc_scale.py b/pyro/infer/reparam/loc_scale.py index 8bc8613c40..628584492d 100644 --- a/pyro/infer/reparam/loc_scale.py +++ b/pyro/infer/reparam/loc_scale.py @@ -49,7 +49,7 @@ def __call__(self, name, fn, obs): assert obs is None, "LocScaleReparam does not support observe statements" centered = self.centered if is_identically_one(centered): - return name, fn, obs + return fn, obs event_shape = fn.event_shape fn, event_dim = self._unwrap(fn) diff --git a/tests/infer/reparam/test_loc_scale.py b/tests/infer/reparam/test_loc_scale.py index 8c2b41b2e4..40fd64f4ed 100644 --- a/tests/infer/reparam/test_loc_scale.py +++ b/tests/infer/reparam/test_loc_scale.py @@ -46,9 +46,9 @@ def model(): expected_probe = get_moments(value) if "dist_type" == "Normal": - reparam = LocScaleReparam() + reparam = LocScaleReparam(centered) else: - reparam = LocScaleReparam(shape_params=["df"]) + reparam = LocScaleReparam(centered, shape_params=["df"]) reparam_model = poutine.reparam(model, {"x": reparam}) value = poutine.trace(reparam_model).get_trace().nodes["x"]["value"] actual_probe = get_moments(value)