Skip to content

Commit

Permalink
fix: cutouts generation (#138)
Browse files Browse the repository at this point in the history
* fix: cutouts generation

* fix: cutouts generation

* fix: cutouts generation

* fix: cutouts generation

* fix: cutouts generation

* fix: cutouts generation

* fix: cutouts generation

* fix: cutouts generation

* fix: cutouts generation

* fix: cutouts generation

* fix: cutouts generation

* fix: cutouts generation

* fix: cutouts generation

* fix: cutouts generation

* fix: cutouts generation

* fix: cutouts generation

* style: fix overload and cli autocomplete

* fix: cutouts generation

* fix: cutouts generation

Co-authored-by: Jina Dev Bot <[email protected]>
  • Loading branch information
hanxiao and jina-bot committed Aug 4, 2022
1 parent aec0e67 commit b6bd8dd
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 179 deletions.
2 changes: 1 addition & 1 deletion discoart/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

__version__ = '0.10.16'
__version__ = '0.11.0'

__all__ = ['create', 'cheatsheet']

Expand Down
2 changes: 1 addition & 1 deletion discoart/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
) as ymlfile:
cut_schedules = yaml.load(ymlfile, Loader=Loader)

_legacy_args = {'clip_sequential_evaluation', 'fuzzy_prompt'}
_legacy_args = {'clip_sequential_evaluation', 'fuzzy_prompt', 'skip_augs'}


def load_config(
Expand Down
2 changes: 0 additions & 2 deletions discoart/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def create(
sat_scale: Optional[Union[int, str]] = 0,
save_rate: Optional[int] = 20,
seed: Optional[int] = None,
skip_augs: Optional[Union[bool, str]] = False,
skip_event: Optional[
Union['multiprocessing.Event', 'asyncio.Event', 'threading.Event']
] = None,
Expand Down Expand Up @@ -134,7 +133,6 @@ def create(**kwargs) -> Optional['DocumentArray']:
:param sat_scale: Saturation scale. Optional, set to zero to turn off. If used, sat_scale will help mitigate oversaturation. If your image is too saturated, increase sat_scale to reduce the saturation.[DiscoArt] Can be scheduled via syntax `[val1]*400+[val2]*600`.
:param save_rate: [DiscoArt] The number of steps to save intermediate results. It is a replacement to original `display_rate` parameter. Set it to -1 for not saving any intermediate result.
:param seed: Deep in the diffusion code, there is a random number ‘seed’ which is used as the basis for determining the initial state of the diffusion. By default, this is random, but you can also specify your own seed. This is useful if you like a particular result and would like to run more iterations that will be similar. After each run, the actual seed value used will be reported in the parameters report, and can be reused if desired by entering seed # here. If a specific numerical seed is used repeatedly, the resulting images will be quite similar but not identical.
:param skip_augs: Controls whether to skip torchvision augmentations.[DiscoArt] Can be scheduled via syntax `[val1]*400+[val2]*600`.
:param skip_event: [DiscoArt] A multiprocessing/asyncio/threading.Event that once set, will skip the current run and move to the next run as defined in `n_batches`.
:param skip_steps: Consider the chart shown here. Noise scheduling (denoise strength) starts very high and progressively gets lower and lower as diffusion steps progress. The noise levels in the first few steps are very high, so images change dramatically in early steps.As DD moves along the curve, noise levels (and thus the amount an image changes per step) declines, and image coherence from one step to the next increases.The first few steps of denoising are often so dramatic that some steps (maybe 10-15% of total) can be skipped without affecting the final image. You can experiment with this as a way to cut render times.If you skip too many steps, however, the remaining noise may not be high enough to generate new content, and thus may not have ‘time left’ to finish an image satisfactorily.Also, depending on your other settings, you may need to skip steps to prevent CLIP from overshooting your goal, resulting in ‘blown out’ colors (hyper saturated, solid white, or solid black regions) or otherwise poor image quality. Consider that the denoising process is at its strongest in the early steps, so skipping steps can sometimes mitigate other problems.Lastly, if using an init_image, you will need to skip ~50% of the diffusion steps to retain the shapes in the original init image. However, if you’re using an init_image, you can also adjust skip_steps up or down for creative reasons. With low skip_steps you can get a result "inspired by" the init_image which will retain the colors and rough layout and shapes but look quite different. With high skip_steps you can preserve most of the init_image contents and just do fine tuning of the texture.
:param steps: When creating an image, the denoising curve is subdivided into steps for processing. Each step (or iteration) involves the AI looking at subsets of the image called ‘cuts’ and calculating the ‘direction’ the image should be guided to be more like the prompt. Then it adjusts the image with the help of the diffusion denoiser, and moves to the next step.Increasing steps will provide more opportunities for the AI to adjust the image, and each adjustment will be smaller, and thus will yield a more precise, detailed image. Increasing steps comes at the expense of longer render times. Also, while increasing steps should generally increase image quality, there is a diminishing return on additional steps beyond 250 - 500 steps. However, some intricate images can take 1000, 2000, or more steps. It is really up to the user. Just know that the render time is directly related to the number of steps, and many other parameters have a major impact on image quality, without costing additional time.
Expand Down
1 change: 0 additions & 1 deletion discoart/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,6 @@ def _get_schedule_table(args) -> Dict:
'cut_ic_pow',
'use_secondary_model',
'cutn_batches',
'skip_augs',
'clip_guidance_scale',
'tv_scale',
'range_scale',
Expand Down
190 changes: 41 additions & 149 deletions discoart/nn/make_cutouts.py
Original file line number Diff line number Diff line change
@@ -1,82 +1,28 @@
import math

import torch
from resize_right import resize
from torch import nn
from torch.nn import functional as F
from torchvision import transforms as T
from torchvision.transforms import functional as TF


def sinc(x):
return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))


