Skip to content

Commit

Permalink
React-stable-audio (#377)
Browse files Browse the repository at this point in the history
* fix react build

* stable audio demo in React UI

* improve layout
  • Loading branch information
rsxdalv authored Sep 20, 2024
1 parent f10143a commit 1a61b1a
Show file tree
Hide file tree
Showing 8 changed files with 341 additions and 4 deletions.
4 changes: 4 additions & 0 deletions react-ui/src/components/Header.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ export const routes: Route[] = [
href: "/magnet",
text: "MAGNeT",
},
{
href: "/stable-audio",
text: "Stable Audio (Demo)",
},
{
href: "/demucs",
text: "Demucs",
Expand Down
122 changes: 122 additions & 0 deletions react-ui/src/components/StableAudioInputs.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import React from "react";
import { StableAudioParams } from "../tabs/StableAudioParams";
import { PromptTextArea } from "./PromptTextArea";
import { HandleChange } from "../types/HandleChange";
import { ParameterSlider } from "./GenericSlider";
import { ModelDropdown } from "./component/ModelDropdown";

const StableAudioModels = ({
params,
onChange,
label,
name,
}: {
params: StableAudioParams;
onChange: HandleChange;
label: string;
name: string;
}) => (
<div />
// <ModelDropdown
// name={name}
// label={label}
// options={params.model_name}
// value={params.model_name}
// onChange={onChange}
// />
);

interface StableAudioInputsProps {
stableAudioParams: StableAudioParams;
handleChange: HandleChange;
setStableAudioParams: React.Dispatch<React.SetStateAction<StableAudioParams>>;
}

export const StableAudioInputs: React.FC<StableAudioInputsProps> = ({
stableAudioParams,
handleChange,
setStableAudioParams,
}) => {
return (
<div className="grid grid-cols-2 gap-4">
<div className="flex flex-col gap-2">
<PromptTextArea
params={stableAudioParams}
handleChange={handleChange}
label="Text"
name="text"
/>
<PromptTextArea
params={stableAudioParams}
handleChange={handleChange}
label="Negative Prompt"
name="negative_prompt"
/>
</div>
<div className="grid grid-cols-2 gap-2">
<div className="flex flex-col gap-2 cell">
<ParameterSlider
params={stableAudioParams}
onChange={handleChange}
label="Seconds Start"
name="seconds_start"
min="0"
max="512"
step="1"
/>
<ParameterSlider
params={stableAudioParams}
onChange={handleChange}
label="Seconds Total"
name="seconds_total"
min="0"
max="512"
step="1"
/>
</div>
<div className="flex gap-2 cell">
<ParameterSlider
params={stableAudioParams}
onChange={handleChange}
label="CFG Scale"
name="cfg_scale"
min="0"
max="50"
step="0.01"
decimals={2}
orientation="vertical"
className="h-40"
/>
<ParameterSlider
params={stableAudioParams}
onChange={handleChange}
label="Steps"
name="steps"
min="0"
max="500"
step="1"
orientation="vertical"
className="h-40"
/>
<ParameterSlider
params={stableAudioParams}
onChange={handleChange}
label="Preview Every"
name="preview_every"
min="0"
max="100"
step="1"
orientation="vertical"
className="h-40"
/>
</div>
<StableAudioModels
params={stableAudioParams}
onChange={handleChange}
label="Model"
name="model_name"
/>
</div>
</div>
);
};
1 change: 0 additions & 1 deletion react-ui/src/components/ui/slider.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ const Slider = React.forwardRef<
>(({ className, ...props }, ref) => (
<SliderPrimitive.Root
ref={ref}
// eslint-disable-next-line tailwindcss/no-contradicting-classname
className={cn(
"relative flex touch-none select-none",
"data-[orientation='horizontal']:h-2 data-[orientation='horizontal']:w-full data-[orientation='horizontal']:items-center",
Expand Down
17 changes: 17 additions & 0 deletions react-ui/src/functions/generateWithStableAudio.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import {
StableAudioParams,
StableAudioResult,
} from "../tabs/StableAudioParams";
import { remove_use_random_seed } from "./remove_use_random_seed";

export async function generateWithStableAudio(
stableAudioParams: StableAudioParams
) {
const body = JSON.stringify(remove_use_random_seed(stableAudioParams));
const response = await fetch("/api/gradio/stable_audio_generate", {
method: "POST",
body,
});

return (await response.json()) as StableAudioResult;
}
33 changes: 33 additions & 0 deletions react-ui/src/pages/api/gradio/[name].tsx
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,32 @@ async function musicgen({ melody, model, ...params }) {
return { audio, metadata, folder_root };
}

const stable_audio_generate = async ({ init_audio, text, ...params }) => {
const exampleAudio = await getFile(init_audio);
const result = await gradioPredict("/stable_audio_generate", {
prompt: text, // temporary fix until forking stable-audio
// prompt: "Hello!!",
// negative_prompt: "Hello!!",
// seconds_start: 0,
// seconds_total: 0,
// cfg_scale: 0,
// steps: 1,
// preview_every: 0,
// seed: "Hello!!",
// sampler_type: "dpmpp-2m-sde",
// sigma_min: 0,
// sigma_max: 0,
// cfg_rescale: 0,
// use_init: true,
init_audio: exampleAudio,
// init_noise_level: 0.1,
...params,
});

const [audio, gallery] = result?.data;
return { audio, gallery };
};

const bark_voice_tokenizer_load = ({ tokenizer, use_gpu }) =>
gradioPredict<[string]>("/bark_voice_tokenizer_load", [
tokenizer,
Expand Down Expand Up @@ -547,4 +573,11 @@ const endpoints = {
musicgen_audiogen_unload_model: passThrough(
"/musicgen_audiogen_unload_model"
),
// stable_audio_generate: passThrough("/stable_audio_generate"),
stable_audio_generate,
stable_audio_inpaint: passThrough("/stable_audio_inpaint"),
stable_audio_get_models: () =>
gradioPredict<[GradioChoices]>("/stable_audio_get_models").then((result) =>
extractChoicesTuple(result?.data[0])
),
};
56 changes: 56 additions & 0 deletions react-ui/src/pages/stable-audio.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import React from "react";
import { Template } from "../components/Template";
import { AudioOutput } from "../components/AudioComponents";
import { HyperParameters } from "../components/HyperParameters";
import { GenerationHistorySimple } from "../components/GenerationHistory";
import { useStableAudioPage } from "../tabs/StableAudioParams";
import { StableAudioInputs } from "../components/StableAudioInputs";

const StableAudioPage = () => {
const {
stableAudioParams,
setStableAudioParams,
historyData,
setHistoryData,
consumer: stableAudioConsumer,
handleChange,
funcs,
} = useStableAudioPage();

return (
<Template title="Stable Audio">
<div className="p-4 grid grid-cols-1 gap-4">
<StableAudioInputs
stableAudioParams={stableAudioParams}
handleChange={handleChange}
setStableAudioParams={setStableAudioParams}
/>
<div className="flex gap-2 col-span-2">
<AudioOutput
audioOutput={historyData[0]?.audio}
label="Stable Audio Output"
funcs={funcs}
metadata={historyData[0]}
filter={["sendToStableAudio"]}
/>
<HyperParameters
genParams={stableAudioParams}
consumer={stableAudioConsumer}
prefix="stableAudio"
/>
</div>

<GenerationHistorySimple
name="stableAudio"
setHistoryData={setHistoryData}
historyData={historyData}
funcs={funcs}
nameKey="folder_root"
filter={["sendToStableAudio"]}
/>
</div>
</Template>
);
};

export default StableAudioPage;
106 changes: 106 additions & 0 deletions react-ui/src/tabs/StableAudioParams.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import useLocalStorage from "../hooks/useLocalStorage";
import { useHistory } from "../hooks/useHistory";
import { parseFormChange } from "../data/parseFormChange";
import { useSeedHelper } from "../functions/results/useSeedHelper";
import { favorite } from "../functions/favorite";
import { Seeded } from "../types/Seeded";
import { GradioFile } from "../types/GradioFile";
import { generateWithStableAudio } from "../functions/generateWithStableAudio";

export const stableAudioId = "stableAudioParams.v3";

export type StableAudioParams = Seeded & {
text: string;
negative_prompt: string;
seconds_start: number;
seconds_total: number;
cfg_scale: number;
steps: number;
preview_every: number;
sampler_type: string;
sigma_min: number;
sigma_max: number;
cfg_rescale: number;
use_init: boolean;
init_audio?: string; // GradioFile;
init_noise_level: number;
};

export const initialStableAudioParams: StableAudioParams = {
seed: 0,
use_random_seed: true,

text: "lofi hip hop beats to relax/study to",
negative_prompt: "",
seconds_start: 0,
seconds_total: 60,
cfg_scale: 7,
steps: 100,
preview_every: 0,
sampler_type: "dpmpp-3m-sde",
sigma_min: 0.03,
sigma_max: 500,
cfg_rescale: 0,
use_init: false,
init_audio: undefined,
init_noise_level: 0.1,
};

export type StableAudioResult = {
audio: GradioFile;
folder_root: string;
metadata: {
_version: string;
_hash_version: string;
_type: string;
prompt: string;
negative_prompt: string;
seconds_start: number;
seconds_total: number;
cfg_scale: number;
steps: number;
preview_every: number;
date: string;
seed: string;
};
};

export function useStableAudioPage() {
const [stableAudioParams, setStableAudioParams] = useLocalStorage(
stableAudioId,
initialStableAudioParams
);

const [historyData, setHistoryData] =
useHistory<StableAudioResult>("stableAudio");

const consumer = async (params: StableAudioParams) => {
const result = await generateWithStableAudio(params);
setHistoryData((prev) => [result, ...prev]);
return result;
};

const funcs = {
favorite: (metadata: any) => favorite(metadata),
useSeed: useSeedHelper(setStableAudioParams),
useParameters: (_url: string, data?: StableAudioResult) => {
const params = data?.metadata;
if (!params) return;
setStableAudioParams({
...stableAudioParams,
...params,
seed: Number(params.seed),
});
},
};

return {
stableAudioParams,
setStableAudioParams,
historyData,
setHistoryData,
consumer,
handleChange: parseFormChange(setStableAudioParams),
funcs,
};
}
6 changes: 3 additions & 3 deletions tts_webui/stable_audio/stable_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def save_result(audio, *generation_args):
def create_sampling_ui(model_config, inpainting=False):
with gr.Row():
with gr.Column(scale=6):
prompt = gr.Textbox(show_label=False, placeholder="Prompt")
text = gr.Textbox(show_label=False, placeholder="Prompt")
negative_prompt = gr.Textbox(
show_label=False, placeholder="Negative prompt"
)
Expand Down Expand Up @@ -461,7 +461,7 @@ def create_sampling_ui(model_config, inpainting=False):
) # still working on the usefulness of this

inputs = [
prompt,
text,
negative_prompt,
seconds_start_slider,
seconds_total_slider,
Expand Down Expand Up @@ -499,7 +499,7 @@ def create_sampling_ui(model_config, inpainting=False):
)

inputs = [
prompt,
text,
negative_prompt,
seconds_start_slider,
seconds_total_slider,
Expand Down

0 comments on commit 1a61b1a

Please sign in to comment.