-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
50 changed files
with
2,206 additions
and
1 deletion.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,384 @@ | ||
|
||
|
||
import math | ||
import torch | ||
from torch import nn | ||
from torch.nn import init | ||
from torch.nn import functional as F | ||
import numpy as np | ||
|
||
class Swish(nn.Module): | ||
def forward(self, x): | ||
return x * torch.sigmoid(x) | ||
|
||
|
||
class TimeEmbedding(nn.Module): | ||
def __init__(self, T, d_model, dim): | ||
assert d_model % 2 == 0 | ||
super().__init__() | ||
emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000) | ||
emb = torch.exp(-emb) | ||
pos = torch.arange(T).float() | ||
emb = pos[:, None] * emb[None, :] | ||
assert list(emb.shape) == [T, d_model // 2] | ||
emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1) | ||
assert list(emb.shape) == [T, d_model // 2, 2] | ||
emb = emb.view(T, d_model) | ||
|
||
self.timembedding = nn.Sequential( | ||
nn.Embedding.from_pretrained(emb), | ||
nn.Linear(d_model, dim), | ||
Swish(), | ||
nn.Linear(dim, dim), | ||
) | ||
self.initialize() | ||
|
||
def initialize(self): | ||
for module in self.modules(): | ||
if isinstance(module, nn.Linear): | ||
init.xavier_uniform_(module.weight) | ||
init.zeros_(module.bias) | ||
|
||
def forward(self, t): | ||
emb = self.timembedding(t) | ||
return emb | ||
|
||
|
||
class DownSample(nn.Module): | ||
def __init__(self, in_ch): | ||
super().__init__() | ||
self.main = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1) | ||
self.initialize() | ||
|
||
def initialize(self): | ||
init.xavier_uniform_(self.main.weight) | ||
init.zeros_(self.main.bias) | ||
|
||
def forward(self, x, temb,light_emb): | ||
x = self.main(x) | ||
return x | ||
|
||
|
||
class UpSample(nn.Module): | ||
def __init__(self, in_ch): | ||
super().__init__() | ||
self.main = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1) | ||
self.initialize() | ||
|
||
def initialize(self): | ||
init.xavier_uniform_(self.main.weight) | ||
init.zeros_(self.main.bias) | ||
|
||
def forward(self, x, temb,light_emb): | ||
_, _, H, W = x.shape | ||
x = F.interpolate( | ||
x, scale_factor=2, mode='nearest') | ||
x = self.main(x) | ||
return x | ||
|
||
|
||
class AttnBlock(nn.Module): | ||
def __init__(self, in_ch): | ||
super().__init__() | ||
self.group_norm = nn.GroupNorm(32, in_ch) | ||
self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) | ||
self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) | ||
self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) | ||
self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) | ||
self.initialize() | ||
|
||
def initialize(self): | ||
for module in [self.proj_q, self.proj_k, self.proj_v, self.proj]: | ||
init.xavier_uniform_(module.weight) | ||
init.zeros_(module.bias) | ||
init.xavier_uniform_(self.proj.weight, gain=1e-5) | ||
|
||
def forward(self, x): | ||
B, C, H, W = x.shape | ||
h = self.group_norm(x) | ||
q = self.proj_q(h) | ||
k = self.proj_k(h) | ||
v = self.proj_v(h) | ||
|
||
q = q.permute(0, 2, 3, 1).view(B, H * W, C) | ||
k = k.view(B, C, H * W) | ||
w = torch.bmm(q, k) * (int(C) ** (-0.5)) | ||
assert list(w.shape) == [B, H * W, H * W] | ||
w = F.softmax(w, dim=-1) | ||
|
||
v = v.permute(0, 2, 3, 1).view(B, H * W, C) | ||
h = torch.bmm(w, v) | ||
assert list(h.shape) == [B, H * W, C] | ||
h = h.view(B, H, W, C).permute(0, 3, 1, 2) | ||
h = self.proj(h) | ||
|
||
return x + h | ||
|
||
|
||
class ResBlock(nn.Module): | ||
def __init__(self, in_ch, out_ch, tdim, dropout, attn=False): | ||
super().__init__() | ||
self.block1 = nn.Sequential( | ||
nn.GroupNorm(32, in_ch), | ||
Swish(), | ||
nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1), | ||
) | ||
self.temb_proj = nn.Sequential( | ||
Swish(), | ||
nn.Linear(tdim, out_ch), | ||
) | ||
self.block2 = nn.Sequential( | ||
nn.GroupNorm(32, out_ch), | ||
Swish(), | ||
nn.Dropout(dropout), | ||
nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1), | ||
) | ||
if in_ch != out_ch: | ||
self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0) | ||
else: | ||
self.shortcut = nn.Identity() | ||
if attn: | ||
self.attn = AttnBlock(out_ch) | ||
|
||
else: | ||
self.attn = nn.Identity() | ||
self.initialize() | ||
|
||
self.mlp = nn.Sequential( | ||
nn.SiLU(), | ||
nn.Linear(128, out_ch * 2) | ||
) | ||
def initialize(self): | ||
for module in self.modules(): | ||
if isinstance(module, (nn.Conv2d, nn.Linear)): | ||
init.xavier_uniform_(module.weight) | ||
init.zeros_(module.bias) | ||
init.xavier_uniform_(self.block2[-1].weight, gain=1e-5) | ||
|
||
def forward(self, x, temb,light_emb): | ||
h = self.block1(x) | ||
h += self.temb_proj(temb)[:, :, None, None] | ||
h = self.block2(h) | ||
|
||
h = h + self.shortcut(x) | ||
h = self.attn(h) | ||
|
||
light_emb=self.mlp(light_emb) | ||
scale, shift = torch.chunk(light_emb, 2, dim=1) | ||
scale=scale.view(scale.shape[0],scale.shape[1],1,1) | ||
shift= shift.view(shift.shape[0], shift.shape[1], 1, 1) | ||
h=h * (1 + scale) + shift | ||
return h | ||
|
||
|
||
|
||
class UNet(nn.Module): | ||
def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout): | ||
super().__init__() | ||
assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound' | ||
tdim = ch * 4 | ||
self.time_embedding = TimeEmbedding(T, ch, tdim) | ||
|
||
self.head = nn.Conv2d(10, ch, kernel_size=3, stride=1, padding=1) | ||
self.downblocks = nn.ModuleList() | ||
chs = [ch] # record output channel when dowmsample for upsample | ||
now_ch = ch | ||
for i, mult in enumerate(ch_mult): | ||
out_ch = ch * mult | ||
for _ in range(num_res_blocks): | ||
self.downblocks.append(ResBlock( | ||
in_ch=now_ch, out_ch=out_ch, tdim=tdim, | ||
dropout=dropout, attn=(i in attn))) | ||
now_ch = out_ch | ||
chs.append(now_ch) | ||
if i != len(ch_mult) - 1: | ||
self.downblocks.append(DownSample(now_ch)) | ||
chs.append(now_ch) | ||
|
||
self.middleblocks = nn.ModuleList([ | ||
ResBlock(now_ch, now_ch, tdim, dropout, attn=True), | ||
ResBlock(now_ch, now_ch, tdim, dropout, attn=False), | ||
]) | ||
|
||
self.upblocks = nn.ModuleList() | ||
for i, mult in reversed(list(enumerate(ch_mult))): | ||
out_ch = ch * mult | ||
for _ in range(num_res_blocks + 1): | ||
self.upblocks.append(ResBlock( | ||
in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim, | ||
dropout=dropout, attn=(i in attn))) | ||
now_ch = out_ch | ||
if i != 0: | ||
self.upblocks.append(UpSample(now_ch)) | ||
assert len(chs) == 0 | ||
|
||
self.tail = nn.Sequential( | ||
nn.GroupNorm(32, now_ch), | ||
Swish(), | ||
nn.Conv2d(now_ch, 3, 3, stride=1, padding=1) | ||
) | ||
self.initialize() | ||
|
||
rand_mat = np.random.randn(128, 128) | ||
rand_otho_mat, _ = np.linalg.qr(rand_mat) | ||
self.light_ecode = nn.Parameter(torch.from_numpy(rand_otho_mat).float(), requires_grad=False) | ||
|
||
def initialize(self): | ||
init.xavier_uniform_(self.head.weight) | ||
init.zeros_(self.head.bias) | ||
init.xavier_uniform_(self.tail[-1].weight, gain=1e-5) | ||
init.zeros_(self.tail[-1].bias) | ||
|
||
def getEmbedding(self, diff): | ||
coord = diff/ 10 # [-1, 1] | ||
# self.ip is [1, 128, 128, 1] | ||
coord = coord.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) | ||
# [b, 1, 1, 1] | ||
#print(coord.shape) | ||
coord = torch.cat((torch.zeros_like(coord), coord), dim=-1) | ||
#print(coord.shape) | ||
# [b, 1, 1, 2] | ||
b = coord.shape[0] | ||
mat = self.light_ecode.unsqueeze(-1).unsqueeze(0) | ||
mat = mat.expand(b, -1, -1, -1) | ||
coord = F.grid_sample(mat, coord, align_corners=True) | ||
#print(coord.shape) | ||
# [b, 128, 1, 1] | ||
coord = coord[:, :, 0, 0] | ||
#print(coord.shape) | ||
# [b, 128] | ||
return coord | ||
|
||
def forward(self, x, t,light_emb,context_zero=None): | ||
# Timestep embedding | ||
device=t.device | ||
#train | ||
#light_context=torch.zeros([light_emb.shape[0],128]) #128为emb的纬度 | ||
light_context=self.getEmbedding(light_emb) | ||
# for b in range(light_emb.shape[0]): | ||
# context_emb = light_emb[b].long().to(device) | ||
# context_emb = self.light_ecode[context_emb] | ||
# light_context[b] = context_emb | ||
context = light_context.to(device) | ||
if context_zero == True: | ||
context = torch.zeros_like(context) | ||
|
||
|
||
temb = self.time_embedding(t) #(80,512) | ||
# Downsampling | ||
h = self.head(x) | ||
hs = [h] | ||
for layer in self.downblocks: | ||
h = layer(h, temb,context) | ||
hs.append(h) | ||
# Middle | ||
for layer in self.middleblocks: | ||
h = layer(h, temb,context) | ||
# Upsampling | ||
for layer in self.upblocks: | ||
if isinstance(layer, ResBlock): | ||
h = torch.cat([h, hs.pop()], dim=1) | ||
h = layer(h, temb,context) | ||
h = self.tail(h) | ||
|
||
assert len(hs) == 0 | ||
return h | ||
|
||
|
||
class UNet_forCNN(nn.Module): | ||
def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout): | ||
super().__init__() | ||
assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound' | ||
tdim = ch * 4 | ||
self.time_embedding = TimeEmbedding(T, ch, tdim) | ||
|
||
self.head = nn.Conv2d(10, ch, kernel_size=3, stride=1, padding=1) | ||
self.downblocks = nn.ModuleList() | ||
chs = [ch] # record output channel when dowmsample for upsample | ||
now_ch = ch | ||
for i, mult in enumerate(ch_mult): | ||
out_ch = ch * mult | ||
for _ in range(num_res_blocks): | ||
self.downblocks.append(ResBlock( | ||
in_ch=now_ch, out_ch=out_ch, tdim=tdim, | ||
dropout=dropout, attn=(i in attn))) | ||
now_ch = out_ch | ||
chs.append(now_ch) | ||
if i != len(ch_mult) - 1: | ||
self.downblocks.append(DownSample(now_ch)) | ||
chs.append(now_ch) | ||
|
||
self.middleblocks = nn.ModuleList([ | ||
ResBlock(now_ch, now_ch, tdim, dropout, attn=True), | ||
ResBlock(now_ch, now_ch, tdim, dropout, attn=False), | ||
]) | ||
|
||
self.upblocks = nn.ModuleList() | ||
for i, mult in reversed(list(enumerate(ch_mult))): | ||
out_ch = ch * mult | ||
for _ in range(num_res_blocks + 1): | ||
self.upblocks.append(ResBlock( | ||
in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim, | ||
dropout=dropout, attn=(i in attn))) | ||
now_ch = out_ch | ||
if i != 0: | ||
self.upblocks.append(UpSample(now_ch)) | ||
assert len(chs) == 0 | ||
|
||
self.tail = nn.Sequential( | ||
nn.GroupNorm(32, now_ch), | ||
Swish(), | ||
nn.Conv2d(now_ch, 3, 3, stride=1, padding=1) | ||
) | ||
self.initialize() | ||
|
||
rand_mat = np.random.randn(128, 128) | ||
rand_otho_mat, _ = np.linalg.qr(rand_mat) | ||
self.light_ecode = nn.Parameter(torch.from_numpy(rand_otho_mat).float(), requires_grad=False) | ||
|
||
def initialize(self): | ||
init.xavier_uniform_(self.head.weight) | ||
init.zeros_(self.head.bias) | ||
init.xavier_uniform_(self.tail[-1].weight, gain=1e-5) | ||
init.zeros_(self.tail[-1].bias) | ||
|
||
def forward(self, x, t,light_emb,context_zero=None): | ||
# Timestep embedding | ||
device=t.device | ||
#train | ||
light_context=torch.zeros([light_emb.shape[0],128]) #128为emb的纬度 | ||
for b in range(light_emb.shape[0]): | ||
context_emb = light_emb[b].long().to(device) | ||
context_emb = self.light_ecode[context_emb] | ||
light_context[b] = context_emb | ||
context = light_context.to(device) | ||
if context_zero == True: | ||
context = torch.zeros_like(context) | ||
|
||
|
||
temb = self.time_embedding(t) #(80,512) | ||
# Downsampling | ||
h = self.head(x) | ||
hs = [h] | ||
for layer in self.downblocks: | ||
h = layer(h, temb,context) | ||
hs.append(h) | ||
# Middle | ||
for layer in self.middleblocks: | ||
h = layer(h, temb,context) | ||
|
||
|
||
#assert len(hs) == 0 | ||
return h | ||
|
||
|
||
if __name__ == '__main__': | ||
batch_size = 8 | ||
model = UNet( | ||
T=1000, ch=128, ch_mult=[1, 2, 2, 2], attn=[1], | ||
num_res_blocks=2, dropout=0.1) | ||
x = torch.randn(batch_size, 3, 32, 32) | ||
t = torch.randint(1000, (batch_size, )) | ||
y = model(x, t) | ||
print(y.shape) | ||
print(model) |
Oops, something went wrong.