-
Notifications
You must be signed in to change notification settings - Fork 3
/
runpod_infer.py
137 lines (116 loc) · 4.18 KB
/
runpod_infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
'''
RunPod | serverless-ckpt-template | runpod_infer.py
Entry point for job requests from RunPod serverless platform.
'''
import os
import argparse
import sd_runner
import runpod
from runpod.serverless.utils import rp_download, rp_cleanup, rp_upload
from runpod.serverless.utils.rp_validator import validate
INPUT_SCHEMA = {
'prompt': {
'type': str,
'required': True
},
'negative_prompt': {
'type': str,
'required': False,
'default': None
},
'width': {
'type': int,
'required': False,
'default': 768,
'constraints': lambda width: width in [128, 256, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024]
},
'height': {
'type': int,
'required': False,
'default': 768,
'constraints': lambda height: height in [128, 256, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024]
},
'num_outputs': {
'type': int,
'required': False,
'default': 1,
'constraints': lambda num_outputs: num_outputs in range(1, 4)
},
'num_inference_steps': {
'type': int,
'required': False,
'default': 50,
'constraints': lambda num_inference_steps: num_inference_steps in range(1, 500)
},
'guidance_scale': {
'type': float,
'required': False,
'default': 7.5,
'constraints': lambda guidance_scale: 0 <= guidance_scale <= 20
},
'scheduler': {
'type': str,
'required': False,
'default': 'DPMSolverMultistep',
'constraints': lambda scheduler: scheduler in ['DDIM', 'K_EULER', 'DPMSolverMultistep', 'K_EULER_ANCESTRAL', 'PNDM', 'KLMS']
},
'seed': {
'type': int,
'required': False,
'default': int.from_bytes(os.urandom(2), "big")
}
}
def handler(job):
'''
Takes in raw data from the API call, prepares it for the model.
Passes the data to the model to get the results.
Prepares the resulting output to be returned to the API call.
'''
job_input = job['input']
job_output = []
# -------------------------------- Validation -------------------------------- #
validated_input = validate(job_input, INPUT_SCHEMA)
if 'errors' in validated_input:
return {"errors": validated_input['errors']}
valid_input = validated_input['validated_input']
image_paths = model_runner.predict(
prompt=valid_input['prompt'],
negative_prompt=valid_input['negative_prompt'],
width=valid_input['width'],
height=valid_input['height'],
num_outputs=valid_input['num_outputs'],
num_inference_steps=valid_input['num_inference_steps'],
guidance_scale=valid_input['guidance_scale'],
scheduler=valid_input['scheduler'],
seed=valid_input['seed']
)
for index, img_path in enumerate(image_paths):
image_url = rp_upload.upload_image(job['id'], img_path)
job_output.append({
"image": image_url,
"prompt": job_input["prompt"],
"negative_prompt": job_input["negative_prompt"],
"width": job_input['width'],
"height": job_input['height'],
"num_inference_steps": job_input['num_inference_steps'],
"guidance_scale": job_input['guidance_scale'],
"scheduler": job_input['scheduler'],
"seed": job_input['seed'] + index
})
# Remove downloaded input objects
# rp_cleanup.clean(['input_objects'])
return job_output
# ---------------------------------------------------------------------------- #
# Main #
# ---------------------------------------------------------------------------- #
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--model_url", type=str,
default=None, help="Model URL")
if __name__ == "__main__":
args = parser.parse_args()
if "huggingface.co" in args.model_url:
url_parts = args.model_url.split("/")
model_id = f"{url_parts[-2]}/{url_parts[-1]}"
model_runner = sd_runner.Predictor(model_id)
model_runner.setup()
runpod.serverless.start({"handler": handler})