diff --git a/tests/models/test_stable_diffusion.py b/tests/models/test_stable_diffusion.py index 1fd0dfd5..f0c80523 100644 --- a/tests/models/test_stable_diffusion.py +++ b/tests/models/test_stable_diffusion.py @@ -2,12 +2,16 @@ import pytest -from nos.test.utils import PyTestGroup, skip_if_no_torch_cuda from nos.models.stable_diffusion import StableDiffusion +from nos.test.utils import PyTestGroup, skip_if_no_torch_cuda + + +STABLE_DIFFUSION_MODELS = StableDiffusion.configs.keys() @pytest.mark.benchmark(group=PyTestGroup.HUB) -def test_stable_diffusion_predict(): +@pytest.mark.parametrize("model", STABLE_DIFFUSION_MODELS) +def test_stable_diffusion_predict(model): """Use StableDiffusion to generate an image from a text prompt. Note: This test should be able to run with CPU or GPU. @@ -17,20 +21,18 @@ def test_stable_diffusion_predict(): """ from PIL import Image - for config in StableDiffusion.configs.values(): - import pdb; pdb.set_trace() - model = StableDiffusion(model_name=config.model_name, scheduler="ddim") - images: List[Image.Image] = model.__call__( - "astronaut on a horse on the moon", - num_images=1, - num_inference_steps=100, - guidance_scale=7.5, - width=512, - height=512, - ) - (image,) = images - assert image is not None - assert image.size == (512, 512) + model = StableDiffusion(model_name=model, scheduler="ddim") + images: List[Image.Image] = model.__call__( + "astronaut on a horse on the moon", + num_images=1, + num_inference_steps=100, + guidance_scale=7.5, + width=512, + height=512, + ) + (image,) = images + assert image is not None + assert image.size == (512, 512) @pytest.mark.benchmark(group=PyTestGroup.HUB) @@ -47,10 +49,13 @@ def test_stable_diffusion_download_all(): @skip_if_no_torch_cuda @pytest.mark.benchmark(group=PyTestGroup.MODEL_BENCHMARK) +@pytest.mark.parametrize("model", STABLE_DIFFUSION_MODELS) def test_stable_diffusion_benchmark(model): """Benchmark StableDiffusion model.""" from nos.test.benchmark import run_benchmark + model = StableDiffusion(model_name=model, scheduler="ddim") + steps = 10 time_ms = run_benchmark( lambda: model.__call__( @@ -61,4 +66,4 @@ def test_stable_diffusion_benchmark(model): ), num_iters=5, ) - print(f"BENCHMARK [{MODEL_NAME}]: {time_ms / steps:.2f} ms / step") + print(f"BENCHMARK [{model}]: {time_ms / steps:.2f} ms / step")