Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Neural network layers #16

Closed
wants to merge 8 commits into from
Closed

Neural network layers #16

wants to merge 8 commits into from

Conversation

rlouf
Copy link
Owner

@rlouf rlouf commented Apr 4, 2020

I open this PR to start thinking about the design of bayesian neural network layers. The idea is to subclass trax’s constructs and allow use of distributions for weights and transformation of weights.

The goal is to able to take any model expressed with ˋtrax` and make it bayesian by adding prior distributions on the weights.

Of course, we should be able to construct hierarchical models by adding hyperpriors on the priors’ parameters.

Layers are distributions over functions; let us see what if could look like on a naive MNIST example:

@mcx.model
def mnist(image):
    nn <~ ml.Serial(
        dense(400, Normal(0, 1)),
        dense(400, Normal(0, 1)),
        dense(10, Normal(0, 1)),
        softmax(),
    )
    p = nn(image)
    cat <~ Categorical(p)
    return cat

The above snippet is naive in the sense that the way Normal(0, 1) is related to each weight in the layer is not very clear. We need to specify broadcasting rules for the bayesian layers.

image

We should be able to easily define hierarchical models:

@mcx.model
def mnist(image):
    sigma <~ Exponential(2)
    nn <~ ml.Serial(
        dense(400, Normal(0, sigma)),
        dense(400, Normal(0, sigma)),
        dense(10, Normal(0, sigma)),
        softmax(),
    )
    p = nn(image)
    cat <~ Categorical(p)
    return cat

Forward sampling

Let’s look now at the design of the forward sampler. We need to return forward samples of the layer's weights as well as the other random variables.

We could define a sample method that draws a realization of each layer and performs a forward pass with the drawn weights.

def  mnist_sampler(rng_key, image):
    nn = ml.Serial(
        dense(400, Normal(0, 1)),
        dense(400, Normal(0, 1)),
        dense(10, Normal(0, 1)),
        softmax(),
    )
    p, weights = nn.sample(rng_key, image)
    cat = Categorical(p).sample(rng_key)
    return weights, cat

where weights is a tuple that contains all the weights's realized value. This would keep a similar API to the distributions' with the added output return value that reflects the fact that we are sampling a function.

Another option is

def  mnist_sampler(rng_key, image):
    nn = ml.Serial(
        dense(400, Normal(0, 1)),
        dense(400, Normal(0, 1)),
        dense(10, Normal(0, 1)),
        softmax(),
    )
    weights = nn.sample(rng_key)
    p = nn(image, weights)
    cat = Categorical(p).sample(rng_key)
    return weights, cat

which feels less magical.

Log-probability density function

def  mnist_logpdf(weights, image, cat):
    logpdf = 0
    nn = ml.Serial(
        dense(400, Normal(0, 1)),
        dense(400, Normal(0, 1)),
        dense(10, Normal(0, 1)),
        softmax(),
    )
    logpdf += nn.logpdf(image, weights)
    p = nn(image, weights)
    logpdf += Categorical(p).logpdf(cat)
    return logpdf

Note: the __call__ method of the layers calls the pure_fn method which is jit-able. Not sure it is necessary to call it directly here.

rlouf and others added 6 commits March 26, 2020 10:07
I believe `path_length` is still in use for historical reasons; however,
it makes more sense to reason in terms of number of integration steps
while it simplifies the cost (no dynamic computation of the number of
integration steps and casting to int). I thus replaced every mention of
`path length` in the HMC proposal and program with
`num_integration_steps`.
The boundary between programs (sampling algorithms) and runtimes
(executors) was not very clear. I remove any dependence on the model
from the program and responsibilities are now clear. Most of the initialization has been transferred to the runtime. I also improved the
performance of the creation of initial states.
@rlouf rlouf added enhancement-api priority-3 Not but, low priority issues/PR labels Apr 6, 2020
@ericmjl
Copy link
Contributor

ericmjl commented Apr 15, 2020

@rlouf I'm not sure if this might help a bit, but would a blog post I wrote on shapes be helpful to you? No pressure to read it though. Just a thought, no pressure.

@rlouf
Copy link
Owner Author

rlouf commented Apr 15, 2020

@ericmjl Thank you for the link, I did read your post before implementing distributions. It was really helpful to dive into TFP’s shape system!

Is there anything in particular you think I might have missed that could help me?

@ericmjl
Copy link
Contributor

ericmjl commented Apr 15, 2020

@rlouf thank you for the kind words! I think (but I'm not 100% sure) maybe working backwards from the desired semantics might be helpful?

Personally, when I think of Gaussian priors on a neural network's weights, I tend to think of them as being the "same" prior (e.g. N(0, 1)) applied to every single weight matrix entry, as I haven't seen a strong reason to apply, for example, N(0, 1) to entry [0, 0] and then N(0, 3) to entry [0, 1] and so on.

I think I might still be unclear, so let me attempt an example that has contrasts in there.

Given the following NN:

@mcx.model
def mnist(image):
    nn ~ ml.Serial(
        dense(400, Normal(0, 1)),
        dense(400, Normal(3, 4)),
	    dense(10, Normal(-2, 7)),
        softmax(),
    )
    p = nn(image)
    cat ~ Categorical(p)
    return cat

I would read it as:

  • First dense layer has shape (:, 400), and so I give each weight in there N(0, 1) as the prior.
  • Second dense layer has shape (400, 400), and so I give each weight in there N(3, 4) as the prior. (No idea why I'd actually wanna do that though!)
  • Third dense layer has shape (400, 10), and so I give each weight in there N(-2, 7) as the prior. (Even more absurd prior! 😛)

I think the suggestion I have here matches to your 2nd option exactly:

Every weight has the same prior distribution.

You don't have to accept the exact suggestion, but maybe implementing it one way first and then trying it out might illuminate whether it's good or not? In reimplementing an RNN, I did the layers in an opinionated, "my-way" fashion first, then realized it'd be easier and more compatible to just go stax/trax-style, and then worked with my intern to get it re-done in a stax-compatible fashion. Not much time was lost, even though in retrospect, I clearly got it wrong the first time.

@rlouf
Copy link
Owner Author

rlouf commented Apr 16, 2020

Interesting feedback, thank you for taking the time to explain! The NN API is indeed a bit tricky to get right the first time.

I am currently leaning towards what you're proposing. Would you agree with simply broadcasting the parameters' shape with the layer's shape to obtain the batch_shape (drawing I made)? This can be done dynamically when forward sampling; since the initialization of the posterior sampler uses forward sampling to determinate the layers' shape it would work.

This way it is also compatible with crazy specs, like a different variance for each layer weight.

@ericmjl
Copy link
Contributor

ericmjl commented Apr 16, 2020

Would you agree with simply broadcasting the parameters' shape with the layer's shape to obtain the batch_shape (drawing I made)?

Yes, I would! It sounds like a sensible default to have.

@rlouf
Copy link
Owner Author

rlouf commented Apr 17, 2020

Thank you for your insights! It feels good to have someone else's opinion.

Was your RNN project Bayesian? If so, is the code available somewhere?

@ericmjl
Copy link
Contributor

ericmjl commented Apr 17, 2020

The RNN wasn't Bayesian, and it was mostly a re-implementation of the original, but done in JAX. Given that it's written stax-style, I'm sure it shouldn't be too hard to extend it to mcx 😄.

You can find the repo here, and we have a mini-writeup available too.

@rlouf rlouf force-pushed the master branch 3 times, most recently from 48e21e4 to 025b5f1 Compare May 7, 2020 05:54
rlouf added 2 commits June 2, 2020 12:33
To keep a simple API when building Bayesian Neural Network we don't want
to have to specify the batching shape of the prior distribution so that
it matches the layer size. Therefore we add a helper function that
re-broadcasts a distribution to a destination shape (here the layer
size) so it can be used in the neural network internals.
@rlouf rlouf force-pushed the master branch 10 times, most recently from ad89b00 to 4e2e2fa Compare September 29, 2020 10:09
@rlouf
Copy link
Owner Author

rlouf commented Apr 12, 2021

Closing for now; the relevant info is in the discussions.

@rlouf rlouf closed this Apr 12, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement-api priority-3 Not but, low priority issues/PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants