Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Soft Inpainting #14208

Merged
merged 36 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
dec791d
Removed code which forces the inpainting mask to be 0 or 1. Now fract…
CodeHatchling Nov 28, 2023
bbba133
Removed conflicting step that replaces the softly inpainted latents w…
CodeHatchling Nov 28, 2023
e715e46
Implements "scheduling" for blending of the original latents and a la…
CodeHatchling Nov 28, 2023
a6e5846
Nerfs the aggressive post-processing step of overlaying the original …
CodeHatchling Nov 28, 2023
debf836
Added UI elements to control blending parameters.
CodeHatchling Nov 28, 2023
c5c7fa0
Added slider for detail preservation strength, removed largely needle…
CodeHatchling Nov 29, 2023
284fd8f
Tweaked UI sliders and labels.
CodeHatchling Nov 29, 2023
c7a1ff8
Tweaked default values.
CodeHatchling Nov 29, 2023
609dea3
Added utility functions related to processing masks.
CodeHatchling Dec 3, 2023
73ab982
Blend masks are now produced afterward, based on an estimate of the v…
CodeHatchling Dec 3, 2023
bb04d40
Rewrote latent_blend() to use in-place operations and to aggressively…
CodeHatchling Dec 3, 2023
3bd3a09
Merge remote-tracking branch 'origin/dev' into soft-inpainting
CodeHatchling Dec 3, 2023
28a2b5b
Fixed a math mistake.
CodeHatchling Dec 3, 2023
552f8bc
"Uncrop" the original denoised image for the composite step, fixing a…
CodeHatchling Dec 3, 2023
aaacf48
Organized the settings and UI of soft inpainting to allow for togglin…
CodeHatchling Dec 4, 2023
259d33c
Enables the original functionality to be toggled on and off.
CodeHatchling Dec 4, 2023
976c105
Cleaned up code, moved main code contributions into soft_inpainting.py
CodeHatchling Dec 4, 2023
1455159
Fixed issue with whitespace, removed commented out code that was mean…
CodeHatchling Dec 4, 2023
57f29bd
Re-introduce latent blending step from the vanilla inpainting procedure.
CodeHatchling Dec 5, 2023
60c6022
Restored original formatting.
CodeHatchling Dec 5, 2023
b32a334
Applies a convert('RGBA') operation early to mimic previous behaviour.
CodeHatchling Dec 5, 2023
6fc1242
Fixed issue where batched inpainting (batch size > 1) wouldn't work b…
CodeHatchling Dec 5, 2023
49bbf11
Fixed unused import.
CodeHatchling Dec 5, 2023
3886481
Merge remote-tracking branch 'origin2/dev' into soft-inpainting
CodeHatchling Dec 5, 2023
e90d433
A custom blending function can be provided by p, replacing the use of…
CodeHatchling Dec 6, 2023
4608f62
Removed changes in some scripts since the arguments for soft painting…
CodeHatchling Dec 7, 2023
ac45789
Removed soft inpainting, added hooks for softpainting to work instead.
CodeHatchling Dec 7, 2023
2abc417
Re-implemented soft inpainting via a script. Also fixed some mistakes…
CodeHatchling Dec 7, 2023
8dbacc7
Fixed "No newline at end of file".
CodeHatchling Dec 7, 2023
56604f0
Moved image filters used by soft inpainting into soft_inpainting.py f…
CodeHatchling Dec 7, 2023
0ef4a4c
Fixed error that occurs when using vanilla samplers (somehow).
CodeHatchling Dec 7, 2023
f284ae2
Added parameters for the composite stage, fixed batched generation.
CodeHatchling Dec 8, 2023
fc3e246
Fixed complaint about whitespace, updated help section for a parameter.
CodeHatchling Dec 8, 2023
659f62e
Fixed grammar error.
CodeHatchling Dec 8, 2023
b241447
soft_inpainting now appears in the "inpaint" section, and will not ac…
CodeHatchling Dec 9, 2023
f1ff932
Formatted soft_inpainting.
CodeHatchling Dec 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions modules/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,3 +791,4 @@ def flatten(img, bgcolor):
img = background

return img.convert('RGB')

92 changes: 67 additions & 25 deletions modules/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,28 +62,35 @@ def apply_color_correction(correction, original_image):
return image.convert('RGB')


def apply_overlay(image, paste_loc, index, overlays):
if overlays is None or index >= len(overlays):
return image
def uncrop(image, dest_size, paste_loc):
x, y, w, h = paste_loc
base_image = Image.new('RGBA', dest_size)
image = images.resize_image(1, image, w, h)
base_image.paste(image, (x, y))
image = base_image

return image

overlay = overlays[index]

def apply_overlay(image, paste_loc, overlay):
if overlay is None:
return image

if paste_loc is not None:
x, y, w, h = paste_loc
base_image = Image.new('RGBA', (overlay.width, overlay.height))
image = images.resize_image(1, image, w, h)
base_image.paste(image, (x, y))
image = base_image
image = uncrop(image, (overlay.width, overlay.height), paste_loc)

image = image.convert('RGBA')
image.alpha_composite(overlay)
image = image.convert('RGB')

return image

