Skip to content

Commit

Permalink
Sweep all stable diffusion models during testing
Browse files Browse the repository at this point in the history
  • Loading branch information
outtanames committed Sep 14, 2023
1 parent c222d92 commit b40b218
Showing 1 changed file with 16 additions and 23 deletions.
39 changes: 16 additions & 23 deletions tests/models/test_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,11 @@
import pytest

from nos.test.utils import PyTestGroup, skip_if_no_torch_cuda


MODEL_NAME = "runwayml/stable-diffusion-v1-5"


@pytest.fixture(scope="module")
def model():
from nos.models import StableDiffusion # noqa: F401

# TODO (spillai): @pytest.parametrize("scheduler", ["ddim", "euler-discrete"])
yield StableDiffusion(model_name=MODEL_NAME, scheduler="ddim")
from nos.models.stable_diffusion import StableDiffusion


@pytest.mark.benchmark(group=PyTestGroup.HUB)
def test_stable_diffusion_predict(model):
def test_stable_diffusion_predict():
"""Use StableDiffusion to generate an image from a text prompt.
Note: This test should be able to run with CPU or GPU.
Expand All @@ -27,17 +17,20 @@ def test_stable_diffusion_predict(model):
"""
from PIL import Image

images: List[Image.Image] = model.__call__(
"astronaut on a horse on the moon",
num_images=1,
num_inference_steps=100,
guidance_scale=7.5,
width=512,
height=512,
)
(image,) = images
assert image is not None
assert image.size == (512, 512)
for config in StableDiffusion.configs.values():
import pdb; pdb.set_trace()
model = StableDiffusion(model_name=config.model_name, scheduler="ddim")
images: List[Image.Image] = model.__call__(
"astronaut on a horse on the moon",
num_images=1,
num_inference_steps=100,
guidance_scale=7.5,
width=512,
height=512,
)
(image,) = images
assert image is not None
assert image.size == (512, 512)


@pytest.mark.benchmark(group=PyTestGroup.HUB)
Expand Down

0 comments on commit b40b218

Please sign in to comment.