Skip to content

Commit

Permalink
fix(gui): improve performance while using image controls
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jan 12, 2023
1 parent cff3210 commit 35e2e1d
Show file tree
Hide file tree
Showing 9 changed files with 159 additions and 72 deletions.
9 changes: 9 additions & 0 deletions gui/esbuild.js
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import { build } from 'esbuild';
import { createRequire } from 'node:module';
import { join } from 'path';
import alias from 'esbuild-plugin-alias';

const require = createRequire(import.meta.url);
const root = process.cwd();

build({
Expand All @@ -14,5 +17,11 @@ build({
keepNames: true,
outdir: 'out/bundle/',
platform: 'browser',
plugins: [
alias({
'react-dom$': 'react-dom/profiling',
'scheduler/tracing': 'scheduler/tracing-profiling',
})
],
sourcemap: true,
}).catch(() => process.exit(1));
1 change: 1 addition & 0 deletions gui/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"chai": "^4.3.7",
"chai-as-promised": "^7.1.1",
"esbuild": "^0.16.14",
"esbuild-plugin-alias": "^0.2.1",
"eslint": "^8.31.0",
"eslint-plugin-chai": "^0.0.1",
"eslint-plugin-chai-expect": "^3.0.0",
Expand Down
15 changes: 9 additions & 6 deletions gui/src/components/ImageHistory.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,31 @@ import { ImageCard } from './ImageCard.js';
import { LoadingCard } from './LoadingCard.js';

export function ImageHistory() {
const state = useStore(mustExist(useContext(StateContext)));
const { images } = state.history;
const history = useStore(mustExist(useContext(StateContext)), (state) => state.history);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setHistory = useStore(mustExist(useContext(StateContext)), (state) => state.setHistory);

const { images } = history;

const children = [];

if (state.history.loading) {
if (history.loading) {
children.push(<LoadingCard key='loading' height={512} width={512} />); // TODO: get dimensions from config
}

function removeHistory(image: ApiResponse) {
state.setHistory(images.filter((item) => image.output !== item.output));
setHistory(images.filter((item) => image.output !== item.output));
}

if (images.length > 0) {
children.push(...images.map((item) => <ImageCard key={item.output} value={item} onDelete={removeHistory} />));
} else {
if (state.history.loading === false) {
if (history.loading === false) {
children.push(<div>No results. Press Generate.</div>);
}
}

const limited = children.slice(0, state.history.limit);
const limited = children.slice(0, history.limit);

return <Grid container spacing={2}>{limited.map((child, idx) => <Grid item key={idx} xs={6}>{child}</Grid>)}</Grid>;
}
26 changes: 17 additions & 9 deletions gui/src/components/Img2Img.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -23,40 +23,48 @@ export function Img2Img(props: Img2ImgProps) {
const { config, model, platform } = props;

async function uploadSource() {
state.setLoading(true);
setLoading(true);

const output = await client.img2img({
...state.img2img,
...params,
model,
platform,
source: mustExist(source), // TODO: show an error if this doesn't exist
});

state.pushHistory(output);
state.setLoading(false);
pushHistory(output);
setLoading(false);
}

const client = mustExist(useContext(ClientContext));
const upload = useMutation(uploadSource);
const state = useStore(mustExist(useContext(StateContext)));

const state = mustExist(useContext(StateContext));
const params = useStore(state, (s) => s.img2img);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setImg2Img = useStore(state, (s) => s.setImg2Img);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setLoading = useStore(state, (s) => s.setLoading);
// eslint-disable-next-line @typescript-eslint/unbound-method
const pushHistory = useStore(state, (s) => s.pushHistory);

const [source, setSource] = useState<File>();

return <Box>
<Stack spacing={2}>
<ImageInput filter={IMAGE_FILTER} label='Source' onChange={setSource} />
<ImageControl config={config} params={state.img2img} onChange={(newParams) => {
state.setImg2Img(newParams);
<ImageControl config={config} params={params} onChange={(newParams) => {
setImg2Img(newParams);
}} />
<NumericField
decimal
label='Strength'
min={config.strength.min}
max={config.strength.max}
step={config.strength.step}
value={state.img2img.strength}
value={params.strength}
onChange={(value) => {
state.setImg2Img({
setImg2Img({
strength: value,
});
}}
Expand Down
22 changes: 15 additions & 7 deletions gui/src/components/Inpaint.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -70,18 +70,18 @@ export function Inpaint(props: InpaintProps) {

async function uploadSource() {
const canvas = mustExist(canvasRef.current);
state.setLoading(true);
setLoading(true);
return new Promise<void>((res, rej) => {
canvas.toBlob((blob) => {
client.inpaint({
...state.inpaint,
...params,
model,
platform,
mask: mustExist(blob),
source: mustExist(source),
}).then((output) => {
state.pushHistory(output);
state.setLoading(false);
pushHistory(output);
setLoading(false);
res();
}).catch((err) => rej(err));
});
Expand Down Expand Up @@ -138,10 +138,18 @@ export function Inpaint(props: InpaintProps) {
ctx.putImageData(image, 0, 0);
}

const state = mustExist(useContext(StateContext));
const params = useStore(state, (s) => s.inpaint);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setInpaint = useStore(state, (s) => s.setInpaint);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setLoading = useStore(state, (s) => s.setLoading);
// eslint-disable-next-line @typescript-eslint/unbound-method
const pushHistory = useStore(state, (s) => s.pushHistory);

const upload = useMutation(uploadSource);
// eslint-disable-next-line no-null/no-null
const canvasRef = useRef<HTMLCanvasElement>(null);
const state = useStore(mustExist(useContext(StateContext)));

// painting state
const [clicks, setClicks] = useState<Array<Point>>([]);
Expand Down Expand Up @@ -259,9 +267,9 @@ export function Inpaint(props: InpaintProps) {
</Stack>
<ImageControl
config={config}
params={state.inpaint}
params={params}
onChange={(newParams) => {
state.setInpaint(newParams);
setInpaint(newParams);
}}
/>
<Button onClick={() => upload.mutate()}>Generate</Button>
Expand Down
30 changes: 19 additions & 11 deletions gui/src/components/Txt2Img.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -22,36 +22,44 @@ export function Txt2Img(props: Txt2ImgProps) {
const { config, model, platform } = props;

async function generateImage() {
state.setLoading(true);
setLoading(true);

const output = await client.txt2img({
...state.txt2img,
...params,
model,
platform,
});

state.pushHistory(output);
state.setLoading(false);
pushHistory(output);
setLoading(false);
}

const client = mustExist(useContext(ClientContext));
const generate = useMutation(generateImage);
const state = useStore(mustExist(useContext(StateContext)));

const state = mustExist(useContext(StateContext));
const params = useStore(state, (s) => s.txt2img);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setTxt2Img = useStore(state, (s) => s.setTxt2Img);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setLoading = useStore(state, (s) => s.setLoading);
// eslint-disable-next-line @typescript-eslint/unbound-method
const pushHistory = useStore(state, (s) => s.pushHistory);

return <Box>
<Stack spacing={2}>
<ImageControl config={config} params={state.txt2img} onChange={(newParams) => {
state.setTxt2Img(newParams);
<ImageControl config={config} params={params} onChange={(newParams) => {
setTxt2Img(newParams);
}} />
<Stack direction='row' spacing={4}>
<NumericField
label='Width'
min={config.width.min}
max={config.width.max}
step={config.width.step}
value={state.txt2img.width}
value={params.width}
onChange={(value) => {
state.setTxt2Img({
setTxt2Img({
width: value,
});
}}
Expand All @@ -61,9 +69,9 @@ export function Txt2Img(props: Txt2ImgProps) {
min={config.height.min}
max={config.height.max}
step={config.height.step}
value={state.txt2img.height}
value={params.height}
onChange={(value) => {
state.setTxt2Img({
setTxt2Img({
height: value,
});
}}
Expand Down
78 changes: 39 additions & 39 deletions gui/src/main.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -66,45 +66,6 @@ export async function main() {
inpaint: {
...defaults,
},
setLimit(limit) {
set((oldState) => ({
...oldState,
history: {
...oldState.history,
limit,
},
}));
},
setLoading(loading) {
set((oldState) => ({
...oldState,
history: {
...oldState.history,
loading,
},
}));
},
pushHistory(newImage: ApiResponse) {
set((oldState) => ({
...oldState,
history: {
...oldState.history,
images: [
newImage,
...oldState.history.images,
].slice(0, oldState.history.limit),
},
}));
},
setHistory(newHistory: Array<ApiResponse>) {
set((oldState) => ({
...oldState,
history: {
...oldState.history,
images: newHistory,
},
}));
},
setDefaults(newParams) {
set((oldState) => ({
...oldState,
Expand Down Expand Up @@ -168,6 +129,45 @@ export async function main() {
},
}));
},
setLimit(limit) {
set((oldState) => ({
...oldState,
history: {
...oldState.history,
limit,
},
}));
},
setLoading(loading) {
set((oldState) => ({
...oldState,
history: {
...oldState.history,
loading,
},
}));
},
pushHistory(newImage: ApiResponse) {
set((oldState) => ({
...oldState,
history: {
...oldState.history,
images: [
newImage,
...oldState.history.images,
].slice(0, oldState.history.limit),
},
}));
},
setHistory(newHistory: Array<ApiResponse>) {
set((oldState) => ({
...oldState,
history: {
...oldState.history,
images: newHistory,
},
}));
},
}), {
name: 'onnx-web',
partialize: (oldState) => ({
Expand Down
45 changes: 45 additions & 0 deletions gui/src/state.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import { ApiResponse, BaseImgParams, Img2ImgParams, InpaintParams, Txt2ImgParams } from './api/client.js';
import { ConfigState } from './config.js';

interface TabState<TabParams extends BaseImgParams> {
params: ConfigState<Required<TabParams>>;

reset(): void;
update(params: Partial<ConfigState<Required<TabParams>>>): void;
}

interface OnnxState {
defaults: {
params: Required<BaseImgParams>;
update(newParams: Partial<BaseImgParams>): void;
};
txt2img: {
params: ConfigState<Required<Txt2ImgParams>>;

reset(): void;
update(newParams: Partial<ConfigState<Required<Txt2ImgParams>>>): void;
};
img2img: {
params: ConfigState<Required<Img2ImgParams>>;

reset(): void;
update(newParams: Partial<ConfigState<Required<Img2ImgParams>>>): void;
};
inpaint: {
params: ConfigState<Required<InpaintParams>>;

reset(): void;
update(newParams: Partial<ConfigState<Required<InpaintParams>>>): void;
};
history: {
images: Array<ApiResponse>;
limit: number;
loading: boolean;

setLimit(limit: number): void;
setLoading(loading: boolean): void;
setHistory(newHistory: Array<ApiResponse>): void;
pushHistory(newImage: ApiResponse): void;
};
}

5 changes: 5 additions & 0 deletions gui/yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -1212,6 +1212,11 @@ es-to-primitive@^1.2.1:
is-date-object "^1.0.1"
is-symbol "^1.0.2"

esbuild-plugin-alias@^0.2.1:
version "0.2.1"
resolved "https://registry.yarnpkg.com/esbuild-plugin-alias/-/esbuild-plugin-alias-0.2.1.tgz#45a86cb941e20e7c2bc68a2bea53562172494fcb"
integrity sha512-jyfL/pwPqaFXyKnj8lP8iLk6Z0m099uXR45aSN8Av1XD4vhvQutxxPzgA2bTcAwQpa1zCXDcWOlhFgyP3GKqhQ==

esbuild@^0.16.14:
version "0.16.14"
resolved "https://registry.yarnpkg.com/esbuild/-/esbuild-0.16.14.tgz#366249a0a0fd431d3ab706195721ef1014198919"
Expand Down

0 comments on commit 35e2e1d

Please sign in to comment.