From 33ab23a474e8ddce7b6897e886112211b5f8b90b Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Mon, 20 Feb 2023 23:47:43 -0600 Subject: [PATCH] convert to list of outputs --- api/onnx_web/output.py | 2 +- api/onnx_web/serve.py | 28 ++++++++++++++---------- api/params.json | 2 +- gui/src/client.ts | 33 +++++++++++++++++------------ gui/src/components/ImageCard.tsx | 6 +++--- gui/src/components/ImageHistory.tsx | 4 ++-- gui/src/components/LoadingCard.tsx | 6 ++++-- gui/src/state.ts | 6 +++--- 8 files changed, 50 insertions(+), 37 deletions(-) diff --git a/api/onnx_web/output.py b/api/onnx_web/output.py index 330ca14d3..446cafb66 100644 --- a/api/onnx_web/output.py +++ b/api/onnx_web/output.py @@ -30,7 +30,7 @@ def hash_value(sha, param: Param): def json_params( - output: str, + output: List[str], params: ImageParams, size: Size, upscale: Optional[UpscaleParams] = None, diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index 04ebae1d5..befd1fda7 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -542,11 +542,12 @@ def img2img(): ) output = make_output_name(context, "img2img", params, size, extras=(strength,)) - logger.info("img2img job queued for: %s", output) + job_name = output[0] + logger.info("img2img job queued for: %s", job_name) source = valid_image(source, min_dims=size, max_dims=size) executor.submit( - output, + job_name, run_img2img_pipeline, context, params, @@ -566,10 +567,11 @@ def txt2img(): upscale = upscale_from_request() output = make_output_name(context, "txt2img", params, size) - logger.info("txt2img job queued for: %s", output) + job_name = output[0] + logger.info("txt2img job queued for: %s", job_name) executor.submit( - output, + job_name, run_txt2img_pipeline, context, params, @@ -623,12 +625,13 @@ def inpaint(): tile_order, ), ) - logger.info("inpaint job queued for: %s", output) + job_name = output[0] + logger.info("inpaint job queued for: %s", job_name) source = valid_image(source, min_dims=size, max_dims=size) mask = valid_image(mask, min_dims=size, max_dims=size) executor.submit( - output, + job_name, run_inpaint_pipeline, context, params, @@ -660,11 +663,12 @@ def upscale(): upscale = upscale_from_request() output = make_output_name(context, "upscale", params, size) - logger.info("upscale job queued for: %s", output) + job_name = output[0] + logger.info("upscale job queued for: %s", job_name) source = valid_image(source, min_dims=size, max_dims=size) executor.submit( - output, + job_name, run_upscale_pipeline, context, params, @@ -697,6 +701,7 @@ def chain(): # get defaults from the regular parameters device, params, size = pipeline_from_request() output = make_output_name(context, "chain", params, size) + job_name = output[0] pipeline = ChainPipeline() for stage_data in data.get("stages", []): @@ -750,7 +755,7 @@ def chain(): # build and run chain pipeline empty_source = Image.new("RGB", (size.width, size.height)) executor.submit( - output, + job_name, pipeline, context, params, @@ -785,10 +790,11 @@ def blend(): upscale = upscale_from_request() output = make_output_name(context, "upscale", params, size) - logger.info("upscale job queued for: %s", output) + job_name = output[0] + logger.info("upscale job queued for: %s", job_name) executor.submit( - output, + job_name, run_blend_pipeline, context, params, diff --git a/api/params.json b/api/params.json index 599634a0b..e45aefdc3 100644 --- a/api/params.json +++ b/api/params.json @@ -32,7 +32,7 @@ "default": 0.0, "min": 0, "max": 1, - "step": 0.1 + "step": 0.01 }, "faceOutscale": { "default": 1, diff --git a/gui/src/client.ts b/gui/src/client.ts index aee390297..127ddc075 100644 --- a/gui/src/client.ts +++ b/gui/src/client.ts @@ -145,10 +145,10 @@ export interface BlendParams { * General response for most image requests. */ export interface ImageResponse { - output: { + output: Array<{ key: string; url: string; - }; + }>; params: Required & Required; size: { width: number; @@ -237,9 +237,9 @@ export interface ApiClient { /** * Check whether some pipeline's output is ready yet. */ - ready(params: ImageResponse): Promise; + ready(key: string): Promise; - cancel(params: ImageResponse): Promise; + cancel(key: string): Promise; } /** @@ -520,16 +520,16 @@ export function makeClient(root: string, f = fetch): ApiClient { method: 'POST', }); }, - async ready(params: ImageResponse): Promise { + async ready(key: string): Promise { const path = makeApiUrl(root, 'ready'); - path.searchParams.append('output', params.output.key); + path.searchParams.append('output', key); const res = await f(path); return await res.json() as ReadyResponse; }, - async cancel(params: ImageResponse): Promise { + async cancel(key: string): Promise { const path = makeApiUrl(root, 'cancel'); - path.searchParams.append('output', params.output.key); + path.searchParams.append('output', key); const res = await f(path, { method: 'PUT', @@ -546,17 +546,22 @@ export function makeClient(root: string, f = fetch): ApiClient { * that into a full URL, since it already knows the root URL of the server. */ export async function parseApiResponse(root: string, res: Response): Promise { - type LimitedResponse = Omit & { output: string }; + type LimitedResponse = Omit & { output: Array }; if (res.status === STATUS_SUCCESS) { const data = await res.json() as LimitedResponse; - const url = new URL(joinPath('output', data.output), root).toString(); + + const images = data.output.map((output) => { + const url = new URL(joinPath('output', output), root).toString(); + return { + key: output, + url, + }; + }); + return { ...data, - output: { - key: data.output, - url, - }, + output: images, }; } else { throw new Error('request error'); diff --git a/gui/src/components/ImageCard.tsx b/gui/src/components/ImageCard.tsx index 7bbcfe4ea..fcffbc134 100644 --- a/gui/src/components/ImageCard.tsx +++ b/gui/src/components/ImageCard.tsx @@ -42,7 +42,7 @@ export function ImageCard(props: ImageCardProps) { const setBlend = useStore(state, (s) => s.setBlend); async function loadSource() { - const req = await fetch(output.url); + const req = await fetch(output[0].url); return req.blob(); } @@ -88,7 +88,7 @@ export function ImageCard(props: ImageCardProps) { } function downloadImage() { - window.open(output.url, '_blank'); + window.open(output[0].url, '_blank'); } function close() { @@ -101,7 +101,7 @@ export function ImageCard(props: ImageCardProps) { return diff --git a/gui/src/components/ImageHistory.tsx b/gui/src/components/ImageHistory.tsx index 3e08bbbce..a0dd58244 100644 --- a/gui/src/components/ImageHistory.tsx +++ b/gui/src/components/ImageHistory.tsx @@ -18,11 +18,11 @@ export function ImageHistory() { const children = []; if (loading.length > 0) { - children.push(...loading.map((item) => )); + children.push(...loading.map((item) => )); } if (history.length > 0) { - children.push(...history.map((item) => )); + children.push(...history.map((item) => )); } else { if (doesExist(loading) === false) { children.push(No results. Press Generate.); diff --git a/gui/src/components/LoadingCard.tsx b/gui/src/components/LoadingCard.tsx index 340485be7..e81b55889 100644 --- a/gui/src/components/LoadingCard.tsx +++ b/gui/src/components/LoadingCard.tsx @@ -14,10 +14,12 @@ const LOADING_PERCENT = 100; const LOADING_OVERAGE = 99; export interface LoadingCardProps { + index: number; loading: ImageResponse; } export function LoadingCard(props: LoadingCardProps) { + const { index, loading } = props; const { steps } = props.loading.params; const client = mustExist(React.useContext(ClientContext)); @@ -31,8 +33,8 @@ export function LoadingCard(props: LoadingCardProps) { // eslint-disable-next-line @typescript-eslint/unbound-method const setReady = useStore(state, (s) => s.setReady); - const cancel = useMutation(() => client.cancel(props.loading)); - const ready = useQuery(`ready-${props.loading.output.key}`, () => client.ready(props.loading), { + const cancel = useMutation(() => client.cancel(loading.output[index].key)); + const ready = useQuery(`ready-${loading.output[index].key}`, () => client.ready(loading.output[index].key), { // data will always be ready without this, even if the API says its not cacheTime: 0, refetchInterval: POLL_TIME, diff --git a/gui/src/state.ts b/gui/src/state.ts index 937b5d3a7..0b06ab254 100644 --- a/gui/src/state.ts +++ b/gui/src/state.ts @@ -311,7 +311,7 @@ export function createStateSlices(server: ServerParams) { clearLoading(image) { set((prev) => ({ ...prev, - loading: prev.loading.filter((it) => it.image.output.key !== image.output.key), + loading: prev.loading.filter((it) => it.image.output[0].key !== image.output[0].key), })); }, pushHistory(image) { @@ -321,7 +321,7 @@ export function createStateSlices(server: ServerParams) { image, ...prev.history, ].slice(0, prev.limit + DEFAULT_HISTORY.scrollback), - loading: prev.loading.filter((it) => it.image.output.key !== image.output.key), + loading: prev.loading.filter((it) => it.image.output[0].key !== image.output[0].key), })); }, pushLoading(image) { @@ -354,7 +354,7 @@ export function createStateSlices(server: ServerParams) { setReady(image, ready) { set((prev) => { const loading = [...prev.loading]; - const idx = loading.findIndex((it) => it.image.output.key === image.output.key); + const idx = loading.findIndex((it) => it.image.output[0].key === image.output[0].key); if (idx >= 0) { loading[idx].ready = ready; } else {