Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update hubconf.py to streamline model loading processes and improve configuration management. #62

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 36 additions & 48 deletions hubconf.py
Original file line number Diff line number Diff line change
@@ -1,69 +1,57 @@
dependencies = ['torch', 'torchaudio', 'numpy', 'vocos', 'safetensors']
NourMerey marked this conversation as resolved.
Show resolved Hide resolved

import logging
import os
from pathlib import Path
from safetensors import safe_open

import torch
from inference import Mars5TTS, InferenceConfig

ar_url = "https://github.com/Camb-ai/MARS5-TTS/releases/download/v0.3/mars5_en_checkpoints_ar-2000000.pt"
nar_url = "https://github.com/Camb-ai/MARS5-TTS/releases/download/v0.3/mars5_en_checkpoints_nar-1980000.pt"

ar_sf_url = "https://github.com/Camb-ai/MARS5-TTS/releases/download/v0.3/mars5_en_checkpoints_ar-2000000.safetensors"
nar_sf_url = "https://github.com/Camb-ai/MARS5-TTS/releases/download/v0.3/mars5_en_checkpoints_nar-1980000.safetensors"

def mars5_english(pretrained=True, progress=True, device=None, ckpt_format='safetensors',
ar_path=None, nar_path=None) -> Mars5TTS:
""" Load mars5 english model on `device`, optionally show `progress`. """
if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu'

assert ckpt_format in ['safetensors', 'pt'], "checkpoint format must be 'safetensors' or 'pt'"

logging.info(f"Using device: {device}")
if pretrained == False: raise AssertionError('Only pretrained model currently supported.')
logging.info("Loading AR checkpoint...")

if ar_path is None:
if ckpt_format == 'safetensors':
ar_ckpt = _load_safetensors_ckpt(ar_sf_url, progress=progress)
elif ckpt_format == 'pt':
ar_ckpt = torch.hub.load_state_dict_from_url(
ar_url, progress=progress, check_hash=False, map_location='cpu'
)
else: ar_ckpt = torch.load(str(ar_path), map_location='cpu')

logging.info("Loading NAR checkpoint...")
if nar_path is None:
if ckpt_format == 'safetensors':
nar_ckpt = _load_safetensors_ckpt(nar_sf_url, progress=progress)
elif ckpt_format == 'pt':
nar_ckpt = torch.hub.load_state_dict_from_url(
nar_url, progress=progress, check_hash=False, map_location='cpu'
)
else: nar_ckpt = torch.load(str(nar_path), map_location='cpu')
logging.info("Initializing modules...")
mars5 = Mars5TTS(ar_ckpt, nar_ckpt, device=device)
return mars5, InferenceConfig
dependencies = ['torch', 'torchaudio', 'numpy', 'vocos', 'safetensors']

# Centralized checkpoint URLs for easy management and updates
CHECKPOINT_URLS = {
"ar": "https://github.com/Camb-ai/MARS5-TTS/releases/download/v0.3/mars5_en_checkpoints_ar-2000000.pt",
"nar": "https://github.com/Camb-ai/MARS5-TTS/releases/download/v0.3/mars5_en_checkpoints_nar-1980000.pt",
"ar_sf": "https://github.com/Camb-ai/MARS5-TTS/releases/download/v0.3/mars5_en_checkpoints_ar-2000000.safetensors",
"nar_sf": "https://github.com/Camb-ai/MARS5-TTS/releases/download/v0.3/mars5_en_checkpoints_nar-1980000.safetensors"
NourMerey marked this conversation as resolved.
Show resolved Hide resolved
}

def _load_safetensors_ckpt(url, progress):
""" Loads checkpoint from a safetensors file """
def load_checkpoint(url, progress=True, ckpt_format='pt'):
""" Helper function to download and load a checkpoint, reducing duplication """
hub_dir = torch.hub.get_dir()
model_dir = os.path.join(hub_dir, 'checkpoints')
os.makedirs(model_dir, exist_ok=True)
parts = torch.hub.urlparse(url)
filename = os.path.basename(parts.path)
cached_file = os.path.join(model_dir, filename)

if not os.path.exists(cached_file):
# download it
torch.hub.download_url_to_file(url, cached_file, None, progress=progress)
# load checkpoint

if ckpt_format == 'safetensors':
return _load_safetensors_ckpt(cached_file)
else:
return torch.load(cached_file, map_location='cpu')

def _load_safetensors_ckpt(file_path):
""" Loads a safetensors checkpoint file """
ckpt = {}
with safe_open(cached_file, framework='pt', device='cpu') as f:
with safe_open(file_path, framework='pt', device='cpu') as f:
metadata = f.metadata()
ckpt['vocab'] = {'texttok.model': metadata['texttok.model'], 'speechtok.model': metadata['speechtok.model']}
ckpt['model'] = {}
for k in f.keys(): ckpt['model'][k] = f.get_tensor(k)
ckpt['model'] = {k: f.get_tensor(k) for k in f.keys()}
return ckpt

def mars5_english(pretrained=True, progress=True, device=None, ckpt_format='safetensors', ar_path=None, nar_path=None):
""" Load Mars5 English model on `device`, optionally show `progress`. """
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
logging.info(f"Using device: {device}")

if not pretrained:
raise ValueError('Only pretrained models are currently supported.')

ar_ckpt = load_checkpoint(CHECKPOINT_URLS[f'ar_{ckpt_format}'], progress, ckpt_format) if ar_path is None else torch.load(ar_path, map_location='cpu')
NourMerey marked this conversation as resolved.
Show resolved Hide resolved
nar_ckpt = load_checkpoint(CHECKPOINT_URLS[f'nar_{ckpt_format}'], progress, ckpt_format) if nar_path is None else torch.load(nar_path, map_location='cpu')

logging.info("Initializing models...")
return Mars5TTS(ar_ckpt, nar_ckpt, device=device), InferenceConfig