Skip to content

Commit

Permalink
feat(gui): implement image polling on the client
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jan 13, 2023
1 parent 55e8b80 commit c36dadd
Show file tree
Hide file tree
Showing 10 changed files with 82 additions and 44 deletions.
25 changes: 21 additions & 4 deletions gui/src/api/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@ export interface OutpaintParams extends Img2ImgParams {
}

export interface ApiResponse {
output: string;
output: {
key: string;
url: string;
};
params: Txt2ImgResponse;
}

Expand All @@ -68,6 +71,8 @@ export interface ApiClient {

inpaint(params: InpaintParams): Promise<ApiResponse>;
outpaint(params: OutpaintParams): Promise<ApiResponse>;

ready(params: ApiResponse): Promise<{ready: boolean}>;
}

export const STATUS_SUCCESS = 200;
Expand All @@ -94,11 +99,16 @@ export function joinPath(...parts: Array<string>): string {
}

export async function imageFromResponse(root: string, res: Response): Promise<ApiResponse> {
type LimitedResponse = Omit<ApiResponse, 'output'> & {output: string};

if (res.status === STATUS_SUCCESS) {
const data = await res.json() as ApiResponse;
const output = new URL(joinPath('output', data.output), root).toString();
const data = await res.json() as LimitedResponse;
const url = new URL(joinPath('output', data.output), root).toString();
return {
output,
output: {
key: data.output,
url,
},
params: data.params,
};
} else {
Expand Down Expand Up @@ -229,5 +239,12 @@ export function makeClient(root: string, f = fetch): ApiClient {
async outpaint() {
throw new NotImplementedError();
},
async ready(params: ApiResponse): Promise<{ready: boolean}> {
const path = new URL('ready', root);
path.searchParams.append('output', params.output.key);

const res = await f(path);
return await res.json() as {ready: boolean};
}
};
}
4 changes: 2 additions & 2 deletions gui/src/components/ImageCard.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ export function ImageCard(props: ImageCardProps) {
}

function downloadImage() {
window.open(output, '_blank');
window.open(output.url, '_blank');
}

return <Card sx={{ maxWidth: params.width }} elevation={2}>
<CardMedia sx={{ height: params.height }}
component='img'
image={output}
image={output.url}
title={params.prompt}
/>
<CardContent>
Expand Down
10 changes: 5 additions & 5 deletions gui/src/components/ImageHistory.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { mustExist } from '@apextoaster/js-utils';
import { doesExist, mustExist } from '@apextoaster/js-utils';
import { Grid } from '@mui/material';
import { useContext } from 'react';
import * as React from 'react';
Expand All @@ -17,14 +17,14 @@ export function ImageHistory() {

const children = [];

if (loading) {
children.push(<LoadingCard key='loading' height={512} width={512} />); // TODO: get dimensions from config
if (doesExist(loading)) {
children.push(<LoadingCard key='loading' loading={loading} />);
}

if (history.length > 0) {
children.push(...history.map((item) => <ImageCard key={item.output} value={item} onDelete={removeHistory} />));
children.push(...history.map((item) => <ImageCard key={item.output.key} value={item} onDelete={removeHistory} />));
} else {
if (loading === false) {
if (doesExist(loading) === false) {
children.push(<div>No results. Press Generate.</div>);
}
}
Expand Down
14 changes: 6 additions & 8 deletions gui/src/components/Img2Img.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { mustExist } from '@apextoaster/js-utils';
import { Box, Button, Stack } from '@mui/material';
import * as React from 'react';
import { useMutation } from 'react-query';
import { useMutation, useQueryClient } from 'react-query';
import { useStore } from 'zustand';

import { ConfigParams, IMAGE_FILTER } from '../config.js';
Expand All @@ -23,30 +23,28 @@ export function Img2Img(props: Img2ImgProps) {
const { config, model, platform } = props;

async function uploadSource() {
setLoading(true);

const output = await client.img2img({
...params,
model,
platform,
source: mustExist(source), // TODO: show an error if this doesn't exist
});

pushHistory(output);
setLoading(false);
setLoading(output);
}

const client = mustExist(useContext(ClientContext));
const upload = useMutation(uploadSource);
const query = useQueryClient();
const upload = useMutation(uploadSource, {
onSuccess: () => query.invalidateQueries({ queryKey: 'ready '}),
});

const state = mustExist(useContext(StateContext));
const params = useStore(state, (s) => s.img2img);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setImg2Img = useStore(state, (s) => s.setImg2Img);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setLoading = useStore(state, (s) => s.setLoading);
// eslint-disable-next-line @typescript-eslint/unbound-method
const pushHistory = useStore(state, (s) => s.pushHistory);

const [source, setSource] = useState<File>();

Expand Down
11 changes: 6 additions & 5 deletions gui/src/components/Inpaint.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { doesExist, mustExist } from '@apextoaster/js-utils';
import { FormatColorFill, Gradient } from '@mui/icons-material';
import { Box, Button, Stack } from '@mui/material';
import * as React from 'react';
import { useMutation } from 'react-query';
import { useMutation, useQueryClient } from 'react-query';
import { useStore } from 'zustand';

import { ConfigParams, DEFAULT_BRUSH, IMAGE_FILTER } from '../config.js';
Expand Down Expand Up @@ -69,7 +69,6 @@ export function Inpaint(props: InpaintProps) {

async function uploadSource() {
const canvas = mustExist(canvasRef.current);
setLoading(true);
return new Promise<void>((res, rej) => {
canvas.toBlob((blob) => {
client.inpaint({
Expand All @@ -79,8 +78,7 @@ export function Inpaint(props: InpaintProps) {
mask: mustExist(blob),
source: mustExist(source),
}).then((output) => {
pushHistory(output);
setLoading(false);
setLoading(output);
res();
}).catch((err) => rej(err));
});
Expand Down Expand Up @@ -146,7 +144,10 @@ export function Inpaint(props: InpaintProps) {
// eslint-disable-next-line @typescript-eslint/unbound-method
const pushHistory = useStore(state, (s) => s.pushHistory);

const upload = useMutation(uploadSource);
const query = useQueryClient();
const upload = useMutation(uploadSource, {
onSuccess: () => query.invalidateQueries({ queryKey: 'ready '}),
});
// eslint-disable-next-line no-null/no-null
const canvasRef = useRef<HTMLCanvasElement>(null);

Expand Down
32 changes: 27 additions & 5 deletions gui/src/components/LoadingCard.tsx
Original file line number Diff line number Diff line change
@@ -1,15 +1,37 @@
import { mustExist } from '@apextoaster/js-utils';
import { Card, CardContent, CircularProgress } from '@mui/material';
import * as React from 'react';
import { useContext } from 'react';
import { useQuery } from 'react-query';
import { useStore } from 'zustand';

import { ApiResponse } from '../api/client.js';
import { POLL_TIME } from '../config.js';
import { ClientContext, StateContext } from '../state.js';

export interface LoadingCardProps {
height: number;
width: number;
loading: ApiResponse;
}

export function LoadingCard(props: LoadingCardProps) {
return <Card sx={{ maxWidth: props.width }}>
<CardContent sx={{ height: props.height }}>
<div style={{ display: 'flex', alignItems: 'center', justifyContent: 'center', minHeight: props.height }}>
const client = mustExist(React.useContext(ClientContext));

// eslint-disable-next-line @typescript-eslint/unbound-method
const pushHistory = useStore(mustExist(useContext(StateContext)), (state) => state.pushHistory);

const ready = useQuery('ready', () => client.ready(props.loading), {
refetchInterval: POLL_TIME,
});

React.useEffect(() => {
if (ready.status === 'success' && ready.data.ready) {
pushHistory(props.loading);
}
}, [ready.status, ready.data?.ready]);

return <Card sx={{ maxWidth: props.loading.params.width }}>
<CardContent sx={{ height: props.loading.params.height }}>
<div style={{ display: 'flex', alignItems: 'center', justifyContent: 'center', minHeight: props.loading.params.height }}>
<CircularProgress />
</div>
</CardContent>
Expand Down
12 changes: 6 additions & 6 deletions gui/src/components/Txt2Img.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { mustExist } from '@apextoaster/js-utils';
import { Box, Button, Stack } from '@mui/material';
import * as React from 'react';
import { useMutation } from 'react-query';
import { useMutation, useQueryClient } from 'react-query';
import { useStore } from 'zustand';

import { ConfigParams } from '../config.js';
Expand All @@ -22,20 +22,20 @@ export function Txt2Img(props: Txt2ImgProps) {
const { config, model, platform } = props;

async function generateImage() {
setLoading(true);

const output = await client.txt2img({
...params,
model,
platform,
});

pushHistory(output);
setLoading(false);
setLoading(output);
}

const client = mustExist(useContext(ClientContext));
const generate = useMutation(generateImage);
const query = useQueryClient();
const generate = useMutation(generateImage, {
onSuccess: () => query.invalidateQueries({ queryKey: 'ready '}),
});

const state = mustExist(useContext(StateContext));
const params = useStore(state, (s) => s.txt2img);
Expand Down
3 changes: 2 additions & 1 deletion gui/src/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ export const DEFAULT_BRUSH = {
size: 8,
};
export const IMAGE_FILTER = '.bmp, .jpg, .jpeg, .png';
export const STALE_TIME = 3_000;
export const STALE_TIME = 300_000; // 5 minutes
export const POLL_TIME = 5_000; // 5 seconds

export async function loadConfig(): Promise<Config> {
const configPath = new URL('./config.json', window.origin);
Expand Down
6 changes: 1 addition & 5 deletions gui/src/main.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,8 @@ export async function main() {
...createDefaultSlice(...slice),
}), {
name: 'onnx-web',
partialize: (oldState) => ({
...oldState,
loading: false,
}),
storage: createJSONStorage(() => localStorage),
version: 2,
version: 3,
}));

// prep react-query client
Expand Down
9 changes: 6 additions & 3 deletions gui/src/state.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ interface InpaintSlice {
interface HistorySlice {
history: Array<ApiResponse>;
limit: number;
loading: boolean;
loading: Maybe<ApiResponse>;

pushHistory(image: ApiResponse): void;
removeHistory(image: ApiResponse): void;
setLimit(limit: number): void;
setLoading(loading: boolean): void;
setLoading(image: Maybe<ApiResponse>): void;
}

interface DefaultSlice {
Expand Down Expand Up @@ -130,14 +130,17 @@ export function createStateSlices(base: ConfigParams) {
const createHistorySlice: StateCreator<OnnxState, [], [], HistorySlice> = (set) => ({
history: [],
limit: 4,
loading: false,
// eslint-disable-next-line no-null/no-null
loading: null,
pushHistory(image) {
set((prev) => ({
...prev,
history: [
image,
...prev.history,
],
// eslint-disable-next-line no-null/no-null
loading: null,
}));
},
removeHistory(image) {
Expand Down

0 comments on commit c36dadd

Please sign in to comment.