-
Notifications
You must be signed in to change notification settings - Fork 6
/
sd_ndi.py
204 lines (167 loc) · 6.39 KB
/
sd_ndi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import torch
from diffusers import AutoencoderTiny, StableDiffusionPipeline
from diffusers.utils import load_image
from streamdiffusion import StreamDiffusion
from streamdiffusion.image_utils import postprocess_image
from streamdiffusion.acceleration.tensorrt import accelerate_with_tensorrt
import time
import numpy as np
import cv2 as cv
import NDIlib as ndi
from pythonosc import udp_client
from pythonosc import osc_server
from pythonosc.dispatcher import Dispatcher
from threading import Thread
from typing import List, Any, Tuple
import json
def process_image(image_np: np.ndarray, range: Tuple[int, int] = (-1, 1)) -> Tuple[torch.Tensor, np.ndarray]:
image = torch.from_numpy(image_np).permute(2, 0, 1).float() / 255.0
r_min, r_max = range[0], range[1]
image = image * (r_max - r_min) + r_min
return image.unsqueeze(0), image_np
def np2tensor(image_np: np.ndarray) -> torch.Tensor:
height, width, _ = image_np.shape
imgs = []
img, _ = process_image(image_np)
imgs.append(img)
imgs = torch.vstack(imgs)
images = torch.nn.functional.interpolate(
imgs, size=(height, width), mode="bilinear", align_corners=False
)
image_tensors = images.to(torch.float16)
return image_tensors
def oscprompt(address: str, *args: List[Any]) -> None:
if address == "/prompt":
global shared_message
shared_message = args[0]
def load_config(file_path):
with open(file_path, 'r') as file:
config = json.load(file)
return config
# Load config
config_data = load_config('config.json')
sd_model = config_data['sd_model']
t_index_list = config_data['t_index_list']
engine = config_data['engine']
min_batch_size = config_data['min_batch_size']
max_batch_size = config_data['max_batch_size']
ndi_name = config_data['ndi_name']
osc_out_adress = config_data['osc_out_adress']
osc_out_port = config_data['osc_out_port']
osc_in_adress = config_data['osc_in_adress']
osc_in_port = config_data['osc_in_port']
print(config_data)
# You can load any models using diffuser's StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained(sd_model).to(
device=torch.device("cuda"),
dtype=torch.float16,
)
frame_buffer_size = 1
# Wrap the pipeline in StreamDiffusion
stream = StreamDiffusion(
pipe,
t_index_list=t_index_list,
torch_dtype=torch.float16,
frame_buffer_size = frame_buffer_size
)
# If the loaded model is not LCM, merge LCM
stream.load_lcm_lora()
stream.fuse_lora()
# Use Tiny VAE for further acceleration
stream.vae = AutoencoderTiny.from_pretrained("madebyollin/taesd").to(device=pipe.device, dtype=pipe.dtype)
# Enable acceleration
stream = accelerate_with_tensorrt(
stream, engine, min_batch_size=min_batch_size ,max_batch_size=max_batch_size
)
prompt = "banana in space"
# Prepare the stream
stream.prepare(prompt)
# NDI
ndi_find = ndi.find_create_v2()
source = ''
while True:
if not ndi.find_wait_for_sources(ndi_find, 5000):
print('NDI: No change to the sources found.')
continue
sources = ndi.find_get_current_sources(ndi_find)
print('NDI: Network sources (%s found).' % len(sources))
for i, s in enumerate(sources):
print('%s. %s' % (i + 1, s.ndi_name))
if s.ndi_name == ndi_name:
source = s
if source != '':
print(f'NDI: Connected to {source.ndi_name}')
break
ndi_recv_create = ndi.RecvCreateV3()
ndi_recv_create.color_format = ndi.RECV_COLOR_FORMAT_BGRX_BGRA
ndi_recv = ndi.recv_create_v3(ndi_recv_create)
ndi.recv_connect(ndi_recv, source)
ndi.find_destroy(ndi_find)
send_settings = ndi.SendCreate()
send_settings.ndi_name = 'SD-NDI'
ndi_send = ndi.send_create(send_settings)
video_frame = ndi.VideoFrameV2()
# OSC
server_address = osc_out_adress
server_port = osc_out_port
client = udp_client.SimpleUDPClient(server_address, server_port)
server_address = osc_in_adress
server_port = osc_in_port
shared_message = None
dispatcher = Dispatcher()
dispatcher.map("/prompt", oscprompt)
server = osc_server.ThreadingOSCUDPServer(
(server_address, server_port), dispatcher)
server_thread = Thread(target=server.serve_forever)
server_thread.start()
# Run the stream infinitely
try:
while True:
if shared_message is not None:
prompt = str(shared_message)
stream.prepare(prompt)
# Process the received message within the loop as needed
print(f"Prompt: {prompt}")
# Reset the shared_message variable
shared_message = None
t, v, _, _ = ndi.recv_capture_v2(ndi_recv, 5000)
if t == ndi.FRAME_TYPE_VIDEO:
frame = np.copy(v.data)
framergb = cv.cvtColor(frame, cv.COLOR_BGRA2BGR)
inputs = []
inputs.append(np2tensor(framergb))
if len(inputs) < frame_buffer_size:
time.sleep(0.005)
continue
start_time = time.time()
sampled_inputs = []
for i in range(frame_buffer_size):
index = (len(inputs) // frame_buffer_size) * i
sampled_inputs.append(inputs[len(inputs) - index - 1])
input_batch = torch.cat(sampled_inputs)
inputs.clear()
output_images = stream(
input_batch.to(device=stream.device, dtype=stream.dtype)
).cpu()
if frame_buffer_size == 1:
output_images = [output_images]
for output_image in output_images:
output_image = postprocess_image(output_image, output_type="np")[0]
open_cv_image = (output_image * 255).round().astype("uint8")
img = cv.cvtColor(open_cv_image, cv.COLOR_RGB2RGBA)
ndi.recv_free_video_v2(ndi_recv, v)
video_frame.data = img
video_frame.FourCC = ndi.FOURCC_VIDEO_TYPE_BGRX
ndi.send_send_video_v2(ndi_send, video_frame)
fps = 1 / (time.time() - start_time)
client.send_message("/fps", fps)
except KeyboardInterrupt:
# Handle KeyboardInterrupt (Ctrl+C)
print("KeyboardInterrupt: Stopping the server")
finally:
# Stop the server when the loop exits
ndi.recv_destroy(ndi_recv)
ndi.send_destroy(ndi_send)
ndi.destroy()
server.shutdown()
server_thread.join()