-
Notifications
You must be signed in to change notification settings - Fork 122
/
train_celeba.py
79 lines (60 loc) · 2.15 KB
/
train_celeba.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from typing import Dict, Optional, Tuple
import os
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torchvision.utils import save_image, make_grid
from mindiffusion.unet import NaiveUnet
from mindiffusion.ddpm import DDPM
from dotenv import load_dotenv
load_dotenv("./.env")
CELEBA_PATH = os.getenv("CELEBA_PATH")
def train_celeba(
n_epoch: int = 100, device: str = "cuda:1", load_pth: Optional[str] = None
) -> None:
ddpm = DDPM(eps_model=NaiveUnet(3, 3, n_feat=128), betas=(1e-4, 0.02), n_T=1000)
if load_pth is not None:
ddpm.load_state_dict(torch.load("ddpm_celeba.pth"))
ddpm.to(device)
tf = transforms.Compose( # resize to 512 x 512, convert to tensor, normalize
[
transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
dataset = ImageFolder(
root=CELEBA_PATH,
transform=tf,
)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=20)
optim = torch.optim.Adam(ddpm.parameters(), lr=2e-5)
for i in range(n_epoch):
print(f"Epoch {i} : ")
ddpm.train()
pbar = tqdm(dataloader)
loss_ema = None
for x, _ in pbar:
optim.zero_grad()
x = x.to(device)
loss = ddpm(x)
loss.backward()
if loss_ema is None:
loss_ema = loss.item()
else:
loss_ema = 0.9 * loss_ema + 0.1 * loss.item()
pbar.set_description(f"loss: {loss_ema:.4f}")
optim.step()
ddpm.eval()
with torch.no_grad():
xh = ddpm.sample(8, (3, 128, 128), device)
xset = torch.cat([xh, x[:8]], dim=0)
grid = make_grid(xset, normalize=True, value_range=(-1, 1), nrow=4)
save_image(grid, f"./contents/ddpm_sample_celeba{i:03d}.png")
# save model
torch.save(ddpm.state_dict(), f"./ddpm_celeba.pth")
if __name__ == "__main__":
train_celeba()