Skip to content

Commit

Permalink
add more model types to models tab
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed May 6, 2023
1 parent 6e78f40 commit 8c88fcd
Show file tree
Hide file tree
Showing 9 changed files with 260 additions and 18 deletions.
25 changes: 16 additions & 9 deletions gui/src/components/input/EditableList.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,40 @@ const { useState } = React;
export interface EditableListProps<T> {
items: Array<T>;

newItem: (s: string) => T;
newItem: (l: string, s: string) => T;
renderItem: (t: T) => React.ReactElement;
setItems: (ts: Array<T>) => void;
}

export function EditableList<T>(props: EditableListProps<T>) {
const { items, newItem, renderItem, setItems } = props;
const [nextItem, setNextItem] = useState('');
const [nextLabel, setNextLabel] = useState('');
const [nextSource, setNextSource] = useState('');

return <Stack>
{items.map((it, idx) => <Stack direction='row' key={idx}>
return <Stack spacing={2}>
{items.map((it, idx) => <Stack direction='row' key={idx} spacing={2}>
{renderItem(it)}
<Button onClick={() => setItems([
...items.slice(0, idx),
...items.slice(idx + 1, items.length),
])}>Remove</Button>
</Stack>)}
<Stack direction='row'>
<Stack direction='row' spacing={2}>
<TextField
label='Label'
variant='outlined'
value={nextLabel}
onChange={(event) => setNextLabel(event.target.value)}
/>
<TextField
label='Source'
variant='outlined'
value={nextItem}
onChange={(event) => setNextItem(event.target.value)}
value={nextSource}
onChange={(event) => setNextSource(event.target.value)}
/>
<Button onClick={() => {
setItems([...items, newItem(nextItem)]);
setNextItem('');
setItems([...items, newItem(nextLabel, nextSource)]);
setNextLabel('');
}}>New</Button>
</Stack>
</Stack>;
Expand Down
17 changes: 17 additions & 0 deletions gui/src/components/input/model/CorrectionModel.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import { Stack, TextField } from '@mui/material';
import * as React from 'react';

import { CorrectionModel } from '../../../types';

export interface CorrectionModelInputProps {
model: CorrectionModel;
}

export function CorrectionModelInput(props: CorrectionModelInputProps) {
const { model } = props;

return <Stack direction='row' spacing={2}>
<TextField value={model.label} />
<TextField value={model.source} />
</Stack>;
}
21 changes: 21 additions & 0 deletions gui/src/components/input/model/DiffusionModel.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import { MenuItem, Select, Stack, TextField } from '@mui/material';
import * as React from 'react';

import { DiffusionModel } from '../../../types';

export interface DiffusionModelInputProps {
model: DiffusionModel;
}

export function DiffusionModelInput(props: DiffusionModelInputProps) {
const { model } = props;

return <Stack direction='row' spacing={2}>
<TextField label='Label' value={model.label} />
<TextField label='Source' value={model.source} />
<Select value={model.format} label='Format'>
<MenuItem value='ckpt'>ckpt</MenuItem>
<MenuItem value='safetensors'>safetensors</MenuItem>
</Select>
</Stack>;
}
26 changes: 26 additions & 0 deletions gui/src/components/input/model/ExtraNetwork.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import { MenuItem, Select, Stack, TextField } from '@mui/material';
import * as React from 'react';

import { ExtraNetwork } from '../../../types';

export interface ExtraNetworkInputProps {
model: ExtraNetwork;
}

export function ExtraNetworkInput(props: ExtraNetworkInputProps) {
const { model } = props;

return <Stack direction='row' spacing={2}>
<TextField value={model.label} label='Label' />
<TextField value={model.source} label='Source' />
<Select value={model.type} label='Type'>
<MenuItem value='inversion'>Textual Inversion</MenuItem>
<MenuItem value='lora'>LoRA or LyCORIS</MenuItem>
</Select>
<Select value={model.model} label='Model'>
<MenuItem value='sd-scripts'>LoRA - sd-scripts</MenuItem>
<MenuItem value='concept'>TI - concept</MenuItem>
<MenuItem value='embeddings'>TI - embeddings</MenuItem>
</Select>
</Stack>;
}
17 changes: 17 additions & 0 deletions gui/src/components/input/model/ExtraSource.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import * as React from 'react';
import { Stack, TextField } from '@mui/material';

import { ExtraSource } from '../../../types';

export interface ExtraSourceInputProps {
model: ExtraSource;
}

export function ExtraSourceInput(props: ExtraSourceInputProps) {
const { model } = props;

return <Stack direction='row' spacing={2}>
<TextField label='dest' value={model.dest} />
<TextField label='source' value={model.source} />
</Stack>;
}
17 changes: 17 additions & 0 deletions gui/src/components/input/model/UpscalingModel.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import { Stack, TextField } from '@mui/material';
import * as React from 'react';

import { UpscalingModel } from '../../../types.js';

export interface UpscalingModelInputProps {
model: UpscalingModel;
}

export function UpscalingModelInput(props: UpscalingModelInputProps) {
const { model } = props;

return <Stack direction='row' spacing={2}>
<TextField value={model.label} />
<TextField value={model.source} />
</Stack>;
}
81 changes: 77 additions & 4 deletions gui/src/components/tab/Models.tsx
Original file line number Diff line number Diff line change
@@ -1,27 +1,41 @@
import { mustExist } from '@apextoaster/js-utils';
import { Accordion, AccordionDetails, AccordionSummary, Button, Stack } from '@mui/material';
import * as React from 'react';
import _ from 'lodash';
import { useStore } from 'zustand';

import { StateContext } from '../../state.js';
import { EditableList } from '../input/EditableList';
import { DiffusionModelInput } from '../input/model/DiffusionModel.js';
import { SafetensorFormat } from '../../types.js';
import { CorrectionModelInput } from '../input/model/CorrectionModel.js';
import { UpscalingModelInput } from '../input/model/UpscalingModel.js';
import { ExtraSourceInput } from '../input/model/ExtraSource.js';
import { ExtraNetworkInput } from '../input/model/ExtraNetwork.js';
// eslint-disable-next-line @typescript-eslint/unbound-method
const { kebabCase } = _;

export function Models() {
const state = mustExist(React.useContext(StateContext));
const extras = useStore(state, (s) => s.extras);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setExtras = useStore(state, (s) => s.setExtras);

return <Stack>
return <Stack spacing={2}>
<Accordion>
<AccordionSummary>
Diffusion Models
</AccordionSummary>
<AccordionDetails>
<EditableList
items={extras.diffusion}
newItem={(s) => s}
renderItem={(t) => <div key={t}>{t}</div>}
newItem={(l, s) => ({
format: 'safetensors' as SafetensorFormat,
label: l,
name: kebabCase(l),
source: s,
})}
renderItem={(t) => <DiffusionModelInput model={t}/>}
setItems={(diffusion) => setExtras({
...extras,
diffusion,
Expand All @@ -34,27 +48,86 @@ export function Models() {
Correction Models
</AccordionSummary>
<AccordionDetails>
<EditableList
items={extras.correction}
newItem={(l, s) => ({
format: 'safetensors' as SafetensorFormat,
label: l,
name: kebabCase(l),
source: s,
})}
renderItem={(t) => <CorrectionModelInput model={t}/>}
setItems={(correction) => setExtras({
...extras,
correction,
})}
/>
</AccordionDetails>
</Accordion>
<Accordion>
<AccordionSummary>
Upscaling Models
</AccordionSummary>
<AccordionDetails>
<EditableList
items={extras.upscaling}
newItem={(l, s) => ({
format: 'safetensors' as SafetensorFormat,
label: l,
name: kebabCase(l),
scale: 4,
source: s,
})}
renderItem={(t) => <UpscalingModelInput model={t}/>}
setItems={(upscaling) => setExtras({
...extras,
upscaling,
})}
/>
</AccordionDetails>
</Accordion>
<Accordion>
<AccordionSummary>
Additional Networks
Extra Networks
</AccordionSummary>
<AccordionDetails>
<EditableList
items={extras.networks}
newItem={(l, s) => ({
format: 'safetensors' as SafetensorFormat,
label: l,
model: 'embeddings' as const,
name: kebabCase(l),
source: s,
type: 'inversion' as const,
})}
renderItem={(t) => <ExtraNetworkInput model={t}/>}
setItems={(networks) => setExtras({
...extras,
networks,
})}
/>
</AccordionDetails>
</Accordion>
<Accordion>
<AccordionSummary>
Other Sources
</AccordionSummary>
<AccordionDetails>
<EditableList
items={extras.sources}
newItem={(l, s) => ({
format: 'safetensors' as SafetensorFormat,
label: l,
name: kebabCase(l),
source: s,
})}
renderItem={(t) => <ExtraSourceInput model={t}/>}
setItems={(sources) => setExtras({
...extras,
sources,
})}
/>
</AccordionDetails>
</Accordion>
<Button color='warning'>Save & Convert</Button>
Expand Down
10 changes: 5 additions & 5 deletions gui/src/state.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import {
UpscaleReqParams,
} from './client/api.js';
import { Config, ConfigFiles, ConfigState, ServerParams } from './config.js';
import { ExtrasFile } from './types.js';

export type Theme = PaletteMode | ''; // tri-state, '' is unset

Expand All @@ -38,10 +39,6 @@ interface HistoryItem {
retry: RetryParams;
}

interface ExtrasFile {
diffusion: Array<string>;
}

interface BrushSlice {
brush: BrushParams;

Expand Down Expand Up @@ -556,15 +553,18 @@ export function createStateSlices(server: ServerParams) {
next.resetTxt2Img();
next.resetUpscaleTab();
next.resetBlend();
// TODO: reset more stuff
return next;
});
},
});

const createExtraSlice: Slice<ExtraSlice> = (set) => ({
extras: {
correction: [],
diffusion: [],
networks: [],
sources: [],
upscaling: [],
},
setExtras(extras) {
set((prev) => ({
Expand Down
64 changes: 64 additions & 0 deletions gui/src/types.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
export type TorchFormat = 'bin' | 'ckpt' | 'pt' | 'pth';
export type OnnxFormat = 'onnx';
export type SafetensorFormat = 'safetensors';

export interface BaseModel {
/**
* Format of the model, used when downloading files that may not have a format in their URL.
*/
format: OnnxFormat | SafetensorFormat | TorchFormat;

/**
* Localized label of the model.
*/
label: string;

/**
* Filename of the model.
*/
name: string;

/**
* Source URL or local path.
*/
source: string;
}

export interface DiffusionModel extends BaseModel {
config?: string;
image_size?: string;
inversions?: Array<unknown>;
loras?: Array<unknown>;
pipeline?: string;
vae?: string;
version?: string;
}

export interface UpscalingModel extends BaseModel {
model?: 'bsrgan' | 'resrgan' | 'swinir';
scale: number;
}

export interface CorrectionModel extends BaseModel {
model?: 'codeformer' | 'gfpgan';
}

export interface ExtraNetwork extends BaseModel {
model: 'concept' | 'embeddings' | 'cloneofsimo' | 'sd-scripts';
type: 'inversion' | 'lora';
}

export interface ExtraSource {
dest?: string;
format?: string;
name: string;
source: string;
}

export interface ExtrasFile {
correction: Array<CorrectionModel>;
diffusion: Array<DiffusionModel>;
upscaling: Array<UpscalingModel>;
networks: Array<ExtraNetwork>;
sources: Array<ExtraSource>;
}

0 comments on commit 8c88fcd

Please sign in to comment.