Skip to content

Commit

Permalink
convert to list of outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Feb 21, 2023
1 parent b65e4e7 commit 33ab23a
Show file tree
Hide file tree
Showing 8 changed files with 50 additions and 37 deletions.
2 changes: 1 addition & 1 deletion api/onnx_web/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def hash_value(sha, param: Param):


def json_params(
output: str,
output: List[str],
params: ImageParams,
size: Size,
upscale: Optional[UpscaleParams] = None,
Expand Down
28 changes: 17 additions & 11 deletions api/onnx_web/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,11 +542,12 @@ def img2img():
)

output = make_output_name(context, "img2img", params, size, extras=(strength,))
logger.info("img2img job queued for: %s", output)
job_name = output[0]
logger.info("img2img job queued for: %s", job_name)

source = valid_image(source, min_dims=size, max_dims=size)
executor.submit(
output,
job_name,
run_img2img_pipeline,
context,
params,
Expand All @@ -566,10 +567,11 @@ def txt2img():
upscale = upscale_from_request()

output = make_output_name(context, "txt2img", params, size)
logger.info("txt2img job queued for: %s", output)
job_name = output[0]
logger.info("txt2img job queued for: %s", job_name)

executor.submit(
output,
job_name,
run_txt2img_pipeline,
context,
params,
Expand Down Expand Up @@ -623,12 +625,13 @@ def inpaint():
tile_order,
),
)
logger.info("inpaint job queued for: %s", output)
job_name = output[0]
logger.info("inpaint job queued for: %s", job_name)

