You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
import torch
from dalle_pytorch import DiscreteVAE, DALLE
from torchvision import transforms as T
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid, save_image
BATCH_SIZE = 4
IMAGE_SIZE = 64
IMAGE_PATH = "."
EPOCHS = 1
vae = DiscreteVAE(
image_size = IMAGE_SIZE,
num_layers = 2, # number of downsamples - ex. 256 / (2 ** 3) = (32 x 32 feature map)
num_tokens = 1024, # number of visual tokens. in the paper, they used 8192, but could be smaller for downsized projects
codebook_dim = 256, # codebook dimension
hidden_dim = 32, # hidden dimension
num_resnet_blocks = 1, # number of resnet blocks
temperature = 0.9, # gumbel softmax temperature, the lower this is, the harder the discretization
straight_through = False, # straight-through for gumbel softmax. unclear if it is better one way or the other
)
##Train on images
images = torch.randn(BATCH_SIZE, 3, IMAGE_SIZE, IMAGE_SIZE)
dataset = ImageFolder(
IMAGE_PATH,
T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize(IMAGE_SIZE),
T.CenterCrop(IMAGE_SIZE),
T.ToTensor()
])
)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
##Run training for several epochs
count = 0
for epoch in range(EPOCHS):
for (images, labels) in iter(dataloader):
loss = vae(images, return_loss = True)
loss.backward()
print(count)
count = count + 1
#Train on text to images
dalle = DALLE(
dim = 1024,
vae = vae, # automatically infer (1) image sequence length and (2) number of image tokens
num_text_tokens = 1000, # vocab size for text
text_seq_len = 16, # text sequence length
depth = 12, # should aim to be 64
heads = 16, # attention heads
dim_head = 64, # attention head dimension
attn_dropout = 0.1, # attention dropout
ff_dropout = 0.1 # feedforward dropout
)
text = torch.randint(0, 1000, (BATCH_SIZE, 16))
images = torch.randn(BATCH_SIZE, 3, IMAGE_SIZE, IMAGE_SIZE)
loss = dalle(text, images, return_loss = True)
loss.backward()
# do the above for a long time with a lot of data ... then
images = dalle.generate_images(text)
img1 = images[0]
save_image(img1, 'img1.png')
print(images.shape) # (4, 3, 256, 256)
The text was updated successfully, but these errors were encountered:
Here's my code -- no idea what's happening
The text was updated successfully, but these errors were encountered: