forked from advimman/lama
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request advimman#112 from geomagical/refinement
Feature Refinement to Improve High Resolution Image Inpainting
- Loading branch information
Showing
124 changed files
with
14,580 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import logging | ||
|
||
import torch | ||
|
||
from saicinpainting.evaluation.evaluator import InpaintingEvaluatorOnline, ssim_fid100_f1, lpips_fid100_f1 | ||
from saicinpainting.evaluation.losses.base_loss import SSIMScore, LPIPSScore, FIDScore | ||
|
||
|
||
def make_evaluator(kind='default', ssim=True, lpips=True, fid=True, integral_kind=None, **kwargs): | ||
logging.info(f'Make evaluator {kind}') | ||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
metrics = {} | ||
if ssim: | ||
metrics['ssim'] = SSIMScore() | ||
if lpips: | ||
metrics['lpips'] = LPIPSScore() | ||
if fid: | ||
metrics['fid'] = FIDScore().to(device) | ||
|
||
if integral_kind is None: | ||
integral_func = None | ||
elif integral_kind == 'ssim_fid100_f1': | ||
integral_func = ssim_fid100_f1 | ||
elif integral_kind == 'lpips_fid100_f1': | ||
integral_func = lpips_fid100_f1 | ||
else: | ||
raise ValueError(f'Unexpected integral_kind={integral_kind}') | ||
|
||
if kind == 'default': | ||
return InpaintingEvaluatorOnline(scores=metrics, | ||
integral_func=integral_func, | ||
integral_title=integral_kind, | ||
**kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
import glob | ||
import os | ||
|
||
import cv2 | ||
import PIL.Image as Image | ||
import numpy as np | ||
|
||
from torch.utils.data import Dataset | ||
import torch.nn.functional as F | ||
|
||
|
||
def load_image(fname, mode='RGB', return_orig=False): | ||
img = np.array(Image.open(fname).convert(mode)) | ||
if img.ndim == 3: | ||
img = np.transpose(img, (2, 0, 1)) | ||
out_img = img.astype('float32') / 255 | ||
if return_orig: | ||
return out_img, img | ||
else: | ||
return out_img | ||
|
||
|
||
def ceil_modulo(x, mod): | ||
if x % mod == 0: | ||
return x | ||
return (x // mod + 1) * mod | ||
|
||
|
||
def pad_img_to_modulo(img, mod): | ||
channels, height, width = img.shape | ||
out_height = ceil_modulo(height, mod) | ||
out_width = ceil_modulo(width, mod) | ||
return np.pad(img, ((0, 0), (0, out_height - height), (0, out_width - width)), mode='symmetric') | ||
|
||
|
||
def pad_tensor_to_modulo(img, mod): | ||
batch_size, channels, height, width = img.shape | ||
out_height = ceil_modulo(height, mod) | ||
out_width = ceil_modulo(width, mod) | ||
return F.pad(img, pad=(0, out_width - width, 0, out_height - height), mode='reflect') | ||
|
||
|
||
def scale_image(img, factor, interpolation=cv2.INTER_AREA): | ||
if img.shape[0] == 1: | ||
img = img[0] | ||
else: | ||
img = np.transpose(img, (1, 2, 0)) | ||
|
||
img = cv2.resize(img, dsize=None, fx=factor, fy=factor, interpolation=interpolation) | ||
|
||
if img.ndim == 2: | ||
img = img[None, ...] | ||
else: | ||
img = np.transpose(img, (2, 0, 1)) | ||
return img | ||
|
||
|
||
class InpaintingDataset(Dataset): | ||
def __init__(self, datadir, img_suffix='.jpg', pad_out_to_modulo=None, scale_factor=None): | ||
self.datadir = datadir | ||
self.mask_filenames = sorted(list(glob.glob(os.path.join(self.datadir, '**', '*mask*.png'), recursive=True))) | ||
self.img_filenames = [fname.rsplit('_mask', 1)[0] + img_suffix for fname in self.mask_filenames] | ||
self.pad_out_to_modulo = pad_out_to_modulo | ||
self.scale_factor = scale_factor | ||
|
||
def __len__(self): | ||
return len(self.mask_filenames) | ||
|
||
def __getitem__(self, i): | ||
image = load_image(self.img_filenames[i], mode='RGB') | ||
mask = load_image(self.mask_filenames[i], mode='L') | ||
result = dict(image=image, mask=mask[None, ...]) | ||
|
||
if self.scale_factor is not None: | ||
result['image'] = scale_image(result['image'], self.scale_factor) | ||
result['mask'] = scale_image(result['mask'], self.scale_factor, interpolation=cv2.INTER_NEAREST) | ||
|
||
if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1: | ||
result['unpad_to_size'] = result['image'].shape[1:] | ||
result['image'] = pad_img_to_modulo(result['image'], self.pad_out_to_modulo) | ||
result['mask'] = pad_img_to_modulo(result['mask'], self.pad_out_to_modulo) | ||
|
||
return result | ||
|
||
class OurInpaintingDataset(Dataset): | ||
def __init__(self, datadir, img_suffix='.jpg', pad_out_to_modulo=None, scale_factor=None): | ||
self.datadir = datadir | ||
self.mask_filenames = sorted(list(glob.glob(os.path.join(self.datadir, 'mask', '**', '*mask*.png'), recursive=True))) | ||
self.img_filenames = [os.path.join(self.datadir, 'img', os.path.basename(fname.rsplit('-', 1)[0].rsplit('_', 1)[0]) + '.png') for fname in self.mask_filenames] | ||
self.pad_out_to_modulo = pad_out_to_modulo | ||
self.scale_factor = scale_factor | ||
|
||
def __len__(self): | ||
return len(self.mask_filenames) | ||
|
||
def __getitem__(self, i): | ||
result = dict(image=load_image(self.img_filenames[i], mode='RGB'), | ||
mask=load_image(self.mask_filenames[i], mode='L')[None, ...]) | ||
|
||
if self.scale_factor is not None: | ||
result['image'] = scale_image(result['image'], self.scale_factor) | ||
result['mask'] = scale_image(result['mask'], self.scale_factor) | ||
|
||
if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1: | ||
result['image'] = pad_img_to_modulo(result['image'], self.pad_out_to_modulo) | ||
result['mask'] = pad_img_to_modulo(result['mask'], self.pad_out_to_modulo) | ||
|
||
return result | ||
|
||
class PrecomputedInpaintingResultsDataset(InpaintingDataset): | ||
def __init__(self, datadir, predictdir, inpainted_suffix='_inpainted.jpg', **kwargs): | ||
super().__init__(datadir, **kwargs) | ||
if not datadir.endswith('/'): | ||
datadir += '/' | ||
self.predictdir = predictdir | ||
self.pred_filenames = [os.path.join(predictdir, os.path.splitext(fname[len(datadir):])[0] + inpainted_suffix) | ||
for fname in self.mask_filenames] | ||
|
||
def __getitem__(self, i): | ||
result = super().__getitem__(i) | ||
result['inpainted'] = load_image(self.pred_filenames[i]) | ||
if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1: | ||
result['inpainted'] = pad_img_to_modulo(result['inpainted'], self.pad_out_to_modulo) | ||
return result | ||
|
||
class OurPrecomputedInpaintingResultsDataset(OurInpaintingDataset): | ||
def __init__(self, datadir, predictdir, inpainted_suffix="png", **kwargs): | ||
super().__init__(datadir, **kwargs) | ||
if not datadir.endswith('/'): | ||
datadir += '/' | ||
self.predictdir = predictdir | ||
self.pred_filenames = [os.path.join(predictdir, os.path.basename(os.path.splitext(fname)[0]) + f'_inpainted.{inpainted_suffix}') | ||
for fname in self.mask_filenames] | ||
# self.pred_filenames = [os.path.join(predictdir, os.path.splitext(fname[len(datadir):])[0] + inpainted_suffix) | ||
# for fname in self.mask_filenames] | ||
|
||
def __getitem__(self, i): | ||
result = super().__getitem__(i) | ||
result['inpainted'] = self.file_loader(self.pred_filenames[i]) | ||
|
||
if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1: | ||
result['inpainted'] = pad_img_to_modulo(result['inpainted'], self.pad_out_to_modulo) | ||
return result | ||
|
||
class InpaintingEvalOnlineDataset(Dataset): | ||
def __init__(self, indir, mask_generator, img_suffix='.jpg', pad_out_to_modulo=None, scale_factor=None, **kwargs): | ||
self.indir = indir | ||
self.mask_generator = mask_generator | ||
self.img_filenames = sorted(list(glob.glob(os.path.join(self.indir, '**', f'*{img_suffix}' ), recursive=True))) | ||
self.pad_out_to_modulo = pad_out_to_modulo | ||
self.scale_factor = scale_factor | ||
|
||
def __len__(self): | ||
return len(self.img_filenames) | ||
|
||
def __getitem__(self, i): | ||
img, raw_image = load_image(self.img_filenames[i], mode='RGB', return_orig=True) | ||
mask = self.mask_generator(img, raw_image=raw_image) | ||
result = dict(image=img, mask=mask) | ||
|
||
if self.scale_factor is not None: | ||
result['image'] = scale_image(result['image'], self.scale_factor) | ||
result['mask'] = scale_image(result['mask'], self.scale_factor, interpolation=cv2.INTER_NEAREST) | ||
|
||
if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1: | ||
result['image'] = pad_img_to_modulo(result['image'], self.pad_out_to_modulo) | ||
result['mask'] = pad_img_to_modulo(result['mask'], self.pad_out_to_modulo) | ||
return result |
Oops, something went wrong.