-
Notifications
You must be signed in to change notification settings - Fork 14
/
grid_distortion.py
75 lines (57 loc) · 2.35 KB
/
grid_distortion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# This code is taken and modified from Curtis Wigington's Github:
# https://github.com/cwig/simple_hwr.git
import cv2
import numpy as np
from scipy.interpolate import griddata
import sys
INTERPOLATION = {
"linear": cv2.INTER_LINEAR,
"cubic": cv2.INTER_CUBIC
}
def warp_image(img, random_state=None, **kwargs):
if random_state is None:
random_state = np.random.RandomState()
w_mesh_interval = kwargs.get('w_mesh_interval', 25)
w_mesh_std = kwargs.get('w_mesh_std', 3.0)
h_mesh_interval = kwargs.get('h_mesh_interval', 25)
h_mesh_std = kwargs.get('h_mesh_std', 3.0)
interpolation_method = kwargs.get('interpolation', 'linear')
h, w = img.shape[:2]
if kwargs.get("fit_interval_to_image", True):
# Change interval so it fits the image size
w_ratio = w / float(w_mesh_interval)
h_ratio = h / float(h_mesh_interval)
w_ratio = max(1, round(w_ratio))
h_ratio = max(1, round(h_ratio))
w_mesh_interval = w / w_ratio
h_mesh_interval = h / h_ratio
############################################
# Get control points
source = np.mgrid[0:h+h_mesh_interval:h_mesh_interval, 0:w+w_mesh_interval:w_mesh_interval]
source = source.transpose(1,2,0).reshape(-1,2)
if kwargs.get("draw_grid_lines", False):
if len(img.shape) == 2:
color = 0
else:
color = np.array([0,0,255])
for s in source:
img[int(s[0]):int(s[0])+1,:] = color
img[:,int(s[1]):int(s[1])+1] = color
# Perturb source control points
destination = source.copy()
source_shape = source.shape[:1]
destination[:,0] = destination[:,0] + random_state.normal(0.0, h_mesh_std, size=source_shape)
destination[:,1] = destination[:,1] + random_state.normal(0.0, w_mesh_std, size=source_shape)
# Warp image
grid_x, grid_y = np.mgrid[0:h, 0:w]
grid_z = griddata(destination, source, (grid_x, grid_y), method=interpolation_method).astype(np.float32)
map_x = grid_z[:,:,1]
map_y = grid_z[:,:,0]
warped = cv2.remap(img, map_x, map_y, INTERPOLATION[interpolation_method], borderValue=(255,255,255))
return warped
if __name__ == "__main__":
input_image = sys.argv[1]
output_image = sys.argv[2]
img = cv2.imread(input_image)
img = warp_image(img, draw_grid_lines=True)
cv2.imwrite(output_image, img)