-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding VQVAE and refactor trainer class (#7)
* added vanilla.py, test_vanila.py * modified vanilla.py * modified vanilla.py * finished vanilla.py for now * added test_VanillaAE; modified VanillaAE * added basetrainer.py * trying to stop codecov warning on PR * trying to stop codecov warning on PR * added test_basetrainer.py * updated VanillaAETrainer; modified BaseTrainer; WIP: added test_VanillaAETrainer; * First fittable VanillaAETrainer * fixed a bug in BaseTrainer * WIP: added vq.py * WIP: modified vq.py * refactored models to trainer * added blank test_vq.py * updated vq.py * modified test_vq.py; fixed some bugs in vq.py * WIP: refactor base trainer * Major refactoring in trainer and model structure * added index_histogram in VectorQuantizer * fixed bugs in vq.py
- Loading branch information
1 parent
09426f3
commit d611466
Showing
12 changed files
with
544 additions
and
183 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import torch | ||
from ..vq import VectorQuantizer | ||
|
||
|
||
def test_VectorQuantizer(): | ||
embedding_dim = 10 | ||
num_embeddings = 3 | ||
vq = VectorQuantizer(embedding_dim, num_embeddings, 1.0) | ||
assert vq.codebook.weight.shape == (num_embeddings, embedding_dim) | ||
|
||
data = torch.randn((5, 10, 4, 4), dtype=torch.float32) | ||
out = vq(data) | ||
assert len(out[0].shape) == 0 | ||
assert out[1].shape == data.shape | ||
assert len(out[2].shape) == 0 | ||
assert out[3].max() == 1 | ||
assert out[3].shape == (data.shape[0], num_embeddings) + data.shape[2:] | ||
assert out[4].shape == data.shape[:1] + data.shape[2:] | ||
assert out[5].shape == (data.shape[0], num_embeddings) | ||
assert all(torch.round(out[5].sum(axis=1), decimals=3) == 16) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
from typing import Optional | ||
|
||
import torch | ||
from torch import nn | ||
|
||
|
||
class VectorQuantizer(nn.Module): | ||
""" | ||
Vector Quantization layer | ||
""" | ||
|
||
def __init__( | ||
self, | ||
embedding_dim: int, | ||
num_embeddings: int, | ||
commitment_cost: float = 0.25, | ||
padding_idx: Optional[int] = None, | ||
initializer: str = 'uniform', | ||
**kwargs, | ||
): | ||
""" | ||
Initializes Vector Quantization layer | ||
Parameters | ||
---------- | ||
embedding_dim : int | ||
Embedding dimension | ||
num_embeddings : int | ||
Number of embeddings | ||
commitment_cost : float | ||
Commitment cost | ||
padding_idx : int | ||
If specified, the entries at padding_idx do not contribute to the gradient; | ||
therefore, the embedding vector at padding_idx is not updated during training, | ||
i.e. it remains as a fixed “pad”. | ||
initializer : str | ||
Initializing distribution | ||
""" | ||
super().__init__() | ||
self.embedding_dim = embedding_dim | ||
self.num_embeddings = num_embeddings | ||
self.commitment_cost = commitment_cost | ||
self.padding_idx = padding_idx | ||
|
||
self.codebook = nn.Embedding(self.num_embeddings, self.embedding_dim, self.padding_idx) | ||
if initializer == 'uniform': | ||
self.codebook.weight.data.uniform_(-1.0 / self.num_embeddings, 1.0 / self.num_embeddings) | ||
|
||
def _calc_dist(self, z): | ||
""" | ||
Computes distance between inputs and codebook. | ||
Parameters | ||
---------- | ||
z : tensor | ||
Usually the output of encoder | ||
Returns | ||
------- | ||
tensor | ||
""" | ||
# reshape z -> (batch, height, width, channel) and flatten | ||
z = torch.movedim(z, 1, -1).contiguous() | ||
z_flattened = z.view(-1, self.embedding_dim) | ||
|
||
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z | ||
distances = ( | ||
torch.sum(z_flattened**2, dim=1, keepdim=True) | ||
+ torch.sum(self.codebook.weight**2, dim=1) | ||
- 2 * torch.matmul(z_flattened, self.codebook.weight.t()) | ||
) | ||
return distances | ||
|
||
def _calc_metrics(self, z, z_quantized, encoding_onehot): | ||
# compute losses | ||
commitment_loss = torch.mean((z_quantized.detach() - z) ** 2) | ||
quantization_loss = torch.mean((z_quantized - z.detach()) ** 2) | ||
loss = quantization_loss + self.commitment_cost * commitment_loss | ||
|
||
# perplexity | ||
avg_probs = torch.mean(encoding_onehot.float(), dim=0) | ||
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) | ||
return loss, perplexity | ||
|
||
def forward(self, z): | ||
""" | ||
Parameters | ||
---------- | ||
z : tensor | ||
input tensor; usually the output of an encoder | ||
Returns | ||
------- | ||
loss : length of 0 | ||
quantized embeddings : same shape as z | ||
perplexity : length of 0 | ||
encoding_onehot : shape of (Batch, Code, Width, Height) | ||
encoding_indices : shape of (Batch, Width, Height) | ||
index_histogram : shape of (Batch, Code index) | ||
""" | ||
distances = self._calc_dist(z) | ||
# Use softmax as argmin and compute histogram | ||
index_histogram = torch.sum(nn.Softmax(-1)(-distances.view((z.shape[0], -1, self.num_embeddings))), dim=1) | ||
# find the closest encodings | ||
encoding_indices = torch.argmin(distances, dim=1) | ||
# Create one-hot vectors | ||
encoding_onehot = nn.functional.one_hot(encoding_indices, self.num_embeddings) | ||
|
||
# get quantized latent vectors | ||
z_quantized = torch.matmul(encoding_onehot.float(), self.codebook.weight) | ||
# reshape back to match original input shape | ||
z_quantized = torch.movedim( | ||
z_quantized.view((z.shape[0],) + z.shape[2:] + (self.embedding_dim,)), -1, 1 | ||
).contiguous() | ||
|
||
# compute metrics | ||
loss, perplexity = self._calc_metrics(z, z_quantized, encoding_onehot) | ||
# reshape back to match original input shape | ||
encoding_onehot = torch.movedim( | ||
encoding_onehot.view((z.shape[0],) + z.shape[2:] + (self.num_embeddings,)), -1, 1 | ||
) | ||
|
||
# copy the gradient from inputs to quantized z. | ||
z_quantized = z + (z_quantized - z).detach() | ||
|
||
return ( | ||
loss, | ||
z_quantized, | ||
perplexity, | ||
encoding_onehot, | ||
encoding_indices.view((-1,) + z.shape[2:]), | ||
index_histogram, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import torch | ||
|
||
from cytoself.trainer.autoencoder.vqvae import VQVAE | ||
|
||
|
||
def test_VQVAE(): | ||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
input_shape, emb_shape = (2, 100, 100), (64, 4, 4) | ||
model = VQVAE(input_shape, emb_shape, input_shape, {'num_embeddings': 7}) | ||
model.to(device) | ||
input_data = torch.randn((1,) + input_shape).to(device) | ||
out = model(input_data) | ||
assert out.shape == input_data.shape | ||
assert len(model.vq_loss.shape) == 0 | ||
assert len(model.perplexity.shape) == 0 | ||
assert model.encoding_onehot.max() == 1 | ||
assert model.encoding_indices.shape == input_data.shape[:1] + emb_shape[1:] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from typing import Optional | ||
from torch import nn, Tensor | ||
|
||
from cytoself.trainer.autoencoder.encoders.efficientenc2d import efficientenc_b0 | ||
from cytoself.trainer.autoencoder.decoders.resnet2d import DecoderResnet | ||
from cytoself.components.layers.vq import VectorQuantizer | ||
|
||
|
||
class VQVAE(nn.Module): | ||
""" | ||
Vector Quantized Variational Autoencoder model | ||
""" | ||
|
||
def __init__( | ||
self, | ||
input_shape: tuple, | ||
emb_shape: tuple, | ||
output_shape: tuple, | ||
vq_args: dict, | ||
encoder_args: Optional[dict] = None, | ||
decoder_args: Optional[dict] = None, | ||
encoder: Optional = None, | ||
decoder: Optional = None, | ||
): | ||
super().__init__() | ||
if encoder is None: | ||
encoder = efficientenc_b0 | ||
if decoder is None: | ||
decoder = DecoderResnet | ||
if encoder_args is None: | ||
encoder_args = {'in_channels': input_shape[0], 'out_channels': emb_shape[0]} | ||
if decoder_args is None: | ||
decoder_args = {'input_shape': emb_shape, 'output_shape': output_shape} | ||
self.encoder = encoder(**encoder_args) | ||
self.decoder = decoder(**decoder_args) | ||
self.vq_layer = VectorQuantizer(embedding_dim=emb_shape[0], **vq_args) | ||
self.vq_loss = None | ||
self.perplexity = None | ||
self.encoding_onehot = None | ||
self.encoding_indices = None | ||
self.index_histogram = None | ||
|
||
def forward(self, x: Tensor) -> Tensor: | ||
x = self.encoder(x) | ||
( | ||
self.vq_loss, | ||
x, | ||
self.perplexity, | ||
self.encoding_onehot, | ||
self.encoding_indices, | ||
self.index_histogram, | ||
) = self.vq_layer(x) | ||
x = self.decoder(x) | ||
return x |
Oops, something went wrong.