Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature Refinement to Improve High Resolution Image Inpainting #112

Merged
merged 6 commits into from
Jul 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,12 @@ bash docker/2_predict.sh $(pwd)/big-lama $(pwd)/LaMa_test_images $(pwd)/output d
```
Docker cuda: TODO

**4. Predict with Refinement**

On the host machine:

python3 bin/predict.py refine=True model.path=$(pwd)/big-lama indir=$(pwd)/LaMa_test_images outdir=$(pwd)/output

# Train and Eval

⚠️ Warning: The training is not fully tested yet, e.g., did not re-training after refactoring ⚠️
Expand Down
56 changes: 32 additions & 24 deletions bin/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import traceback

from saicinpainting.evaluation.utils import move_to_device

from saicinpainting.evaluation.refinement import refine_predict
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['OPENBLAS_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
Expand Down Expand Up @@ -56,34 +56,42 @@ def main(predict_config: OmegaConf):
predict_config.model.checkpoint)
model = load_checkpoint(train_config, checkpoint_path, strict=False, map_location='cpu')
model.freeze()
model.to(device)
if not predict_config.get('refine', False):
model.to(device)

if not predict_config.indir.endswith('/'):
predict_config.indir += '/'

dataset = make_default_val_dataset(predict_config.indir, **predict_config.dataset)
with torch.no_grad():
for img_i in tqdm.trange(len(dataset)):
mask_fname = dataset.mask_filenames[img_i]
cur_out_fname = os.path.join(
predict_config.outdir,
os.path.splitext(mask_fname[len(predict_config.indir):])[0] + out_ext
)
os.makedirs(os.path.dirname(cur_out_fname), exist_ok=True)

batch = move_to_device(default_collate([dataset[img_i]]), device)
batch['mask'] = (batch['mask'] > 0) * 1
batch = model(batch)
cur_res = batch[predict_config.out_key][0].permute(1, 2, 0).detach().cpu().numpy()

unpad_to_size = batch.get('unpad_to_size', None)
if unpad_to_size is not None:
orig_height, orig_width = unpad_to_size
cur_res = cur_res[:orig_height, :orig_width]

cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
cv2.imwrite(cur_out_fname, cur_res)
for img_i in tqdm.trange(len(dataset)):
mask_fname = dataset.mask_filenames[img_i]
cur_out_fname = os.path.join(
predict_config.outdir,
os.path.splitext(mask_fname[len(predict_config.indir):])[0] + out_ext
)
os.makedirs(os.path.dirname(cur_out_fname), exist_ok=True)
batch = default_collate([dataset[img_i]])
if predict_config.get('refine', False):
assert 'unpad_to_size' in batch, "Unpadded size is required for the refinement"
# image unpadding is taken care of in the refiner, so that output image
# is same size as the input image
cur_res = refine_predict(batch, model, **predict_config.refiner)
cur_res = cur_res[0].permute(1,2,0).detach().cpu().numpy()
else:
with torch.no_grad():
batch = move_to_device(batch, device)
batch['mask'] = (batch['mask'] > 0) * 1
batch = model(batch)
cur_res = batch[predict_config.out_key][0].permute(1, 2, 0).detach().cpu().numpy()
unpad_to_size = batch.get('unpad_to_size', None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

L84-87 should be outside if-else - they need to be executed regardless refinement

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

L84-87 already get addressed inside the refiner. Refiner works on unpadded images (it does the necesssary padding internally and then unpads the output appropriately). We can:

  • add an assertion to check unpad_to_size is not None
  • enable refiner to just run on padded image, if unpad_to_size is None.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I see. I'd move padding-unpadding from the refiner to predict.py - so both parts of the code are simplified and no logic duplication is introduced. What do you think, is it possible and does it make sense?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can let the refiner get padded input. But refiner still needs some padding in place. Because -

Suppose your input image is a square of size 1000. Then the original image isn't padded because 1000%8==0, but in the refiner, once we downscale the image, it's size becomes 500, and 500%8!=0. So we have to pad it to make it 504x504.

So we can't get rid of lines 301 and 302 in refinement.py, but we can:

  • let the padded image to be input to refiner, so that we take L84-87 outside the if-else.
  • refiner then doesn't check unpad_to_size argument at all.
  • Padding would happen in the refiner to ensure downscaled image size is divisible by 8.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

refiner still needs some padding in place

I see, thank you for the clarification! Let's just leave that piece of the as is - and add a comment about "padding-unpadding is handled within refiner"

image size is divisible by 8.

Padding size depends on depth of the generator and thus needs to be configurable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha, will add the comment. Yeah the padding size of the refiner is not exactly 8, but exactly equal to dataset.pad_out_to_modulo in the predict config. I'll add a comment there in the PR

if unpad_to_size is not None:
orig_height, orig_width = unpad_to_size
cur_res = cur_res[:orig_height, :orig_width]

cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
cv2.imwrite(cur_out_fname, cur_res)

except KeyboardInterrupt:
LOGGER.warning('Interrupted by user')
except Exception as ex:
Expand Down
10 changes: 10 additions & 0 deletions configs/prediction/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,13 @@ dataset:

device: cuda
out_key: inpainted

refine: False # refiner will only run if this is True
refiner:
gpu_ids: 0,1 # the GPU ids of the machine to use. If only single GPU, use: "0,"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest using only 0 by default - or even introduce "None" default (so refiner would rely on the parent device setting). That would make this work by default in more environments without any modifications by default.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually the refiner needs around 24GB GPU to process 1.8 megapixel images (~1200x1500). Since most people have two 12GB GPUs instead, we decided to split the model onto two GPUs, that's why the default config setting.
Do you suggest to still make it None by default?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, right, I have not thought about memory consumption. It seems that the most of the consumption comes from storing activations for backward... And you're splitting res-blocks between GPUs to distribute that memory - not to speedup inference - because GPUs are called sequentially.

I have a couple ideas how to overcome that without complex logic or requirement to have two GPUs:

  • Set param.requires_grad_(False) for all parameters in the generator - that will lead to storing only activations, not gradients for parameters.
  • Use activation checkpointing - it does something very similar to what you're doing - it splits a nn.Sequential in multiple chunks and runs each chunk with torch.no_grad - so only activations between chunks have to be stored. That will slow the optimization down, but maybe not severely.
  • torch.cuda.amp - optimize in fp16 instead of fp32. In case of refinement there is no adversarial training, so there should not be stability issues due to reduced precision (but I'm not sure)

Copy link
Contributor Author

@ankuPRK ankuPRK May 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the ideas! param.requires_grad_ is already set to False, since we freeze the model here: https://github.com/geomagical/lama-with-refiner/blob/24a20f804390c6ab969c28abbe999c940f8d6a56/bin/predict.py#L58
I also manually verified the requires_grad for all the params of the model, they were False.

We were already looking at activation checkpointing, will focus on it more now that you have also mentioned it. Will try the third idea also.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.cuda.amp isn't working because pytorch doesn't seem to support Half dtype for torch.fft.rfftn. PFA link to the relevant issues in Pytorch repo:

pytorch/pytorch#70664
pytorch/pytorch#71680

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also manually verified the requires_grad for all the params of the model

Great, thank you!

torch.cuda.amp isn't working because pytorch doesn't seem to support Half

Sure, I've forgot that I've already tried half and failed because of that... We could wrap rfftn/irfftn with conversion to and from .float(), but I'm not sure there wouldn't be other issues..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @windj007, sorry for coming back after 2 months! We picked up the experiments, our findings:

  1. We were able to perform the optimization in mixed precision. I haven't benchmarked it quantitatively, but qualitative results look good. However, for 1024x1024 images, it only reduces the memory from 21-22GB -> 17-18GB, so it is still not sufficient to fit on a single 12GB GPU
  2. We also tried to play with checkpointing. Performing it naively throws RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn, which we bypass by setting use_reentrant=False. However, this setting has some memory leak problem, which causes the GPU consumption to increase at each training loop, eventually leading to OOM error. We plan to raise this issue on the Pytorch repo.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please share your code for mixed-precision?

Copy link
Contributor Author

@ankuPRK ankuPRK Aug 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, it's in amp_float16 branch of our fork:

https://github.com/geomagical/lama-with-refiner/tree/amp_float16

You can get this code by:

git clone [email protected]:geomagical/lama-with-refiner.git
git checkout amp_float16

Also, I've changed the config file of the refiner to run on a single GPU. But yeah feel free to play around with config parameters or anything :)

Link to the config file in the code: https://github.com/geomagical/lama-with-refiner/blob/amp_float16/configs/prediction/default.yaml

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much 🙏

modulo: ${dataset.pad_out_to_modulo}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@windj007 refiner padding is defined here

n_iters: 15 # number of iterations of refinement for each scale
lr: 0.002 # learning rate
min_side: 512 # all sides of image on all scales should be >= min_side / sqrt(2)
max_scales: 3 # max number of downscaling scales for the image-mask pyramid
px_budget: 1800000 # pixels budget. Any image will be resized to satisfy height*width <= px_budget
Loading