def lanczos(x, a):
cond = torch.logical_and(-a < x, x < a)
out = torch.where(cond, sinc(x) * sinc(x / a), x.new_zeros([]))
return out / out.sum()


def ramp(ratio, width):
n = math.ceil(width / ratio + 1)
out = torch.empty([n])
cur = 0
for i in range(out.shape[0]):
out[i] = cur
cur += ratio
return torch.cat([-out[1:].flip([0]), out])[1:-1]


class MakeCutouts(nn.Module):
def __init__(self, cut_size, cutn, skip_augs=False):
super().__init__()
self.cut_size = cut_size
self.cutn = cutn
self.skip_augs = skip_augs
self.augs = T.Compose(
[
T.RandomHorizontalFlip(p=0.5),
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
T.RandomAffine(degrees=15, translate=(0.1, 0.1)),
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
T.RandomPerspective(distortion_scale=0.4, p=0.7),
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
T.RandomGrayscale(p=0.15),
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
# T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
]
)

def forward(self, input):
input = T.Pad(input.shape[2] // 4, fill=0)(input)
sideY, sideX = input.shape[2:4]
max_size = min(sideX, sideY)

cutouts = []
for ch in range(self.cutn):
if ch > self.cutn - self.cutn // 4:
cutout = input.clone()
else:
size = int(
max_size
* torch.zeros(
1,
)
.normal_(mean=0.8, std=0.3)
.clip(float(self.cut_size / max_size), 1.0)
)
offsetx = torch.randint(0, abs(sideX - size + 1), ())
offsety = torch.randint(0, abs(sideY - size + 1), ())
cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size]

if not self.skip_augs:
cutout = self.augs(cutout)
cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
del cutout

cutouts = torch.cat(cutouts, dim=0)
return cutouts
augment = torch.jit.script(
torch.nn.Sequential(
*[
T.RandomHorizontalFlip(p=0.5),
T.RandomAffine(
degrees=10,
translate=(0.05, 0.05),
interpolation=T.InterpolationMode.BILINEAR,
),
T.RandomGrayscale(p=0.1),
T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
T.Normalize(
mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711],
),
]
)
)


