Skip to content

Commit

Permalink
training code
Browse files Browse the repository at this point in the history
  • Loading branch information
YuyangYin committed Jan 3, 2024
1 parent c451484 commit b6abffc
Show file tree
Hide file tree
Showing 50 changed files with 2,206 additions and 1 deletion.
401 changes: 401 additions & 0 deletions Diffusion/Diffusion.py

Large diffs are not rendered by default.

384 changes: 384 additions & 0 deletions Diffusion/Model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,384 @@


import math
import torch
from torch import nn
from torch.nn import init
from torch.nn import functional as F
import numpy as np

class Swish(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)


class TimeEmbedding(nn.Module):
def __init__(self, T, d_model, dim):
assert d_model % 2 == 0
super().__init__()
emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000)
emb = torch.exp(-emb)
pos = torch.arange(T).float()
emb = pos[:, None] * emb[None, :]
assert list(emb.shape) == [T, d_model // 2]
emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1)
assert list(emb.shape) == [T, d_model // 2, 2]
emb = emb.view(T, d_model)

self.timembedding = nn.Sequential(
nn.Embedding.from_pretrained(emb),
nn.Linear(d_model, dim),
Swish(),
nn.Linear(dim, dim),
)
self.initialize()

def initialize(self):
for module in self.modules():
if isinstance(module, nn.Linear):
init.xavier_uniform_(module.weight)
init.zeros_(module.bias)

def forward(self, t):
emb = self.timembedding(t)
return emb


class DownSample(nn.Module):
def __init__(self, in_ch):
super().__init__()
self.main = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1)
self.initialize()

def initialize(self):
init.xavier_uniform_(self.main.weight)
init.zeros_(self.main.bias)

def forward(self, x, temb,light_emb):
x = self.main(x)
return x


class UpSample(nn.Module):
def __init__(self, in_ch):
super().__init__()
self.main = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1)
self.initialize()

def initialize(self):
init.xavier_uniform_(self.main.weight)
init.zeros_(self.main.bias)

def forward(self, x, temb,light_emb):
_, _, H, W = x.shape
x = F.interpolate(
x, scale_factor=2, mode='nearest')
x = self.main(x)
return x


class AttnBlock(nn.Module):
def __init__(self, in_ch):
super().__init__()
self.group_norm = nn.GroupNorm(32, in_ch)
self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
self.initialize()

def initialize(self):
for module in [self.proj_q, self.proj_k, self.proj_v, self.proj]:
init.xavier_uniform_(module.weight)
init.zeros_(module.bias)
init.xavier_uniform_(self.proj.weight, gain=1e-5)

def forward(self, x):
B, C, H, W = x.shape
h = self.group_norm(x)
q = self.proj_q(h)
k = self.proj_k(h)
v = self.proj_v(h)

q = q.permute(0, 2, 3, 1).view(B, H * W, C)
k = k.view(B, C, H * W)
w = torch.bmm(q, k) * (int(C) ** (-0.5))
assert list(w.shape) == [B, H * W, H * W]
w = F.softmax(w, dim=-1)

v = v.permute(0, 2, 3, 1).view(B, H * W, C)
h = torch.bmm(w, v)
assert list(h.shape) == [B, H * W, C]
h = h.view(B, H, W, C).permute(0, 3, 1, 2)
h = self.proj(h)

return x + h


class ResBlock(nn.Module):
def __init__(self, in_ch, out_ch, tdim, dropout, attn=False):
super().__init__()
self.block1 = nn.Sequential(
nn.GroupNorm(32, in_ch),
Swish(),
nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),
)
self.temb_proj = nn.Sequential(
Swish(),
nn.Linear(tdim, out_ch),
)
self.block2 = nn.Sequential(
nn.GroupNorm(32, out_ch),
Swish(),
nn.Dropout(dropout),
nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
)
if in_ch != out_ch:
self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
else:
self.shortcut = nn.Identity()
if attn:
self.attn = AttnBlock(out_ch)

