Skip to content

Commit

Permalink
add test for ldm uncond
Browse files Browse the repository at this point in the history
  • Loading branch information
patil-suraj committed Jun 29, 2022
1 parent 65788e4 commit 859ffea
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
GradTTSPipeline,
GradTTSScheduler,
LatentDiffusionPipeline,
LatentDiffusionUncondPipeline,
NCSNpp,
PNDMPipeline,
PNDMScheduler,
Expand All @@ -46,7 +47,6 @@
UNetLDMModel,
UNetModel,
VQModel,
AutoencoderKL,
)
from diffusers.configuration_utils import ConfigMixin
from diffusers.pipeline_utils import DiffusionPipeline
Expand Down Expand Up @@ -915,7 +915,7 @@ def prepare_init_args_and_inputs_for_common(self):
"out_ch": 3,
"resolution": 32,
"z_channels": 4,
"attn_resolutions": []
"attn_resolutions": [],
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
Expand All @@ -925,7 +925,7 @@ def test_forward_signature(self):

def test_training(self):
pass

def test_from_pretrained_hub(self):
model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True)
self.assertIsNotNone(model)
Expand Down Expand Up @@ -1151,6 +1151,19 @@ def test_score_sde_vp_pipeline(self):
assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2
assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4

@slow
def test_ldm_uncond(self):
ldm = LatentDiffusionUncondPipeline.from_pretrained("fusing/latent-diffusion-celeba-256")

generator = torch.manual_seed(0)
image = ldm(generator=generator, num_inference_steps=5)

image_slice = image[0, -1, -3:, -3:].cpu()

assert image.shape == (1, 3, 256, 256)
expected_slice = torch.tensor([0.5025, 0.4121, 0.3851, 0.4806, 0.3996, 0.3745, 0.4839, 0.4559, 0.4293])
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2

def test_module_from_pipeline(self):
model = DiffWave(num_res_layers=4)
noise_scheduler = DDPMScheduler(timesteps=12)
Expand Down

0 comments on commit 859ffea

Please sign in to comment.