def create_binary_mask(image):
def create_binary_mask(image, round=True):
if image.mode == 'RGBA' and image.getextrema()[-1] != (255, 255):
image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
if round:
image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
else:
image = image.split()[-1].convert("L")
else:
image = image.convert('L')
return image
Expand Down Expand Up @@ -308,7 +315,7 @@ def unclip_image_conditioning(self, source_image):
c_adm = torch.cat((c_adm, noise_level_emb), 1)
return c_adm

def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None):
def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):
self.is_using_inpainting_conditioning = True

# Handle the different mask inputs
Expand All @@ -320,8 +327,10 @@ def inpainting_image_conditioning(self, source_image, latent_image, image_mask=N
conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
conditioning_mask = torch.from_numpy(conditioning_mask[None, None])

# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
conditioning_mask = torch.round(conditioning_mask)
if round_image_mask:
# Caller is requesting a discretized mask as input, so we round to either 1.0 or 0.0
conditioning_mask = torch.round(conditioning_mask)

else:
conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])

Expand All @@ -345,7 +354,7 @@ def inpainting_image_conditioning(self, source_image, latent_image, image_mask=N

return image_conditioning

def img2img_image_conditioning(self, source_image, latent_image, image_mask=None):
def img2img_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):
source_image = devices.cond_cast_float(source_image)

# HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
Expand All @@ -357,7 +366,7 @@ def img2img_image_conditioning(self, source_image, latent_image, image_mask=None
return self.edit_image_conditioning(source_image)

if self.sampler.conditioning_key in {'hybrid', 'concat'}:
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask, round_image_mask=round_image_mask)

if self.sampler.conditioning_key == "crossattn-adm":
return self.unclip_image_conditioning(source_image)
Expand Down Expand Up @@ -867,6 +876,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)

if p.scripts is not None:
ps = scripts.PostSampleArgs(samples_ddim)
p.scripts.post_sample(p, ps)
samples_ddim = ps.samples

if getattr(samples_ddim, 'already_decoded', False):
x_samples_ddim = samples_ddim
else:
Expand Down Expand Up @@ -922,13 +936,31 @@ def infotext(index=0, use_main_prompt=False):
pp = scripts.PostprocessImageArgs(image)
p.scripts.postprocess_image(p, pp)
image = pp.image

mask_for_overlay = getattr(p, "mask_for_overlay", None)
overlay_image = p.overlay_images[i] if getattr(p, "overlay_images", None) is not None and i < len(p.overlay_images) else None

if p.scripts is not None:
ppmo = scripts.PostProcessMaskOverlayArgs(i, mask_for_overlay, overlay_image)
p.scripts.postprocess_maskoverlay(p, ppmo)
mask_for_overlay, overlay_image = ppmo.mask_for_overlay, ppmo.overlay_image

if p.color_corrections is not None and i < len(p.color_corrections):
if save_samples and opts.save_images_before_color_correction:
image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
image_without_cc = apply_overlay(image, p.paste_to, overlay_image)
images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction")
image = apply_color_correction(p.color_corrections[i], image)

image = apply_overlay(image, p.paste_to, i, p.overlay_images)
# If the intention is to show the output from the model
# that is being composited over the original image,
# we need to keep the original image around
# and use it in the composite step.
original_denoised_image = image.copy()

if p.paste_to is not None:
original_denoised_image = uncrop(original_denoised_image, (overlay_image.width, overlay_image.height), p.paste_to)

image = apply_overlay(image, p.paste_to, overlay_image)

if save_samples:
images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p)
Expand All @@ -938,16 +970,17 @@ def infotext(index=0, use_main_prompt=False):
if opts.enable_pnginfo:
image.info["parameters"] = text
output_images.append(image)
if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay:

if mask_for_overlay is not None:
if opts.return_mask or opts.save_mask:
image_mask = p.mask_for_overlay.convert('RGB')
image_mask = mask_for_overlay.convert('RGB')
if save_samples and opts.save_mask:
images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask")
if opts.return_mask:
output_images.append(image_mask)

if opts.return_mask_composite or opts.save_mask_composite:
image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
image_mask_composite = Image.composite(original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
if save_samples and opts.save_mask_composite:
images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite")
if opts.return_mask_composite:
Expand Down Expand Up @@ -1351,6 +1384,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
mask_blur_x: int = 4
mask_blur_y: int = 4
mask_blur: int = None
mask_round: bool = True
inpainting_fill: int = 0
inpaint_full_res: bool = True
inpaint_full_res_padding: int = 0
Expand Down Expand Up @@ -1396,7 +1430,7 @@ def init(self, all_prompts, all_seeds, all_subseeds):
if image_mask is not None:
# image_mask is passed in as RGBA by Gradio to support alpha masks,
# but we still want to support binary masks.
image_mask = create_binary_mask(image_mask)
image_mask = create_binary_mask(image_mask, round=self.mask_round)

if self.inpainting_mask_invert:
image_mask = ImageOps.invert(image_mask)
Expand Down Expand Up @@ -1503,7 +1537,8 @@ def init(self, all_prompts, all_seeds, all_subseeds):
latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
latmask = latmask[0]
latmask = np.around(latmask)
if self.mask_round:
latmask = np.around(latmask)
latmask = np.tile(latmask[None], (4, 1, 1))

self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
Expand All @@ -1515,7 +1550,7 @@ def init(self, all_prompts, all_seeds, all_subseeds):
elif self.inpainting_fill == 3:
self.init_latent = self.init_latent * self.mask

self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask)
self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask, self.mask_round)