else:
self.attn = nn.Identity()
self.initialize()

self.mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(128, out_ch * 2)
)
def initialize(self):
for module in self.modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
init.xavier_uniform_(module.weight)
init.zeros_(module.bias)
init.xavier_uniform_(self.block2[-1].weight, gain=1e-5)

def forward(self, x, temb,light_emb):
h = self.block1(x)
h += self.temb_proj(temb)[:, :, None, None]
h = self.block2(h)

h = h + self.shortcut(x)
h = self.attn(h)

light_emb=self.mlp(light_emb)
scale, shift = torch.chunk(light_emb, 2, dim=1)
scale=scale.view(scale.shape[0],scale.shape[1],1,1)
shift= shift.view(shift.shape[0], shift.shape[1], 1, 1)
h=h * (1 + scale) + shift
return h



class UNet(nn.Module):
def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):
super().__init__()
assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound'
tdim = ch * 4
self.time_embedding = TimeEmbedding(T, ch, tdim)

self.head = nn.Conv2d(10, ch, kernel_size=3, stride=1, padding=1)
self.downblocks = nn.ModuleList()
chs = [ch] # record output channel when dowmsample for upsample
now_ch = ch
for i, mult in enumerate(ch_mult):
out_ch = ch * mult
for _ in range(num_res_blocks):
self.downblocks.append(ResBlock(
in_ch=now_ch, out_ch=out_ch, tdim=tdim,
dropout=dropout, attn=(i in attn)))
now_ch = out_ch
chs.append(now_ch)
if i != len(ch_mult) - 1:
self.downblocks.append(DownSample(now_ch))
chs.append(now_ch)

self.middleblocks = nn.ModuleList([
ResBlock(now_ch, now_ch, tdim, dropout, attn=True),
ResBlock(now_ch, now_ch, tdim, dropout, attn=False),
])

self.upblocks = nn.ModuleList()
for i, mult in reversed(list(enumerate(ch_mult))):
out_ch = ch * mult
for _ in range(num_res_blocks + 1):
self.upblocks.append(ResBlock(
in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim,
dropout=dropout, attn=(i in attn)))
now_ch = out_ch
if i != 0:
self.upblocks.append(UpSample(now_ch))
assert len(chs) == 0

self.tail = nn.Sequential(
nn.GroupNorm(32, now_ch),
Swish(),
nn.Conv2d(now_ch, 3, 3, stride=1, padding=1)
)
self.initialize()

rand_mat = np.random.randn(128, 128)
rand_otho_mat, _ = np.linalg.qr(rand_mat)
self.light_ecode = nn.Parameter(torch.from_numpy(rand_otho_mat).float(), requires_grad=False)

def initialize(self):
init.xavier_uniform_(self.head.weight)
init.zeros_(self.head.bias)
init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)
init.zeros_(self.tail[-1].bias)

def getEmbedding(self, diff):
coord = diff/ 10 # [-1, 1]
# self.ip is [1, 128, 128, 1]
coord = coord.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
# [b, 1, 1, 1]
#print(coord.shape)
coord = torch.cat((torch.zeros_like(coord), coord), dim=-1)
#print(coord.shape)
# [b, 1, 1, 2]
b = coord.shape[0]
mat = self.light_ecode.unsqueeze(-1).unsqueeze(0)
mat = mat.expand(b, -1, -1, -1)
coord = F.grid_sample(mat, coord, align_corners=True)
#print(coord.shape)
# [b, 128, 1, 1]
coord = coord[:, :, 0, 0]
#print(coord.shape)
# [b, 128]
return coord

def forward(self, x, t,light_emb,context_zero=None):
# Timestep embedding
device=t.device
#train
#light_context=torch.zeros([light_emb.shape[0],128]) #128为emb的纬度
light_context=self.getEmbedding(light_emb)
# for b in range(light_emb.shape[0]):
# context_emb = light_emb[b].long().to(device)
# context_emb = self.light_ecode[context_emb]
# light_context[b] = context_emb
context = light_context.to(device)
if context_zero == True:
context = torch.zeros_like(context)


