Skip to content

Commit

Permalink
feat(api): add batch size to txt2img and img2img pipelines (#195)
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Feb 20, 2023
1 parent 0deaa88 commit 5f3b848
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 50 deletions.
5 changes: 3 additions & 2 deletions api/onnx_web/chain/persist_disk.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from logging import getLogger
from typing import List

from PIL import Image

Expand All @@ -16,12 +17,12 @@ def persist_disk(
_params: ImageParams,
source: Image.Image,
*,
output: str,
output: List[str],
stage_source: Image.Image,
**kwargs,
) -> Image.Image:
source = stage_source or source

dest = save_image(server, output, source)
dest = save_image(server, output[0], source)
logger.info("saved image to %s", dest)
return source
78 changes: 40 additions & 38 deletions api/onnx_web/diffusion/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def run_txt2img_pipeline(
server: ServerContext,
params: ImageParams,
size: Size,
output: str,
outputs: List[str],
upscale: UpscaleParams,
) -> None:
latents = get_latents_from_seed(params.seed, size)
Expand All @@ -50,6 +50,7 @@ def run_txt2img_pipeline(
guidance_scale=params.cfg,
latents=latents,
negative_prompt=params.negative_prompt,
num_images_per_prompt=params.batch,
num_inference_steps=params.steps,
eta=params.eta,
callback=progress,
Expand All @@ -64,27 +65,27 @@ def run_txt2img_pipeline(
guidance_scale=params.cfg,
latents=latents,
negative_prompt=params.negative_prompt,
num_images_per_prompt=params.batch,
num_inference_steps=params.steps,
eta=params.eta,
callback=progress,
)

image = result.images[0]
image = run_upscale_correction(
job,
server,
StageParams(),
params,
image,
upscale=upscale,
callback=progress,
)
for image, output in zip(result.images, outputs):
image = run_upscale_correction(
job,
server,
StageParams(),
params,
image,
upscale=upscale,
callback=progress,
)

dest = save_image(server, output, image)
save_params(server, output, params, size, upscale=upscale)
dest = save_image(server, output, image)
save_params(server, output, params, size, upscale=upscale)

del pipe
del image
del result

run_gc([job.get_device()])
Expand All @@ -96,7 +97,7 @@ def run_img2img_pipeline(
job: JobContext,
server: ServerContext,
params: ImageParams,
output: str,
outputs: List[str],
upscale: UpscaleParams,
source: Image.Image,
strength: float,
Expand All @@ -119,6 +120,7 @@ def run_img2img_pipeline(
generator=rng,
guidance_scale=params.cfg,
negative_prompt=params.negative_prompt,
num_images_per_prompt=params.batch,
num_inference_steps=params.steps,
strength=strength,
eta=params.eta,
Expand All @@ -132,29 +134,29 @@ def run_img2img_pipeline(
generator=rng,
guidance_scale=params.cfg,
negative_prompt=params.negative_prompt,
num_images_per_prompt=params.batch,
num_inference_steps=params.steps,
strength=strength,
eta=params.eta,
callback=progress,
)

image = result.images[0]
image = run_upscale_correction(
job,
server,
StageParams(),
params,
image,
upscale=upscale,
callback=progress,
)
for image, output in zip(result.images, outputs):
image = run_upscale_correction(
job,
server,
StageParams(),
params,
image,
upscale=upscale,
callback=progress,
)

dest = save_image(server, output, image)
size = Size(*source.size)
save_params(server, output, params, size, upscale=upscale)
dest = save_image(server, output, image)
size = Size(*source.size)
save_params(server, output, params, size, upscale=upscale)

del pipe
del image
del result

run_gc([job.get_device()])
Expand All @@ -167,7 +169,7 @@ def run_inpaint_pipeline(
server: ServerContext,
params: ImageParams,
size: Size,
output: str,
outputs: List[str],
upscale: UpscaleParams,
source: Image.Image,
mask: Image.Image,
Expand Down Expand Up @@ -202,8 +204,8 @@ def run_inpaint_pipeline(
job, server, stage, params, image, upscale=upscale, callback=progress
)

dest = save_image(server, output, image)
save_params(server, output, params, size, upscale=upscale, border=border)
dest = save_image(server, outputs[0], image)
save_params(server, outputs[0], params, size, upscale=upscale, border=border)

del image

Expand All @@ -217,7 +219,7 @@ def run_upscale_pipeline(
server: ServerContext,
params: ImageParams,
size: Size,
output: str,
outputs: List[str],
upscale: UpscaleParams,
source: Image.Image,
) -> None:
Expand All @@ -228,8 +230,8 @@ def run_upscale_pipeline(
job, server, stage, params, source, upscale=upscale, callback=progress
)

dest = save_image(server, output, image)
save_params(server, output, params, size, upscale=upscale)
dest = save_image(server, outputs[0], image)
save_params(server, outputs[0], params, size, upscale=upscale)

del image

Expand All @@ -243,7 +245,7 @@ def run_blend_pipeline(
server: ServerContext,
params: ImageParams,
size: Size,
output: str,
outputs: List[str],
upscale: UpscaleParams,
sources: List[Image.Image],
mask: Image.Image,
Expand All @@ -266,8 +268,8 @@ def run_blend_pipeline(
job, server, stage, params, image, upscale=upscale, callback=progress
)

dest = save_image(server, output, image)
save_params(server, output, params, size, upscale=upscale)
dest = save_image(server, outputs[0], image)
save_params(server, outputs[0], params, size, upscale=upscale)

del image

Expand Down
17 changes: 7 additions & 10 deletions api/onnx_web/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from os import path
from struct import pack
from time import time
from typing import Any, Optional, Tuple
from typing import Any, List, Optional, Tuple

from PIL import Image

Expand Down Expand Up @@ -63,7 +63,7 @@ def make_output_name(
params: ImageParams,
size: Size,
extras: Optional[Tuple[Param]] = None,
) -> str:
) -> List[str]:
now = int(time())
sha = sha256()

Expand All @@ -82,13 +82,10 @@ def make_output_name(
for param in extras:
hash_value(sha, param)

return "%s_%s_%s_%s.%s" % (
mode,
params.seed,
sha.hexdigest(),
now,
ctx.image_format,
)
return [
f"{mode}_{params.seed}_{sha.hexdigest()}_{now}_{i}.{ctx.image_format}"
for i in range(params.batch)
]


def save_image(ctx: ServerContext, output: str, image: Image.Image) -> str:
Expand All @@ -106,7 +103,7 @@ def save_params(
upscale: Optional[UpscaleParams] = None,
border: Optional[Border] = None,
) -> str:
path = base_join(ctx.output_path, "%s.json" % (output))
path = base_join(ctx.output_path, f"{output}.json")
json = json_params(output, params, size, upscale=upscale, border=border)
with open(path, "w") as f:
f.write(dumps(json))
Expand Down
4 changes: 4 additions & 0 deletions api/onnx_web/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def __init__(
negative_prompt: Optional[str] = None,
lpw: bool = False,
eta: float = 0.0,
batch: int = 1,
) -> None:
self.model = model
self.scheduler = scheduler
Expand All @@ -166,6 +167,7 @@ def __init__(
self.steps = steps
self.lpw = lpw or False
self.eta = eta
self.batch = batch

def tojson(self) -> Dict[str, Optional[Param]]:
return {
Expand All @@ -178,6 +180,7 @@ def tojson(self) -> Dict[str, Optional[Param]]:
"steps": self.steps,
"lpw": self.lpw,
"eta": self.eta,
"batch": self.batch,
}

def with_args(self, **kwargs):
Expand All @@ -191,6 +194,7 @@ def with_args(self, **kwargs):
kwargs.get("negative_prompt", self.negative_prompt),
kwargs.get("lpw", self.lpw),
kwargs.get("eta", self.eta),
kwargs.get("batch", self.batch),
)


Expand Down
6 changes: 6 additions & 0 deletions api/params.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
{
"version": "0.7.1",
"batch": {
"default": 1,
"min": 1,
"max": 5,
"step": 1
},
"bottom": {
"default": 0,
"min": 0,
Expand Down
1 change: 1 addition & 0 deletions gui/src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ export interface BaseImgParams {
prompt: string;
negativePrompt?: string;

batch: number;
cfg: number;
steps: number;
seed: number;
Expand Down
15 changes: 15 additions & 0 deletions gui/src/components/control/ImageControl.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,21 @@ export function ImageControl(props: ImageControlProps) {
}
}}
/>
<NumericField
label='Batch Size'
min={params.batch.min}
max={params.batch.max}
step={params.batch.step}
value={controlState.batch}
onChange={(batch) => {
if (doesExist(props.onChange)) {
props.onChange({
...controlState,
batch,
});
}
}}
/>
</Stack>
<Stack direction='row' spacing={4}>
<NumericField
Expand Down
1 change: 1 addition & 0 deletions gui/src/state.ts
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ export const DEFAULT_HISTORY = {

export function baseParamsFromServer(defaults: ServerParams): Required<BaseImgParams> {
return {
batch: defaults.batch.default,
cfg: defaults.cfg.default,
eta: defaults.eta.default,
negativePrompt: defaults.negativePrompt.default,
Expand Down

0 comments on commit 5f3b848

Please sign in to comment.