🦄️ Noise Calibration: Plug-and-play Content-Preserving Video Enhancement using Pre-trained Video Diffusion Models (ECCV 2024)
1 Dalian University of Technology 2 Tencent AI Lab * Corresponding Author
In European Conference on Computer Vision (ECCV) 2024
We propose Noise Calibration,a method that substantially improves consistency of content between enhanced videos based on SDEdit and original videos.
✅ Totally no training ✅ Less than 10% extra time ✅ Plug-and-play
import torch
import torch.fft as fft
def get_low_or_high_fft(x, scale, is_low=True):
# FFT
x_freq = fft.fftn(x, dim=(-2, -1))
x_freq = fft.fftshift(x_freq, dim=(-2, -1))
B, C, T, H, W = x_freq.shape
# extract
if is_low:
mask = torch.zeros((B, C, T, H, W), device=x.device)
crow, ccol = H // 2, W // 2
mask[..., crow - int(crow * scale):crow + int(crow * scale), ccol - int(ccol * scale):ccol + int(ccol * scale)] = 1
else:
mask = torch.ones((B, C, T, H, W), device=x.device)
crow, ccol = H // 2, W //2
mask[..., crow - int(crow * scale):crow + int(crow * scale), ccol - int(ccol * scale):ccol + int(ccol * scale)] = 0
x_freq = x_freq * mask
# IFFT
x_freq = fft.ifftshift(x_freq, dim=(-2, -1))
x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real
return x_filtere
-------------------------------------------------------------------------------------------------------------------------------------------------------
### Noise Calibration(Algorithm 1)
e_t = torch.randn_like(x_r)
for _ in range(N):
x = a_t.sqrt() * x_r + sqrt_one_minus_at * e_t
e_t_theta = self.model.apply_model(x, t, c, **kwargs)
x_0_t = (x - sqrt_one_minus_at * e_t_theta) / a_t.sqrt()
e_t = e_t_theta + a_t.sqrt() / sqrt_one_minus_at * (get_low_or_high_fft(x_0_t, scale, is_low=False) - get_low_or_high_fft(x_r, scale, is_low=False))
x = a_t.sqrt() * x_r + sqrt_one_minus_at * e_t
------------------------------------------------------------------------------------------------------------------------------------------------------
Next, you can use SDEdit to enhance the video based on x.