Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consistent API for user friendly distribution customization #4530

Closed
ricardoV94 opened this issue Mar 11, 2021 · 10 comments · Fixed by #6361
Closed

Consistent API for user friendly distribution customization #4530

ricardoV94 opened this issue Mar 11, 2021 · 10 comments · Fixed by #6361

Comments

@ricardoV94
Copy link
Member

ricardoV94 commented Mar 11, 2021

The following common "customized distributions" are missing / faulty / only implemented for unobserved variables. It would be nice to have a standard API for when all of these get implemented for both unobserved and observed distributions.

This issue is intended to discuss a possible API going forward when implementing these custom distributions. I will illustrate with the example of a user who wants to create an observable shifted exponential distribution as in #4507. I will call this helper method pm.Shift but probably something like pm.Affine that allows both shifting and scaling would be better.

# We want to shift by a random normal variable
a = pm.Normal('a', 0, 1)

# Adapt Deterministic syntax. Requires naming the raw variable and probably 
# incompatible with current RV registration logic
x = pm.Shift('x', pm.Exponential('x_raw', 1), shift=a, observed=data)

# Adapt syntax used in pm.Mixture. Requires the non intuitive .dist() call
x = pm.Shift('x', pm.Exponential.dist(1), shift=a, observed=data)

# Adapt syntax from pm.Bound. Not very intuitive. Also distribution and transformation 
# parameters are separated from class Names. Not obvious whether data is expected
# to be on the original or shifted scale (should be on the shifted scale)
x = pm.Shift(pm.Exponential, shift=a)('x', 1, observed=data)

# Add a generic "modify" argument
x = pm.Exponential('x', 1, modify=pm.Shift(shift=a), observed=data)

# Add all necessary arguments to all distributions
# sort = bool
# lower / upper = float (for truncation / censoring)
# shift / scale = float 
x = pm.Exponential('x', 1, shift=a, observed=data)

# Add new methods to RVs and (optionally) separate initialization from conditioning
# to remove ambiguity as to whether the data should be on the shifted scale (should be)
x = pm.Exponential('x', 1).shift(a)
x.observe(data)

# Same logic as above but with operator overloading
# Conflicts with normal aesara operators
x = pm.Exponential('x', 1) + a
x.observe(data)

# Other examples of (impractical) operator overloading.
x = pm.Exponential('x', 1, size=2) 
new_x = x + a  # Shifting
new_x = x < a  # Truncating 
new_x = x[x > a] = a  # Censoring
new_x = x[(None, a)]  # Truncating alternative
new_x = x[[None, a]]  # Censoring alternative 
new_x = x[0] < x[1]  # Sorting, more generally x[:-1] < x[1:]
new_x.observe(data)

# Something else?
@brandonwillard
Copy link
Contributor

brandonwillard commented Mar 11, 2021

We need an interface that maintains the separation between random variables and their corresponding (transformed) log-likelihoods. This is important, because most users won't need/want to interact with the transformed log-likelihoods.

We can always—and currently—define a model in the transformed space, which is what your use of pm.Shift and statements like pm.Exponential('x', 1) + a seem to imply, but the goal is to preserve the original, often more understandable untransformed "sample-space" graph while still allowing for easy, user-defined reparameterizations of the "measure-space" (i.e. log-likelihood) graph.

Also, we should keep in mind that these transforms are almost exclusively a utility for (some) samplers. Because of that fact, we shouldn't accidentally expand this transform interface and functionality into areas outside of its applicability.

Otherwise, my main concern right now revolves around the storage, access, organization, and flexibility surrounding transforms.

For instance, in #4521, the transform interface is essentially the same as v3—i.e. default transforms are applied automatically and customization occurs through a transform keyword—with the exception that the Model object contains no corresponding transformed random variables. (We could add transformed random variables, but I don't see the point, because it's not like we need to sample them, right?)

Instead, the Model.vars list, which consists of the the symbolic arguments to the model's total log-likelihood (e.g. the a and b parameters in P(Y=y | A=a, B=b) for unobserved RVs A and B), are always in the transformed space. This works, but it's rather inflexible, because it puts all the log-likelihood results in the transformed space by default, and we should really provide results in the same form that the user specified.

Also, one can't change the transforms in this—or the current/v3—context, at least not without recreating the entire model. This is where we run into the concerns surrounding #4529, because we have everything we need to create newly transformed log-likelihoods in a Model object, but we can't because we've decided to weave a very specific log-likelihood into the fabric of Model's identity.

In early discussions about v4, I mentioned that we should leave it to the functions that require transforms to transform "naive" (partially) user-specified log-likelihood graphs themselves. I still believe this is the best approach. In this case, the Model object might only carry information about the default/user-specified transforms for each random variable's log-likelihood input variable.

CC @rlouf

@ricardoV94
Copy link
Member Author

ricardoV94 commented Mar 12, 2021