def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
x = self.rng.next()
Expand All @@ -1527,7 +1562,14 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)

if self.mask is not None:
samples = samples * self.nmask + self.init_latent * self.mask
blended_samples = samples * self.nmask + self.init_latent * self.mask

if self.scripts is not None:
mba = scripts.MaskBlendArgs(samples, self.nmask, self.init_latent, self.mask, blended_samples)
self.scripts.on_mask_blend(self, mba)
blended_samples = mba.blended_latent

samples = blended_samples

del x
devices.torch_gc()
Expand Down
70 changes: 70 additions & 0 deletions modules/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,31 @@

AlwaysVisible = object()

class MaskBlendArgs:
def __init__(self, current_latent, nmask, init_latent, mask, blended_latent, denoiser=None, sigma=None):
self.current_latent = current_latent
self.nmask = nmask
self.init_latent = init_latent
self.mask = mask
self.blended_latent = blended_latent

self.denoiser = denoiser
self.is_final_blend = denoiser is None
self.sigma = sigma

class PostSampleArgs:
def __init__(self, samples):
self.samples = samples

class PostprocessImageArgs:
def __init__(self, image):
self.image = image

class PostProcessMaskOverlayArgs:
def __init__(self, index, mask_for_overlay, overlay_image):
self.index = index
self.mask_for_overlay = mask_for_overlay
self.overlay_image = overlay_image

class PostprocessBatchListArgs:
def __init__(self, images):
Expand Down Expand Up @@ -206,13 +226,39 @@ def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, *args, **kwarg

pass

def on_mask_blend(self, p, mba: MaskBlendArgs, *args):
"""
Called in inpainting mode when the original content is blended with the inpainted content.
This is called at every step in the denoising process and once at the end.
If is_final_blend is true, this is called for the final blending stage.
Otherwise, denoiser and sigma are defined and may be used to inform the procedure.
"""

pass

def post_sample(self, p, ps: PostSampleArgs, *args):
"""
Called after the samples have been generated,
but before they have been decoded by the VAE, if applicable.
Check getattr(samples, 'already_decoded', False) to test if the images are decoded.
"""

pass

def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
"""
Called for every image after it has been generated.
"""

pass

def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs, *args):
"""
Called for every image after it has been generated.
"""

pass

def postprocess(self, p, processed, *args):
"""
This function is called after processing ends for AlwaysVisible scripts.
Expand Down Expand Up @@ -767,6 +813,22 @@ def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, **kwargs):
except Exception:
errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True)

def post_sample(self, p, ps: PostSampleArgs):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.post_sample(p, ps, *script_args)
except Exception:
errors.report(f"Error running post_sample: {script.filename}", exc_info=True)

def on_mask_blend(self, p, mba: MaskBlendArgs):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.on_mask_blend(p, mba, *script_args)
except Exception:
errors.report(f"Error running post_sample: {script.filename}", exc_info=True)

def postprocess_image(self, p, pp: PostprocessImageArgs):
for script in self.alwayson_scripts:
try:
Expand All @@ -775,6 +837,14 @@ def postprocess_image(self, p, pp: PostprocessImageArgs):
except Exception:
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)

def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_maskoverlay(p, ppmo, *script_args)
except Exception:
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)

def before_component(self, component, **kwargs):
for callback, script in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []):
try:
Expand Down
21 changes: 19 additions & 2 deletions modules/sd_samplers_cfg_denoiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def __init__(self, sampler):
self.sampler = sampler
self.model_wrap = None
self.p = None

# NOTE: masking before denoising can cause the original latents to be oversmoothed
# as the original latents do not have noise
self.mask_before_denoising = False

@property
Expand Down Expand Up @@ -105,8 +108,21 @@ def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):

assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"

# If we use masks, blending between the denoised and original latent images occurs here.
def apply_blend(current_latent):
blended_latent = current_latent * self.nmask + self.init_latent * self.mask

if self.p.scripts is not None:
from modules import scripts
mba = scripts.MaskBlendArgs(current_latent, self.nmask, self.init_latent, self.mask, blended_latent, denoiser=self, sigma=sigma)
self.p.scripts.on_mask_blend(self.p, mba)
blended_latent = mba.blended_latent

return blended_latent

# Blend in the original latents (before)
if self.mask_before_denoising and self.mask is not None:
x = self.init_latent * self.mask + self.nmask * x
x = apply_blend(x)

batch_size = len(conds_list)
repeats = [len(conds_list[i]) for i in range(batch_size)]
Expand Down Expand Up @@ -207,8 +223,9 @@ def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
else:
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)

# Blend in the original latents (after)
if not self.mask_before_denoising and self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
denoised = apply_blend(denoised)

self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)

Expand Down
Loading