forked from google-deepmind/deepmind-research
-
Notifications
You must be signed in to change notification settings - Fork 0
/
vae.py
133 lines (111 loc) · 4.2 KB
/
vae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Standard VAE class."""
from typing import Optional
import jax
import jax.numpy as jnp
from avae import decoders
from avae import encoders
from avae import kl
from avae import types
class VAE:
"""VAE class.
This class defines the ELBO used in training VAE models. It also adds function
for forward passing data through VAE.
"""
def __init__(self, encoder: encoders.EncoderBase,
decoder: decoders.DecoderBase, rho: Optional[float] = None):
"""Class initializer.
Args:
encoder: Encoder network architecture.
decoder: Decoder network architecture.
rho: Rho parameter used in AVAE training.
"""
self._encoder = encoder
self._decoder = decoder
self._rho = rho
def vae_elbo(
self, input_data: jnp.ndarray,
key: jnp.ndarray) -> types.ELBOOutputs:
"""ELBO for training VAE.
Args:
input_data: Input batch of shape (batch_size, ...).
key: Key for random number generator.
Returns:
Computed VAE Elbo as type util_dataclasses.ELBOOutputs
"""
posterior = self._encoder(input_data)
samples = self._encoder.sample(posterior, key)
kls = jax.vmap(kl.kl_p_with_uniform_normal, [0])(
posterior.mean, posterior.variance)
recons = self._decoder(samples)
data_fidelity = self._decoder.data_fidelity(input_data, recons)
elbo = data_fidelity - kls
return types.ELBOOutputs(elbo, data_fidelity, kls)
def avae_elbo(
self, input_data: jnp.ndarray,
key: jnp.ndarray) -> types.ELBOOutputs:
"""ELBO for training AVAE model.
Args:
input_data: Input batch of shape (batch_size, ...).
key: Key for random number generator.
Returns:
Computed AVAE Elbo in nested tuple (Elbo, (data_fidelity, KL)). All arrays
have batch dimension intact.
"""
aux_images = jax.lax.stop_gradient(self(input_data, key))
posterior = self._encoder(input_data)
samples = self._encoder.sample(posterior, key)
kls = jax.vmap(kl.kl_p_with_uniform_normal, [0, 0])(
posterior.mean, posterior.variance)
recons = self._decoder(samples)
data_fidelity = self._decoder.data_fidelity(input_data, recons)
elbo = data_fidelity - kls
aux_posterior = self._encoder(aux_images)
latent_mean = posterior.mean
latent_var = posterior.variance
aux_latent_mean = aux_posterior.mean
aux_latent_var = aux_posterior.variance
latent_dim = latent_mean.shape[1]
def _reduce(x):
return jnp.mean(jnp.sum(x, axis=1))
# Computation of <log p(Z_aux | Z)>.
expected_log_conditional = (
aux_latent_var + jnp.square(self._rho) * latent_var +
jnp.square(aux_latent_mean - self._rho * latent_mean))
expected_log_conditional = _reduce(expected_log_conditional)
expected_log_conditional /= 2.0 * (1.0 - jnp.square(self._rho))
expected_log_conditional = (latent_dim *
jnp.log(1.0 / (2 * jnp.pi)) -
expected_log_conditional)
elbo += expected_log_conditional
# Entropy of Z_aux
elbo += _reduce(0.5 * jnp.log(2 * jnp.pi * jnp.e * aux_latent_var))
return types.ELBOOutputs(elbo, data_fidelity, kls)
def __call__(
self, input_data: jnp.ndarray,
key: jnp.ndarray) -> jnp.ndarray:
"""Reconstruction of the input data.
Args:
input_data: Input batch of shape (batch_size, ...).
key: Key for random number generator.
Returns:
Reconstruction of the input data as jnp.ndarray of shape
[batch_dim, observation_dims].
"""
posterior = self._encoder(input_data)
samples = self._encoder.sample(posterior, key)
recons = self._decoder(samples)
return recons