Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce value variables in logprob IR #7491

Merged
merged 6 commits into from
Sep 11, 2024

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Sep 4, 2024

Description

This supersedes stale #6918, it introduces ValuedRV nodes in the IR so rewrites can transparently reason about the conditioning points. The main purpose is to simplify the IR rewrite logic.

It also prevents default PyTensor rewrites from breaking dependency on valued RVs (which was behind the bug in #6917)

It also fixes some limitations in derived Scans, and makes it more strict. For instance, #6909 now fails explicitly instead of returning wrong result silently.

Related Issue


📚 Documentation preview 📚: https://pymc--7491.org.readthedocs.build/en/7491/

@ricardoV94
Copy link
Member Author

The failing test should pass after the changes in #7480

Copy link

codecov bot commented Sep 7, 2024

Codecov Report

Attention: Patch coverage is 95.33528% with 16 lines in your changes missing coverage. Please review.

Project coverage is 92.42%. Comparing base (2856062) to head (bac9f6e).
Report is 6 commits behind head on main.

Files with missing lines Patch % Lines
pymc/logprob/scan.py 88.00% 9 Missing ⚠️
pymc/logprob/transforms.py 87.50% 3 Missing ⚠️
pymc/logprob/mixture.py 96.55% 2 Missing ⚠️
pymc/logprob/abstract.py 95.45% 1 Missing ⚠️
pymc/logprob/cumsum.py 85.71% 1 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7491      +/-   ##
==========================================
+ Coverage   92.15%   92.42%   +0.27%     
==========================================
  Files         103      103              
  Lines       17208    17104     -104     
==========================================
- Hits        15858    15809      -49     
+ Misses       1350     1295      -55     
Files with missing lines Coverage Δ
pymc/logprob/basic.py 94.28% <100.00%> (-0.09%) ⬇️
pymc/logprob/binary.py 96.29% <100.00%> (+2.04%) ⬆️
pymc/logprob/censoring.py 97.67% <100.00%> (+2.02%) ⬆️
pymc/logprob/checks.py 100.00% <100.00%> (+9.80%) ⬆️
pymc/logprob/order.py 96.77% <100.00%> (+3.02%) ⬆️
pymc/logprob/rewriting.py 100.00% <100.00%> (+10.24%) ⬆️
pymc/logprob/tensor.py 100.00% <100.00%> (+22.95%) ⬆️
pymc/logprob/transform_value.py 98.14% <100.00%> (+2.78%) ⬆️
pymc/logprob/utils.py 92.46% <100.00%> (+0.21%) ⬆️
pymc/testing.py 89.77% <ø> (ø)
... and 5 more

@twiecki
Copy link
Member

twiecki commented Sep 9, 2024

This is pretty substantial.

Copy link
Member

@larryshamalama larryshamalama left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ricardoV94, great PR, as always, especially with so much extension in abstraction :)

A lot of my comments are to clarify some parts of the code and me thinking out loud as I went over the PR

Are you able to add a test showing an IR graph after applying rewrites from early_measurable_ir_rewrites_db but before measurable_ir_rewrites_db? Perhaps this could highlight where and why exactly PromisedValuedRVs are needed

pymc/logprob/abstract.py Outdated Show resolved Hide resolved
pymc/logprob/abstract.py Outdated Show resolved Hide resolved


