Skip to content

Commit

Permalink
automatically use cuda for mms (#379)
Browse files Browse the repository at this point in the history
  • Loading branch information
rsxdalv authored Sep 23, 2024
1 parent 644017e commit 1e3e75a
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions tts_webui/mms/mms_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from transformers import VitsTokenizer, VitsModel
import gradio as gr

from tts_webui.decorators.gradio_dict_decorator import dictionarize, gradio_dict_decorator
from tts_webui.decorators.gradio_dict_decorator import dictionarize
from tts_webui.utils.manage_model_state import manage_model_state
from tts_webui.utils.list_dir_models import unload_model_button
from tts_webui.decorators.decorator_apply_torch_seed import decorator_apply_torch_seed
Expand All @@ -24,11 +24,15 @@

@manage_model_state("mms")
def preload_models_if_needed(language="eng") -> tuple[VitsModel, VitsTokenizer]:
return VitsModel.from_pretrained(
f"facebook/mms-tts-{language}"
), VitsTokenizer.from_pretrained( # type: ignore
f"facebook/mms-tts-{language}"
device = "cuda" if torch.cuda.is_available() else "cpu"
model = VitsModel.from_pretrained( # type: ignore
f"facebook/mms-tts-{language}",
)
model = model.to(device) # type: ignore
tokenizer = VitsTokenizer.from_pretrained( # type: ignore
f"facebook/mms-tts-{language}",
) # type: ignore
return model, tokenizer


@decorator_extension_outer
Expand All @@ -53,10 +57,10 @@ def generate_audio_with_mms(
model.speaking_rate = speaking_rate
model.noise_scale = noise_scale
model.noise_scale_duration = noise_scale_duration
inputs = tokenizer(text=text, return_tensors="pt")
inputs = tokenizer(text=text, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model(**inputs) # type: ignore
waveform = outputs.waveform[0].numpy().squeeze()
waveform = outputs.waveform[0].cpu().numpy().squeeze()
return {
"audio_out": (model.config.sampling_rate, waveform),
}
Expand Down

0 comments on commit 1e3e75a

Please sign in to comment.