-
Notifications
You must be signed in to change notification settings - Fork 7
/
vis_codebook.py
120 lines (94 loc) · 4.81 KB
/
vis_codebook.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from itertools import count
from tokenize import PlainToken
import torch
import torchvision.transforms as tf
from torchvision.utils import save_image
import torchvision.utils as tvu
import numpy as np
import os
import random
from tqdm import tqdm
import cv2
from matplotlib import pyplot as plt
import seaborn as sns
from basicsr.utils.misc import set_random_seed
from basicsr.utils import img2tensor, tensor2img, imwrite
from basicsr.archs.femasr_arch import FeMaSRNet
def reconstruct_ost(model, data_dir, save_dir, maxnum=100):
texture_classes = list(os.listdir(data_dir))
texture_classes.remove('manga109')
code_idx_dict = {}
for tc in texture_classes:
img_name_list = os.listdir(os.path.join(data_dir, tc))
random.shuffle(img_name_list)
tmp_code_idx_list = []
for img_name in tqdm(img_name_list[:maxnum]):
img_path = os.path.join(data_dir, tc, img_name)
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
img_tensor = img2tensor(img).to(device) / 255.
img_tensor = img_tensor.unsqueeze(0)
rec, _, _, indices = model(img_tensor)
indices = indices[0]
save_path = os.path.join(save_dir, tc, img_name)
if not os.path.exists(os.path.join(save_dir, tc)):
os.makedirs(os.path.join(save_dir, tc), exist_ok=True)
imwrite(tensor2img(rec), save_path)
save_org_dir = save_dir.replace('rec', 'org')
save_org_path = os.path.join(save_org_dir, tc, img_name)
if not os.path.exists(os.path.join(save_org_dir, tc)):
os.makedirs(os.path.join(save_org_dir, tc), exist_ok=True)
imwrite(tensor2img(img_tensor), save_org_path)
tmp_code_idx_list.append(indices)
code_idx_dict[tc] = tmp_code_idx_list
torch.save(code_idx_dict, './tmp_code_vis/code_idx_dict.pth')
def vis_hrp(model, code_list_path, save_dir, samples_each_class=16):
code_idx_dict = torch.load(code_list_path)
classes = list(code_idx_dict.keys())
latent_size = 8
color_palette = sns.color_palette()
for idx, (key, value) in enumerate(code_idx_dict.items()):
all_idx = torch.cat([x.flatten() for x in value])
plt.figure(figsize=(16, 8))
sns.histplot(all_idx.cpu().numpy(), color=color_palette[idx])
plt.xlabel(key, fontsize=30)
plt.ylabel('Count', fontsize=30)
plt.savefig(f'./tmp_code_vis/code_stat/code_index_bincount_{key}.pdf')
counts = all_idx.bincount()
dist = counts / sum(counts)
dist = dist.cpu().numpy()
vis_tex_samples = []
for sid in range(32):
vis_tex_map = np.random.choice(np.arange(dist.shape[0]), latent_size ** 2, p=dist)
vis_tex_map = torch.from_numpy(vis_tex_map).to(all_idx)
vis_tex_map = vis_tex_map.reshape(1, 1, latent_size, latent_size)
vis_tex_img = model.decode_indices(vis_tex_map)
vis_tex_samples.append(vis_tex_img)
vis_tex_samples = torch.cat(vis_tex_samples, dim=0)
save_image(vis_tex_samples, f'./tmp_code_vis/tmp_tex_vis/{key}.jpg', normalize=True, nrow=16)
def vis_single_code(model, codenum, save_path, up_factor=4):
code_idx = torch.arange(codenum).reshape(codenum, 1, 1, 1)
code_idx = code_idx.repeat(1, 1, up_factor, up_factor)
output_img = model.decode_indices(code_idx)
output_img = tvu.make_grid(output_img, nrow=32)
save_image(output_img, save_path)
def vis_rand_single_code(model, codenum, save_path, up_factor=4, vis_num=5):
code_idx = torch.randint(codenum, (vis_num,)).reshape(vis_num, 1, 1, 1)
code_idx = code_idx.repeat(1, 1, up_factor, up_factor)
output_img = model.decode_indices(code_idx)
output_img = tvu.make_grid(output_img, nrow=32)
save_image(output_img, save_path)
if __name__ == '__main__':
# set random seeds to ensure reproducibility
set_random_seed(123)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# set up the model
weight_path = './experiments/pretrained_models/S1_c2_512x256_net_g_best.pth'
codebook_size = os.path.basename(weight_path).split('_')[2].split('x')
vqgan = FeMaSRNet(codebook_params=[[32, int(codebook_size[0]), int(codebook_size[1])]], LQ_stage=False).to(device)
vqgan.load_state_dict(torch.load(weight_path)['params'], strict=False)
vqgan.eval()
os.makedirs('results/codebook_vis', exist_ok=True)
vis_single_code(vqgan, int(codebook_size[0]), 'results/codebook_vis/{}.png'.format(os.path.basename(weight_path).split('.')[0]))
# vis_rand_single_code(vqgan, 256, 'codebook_vis/sample_ffhq.png', vis_num=10)
# reconstruct_ost(vqgan, '../datasets/SR_OST_datasets/OutdoorSceneTrain_v2/', './tmp_code_vis/ost_rec', maxnum=1000)
# vis_hrp(vqgan, './tmp_code_vis/code_idx_dict.pth', './tmp_code_vis/')