source = valid_image(source, min_dims=size, max_dims=size)
mask = valid_image(mask, min_dims=size, max_dims=size)
executor.submit(
output,
job_name,
run_inpaint_pipeline,
context,
params,
Expand Down Expand Up @@ -660,11 +663,12 @@ def upscale():
upscale = upscale_from_request()

output = make_output_name(context, "upscale", params, size)
logger.info("upscale job queued for: %s", output)
job_name = output[0]
logger.info("upscale job queued for: %s", job_name)

source = valid_image(source, min_dims=size, max_dims=size)
executor.submit(
output,
job_name,
run_upscale_pipeline,
context,
params,
Expand Down Expand Up @@ -697,6 +701,7 @@ def chain():
# get defaults from the regular parameters
device, params, size = pipeline_from_request()
output = make_output_name(context, "chain", params, size)
job_name = output[0]

pipeline = ChainPipeline()
for stage_data in data.get("stages", []):
Expand Down Expand Up @@ -750,7 +755,7 @@ def chain():
# build and run chain pipeline
empty_source = Image.new("RGB", (size.width, size.height))
executor.submit(
output,
job_name,
pipeline,
context,
params,
Expand Down Expand Up @@ -785,10 +790,11 @@ def blend():
upscale = upscale_from_request()

output = make_output_name(context, "upscale", params, size)
logger.info("upscale job queued for: %s", output)
job_name = output[0]
logger.info("upscale job queued for: %s", job_name)

executor.submit(
output,
job_name,
run_blend_pipeline,
context,
params,
Expand Down
2 changes: 1 addition & 1 deletion api/params.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
"default": 0.0,
"min": 0,
"max": 1,
"step": 0.1
"step": 0.01
},
"faceOutscale": {
"default": 1,
Expand Down
33 changes: 19 additions & 14 deletions gui/src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,10 @@ export interface BlendParams {
* General response for most image requests.
*/
export interface ImageResponse {
output: {
output: Array<{
key: string;
url: string;
};
}>;
params: Required<BaseImgParams> & Required<ModelParams>;
size: {
width: number;
Expand Down Expand Up @@ -237,9 +237,9 @@ export interface ApiClient {
/**
* Check whether some pipeline's output is ready yet.
*/
ready(params: ImageResponse): Promise<ReadyResponse>;
ready(key: string): Promise<ReadyResponse>;

cancel(params: ImageResponse): Promise<boolean>;
cancel(key: string): Promise<boolean>;
}

/**
Expand Down Expand Up @@ -520,16 +520,16 @@ export function makeClient(root: string, f = fetch): ApiClient {
method: 'POST',
});
},
async ready(params: ImageResponse): Promise<ReadyResponse> {
async ready(key: string): Promise<ReadyResponse> {
const path = makeApiUrl(root, 'ready');
path.searchParams.append('output', params.output.key);
path.searchParams.append('output', key);

const res = await f(path);
return await res.json() as ReadyResponse;
},
async cancel(params: ImageResponse): Promise<boolean> {
async cancel(key: string): Promise<boolean> {
const path = makeApiUrl(root, 'cancel');
path.searchParams.append('output', params.output.key);
path.searchParams.append('output', key);

const res = await f(path, {
method: 'PUT',
Expand All @@ -546,17 +546,22 @@ export function makeClient(root: string, f = fetch): ApiClient {
* that into a full URL, since it already knows the root URL of the server.
*/
export async function parseApiResponse(root: string, res: Response): Promise<ImageResponse> {
type LimitedResponse = Omit<ImageResponse, 'output'> & { output: string };
type LimitedResponse = Omit<ImageResponse, 'output'> & { output: Array<string> };

if (res.status === STATUS_SUCCESS) {
const data = await res.json() as LimitedResponse;
const url = new URL(joinPath('output', data.output), root).toString();

const images = data.output.map((output) => {
const url = new URL(joinPath('output', output), root).toString();
return {
key: output,
url,
};
});

return {
...data,
output: {
key: data.output,
url,
},
output: images,
};
} else {
throw new Error('request error');
Expand Down
6 changes: 3 additions & 3 deletions gui/src/components/ImageCard.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ export function ImageCard(props: ImageCardProps) {
const setBlend = useStore(state, (s) => s.setBlend);

async function loadSource() {
const req = await fetch(output.url);
const req = await fetch(output[0].url);
return req.blob();
}

Expand Down Expand Up @@ -88,7 +88,7 @@ export function ImageCard(props: ImageCardProps) {
}

function downloadImage() {
window.open(output.url, '_blank');
window.open(output[0].url, '_blank');
}

function close() {
Expand All @@ -101,7 +101,7 @@ export function ImageCard(props: ImageCardProps) {
return <Card sx={{ maxWidth: config.params.width.default }} elevation={2}>
<CardMedia sx={{ height: config.params.height.default }}
component='img'
image={output.url}
image={output[0].url}
title={params.prompt}
/>
<CardContent>
Expand Down
4 changes: 2 additions & 2 deletions gui/src/components/ImageHistory.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ export function ImageHistory() {
const children = [];

if (loading.length > 0) {
children.push(...loading.map((item) => <LoadingCard key={`loading-${item.image.output.key}`} loading={item.image} />));
children.push(...loading.map((item) => <LoadingCard key={`loading-${item.image.output[0].key}`} index={0} loading={item.image} />));
}

if (history.length > 0) {
children.push(...history.map((item) => <ImageCard key={`history-${item.output.key}`} value={item} onDelete={removeHistory} />));
children.push(...history.map((item) => <ImageCard key={`history-${item.output[0].key}`} value={item} onDelete={removeHistory} />));
} else {
if (doesExist(loading) === false) {
children.push(<Typography>No results. Press Generate.</Typography>);
Expand Down
6 changes: 4 additions & 2 deletions gui/src/components/LoadingCard.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@ const LOADING_PERCENT = 100;
const LOADING_OVERAGE = 99;

export interface LoadingCardProps {
index: number;
loading: ImageResponse;
}

export function LoadingCard(props: LoadingCardProps) {
const { index, loading } = props;
const { steps } = props.loading.params;

const client = mustExist(React.useContext(ClientContext));
Expand All @@ -31,8 +33,8 @@ export function LoadingCard(props: LoadingCardProps) {
// eslint-disable-next-line @typescript-eslint/unbound-method
const setReady = useStore(state, (s) => s.setReady);

const cancel = useMutation(() => client.cancel(props.loading));
const ready = useQuery(`ready-${props.loading.output.key}`, () => client.ready(props.loading), {
const cancel = useMutation(() => client.cancel(loading.output[index].key));
const ready = useQuery(`ready-${loading.output[index].key}`, () => client.ready(loading.output[index].key), {
// data will always be ready without this, even if the API says its not
cacheTime: 0,
refetchInterval: POLL_TIME,
Expand Down
6 changes: 3 additions & 3 deletions gui/src/state.ts
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ export function createStateSlices(server: ServerParams) {
clearLoading(image) {
set((prev) => ({
...prev,
loading: prev.loading.filter((it) => it.image.output.key !== image.output.key),
loading: prev.loading.filter((it) => it.image.output[0].key !== image.output[0].key),
}));
},
pushHistory(image) {
Expand All @@ -321,7 +321,7 @@ export function createStateSlices(server: ServerParams) {
image,
...prev.history,
].slice(0, prev.limit + DEFAULT_HISTORY.scrollback),
loading: prev.loading.filter((it) => it.image.output.key !== image.output.key),
loading: prev.loading.filter((it) => it.image.output[0].key !== image.output[0].key),
}));
},
pushLoading(image) {
Expand Down Expand Up @@ -354,7 +354,7 @@ export function createStateSlices(server: ServerParams) {
setReady(image, ready) {
set((prev) => {
const loading = [...prev.loading];
const idx = loading.findIndex((it) => it.image.output.key === image.output.key);
const idx = loading.findIndex((it) => it.image.output[0].key === image.output[0].key);
if (idx >= 0) {
loading[idx].ready = ready;
} else {
Expand Down

0 comments on commit 33ab23a

Please sign in to comment.