-
Notifications
You must be signed in to change notification settings - Fork 3
/
sd_runner.py
109 lines (91 loc) · 3.77 KB
/
sd_runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import time
from typing import List
import torch
from diffusers import (
StableDiffusionPipeline,
PNDMScheduler,
LMSDiscreteScheduler,
DDIMScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
)
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
MODEL_ID = "stabilityai/stable-diffusion-2-1"
MODEL_CACHE = "diffusers-cache"
SAFETY_MODEL_ID = "CompVis/stable-diffusion-safety-checker"
class Predictor:
''' A predictor class that loads the model into memory and runs predictions '''
def __init__(self, model_id):
self.model_id = model_id
self.device = "cuda" if torch.cuda.is_available() else "cpu"
def setup(self):
start_time = time.time()
"""Load the model into memory to make running multiple predictions efficient"""
print("Loading pipeline...")
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
SAFETY_MODEL_ID,
cache_dir=MODEL_CACHE,
local_files_only=True,
)
self.pipe = StableDiffusionPipeline.from_pretrained(
self.model_id,
# safety_checker=safety_checker,
safety_checker=None,
cache_dir=MODEL_CACHE,
local_files_only=True,
).to(self.device)
self.pipe.enable_xformers_memory_efficient_attention()
# self.pipe.vae.enable_xformers_memory_efficient_attention(attention_op=None)
end_time = time.time()
print(f"setup time: {end_time - start_time}")
@torch.inference_mode()
def predict(self, prompt, negative_prompt, width, height, num_outputs, num_inference_steps, guidance_scale, scheduler, seed):
"""Run a single prediction on the model"""
start_time = time.time()
if seed is None:
seed = int.from_bytes(os.urandom(2), "big")
print(f"Using seed: {seed}")
if width * height > 786432:
raise ValueError(
"Maximum size is 1024x768 or 768x1024 pixels, because of memory limits. Please select a lower width or height."
)
self.pipe.scheduler = make_scheduler(scheduler, self.pipe.scheduler.config)
generator = torch.Generator(self.device).manual_seed(seed)
output = self.pipe(
prompt=[prompt] * num_outputs if prompt is not None else None,
negative_prompt=[negative_prompt] * num_outputs
if negative_prompt is not None
else None,
width=width,
height=height,
guidance_scale=guidance_scale,
generator=generator,
num_inference_steps=num_inference_steps,
)
output_paths = []
for i, sample in enumerate(output.images):
if output.nsfw_content_detected and output.nsfw_content_detected[i]:
continue
output_path = f"/tmp/out-{i}.png"
sample.save(output_path)
output_paths.append(output_path)
if len(output_paths) == 0:
raise Exception(
"NSFW content detected. Try running it again, or try a different prompt.")
end_time = time.time()
print(f"inference took {end_time - start_time} time")
return output_paths
def make_scheduler(name, config):
return {
"PNDM": PNDMScheduler.from_config(config),
"KLMS": LMSDiscreteScheduler.from_config(config),
"DDIM": DDIMScheduler.from_config(config),
"K_EULER": EulerDiscreteScheduler.from_config(config),
"K_EULER_ANCESTRAL": EulerAncestralDiscreteScheduler.from_config(config),
"DPMSolverMultistep": DPMSolverMultistepScheduler.from_config(config),
}[name]