Skip to content

Commit

Permalink
Parametrize not sweep
Browse files Browse the repository at this point in the history
  • Loading branch information
outtanames committed Sep 14, 2023
1 parent b40b218 commit ecb0e65
Showing 1 changed file with 22 additions and 17 deletions.
39 changes: 22 additions & 17 deletions tests/models/test_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@

import pytest

from nos.test.utils import PyTestGroup, skip_if_no_torch_cuda
from nos.models.stable_diffusion import StableDiffusion
from nos.test.utils import PyTestGroup, skip_if_no_torch_cuda


STABLE_DIFFUSION_MODELS = StableDiffusion.configs.keys()


@pytest.mark.benchmark(group=PyTestGroup.HUB)
def test_stable_diffusion_predict():
@pytest.mark.parametrize("model", STABLE_DIFFUSION_MODELS)
def test_stable_diffusion_predict(model):
"""Use StableDiffusion to generate an image from a text prompt.
Note: This test should be able to run with CPU or GPU.
Expand All @@ -17,20 +21,18 @@ def test_stable_diffusion_predict():
"""
from PIL import Image

for config in StableDiffusion.configs.values():
import pdb; pdb.set_trace()
model = StableDiffusion(model_name=config.model_name, scheduler="ddim")
images: List[Image.Image] = model.__call__(
"astronaut on a horse on the moon",
num_images=1,
num_inference_steps=100,
guidance_scale=7.5,
width=512,
height=512,
)
(image,) = images
assert image is not None
assert image.size == (512, 512)
model = StableDiffusion(model_name=model, scheduler="ddim")
images: List[Image.Image] = model.__call__(
"astronaut on a horse on the moon",
num_images=1,
num_inference_steps=100,
guidance_scale=7.5,
width=512,
height=512,
)
(image,) = images
assert image is not None
assert image.size == (512, 512)


@pytest.mark.benchmark(group=PyTestGroup.HUB)
Expand All @@ -47,10 +49,13 @@ def test_stable_diffusion_download_all():

@skip_if_no_torch_cuda
@pytest.mark.benchmark(group=PyTestGroup.MODEL_BENCHMARK)
@pytest.mark.parametrize("model", STABLE_DIFFUSION_MODELS)
def test_stable_diffusion_benchmark(model):
"""Benchmark StableDiffusion model."""
from nos.test.benchmark import run_benchmark

model = StableDiffusion(model_name=model, scheduler="ddim")

steps = 10
time_ms = run_benchmark(
lambda: model.__call__(
Expand All @@ -61,4 +66,4 @@ def test_stable_diffusion_benchmark(model):
),
num_iters=5,
)
print(f"BENCHMARK [{MODEL_NAME}]: {time_ms / steps:.2f} ms / step")
print(f"BENCHMARK [{model}]: {time_ms / steps:.2f} ms / step")

0 comments on commit ecb0e65

Please sign in to comment.