class MakeCutoutsDango(nn.Module):
Expand All @@ -87,33 +33,18 @@ def __init__(
InnerCrop=0,
IC_Size_Pow=0.5,
IC_Grey_P=0.2,
skip_augs=False,
):
super().__init__()
self.cut_size = cut_size
self.Overview = Overview
self.InnerCrop = InnerCrop
self.IC_Size_Pow = IC_Size_Pow
self.IC_Grey_P = IC_Grey_P
self.skip_augs = skip_augs
self.augs = T.Compose(
[
T.RandomHorizontalFlip(p=0.5),
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
T.RandomAffine(
degrees=10,
translate=(0.05, 0.05),
interpolation=T.InterpolationMode.BILINEAR,
),
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
T.RandomGrayscale(p=0.1),
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
]
)

def forward(self, input):
cutouts = []
return torch.cat([augment(c) for c in self._cut_generator(input)])

def _cut_generator(self, input):
gray = T.Grayscale(3)
sideY, sideX = input.shape[2:4]
max_size = min(sideX, sideY)
Expand All @@ -127,65 +58,26 @@ def forward(self, input):
(sideX - max_size) // 2,
(sideX - max_size) // 2,
),
**padargs,
)
cutout = resize(pad_input, out_shape=output_shape)

if self.Overview > 0:
if self.Overview <= 4:
if self.Overview >= 1:
cutouts.append(cutout)
if self.Overview >= 2:
cutouts.append(gray(cutout))
if self.Overview >= 3:
cutouts.append(TF.hflip(cutout))
if self.Overview == 4:
cutouts.append(gray(TF.hflip(cutout)))
cutout = resize(pad_input, out_shape=output_shape)
for j in range(self.Overview):
if j == 1:
yield gray(cutout)
elif j == 2:
yield TF.hflip(cutout)
elif j == 3:
yield gray(TF.hflip(cutout))
else:
cutout = resize(pad_input, out_shape=output_shape)
for _ in range(self.Overview):
cutouts.append(cutout)

if self.InnerCrop > 0:
for i in range(self.InnerCrop):
size = int(
torch.rand([]) ** self.IC_Size_Pow * (max_size - min_size)
+ min_size
)
offsetx = torch.randint(0, sideX - size + 1, ())
offsety = torch.randint(0, sideY - size + 1, ())
cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size]
if i <= int(self.IC_Grey_P * self.InnerCrop):
cutout = gray(cutout)
cutout = resize(cutout, out_shape=output_shape)
cutouts.append(cutout)

cutouts = torch.cat(cutouts)
if not self.skip_augs:
cutouts = self.augs(cutouts)
return cutouts


def resample(input, size, align_corners=True):
n, c, h, w = input.shape
dh, dw = size

input = input.reshape([n * c, 1, h, w])

if dh < h:
kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)
pad_h = (kernel_h.shape[0] - 1) // 2
input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')
input = F.conv2d(input, kernel_h[None, None, :, None])

if dw < w:
kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)
pad_w = (kernel_w.shape[0] - 1) // 2
input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')
input = F.conv2d(input, kernel_w[None, None, None, :])

input = input.reshape([n, c, h, w])
return F.interpolate(input, size, mode='bicubic', align_corners=align_corners)


padargs = {}
yield cutout

for i in range(self.InnerCrop):
size = int(
torch.rand([]) ** self.IC_Size_Pow * (max_size - min_size) + min_size
)
offsetx = torch.randint(0, sideX - size + 1, ())
offsety = torch.randint(0, sideY - size + 1, ())
cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size]
if i <= int(self.IC_Grey_P * self.InnerCrop):
cutout = gray(cutout)
yield resize(cutout, out_shape=output_shape)
1 change: 0 additions & 1 deletion discoart/resources/default.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ clip_models_schedules:
use_vertical_symmetry: False
use_horizontal_symmetry: False
transformation_percent: [0.09]
skip_augs: False

on_misspelled_token: ignore
diffusion_model_config:
Expand Down
4 changes: 0 additions & 4 deletions discoart/resources/docstrings.yml
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,6 @@ clip_models: |
use_vertical_symmetry: Enforce symmetry over x axis of the image on [tr_ststeps for tr_st in transformation_steps] steps of the diffusion process
use_horizontal_symmetry: Enforce symmetry over y axis of the image on [tr_ststeps for tr_st in transformation_steps] steps of the diffusion process
transformation_percent: Steps expressed in percentages in which the symmetry is enforced
skip_augs: |
Controls whether to skip torchvision augmentations.
[DiscoArt] Can be scheduled via syntax `[val1]*400+[val2]*600`.


on_misspelled_token: |
Expand Down
42 changes: 22 additions & 20 deletions discoart/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,22 @@
from .persist import _sample_thread, _persist_thread, _save_progress_thread
from .prompt import PromptPlanner

inv_normalize = T.Normalize(
mean=[-0.48145466 / 0.26862954, -0.4578275 / 0.26130258, -0.40821073 / 0.27577711],
std=[1 / 0.26862954, 1 / 0.26130258, 1 / 0.27577711],
)


def do_run(args, models, device, events) -> 'DocumentArray':
skip_event, stop_event = events

_is_jupyter = is_jupyter()

output_dir = get_output_dir(args.name_docarray)

logger.info('preparing models...')

model, diffusion, clip_models, secondary_model = models
normalize = T.Normalize(
mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711],
)
lpips_model = lpips.LPIPS(net='vgg').to(device)

side_x, side_y = ((args.width_height[j] // 64) * 64 for j in (0, 1))
Expand Down Expand Up @@ -217,21 +220,22 @@ def cond_fn(x, t, **kwargs):
else:
continue

for _ in range(scheduler.cutn_batches):
cuts = MakeCutoutsDango(
model_stat['input_resolution'],
Overview=scheduler.cut_overview,
InnerCrop=scheduler.cut_innercut,
IC_Size_Pow=scheduler.cut_ic_pow,
IC_Grey_P=scheduler.cut_icgray_p,
skip_augs=scheduler.skip_augs,
)
cuts = MakeCutoutsDango(
model_stat['input_resolution'],
Overview=scheduler.cut_overview,
InnerCrop=scheduler.cut_innercut,
IC_Size_Pow=scheduler.cut_ic_pow,
IC_Grey_P=scheduler.cut_icgray_p,
)

for _ in range(scheduler.cutn_batches):
clip_in = cuts(x_in.add(1).div(2))

if args.visualize_cuts and not is_cuts_visualized:
_cuts_da = DocumentArray.empty(clip_in.shape[0])
_cuts_da.tensors = (clip_in * 255).detach().cpu().numpy()
_cuts_da.tensors = (
(inv_normalize(clip_in) * 255).detach().cpu().numpy()
)
_cuts_da.plot_image_sprites(
os.path.join(output_dir, f'{_nb}-cuts-{num_step}.png'),
show_index=True,
Expand All @@ -240,9 +244,7 @@ def cond_fn(x, t, **kwargs):
is_cuts_visualized = True

image_embeds = (
model_stat['clip_model']
.encode_image(normalize(clip_in))
.unsqueeze(1)
model_stat['clip_model'].encode_image(clip_in).unsqueeze(1)
)

dists = spherical_dist_loss(
Expand Down Expand Up @@ -350,7 +352,7 @@ def cond_fn(x, t, **kwargs):
new_seed = org_seed + _nb
_set_seed(new_seed)
args.seed = new_seed
if is_jupyter():
if _is_jupyter:
redraw_widget(
_handlers,
_redraw_fn,
Expand All @@ -377,7 +379,7 @@ def cond_fn(x, t, **kwargs):
clip_denoised=args.clip_denoised,
model_kwargs={},
cond_fn=cond_fn,
progress=True,
progress=_is_jupyter,
skip_timesteps=skip_steps,
init_image=init,
randomize_class=args.randomize_class,
Expand All @@ -394,7 +396,7 @@ def cond_fn(x, t, **kwargs):
clip_denoised=args.clip_denoised,
model_kwargs={},
cond_fn=cond_fn,
progress=True,
progress=_is_jupyter,
skip_timesteps=skip_steps,
init_image=init,
randomize_class=args.randomize_class,
Expand Down

0 comments on commit b6bd8dd

Please sign in to comment.