Apply jnp.exp to output of a model #14440
Answered
by
PhilipVinc
averageFlaxUser
asked this question in
Q&A
-
I have a NN that outputs
But I get a warning from Jax at
How to deal with this? My model is defined using
|
Beta Was this translation helpful? Give feedback.
Answered by
PhilipVinc
Feb 13, 2023
Replies: 1 comment
-
Can't you simply redefine your model to be class Model(nn.Module):
training: bool
@nn.compact
def __call__(self, x):
x = nn.Dense(features=10)(x)
...
x = nn.Dense(features=2)(x)
mu, sigma = x[...,0], x[...,1]
return mu, jnp.exp(sigma) |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
averageFlaxUser
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Can't you simply redefine your model to be