Skip to content

Commit

Permalink
Merge pull request advimman#112 from geomagical/refinement
Browse files Browse the repository at this point in the history
Feature Refinement to Improve High Resolution Image Inpainting
  • Loading branch information
senya-ashukha authored and HeunSeungLim committed Nov 14, 2022
2 parents a7cf3ce + 63d7816 commit 58acfcd
Show file tree
Hide file tree
Showing 124 changed files with 14,580 additions and 24 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,12 @@ bash docker/2_predict.sh $(pwd)/big-lama $(pwd)/LaMa_test_images $(pwd)/output d
```
Docker cuda: TODO

**4. Predict with Refinement**

On the host machine:

python3 bin/predict.py refine=True model.path=$(pwd)/big-lama indir=$(pwd)/LaMa_test_images outdir=$(pwd)/output

# Train and Eval

⚠️ Warning: The training is not fully tested yet, e.g., did not re-training after refactoring ⚠️
Expand Down
56 changes: 32 additions & 24 deletions bin/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import traceback

from saicinpainting.evaluation.utils import move_to_device

from saicinpainting.evaluation.refinement import refine_predict
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['OPENBLAS_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
Expand Down Expand Up @@ -56,34 +56,42 @@ def main(predict_config: OmegaConf):
predict_config.model.checkpoint)
model = load_checkpoint(train_config, checkpoint_path, strict=False, map_location='cpu')
model.freeze()
model.to(device)
if not predict_config.get('refine', False):
model.to(device)

if not predict_config.indir.endswith('/'):
predict_config.indir += '/'

dataset = make_default_val_dataset(predict_config.indir, **predict_config.dataset)
with torch.no_grad():
for img_i in tqdm.trange(len(dataset)):
mask_fname = dataset.mask_filenames[img_i]
cur_out_fname = os.path.join(
predict_config.outdir,
os.path.splitext(mask_fname[len(predict_config.indir):])[0] + out_ext
)
os.makedirs(os.path.dirname(cur_out_fname), exist_ok=True)

batch = move_to_device(default_collate([dataset[img_i]]), device)
batch['mask'] = (batch['mask'] > 0) * 1
batch = model(batch)
cur_res = batch[predict_config.out_key][0].permute(1, 2, 0).detach().cpu().numpy()

unpad_to_size = batch.get('unpad_to_size', None)
if unpad_to_size is not None:
orig_height, orig_width = unpad_to_size
cur_res = cur_res[:orig_height, :orig_width]

cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
cv2.imwrite(cur_out_fname, cur_res)
for img_i in tqdm.trange(len(dataset)):
mask_fname = dataset.mask_filenames[img_i]
cur_out_fname = os.path.join(
predict_config.outdir,
os.path.splitext(mask_fname[len(predict_config.indir):])[0] + out_ext
)
os.makedirs(os.path.dirname(cur_out_fname), exist_ok=True)
batch = default_collate([dataset[img_i]])
if predict_config.get('refine', False):
assert 'unpad_to_size' in batch, "Unpadded size is required for the refinement"
# image unpadding is taken care of in the refiner, so that output image
# is same size as the input image
cur_res = refine_predict(batch, model, **predict_config.refiner)
cur_res = cur_res[0].permute(1,2,0).detach().cpu().numpy()
else:
with torch.no_grad():
batch = move_to_device(batch, device)
batch['mask'] = (batch['mask'] > 0) * 1
batch = model(batch)
cur_res = batch[predict_config.out_key][0].permute(1, 2, 0).detach().cpu().numpy()
unpad_to_size = batch.get('unpad_to_size', None)
if unpad_to_size is not None:
orig_height, orig_width = unpad_to_size
cur_res = cur_res[:orig_height, :orig_width]

cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
cv2.imwrite(cur_out_fname, cur_res)

except KeyboardInterrupt:
LOGGER.warning('Interrupted by user')
except Exception as ex:
Expand Down
Empty file added bin/saicinpainting/__init__.py
Empty file.
33 changes: 33 additions & 0 deletions bin/saicinpainting/evaluation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import logging

import torch

from saicinpainting.evaluation.evaluator import InpaintingEvaluatorOnline, ssim_fid100_f1, lpips_fid100_f1
from saicinpainting.evaluation.losses.base_loss import SSIMScore, LPIPSScore, FIDScore


def make_evaluator(kind='default', ssim=True, lpips=True, fid=True, integral_kind=None, **kwargs):
logging.info(f'Make evaluator {kind}')
device = "cuda" if torch.cuda.is_available() else "cpu"
metrics = {}
if ssim:
metrics['ssim'] = SSIMScore()
if lpips:
metrics['lpips'] = LPIPSScore()
if fid:
metrics['fid'] = FIDScore().to(device)

if integral_kind is None:
integral_func = None
elif integral_kind == 'ssim_fid100_f1':
integral_func = ssim_fid100_f1
elif integral_kind == 'lpips_fid100_f1':
integral_func = lpips_fid100_f1
else:
raise ValueError(f'Unexpected integral_kind={integral_kind}')

if kind == 'default':
return InpaintingEvaluatorOnline(scores=metrics,
integral_func=integral_func,
integral_title=integral_kind,
**kwargs)
168 changes: 168 additions & 0 deletions bin/saicinpainting/evaluation/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import glob
import os

import cv2
import PIL.Image as Image
import numpy as np

from torch.utils.data import Dataset
import torch.nn.functional as F


def load_image(fname, mode='RGB', return_orig=False):
img = np.array(Image.open(fname).convert(mode))
if img.ndim == 3:
img = np.transpose(img, (2, 0, 1))
out_img = img.astype('float32') / 255
if return_orig:
return out_img, img
else:
return out_img


def ceil_modulo(x, mod):
if x % mod == 0:
return x
return (x // mod + 1) * mod


def pad_img_to_modulo(img, mod):
channels, height, width = img.shape
out_height = ceil_modulo(height, mod)
out_width = ceil_modulo(width, mod)
return np.pad(img, ((0, 0), (0, out_height - height), (0, out_width - width)), mode='symmetric')


def pad_tensor_to_modulo(img, mod):
batch_size, channels, height, width = img.shape
out_height = ceil_modulo(height, mod)
out_width = ceil_modulo(width, mod)
return F.pad(img, pad=(0, out_width - width, 0, out_height - height), mode='reflect')


def scale_image(img, factor, interpolation=cv2.INTER_AREA):
if img.shape[0] == 1:
img = img[0]
else:
img = np.transpose(img, (1, 2, 0))

img = cv2.resize(img, dsize=None, fx=factor, fy=factor, interpolation=interpolation)

if img.ndim == 2:
img = img[None, ...]
else:
img = np.transpose(img, (2, 0, 1))
return img


class InpaintingDataset(Dataset):
def __init__(self, datadir, img_suffix='.jpg', pad_out_to_modulo=None, scale_factor=None):
self.datadir = datadir
self.mask_filenames = sorted(list(glob.glob(os.path.join(self.datadir, '**', '*mask*.png'), recursive=True)))
self.img_filenames = [fname.rsplit('_mask', 1)[0] + img_suffix for fname in self.mask_filenames]
self.pad_out_to_modulo = pad_out_to_modulo
self.scale_factor = scale_factor

def __len__(self):
return len(self.mask_filenames)

def __getitem__(self, i):
image = load_image(self.img_filenames[i], mode='RGB')
mask = load_image(self.mask_filenames[i], mode='L')
result = dict(image=image, mask=mask[None, ...])

if self.scale_factor is not None:
result['image'] = scale_image(result['image'], self.scale_factor)
result['mask'] = scale_image(result['mask'], self.scale_factor, interpolation=cv2.INTER_NEAREST)

if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1:
result['unpad_to_size'] = result['image'].shape[1:]
result['image'] = pad_img_to_modulo(result['image'], self.pad_out_to_modulo)
result['mask'] = pad_img_to_modulo(result['mask'], self.pad_out_to_modulo)

return result

class OurInpaintingDataset(Dataset):
def __init__(self, datadir, img_suffix='.jpg', pad_out_to_modulo=None, scale_factor=None):
self.datadir = datadir
self.mask_filenames = sorted(list(glob.glob(os.path.join(self.datadir, 'mask', '**', '*mask*.png'), recursive=True)))
self.img_filenames = [os.path.join(self.datadir, 'img', os.path.basename(fname.rsplit('-', 1)[0].rsplit('_', 1)[0]) + '.png') for fname in self.mask_filenames]
self.pad_out_to_modulo = pad_out_to_modulo
self.scale_factor = scale_factor

def __len__(self):
return len(self.mask_filenames)

def __getitem__(self, i):
result = dict(image=load_image(self.img_filenames[i], mode='RGB'),
mask=load_image(self.mask_filenames[i], mode='L')[None, ...])

if self.scale_factor is not None:
result['image'] = scale_image(result['image'], self.scale_factor)
result['mask'] = scale_image(result['mask'], self.scale_factor)

if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1:
result['image'] = pad_img_to_modulo(result['image'], self.pad_out_to_modulo)
result['mask'] = pad_img_to_modulo(result['mask'], self.pad_out_to_modulo)

return result

class PrecomputedInpaintingResultsDataset(InpaintingDataset):
def __init__(self, datadir, predictdir, inpainted_suffix='_inpainted.jpg', **kwargs):
super().__init__(datadir, **kwargs)
if not datadir.endswith('/'):
datadir += '/'
self.predictdir = predictdir
self.pred_filenames = [os.path.join(predictdir, os.path.splitext(fname[len(datadir):])[0] + inpainted_suffix)
for fname in self.mask_filenames]

def __getitem__(self, i):
result = super().__getitem__(i)
result['inpainted'] = load_image(self.pred_filenames[i])
if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1:
result['inpainted'] = pad_img_to_modulo(result['inpainted'], self.pad_out_to_modulo)
return result

class OurPrecomputedInpaintingResultsDataset(OurInpaintingDataset):
def __init__(self, datadir, predictdir, inpainted_suffix="png", **kwargs):
super().__init__(datadir, **kwargs)
if not datadir.endswith('/'):
datadir += '/'
self.predictdir = predictdir
self.pred_filenames = [os.path.join(predictdir, os.path.basename(os.path.splitext(fname)[0]) + f'_inpainted.{inpainted_suffix}')
for fname in self.mask_filenames]
# self.pred_filenames = [os.path.join(predictdir, os.path.splitext(fname[len(datadir):])[0] + inpainted_suffix)
# for fname in self.mask_filenames]

def __getitem__(self, i):
result = super().__getitem__(i)
result['inpainted'] = self.file_loader(self.pred_filenames[i])

if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1:
result['inpainted'] = pad_img_to_modulo(result['inpainted'], self.pad_out_to_modulo)
return result

class InpaintingEvalOnlineDataset(Dataset):
def __init__(self, indir, mask_generator, img_suffix='.jpg', pad_out_to_modulo=None, scale_factor=None, **kwargs):
self.indir = indir
self.mask_generator = mask_generator
self.img_filenames = sorted(list(glob.glob(os.path.join(self.indir, '**', f'*{img_suffix}' ), recursive=True)))
self.pad_out_to_modulo = pad_out_to_modulo
self.scale_factor = scale_factor

def __len__(self):
return len(self.img_filenames)

def __getitem__(self, i):
img, raw_image = load_image(self.img_filenames[i], mode='RGB', return_orig=True)
mask = self.mask_generator(img, raw_image=raw_image)
result = dict(image=img, mask=mask)

if self.scale_factor is not None:
result['image'] = scale_image(result['image'], self.scale_factor)
result['mask'] = scale_image(result['mask'], self.scale_factor, interpolation=cv2.INTER_NEAREST)

if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1:
result['image'] = pad_img_to_modulo(result['image'], self.pad_out_to_modulo)
result['mask'] = pad_img_to_modulo(result['mask'], self.pad_out_to_modulo)
return result
Loading

0 comments on commit 58acfcd

Please sign in to comment.