class PromisedValuedRV(Op):
r"""Marks a variable as being promised a valued variable in the logprob method."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't we directly provide a corresponding ValuedRV during the rewrite? Just asking to clarify. I see that this (only?) is used for Joins and MakeVectors

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revisiting this comment after skimming a bit more

I now see that there are a set of newly defined early measurable ir rewrites and I imagine that these promised value RVs are temporarily substituted into that intermediate graph before applying the other logprob rewrites. Why does this approach prevent breaking interdependencies?

Copy link
Member Author

@ricardoV94 ricardoV94 Sep 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're trying to fix the same problem of not breaking dependencies. The funny thing with Join/MakeVector is that they combine multiple RVs, potentially interdependent, into a single composite valued node. Only in the logp function is this value split and sent to each component, but we still want to prevent the same outer-problem that motivated this PR. If you have:

x1 = pt.random.normal() * 5
x2 = pt.random.normal(x1 * 8)
xs = pt.stack([x1, x2])
xs_vv = pt.vector("xs_vv", shape=(2,))

By the time you get to the logp of xs_vv, you want to basically do the same thing that the new test does, plus concatenate the logp terms. To prevent the conditioning from breaking during the IR rewrites we would want to introduce the ValuedRV nodes. We have to do this as an early rewrite, before anything comes up and breaks it. In normal situations we also do it before any rewrites in the construct_ir_fgraph code, but for join/stack this already accounts as some sort of inference.

Now why Promised and not just vanilla Valued? Just for convenience, because we still want a function from xs_vv to stack([logp(x1), logp(x2)]) in the end, and if I split the values, I wouldn't know how to stack them later in the loop in conditional_logp (or if this was a single logp call), or even know that I needed to. So we just "promise" there will be values to avoid rewrites from breaking the dependency that will be required in the logp function.

Another way of thinking is that we are basically trying to truncate the graphs between valued nodes and do manipulations only within these subgraphs, not across them. We could have literally split the graph, done IR in each and collect logp expressions. What we are doing is identical. Either way, we would still need to do a further split for the RVs within Join/MakeVector.

b_value = b.type()
logp_b = conditional_logp({a: a_value, b: b_value})[b_value]

assert_no_rvs(logp_b)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Brainstorming to see if I understand here. From #6917, IIUC assert_no_rvs(logp_b) would throw a warning (or an error?).

When would we want b to be rewritten as pm.Normal.dist(a_base * 40)? When, say, a_base_value = a_base.type() is provided in conditional_logp? My understanding is that, while both situations are mathematically equivalent, PyMC's log-prob inference does not work well when value variables (a_value) of deterministic transformations of RandomVariables (a) are provided in lieu of their value variables (a_base.type(); not provided). Is this correct?

Copy link
Member Author

@ricardoV94 ricardoV94 Sep 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the general problem is that a second variable cannot depend on a transformation of the value variable. We don't have the machinery to do that sort of inversion (other than the ad-hoc transform_value rewrite). For example ,the following is not supported:

x = pt.random.normal()
x_exp = pt.exp(x)
y = pt.random.normal(loc=x)

x_exp_vv = pt.scalar("x_exp_vv")
y_vv = pt.scalar("y_vv")
conditional_logprob({x_exp: x_exp_vv, y: y_vv})

There is nothing wrong in principle, but we don't have the machinery to find out that the density of y should depend on a (log) transform of x_exp_vv. Would be nice to have, but I haven't come across an elegant solution. The changes in this PR prevent our rewrites from introducing such indirections by accident by changing the IR of this test example to something like.

 a_base = pm.Normal.dist()
 a = valued_rv(a_base * 5, a_value)
 b = valued_rv(pm.Normal.dist(a * 8), b_value)

Since there are no default PyTensor rewrites that know what to do with a valued_rv, there is no risk of "mixing" information before and after the conditioning points (in this case, constant_folding 5 * 8 = 40 in the graph of b)

This avoids rewrites across conditioning points, that could break dependencies

Also extend logprob derivation of scans with multiple valued output types
@ricardoV94
Copy link
Member Author

ricardoV94 commented Sep 10, 2024

@larryshamalama thanks for the thoughtful questions! I cleaned up the answers I gave above and added them in the docstrings of ValuedRV and PromisedValuedRV.

@ricardoV94 ricardoV94 merged commit 97df9c3 into pymc-devs:main Sep 11, 2024
22 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Relationship between valued RVs lost during logp inference
3 participants