Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

settings for model locations and new settings tab #90

Merged
merged 2 commits into from
Jul 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ https://rsxdalv.github.io/bark-speaker-directory/

## Changelog
July 21:
* Fix hubert not working with CPU only
* Add Google Colab demo
* Fix hubert not working with CPU only (https://github.com/rsxdalv/tts-generation-webui/pull/87)
* Add Google Colab demo (https://github.com/rsxdalv/tts-generation-webui/pull/88)
* New settings tab and model locations (for advanced users) (https://github.com/rsxdalv/tts-generation-webui/pull/90)

July 19:
* Add Tortoise Optimizations, Thank you https://github.com/manmay-nakhashi https://github.com/rsxdalv/tts-generation-webui/pull/79 (Implements https://github.com/rsxdalv/tts-generation-webui/issues/18)
Expand Down
16 changes: 14 additions & 2 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,20 @@ def reload_config_and_restart_ui():
)
voices_tab(register_use_as_history_button)

settings_tab_bark()
settings_tab_gradio(reload_config_and_restart_ui, gradio_interface_options)
with gr.Tab("Settings"):
from src.settings_tab_gradio import settings_tab_gradio

settings_tab_gradio(reload_config_and_restart_ui, gradio_interface_options)

from src.bark.settings_tab_bark import settings_tab_bark

settings_tab_bark()
from src.utils.model_location_settings_tab import (
model_location_settings_tab,
)

model_location_settings_tab()

remixer_input = simple_remixer_tab()
Joutai.singleton.tabs.render()

Expand Down
41 changes: 19 additions & 22 deletions src/bark/settings_tab_bark.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ def settings_tab_bark() -> None:
with gr.Row(variant="panel"):
gr.Markdown("### Codec:")
codec_use_gpu = gr.Checkbox(
label="Use GPU for codec", value=config["model"]["codec_use_gpu"],
scale=2
label="Use GPU for codec",
value=config["model"]["codec_use_gpu"],
scale=2,
)

save_beacon = gr.Markdown("")
Expand Down Expand Up @@ -86,10 +87,6 @@ def settings_tab_bark() -> None:
label="Offload GPU models to CPU", value=ENV_OFFLOAD_CPU
)

save_environment_button = gr.Button(
value="Save Environment Variables and Exit"
)

def save_environment_variables(
environment_suno_use_small_models,
environment_suno_enable_mps,
Expand All @@ -100,24 +97,24 @@ def save_environment_variables(
)
os.environ["SUNO_ENABLE_MPS"] = str(environment_suno_enable_mps)
os.environ["SUNO_OFFLOAD_CPU"] = str(environment_suno_offload_cpu)
with open("../../.env", "w") as outfile:
outfile.write(
generate_env(
environment_suno_use_small_models,
environment_suno_enable_mps,
environment_suno_offload_cpu,
)
from src.utils.setup_or_recover import write_env

write_env(
generate_env(
environment_suno_use_small_models=environment_suno_use_small_models,
environment_suno_enable_mps=environment_suno_enable_mps,
environment_suno_offload_cpu=environment_suno_offload_cpu,
)
os._exit(0)
)

save_environment_button.click(
fn=save_environment_variables,
inputs=[
environment_suno_use_small_models,
environment_suno_enable_mps,
environment_suno_offload_cpu,
],
)
env_inputs = [
environment_suno_use_small_models,
environment_suno_enable_mps,
environment_suno_offload_cpu,
]

for i in env_inputs:
i.change(fn=save_environment_variables, inputs=env_inputs)

inputs = [
text_use_gpu,
Expand Down
50 changes: 26 additions & 24 deletions src/settings_tab_gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,109 +40,111 @@ def settings_tab_gradio(
] = {
"inline": gr.Checkbox(
label="inline: Display inline in an iframe",
value=gradio_interface_options["inline"],
value=gradio_interface_options.get("inline", None),
),
"inbrowser": gr.Checkbox(
label="inbrowser: Automatically launch in a new tab",
value=gradio_interface_options["inbrowser"],
value=gradio_interface_options.get("inbrowser", None),
),
"share": gr.Checkbox(
label="share: Create a publicly shareable link",
value=gradio_interface_options["share"],
value=gradio_interface_options.get("share", None),
),
"debug": gr.Checkbox(
label="debug: Block the main thread from running",
value=gradio_interface_options["debug"],
value=gradio_interface_options.get("debug", None),
),
"enable_queue": gr.Checkbox(
label="enable_queue: Serve inference requests through a queue",
value=gradio_interface_options["enable_queue"],
value=gradio_interface_options.get("enable_queue", None),
),
"max_threads": gr.Slider(
minimum=1,
maximum=100,
step=1,
label="max_threads: Maximum number of total threads",
value=gradio_interface_options["max_threads"],
value=gradio_interface_options.get("max_threads", None),
),
"auth": gr.Textbox(
label="auth: Username and password required to access interface, username:password",
value=gradio_interface_options["auth"],
value=gradio_interface_options.get("auth", None),
),
"auth_message": gr.Textbox(
label="auth_message: HTML message provided on login page",
value=gradio_interface_options["auth_message"],
value=gradio_interface_options.get("auth_message", None),
),
"prevent_thread_lock": gr.Checkbox(
label="prevent_thread_lock: Block the main thread while the server is running",
value=gradio_interface_options["prevent_thread_lock"],
value=gradio_interface_options.get("prevent_thread_lock", None),
),
"show_error": gr.Checkbox(
label="show_error: Display errors in an alert modal",
value=gradio_interface_options["show_error"],
value=gradio_interface_options.get("show_error", None),
),
"server_name": gr.Textbox(
label="server_name: Make app accessible on local network",
value=gradio_interface_options["server_name"],
value=gradio_interface_options.get("server_name", None),
),
"server_port": gr.Textbox(
label="server_port: Start gradio app on this port",
value=gradio_interface_options["server_port"],
value=gradio_interface_options.get("server_port", None),
),
"show_tips": gr.Checkbox(
label="show_tips: Show tips about new Gradio features",
value=gradio_interface_options["show_tips"],
value=gradio_interface_options.get("show_tips", None),
),
"height": gr.Slider(
minimum=100,
maximum=1000,
step=10,
label="height: Height in pixels of the iframe element",
value=gradio_interface_options["height"],
value=gradio_interface_options.get("height", None),
),
"width": gr.Slider(
minimum=100,
maximum=1000,
step=10,
label="width: Width in pixels of the iframe element",
value=gradio_interface_options["width"],
value=gradio_interface_options.get("width", None),
),
"favicon_path": gr.Textbox(
label="favicon_path: Path to a file (.png, .gif, or .ico) to use as the favicon",
value=gradio_interface_options["favicon_path"],
value=gradio_interface_options.get("favicon_path", None),
),
"ssl_keyfile": gr.Textbox(
label="ssl_keyfile: Path to a file to use as the private key file to create a local server "
"running on https",
value=gradio_interface_options["ssl_keyfile"],
value=gradio_interface_options.get("ssl_keyfile", None),
),
"ssl_certfile": gr.Textbox(
label="ssl_certfile: Path to a file to use as the signed certificate for https",
value=gradio_interface_options["ssl_certfile"],
value=gradio_interface_options.get("ssl_certfile", None),
),
"ssl_keyfile_password": gr.Textbox(
label="ssl_keyfile_password: Password to use with the ssl certificate for https",
value=gradio_interface_options["ssl_keyfile_password"],
value=gradio_interface_options.get(
"ssl_keyfile_password", None
),
),
"ssl_verify": gr.Checkbox(
label="ssl_verify: Skip certificate validation",
value=gradio_interface_options["ssl_verify"],
value=gradio_interface_options.get("ssl_verify", None),
),
"quiet": gr.Checkbox(
label="quiet: Suppress most print statements",
value=gradio_interface_options["quiet"],
value=gradio_interface_options.get("quiet", None),
),
"show_api": gr.Checkbox(
label="show_api: Show the api docs in the footer of the app",
value=gradio_interface_options["show_api"],
value=gradio_interface_options.get("show_api", None),
),
"file_directories": gr.Textbox(
label="file_directories: List of directories that gradio is allowed to serve files from",
value=gradio_interface_options["file_directories"],
value=gradio_interface_options.get("file_directories", None),
),
"_frontend": gr.Checkbox(
label="_frontend: Frontend",
value=gradio_interface_options["_frontend"],
value=gradio_interface_options.get("_frontend", None),
),
}

Expand Down
114 changes: 114 additions & 0 deletions src/utils/model_location_settings_tab.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import gradio as gr

def model_location_settings():
with gr.Column():
gr.Markdown("## Model Location Settings (Experimental!)")

gr.Markdown("- Requires restart to apply")

gr.Markdown(
"""
1. **HUGGINGFACE_HUB_CACHE**:
- This environment variable is used to specify the location of the Hugging Face cache, which stores downloaded models and other assets used by Hugging Face Transformers library.
- By default, the cache directory is usually set to your user's home directory.
- You can customize the cache directory by either providing the `cache_dir` argument in the methods of Hugging Face Transformers library, or by setting this environment variable to the desired path.

2. **HF_HOME**:
- This environment variable is also related to the Hugging Face cache.
- It allows you to set a custom directory for the Hugging Face cache, overriding the default location (usually the user's home directory).
- If `HUGGINGFACE_HUB_CACHE` is not set, this environment variable can be used as an alternative to specify the Hugging Face cache location.

3. **TORCH_HOME**:
- This environment variable allows you to set a custom directory for the Torch Hub cache, similar to `PATH_TO_HUB_DIR`.
- If `PATH_TO_HUB_DIR` is not set, you can use this environment variable to specify the Torch Hub cache location.

4. **XDG_CACHE_HOME**:
- This environment variable allows you to set a custom directory for caching various applications, including the Torch Hub cache.
- If `PATH_TO_HUB_DIR` and `TORCH_HOME` are not set, you can use this environment variable to specify the Torch Hub cache location.

Please note that the "tts-generation-webui Directory" is currently fixed and cannot be changed. This directory contains the TTS (Text-to-Speech) models used by the web UI. The models are stored inside the `data/models/` directory within the `tts-generation-webui` directory.

Using these environment variables allows for better management and sharing of cached assets between different installations and projects, making it easier to access and reuse models and data across different user interfaces and projects.
"""
)
# Hugging Face cache
gr.Markdown("### Hugging Face Cache")
# HUGGINGFACE_HUB_CACHE env variable
model_location_hf_env_var = gr.Textbox(
label="Environment: HUGGINGFACE_HUB_CACHE",
value="",
placeholder="Unset",
)
# HF_HOME env variable
model_location_hf_env_var2 = gr.Textbox(
label="Environment: HF_HOME",
value="",
placeholder="Unset",
)

# Torch Hub cache
gr.Markdown("### Torch Hub Cache")
# model_location_th_set_dir = gr.Textbox(
# label="Environment: PATH_TO_HUB_DIR", value="Default"
# ) # for hub.set_dir(<PATH_TO_HUB_DIR>)
model_location_th_home = gr.Textbox(
label="Environment: TORCH_HOME",
value="",
placeholder="Unset, default: ~/.cache/torch/",
)
model_location_th_xdg = gr.Textbox(
label="Environment: XDG_CACHE_HOME",
value="",
placeholder="Unset, default: ~/.cache/",
)
# tts-generation-webui directory
gr.Markdown("### tts-generation-webui Directory (can't be changed yet)")
gr.Textbox(
label="Model Location (TTS Default)",
value="./tts-generation-webui/data/models/",
interactive=False,
)

inputs = [
model_location_hf_env_var,
model_location_hf_env_var2,
model_location_th_home,
model_location_th_xdg,
]

save_beacon = gr.Markdown()

def save_environment_variables2(
model_location_hf_env_var,
model_location_hf_env_var2,
model_location_th_home,
model_location_th_xdg,
):
import os

os.environ["HUGGINGFACE_HUB_CACHE"] = str(model_location_hf_env_var)
os.environ["HF_HOME"] = str(model_location_hf_env_var2)
os.environ["TORCH_HOME"] = str(model_location_th_home)
os.environ["XDG_CACHE_HOME"] = str(model_location_th_xdg)
from src.utils.setup_or_recover import generate_env, write_env

write_env(
generate_env(
model_location_hf_env_var=model_location_hf_env_var,
model_location_hf_env_var2=model_location_hf_env_var2,
model_location_th_home=model_location_th_home,
model_location_th_xdg=model_location_th_xdg,
)
)
# os._exit(0)
return "saved"

for i in inputs:
i.change(
fn=save_environment_variables2, inputs=inputs, outputs=[save_beacon]
)


def model_location_settings_tab():
with gr.Tab("Model Location Settings"):
model_location_settings()
Loading