Skip to content

Apply jnp.exp to output of a model #14440

Answered by PhilipVinc
averageFlaxUser asked this question in Q&A
Discussion options

You must be logged in to vote

Can't you simply redefine your model to be

class Model(nn.Module):
    training: bool    
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=10)(x)
        ...
        x = nn.Dense(features=2)(x)
        mu, sigma = x[...,0], x[...,1]
        return mu, jnp.exp(sigma)

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by averageFlaxUser
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants