-
Notifications
You must be signed in to change notification settings - Fork 7
/
generate_distortions.py
49 lines (41 loc) · 1.6 KB
/
generate_distortions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import os
import h5py
import hydra
import copy
import numpy as np
from PIL import Image
import utils.io as io
from grit_paths import GritPaths
@hydra.main(config_path='configs',config_name='default')
def main(cfg):
if cfg.subsets_to_distort is not None:
subsets_to_distort = cfg.subsets_to_distort
else:
subsets_to_distort = ['ablation','test']
for subset in subsets_to_distort:
print(subset)
grit_paths = GritPaths(cfg.grit.base)
for task in cfg.tasks_to_distort:
print(f'- {task}')
samples = io.load_json_object(grit_paths.samples(task,subset))
samples = [s for s in samples if 'distorted' in s['image_id']]
deltas = h5py.File(grit_paths.dist_deltas(task,subset),'r')
for sample in samples:
example_id = sample['example_id']
delta = deltas[example_id[:-5]][()]
dist_image_id = sample['image_id']
undist_image_id = '/'.join(dist_image_id.split('/')[2:])
img = Image.open(os.path.join(
cfg.grit.images,
undist_image_id)).convert('RGB')
dist_img = np.asarray(img) + delta
dist_img = Image.fromarray(dist_img.astype(np.uint8))
dist_img_path = os.path.join(
cfg.grit.images,
dist_image_id)
io.mkdir_if_not_exists(
os.path.dirname(dist_img_path),
recursive=True)
dist_img.save(dist_img_path)
if __name__=='__main__':
main()