From 038adb152b7dc96c48d5010b85f268ad3788d28e Mon Sep 17 00:00:00 2001 From: Lance Chua Date: Mon, 29 Apr 2024 01:48:47 +0800 Subject: [PATCH 1/3] Update docs when list is passed to trace in sample for partial trace --- pymc/sampling/mcmc.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 7cbd2da98a..59a71b7938 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -630,10 +630,7 @@ def sample( else: kwargs["nuts"] = {"target_accept": kwargs.pop("target_accept")} if isinstance(trace, list): - raise DeprecationWarning( - "We have removed support for partial traces because it simplified things." - " Please open an issue if & why this is a problem for you." - ) + raise DeprecationWarning("Please use `var_names` keyword argument for partial traces.") model = modelcontext(model) if not model.free_RVs: From 45e6e7934a3d623bd21965468e59ca23ead805a1 Mon Sep 17 00:00:00 2001 From: Lance Chua Date: Mon, 29 Apr 2024 02:56:41 +0800 Subject: [PATCH 2/3] Change DeprecationWarning to ValueError Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- pymc/sampling/mcmc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 59a71b7938..f05fcc2dd0 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -630,7 +630,7 @@ def sample( else: kwargs["nuts"] = {"target_accept": kwargs.pop("target_accept")} if isinstance(trace, list): - raise DeprecationWarning("Please use `var_names` keyword argument for partial traces.") + raise ValueError("Please use `var_names` keyword argument for partial traces.") model = modelcontext(model) if not model.free_RVs: From 4498433c75a6298e80e92d9c829134bebc674aa3 Mon Sep 17 00:00:00 2001 From: Lance Chua Date: Mon, 29 Apr 2024 09:54:16 +0800 Subject: [PATCH 3/3] Update test_partial_trace_unsupported --- tests/sampling/test_mcmc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/sampling/test_mcmc.py b/tests/sampling/test_mcmc.py index 647e8ada70..8611a7028b 100644 --- a/tests/sampling/test_mcmc.py +++ b/tests/sampling/test_mcmc.py @@ -507,11 +507,11 @@ def test_empty_model(): error.match("any free variables") -def test_partial_trace_unsupported(): +def test_partial_trace_with_trace_unsupported(): with pm.Model() as model: a = pm.Normal("a", mu=0, sigma=1) b = pm.Normal("b", mu=0, sigma=1) - with pytest.raises(DeprecationWarning, match="removed support"): + with pytest.raises(ValueError, match="var_names"): pm.sample(trace=[a])