Skip to content

Commit

Permalink
Fix bug in compute_deterministics
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Apr 12, 2024
1 parent 7ffd47d commit 6073d8d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
3 changes: 2 additions & 1 deletion pymc/sampling/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def compute_deterministics(

if var_names is None:
deterministics = model.deterministics
var_names = [det.name for det in deterministics]
else:
deterministics = [model[var_name] for var_name in var_names]
if not set(deterministics).issubset(set(model.deterministics)):
Expand All @@ -101,7 +102,7 @@ def compute_deterministics(
new_dataset = apply_function_over_dataset(
fn,
dataset[[rv.name for rv in model.free_RVs]],
output_var_names=[det.name for det in model.deterministics],
output_var_names=var_names,
dims=dims,
coords=coords,
sample_dims=sample_dims,
Expand Down
5 changes: 5 additions & 0 deletions tests/sampling/test_deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ def test_compute_deterministics():
assert extended_with_mu["mu"].dims == ("chain", "draw", "group")
assert_allclose(extended_with_mu["mu"], dataset["mu_raw"].cumsum("group"))

only_sigma = compute_deterministics(dataset, var_names=["sigma"], model=m, progressbar=False)
assert set(only_sigma.data_vars.variables) == {"sigma"}
assert only_sigma["sigma"].dims == ("chain", "draw")
assert_allclose(only_sigma["sigma"], np.exp(dataset["sigma_raw"]))


def test_docstring_example():
import pymc as pm
Expand Down

0 comments on commit 6073d8d

Please sign in to comment.