forked from replicate/cog-stable-diffusion
-
Notifications
You must be signed in to change notification settings - Fork 36
/
predict.py
121 lines (107 loc) · 4.29 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
from typing import Optional, List
import torch
from torch import autocast
from diffusers import PNDMScheduler, LMSDiscreteScheduler
from PIL import Image
from cog import BasePredictor, Input, Path
from image_to_image import (
StableDiffusionImg2ImgPipeline,
preprocess_init_image,
preprocess_mask,
)
MODEL_CACHE = "diffusers-cache"
class Predictor(BasePredictor):
def setup(self):
"""Load the model into memory to make running multiple predictions efficient"""
print("Loading pipeline...")
scheduler = PNDMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
)
self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
scheduler=scheduler,
revision="fp16",
torch_dtype=torch.float16,
cache_dir=MODEL_CACHE,
local_files_only=True,
).to("cuda")
@torch.inference_mode()
@torch.cuda.amp.autocast()
def predict(
self,
prompt: str = Input(description="Input prompt", default=""),
width: int = Input(
description="Width of output image",
choices=[128, 256, 512, 768, 1024],
default=512,
),
height: int = Input(
description="Height of output image",
choices=[128, 256, 512, 768],
default=512,
),
init_image: Path = Input(
description="Inital image to generate variations of. Will be resized to the specified width and height", default=None
),
mask: Path = Input(
description="Black and white image to use as mask for inpainting over init_image. Black pixels are inpainted and white pixels are preserved. Experimental feature, tends to work better with prompt strength of 0.5-0.7",
default=None,
),
prompt_strength: float = Input(
description="Prompt strength when using init image. 1.0 corresponds to full destruction of information in init image",
default=0.8,
),
num_outputs: int = Input(
description="Number of images to output", choices=[1, 4], default=1
),
num_inference_steps: int = Input(
description="Number of denoising steps", ge=1, le=500, default=50
),
guidance_scale: float = Input(
description="Scale for classifier-free guidance", ge=1, le=20, default=7.5
),
seed: int = Input(
description="Random seed. Leave blank to randomize the seed", default=None
),
) -> List[Path]:
"""Run a single prediction on the model"""
if seed is None:
seed = int.from_bytes(os.urandom(2), "big")
print(f"Using seed: {seed}")
if init_image:
init_image = Image.open(init_image).convert("RGB")
init_image = preprocess_init_image(init_image, width, height).to("cuda")
# use PNDM with init images
scheduler = PNDMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
)
else:
# use LMS without init images
scheduler = LMSDiscreteScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
)
self.pipe.scheduler = scheduler
if mask:
mask = Image.open(mask).convert("RGB")
mask = preprocess_mask(mask, width, height).to("cuda")
generator = torch.Generator("cuda").manual_seed(seed)
output = self.pipe(
prompt=[prompt] * num_outputs if prompt is not None else None,
init_image=init_image,
mask=mask,
width=width,
height=height,
prompt_strength=prompt_strength,
guidance_scale=guidance_scale,
generator=generator,
num_inference_steps=num_inference_steps,
)
if any(output["nsfw_content_detected"]):
raise Exception("NSFW content detected, please try a different prompt")
output_paths = []
for i, sample in enumerate(output["sample"]):
output_path = f"/tmp/out-{i}.png"
sample.save(output_path)
output_paths.append(Path(output_path))
return output_paths