-
Notifications
You must be signed in to change notification settings - Fork 2
/
generate_images.py
149 lines (135 loc) · 5.53 KB
/
generate_images.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import torch
from tqdm import tqdm
import os
from torch.utils.data import DataLoader
from utils.datasets import Proposed_biases
from utils.DDP_manager import DDP
import utils.arg_parse as arg_parse
from torch.utils.data import Dataset
# required to run eval()
from utils.generative_models import Stable_Diffusion_XL, Stable_Diffusion
# MULTI DATA PARALLELIZATION
import torch.multiprocessing as mp
# split the dataset into chunks for each rank
class Distributed_dataset(Dataset):
def __init__(self,
rank,
world_size,
opt,
ds
):
self.rank = rank
self.world_size = world_size
# get data
data = ds.get_data()
self.data_to_generate = []
for prompt, caption_id in data:
# if the folder does not exist, add it to the list of data to generate
if not os.path.isdir(os.path.join(opt['save_path'], str(caption_id))):
self.data_to_generate.append((prompt, caption_id))
# if the folder exists, check if the number of images is less than the desired number of images
else:
length = len(os.listdir(os.path.join(opt['save_path'], str(caption_id))))
if length < opt['dataset_setting']['n-images']:
# if the number of images is less than the desired number of images, add it to the list of data to generate
# NOTE: this will overwrite the existing images
self.data_to_generate.append((prompt, caption_id))
# split data
length = len(self.data_to_generate)
samples_per_rank = length // world_size
if rank == world_size-1:
self.data_to_generate = self.data_to_generate[rank*samples_per_rank:]
else:
self.data_to_generate = self.data_to_generate[rank*samples_per_rank: (rank+1)*samples_per_rank]
def __getitem__(self, idx):
caption, caption_id = self.data_to_generate[idx]
return caption, caption_id
def __len__(self):
return len(self.data_to_generate)
class DDP_image_gen(DDP):
def __init__(
self,
rank,
world_size,
opt,
ds
):
self.seed = opt['seed']
self.gen_info = opt['generator']
self.save_path = opt['save_path']
os.makedirs(self.save_path, exist_ok=True)
self.n_images = opt['dataset_setting']['n-images']
self.batch_size = opt['gen_setting']['batch_size']
self.pos_prompt = self.gen_info['pos_prompt']
self.opt = opt
self.ds = ds
super(DDP_image_gen, self).__init__(rank, world_size)
def split_batches(self, l, n_images):
for i in range(0, len(l), n_images):
yield l[i: i+n_images]
def main(self):
# init generative model
generative_model = eval(self.gen_info['class'])(gen_info = self.gen_info, device = self.device, n_images=self.n_images)
# get dataset for the specific rank
ds = Distributed_dataset(
rank = self.rank,
world_size = self.world_size,
opt = self.opt,
ds = self.ds
)
print(f'Rank {self.rank} has {len(ds)} samples to generate')
loader = DataLoader(ds, batch_size=1, shuffle=False, num_workers=1, pin_memory=True)
# generate and save images
for prompts, caption_ids in tqdm(loader, position=self.rank, desc=f'Rank {self.rank}'):
torch.cuda.empty_cache()
prompts = [p+' '+self.pos_prompt for p in prompts]
gen_images = generative_model.generate_images(prompt=prompts)
batch_images = self.split_batches(gen_images, self.n_images)
for batch_idx, images in enumerate(batch_images):
caption_id = str(caption_ids[batch_idx].item()) if type(caption_ids[batch_idx]) == torch.Tensor else str(caption_ids[batch_idx])
save_dir = os.path.join(self.save_path, caption_id)
os.makedirs(save_dir, exist_ok=True)
# for each generated image in the batch
for image_idx, image in enumerate(images):
# save image
image.save(os.path.join(save_dir, f'{image_idx}.jpg'))
# check if image was saved correctly
if not os.path.isfile(os.path.join(save_dir, f'{image_idx}.jpg')):
print(f'ERROR: image {image_idx} of caption {caption_id} not saved')
def run(
rank,
world_size,
opt,
ds
):
# Set seed
torch.manual_seed(opt['seed'])
DDP_image_gen(
rank = rank,
world_size = world_size,
opt = opt,
ds = ds
)
if __name__ == "__main__":
world_size = torch.cuda.device_count()
print(f'Using {world_size} GPUs')
# Parse arguments
opt = arg_parse.argparse_generate_images()
mp.set_start_method('spawn')
# Load dataset
ds = Proposed_biases(
dataset_path = opt['dataset_setting']['proposed_biases_path'],
max_prompts = opt['gen_setting']['max_prompts_per_bias'],
filter_threshold = opt['gen_setting']['filter_threshold'],
hard_threshold = opt['gen_setting']['hard_threshold'],
merge_threshold = opt['gen_setting']['merge_threshold'],
valid_bias_fn = opt['dataset_setting']['valid_bias_fn'],
filter_caption_fn = opt['dataset_setting']['filter_caption_fn'],
all_images = opt['dataset_setting']['all_images']
)
# Start DDP
mp.spawn(
run,
args=(world_size, opt, ds, ),
nprocs=world_size
)