temb = self.time_embedding(t) #(80,512)
# Downsampling
h = self.head(x)
hs = [h]
for layer in self.downblocks:
h = layer(h, temb,context)
hs.append(h)
# Middle
for layer in self.middleblocks:
h = layer(h, temb,context)
# Upsampling
for layer in self.upblocks:
if isinstance(layer, ResBlock):
h = torch.cat([h, hs.pop()], dim=1)
h = layer(h, temb,context)
h = self.tail(h)

assert len(hs) == 0
return h


class UNet_forCNN(nn.Module):
def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):
super().__init__()
assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound'
tdim = ch * 4
self.time_embedding = TimeEmbedding(T, ch, tdim)

self.head = nn.Conv2d(10, ch, kernel_size=3, stride=1, padding=1)
self.downblocks = nn.ModuleList()
chs = [ch] # record output channel when dowmsample for upsample
now_ch = ch
for i, mult in enumerate(ch_mult):
out_ch = ch * mult
for _ in range(num_res_blocks):
self.downblocks.append(ResBlock(
in_ch=now_ch, out_ch=out_ch, tdim=tdim,
dropout=dropout, attn=(i in attn)))
now_ch = out_ch
chs.append(now_ch)
if i != len(ch_mult) - 1:
self.downblocks.append(DownSample(now_ch))
chs.append(now_ch)

self.middleblocks = nn.ModuleList([
ResBlock(now_ch, now_ch, tdim, dropout, attn=True),
ResBlock(now_ch, now_ch, tdim, dropout, attn=False),
])

self.upblocks = nn.ModuleList()
for i, mult in reversed(list(enumerate(ch_mult))):
out_ch = ch * mult
for _ in range(num_res_blocks + 1):
self.upblocks.append(ResBlock(
in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim,
dropout=dropout, attn=(i in attn)))
now_ch = out_ch
if i != 0:
self.upblocks.append(UpSample(now_ch))
assert len(chs) == 0

self.tail = nn.Sequential(
nn.GroupNorm(32, now_ch),
Swish(),
nn.Conv2d(now_ch, 3, 3, stride=1, padding=1)
)
self.initialize()

rand_mat = np.random.randn(128, 128)
rand_otho_mat, _ = np.linalg.qr(rand_mat)
self.light_ecode = nn.Parameter(torch.from_numpy(rand_otho_mat).float(), requires_grad=False)

def initialize(self):
init.xavier_uniform_(self.head.weight)
init.zeros_(self.head.bias)
init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)
init.zeros_(self.tail[-1].bias)

def forward(self, x, t,light_emb,context_zero=None):
# Timestep embedding
device=t.device
#train
light_context=torch.zeros([light_emb.shape[0],128]) #128为emb的纬度
for b in range(light_emb.shape[0]):
context_emb = light_emb[b].long().to(device)
context_emb = self.light_ecode[context_emb]
light_context[b] = context_emb
context = light_context.to(device)
if context_zero == True:
context = torch.zeros_like(context)


temb = self.time_embedding(t) #(80,512)
# Downsampling
h = self.head(x)
hs = [h]
for layer in self.downblocks:
h = layer(h, temb,context)
hs.append(h)
# Middle
for layer in self.middleblocks:
h = layer(h, temb,context)


#assert len(hs) == 0
return h


if __name__ == '__main__':
batch_size = 8
model = UNet(
T=1000, ch=128, ch_mult=[1, 2, 2, 2], attn=[1],
num_res_blocks=2, dropout=0.1)
x = torch.randn(batch_size, 3, 32, 32)
t = torch.randint(1000, (batch_size, ))
y = model(x, t)
print(y.shape)
print(model)
Loading

0 comments on commit b6abffc

Please sign in to comment.