Skip to content

Commit

Permalink
Adding VQVAE and refactor trainer class (#7)
Browse files Browse the repository at this point in the history
* added vanilla.py, test_vanila.py

* modified vanilla.py

* modified vanilla.py

* finished vanilla.py for now

* added test_VanillaAE;
modified VanillaAE

* added basetrainer.py

* trying to stop codecov warning on PR

* trying to stop codecov warning on PR

* added test_basetrainer.py

* updated VanillaAETrainer;
modified BaseTrainer;
WIP: added test_VanillaAETrainer;

* First fittable VanillaAETrainer

* fixed a bug in BaseTrainer

* WIP: added vq.py

* WIP: modified vq.py

* refactored models to trainer

* added blank test_vq.py

* updated vq.py

* modified test_vq.py; fixed some bugs in vq.py

* WIP: refactor base trainer

* Major refactoring in trainer and model structure

* added index_histogram in VectorQuantizer

* fixed bugs in vq.py
  • Loading branch information
li-li-github authored May 23, 2022
1 parent 09426f3 commit d611466
Show file tree
Hide file tree
Showing 12 changed files with 544 additions and 183 deletions.
Empty file.
20 changes: 20 additions & 0 deletions cytoself/components/layers/test/test_vq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch
from ..vq import VectorQuantizer


def test_VectorQuantizer():
embedding_dim = 10
num_embeddings = 3
vq = VectorQuantizer(embedding_dim, num_embeddings, 1.0)
assert vq.codebook.weight.shape == (num_embeddings, embedding_dim)

data = torch.randn((5, 10, 4, 4), dtype=torch.float32)
out = vq(data)
assert len(out[0].shape) == 0
assert out[1].shape == data.shape
assert len(out[2].shape) == 0
assert out[3].max() == 1
assert out[3].shape == (data.shape[0], num_embeddings) + data.shape[2:]
assert out[4].shape == data.shape[:1] + data.shape[2:]
assert out[5].shape == (data.shape[0], num_embeddings)
assert all(torch.round(out[5].sum(axis=1), decimals=3) == 16)
136 changes: 136 additions & 0 deletions cytoself/components/layers/vq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from typing import Optional

import torch
from torch import nn


class VectorQuantizer(nn.Module):
"""
Vector Quantization layer
"""

def __init__(
self,
embedding_dim: int,
num_embeddings: int,
commitment_cost: float = 0.25,
padding_idx: Optional[int] = None,
initializer: str = 'uniform',
**kwargs,
):
"""
Initializes Vector Quantization layer
Parameters
----------
embedding_dim : int
Embedding dimension
num_embeddings : int
Number of embeddings
commitment_cost : float
Commitment cost
padding_idx : int
If specified, the entries at padding_idx do not contribute to the gradient;
therefore, the embedding vector at padding_idx is not updated during training,
i.e. it remains as a fixed “pad”.
initializer : str
Initializing distribution
"""
super().__init__()
self.embedding_dim = embedding_dim
self.num_embeddings = num_embeddings
self.commitment_cost = commitment_cost
self.padding_idx = padding_idx

self.codebook = nn.Embedding(self.num_embeddings, self.embedding_dim, self.padding_idx)
if initializer == 'uniform':
self.codebook.weight.data.uniform_(-1.0 / self.num_embeddings, 1.0 / self.num_embeddings)

def _calc_dist(self, z):
"""
Computes distance between inputs and codebook.
Parameters
----------
z : tensor
Usually the output of encoder
Returns
-------
tensor
"""
# reshape z -> (batch, height, width, channel) and flatten
z = torch.movedim(z, 1, -1).contiguous()
z_flattened = z.view(-1, self.embedding_dim)

# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
distances = (
torch.sum(z_flattened**2, dim=1, keepdim=True)
+ torch.sum(self.codebook.weight**2, dim=1)
- 2 * torch.matmul(z_flattened, self.codebook.weight.t())
)
return distances

def _calc_metrics(self, z, z_quantized, encoding_onehot):
# compute losses
commitment_loss = torch.mean((z_quantized.detach() - z) ** 2)
quantization_loss = torch.mean((z_quantized - z.detach()) ** 2)
loss = quantization_loss + self.commitment_cost * commitment_loss

# perplexity
avg_probs = torch.mean(encoding_onehot.float(), dim=0)
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
return loss, perplexity

def forward(self, z):
"""
Parameters
----------
z : tensor
input tensor; usually the output of an encoder
Returns
-------
loss : length of 0
quantized embeddings : same shape as z
perplexity : length of 0
encoding_onehot : shape of (Batch, Code, Width, Height)
encoding_indices : shape of (Batch, Width, Height)
index_histogram : shape of (Batch, Code index)
"""
distances = self._calc_dist(z)
# Use softmax as argmin and compute histogram
index_histogram = torch.sum(nn.Softmax(-1)(-distances.view((z.shape[0], -1, self.num_embeddings))), dim=1)
# find the closest encodings
encoding_indices = torch.argmin(distances, dim=1)
# Create one-hot vectors
encoding_onehot = nn.functional.one_hot(encoding_indices, self.num_embeddings)

# get quantized latent vectors
z_quantized = torch.matmul(encoding_onehot.float(), self.codebook.weight)
# reshape back to match original input shape
z_quantized = torch.movedim(
z_quantized.view((z.shape[0],) + z.shape[2:] + (self.embedding_dim,)), -1, 1
).contiguous()

# compute metrics
loss, perplexity = self._calc_metrics(z, z_quantized, encoding_onehot)
# reshape back to match original input shape
encoding_onehot = torch.movedim(
encoding_onehot.view((z.shape[0],) + z.shape[2:] + (self.num_embeddings,)), -1, 1
)

# copy the gradient from inputs to quantized z.
z_quantized = z + (z_quantized - z).detach()

return (
loss,
z_quantized,
perplexity,
encoding_onehot,
encoding_indices.view((-1,) + z.shape[2:]),
index_histogram,
)
2 changes: 1 addition & 1 deletion cytoself/datamanager/test/test_datamanager_oc.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_split_data(self):
for d, s in zip(index_list, self.datamgr.data_split):
data = test_label[d]
assert (
min(1, floor(len(label_all) * s * 0.7)) <= len(data) <= ceil(len(label_all) * s * 1.4)
min(1, floor(len(label_all) * s * 0.68)) <= len(data) <= ceil(len(label_all) * s * 1.4)
), 'Split ratio deviates too far.'

def test_split_data_notfov(self):
Expand Down
2 changes: 1 addition & 1 deletion cytoself/datamanager/utils/test/test_cumsum_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ def test_cumsum_split_arr():
out = cumsum_split(counts, (8, 1, 1), np.arange(len(counts)))
sums = [sum(counts[i]) for i in out]
for i, d in enumerate(splits):
assert d * 0.8 * 0.1 < sums[i] / sum(sums) < d * 1.3 * 0.1
assert d * 0.8 * 0.1 < sums[i] / sum(sums) < d * 1.32 * 0.1
17 changes: 17 additions & 0 deletions cytoself/trainer/autoencoder/test/test_vqvae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import torch

from cytoself.trainer.autoencoder.vqvae import VQVAE


def test_VQVAE():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
input_shape, emb_shape = (2, 100, 100), (64, 4, 4)
model = VQVAE(input_shape, emb_shape, input_shape, {'num_embeddings': 7})
model.to(device)
input_data = torch.randn((1,) + input_shape).to(device)
out = model(input_data)
assert out.shape == input_data.shape
assert len(model.vq_loss.shape) == 0
assert len(model.perplexity.shape) == 0
assert model.encoding_onehot.max() == 1
assert model.encoding_indices.shape == input_data.shape[:1] + emb_shape[1:]
54 changes: 54 additions & 0 deletions cytoself/trainer/autoencoder/vqvae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import Optional
from torch import nn, Tensor

from cytoself.trainer.autoencoder.encoders.efficientenc2d import efficientenc_b0
from cytoself.trainer.autoencoder.decoders.resnet2d import DecoderResnet
from cytoself.components.layers.vq import VectorQuantizer


class VQVAE(nn.Module):
"""
Vector Quantized Variational Autoencoder model
"""

def __init__(
self,
input_shape: tuple,
emb_shape: tuple,
output_shape: tuple,
vq_args: dict,
encoder_args: Optional[dict] = None,
decoder_args: Optional[dict] = None,
encoder: Optional = None,
decoder: Optional = None,
):
super().__init__()
if encoder is None:
encoder = efficientenc_b0
if decoder is None:
decoder = DecoderResnet
if encoder_args is None:
encoder_args = {'in_channels': input_shape[0], 'out_channels': emb_shape[0]}
if decoder_args is None:
decoder_args = {'input_shape': emb_shape, 'output_shape': output_shape}
self.encoder = encoder(**encoder_args)
self.decoder = decoder(**decoder_args)
self.vq_layer = VectorQuantizer(embedding_dim=emb_shape[0], **vq_args)
self.vq_loss = None
self.perplexity = None
self.encoding_onehot = None
self.encoding_indices = None
self.index_histogram = None

def forward(self, x: Tensor) -> Tensor:
x = self.encoder(x)
(
self.vq_loss,
x,
self.perplexity,
self.encoding_onehot,
self.encoding_indices,
self.index_histogram,
) = self.vq_layer(x)
x = self.decoder(x)
return x
Loading

0 comments on commit d611466

Please sign in to comment.