-
Notifications
You must be signed in to change notification settings - Fork 2
/
deep_interpolate.py
199 lines (173 loc) · 8.14 KB
/
deep_interpolate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
"""Frame Interpolation Core Code"""
import os
import argparse
from typing import Callable
import cv2
from tqdm import tqdm
from interpolate_engine import InterpolateEngine
from interpolate import Interpolate
from webui_utils.simple_log import SimpleLog
from webui_utils.simple_utils import max_steps, sortable_float_index
from webui_utils.file_utils import create_directory
def main():
"""Use Frame Interpolation from the command line"""
parser = argparse.ArgumentParser(description="Video Frame Interpolation (deep)")
parser.add_argument("--model",
default="./pretrained_models/pretrained_VFIformer/net_220.pth", type=str)
parser.add_argument("--gpu_ids", type=str, default="0",
help="gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU")
parser.add_argument("--img_before", default="./images/image0.png", type=str,
help="Path to before frame image")
parser.add_argument("--img_after", default="./images/image2.png", type=str,
help="Path to after frame image")
parser.add_argument("--depth", default=2, type=int,
help="how many doublings of the frames")
parser.add_argument("--output_path", default="./output", type=str,
help="Output path for interpolated PNGs")
parser.add_argument("--base_filename", default="interpolated_frame", type=str,
help="Base filename for interpolated PNGs")
parser.add_argument("--verbose", dest="verbose", default=False, action="store_true",
help="Show extra details")
args = parser.parse_args()
log = SimpleLog(args.verbose)
create_directory(args.output_path)
engine = InterpolateEngine(args.model, args.gpu_ids)
interpolater = Interpolate(engine.model, log.log)
deep_interpolater = DeepInterpolate(interpolater, log.log)
deep_interpolater.split_frames(args.img_before, args.img_after, args.depth, args.output_path,
args.base_filename)
class DeepInterpolate():
"""Encapsulates logic for the Frame Interpolation feature"""
def __init__(self,
interpolater : Interpolate,
log_fn : Callable | None):
self.interpolater = interpolater
self.log_fn = log_fn
self.split_count = 0
self.frame_register = []
self.progress = None
self.output_paths = []
def split_frames(self,
before_filepath,
after_filepath,
num_splits,
output_path,
base_filename,
progress_label="Frame",
continued=False,
resynthesis=False):
"""Invoke the Frame Interpolation feature"""
self.init_frame_register()
self.reset_split_manager(num_splits)
num_steps = max_steps(num_splits)
self.init_progress(num_splits, num_steps, progress_label)
output_filepath_prefix = os.path.join(output_path, base_filename)
self._set_up_outer_frames(before_filepath, after_filepath, output_filepath_prefix)
self._recursive_split_frames(0.0, 1.0, output_filepath_prefix)
self._integerize_filenames(output_path, base_filename, continued, resynthesis)
self.close_progress()
def _set_up_outer_frames(self,
before_file,
after_file,
output_filepath_prefix):
"""Start with the original frames at 0.0 and 1.0"""
img0 = cv2.imread(before_file)
img1 = cv2.imread(after_file)
# create outer 0.0 and 1.0 versions of original frames
before_index, after_index = 0.0, 1.0
before_file = self.indexed_filepath(output_filepath_prefix, before_index)
after_file = self.indexed_filepath(output_filepath_prefix, after_index)
cv2.imwrite(before_file, img0)
self.register_frame(before_file)
self.log("copied " + before_file)
cv2.imwrite(after_file, img1)
self.register_frame(after_file)
self.log("copied " + after_file)
def _recursive_split_frames(self,
first_index : float,
last_index : float,
filepath_prefix : str):
"""Create a new frame between the given frames, and re-enter to split deeper"""
if self.enter_split():
mid_index = first_index + (last_index - first_index) / 2.0
first_filepath = self.indexed_filepath(filepath_prefix, first_index)
last_filepath = self.indexed_filepath(filepath_prefix, last_index)
mid_filepath = self.indexed_filepath(filepath_prefix, mid_index)
self.interpolater.create_between_frame(first_filepath, last_filepath, mid_filepath)
self.register_frame(mid_filepath)
self.step_progress()
# deal with two new split regions
self._recursive_split_frames(first_index, mid_index, filepath_prefix)
self._recursive_split_frames(mid_index, last_index, filepath_prefix)
self.exit_split()
def _integerize_filenames(self, output_path, base_name, continued, resynthesis):
"""Keep the interpolated frame files with an index number for sorting"""
file_prefix = os.path.join(output_path, base_name)
frame_files = self.sorted_registered_frames()
num_files = len(frame_files)
num_width = len(str(num_files))
index = 0
self.output_paths = []
for file in frame_files:
if resynthesis and (index == 0 or index == num_files - 1):
# if a resynthesis process, keep only the interpolated frames
os.remove(file)
self.log("resynthesis - removed uneeded " + file)
elif continued and index == 0:
# if a continuation from a previous set of frames, delete the first frame
# to maintain continuity since it's duplicate of the previous round last frame
os.remove(file)
self.log("continuation - removed uneeded " + file)
else:
new_filename = file_prefix + str(index).zfill(num_width) + ".png"
os.replace(file, new_filename)
self.output_paths.append(new_filename)
self.log("renamed " + file + " to " + new_filename)
index += 1
def reset_split_manager(self, num_splits : int):
"""Start managing split depths of a new round of searches"""
self.split_count = num_splits
def enter_split(self):
"""Enter a split depth if allowed, returns True if so"""
if self.split_count < 1:
return False
self.split_count -= 1
return True
def exit_split(self):
"""Exit the current split depth"""
self.split_count += 1
def init_frame_register(self):
"""Start managing interpolated frame files for a new round of searches"""
self.frame_register = []
def register_frame(self, filepath : str):
"""Register a found frame file"""
self.frame_register.append(filepath)
def sorted_registered_frames(self):
"""Return a sorted list of the currently registered found frame files"""
return sorted(self.frame_register)
def init_progress(self, num_splits, _max, description):
"""Start managing progress bar for a new found of searches"""
if num_splits < 2:
self.progress = None
else:
self.progress = tqdm(range(_max), desc=description)
def step_progress(self):
"""Advance the progress bar"""
if self.progress:
self.progress.update()
self.progress.refresh()
def close_progress(self):
"""Done with the progress bar"""
if self.progress:
self.progress.close()
# filepath prefix representing the split position while splitting
def indexed_filepath(self, filepath_prefix, index):
"""Filepath prefix representing the split position while splitting"""
float_index = sortable_float_index(index, fixed_width=True)
return filepath_prefix + f"{float_index}.png"
def log(self, message):
"""Logging"""
if self.log_fn:
self.log_fn(message)
if __name__ == '__main__':
main()