Skip to content

Commit

Permalink
feat(api): add source image filters for controlnet and others
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Apr 14, 2023
1 parent bd99239 commit 80d00e4
Show file tree
Hide file tree
Showing 11 changed files with 722 additions and 112 deletions.
18 changes: 16 additions & 2 deletions api/onnx_web/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,32 @@
)
from .diffusers.stub_scheduler import StubScheduler
from .diffusers.upscale import run_upscale_correction
from .image import (
from .image.utils import (
expand_image,
valid_image,
)
from .image.mask_filter import (
mask_filter_gaussian_multiply,
mask_filter_gaussian_screen,
mask_filter_none,
)
from .image.noise_source import (
noise_source_fill_edge,
noise_source_fill_mask,
noise_source_gaussian,
noise_source_histogram,
noise_source_normal,
noise_source_uniform,
valid_image,
)
from .image.source_filter import (
source_filter_canny,
source_filter_depth,
source_filter_hed,
source_filter_mlsd,
source_filter_normal,
source_filter_pose,
source_filter_scribble,
source_filter_segment,
)
from .onnx import OnnxRRDBNet, OnnxTensor
from .params import (
Expand Down
10 changes: 9 additions & 1 deletion api/onnx_web/diffusers/run.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from logging import getLogger
from typing import Any, List
from typing import Any, List, Optional

import numpy as np
import torch
Expand All @@ -19,6 +19,7 @@
UpscaleParams,
)
from ..server import ServerContext
from ..server.load import get_source_filters
from ..utils import run_gc
from ..worker import WorkerContext
from .load import get_latents_from_seed, load_pipeline
Expand Down Expand Up @@ -222,11 +223,16 @@ def run_img2img_pipeline(
upscale: UpscaleParams,
source: Image.Image,
strength: float,
source_filter: Optional[str] = None,
) -> None:
(prompt, loras) = get_loras_from_prompt(params.prompt)
(prompt, inversions) = get_inversions_from_prompt(prompt)
params.prompt = prompt

# filter the source image
if source_filter is not None:
source = get_source_filters(source_filter)(source)

pipe = load_pipeline(
server,
params.pipeline, # this is one of the only places this can actually vary between different pipelines
Expand All @@ -243,6 +249,8 @@ def run_img2img_pipeline(
pipe_params["controlnet_conditioning_scale"] = strength
elif params.pipeline == "img2img":
pipe_params["strength"] = strength
elif params.pipeline == "pix2pix":
pipe_params["image_guidance_scale"] = strength

progress = job.get_progress_callback()
if params.lpw():
Expand Down
30 changes: 30 additions & 0 deletions api/onnx_web/image/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from .utils import (
expand_image,
valid_image,
)
from .mask_filter import (
mask_filter_gaussian_multiply,
mask_filter_gaussian_screen,
mask_filter_none,
)
from .noise_source import (
noise_source_fill_edge,
noise_source_fill_mask,
noise_source_gaussian,
noise_source_histogram,
noise_source_normal,
noise_source_uniform,
)
from .source_filter import (
source_filter_canny,
source_filter_depth,
source_filter_face,
source_filter_gaussian,
source_filter_hed,
source_filter_mlsd,
source_filter_noise,
source_filter_normal,
source_filter_openpose,
source_filter_scribble,
source_filter_segment,
)
167 changes: 167 additions & 0 deletions api/onnx_web/image/laion_face.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# from https://github.com/ForserX/StableDiffusionUI/blob/main/data/repo/diffusion_scripts/modules/controlnet/laion_face_common.py

from typing import Mapping

import mediapipe as mp
import numpy

mp_drawing = mp.solutions.drawing_utils
mp_drawing_styles = mp.solutions.drawing_styles
mp_face_detection = mp.solutions.face_detection # Only for counting faces.
mp_face_mesh = mp.solutions.face_mesh
mp_face_connections = mp.solutions.face_mesh_connections.FACEMESH_TESSELATION
mp_hand_connections = mp.solutions.hands_connections.HAND_CONNECTIONS
mp_body_connections = mp.solutions.pose_connections.POSE_CONNECTIONS

DrawingSpec = mp.solutions.drawing_styles.DrawingSpec
PoseLandmark = mp.solutions.drawing_styles.PoseLandmark

min_face_size_pixels: int = 64
f_thick = 2
f_rad = 1
right_iris_draw = DrawingSpec(
color=(10, 200, 250), thickness=f_thick, circle_radius=f_rad
)
right_eye_draw = DrawingSpec(
color=(10, 200, 180), thickness=f_thick, circle_radius=f_rad
)
right_eyebrow_draw = DrawingSpec(
color=(10, 220, 180), thickness=f_thick, circle_radius=f_rad
)
left_iris_draw = DrawingSpec(
color=(250, 200, 10), thickness=f_thick, circle_radius=f_rad
)
left_eye_draw = DrawingSpec(
color=(180, 200, 10), thickness=f_thick, circle_radius=f_rad
)
left_eyebrow_draw = DrawingSpec(
color=(180, 220, 10), thickness=f_thick, circle_radius=f_rad
)
mouth_draw = DrawingSpec(color=(10, 180, 10), thickness=f_thick, circle_radius=f_rad)
head_draw = DrawingSpec(color=(10, 200, 10), thickness=f_thick, circle_radius=f_rad)

# mp_face_mesh.FACEMESH_CONTOURS has all the items we care about.
face_connection_spec = {}
for edge in mp_face_mesh.FACEMESH_FACE_OVAL:
face_connection_spec[edge] = head_draw
for edge in mp_face_mesh.FACEMESH_LEFT_EYE:
face_connection_spec[edge] = left_eye_draw
for edge in mp_face_mesh.FACEMESH_LEFT_EYEBROW:
face_connection_spec[edge] = left_eyebrow_draw
# for edge in mp_face_mesh.FACEMESH_LEFT_IRIS:
# face_connection_spec[edge] = left_iris_draw
for edge in mp_face_mesh.FACEMESH_RIGHT_EYE:
face_connection_spec[edge] = right_eye_draw
for edge in mp_face_mesh.FACEMESH_RIGHT_EYEBROW:
face_connection_spec[edge] = right_eyebrow_draw
# for edge in mp_face_mesh.FACEMESH_RIGHT_IRIS:
# face_connection_spec[edge] = right_iris_draw
for edge in mp_face_mesh.FACEMESH_LIPS:
face_connection_spec[edge] = mouth_draw
iris_landmark_spec = {468: right_iris_draw, 473: left_iris_draw}


def draw_pupils(image, landmark_list, drawing_spec, halfwidth: int = 2):
"""We have a custom function to draw the pupils because the mp.draw_landmarks method requires a parameter for all
landmarks. Until our PR is merged into mediapipe, we need this separate method."""
if len(image.shape) != 3:
raise ValueError("Input image must be H,W,C.")
image_rows, image_cols, image_channels = image.shape
if image_channels != 3: # BGR channels
raise ValueError("Input image must contain three channel bgr data.")
for idx, landmark in enumerate(landmark_list.landmark):
if (landmark.HasField("visibility") and landmark.visibility < 0.9) or (
landmark.HasField("presence") and landmark.presence < 0.5
):
continue
if landmark.x >= 1.0 or landmark.x < 0 or landmark.y >= 1.0 or landmark.y < 0:
continue
image_x = int(image_cols * landmark.x)
image_y = int(image_rows * landmark.y)
draw_color = None
if isinstance(drawing_spec, Mapping):
if drawing_spec.get(idx) is None:
continue
else:
draw_color = drawing_spec[idx].color
elif isinstance(drawing_spec, DrawingSpec):
draw_color = drawing_spec.color
image[
image_y - halfwidth : image_y + halfwidth,
image_x - halfwidth : image_x + halfwidth,
:,
] = draw_color


def reverse_channels(image):
"""Given a numpy array in RGB form, convert to BGR. Will also convert from BGR to RGB."""
# im[:,:,::-1] is a neat hack to convert BGR to RGB by reversing the indexing order.
# im[:,:,::[2,1,0]] would also work but makes a copy of the data.
return image[:, :, ::-1]


def generate_annotation(img_rgb, max_faces: int, min_confidence: float):
"""
Find up to 'max_faces' inside the provided input image.
If min_face_size_pixels is provided and nonzero it will be used to filter faces that occupy less than this many
pixels in the image.
"""
with mp_face_mesh.FaceMesh(
static_image_mode=True,
max_num_faces=max_faces,
refine_landmarks=True,
min_detection_confidence=min_confidence,
) as facemesh:
img_height, img_width, img_channels = img_rgb.shape
assert img_channels == 3

results = facemesh.process(img_rgb).multi_face_landmarks

if results is None:
print("No faces detected in controlnet image for Mediapipe face annotator.")
return numpy.zeros_like(img_rgb)

# Filter faces that are too small
filtered_landmarks = []
for lm in results:
landmarks = lm.landmark
face_rect = [
landmarks[0].x,
landmarks[0].y,
landmarks[0].x,
landmarks[0].y,
] # Left, up, right, down.
for i in range(len(landmarks)):
face_rect[0] = min(face_rect[0], landmarks[i].x)
face_rect[1] = min(face_rect[1], landmarks[i].y)
face_rect[2] = max(face_rect[2], landmarks[i].x)
face_rect[3] = max(face_rect[3], landmarks[i].y)
if min_face_size_pixels > 0:
face_width = abs(face_rect[2] - face_rect[0])
face_height = abs(face_rect[3] - face_rect[1])
face_width_pixels = face_width * img_width
face_height_pixels = face_height * img_height
face_size = min(face_width_pixels, face_height_pixels)
if face_size >= min_face_size_pixels:
filtered_landmarks.append(lm)
else:
filtered_landmarks.append(lm)

# Annotations are drawn in BGR for some reason, but we don't need to flip a zero-filled image at the start.
empty = numpy.zeros_like(img_rgb)

# Draw detected faces:
for face_landmarks in filtered_landmarks:
mp_drawing.draw_landmarks(
empty,
face_landmarks,
connections=face_connection_spec.keys(),
landmark_drawing_spec=None,
connection_drawing_spec=face_connection_spec,
)
draw_pupils(empty, face_landmarks, iris_landmark_spec, 2)

# Flip BGR back to RGB.
empty = reverse_channels(empty).copy()

return empty
44 changes: 44 additions & 0 deletions api/onnx_web/image/mask_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from PIL import Image, ImageChops, ImageFilter

from .params import Point


def mask_filter_none(
mask: Image.Image, dims: Point, origin: Point, fill="white", **kw
) -> Image.Image:
width, height = dims

noise = Image.new("RGB", (width, height), fill)
noise.paste(mask, origin)

return noise


def mask_filter_gaussian_multiply(
mask: Image.Image, dims: Point, origin: Point, rounds=3, **kw
) -> Image.Image:
"""
Gaussian blur with multiply, source image centered on white canvas.
"""
noise = mask_filter_none(mask, dims, origin)

for _i in range(rounds):
blur = noise.filter(ImageFilter.GaussianBlur(5))
noise = ImageChops.multiply(noise, blur)

return noise


def mask_filter_gaussian_screen(
mask: Image.Image, dims: Point, origin: Point, rounds=3, **kw
) -> Image.Image:
"""
Gaussian blur, source image centered on white canvas.
"""
noise = mask_filter_none(mask, dims, origin)

for _i in range(rounds):
blur = noise.filter(ImageFilter.GaussianBlur(5))
noise = ImageChops.screen(noise, blur)

return noise
Loading

0 comments on commit 80d00e4

Please sign in to comment.