Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update hubconf.py to streamline model loading processes and improve configuration management. #62

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 61 additions & 46 deletions hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,66 +4,81 @@
import os
from pathlib import Path
from safetensors import safe_open

import torch
from inference import Mars5TTS, InferenceConfig

ar_url = "https://github.com/Camb-ai/MARS5-TTS/releases/download/v0.3/mars5_en_checkpoints_ar-2000000.pt"
nar_url = "https://github.com/Camb-ai/MARS5-TTS/releases/download/v0.3/mars5_en_checkpoints_nar-1980000.pt"

ar_sf_url = "https://github.com/Camb-ai/MARS5-TTS/releases/download/v0.3/mars5_en_checkpoints_ar-2000000.safetensors"
nar_sf_url = "https://github.com/Camb-ai/MARS5-TTS/releases/download/v0.3/mars5_en_checkpoints_nar-1980000.safetensors"

def mars5_english(pretrained=True, progress=True, device=None, ckpt_format='safetensors',
ar_path=None, nar_path=None) -> Mars5TTS:
""" Load mars5 english model on `device`, optionally show `progress`. """
if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu'

assert ckpt_format in ['safetensors', 'pt'], "checkpoint format must be 'safetensors' or 'pt'"

logging.info(f"Using device: {device}")
if pretrained == False: raise AssertionError('Only pretrained model currently supported.')
logging.info("Loading AR checkpoint...")

if ar_path is None:
if ckpt_format == 'safetensors':
ar_ckpt = _load_safetensors_ckpt(ar_sf_url, progress=progress)
elif ckpt_format == 'pt':
ar_ckpt = torch.hub.load_state_dict_from_url(
ar_url, progress=progress, check_hash=False, map_location='cpu'
)
else: ar_ckpt = torch.load(str(ar_path), map_location='cpu')
# Centralized checkpoint URLs for easy management and updates
CHECKPOINT_URLS = {
"ar": "https://github.com/Camb-ai/MARS5-TTS/releases/download/v0.4/mars5_en_checkpoints_ar-3000000.pt",
"nar": "https://github.com/Camb-ai/MARS5-TTS/releases/download/v0.3/mars5_en_checkpoints_nar-1980000.pt",
"ar_sf": "https://github.com/Camb-ai/MARS5-TTS/releases/download/v0.4/mars5_en_checkpoints_ar-3000000.safetensors",
"nar_sf": "https://github.com/Camb-ai/MARS5-TTS/releases/download/v0.3/mars5_en_checkpoints_nar-1980000.safetensors"
}

logging.info("Loading NAR checkpoint...")
if nar_path is None:
if ckpt_format == 'safetensors':
nar_ckpt = _load_safetensors_ckpt(nar_sf_url, progress=progress)
elif ckpt_format == 'pt':
nar_ckpt = torch.hub.load_state_dict_from_url(
nar_url, progress=progress, check_hash=False, map_location='cpu'
)
else: nar_ckpt = torch.load(str(nar_path), map_location='cpu')
logging.info("Initializing modules...")
mars5 = Mars5TTS(ar_ckpt, nar_ckpt, device=device)
return mars5, InferenceConfig


def _load_safetensors_ckpt(url, progress):
""" Loads checkpoint from a safetensors file """
def load_checkpoint(url, progress=True, ckpt_format='pt'):
""" Helper function to download and load a checkpoint, reducing duplication """
hub_dir = torch.hub.get_dir()
model_dir = os.path.join(hub_dir, 'checkpoints')
os.makedirs(model_dir, exist_ok=True)
parts = torch.hub.urlparse(url)
filename = os.path.basename(parts.path)
cached_file = os.path.join(model_dir, filename)

if not os.path.exists(cached_file):
# download it
torch.hub.download_url_to_file(url, cached_file, None, progress=progress)
# load checkpoint

if ckpt_format == 'safetensors':
return _load_safetensors_ckpt(cached_file)
else:
return torch.load(cached_file, map_location='cpu')

def _load_safetensors_ckpt(file_path):
""" Loads a safetensors checkpoint file """
ckpt = {}
with safe_open(cached_file, framework='pt', device='cpu') as f:
with safe_open(file_path, framework='pt', device='cpu') as f:
metadata = f.metadata()
ckpt['vocab'] = {'texttok.model': metadata['texttok.model'], 'speechtok.model': metadata['speechtok.model']}
ckpt['model'] = {}
for k in f.keys(): ckpt['model'][k] = f.get_tensor(k)
ckpt['model'] = {k: f.get_tensor(k) for k in f.keys()}
return ckpt


# Load Mars5 English model on `device`, optionally showing progress.
# This function also handles user-provided path for model checkpoints,
# supporting both .pt and .safetensors formats.

def mars5_english(pretrained=True, progress=True, device=None, ckpt_format='safetensors', ar_path=None, nar_path=None):
NourMerey marked this conversation as resolved.
Show resolved Hide resolved

NourMerey marked this conversation as resolved.
Show resolved Hide resolved


if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
logging.info(f"Using device: {device}")

if not pretrained:
raise ValueError('Only pretrained models are currently supported.')

# Determine the format of the checkpoint based on the file extension if paths are provided
if ar_path is not None:
if ar_path.endswith('.pt'):
ar_ckpt = load_checkpoint(None, progress, 'pt', ar_path)
elif ar_path.endswith('.safetensors'):
ar_ckpt = load_checkpoint(None, progress, 'safetensors', ar_path)
else:
raise NotImplementedError("Unsupported file format for ar_path. Please provide a .pt or .safetensors file.")
else:
ar_ckpt = load_checkpoint(CHECKPOINT_URLS[f'ar_{ckpt_format}'], progress, ckpt_format)

if nar_path is not None:
if nar_path.endswith('.pt'):
nar_ckpt = load_checkpoint(None, progress, 'pt', nar_path)
elif nar_path.endswith('.safetensors'):
nar_ckpt = load_checkpoint(None, progress, 'safetensors', nar_path)
else:
raise NotImplementedError("Unsupported file format for nar_path. Please provide a .pt or .safetensors file.")
else:
nar_ckpt = load_checkpoint(CHECKPOINT_URLS[f'nar_{ckpt_format}'], progress, ckpt_format)

logging.info("Initializing models...")
return Mars5TTS(ar_ckpt, nar_ckpt, device=device), InferenceConfig