Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Option to set a custom random seed #10

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,8 @@ __pycache__/
/docs/harald_24000.wav
/docs/harald.wav
/models
/bark/assets/prompts/MeMyselfAndI.npz
.venv/*
build/*
Outputs/*
suno_bark.egg-info/*
/bark/assets/prompts/MeMyselfAndI.npz
2 changes: 1 addition & 1 deletion bark/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .api import generate_audio, text_to_semantic, semantic_to_waveform, save_as_prompt
from .api import generate_audio, text_to_semantic, semantic_to_waveform, save_as_prompt, set_seed
from .generation import SAMPLE_RATE, preload_models
45 changes: 45 additions & 0 deletions bark/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

import numpy as np

import torch
import random
import os

from .generation import codec_decode, generate_coarse, generate_fine, generate_text_semantic


Expand Down Expand Up @@ -123,3 +127,44 @@ def generate_audio(
else:
audio_arr = out
return audio_arr

def set_seed(seed: int = 0):
"""Set the seed

seed = 0 Generate a random seed
seed = -1 Disable deterministic algorithms
0 < seed < 2**32 Set the seed

Args:
seed: integer to use as seed

Returns:
integer used as seed
"""

original_seed = seed

# See for more informations: https://pytorch.org/docs/stable/notes/randomness.html
if seed == -1:
# Disable deterministic
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
else:
# Enable deterministic
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

if seed <= 0:
# Generate random seed
# Use default_rng() because it is independent of np.random.seed()
seed = np.random.default_rng().integers(1, 2**32 - 1)

assert(0 < seed and seed < 2**32)

np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
os.environ["PYTHONHASHSEED"] = str(seed)

return original_seed if original_seed != 0 else seed
25 changes: 19 additions & 6 deletions webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from xml.sax import saxutils
#import nltk

from bark import SAMPLE_RATE, generate_audio
from bark import SAMPLE_RATE, generate_audio, set_seed
from bark.clonevoice import clone_voice
from bark.generation import SAMPLE_RATE, preload_models, codec_decode, generate_coarse, generate_fine, generate_text_semantic
from scipy.io.wavfile import write as write_wav
Expand Down Expand Up @@ -54,7 +54,7 @@ def generate_with_settings(text_prompt, semantic_temp=0.7, semantic_top_k=50, se
return full_generation, codec_decode(x_fine_gen)
return codec_decode(x_fine_gen)

def generate_text_to_speech(text, selected_speaker, text_temp, waveform_temp, quick_generation, complete_settings, progress=gr.Progress(track_tqdm=True)):
def generate_text_to_speech(text, selected_speaker, text_temp, waveform_temp, quick_generation, complete_settings, random_seed_number, progress=gr.Progress(track_tqdm=True)):
if text == None or len(text) < 1:
raise gr.Error('No text entered!')

Expand All @@ -80,6 +80,8 @@ def generate_text_to_speech(text, selected_speaker, text_temp, waveform_temp, qu
use_coarse_history_prompt = "Use coarse history" in complete_settings
use_fine_history_prompt = "Use fine history" in complete_settings
use_last_generation_as_history = "Use last generation as history" in complete_settings
random_seed = int(random_seed_number)

progress(0, desc="Generating")

silenceshort = np.zeros(int(0.25 * SAMPLE_RATE), dtype=np.float32) # quarter second of silence
Expand All @@ -91,6 +93,8 @@ def generate_text_to_speech(text, selected_speaker, text_temp, waveform_temp, qu
list_speak = create_clips_from_ssml(text)
prev_speaker = None
for i, clip in tqdm(enumerate(list_speak), total=len(list_speak)):
# set seed for consistent generation
set_seed(random_seed)
selected_speaker = clip[0]
# Add pause break between speakers
if i > 0 and selected_speaker != prev_speaker:
Expand All @@ -109,6 +113,8 @@ def generate_text_to_speech(text, selected_speaker, text_temp, waveform_temp, qu
else:
texts = split_and_recombine_text(text)
for i, text in tqdm(enumerate(texts), total=len(texts)):
# set seed for consistent generation
set_seed(random_seed)
print(f"\nGenerating Text ({i+1}/{len(texts)}) -> {selected_speaker}:`{text}`")
if quick_generation == True:
audio_array = generate_audio(text, selected_speaker, text_temp, waveform_temp)
Expand Down Expand Up @@ -150,6 +156,10 @@ def generate_text_to_speech(text, selected_speaker, text_temp, waveform_temp, qu
if text[-1] in "!?.\n" and i > 1:
all_parts += [silenceshort.copy()]

#all_parts += [audio_array, silencelong.copy()]
# reset seed
set_seed(-1)

# save & play audio
result = create_filename(OUTPUTFOLDER, "final",".wav")
save_wav(np.concatenate(all_parts), result)
Expand Down Expand Up @@ -298,9 +308,12 @@ def convert_text_to_ssml(text, selected_speaker):
with gr.Column():
quick_gen_checkbox = gr.Checkbox(label="Quick Generation", value=True)
with gr.Column():
settings_checkboxes = ["Use semantic history", "Use coarse history", "Use fine history", "Use last generation as history"]
complete_settings = gr.CheckboxGroup(choices=settings_checkboxes, value=settings_checkboxes, label="Detailed Generation Settings", type="value", interactive=True, visible=False)
quick_gen_checkbox.change(fn=on_quick_gen_changed, inputs=quick_gen_checkbox, outputs=complete_settings)
with gr.Row():
settings_checkboxes = ["Use semantic history", "Use coarse history", "Use fine history", "Use last generation as history"]
complete_settings = gr.CheckboxGroup(choices=settings_checkboxes, value=settings_checkboxes, label="Detailed Generation Settings", type="value", interactive=True, visible=False)
random_seed_number = gr.inputs.Number(label="Random Seed", default=-1)
#random_seed_settings = gr.NumberGroup([random_seed_number], label="Random Seed Settings", interactive=True, visible=False)
quick_gen_checkbox.change(fn=on_quick_gen_changed, inputs=quick_gen_checkbox, outputs=[complete_settings])

with gr.Row():
with gr.Column():
Expand All @@ -321,7 +334,7 @@ def convert_text_to_ssml(text, selected_speaker):
dummy = gr.Text(label="Progress")

convert_to_ssml_button.click(convert_text_to_ssml, inputs=[input_text, speaker],outputs=input_text)
tts_create_button.click(generate_text_to_speech, inputs=[input_text, speaker, text_temp, waveform_temp, quick_gen_checkbox, complete_settings],outputs=output_audio)
tts_create_button.click(generate_text_to_speech, inputs=[input_text, speaker, text_temp, waveform_temp, quick_gen_checkbox, complete_settings, random_seed_number],outputs=output_audio)
# Javascript hack to display modal confirmation dialog
js = "(x) => confirm('Are you sure? This will remove all files from output folder')"
button_delete_files.click(None, None, hidden_checkbox, _js=js)
Expand Down