I probably shouldn't have called these transforms. These are more like random variable factories to create customized versions of default distributions. Although they sometimes look the same / require the use of different automatic transformations?

I changed the original post.

A discussion for automatic and user defined transformations is still needed, but maybe it can be done separately?

@ricardoV94 ricardoV94 changed the title Consistent API for user friendly variable transformations Consistent API for user friendly distribution customization Mar 12, 2021
@rlouf
Copy link
Contributor

rlouf commented Mar 12, 2021

A discussion for automatic and user defined transformations is still needed, but maybe it can be done separately?

Happy to participate in this discussion; we talked about it with Brandon this week and I have strong opinions. I'm working on an alternative approach that generalizes to any transport map (like "normalizing flows"). It will most certainly be implemented in BlackJAX.

The automatic part will still be the PPL's responsibility however.

@ricardoV94
Copy link
Member Author

In early discussions about v4, I mentioned that we should leave it to the functions that require transforms to transform "naive" (partially) user-specified log-likelihood graphs themselves. I still believe this is the best approach. In this case, the Model object might only carry information about the default/user-specified transforms for each random variable's log-likelihood input variable.

That sounds more clean indeed.

Do you want to open a discussion issue for this (and for @rlouf strong opinions), now that I realized I was talking about something different (and that my examples don't look like a useful introduction to that discussion)

@brandonwillard
Copy link
Contributor

Do you want to open a discussion issue for this (and for @rlouf strong opinions), now that I realized I was talking about something different (and that my examples don't look like a useful introduction to that discussion)

Even so, I would like to make sure that I understand what you're interested in here.

Is this issue about explicitly transforming random variables at the user/model definition level: e.g. X = Normal(0, 1); Z = b * X + a? In other words, are you talking about something that's closer to v3's Deterministics?

@ricardoV94
Copy link
Member Author

ricardoV94 commented Mar 12, 2021

It's about how to offer the ability to create non standard distributions from standard distributions in a coherent way.

  • Exponential -> Mixture of exponentials
  • Exponential -> Sorted exponentials
  • Exponential -> Truncated exponential
  • Exponential -> Censored exponential
  • Exponential -> Shifted / scaled Exponential

All of which should allow random sampling as well as logp evaluation.

@brandonwillard
Copy link
Contributor

brandonwillard commented Mar 12, 2021

The sampling can already be accomplished entirely through Aesara—as long as it can be expressed using Aesara Ops, that is.

For example, if you want a mixture:

import aesara
import aesara.tensor as at
import aesara.tensor.random.basic as ar


S_rv = ar.bernoulli(0.5, size=10)
Z_rv = at.stack([ar.normal(0, 1, size=10), ar.normal(100, 1, size=10)])
Y_rv = at.choose(S_rv, Z_rv)

Y_sampler = aesara.function([], Y_rv)
>>> Y_sampler()
array([  0.1251257 , 100.11507077,  -0.6616374 ,  98.30614853,
        98.92522735, 100.47781691, 101.20954759,  -0.79534251,
        -0.51445939,  -0.40730754])

The log-likelihood capabilities would be provided by pymc3.distributions.logpt in v4 (e.g. logpt(Y_rv, y_vals) would return a graph of the log-likelihood for the above mixture).

The current implementation of logpt won't work for that example right now, but we can definitely make it work.

@ricardoV94
Copy link
Member Author

ricardoV94 commented Mar 13, 2021

Yeah I was wondering if something like that is feasible. It would be really cool to have a "if it can be sampled it can be measured"

Do you think this can be done for the following examples?

y = HalfNormal(1)
x = Exponential(y)
z = x * -1

And now we got a negative exponential and the user could condition z on negative data and we would behind scenes use a standard exponential logpt evaluating data * -1?

Basically a change of variable?

y = Normal(0, 1)
x = tt.exp(y)

x now follows a Lognormal distribution

Or

y = Exponential(1)
x = y[y < 5]
z = tt.clip(y, y, 5)

And we would figure out x should have the logpt of a right truncated exponential and z of a right censored exponential?

@brandonwillard
Copy link
Contributor

... I was wondering if something like that is feasible.

It is. The question is "To what extent can it be done?", and that answer must be "Enough to be useful.", at the very least.

Do you think this can be done for the following examples?

If you can obtain a closed form for the log-likelihood (using operations supported by Aesara, of course), then we can do it.

@ricardoV94
Copy link
Member Author

ricardoV94 commented Apr 23, 2021

One new suggestion from @twiecki related to RandomWalk factories in #4653 (comment) (and also in #4047):

y = pm.Flat.dist() + pm.Normal.dist(sigma=2).cumsum()
y = pm.RandomVariable('y',  y, observed=data)

Which could be generated by a helper method as well

y = pm.RandomWalk('y', init=pm.Flat.dist(), dist=pm.Normal.dist(sigma=2), observed=data)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants