-
-
Notifications
You must be signed in to change notification settings - Fork 335
Technical Part
Kahsolt edited this page Mar 30, 2024
·
1 revision
For those who want to know how this works.
The core technique is to estimate GroupNorm params for a seamless generation.
- The image is split into tiles, which are then padded with 11/32 pixels' in the decoder/encoder.
- When Fast Mode is disabled:
- The original VAE forward is decomposed into a task queue and a task worker, which starts to process each tile.
- When GroupNorm is needed, it suspends, stores current GroupNorm mean and var, send everything to RAM, and turns to the next tile.
- After all GroupNorm means and vars are summarized, it applies group norm to tiles and continues.
- A zigzag execution order is used to reduce unnecessary data transfer.
- When Fast Mode is enabled:
- The original input is downsampled and passed to a separate task queue.
- Its group norm parameters are recorded and used by all tiles' task queues.
- Each tile is separately processed without any RAM-VRAM data transfer.
- After all tiles are processed, tiles are written to a result buffer and returned.
ℹ Encoder color fix = only estimate GroupNorm before downsampling, i.e., run in a semi-fast mode.
- The latent image is split into tiles.
- In MultiDiffusion:
- The UNet predicts the noise of each tile.
- The tiles are denoised by the original sampler for one time step.
- The tiles are added together but divided by how many times each pixel is added.
- In Mixture of Diffusers:
- The UNet predicts the noise of each tile
- All noises are fused with a gaussian weight mask.
- The denoiser denoises the whole image for one time step using fused noises.
- Repeat 2-3 until all timesteps are completed.
⚪ Advantages
- Draw super large resolution (2k~8k) images in limited VRAM
- Seamless output without any post-processing
⚪ Drawbacks
- It will be significantly slower than the usual generation.
- The gradient calculation is not compatible with this hack. It will break any backward() or torch.autograd.grad()