diff --git a/docs/audiocraft/data/audio.html b/docs/audiocraft/data/audio.html new file mode 100644 index 00000000..617b6e7b --- /dev/null +++ b/docs/audiocraft/data/audio.html @@ -0,0 +1,522 @@ + + + + + + +audiocraft.data.audio API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.data.audio

+
+
+

Audio IO methods are defined in this module (info, read, write), +We rely on av library for faster read when possible, otherwise on torchaudio.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Audio IO methods are defined in this module (info, read, write),
+We rely on av library for faster read when possible, otherwise on torchaudio.
+"""
+
+from dataclasses import dataclass
+from pathlib import Path
+import logging
+import typing as tp
+
+import numpy as np
+import soundfile
+import torch
+from torch.nn import functional as F
+import torchaudio as ta
+
+import av
+
+from .audio_utils import f32_pcm, i16_pcm, normalize_audio
+
+
+_av_initialized = False
+
+
+def _init_av():
+    global _av_initialized
+    if _av_initialized:
+        return
+    logger = logging.getLogger('libav.mp3')
+    logger.setLevel(logging.ERROR)
+    _av_initialized = True
+
+
+@dataclass(frozen=True)
+class AudioFileInfo:
+    sample_rate: int
+    duration: float
+    channels: int
+
+
+def _av_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
+    _init_av()
+    with av.open(str(filepath)) as af:
+        stream = af.streams.audio[0]
+        sample_rate = stream.codec_context.sample_rate
+        duration = float(stream.duration * stream.time_base)
+        channels = stream.channels
+        return AudioFileInfo(sample_rate, duration, channels)
+
+
+def _soundfile_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
+    info = soundfile.info(filepath)
+    return AudioFileInfo(info.samplerate, info.duration, info.channels)
+
+
+def audio_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
+    # torchaudio no longer returns useful duration informations for some formats like mp3s.
+    filepath = Path(filepath)
+    if filepath.suffix in ['.flac', '.ogg']:  # TODO: Validate .ogg can be safely read with av_info
+        # ffmpeg has some weird issue with flac.
+        return _soundfile_info(filepath)
+    else:
+        return _av_info(filepath)
+
+
+def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, duration: float = -1.) -> tp.Tuple[torch.Tensor, int]:
+    """FFMPEG-based audio file reading using PyAV bindings.
+    Soundfile cannot read mp3 and av_read is more efficient than torchaudio.
+
+    Args:
+        filepath (str or Path): Path to audio file to read.
+        seek_time (float): Time at which to start reading in the file.
+        duration (float): Duration to read from the file. If set to -1, the whole file is read.
+    Returns:
+        Tuple[torch.Tensor, int]: Tuple containing audio data and sample rate
+    """
+    _init_av()
+    with av.open(str(filepath)) as af:
+        stream = af.streams.audio[0]
+        sr = stream.codec_context.sample_rate
+        num_frames = int(sr * duration) if duration >= 0 else -1
+        frame_offset = int(sr * seek_time)
+        # we need a small negative offset otherwise we get some edge artifact
+        # from the mp3 decoder.
+        af.seek(int(max(0, (seek_time - 0.1)) / stream.time_base), stream=stream)
+        frames = []
+        length = 0
+        for frame in af.decode(streams=stream.index):
+            current_offset = int(frame.rate * frame.pts * frame.time_base)
+            strip = max(0, frame_offset - current_offset)
+            buf = torch.from_numpy(frame.to_ndarray())
+            if buf.shape[0] != stream.channels:
+                buf = buf.view(-1, stream.channels).t()
+            buf = buf[:, strip:]
+            frames.append(buf)
+            length += buf.shape[1]
+            if num_frames > 0 and length >= num_frames:
+                break
+        assert frames
+        # If the above assert fails, it is likely because we seeked past the end of file point,
+        # in which case ffmpeg returns a single frame with only zeros, and a weird timestamp.
+        # This will need proper debugging, in due time.
+        wav = torch.cat(frames, dim=1)
+        assert wav.shape[0] == stream.channels
+        if num_frames > 0:
+            wav = wav[:, :num_frames]
+        return f32_pcm(wav), sr
+
+
+def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
+               duration: float = -1., pad: bool = False) -> tp.Tuple[torch.Tensor, int]:
+    """Read audio by picking the most appropriate backend tool based on the audio format.
+
+    Args:
+        filepath (str or Path): Path to audio file to read.
+        seek_time (float): Time at which to start reading in the file.
+        duration (float): Duration to read from the file. If set to -1, the whole file is read.
+        pad (bool): Pad output audio if not reaching expected duration.
+    Returns:
+        Tuple[torch.Tensor, int]: Tuple containing audio data and sample rate.
+    """
+    fp = Path(filepath)
+    if fp.suffix in ['.flac', '.ogg']:  # TODO: check if we can safely use av_read for .ogg
+        # There is some bug with ffmpeg and reading flac
+        info = _soundfile_info(filepath)
+        frames = -1 if duration <= 0 else int(duration * info.sample_rate)
+        frame_offset = int(seek_time * info.sample_rate)
+        wav, sr = soundfile.read(filepath, start=frame_offset, frames=frames, dtype=np.float32)
+        assert info.sample_rate == sr, f"Mismatch of sample rates {info.sample_rate} {sr}"
+        wav = torch.from_numpy(wav).t().contiguous()
+        if len(wav.shape) == 1:
+            wav = torch.unsqueeze(wav, 0)
+    elif (
+        fp.suffix in ['.wav', '.mp3'] and fp.suffix[1:] in ta.utils.sox_utils.list_read_formats()
+        and duration <= 0 and seek_time == 0
+    ):
+        # Torchaudio is faster if we load an entire file at once.
+        wav, sr = ta.load(fp)
+    else:
+        wav, sr = _av_read(filepath, seek_time, duration)
+    if pad and duration > 0:
+        expected_frames = int(duration * sr)
+        wav = F.pad(wav, (0, expected_frames - wav.shape[-1]))
+    return wav, sr
+
+
+def audio_write(stem_name: tp.Union[str, Path],
+                wav: torch.Tensor, sample_rate: int,
+                format: str = 'wav', mp3_rate: int = 320, normalize: bool = True,
+                strategy: str = 'peak', peak_clip_headroom_db: float = 1,
+                rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
+                loudness_compressor: bool = False,
+                log_clipping: bool = True, make_parent_dir: bool = True,
+                add_suffix: bool = True) -> Path:
+    """Convenience function for saving audio to disk. Returns the filename the audio was written to.
+
+    Args:
+        stem_name (str or Path): Filename without extension which will be added automatically.
+        format (str): Either "wav" or "mp3".
+        mp3_rate (int): kbps when using mp3s.
+        normalize (bool): if `True` (default), normalizes according to the prescribed
+            strategy (see after). If `False`, the strategy is only used in case clipping
+            would happen.
+        strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
+            i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
+            with extra headroom to avoid clipping. 'clip' just clips.
+        peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
+        rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
+            than the `peak_clip` one to avoid further clipping.
+        loudness_headroom_db (float): Target loudness for loudness normalization.
+        loudness_compressor (bool): Uses tanh for soft clipping when strategy is 'loudness'.
+         when strategy is 'loudness'log_clipping (bool): If True, basic logging on stderr when clipping still
+            occurs despite strategy (only for 'rms').
+        make_parent_dir (bool): Make parent directory if it doesn't exist.
+    Returns:
+        Path: Path of the saved audio.
+    """
+    assert wav.dtype.is_floating_point, "wav is not floating point"
+    if wav.dim() == 1:
+        wav = wav[None]
+    elif wav.dim() > 2:
+        raise ValueError("Input wav should be at most 2 dimension.")
+    assert wav.isfinite().all()
+    wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db,
+                          rms_headroom_db, loudness_headroom_db, log_clipping=log_clipping,
+                          sample_rate=sample_rate, stem_name=str(stem_name))
+    kwargs: dict = {}
+    if format == 'mp3':
+        suffix = '.mp3'
+        kwargs.update({"compression": mp3_rate})
+    elif format == 'wav':
+        wav = i16_pcm(wav)
+        suffix = '.wav'
+        kwargs.update({"encoding": "PCM_S", "bits_per_sample": 16})
+    else:
+        raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.")
+    if not add_suffix:
+        suffix = ''
+    path = Path(str(stem_name) + suffix)
+    if make_parent_dir:
+        path.parent.mkdir(exist_ok=True, parents=True)
+    try:
+        ta.save(path, wav, sample_rate, **kwargs)
+    except Exception:
+        if path.exists():
+            # we do not want to leave half written files around.
+            path.unlink()
+        raise
+    return path
+
+
+
+
+
+
+
+

Functions

+
+
+def audio_info(filepath: Union[str, pathlib.Path]) ‑> AudioFileInfo +
+
+
+
+ +Expand source code + +
def audio_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
+    # torchaudio no longer returns useful duration informations for some formats like mp3s.
+    filepath = Path(filepath)
+    if filepath.suffix in ['.flac', '.ogg']:  # TODO: Validate .ogg can be safely read with av_info
+        # ffmpeg has some weird issue with flac.
+        return _soundfile_info(filepath)
+    else:
+        return _av_info(filepath)
+
+
+
+def audio_read(filepath: Union[str, pathlib.Path], seek_time: float = 0.0, duration: float = -1.0, pad: bool = False) ‑> Tuple[torch.Tensor, int] +
+
+

Read audio by picking the most appropriate backend tool based on the audio format.

+

Args

+
+
filepath : str or Path
+
Path to audio file to read.
+
seek_time : float
+
Time at which to start reading in the file.
+
duration : float
+
Duration to read from the file. If set to -1, the whole file is read.
+
pad : bool
+
Pad output audio if not reaching expected duration.
+
+

Returns

+
+
Tuple[torch.Tensor, int]
+
Tuple containing audio data and sample rate.
+
+
+ +Expand source code + +
def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
+               duration: float = -1., pad: bool = False) -> tp.Tuple[torch.Tensor, int]:
+    """Read audio by picking the most appropriate backend tool based on the audio format.
+
+    Args:
+        filepath (str or Path): Path to audio file to read.
+        seek_time (float): Time at which to start reading in the file.
+        duration (float): Duration to read from the file. If set to -1, the whole file is read.
+        pad (bool): Pad output audio if not reaching expected duration.
+    Returns:
+        Tuple[torch.Tensor, int]: Tuple containing audio data and sample rate.
+    """
+    fp = Path(filepath)
+    if fp.suffix in ['.flac', '.ogg']:  # TODO: check if we can safely use av_read for .ogg
+        # There is some bug with ffmpeg and reading flac
+        info = _soundfile_info(filepath)
+        frames = -1 if duration <= 0 else int(duration * info.sample_rate)
+        frame_offset = int(seek_time * info.sample_rate)
+        wav, sr = soundfile.read(filepath, start=frame_offset, frames=frames, dtype=np.float32)
+        assert info.sample_rate == sr, f"Mismatch of sample rates {info.sample_rate} {sr}"
+        wav = torch.from_numpy(wav).t().contiguous()
+        if len(wav.shape) == 1:
+            wav = torch.unsqueeze(wav, 0)
+    elif (
+        fp.suffix in ['.wav', '.mp3'] and fp.suffix[1:] in ta.utils.sox_utils.list_read_formats()
+        and duration <= 0 and seek_time == 0
+    ):
+        # Torchaudio is faster if we load an entire file at once.
+        wav, sr = ta.load(fp)
+    else:
+        wav, sr = _av_read(filepath, seek_time, duration)
+    if pad and duration > 0:
+        expected_frames = int(duration * sr)
+        wav = F.pad(wav, (0, expected_frames - wav.shape[-1]))
+    return wav, sr
+
+
+
+def audio_write(stem_name: Union[str, pathlib.Path], wav: torch.Tensor, sample_rate: int, format: str = 'wav', mp3_rate: int = 320, normalize: bool = True, strategy: str = 'peak', peak_clip_headroom_db: float = 1, rms_headroom_db: float = 18, loudness_headroom_db: float = 14, loudness_compressor: bool = False, log_clipping: bool = True, make_parent_dir: bool = True, add_suffix: bool = True) ‑> pathlib.Path +
+
+

Convenience function for saving audio to disk. Returns the filename the audio was written to.

+

Args

+
+
stem_name : str or Path
+
Filename without extension which will be added automatically.
+
format : str
+
Either "wav" or "mp3".
+
mp3_rate : int
+
kbps when using mp3s.
+
normalize : bool
+
if True (default), normalizes according to the prescribed +strategy (see after). If False, the strategy is only used in case clipping +would happen.
+
strategy : str
+
Can be either 'clip', 'peak', or 'rms'. Default is 'peak', +i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square +with extra headroom to avoid clipping. 'clip' just clips.
+
peak_clip_headroom_db : float
+
Headroom in dB when doing 'peak' or 'clip' strategy.
+
rms_headroom_db : float
+
Headroom in dB when doing 'rms' strategy. This must be much larger +than the peak_clip one to avoid further clipping.
+
loudness_headroom_db : float
+
Target loudness for loudness normalization.
+
loudness_compressor : bool
+
Uses tanh for soft clipping when strategy is 'loudness'.
+
when strategy is 'loudness'log_clipping (bool): If True, basic logging on stderr when clipping still
+
occurs despite strategy (only for 'rms').
+
make_parent_dir : bool
+
Make parent directory if it doesn't exist.
+
+

Returns

+
+
Path
+
Path of the saved audio.
+
+
+ +Expand source code + +
def audio_write(stem_name: tp.Union[str, Path],
+                wav: torch.Tensor, sample_rate: int,
+                format: str = 'wav', mp3_rate: int = 320, normalize: bool = True,
+                strategy: str = 'peak', peak_clip_headroom_db: float = 1,
+                rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
+                loudness_compressor: bool = False,
+                log_clipping: bool = True, make_parent_dir: bool = True,
+                add_suffix: bool = True) -> Path:
+    """Convenience function for saving audio to disk. Returns the filename the audio was written to.
+
+    Args:
+        stem_name (str or Path): Filename without extension which will be added automatically.
+        format (str): Either "wav" or "mp3".
+        mp3_rate (int): kbps when using mp3s.
+        normalize (bool): if `True` (default), normalizes according to the prescribed
+            strategy (see after). If `False`, the strategy is only used in case clipping
+            would happen.
+        strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
+            i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
+            with extra headroom to avoid clipping. 'clip' just clips.
+        peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
+        rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
+            than the `peak_clip` one to avoid further clipping.
+        loudness_headroom_db (float): Target loudness for loudness normalization.
+        loudness_compressor (bool): Uses tanh for soft clipping when strategy is 'loudness'.
+         when strategy is 'loudness'log_clipping (bool): If True, basic logging on stderr when clipping still
+            occurs despite strategy (only for 'rms').
+        make_parent_dir (bool): Make parent directory if it doesn't exist.
+    Returns:
+        Path: Path of the saved audio.
+    """
+    assert wav.dtype.is_floating_point, "wav is not floating point"
+    if wav.dim() == 1:
+        wav = wav[None]
+    elif wav.dim() > 2:
+        raise ValueError("Input wav should be at most 2 dimension.")
+    assert wav.isfinite().all()
+    wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db,
+                          rms_headroom_db, loudness_headroom_db, log_clipping=log_clipping,
+                          sample_rate=sample_rate, stem_name=str(stem_name))
+    kwargs: dict = {}
+    if format == 'mp3':
+        suffix = '.mp3'
+        kwargs.update({"compression": mp3_rate})
+    elif format == 'wav':
+        wav = i16_pcm(wav)
+        suffix = '.wav'
+        kwargs.update({"encoding": "PCM_S", "bits_per_sample": 16})
+    else:
+        raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.")
+    if not add_suffix:
+        suffix = ''
+    path = Path(str(stem_name) + suffix)
+    if make_parent_dir:
+        path.parent.mkdir(exist_ok=True, parents=True)
+    try:
+        ta.save(path, wav, sample_rate, **kwargs)
+    except Exception:
+        if path.exists():
+            # we do not want to leave half written files around.
+            path.unlink()
+        raise
+    return path
+
+
+
+
+
+

Classes

+
+
+class AudioFileInfo +(sample_rate: int, duration: float, channels: int) +
+
+

AudioFileInfo(sample_rate: int, duration: float, channels: int)

+
+ +Expand source code + +
class AudioFileInfo:
+    sample_rate: int
+    duration: float
+    channels: int
+
+

Class variables

+
+
var channels : int
+
+
+
+
var duration : float
+
+
+
+
var sample_rate : int
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/audiocraft/data/audio_dataset.html b/docs/audiocraft/data/audio_dataset.html new file mode 100644 index 00000000..c3b10ba1 --- /dev/null +++ b/docs/audiocraft/data/audio_dataset.html @@ -0,0 +1,1539 @@ + + + + + + +audiocraft.data.audio_dataset API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.data.audio_dataset

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import copy
+from concurrent.futures import ThreadPoolExecutor, Future
+from dataclasses import dataclass, fields
+from contextlib import ExitStack
+import gzip
+import json
+import logging
+import os
+from pathlib import Path
+import random
+import sys
+import typing as tp
+
+import torch
+import torch.nn.functional as F
+
+from .audio import audio_read, audio_info
+from .audio_utils import convert_audio
+from .zip import PathInZip
+
+try:
+    import dora
+except ImportError:
+    dora = None  # type: ignore
+
+
+@dataclass(order=True)
+class BaseInfo:
+
+    @classmethod
+    def _dict2fields(cls, dictionary: dict):
+        return {
+            field.name: dictionary[field.name]
+            for field in fields(cls) if field.name in dictionary
+        }
+
+    @classmethod
+    def from_dict(cls, dictionary: dict):
+        _dictionary = cls._dict2fields(dictionary)
+        return cls(**_dictionary)
+
+    def to_dict(self):
+        return {
+            field.name: self.__getattribute__(field.name)
+            for field in fields(self)
+            }
+
+
+@dataclass(order=True)
+class AudioMeta(BaseInfo):
+    path: str
+    duration: float
+    sample_rate: int
+    amplitude: tp.Optional[float] = None
+    weight: tp.Optional[float] = None
+    # info_path is used to load additional information about the audio file that is stored in zip files.
+    info_path: tp.Optional[PathInZip] = None
+
+    @classmethod
+    def from_dict(cls, dictionary: dict):
+        base = cls._dict2fields(dictionary)
+        if 'info_path' in base and base['info_path'] is not None:
+            base['info_path'] = PathInZip(base['info_path'])
+        return cls(**base)
+
+    def to_dict(self):
+        d = super().to_dict()
+        if d['info_path'] is not None:
+            d['info_path'] = str(d['info_path'])
+        return d
+
+
+@dataclass(order=True)
+class SegmentInfo(BaseInfo):
+    meta: AudioMeta
+    seek_time: float
+    n_frames: int  # actual number of frames without padding
+    total_frames: int  # total number of frames, padding included
+    sample_rate: int  # actual sample rate
+
+
+DEFAULT_EXTS = ['.wav', '.mp3', '.flac', '.ogg', '.m4a']
+
+logger = logging.getLogger(__name__)
+
+
+def _get_audio_meta(file_path: str, minimal: bool = True) -> AudioMeta:
+    """AudioMeta from a path to an audio file.
+
+    Args:
+        file_path (str): Resolved path of valid audio file.
+        minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
+    Returns:
+        AudioMeta: Audio file path and its metadata.
+    """
+    info = audio_info(file_path)
+    amplitude: tp.Optional[float] = None
+    if not minimal:
+        wav, sr = audio_read(file_path)
+        amplitude = wav.abs().max().item()
+    return AudioMeta(file_path, info.duration, info.sample_rate, amplitude)
+
+
+def _resolve_audio_meta(m: AudioMeta, fast: bool = True) -> AudioMeta:
+    """If Dora is available as a dependency, try to resolve potential relative paths
+    in list of AudioMeta. This method is expected to be used when loading meta from file.
+
+    Args:
+        m (AudioMeta): Audio meta to resolve.
+        fast (bool): If True, uses a really fast check for determining if a file is already absolute or not.
+            Only valid on Linux/Mac.
+    Returns:
+        AudioMeta: Audio meta with resolved path.
+    """
+    def is_abs(m):
+        if fast:
+            return str(m)[0] == '/'
+        else:
+            os.path.isabs(str(m))
+
+    if not dora:
+        return m
+
+    if not is_abs(m.path):
+        m.path = dora.git_save.to_absolute_path(m.path)
+    if m.info_path is not None and not is_abs(m.info_path.zip_path):
+        m.info_path.zip_path = dora.git_save.to_absolute_path(m.path)
+    return m
+
+
+def find_audio_files(path: tp.Union[Path, str],
+                     exts: tp.List[str] = DEFAULT_EXTS,
+                     resolve: bool = True,
+                     minimal: bool = True,
+                     progress: bool = False,
+                     workers: int = 0) -> tp.List[AudioMeta]:
+    """Build a list of AudioMeta from a given path,
+    collecting relevant audio files and fetching meta info.
+
+    Args:
+        path (str or Path): Path to folder containing audio files.
+        exts (list of str): List of file extensions to consider for audio files.
+        minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
+        progress (bool): Whether to log progress on audio files collection.
+        workers (int): number of parallel workers, if 0, use only the current thread.
+    Returns:
+        List[AudioMeta]: List of audio file path and its metadata.
+    """
+    audio_files = []
+    futures: tp.List[Future] = []
+    pool: tp.Optional[ThreadPoolExecutor] = None
+    with ExitStack() as stack:
+        if workers > 0:
+            pool = ThreadPoolExecutor(workers)
+            stack.enter_context(pool)
+
+        if progress:
+            print("Finding audio files...")
+        for root, folders, files in os.walk(path, followlinks=True):
+            for file in files:
+                full_path = Path(root) / file
+                if full_path.suffix.lower() in exts:
+                    audio_files.append(full_path)
+                    if pool is not None:
+                        futures.append(pool.submit(_get_audio_meta, str(audio_files[-1]), minimal))
+                    if progress:
+                        print(format(len(audio_files), " 8d"), end='\r', file=sys.stderr)
+
+        if progress:
+            print("Getting audio metadata...")
+        meta: tp.List[AudioMeta] = []
+        for idx, file_path in enumerate(audio_files):
+            try:
+                if pool is None:
+                    m = _get_audio_meta(str(file_path), minimal)
+                else:
+                    m = futures[idx].result()
+                if resolve:
+                    m = _resolve_audio_meta(m)
+            except Exception as err:
+                print("Error with", str(file_path), err, file=sys.stderr)
+                continue
+            meta.append(m)
+            if progress:
+                print(format((1 + idx) / len(audio_files), " 3.1%"), end='\r', file=sys.stderr)
+    meta.sort()
+    return meta
+
+
+def load_audio_meta(path: tp.Union[str, Path],
+                    resolve: bool = True, fast: bool = True) -> tp.List[AudioMeta]:
+    """Load list of AudioMeta from an optionally compressed json file.
+
+    Args:
+        path (str or Path): Path to JSON file.
+        resolve (bool): Whether to resolve the path from AudioMeta (default=True).
+        fast (bool): activates some tricks to make things faster.
+    Returns:
+        List[AudioMeta]: List of audio file path and its total duration.
+    """
+    open_fn = gzip.open if str(path).lower().endswith('.gz') else open
+    with open_fn(path, 'rb') as fp:  # type: ignore
+        lines = fp.readlines()
+    meta = []
+    for line in lines:
+        d = json.loads(line)
+        m = AudioMeta.from_dict(d)
+        if resolve:
+            m = _resolve_audio_meta(m, fast=fast)
+        meta.append(m)
+    return meta
+
+
+def save_audio_meta(path: tp.Union[str, Path], meta: tp.List[AudioMeta]):
+    """Save the audio metadata to the file pointer as json.
+
+    Args:
+        path (str or Path): Path to JSON file.
+        metadata (list of BaseAudioMeta): List of audio meta to save.
+    """
+    Path(path).parent.mkdir(exist_ok=True, parents=True)
+    open_fn = gzip.open if str(path).lower().endswith('.gz') else open
+    with open_fn(path, 'wb') as fp:  # type: ignore
+        for m in meta:
+            json_str = json.dumps(m.to_dict()) + '\n'
+            json_bytes = json_str.encode('utf-8')
+            fp.write(json_bytes)
+
+
+class AudioDataset:
+    """Base audio dataset.
+
+    The dataset takes a list of AudioMeta and create a dataset composed of segments of audio
+    and potentially additional information, by creating random segments from the list of audio
+    files referenced in the metadata and applying minimal data pre-processing such as resampling,
+    mixing of channels, padding, etc.
+
+    If no segment_duration value is provided, the AudioDataset will return the full wav for each
+    audio file. Otherwise, it will randomly sample audio files and create a segment of the specified
+    duration, applying padding if required.
+
+    By default, only the torch Tensor corresponding to the waveform is returned. Setting return_info=True
+    allows to return a tuple containing the torch Tensor and additional metadata on the segment and the
+    original audio meta.
+
+    Args:
+        meta (tp.List[AudioMeta]): List of audio files metadata.
+        segment_duration (float): Optional segment duration of audio to load.
+            If not specified, the dataset will load the full audio segment from the file.
+        shuffle (bool): Set to `True` to have the data reshuffled at every epoch.
+        sample_rate (int): Target sample rate of the loaded audio samples.
+        channels (int): Target number of channels of the loaded audio samples.
+        sample_on_duration (bool): Set to `True` to sample segments with probability
+            dependent on audio file duration. This is only used if `segment_duration` is provided.
+        sample_on_weight (bool): Set to `True` to sample segments using the `weight` entry of
+            `AudioMeta`. If `sample_on_duration` is also True, the actual weight will be the product
+            of the file duration and file weight. This is only used if `segment_duration` is provided.
+        min_segment_ratio (float): Minimum segment ratio to use when the audio file
+            is shorter than the desired segment.
+        max_read_retry (int): Maximum number of retries to sample an audio segment from the dataset.
+        return_info (bool): Whether to return the wav only or return wav along with segment info and metadata.
+        min_audio_duration (tp.Optional[float], optional): Minimum audio file duration, in seconds, if provided
+            audio shorter than this will be filtered out.
+        max_audio_duration (tp.Optional[float], optional): Maximal audio file duration in seconds, if provided
+            audio longer than this will be filtered out.
+    """
+    def __init__(self,
+                 meta: tp.List[AudioMeta],
+                 segment_duration: tp.Optional[float] = None,
+                 shuffle: bool = True,
+                 num_samples: int = 10_000,
+                 sample_rate: int = 48_000,
+                 channels: int = 2,
+                 pad: bool = True,
+                 sample_on_duration: bool = True,
+                 sample_on_weight: bool = True,
+                 min_segment_ratio: float = 0.5,
+                 max_read_retry: int = 10,
+                 return_info: bool = False,
+                 min_audio_duration: tp.Optional[float] = None,
+                 max_audio_duration: tp.Optional[float] = None
+                 ):
+        assert len(meta) > 0, 'No audio meta provided to AudioDataset. Please check loading of audio meta.'
+        assert segment_duration is None or segment_duration > 0
+        assert segment_duration is None or min_segment_ratio >= 0
+        logging.debug(f'sample_on_duration: {sample_on_duration}')
+        logging.debug(f'sample_on_weight: {sample_on_weight}')
+        logging.debug(f'pad: {pad}')
+        logging.debug(f'min_segment_ratio: {min_segment_ratio}')
+
+        self.segment_duration = segment_duration
+        self.min_segment_ratio = min_segment_ratio
+        self.max_audio_duration = max_audio_duration
+        self.min_audio_duration = min_audio_duration
+        if self.min_audio_duration is not None and self.max_audio_duration is not None:
+            assert self.min_audio_duration <= self.max_audio_duration
+        self.meta: tp.List[AudioMeta] = self._filter_duration(meta)
+        assert len(self.meta)  # Fail fast if all data has been filtered.
+        self.total_duration = sum(d.duration for d in self.meta)
+
+        if segment_duration is None:
+            num_samples = len(self.meta)
+        self.num_samples = num_samples
+        self.shuffle = shuffle
+        self.sample_rate = sample_rate
+        self.channels = channels
+        self.pad = pad
+        self.sample_on_weight = sample_on_weight
+        self.sample_on_duration = sample_on_duration
+        self.sampling_probabilities = self._get_sampling_probabilities()
+        self.max_read_retry = max_read_retry
+        self.return_info = return_info
+
+    def __len__(self):
+        return self.num_samples
+
+    def _get_sampling_probabilities(self, normalized: bool = True):
+        """Return the sampling probabilities for each file inside `self.meta`.
+        """
+        scores: tp.List[float] = []
+        for file_meta in self.meta:
+            score = 1.
+            if self.sample_on_weight and file_meta.weight is not None:
+                score *= file_meta.weight
+            if self.sample_on_duration:
+                score *= file_meta.duration
+            scores.append(score)
+        probabilities = torch.tensor(scores)
+        if normalized:
+            probabilities /= probabilities.sum()
+        return probabilities
+
+    def sample_file(self, rng: torch.Generator) -> AudioMeta:
+        """Sample a given file from `self.meta`. Can be overriden in subclasses.
+        This is only called if `segment_duration` is not None.
+
+        You must use the provided random number generator `rng` for reproducibility.
+        """
+        if not self.sample_on_weight and not self.sample_on_duration:
+            file_index = int(torch.randint(len(self.sampling_probabilities), (1,), generator=rng).item())
+        else:
+            file_index = int(torch.multinomial(self.sampling_probabilities, 1, generator=rng).item())
+
+        return self.meta[file_index]
+
+    def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentInfo]]:
+        if self.segment_duration is None:
+            file_meta = self.meta[index]
+            out, sr = audio_read(file_meta.path)
+            out = convert_audio(out, sr, self.sample_rate, self.channels)
+            n_frames = out.shape[-1]
+            segment_info = SegmentInfo(file_meta, seek_time=0., n_frames=n_frames, total_frames=n_frames,
+                                       sample_rate=self.sample_rate)
+        else:
+            rng = torch.Generator()
+            if self.shuffle:
+                # We use index, plus extra randomness
+                rng.manual_seed(index + self.num_samples * random.randint(0, 2**24))
+            else:
+                # We only use index
+                rng.manual_seed(index)
+
+            for retry in range(self.max_read_retry):
+                file_meta = self.sample_file(rng)
+                # We add some variance in the file position even if audio file is smaller than segment
+                # without ending up with empty segments
+                max_seek = max(0, file_meta.duration - self.segment_duration * self.min_segment_ratio)
+                seek_time = torch.rand(1, generator=rng).item() * max_seek
+                try:
+                    out, sr = audio_read(file_meta.path, seek_time, self.segment_duration, pad=False)
+                    out = convert_audio(out, sr, self.sample_rate, self.channels)
+                    n_frames = out.shape[-1]
+                    target_frames = int(self.segment_duration * self.sample_rate)
+                    if self.pad:
+                        out = F.pad(out, (0, target_frames - n_frames))
+                    segment_info = SegmentInfo(file_meta, seek_time, n_frames=n_frames, total_frames=target_frames,
+                                               sample_rate=self.sample_rate)
+                except Exception as exc:
+                    logger.warning("Error opening file %s: %r", file_meta.path, exc)
+                    if retry == self.max_read_retry - 1:
+                        raise
+                else:
+                    break
+
+        if self.return_info:
+            # Returns the wav and additional information on the wave segment
+            return out, segment_info
+        else:
+            return out
+
+    def collater(self, samples):
+        """The collater function has to be provided to the dataloader
+        if AudioDataset has return_info=True in order to properly collate
+        the samples of a batch.
+        """
+        if self.segment_duration is None and len(samples) > 1:
+            assert self.pad, "Must allow padding when batching examples of different durations."
+
+        # In this case the audio reaching the collater is of variable length as segment_duration=None.
+        to_pad = self.segment_duration is None and self.pad
+        if to_pad:
+            max_len = max([wav.shape[-1] for wav, _ in samples])
+
+            def _pad_wav(wav):
+                return F.pad(wav, (0, max_len - wav.shape[-1]))
+
+        if self.return_info:
+            if len(samples) > 0:
+                assert len(samples[0]) == 2
+                assert isinstance(samples[0][0], torch.Tensor)
+                assert isinstance(samples[0][1], SegmentInfo)
+
+            wavs = [wav for wav, _ in samples]
+            segment_infos = [copy.deepcopy(info) for _, info in samples]
+
+            if to_pad:
+                # Each wav could be of a different duration as they are not segmented.
+                for i in range(len(samples)):
+                    # Determines the total legth of the signal with padding, so we update here as we pad.
+                    segment_infos[i].total_frames = max_len
+                    wavs[i] = _pad_wav(wavs[i])
+
+            wav = torch.stack(wavs)
+            return wav, segment_infos
+        else:
+            assert isinstance(samples[0], torch.Tensor)
+            if to_pad:
+                samples = [_pad_wav(s) for s in samples]
+            return torch.stack(samples)
+
+    def _filter_duration(self, meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
+        """Filters out audio files with short durations.
+        Removes from meta files that have durations that will not allow to samples examples from them.
+        """
+        orig_len = len(meta)
+
+        # Filter data that is too short.
+        if self.min_audio_duration is not None:
+            meta = [m for m in meta if m.duration >= self.min_audio_duration]
+
+        # Filter data that is too long.
+        if self.max_audio_duration is not None:
+            meta = [m for m in meta if m.duration <= self.max_audio_duration]
+
+        filtered_len = len(meta)
+        removed_percentage = 100*(1-float(filtered_len)/orig_len)
+        msg = 'Removed %.2f percent of the data because it was too short or too long.' % removed_percentage
+        if removed_percentage < 10:
+            logging.debug(msg)
+        else:
+            logging.warning(msg)
+        return meta
+
+    @classmethod
+    def from_meta(cls, root: tp.Union[str, Path], **kwargs):
+        """Instantiate AudioDataset from a path to a directory containing a manifest as a jsonl file.
+
+        Args:
+            root (str or Path): Path to root folder containing audio files.
+            kwargs: Additional keyword arguments for the AudioDataset.
+        """
+        root = Path(root)
+        if root.is_dir():
+            if (root / 'data.jsonl').exists():
+                root = root / 'data.jsonl'
+            elif (root / 'data.jsonl.gz').exists():
+                root = root / 'data.jsonl.gz'
+            else:
+                raise ValueError("Don't know where to read metadata from in the dir. "
+                                 "Expecting either a data.jsonl or data.jsonl.gz file but none found.")
+        meta = load_audio_meta(root)
+        return cls(meta, **kwargs)
+
+    @classmethod
+    def from_path(cls, root: tp.Union[str, Path], minimal_meta: bool = True,
+                  exts: tp.List[str] = DEFAULT_EXTS, **kwargs):
+        """Instantiate AudioDataset from a path containing (possibly nested) audio files.
+
+        Args:
+            root (str or Path): Path to root folder containing audio files.
+            minimal_meta (bool): Whether to only load minimal metadata or not.
+            exts (list of str): Extensions for audio files.
+            kwargs: Additional keyword arguments for the AudioDataset.
+        """
+        root = Path(root)
+        if root.is_file():
+            meta = load_audio_meta(root, resolve=True)
+        else:
+            meta = find_audio_files(root, exts, minimal=minimal_meta, resolve=True)
+        return cls(meta, **kwargs)
+
+
+def main():
+    logging.basicConfig(stream=sys.stderr, level=logging.INFO)
+    parser = argparse.ArgumentParser(
+        prog='audio_dataset',
+        description='Generate .jsonl files by scanning a folder.')
+    parser.add_argument('root', help='Root folder with all the audio files')
+    parser.add_argument('output_meta_file',
+                        help='Output file to store the metadata, ')
+    parser.add_argument('--complete',
+                        action='store_false', dest='minimal', default=True,
+                        help='Retrieve all metadata, even the one that are expansive '
+                             'to compute (e.g. normalization).')
+    parser.add_argument('--resolve',
+                        action='store_true', default=False,
+                        help='Resolve the paths to be absolute and with no symlinks.')
+    parser.add_argument('--workers',
+                        default=10, type=int,
+                        help='Number of workers.')
+    args = parser.parse_args()
+    meta = find_audio_files(args.root, DEFAULT_EXTS, progress=True,
+                            resolve=args.resolve, minimal=args.minimal, workers=args.workers)
+    save_audio_meta(args.output_meta_file, meta)
+
+
+if __name__ == '__main__':
+    main()
+
+
+
+
+
+
+
+

Functions

+
+
+def find_audio_files(path: Union[str, pathlib.Path], exts: List[str] = ['.wav', '.mp3', '.flac', '.ogg', '.m4a'], resolve: bool = True, minimal: bool = True, progress: bool = False, workers: int = 0) ‑> List[AudioMeta] +
+
+

Build a list of AudioMeta from a given path, +collecting relevant audio files and fetching meta info.

+

Args

+
+
path : str or Path
+
Path to folder containing audio files.
+
exts : list of str
+
List of file extensions to consider for audio files.
+
minimal : bool
+
Whether to only load the minimal set of metadata (takes longer if not).
+
progress : bool
+
Whether to log progress on audio files collection.
+
workers : int
+
number of parallel workers, if 0, use only the current thread.
+
+

Returns

+
+
List[AudioMeta]
+
List of audio file path and its metadata.
+
+
+ +Expand source code + +
def find_audio_files(path: tp.Union[Path, str],
+                     exts: tp.List[str] = DEFAULT_EXTS,
+                     resolve: bool = True,
+                     minimal: bool = True,
+                     progress: bool = False,
+                     workers: int = 0) -> tp.List[AudioMeta]:
+    """Build a list of AudioMeta from a given path,
+    collecting relevant audio files and fetching meta info.
+
+    Args:
+        path (str or Path): Path to folder containing audio files.
+        exts (list of str): List of file extensions to consider for audio files.
+        minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
+        progress (bool): Whether to log progress on audio files collection.
+        workers (int): number of parallel workers, if 0, use only the current thread.
+    Returns:
+        List[AudioMeta]: List of audio file path and its metadata.
+    """
+    audio_files = []
+    futures: tp.List[Future] = []
+    pool: tp.Optional[ThreadPoolExecutor] = None
+    with ExitStack() as stack:
+        if workers > 0:
+            pool = ThreadPoolExecutor(workers)
+            stack.enter_context(pool)
+
+        if progress:
+            print("Finding audio files...")
+        for root, folders, files in os.walk(path, followlinks=True):
+            for file in files:
+                full_path = Path(root) / file
+                if full_path.suffix.lower() in exts:
+                    audio_files.append(full_path)
+                    if pool is not None:
+                        futures.append(pool.submit(_get_audio_meta, str(audio_files[-1]), minimal))
+                    if progress:
+                        print(format(len(audio_files), " 8d"), end='\r', file=sys.stderr)
+
+        if progress:
+            print("Getting audio metadata...")
+        meta: tp.List[AudioMeta] = []
+        for idx, file_path in enumerate(audio_files):
+            try:
+                if pool is None:
+                    m = _get_audio_meta(str(file_path), minimal)
+                else:
+                    m = futures[idx].result()
+                if resolve:
+                    m = _resolve_audio_meta(m)
+            except Exception as err:
+                print("Error with", str(file_path), err, file=sys.stderr)
+                continue
+            meta.append(m)
+            if progress:
+                print(format((1 + idx) / len(audio_files), " 3.1%"), end='\r', file=sys.stderr)
+    meta.sort()
+    return meta
+
+
+
+def load_audio_meta(path: Union[str, pathlib.Path], resolve: bool = True, fast: bool = True) ‑> List[AudioMeta] +
+
+

Load list of AudioMeta from an optionally compressed json file.

+

Args

+
+
path : str or Path
+
Path to JSON file.
+
resolve : bool
+
Whether to resolve the path from AudioMeta (default=True).
+
fast : bool
+
activates some tricks to make things faster.
+
+

Returns

+
+
List[AudioMeta]
+
List of audio file path and its total duration.
+
+
+ +Expand source code + +
def load_audio_meta(path: tp.Union[str, Path],
+                    resolve: bool = True, fast: bool = True) -> tp.List[AudioMeta]:
+    """Load list of AudioMeta from an optionally compressed json file.
+
+    Args:
+        path (str or Path): Path to JSON file.
+        resolve (bool): Whether to resolve the path from AudioMeta (default=True).
+        fast (bool): activates some tricks to make things faster.
+    Returns:
+        List[AudioMeta]: List of audio file path and its total duration.
+    """
+    open_fn = gzip.open if str(path).lower().endswith('.gz') else open
+    with open_fn(path, 'rb') as fp:  # type: ignore
+        lines = fp.readlines()
+    meta = []
+    for line in lines:
+        d = json.loads(line)
+        m = AudioMeta.from_dict(d)
+        if resolve:
+            m = _resolve_audio_meta(m, fast=fast)
+        meta.append(m)
+    return meta
+
+
+
+def main() +
+
+
+
+ +Expand source code + +
def main():
+    logging.basicConfig(stream=sys.stderr, level=logging.INFO)
+    parser = argparse.ArgumentParser(
+        prog='audio_dataset',
+        description='Generate .jsonl files by scanning a folder.')
+    parser.add_argument('root', help='Root folder with all the audio files')
+    parser.add_argument('output_meta_file',
+                        help='Output file to store the metadata, ')
+    parser.add_argument('--complete',
+                        action='store_false', dest='minimal', default=True,
+                        help='Retrieve all metadata, even the one that are expansive '
+                             'to compute (e.g. normalization).')
+    parser.add_argument('--resolve',
+                        action='store_true', default=False,
+                        help='Resolve the paths to be absolute and with no symlinks.')
+    parser.add_argument('--workers',
+                        default=10, type=int,
+                        help='Number of workers.')
+    args = parser.parse_args()
+    meta = find_audio_files(args.root, DEFAULT_EXTS, progress=True,
+                            resolve=args.resolve, minimal=args.minimal, workers=args.workers)
+    save_audio_meta(args.output_meta_file, meta)
+
+
+
+def save_audio_meta(path: Union[str, pathlib.Path], meta: List[AudioMeta]) +
+
+

Save the audio metadata to the file pointer as json.

+

Args

+
+
path : str or Path
+
Path to JSON file.
+
metadata : list of BaseAudioMeta
+
List of audio meta to save.
+
+
+ +Expand source code + +
def save_audio_meta(path: tp.Union[str, Path], meta: tp.List[AudioMeta]):
+    """Save the audio metadata to the file pointer as json.
+
+    Args:
+        path (str or Path): Path to JSON file.
+        metadata (list of BaseAudioMeta): List of audio meta to save.
+    """
+    Path(path).parent.mkdir(exist_ok=True, parents=True)
+    open_fn = gzip.open if str(path).lower().endswith('.gz') else open
+    with open_fn(path, 'wb') as fp:  # type: ignore
+        for m in meta:
+            json_str = json.dumps(m.to_dict()) + '\n'
+            json_bytes = json_str.encode('utf-8')
+            fp.write(json_bytes)
+
+
+
+
+
+

Classes

+
+
+class AudioDataset +(meta: List[AudioMeta], segment_duration: Optional[float] = None, shuffle: bool = True, num_samples: int = 10000, sample_rate: int = 48000, channels: int = 2, pad: bool = True, sample_on_duration: bool = True, sample_on_weight: bool = True, min_segment_ratio: float = 0.5, max_read_retry: int = 10, return_info: bool = False, min_audio_duration: Optional[float] = None, max_audio_duration: Optional[float] = None) +
+
+

Base audio dataset.

+

The dataset takes a list of AudioMeta and create a dataset composed of segments of audio +and potentially additional information, by creating random segments from the list of audio +files referenced in the metadata and applying minimal data pre-processing such as resampling, +mixing of channels, padding, etc.

+

If no segment_duration value is provided, the AudioDataset will return the full wav for each +audio file. Otherwise, it will randomly sample audio files and create a segment of the specified +duration, applying padding if required.

+

By default, only the torch Tensor corresponding to the waveform is returned. Setting return_info=True +allows to return a tuple containing the torch Tensor and additional metadata on the segment and the +original audio meta.

+

Args

+
+
meta : tp.List[AudioMeta]
+
List of audio files metadata.
+
segment_duration : float
+
Optional segment duration of audio to load. +If not specified, the dataset will load the full audio segment from the file.
+
shuffle : bool
+
Set to True to have the data reshuffled at every epoch.
+
sample_rate : int
+
Target sample rate of the loaded audio samples.
+
channels : int
+
Target number of channels of the loaded audio samples.
+
sample_on_duration : bool
+
Set to True to sample segments with probability +dependent on audio file duration. This is only used if segment_duration is provided.
+
sample_on_weight : bool
+
Set to True to sample segments using the weight entry of +AudioMeta. If sample_on_duration is also True, the actual weight will be the product +of the file duration and file weight. This is only used if segment_duration is provided.
+
min_segment_ratio : float
+
Minimum segment ratio to use when the audio file +is shorter than the desired segment.
+
max_read_retry : int
+
Maximum number of retries to sample an audio segment from the dataset.
+
return_info : bool
+
Whether to return the wav only or return wav along with segment info and metadata.
+
min_audio_duration : tp.Optional[float], optional
+
Minimum audio file duration, in seconds, if provided +audio shorter than this will be filtered out.
+
max_audio_duration : tp.Optional[float], optional
+
Maximal audio file duration in seconds, if provided +audio longer than this will be filtered out.
+
+
+ +Expand source code + +
class AudioDataset:
+    """Base audio dataset.
+
+    The dataset takes a list of AudioMeta and create a dataset composed of segments of audio
+    and potentially additional information, by creating random segments from the list of audio
+    files referenced in the metadata and applying minimal data pre-processing such as resampling,
+    mixing of channels, padding, etc.
+
+    If no segment_duration value is provided, the AudioDataset will return the full wav for each
+    audio file. Otherwise, it will randomly sample audio files and create a segment of the specified
+    duration, applying padding if required.
+
+    By default, only the torch Tensor corresponding to the waveform is returned. Setting return_info=True
+    allows to return a tuple containing the torch Tensor and additional metadata on the segment and the
+    original audio meta.
+
+    Args:
+        meta (tp.List[AudioMeta]): List of audio files metadata.
+        segment_duration (float): Optional segment duration of audio to load.
+            If not specified, the dataset will load the full audio segment from the file.
+        shuffle (bool): Set to `True` to have the data reshuffled at every epoch.
+        sample_rate (int): Target sample rate of the loaded audio samples.
+        channels (int): Target number of channels of the loaded audio samples.
+        sample_on_duration (bool): Set to `True` to sample segments with probability
+            dependent on audio file duration. This is only used if `segment_duration` is provided.
+        sample_on_weight (bool): Set to `True` to sample segments using the `weight` entry of
+            `AudioMeta`. If `sample_on_duration` is also True, the actual weight will be the product
+            of the file duration and file weight. This is only used if `segment_duration` is provided.
+        min_segment_ratio (float): Minimum segment ratio to use when the audio file
+            is shorter than the desired segment.
+        max_read_retry (int): Maximum number of retries to sample an audio segment from the dataset.
+        return_info (bool): Whether to return the wav only or return wav along with segment info and metadata.
+        min_audio_duration (tp.Optional[float], optional): Minimum audio file duration, in seconds, if provided
+            audio shorter than this will be filtered out.
+        max_audio_duration (tp.Optional[float], optional): Maximal audio file duration in seconds, if provided
+            audio longer than this will be filtered out.
+    """
+    def __init__(self,
+                 meta: tp.List[AudioMeta],
+                 segment_duration: tp.Optional[float] = None,
+                 shuffle: bool = True,
+                 num_samples: int = 10_000,
+                 sample_rate: int = 48_000,
+                 channels: int = 2,
+                 pad: bool = True,
+                 sample_on_duration: bool = True,
+                 sample_on_weight: bool = True,
+                 min_segment_ratio: float = 0.5,
+                 max_read_retry: int = 10,
+                 return_info: bool = False,
+                 min_audio_duration: tp.Optional[float] = None,
+                 max_audio_duration: tp.Optional[float] = None
+                 ):
+        assert len(meta) > 0, 'No audio meta provided to AudioDataset. Please check loading of audio meta.'
+        assert segment_duration is None or segment_duration > 0
+        assert segment_duration is None or min_segment_ratio >= 0
+        logging.debug(f'sample_on_duration: {sample_on_duration}')
+        logging.debug(f'sample_on_weight: {sample_on_weight}')
+        logging.debug(f'pad: {pad}')
+        logging.debug(f'min_segment_ratio: {min_segment_ratio}')
+
+        self.segment_duration = segment_duration
+        self.min_segment_ratio = min_segment_ratio
+        self.max_audio_duration = max_audio_duration
+        self.min_audio_duration = min_audio_duration
+        if self.min_audio_duration is not None and self.max_audio_duration is not None:
+            assert self.min_audio_duration <= self.max_audio_duration
+        self.meta: tp.List[AudioMeta] = self._filter_duration(meta)
+        assert len(self.meta)  # Fail fast if all data has been filtered.
+        self.total_duration = sum(d.duration for d in self.meta)
+
+        if segment_duration is None:
+            num_samples = len(self.meta)
+        self.num_samples = num_samples
+        self.shuffle = shuffle
+        self.sample_rate = sample_rate
+        self.channels = channels
+        self.pad = pad
+        self.sample_on_weight = sample_on_weight
+        self.sample_on_duration = sample_on_duration
+        self.sampling_probabilities = self._get_sampling_probabilities()
+        self.max_read_retry = max_read_retry
+        self.return_info = return_info
+
+    def __len__(self):
+        return self.num_samples
+
+    def _get_sampling_probabilities(self, normalized: bool = True):
+        """Return the sampling probabilities for each file inside `self.meta`.
+        """
+        scores: tp.List[float] = []
+        for file_meta in self.meta:
+            score = 1.
+            if self.sample_on_weight and file_meta.weight is not None:
+                score *= file_meta.weight
+            if self.sample_on_duration:
+                score *= file_meta.duration
+            scores.append(score)
+        probabilities = torch.tensor(scores)
+        if normalized:
+            probabilities /= probabilities.sum()
+        return probabilities
+
+    def sample_file(self, rng: torch.Generator) -> AudioMeta:
+        """Sample a given file from `self.meta`. Can be overriden in subclasses.
+        This is only called if `segment_duration` is not None.
+
+        You must use the provided random number generator `rng` for reproducibility.
+        """
+        if not self.sample_on_weight and not self.sample_on_duration:
+            file_index = int(torch.randint(len(self.sampling_probabilities), (1,), generator=rng).item())
+        else:
+            file_index = int(torch.multinomial(self.sampling_probabilities, 1, generator=rng).item())
+
+        return self.meta[file_index]
+
+    def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentInfo]]:
+        if self.segment_duration is None:
+            file_meta = self.meta[index]
+            out, sr = audio_read(file_meta.path)
+            out = convert_audio(out, sr, self.sample_rate, self.channels)
+            n_frames = out.shape[-1]
+            segment_info = SegmentInfo(file_meta, seek_time=0., n_frames=n_frames, total_frames=n_frames,
+                                       sample_rate=self.sample_rate)
+        else:
+            rng = torch.Generator()
+            if self.shuffle:
+                # We use index, plus extra randomness
+                rng.manual_seed(index + self.num_samples * random.randint(0, 2**24))
+            else:
+                # We only use index
+                rng.manual_seed(index)
+
+            for retry in range(self.max_read_retry):
+                file_meta = self.sample_file(rng)
+                # We add some variance in the file position even if audio file is smaller than segment
+                # without ending up with empty segments
+                max_seek = max(0, file_meta.duration - self.segment_duration * self.min_segment_ratio)
+                seek_time = torch.rand(1, generator=rng).item() * max_seek
+                try:
+                    out, sr = audio_read(file_meta.path, seek_time, self.segment_duration, pad=False)
+                    out = convert_audio(out, sr, self.sample_rate, self.channels)
+                    n_frames = out.shape[-1]
+                    target_frames = int(self.segment_duration * self.sample_rate)
+                    if self.pad:
+                        out = F.pad(out, (0, target_frames - n_frames))
+                    segment_info = SegmentInfo(file_meta, seek_time, n_frames=n_frames, total_frames=target_frames,
+                                               sample_rate=self.sample_rate)
+                except Exception as exc:
+                    logger.warning("Error opening file %s: %r", file_meta.path, exc)
+                    if retry == self.max_read_retry - 1:
+                        raise
+                else:
+                    break
+
+        if self.return_info:
+            # Returns the wav and additional information on the wave segment
+            return out, segment_info
+        else:
+            return out
+
+    def collater(self, samples):
+        """The collater function has to be provided to the dataloader
+        if AudioDataset has return_info=True in order to properly collate
+        the samples of a batch.
+        """
+        if self.segment_duration is None and len(samples) > 1:
+            assert self.pad, "Must allow padding when batching examples of different durations."
+
+        # In this case the audio reaching the collater is of variable length as segment_duration=None.
+        to_pad = self.segment_duration is None and self.pad
+        if to_pad:
+            max_len = max([wav.shape[-1] for wav, _ in samples])
+
+            def _pad_wav(wav):
+                return F.pad(wav, (0, max_len - wav.shape[-1]))
+
+        if self.return_info:
+            if len(samples) > 0:
+                assert len(samples[0]) == 2
+                assert isinstance(samples[0][0], torch.Tensor)
+                assert isinstance(samples[0][1], SegmentInfo)
+
+            wavs = [wav for wav, _ in samples]
+            segment_infos = [copy.deepcopy(info) for _, info in samples]
+
+            if to_pad:
+                # Each wav could be of a different duration as they are not segmented.
+                for i in range(len(samples)):
+                    # Determines the total legth of the signal with padding, so we update here as we pad.
+                    segment_infos[i].total_frames = max_len
+                    wavs[i] = _pad_wav(wavs[i])
+
+            wav = torch.stack(wavs)
+            return wav, segment_infos
+        else:
+            assert isinstance(samples[0], torch.Tensor)
+            if to_pad:
+                samples = [_pad_wav(s) for s in samples]
+            return torch.stack(samples)
+
+    def _filter_duration(self, meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
+        """Filters out audio files with short durations.
+        Removes from meta files that have durations that will not allow to samples examples from them.
+        """
+        orig_len = len(meta)
+
+        # Filter data that is too short.
+        if self.min_audio_duration is not None:
+            meta = [m for m in meta if m.duration >= self.min_audio_duration]
+
+        # Filter data that is too long.
+        if self.max_audio_duration is not None:
+            meta = [m for m in meta if m.duration <= self.max_audio_duration]
+
+        filtered_len = len(meta)
+        removed_percentage = 100*(1-float(filtered_len)/orig_len)
+        msg = 'Removed %.2f percent of the data because it was too short or too long.' % removed_percentage
+        if removed_percentage < 10:
+            logging.debug(msg)
+        else:
+            logging.warning(msg)
+        return meta
+
+    @classmethod
+    def from_meta(cls, root: tp.Union[str, Path], **kwargs):
+        """Instantiate AudioDataset from a path to a directory containing a manifest as a jsonl file.
+
+        Args:
+            root (str or Path): Path to root folder containing audio files.
+            kwargs: Additional keyword arguments for the AudioDataset.
+        """
+        root = Path(root)
+        if root.is_dir():
+            if (root / 'data.jsonl').exists():
+                root = root / 'data.jsonl'
+            elif (root / 'data.jsonl.gz').exists():
+                root = root / 'data.jsonl.gz'
+            else:
+                raise ValueError("Don't know where to read metadata from in the dir. "
+                                 "Expecting either a data.jsonl or data.jsonl.gz file but none found.")
+        meta = load_audio_meta(root)
+        return cls(meta, **kwargs)
+
+    @classmethod
+    def from_path(cls, root: tp.Union[str, Path], minimal_meta: bool = True,
+                  exts: tp.List[str] = DEFAULT_EXTS, **kwargs):
+        """Instantiate AudioDataset from a path containing (possibly nested) audio files.
+
+        Args:
+            root (str or Path): Path to root folder containing audio files.
+            minimal_meta (bool): Whether to only load minimal metadata or not.
+            exts (list of str): Extensions for audio files.
+            kwargs: Additional keyword arguments for the AudioDataset.
+        """
+        root = Path(root)
+        if root.is_file():
+            meta = load_audio_meta(root, resolve=True)
+        else:
+            meta = find_audio_files(root, exts, minimal=minimal_meta, resolve=True)
+        return cls(meta, **kwargs)
+
+

Static methods

+
+
+def from_meta(root: Union[str, pathlib.Path], **kwargs) +
+
+

Instantiate AudioDataset from a path to a directory containing a manifest as a jsonl file.

+

Args

+
+
root : str or Path
+
Path to root folder containing audio files.
+
kwargs
+
Additional keyword arguments for the AudioDataset.
+
+
+ +Expand source code + +
@classmethod
+def from_meta(cls, root: tp.Union[str, Path], **kwargs):
+    """Instantiate AudioDataset from a path to a directory containing a manifest as a jsonl file.
+
+    Args:
+        root (str or Path): Path to root folder containing audio files.
+        kwargs: Additional keyword arguments for the AudioDataset.
+    """
+    root = Path(root)
+    if root.is_dir():
+        if (root / 'data.jsonl').exists():
+            root = root / 'data.jsonl'
+        elif (root / 'data.jsonl.gz').exists():
+            root = root / 'data.jsonl.gz'
+        else:
+            raise ValueError("Don't know where to read metadata from in the dir. "
+                             "Expecting either a data.jsonl or data.jsonl.gz file but none found.")
+    meta = load_audio_meta(root)
+    return cls(meta, **kwargs)
+
+
+
+def from_path(root: Union[str, pathlib.Path], minimal_meta: bool = True, exts: List[str] = ['.wav', '.mp3', '.flac', '.ogg', '.m4a'], **kwargs) +
+
+

Instantiate AudioDataset from a path containing (possibly nested) audio files.

+

Args

+
+
root : str or Path
+
Path to root folder containing audio files.
+
minimal_meta : bool
+
Whether to only load minimal metadata or not.
+
exts : list of str
+
Extensions for audio files.
+
kwargs
+
Additional keyword arguments for the AudioDataset.
+
+
+ +Expand source code + +
@classmethod
+def from_path(cls, root: tp.Union[str, Path], minimal_meta: bool = True,
+              exts: tp.List[str] = DEFAULT_EXTS, **kwargs):
+    """Instantiate AudioDataset from a path containing (possibly nested) audio files.
+
+    Args:
+        root (str or Path): Path to root folder containing audio files.
+        minimal_meta (bool): Whether to only load minimal metadata or not.
+        exts (list of str): Extensions for audio files.
+        kwargs: Additional keyword arguments for the AudioDataset.
+    """
+    root = Path(root)
+    if root.is_file():
+        meta = load_audio_meta(root, resolve=True)
+    else:
+        meta = find_audio_files(root, exts, minimal=minimal_meta, resolve=True)
+    return cls(meta, **kwargs)
+
+
+
+

Methods

+
+
+def collater(self, samples) +
+
+

The collater function has to be provided to the dataloader +if AudioDataset has return_info=True in order to properly collate +the samples of a batch.

+
+ +Expand source code + +
def collater(self, samples):
+    """The collater function has to be provided to the dataloader
+    if AudioDataset has return_info=True in order to properly collate
+    the samples of a batch.
+    """
+    if self.segment_duration is None and len(samples) > 1:
+        assert self.pad, "Must allow padding when batching examples of different durations."
+
+    # In this case the audio reaching the collater is of variable length as segment_duration=None.
+    to_pad = self.segment_duration is None and self.pad
+    if to_pad:
+        max_len = max([wav.shape[-1] for wav, _ in samples])
+
+        def _pad_wav(wav):
+            return F.pad(wav, (0, max_len - wav.shape[-1]))
+
+    if self.return_info:
+        if len(samples) > 0:
+            assert len(samples[0]) == 2
+            assert isinstance(samples[0][0], torch.Tensor)
+            assert isinstance(samples[0][1], SegmentInfo)
+
+        wavs = [wav for wav, _ in samples]
+        segment_infos = [copy.deepcopy(info) for _, info in samples]
+
+        if to_pad:
+            # Each wav could be of a different duration as they are not segmented.
+            for i in range(len(samples)):
+                # Determines the total legth of the signal with padding, so we update here as we pad.
+                segment_infos[i].total_frames = max_len
+                wavs[i] = _pad_wav(wavs[i])
+
+        wav = torch.stack(wavs)
+        return wav, segment_infos
+    else:
+        assert isinstance(samples[0], torch.Tensor)
+        if to_pad:
+            samples = [_pad_wav(s) for s in samples]
+        return torch.stack(samples)
+
+
+
+def sample_file(self, rng: torch._C.Generator) ‑> AudioMeta +
+
+

Sample a given file from self.meta. Can be overriden in subclasses. +This is only called if segment_duration is not None.

+

You must use the provided random number generator rng for reproducibility.

+
+ +Expand source code + +
def sample_file(self, rng: torch.Generator) -> AudioMeta:
+    """Sample a given file from `self.meta`. Can be overriden in subclasses.
+    This is only called if `segment_duration` is not None.
+
+    You must use the provided random number generator `rng` for reproducibility.
+    """
+    if not self.sample_on_weight and not self.sample_on_duration:
+        file_index = int(torch.randint(len(self.sampling_probabilities), (1,), generator=rng).item())
+    else:
+        file_index = int(torch.multinomial(self.sampling_probabilities, 1, generator=rng).item())
+
+    return self.meta[file_index]
+
+
+
+
+
+class AudioMeta +(path: str, duration: float, sample_rate: int, amplitude: Optional[float] = None, weight: Optional[float] = None, info_path: Optional[PathInZip] = None) +
+
+

AudioMeta(path: str, duration: float, sample_rate: int, amplitude: Union[float, NoneType] = None, weight: Union[float, NoneType] = None, info_path: Union[audiocraft.data.zip.PathInZip, NoneType] = None)

+
+ +Expand source code + +
class AudioMeta(BaseInfo):
+    path: str
+    duration: float
+    sample_rate: int
+    amplitude: tp.Optional[float] = None
+    weight: tp.Optional[float] = None
+    # info_path is used to load additional information about the audio file that is stored in zip files.
+    info_path: tp.Optional[PathInZip] = None
+
+    @classmethod
+    def from_dict(cls, dictionary: dict):
+        base = cls._dict2fields(dictionary)
+        if 'info_path' in base and base['info_path'] is not None:
+            base['info_path'] = PathInZip(base['info_path'])
+        return cls(**base)
+
+    def to_dict(self):
+        d = super().to_dict()
+        if d['info_path'] is not None:
+            d['info_path'] = str(d['info_path'])
+        return d
+
+

Ancestors

+ +

Class variables

+
+
var amplitude : Optional[float]
+
+
+
+
var duration : float
+
+
+
+
var info_path : Optional[PathInZip]
+
+
+
+
var path : str
+
+
+
+
var sample_rate : int
+
+
+
+
var weight : Optional[float]
+
+
+
+
+

Static methods

+
+
+def from_dict(dictionary: dict) +
+
+
+
+ +Expand source code + +
@classmethod
+def from_dict(cls, dictionary: dict):
+    base = cls._dict2fields(dictionary)
+    if 'info_path' in base and base['info_path'] is not None:
+        base['info_path'] = PathInZip(base['info_path'])
+    return cls(**base)
+
+
+
+

Methods

+
+
+def to_dict(self) +
+
+
+
+ +Expand source code + +
def to_dict(self):
+    d = super().to_dict()
+    if d['info_path'] is not None:
+        d['info_path'] = str(d['info_path'])
+    return d
+
+
+
+
+
+class BaseInfo +
+
+

BaseInfo()

+
+ +Expand source code + +
class BaseInfo:
+
+    @classmethod
+    def _dict2fields(cls, dictionary: dict):
+        return {
+            field.name: dictionary[field.name]
+            for field in fields(cls) if field.name in dictionary
+        }
+
+    @classmethod
+    def from_dict(cls, dictionary: dict):
+        _dictionary = cls._dict2fields(dictionary)
+        return cls(**_dictionary)
+
+    def to_dict(self):
+        return {
+            field.name: self.__getattribute__(field.name)
+            for field in fields(self)
+            }
+
+

Subclasses

+ +

Static methods

+
+
+def from_dict(dictionary: dict) +
+
+
+
+ +Expand source code + +
@classmethod
+def from_dict(cls, dictionary: dict):
+    _dictionary = cls._dict2fields(dictionary)
+    return cls(**_dictionary)
+
+
+
+

Methods

+
+
+def to_dict(self) +
+
+
+
+ +Expand source code + +
def to_dict(self):
+    return {
+        field.name: self.__getattribute__(field.name)
+        for field in fields(self)
+        }
+
+
+
+
+
+class SegmentInfo +(meta: AudioMeta, seek_time: float, n_frames: int, total_frames: int, sample_rate: int) +
+
+

SegmentInfo(meta: audiocraft.data.audio_dataset.AudioMeta, seek_time: float, n_frames: int, total_frames: int, sample_rate: int)

+
+ +Expand source code + +
class SegmentInfo(BaseInfo):
+    meta: AudioMeta
+    seek_time: float
+    n_frames: int  # actual number of frames without padding
+    total_frames: int  # total number of frames, padding included
+    sample_rate: int  # actual sample rate
+
+

Ancestors

+ +

Subclasses

+ +

Class variables

+
+
var metaAudioMeta
+
+
+
+
var n_frames : int
+
+
+
+
var sample_rate : int
+
+
+
+
var seek_time : float
+
+
+
+
var total_frames : int
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/audiocraft/data/audio_utils.html b/docs/audiocraft/data/audio_utils.html new file mode 100644 index 00000000..20744e25 --- /dev/null +++ b/docs/audiocraft/data/audio_utils.html @@ -0,0 +1,519 @@ + + + + + + +audiocraft.data.audio_utils API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.data.audio_utils

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import sys
+import typing as tp
+
+import julius
+import torch
+import torchaudio
+
+
+def convert_audio_channels(wav: torch.Tensor, channels: int = 2) -> torch.Tensor:
+    """Convert audio to the given number of channels.
+
+    Args:
+        wav (torch.Tensor): Audio wave of shape [B, C, T].
+        channels (int): Expected number of channels as output.
+    Returns:
+        torch.Tensor: Downmixed or unchanged audio wave [B, C, T].
+    """
+    *shape, src_channels, length = wav.shape
+    if src_channels == channels:
+        pass
+    elif channels == 1:
+        # Case 1:
+        # The caller asked 1-channel audio, and the stream has multiple
+        # channels, downmix all channels.
+        wav = wav.mean(dim=-2, keepdim=True)
+    elif src_channels == 1:
+        # Case 2:
+        # The caller asked for multiple channels, but the input file has
+        # a single channel, replicate the audio over all channels.
+        wav = wav.expand(*shape, channels, length)
+    elif src_channels >= channels:
+        # Case 3:
+        # The caller asked for multiple channels, and the input file has
+        # more channels than requested. In that case return the first channels.
+        wav = wav[..., :channels, :]
+    else:
+        # Case 4: What is a reasonable choice here?
+        raise ValueError('The audio file has less channels than requested but is not mono.')
+    return wav
+
+
+def convert_audio(wav: torch.Tensor, from_rate: float,
+                  to_rate: float, to_channels: int) -> torch.Tensor:
+    """Convert audio to new sample rate and number of audio channels.
+    """
+    wav = julius.resample_frac(wav, int(from_rate), int(to_rate))
+    wav = convert_audio_channels(wav, to_channels)
+    return wav
+
+
+def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db: float = 14,
+                       loudness_compressor: bool = False, energy_floor: float = 2e-3):
+    """Normalize an input signal to a user loudness in dB LKFS.
+    Audio loudness is defined according to the ITU-R BS.1770-4 recommendation.
+
+    Args:
+        wav (torch.Tensor): Input multichannel audio data.
+        sample_rate (int): Sample rate.
+        loudness_headroom_db (float): Target loudness of the output in dB LUFS.
+        loudness_compressor (bool): Uses tanh for soft clipping.
+        energy_floor (float): anything below that RMS level will not be rescaled.
+    Returns:
+        output (torch.Tensor): Loudness normalized output data.
+    """
+    energy = wav.pow(2).mean().sqrt().item()
+    if energy < energy_floor:
+        return wav
+    transform = torchaudio.transforms.Loudness(sample_rate)
+    input_loudness_db = transform(wav).item()
+    # calculate the gain needed to scale to the desired loudness level
+    delta_loudness = -loudness_headroom_db - input_loudness_db
+    gain = 10.0 ** (delta_loudness / 20.0)
+    output = gain * wav
+    if loudness_compressor:
+        output = torch.tanh(output)
+    assert output.isfinite().all(), (input_loudness_db, wav.pow(2).mean().sqrt())
+    return output
+
+
+def _clip_wav(wav: torch.Tensor, log_clipping: bool = False, stem_name: tp.Optional[str] = None) -> None:
+    """Utility function to clip the audio with logging if specified."""
+    max_scale = wav.abs().max()
+    if log_clipping and max_scale > 1:
+        clamp_prob = (wav.abs() > 1).float().mean().item()
+        print(f"CLIPPING {stem_name or ''} happening with proba (a bit of clipping is okay):",
+              clamp_prob, "maximum scale: ", max_scale.item(), file=sys.stderr)
+    wav.clamp_(-1, 1)
+
+
+def normalize_audio(wav: torch.Tensor, normalize: bool = True,
+                    strategy: str = 'peak', peak_clip_headroom_db: float = 1,
+                    rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
+                    loudness_compressor: bool = False, log_clipping: bool = False,
+                    sample_rate: tp.Optional[int] = None,
+                    stem_name: tp.Optional[str] = None) -> torch.Tensor:
+    """Normalize the audio according to the prescribed strategy (see after).
+
+    Args:
+        wav (torch.Tensor): Audio data.
+        normalize (bool): if `True` (default), normalizes according to the prescribed
+            strategy (see after). If `False`, the strategy is only used in case clipping
+            would happen.
+        strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
+            i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
+            with extra headroom to avoid clipping. 'clip' just clips.
+        peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
+        rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
+            than the `peak_clip` one to avoid further clipping.
+        loudness_headroom_db (float): Target loudness for loudness normalization.
+        loudness_compressor (bool): If True, uses tanh based soft clipping.
+        log_clipping (bool): If True, basic logging on stderr when clipping still
+            occurs despite strategy (only for 'rms').
+        sample_rate (int): Sample rate for the audio data (required for loudness).
+        stem_name (Optional[str]): Stem name for clipping logging.
+    Returns:
+        torch.Tensor: Normalized audio.
+    """
+    scale_peak = 10 ** (-peak_clip_headroom_db / 20)
+    scale_rms = 10 ** (-rms_headroom_db / 20)
+    if strategy == 'peak':
+        rescaling = (scale_peak / wav.abs().max())
+        if normalize or rescaling < 1:
+            wav = wav * rescaling
+    elif strategy == 'clip':
+        wav = wav.clamp(-scale_peak, scale_peak)
+    elif strategy == 'rms':
+        mono = wav.mean(dim=0)
+        rescaling = scale_rms / mono.pow(2).mean().sqrt()
+        if normalize or rescaling < 1:
+            wav = wav * rescaling
+        _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
+    elif strategy == 'loudness':
+        assert sample_rate is not None, "Loudness normalization requires sample rate."
+        wav = normalize_loudness(wav, sample_rate, loudness_headroom_db, loudness_compressor)
+        _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
+    else:
+        assert wav.abs().max() < 1
+        assert strategy == '' or strategy == 'none', f"Unexpected strategy: '{strategy}'"
+    return wav
+
+
+def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
+    """Convert audio to float 32 bits PCM format.
+    """
+    if wav.dtype.is_floating_point:
+        return wav
+    else:
+        assert wav.dtype == torch.int16
+        return wav.float() / 2**15
+
+
+def i16_pcm(wav: torch.Tensor) -> torch.Tensor:
+    """Convert audio to int 16 bits PCM format.
+
+    ..Warning:: There exist many formula for doing this convertion. None are perfect
+    due to the asymetry of the int16 range. One either have possible clipping, DC offset,
+    or inconsistancies with f32_pcm. If the given wav doesn't have enough headroom,
+    it is possible that `i16_pcm(f32_pcm)) != Identity`.
+    """
+    if wav.dtype.is_floating_point:
+        assert wav.abs().max() <= 1
+        candidate = (wav * 2 ** 15).round()
+        if candidate.max() >= 2 ** 15:  # clipping would occur
+            candidate = (wav * (2 ** 15 - 1)).round()
+        return candidate.short()
+    else:
+        assert wav.dtype == torch.int16
+        return wav
+
+
+
+
+
+
+
+

Functions

+
+
+def convert_audio(wav: torch.Tensor, from_rate: float, to_rate: float, to_channels: int) ‑> torch.Tensor +
+
+

Convert audio to new sample rate and number of audio channels.

+
+ +Expand source code + +
def convert_audio(wav: torch.Tensor, from_rate: float,
+                  to_rate: float, to_channels: int) -> torch.Tensor:
+    """Convert audio to new sample rate and number of audio channels.
+    """
+    wav = julius.resample_frac(wav, int(from_rate), int(to_rate))
+    wav = convert_audio_channels(wav, to_channels)
+    return wav
+
+
+
+def convert_audio_channels(wav: torch.Tensor, channels: int = 2) ‑> torch.Tensor +
+
+

Convert audio to the given number of channels.

+

Args

+
+
wav : torch.Tensor
+
Audio wave of shape [B, C, T].
+
channels : int
+
Expected number of channels as output.
+
+

Returns

+
+
torch.Tensor
+
Downmixed or unchanged audio wave [B, C, T].
+
+
+ +Expand source code + +
def convert_audio_channels(wav: torch.Tensor, channels: int = 2) -> torch.Tensor:
+    """Convert audio to the given number of channels.
+
+    Args:
+        wav (torch.Tensor): Audio wave of shape [B, C, T].
+        channels (int): Expected number of channels as output.
+    Returns:
+        torch.Tensor: Downmixed or unchanged audio wave [B, C, T].
+    """
+    *shape, src_channels, length = wav.shape
+    if src_channels == channels:
+        pass
+    elif channels == 1:
+        # Case 1:
+        # The caller asked 1-channel audio, and the stream has multiple
+        # channels, downmix all channels.
+        wav = wav.mean(dim=-2, keepdim=True)
+    elif src_channels == 1:
+        # Case 2:
+        # The caller asked for multiple channels, but the input file has
+        # a single channel, replicate the audio over all channels.
+        wav = wav.expand(*shape, channels, length)
+    elif src_channels >= channels:
+        # Case 3:
+        # The caller asked for multiple channels, and the input file has
+        # more channels than requested. In that case return the first channels.
+        wav = wav[..., :channels, :]
+    else:
+        # Case 4: What is a reasonable choice here?
+        raise ValueError('The audio file has less channels than requested but is not mono.')
+    return wav
+
+
+
+def f32_pcm(wav: torch.Tensor) ‑> torch.Tensor +
+
+

Convert audio to float 32 bits PCM format.

+
+ +Expand source code + +
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
+    """Convert audio to float 32 bits PCM format.
+    """
+    if wav.dtype.is_floating_point:
+        return wav
+    else:
+        assert wav.dtype == torch.int16
+        return wav.float() / 2**15
+
+
+
+def i16_pcm(wav: torch.Tensor) ‑> torch.Tensor +
+
+

Convert audio to int 16 bits PCM format.

+
+

Warning: There exist many formula for doing this convertion. None are perfect

+
+

due to the asymetry of the int16 range. One either have possible clipping, DC offset, +or inconsistancies with f32_pcm. If the given wav doesn't have enough headroom, +it is possible that i16_pcm(f32_pcm)) != Identity.

+
+ +Expand source code + +
def i16_pcm(wav: torch.Tensor) -> torch.Tensor:
+    """Convert audio to int 16 bits PCM format.
+
+    ..Warning:: There exist many formula for doing this convertion. None are perfect
+    due to the asymetry of the int16 range. One either have possible clipping, DC offset,
+    or inconsistancies with f32_pcm. If the given wav doesn't have enough headroom,
+    it is possible that `i16_pcm(f32_pcm)) != Identity`.
+    """
+    if wav.dtype.is_floating_point:
+        assert wav.abs().max() <= 1
+        candidate = (wav * 2 ** 15).round()
+        if candidate.max() >= 2 ** 15:  # clipping would occur
+            candidate = (wav * (2 ** 15 - 1)).round()
+        return candidate.short()
+    else:
+        assert wav.dtype == torch.int16
+        return wav
+
+
+
+def normalize_audio(wav: torch.Tensor, normalize: bool = True, strategy: str = 'peak', peak_clip_headroom_db: float = 1, rms_headroom_db: float = 18, loudness_headroom_db: float = 14, loudness_compressor: bool = False, log_clipping: bool = False, sample_rate: Optional[int] = None, stem_name: Optional[str] = None) ‑> torch.Tensor +
+
+

Normalize the audio according to the prescribed strategy (see after).

+

Args

+
+
wav : torch.Tensor
+
Audio data.
+
normalize : bool
+
if True (default), normalizes according to the prescribed +strategy (see after). If False, the strategy is only used in case clipping +would happen.
+
strategy : str
+
Can be either 'clip', 'peak', or 'rms'. Default is 'peak', +i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square +with extra headroom to avoid clipping. 'clip' just clips.
+
peak_clip_headroom_db : float
+
Headroom in dB when doing 'peak' or 'clip' strategy.
+
rms_headroom_db : float
+
Headroom in dB when doing 'rms' strategy. This must be much larger +than the peak_clip one to avoid further clipping.
+
loudness_headroom_db : float
+
Target loudness for loudness normalization.
+
loudness_compressor : bool
+
If True, uses tanh based soft clipping.
+
log_clipping : bool
+
If True, basic logging on stderr when clipping still +occurs despite strategy (only for 'rms').
+
sample_rate : int
+
Sample rate for the audio data (required for loudness).
+
stem_name : Optional[str]
+
Stem name for clipping logging.
+
+

Returns

+
+
torch.Tensor
+
Normalized audio.
+
+
+ +Expand source code + +
def normalize_audio(wav: torch.Tensor, normalize: bool = True,
+                    strategy: str = 'peak', peak_clip_headroom_db: float = 1,
+                    rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
+                    loudness_compressor: bool = False, log_clipping: bool = False,
+                    sample_rate: tp.Optional[int] = None,
+                    stem_name: tp.Optional[str] = None) -> torch.Tensor:
+    """Normalize the audio according to the prescribed strategy (see after).
+
+    Args:
+        wav (torch.Tensor): Audio data.
+        normalize (bool): if `True` (default), normalizes according to the prescribed
+            strategy (see after). If `False`, the strategy is only used in case clipping
+            would happen.
+        strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
+            i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
+            with extra headroom to avoid clipping. 'clip' just clips.
+        peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
+        rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
+            than the `peak_clip` one to avoid further clipping.
+        loudness_headroom_db (float): Target loudness for loudness normalization.
+        loudness_compressor (bool): If True, uses tanh based soft clipping.
+        log_clipping (bool): If True, basic logging on stderr when clipping still
+            occurs despite strategy (only for 'rms').
+        sample_rate (int): Sample rate for the audio data (required for loudness).
+        stem_name (Optional[str]): Stem name for clipping logging.
+    Returns:
+        torch.Tensor: Normalized audio.
+    """
+    scale_peak = 10 ** (-peak_clip_headroom_db / 20)
+    scale_rms = 10 ** (-rms_headroom_db / 20)
+    if strategy == 'peak':
+        rescaling = (scale_peak / wav.abs().max())
+        if normalize or rescaling < 1:
+            wav = wav * rescaling
+    elif strategy == 'clip':
+        wav = wav.clamp(-scale_peak, scale_peak)
+    elif strategy == 'rms':
+        mono = wav.mean(dim=0)
+        rescaling = scale_rms / mono.pow(2).mean().sqrt()
+        if normalize or rescaling < 1:
+            wav = wav * rescaling
+        _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
+    elif strategy == 'loudness':
+        assert sample_rate is not None, "Loudness normalization requires sample rate."
+        wav = normalize_loudness(wav, sample_rate, loudness_headroom_db, loudness_compressor)
+        _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
+    else:
+        assert wav.abs().max() < 1
+        assert strategy == '' or strategy == 'none', f"Unexpected strategy: '{strategy}'"
+    return wav
+
+
+
+def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db: float = 14, loudness_compressor: bool = False, energy_floor: float = 0.002) +
+
+

Normalize an input signal to a user loudness in dB LKFS. +Audio loudness is defined according to the ITU-R BS.1770-4 recommendation.

+

Args

+
+
wav : torch.Tensor
+
Input multichannel audio data.
+
sample_rate : int
+
Sample rate.
+
loudness_headroom_db : float
+
Target loudness of the output in dB LUFS.
+
loudness_compressor : bool
+
Uses tanh for soft clipping.
+
energy_floor : float
+
anything below that RMS level will not be rescaled.
+
+

Returns

+

output (torch.Tensor): Loudness normalized output data.

+
+ +Expand source code + +
def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db: float = 14,
+                       loudness_compressor: bool = False, energy_floor: float = 2e-3):
+    """Normalize an input signal to a user loudness in dB LKFS.
+    Audio loudness is defined according to the ITU-R BS.1770-4 recommendation.
+
+    Args:
+        wav (torch.Tensor): Input multichannel audio data.
+        sample_rate (int): Sample rate.
+        loudness_headroom_db (float): Target loudness of the output in dB LUFS.
+        loudness_compressor (bool): Uses tanh for soft clipping.
+        energy_floor (float): anything below that RMS level will not be rescaled.
+    Returns:
+        output (torch.Tensor): Loudness normalized output data.
+    """
+    energy = wav.pow(2).mean().sqrt().item()
+    if energy < energy_floor:
+        return wav
+    transform = torchaudio.transforms.Loudness(sample_rate)
+    input_loudness_db = transform(wav).item()
+    # calculate the gain needed to scale to the desired loudness level
+    delta_loudness = -loudness_headroom_db - input_loudness_db
+    gain = 10.0 ** (delta_loudness / 20.0)
+    output = gain * wav
+    if loudness_compressor:
+        output = torch.tanh(output)
+    assert output.isfinite().all(), (input_loudness_db, wav.pow(2).mean().sqrt())
+    return output
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/audiocraft/data/index.html b/docs/audiocraft/data/index.html new file mode 100644 index 00000000..84b4c7b4 --- /dev/null +++ b/docs/audiocraft/data/index.html @@ -0,0 +1,94 @@ + + + + + + +audiocraft.data API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.data

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# flake8: noqa
+from . import audio, audio_dataset
+
+
+
+

Sub-modules

+
+
audiocraft.data.audio
+
+

Audio IO methods are defined in this module (info, read, write), +We rely on av library for faster read when possible, otherwise on torchaudio.

+
+
audiocraft.data.audio_dataset
+
+
+
+
audiocraft.data.audio_utils
+
+
+
+
audiocraft.data.zip
+
+
+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/audiocraft/data/zip.html b/docs/audiocraft/data/zip.html new file mode 100644 index 00000000..d8bcfcef --- /dev/null +++ b/docs/audiocraft/data/zip.html @@ -0,0 +1,289 @@ + + + + + + +audiocraft.data.zip API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.data.zip

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import typing
+import zipfile
+
+from dataclasses import dataclass
+from functools import lru_cache
+from typing_extensions import Literal
+
+
+DEFAULT_SIZE = 32
+MODE = Literal['r', 'w', 'x', 'a']
+
+
+@dataclass(order=True)
+class PathInZip:
+    """Class for holding a path of file within a zip file.
+
+    Args:
+        path: The convention is <path_to_zip>:<relative_path_inside_zip>
+            Let's assume there is a zip file /some/location/foo.zip
+            and inside of it is a json file located at /data/file1.json,
+            Then we expect path = "/some/location/foo.zip:/data/file1.json"
+    """
+
+    INFO_PATH_SEP = ':'
+    zip_path: str
+    file_path: str
+
+    def __init__(self, path: str) -> None:
+        split_path = path.split(self.INFO_PATH_SEP)
+        assert len(split_path) == 2
+        self.zip_path, self.file_path = split_path
+
+    @classmethod
+    def from_paths(cls, zip_path: str, file_path: str):
+        return cls(zip_path + cls.INFO_PATH_SEP + file_path)
+
+    def __str__(self) -> str:
+        return self.zip_path + self.INFO_PATH_SEP + self.file_path
+
+
+def _open_zip(path: str, mode: MODE = 'r'):
+    return zipfile.ZipFile(path, mode)
+
+
+_cached_open_zip = lru_cache(DEFAULT_SIZE)(_open_zip)
+
+
+def set_zip_cache_size(max_size: int):
+    """Sets the maximal LRU caching for zip file opening.
+
+    Args:
+        max_size: the maximal LRU cache.
+    """
+    global _cached_open_zip
+    _cached_open_zip = lru_cache(max_size)(_open_zip)
+
+
+def open_file_in_zip(path_in_zip: PathInZip, mode: str = 'r') -> typing.IO:
+    """Opens a file stored inside a zip and returns a file-like object.
+
+    Args:
+        path_in_zip: A PathInZip object representing the file to return a file-like object of.
+        mode: The mode in which to open the file with.
+    Returns:
+        A file-like object for PathInZip.
+    """
+    zf = _cached_open_zip(path_in_zip.zip_path)
+    return zf.open(path_in_zip.file_path)
+
+
+
+
+
+
+
+

Functions

+
+
+def open_file_in_zip(path_in_zip: PathInZip, mode: str = 'r') ‑>  +
+
+

Opens a file stored inside a zip and returns a file-like object.

+

Args

+
+
path_in_zip
+
A PathInZip object representing the file to return a file-like object of.
+
mode
+
The mode in which to open the file with.
+
+

Returns

+

A file-like object for PathInZip.

+
+ +Expand source code + +
def open_file_in_zip(path_in_zip: PathInZip, mode: str = 'r') -> typing.IO:
+    """Opens a file stored inside a zip and returns a file-like object.
+
+    Args:
+        path_in_zip: A PathInZip object representing the file to return a file-like object of.
+        mode: The mode in which to open the file with.
+    Returns:
+        A file-like object for PathInZip.
+    """
+    zf = _cached_open_zip(path_in_zip.zip_path)
+    return zf.open(path_in_zip.file_path)
+
+
+
+def set_zip_cache_size(max_size: int) +
+
+

Sets the maximal LRU caching for zip file opening.

+

Args

+
+
max_size
+
the maximal LRU cache.
+
+
+ +Expand source code + +
def set_zip_cache_size(max_size: int):
+    """Sets the maximal LRU caching for zip file opening.
+
+    Args:
+        max_size: the maximal LRU cache.
+    """
+    global _cached_open_zip
+    _cached_open_zip = lru_cache(max_size)(_open_zip)
+
+
+
+
+
+

Classes

+
+
+class PathInZip +(path: str) +
+
+

Class for holding a path of file within a zip file.

+

Args

+
+
path
+
The convention is : +Let's assume there is a zip file /some/location/foo.zip +and inside of it is a json file located at /data/file1.json, +Then we expect path = "/some/location/foo.zip:/data/file1.json"
+
+
+ +Expand source code + +
class PathInZip:
+    """Class for holding a path of file within a zip file.
+
+    Args:
+        path: The convention is <path_to_zip>:<relative_path_inside_zip>
+            Let's assume there is a zip file /some/location/foo.zip
+            and inside of it is a json file located at /data/file1.json,
+            Then we expect path = "/some/location/foo.zip:/data/file1.json"
+    """
+
+    INFO_PATH_SEP = ':'
+    zip_path: str
+    file_path: str
+
+    def __init__(self, path: str) -> None:
+        split_path = path.split(self.INFO_PATH_SEP)
+        assert len(split_path) == 2
+        self.zip_path, self.file_path = split_path
+
+    @classmethod
+    def from_paths(cls, zip_path: str, file_path: str):
+        return cls(zip_path + cls.INFO_PATH_SEP + file_path)
+
+    def __str__(self) -> str:
+        return self.zip_path + self.INFO_PATH_SEP + self.file_path
+
+

Class variables

+
+
var INFO_PATH_SEP
+
+
+
+
var file_path : str
+
+
+
+
var zip_path : str
+
+
+
+
+

Static methods

+
+
+def from_paths(zip_path: str, file_path: str) +
+
+
+
+ +Expand source code + +
@classmethod
+def from_paths(cls, zip_path: str, file_path: str):
+    return cls(zip_path + cls.INFO_PATH_SEP + file_path)
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/audiocraft/index.html b/docs/audiocraft/index.html new file mode 100644 index 00000000..2a77ad7d --- /dev/null +++ b/docs/audiocraft/index.html @@ -0,0 +1,95 @@ + + + + + + +audiocraft API documentation + + + + + + + + + + + +
+
+
+

Package audiocraft

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# flake8: noqa
+from . import data, modules, models
+
+__version__ = '0.0.2a2'
+
+
+
+

Sub-modules

+
+
audiocraft.data
+
+
+
+
audiocraft.models
+
+
+
+
audiocraft.modules
+
+
+
+
audiocraft.quantization
+
+
+
+
audiocraft.utils
+
+
+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/audiocraft/models/builders.html b/docs/audiocraft/models/builders.html new file mode 100644 index 00000000..867a760b --- /dev/null +++ b/docs/audiocraft/models/builders.html @@ -0,0 +1,556 @@ + + + + + + +audiocraft.models.builders API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.models.builders

+
+
+

All the functions to build the relevant models and modules +from the Hydra config.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+All the functions to build the relevant models and modules
+from the Hydra config.
+"""
+
+import typing as tp
+import warnings
+
+import audiocraft
+import omegaconf
+import torch
+
+from .encodec import CompressionModel, EncodecModel, FlattenedCompressionModel  # noqa
+from .lm import LMModel
+from ..modules.codebooks_patterns import (
+    CodebooksPatternProvider,
+    DelayedPatternProvider,
+    ParallelPatternProvider,
+    UnrolledPatternProvider,
+    VALLEPattern,
+    MusicLMPattern,
+)
+from ..modules.conditioners import (
+    BaseConditioner,
+    ConditioningProvider,
+    LUTConditioner,
+    T5Conditioner,
+    ConditionFuser,
+    ChromaStemConditioner,
+)
+from .. import quantization as qt
+from ..utils.utils import dict_from_config
+
+
+def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) -> qt.BaseQuantizer:
+    klass = {
+        'no_quant': qt.DummyQuantizer,
+        'rvq': qt.ResidualVectorQuantizer
+    }[quantizer]
+    kwargs = dict_from_config(getattr(cfg, quantizer))
+    if quantizer != 'no_quant':
+        kwargs['dimension'] = dimension
+    return klass(**kwargs)
+
+
+def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig):
+    if encoder_name == 'seanet':
+        kwargs = dict_from_config(getattr(cfg, 'seanet'))
+        encoder_override_kwargs = kwargs.pop('encoder')
+        decoder_override_kwargs = kwargs.pop('decoder')
+        encoder_kwargs = {**kwargs, **encoder_override_kwargs}
+        decoder_kwargs = {**kwargs, **decoder_override_kwargs}
+        encoder = audiocraft.modules.SEANetEncoder(**encoder_kwargs)
+        decoder = audiocraft.modules.SEANetDecoder(**decoder_kwargs)
+        return encoder, decoder
+    else:
+        raise KeyError(f'Unexpected compression model {cfg.compression_model}')
+
+
+def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel:
+    """Instantiate a compression model.
+    """
+    if cfg.compression_model == 'encodec':
+        kwargs = dict_from_config(getattr(cfg, 'encodec'))
+        encoder_name = kwargs.pop('autoencoder')
+        quantizer_name = kwargs.pop('quantizer')
+        encoder, decoder = get_encodec_autoencoder(encoder_name, cfg)
+        quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension)
+        frame_rate = kwargs['sample_rate'] // encoder.hop_length
+        renormalize = kwargs.pop('renormalize', None)
+        renorm = kwargs.pop('renorm')
+        if renormalize is None:
+            renormalize = renorm is not None
+            warnings.warn("You are using a deprecated EnCodec model. Please migrate to new renormalization.")
+        return EncodecModel(encoder, decoder, quantizer,
+                            frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device)
+    else:
+        raise KeyError(f'Unexpected compression model {cfg.compression_model}')
+
+
+def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel:
+    """Instantiate a transformer LM.
+    """
+    if cfg.lm_model == 'transformer_lm':
+        kwargs = dict_from_config(getattr(cfg, 'transformer_lm'))
+        n_q = kwargs['n_q']
+        q_modeling = kwargs.pop('q_modeling', None)
+        codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
+        attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout'))
+        cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance'))
+        cfg_prob, cfg_coef = cls_free_guidance["training_dropout"], cls_free_guidance["inference_coef"]
+        fuser = get_condition_fuser(cfg)
+        condition_provider = get_conditioner_provider(kwargs["dim"], cfg).to(cfg.device)
+        if len(fuser.fuse2cond['cross']) > 0:  # enforce cross-att programatically
+            kwargs['cross_attention'] = True
+        if codebooks_pattern_cfg.modeling is None:
+            assert q_modeling is not None, \
+                'LM model should either have a codebook pattern defined or transformer_lm.q_modeling'
+            codebooks_pattern_cfg = omegaconf.OmegaConf.create(
+                {'modeling': q_modeling, 'delay': {'delays': list(range(n_q))}}
+            )
+        pattern_provider = get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg)
+        return LMModel(
+            pattern_provider=pattern_provider,
+            condition_provider=condition_provider,
+            fuser=fuser,
+            cfg_dropout=cfg_prob,
+            cfg_coef=cfg_coef,
+            attribute_dropout=attribute_dropout,
+            dtype=getattr(torch, cfg.dtype),
+            device=cfg.device,
+            **kwargs
+        ).to(cfg.device)
+    else:
+        raise KeyError(f'Unexpected LM model {cfg.lm_model}')
+
+
+def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> ConditioningProvider:
+    """Instantiate a conditioning model.
+    """
+    device = cfg.device
+    duration = cfg.dataset.segment_duration
+    cfg = getattr(cfg, "conditioners")
+    cfg = omegaconf.OmegaConf.create({}) if cfg is None else cfg
+    conditioners: tp.Dict[str, BaseConditioner] = {}
+    with omegaconf.open_dict(cfg):
+        condition_provider_args = cfg.pop('args', {})
+    for cond, cond_cfg in cfg.items():
+        model_type = cond_cfg["model"]
+        model_args = cond_cfg[model_type]
+        if model_type == "t5":
+            conditioners[str(cond)] = T5Conditioner(output_dim=output_dim, device=device, **model_args)
+        elif model_type == "lut":
+            conditioners[str(cond)] = LUTConditioner(output_dim=output_dim, **model_args)
+        elif model_type == "chroma_stem":
+            model_args.pop('cache_path', None)
+            conditioners[str(cond)] = ChromaStemConditioner(
+                output_dim=output_dim,
+                duration=duration,
+                device=device,
+                **model_args
+            )
+        else:
+            raise ValueError(f"unrecognized conditioning model: {model_type}")
+    conditioner = ConditioningProvider(conditioners, device=device, **condition_provider_args)
+    return conditioner
+
+
+def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser:
+    """Instantiate a condition fuser object.
+    """
+    fuser_cfg = getattr(cfg, "fuser")
+    fuser_methods = ["sum", "cross", "prepend", "input_interpolate"]
+    fuse2cond = {k: fuser_cfg[k] for k in fuser_methods}
+    kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods}
+    fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs)
+    return fuser
+
+
+def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> CodebooksPatternProvider:
+    """Instantiate a codebooks pattern provider object.
+    """
+    pattern_providers = {
+        'parallel': ParallelPatternProvider,
+        'delay': DelayedPatternProvider,
+        'unroll': UnrolledPatternProvider,
+        'valle': VALLEPattern,
+        'musiclm': MusicLMPattern,
+    }
+    name = cfg.modeling
+    kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {}
+    klass = pattern_providers[name]
+    return klass(n_q, **kwargs)
+
+
+def get_debug_compression_model(device='cpu'):
+    """Instantiate a debug compression model to be used for unit tests.
+    """
+    seanet_kwargs = {
+        'n_filters': 4,
+        'n_residual_layers': 1,
+        'dimension': 32,
+        'ratios': [10, 8, 16]  # 25 Hz at 32kHz
+    }
+    encoder = audiocraft.modules.SEANetEncoder(**seanet_kwargs)
+    decoder = audiocraft.modules.SEANetDecoder(**seanet_kwargs)
+    quantizer = qt.ResidualVectorQuantizer(dimension=32, bins=400, n_q=4)
+    init_x = torch.randn(8, 32, 128)
+    quantizer(init_x, 1)  # initialize kmeans etc.
+    compression_model = EncodecModel(
+        encoder, decoder, quantizer,
+        frame_rate=25, sample_rate=32000, channels=1).to(device)
+    return compression_model.eval()
+
+
+def get_debug_lm_model(device='cpu'):
+    """Instantiate a debug LM to be used for unit tests.
+    """
+    pattern = DelayedPatternProvider(n_q=4)
+    dim = 16
+    providers = {
+        'description': LUTConditioner(n_bins=128, dim=dim, output_dim=dim, tokenizer="whitespace"),
+    }
+    condition_provider = ConditioningProvider(providers)
+    fuser = ConditionFuser(
+        {'cross': ['description'], 'prepend': [],
+         'sum': [], 'input_interpolate': []})
+    lm = LMModel(
+        pattern, condition_provider, fuser,
+        n_q=4, card=400, dim=dim, num_heads=4, custom=True, num_layers=2,
+        cross_attention=True, causal=True)
+    return lm.to(device).eval()
+
+
+
+
+
+
+
+

Functions

+
+
+def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.dictconfig.DictConfig) ‑> CodebooksPatternProvider +
+
+

Instantiate a codebooks pattern provider object.

+
+ +Expand source code + +
def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> CodebooksPatternProvider:
+    """Instantiate a codebooks pattern provider object.
+    """
+    pattern_providers = {
+        'parallel': ParallelPatternProvider,
+        'delay': DelayedPatternProvider,
+        'unroll': UnrolledPatternProvider,
+        'valle': VALLEPattern,
+        'musiclm': MusicLMPattern,
+    }
+    name = cfg.modeling
+    kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {}
+    klass = pattern_providers[name]
+    return klass(n_q, **kwargs)
+
+
+
+def get_compression_model(cfg: omegaconf.dictconfig.DictConfig) ‑> CompressionModel +
+
+

Instantiate a compression model.

+
+ +Expand source code + +
def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel:
+    """Instantiate a compression model.
+    """
+    if cfg.compression_model == 'encodec':
+        kwargs = dict_from_config(getattr(cfg, 'encodec'))
+        encoder_name = kwargs.pop('autoencoder')
+        quantizer_name = kwargs.pop('quantizer')
+        encoder, decoder = get_encodec_autoencoder(encoder_name, cfg)
+        quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension)
+        frame_rate = kwargs['sample_rate'] // encoder.hop_length
+        renormalize = kwargs.pop('renormalize', None)
+        renorm = kwargs.pop('renorm')
+        if renormalize is None:
+            renormalize = renorm is not None
+            warnings.warn("You are using a deprecated EnCodec model. Please migrate to new renormalization.")
+        return EncodecModel(encoder, decoder, quantizer,
+                            frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device)
+    else:
+        raise KeyError(f'Unexpected compression model {cfg.compression_model}')
+
+
+
+def get_condition_fuser(cfg: omegaconf.dictconfig.DictConfig) ‑> ConditionFuser +
+
+

Instantiate a condition fuser object.

+
+ +Expand source code + +
def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser:
+    """Instantiate a condition fuser object.
+    """
+    fuser_cfg = getattr(cfg, "fuser")
+    fuser_methods = ["sum", "cross", "prepend", "input_interpolate"]
+    fuse2cond = {k: fuser_cfg[k] for k in fuser_methods}
+    kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods}
+    fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs)
+    return fuser
+
+
+
+def get_conditioner_provider(output_dim: int, cfg: omegaconf.dictconfig.DictConfig) ‑> ConditioningProvider +
+
+

Instantiate a conditioning model.

+
+ +Expand source code + +
def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> ConditioningProvider:
+    """Instantiate a conditioning model.
+    """
+    device = cfg.device
+    duration = cfg.dataset.segment_duration
+    cfg = getattr(cfg, "conditioners")
+    cfg = omegaconf.OmegaConf.create({}) if cfg is None else cfg
+    conditioners: tp.Dict[str, BaseConditioner] = {}
+    with omegaconf.open_dict(cfg):
+        condition_provider_args = cfg.pop('args', {})
+    for cond, cond_cfg in cfg.items():
+        model_type = cond_cfg["model"]
+        model_args = cond_cfg[model_type]
+        if model_type == "t5":
+            conditioners[str(cond)] = T5Conditioner(output_dim=output_dim, device=device, **model_args)
+        elif model_type == "lut":
+            conditioners[str(cond)] = LUTConditioner(output_dim=output_dim, **model_args)
+        elif model_type == "chroma_stem":
+            model_args.pop('cache_path', None)
+            conditioners[str(cond)] = ChromaStemConditioner(
+                output_dim=output_dim,
+                duration=duration,
+                device=device,
+                **model_args
+            )
+        else:
+            raise ValueError(f"unrecognized conditioning model: {model_type}")
+    conditioner = ConditioningProvider(conditioners, device=device, **condition_provider_args)
+    return conditioner
+
+
+
+def get_debug_compression_model(device='cpu') +
+
+

Instantiate a debug compression model to be used for unit tests.

+
+ +Expand source code + +
def get_debug_compression_model(device='cpu'):
+    """Instantiate a debug compression model to be used for unit tests.
+    """
+    seanet_kwargs = {
+        'n_filters': 4,
+        'n_residual_layers': 1,
+        'dimension': 32,
+        'ratios': [10, 8, 16]  # 25 Hz at 32kHz
+    }
+    encoder = audiocraft.modules.SEANetEncoder(**seanet_kwargs)
+    decoder = audiocraft.modules.SEANetDecoder(**seanet_kwargs)
+    quantizer = qt.ResidualVectorQuantizer(dimension=32, bins=400, n_q=4)
+    init_x = torch.randn(8, 32, 128)
+    quantizer(init_x, 1)  # initialize kmeans etc.
+    compression_model = EncodecModel(
+        encoder, decoder, quantizer,
+        frame_rate=25, sample_rate=32000, channels=1).to(device)
+    return compression_model.eval()
+
+
+
+def get_debug_lm_model(device='cpu') +
+
+

Instantiate a debug LM to be used for unit tests.

+
+ +Expand source code + +
def get_debug_lm_model(device='cpu'):
+    """Instantiate a debug LM to be used for unit tests.
+    """
+    pattern = DelayedPatternProvider(n_q=4)
+    dim = 16
+    providers = {
+        'description': LUTConditioner(n_bins=128, dim=dim, output_dim=dim, tokenizer="whitespace"),
+    }
+    condition_provider = ConditioningProvider(providers)
+    fuser = ConditionFuser(
+        {'cross': ['description'], 'prepend': [],
+         'sum': [], 'input_interpolate': []})
+    lm = LMModel(
+        pattern, condition_provider, fuser,
+        n_q=4, card=400, dim=dim, num_heads=4, custom=True, num_layers=2,
+        cross_attention=True, causal=True)
+    return lm.to(device).eval()
+
+
+
+def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.dictconfig.DictConfig) +
+
+
+
+ +Expand source code + +
def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig):
+    if encoder_name == 'seanet':
+        kwargs = dict_from_config(getattr(cfg, 'seanet'))
+        encoder_override_kwargs = kwargs.pop('encoder')
+        decoder_override_kwargs = kwargs.pop('decoder')
+        encoder_kwargs = {**kwargs, **encoder_override_kwargs}
+        decoder_kwargs = {**kwargs, **decoder_override_kwargs}
+        encoder = audiocraft.modules.SEANetEncoder(**encoder_kwargs)
+        decoder = audiocraft.modules.SEANetDecoder(**decoder_kwargs)
+        return encoder, decoder
+    else:
+        raise KeyError(f'Unexpected compression model {cfg.compression_model}')
+
+
+
+def get_lm_model(cfg: omegaconf.dictconfig.DictConfig) ‑> LMModel +
+
+

Instantiate a transformer LM.

+
+ +Expand source code + +
def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel:
+    """Instantiate a transformer LM.
+    """
+    if cfg.lm_model == 'transformer_lm':
+        kwargs = dict_from_config(getattr(cfg, 'transformer_lm'))
+        n_q = kwargs['n_q']
+        q_modeling = kwargs.pop('q_modeling', None)
+        codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
+        attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout'))
+        cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance'))
+        cfg_prob, cfg_coef = cls_free_guidance["training_dropout"], cls_free_guidance["inference_coef"]
+        fuser = get_condition_fuser(cfg)
+        condition_provider = get_conditioner_provider(kwargs["dim"], cfg).to(cfg.device)
+        if len(fuser.fuse2cond['cross']) > 0:  # enforce cross-att programatically
+            kwargs['cross_attention'] = True
+        if codebooks_pattern_cfg.modeling is None:
+            assert q_modeling is not None, \
+                'LM model should either have a codebook pattern defined or transformer_lm.q_modeling'
+            codebooks_pattern_cfg = omegaconf.OmegaConf.create(
+                {'modeling': q_modeling, 'delay': {'delays': list(range(n_q))}}
+            )
+        pattern_provider = get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg)
+        return LMModel(
+            pattern_provider=pattern_provider,
+            condition_provider=condition_provider,
+            fuser=fuser,
+            cfg_dropout=cfg_prob,
+            cfg_coef=cfg_coef,
+            attribute_dropout=attribute_dropout,
+            dtype=getattr(torch, cfg.dtype),
+            device=cfg.device,
+            **kwargs
+        ).to(cfg.device)
+    else:
+        raise KeyError(f'Unexpected LM model {cfg.lm_model}')
+
+
+
+def get_quantizer(quantizer: str, cfg: omegaconf.dictconfig.DictConfig, dimension: int) ‑> BaseQuantizer +
+
+
+
+ +Expand source code + +
def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) -> qt.BaseQuantizer:
+    klass = {
+        'no_quant': qt.DummyQuantizer,
+        'rvq': qt.ResidualVectorQuantizer
+    }[quantizer]
+    kwargs = dict_from_config(getattr(cfg, quantizer))
+    if quantizer != 'no_quant':
+        kwargs['dimension'] = dimension
+    return klass(**kwargs)
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/audiocraft/models/encodec.html b/docs/audiocraft/models/encodec.html new file mode 100644 index 00000000..6413ca4d --- /dev/null +++ b/docs/audiocraft/models/encodec.html @@ -0,0 +1,1306 @@ + + + + + + +audiocraft.models.encodec API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.models.encodec

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from abc import ABC, abstractmethod
+import typing as tp
+
+from einops import rearrange
+import torch
+from torch import nn
+
+from .. import quantization as qt
+
+
+class CompressionModel(ABC, nn.Module):
+
+    @abstractmethod
+    def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
+        ...
+
+    @abstractmethod
+    def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
+        """See `EncodecModel.encode`"""
+        ...
+
+    @abstractmethod
+    def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
+        """See `EncodecModel.decode`"""
+        ...
+
+    @property
+    @abstractmethod
+    def channels(self) -> int:
+        ...
+
+    @property
+    @abstractmethod
+    def frame_rate(self) -> int:
+        ...
+
+    @property
+    @abstractmethod
+    def sample_rate(self) -> int:
+        ...
+
+    @property
+    @abstractmethod
+    def cardinality(self) -> int:
+        ...
+
+    @property
+    @abstractmethod
+    def num_codebooks(self) -> int:
+        ...
+
+    @property
+    @abstractmethod
+    def total_codebooks(self) -> int:
+        ...
+
+    @abstractmethod
+    def set_num_codebooks(self, n: int):
+        """Set the active number of codebooks used by the quantizer.
+        """
+        ...
+
+
+class EncodecModel(CompressionModel):
+    """Encodec model operating on the raw waveform.
+
+    Args:
+        encoder (nn.Module): Encoder network.
+        decoder (nn.Module): Decoder network.
+        quantizer (qt.BaseQuantizer): Quantizer network.
+        frame_rate (int): Frame rate for the latent representation.
+        sample_rate (int): Audio sample rate.
+        channels (int): Number of audio channels.
+        causal (bool): Whether to use a causal version of the model.
+        renormalize (bool): Whether to renormalize the audio before running the model.
+    """
+    # we need assignement to override the property in the abstract class,
+    # I couldn't find a better way...
+    frame_rate: int = 0
+    sample_rate: int = 0
+    channels: int = 0
+
+    def __init__(self,
+                 encoder: nn.Module,
+                 decoder: nn.Module,
+                 quantizer: qt.BaseQuantizer,
+                 frame_rate: int,
+                 sample_rate: int,
+                 channels: int,
+                 causal: bool = False,
+                 renormalize: bool = False):
+        super().__init__()
+        self.encoder = encoder
+        self.decoder = decoder
+        self.quantizer = quantizer
+        self.frame_rate = frame_rate
+        self.sample_rate = sample_rate
+        self.channels = channels
+        self.renormalize = renormalize
+        self.causal = causal
+        if self.causal:
+            # we force disabling here to avoid handling linear overlap of segments
+            # as supported in original EnCodec codebase.
+            assert not self.renormalize, 'Causal model does not support renormalize'
+
+    @property
+    def total_codebooks(self):
+        """Total number of quantizer codebooks available.
+        """
+        return self.quantizer.total_codebooks
+
+    @property
+    def num_codebooks(self):
+        """Active number of codebooks used by the quantizer.
+        """
+        return self.quantizer.num_codebooks
+
+    def set_num_codebooks(self, n: int):
+        """Set the active number of codebooks used by the quantizer.
+        """
+        self.quantizer.set_num_codebooks(n)
+
+    @property
+    def cardinality(self):
+        """Cardinality of each codebook.
+        """
+        return self.quantizer.bins
+
+    def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
+        scale: tp.Optional[torch.Tensor]
+        if self.renormalize:
+            mono = x.mean(dim=1, keepdim=True)
+            volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt()
+            scale = 1e-8 + volume
+            x = x / scale
+            scale = scale.view(-1, 1)
+        else:
+            scale = None
+        return x, scale
+
+    def postprocess(self,
+                    x: torch.Tensor,
+                    scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
+        if scale is not None:
+            assert self.renormalize
+            x = x * scale.view(-1, 1, 1)
+        return x
+
+    def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
+        assert x.dim() == 3
+        length = x.shape[-1]
+        x, scale = self.preprocess(x)
+
+        emb = self.encoder(x)
+        q_res = self.quantizer(emb, self.frame_rate)
+        out = self.decoder(q_res.x)
+
+        # remove extra padding added by the encoder and decoder
+        assert out.shape[-1] >= length, (out.shape[-1], length)
+        out = out[..., :length]
+
+        q_res.x = self.postprocess(out, scale)
+
+        return q_res
+
+    def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
+        """Encode the given input tensor to quantized representation along with scale parameter.
+
+        Args:
+            x (torch.Tensor): Float tensor of shape [B, C, T]
+
+        Returns:
+            codes, scale (tp.Tuple[torch.Tensor, torch.Tensor]): Tuple composed of:
+                codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep.
+                scale a float tensor containing the scale for audio renormalizealization.
+        """
+        assert x.dim() == 3
+        x, scale = self.preprocess(x)
+        emb = self.encoder(x)
+        codes = self.quantizer.encode(emb)
+        return codes, scale
+
+    def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
+        """Decode the given codes to a reconstructed representation, using the scale to perform
+        audio denormalization if needed.
+
+        Args:
+            codes (torch.Tensor): Int tensor of shape [B, K, T]
+            scale (tp.Optional[torch.Tensor]): Float tensor containing the scale value.
+
+        Returns:
+            out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio.
+        """
+        emb = self.quantizer.decode(codes)
+        out = self.decoder(emb)
+        out = self.postprocess(out, scale)
+        # out contains extra padding added by the encoder and decoder
+        return out
+
+
+class FlattenedCompressionModel(CompressionModel):
+    """Wraps a CompressionModel and flatten its codebooks, e.g.
+    instead of returning [B, K, T], return [B, S, T * (K // S)] with
+    S the number of codebooks per step, and `K // S` the number of 'virtual steps'
+    for each real time step.
+
+    Args:
+        model (CompressionModel): compression model to wrap.
+        codebooks_per_step (int): number of codebooks to keep per step,
+            this must divide the number of codebooks provided by the wrapped model.
+        extend_cardinality (bool): if True, and for instance if codebooks_per_step = 1,
+            if each codebook has a cardinality N, then the first codebook will
+            use the range [0, N - 1], and the second [N, 2 N - 1] etc.
+            On decoding, this can lead to potentially invalid sequences.
+            Any invalid entry will be silently remapped to the proper range
+            with a modulo.
+    """
+    def __init__(self, model: CompressionModel, codebooks_per_step: int = 1,
+                 extend_cardinality: bool = True):
+        super().__init__()
+        self.model = model
+        self.codebooks_per_step = codebooks_per_step
+        self.extend_cardinality = extend_cardinality
+
+    @property
+    def total_codebooks(self):
+        return self.model.total_codebooks
+
+    @property
+    def num_codebooks(self):
+        """Active number of codebooks used by the quantizer.
+
+        ..Warning:: this reports the number of codebooks after the flattening
+        of the codebooks!
+        """
+        assert self.model.num_codebooks % self.codebooks_per_step == 0
+        return self.codebooks_per_step
+
+    def set_num_codebooks(self, n: int):
+        """Set the active number of codebooks used by the quantizer.
+
+        ..Warning:: this sets the number of codebooks **before** the flattening
+        of the codebooks.
+        """
+        assert n % self.codebooks_per_step == 0
+        self.model.set_num_codebooks(n)
+
+    @property
+    def num_virtual_steps(self) -> int:
+        """Return the number of virtual steps, e.g. one real step
+        will be split into that many steps.
+        """
+        return self.model.num_codebooks // self.codebooks_per_step
+
+    @property
+    def frame_rate(self) -> int:
+        return self.model.frame_rate * self.num_virtual_steps
+
+    @property
+    def sample_rate(self) -> int:
+        return self.model.sample_rate
+
+    @property
+    def channels(self) -> int:
+        return self.model.channels
+
+    @property
+    def cardinality(self):
+        """Cardinality of each codebook.
+        """
+        if self.extend_cardinality:
+            return self.model.cardinality * self.num_virtual_steps
+        else:
+            return self.model.cardinality
+
+    def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
+        raise NotImplementedError("Not supported, use encode and decode.")
+
+    def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
+        indices, scales = self.model.encode(x)
+        B, K, T = indices.shape
+        indices = rearrange(indices, 'b (k v) t -> b k t v', k=self.codebooks_per_step)
+        if self.extend_cardinality:
+            for virtual_step in range(1, self.num_virtual_steps):
+                indices[..., virtual_step] += self.model.cardinality * virtual_step
+        indices = rearrange(indices, 'b k t v -> b k (t v)')
+        return (indices, scales)
+
+    def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
+        B, K, T = codes.shape
+        assert T % self.num_virtual_steps == 0
+        codes = rearrange(codes, 'b k (t v) -> b (k v) t', v=self.num_virtual_steps)
+        # We silently ignore potential errors from the LM when
+        # using extend_cardinality.
+        codes = codes % self.model.cardinality
+        return self.model.decode(codes, scale)
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class CompressionModel +(*args, **kwargs) +
+
+

Helper class that provides a standard way to create an ABC using +inheritance.

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class CompressionModel(ABC, nn.Module):
+
+    @abstractmethod
+    def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
+        ...
+
+    @abstractmethod
+    def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
+        """See `EncodecModel.encode`"""
+        ...
+
+    @abstractmethod
+    def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
+        """See `EncodecModel.decode`"""
+        ...
+
+    @property
+    @abstractmethod
+    def channels(self) -> int:
+        ...
+
+    @property
+    @abstractmethod
+    def frame_rate(self) -> int:
+        ...
+
+    @property
+    @abstractmethod
+    def sample_rate(self) -> int:
+        ...
+
+    @property
+    @abstractmethod
+    def cardinality(self) -> int:
+        ...
+
+    @property
+    @abstractmethod
+    def num_codebooks(self) -> int:
+        ...
+
+    @property
+    @abstractmethod
+    def total_codebooks(self) -> int:
+        ...
+
+    @abstractmethod
+    def set_num_codebooks(self, n: int):
+        """Set the active number of codebooks used by the quantizer.
+        """
+        ...
+
+

Ancestors

+
    +
  • abc.ABC
  • +
  • torch.nn.modules.module.Module
  • +
+

Subclasses

+ +

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Instance variables

+
+
var cardinality : int
+
+
+
+ +Expand source code + +
@property
+@abstractmethod
+def cardinality(self) -> int:
+    ...
+
+
+
var channels : int
+
+
+
+ +Expand source code + +
@property
+@abstractmethod
+def channels(self) -> int:
+    ...
+
+
+
var frame_rate : int
+
+
+
+ +Expand source code + +
@property
+@abstractmethod
+def frame_rate(self) -> int:
+    ...
+
+
+
var num_codebooks : int
+
+
+
+ +Expand source code + +
@property
+@abstractmethod
+def num_codebooks(self) -> int:
+    ...
+
+
+
var sample_rate : int
+
+
+
+ +Expand source code + +
@property
+@abstractmethod
+def sample_rate(self) -> int:
+    ...
+
+
+
var total_codebooks : int
+
+
+
+ +Expand source code + +
@property
+@abstractmethod
+def total_codebooks(self) -> int:
+    ...
+
+
+
+

Methods

+
+
+def decode(self, codes: torch.Tensor, scale: Optional[torch.Tensor] = None) +
+
+ +
+ +Expand source code + +
@abstractmethod
+def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
+    """See `EncodecModel.decode`"""
+    ...
+
+
+
+def encode(self, x: torch.Tensor) ‑> Tuple[torch.Tensor, Optional[torch.Tensor]] +
+
+ +
+ +Expand source code + +
@abstractmethod
+def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
+    """See `EncodecModel.encode`"""
+    ...
+
+
+
+def forward(self, x: torch.Tensor) ‑> QuantizedResult +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
@abstractmethod
+def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
+    ...
+
+
+
+def set_num_codebooks(self, n: int) +
+
+

Set the active number of codebooks used by the quantizer.

+
+ +Expand source code + +
@abstractmethod
+def set_num_codebooks(self, n: int):
+    """Set the active number of codebooks used by the quantizer.
+    """
+    ...
+
+
+
+
+
+class EncodecModel +(encoder: torch.nn.modules.module.Module, decoder: torch.nn.modules.module.Module, quantizer: BaseQuantizer, frame_rate: int, sample_rate: int, channels: int, causal: bool = False, renormalize: bool = False) +
+
+

Encodec model operating on the raw waveform.

+

Args

+
+
encoder : nn.Module
+
Encoder network.
+
decoder : nn.Module
+
Decoder network.
+
quantizer : qt.BaseQuantizer
+
Quantizer network.
+
frame_rate : int
+
Frame rate for the latent representation.
+
sample_rate : int
+
Audio sample rate.
+
channels : int
+
Number of audio channels.
+
causal : bool
+
Whether to use a causal version of the model.
+
renormalize : bool
+
Whether to renormalize the audio before running the model.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class EncodecModel(CompressionModel):
+    """Encodec model operating on the raw waveform.
+
+    Args:
+        encoder (nn.Module): Encoder network.
+        decoder (nn.Module): Decoder network.
+        quantizer (qt.BaseQuantizer): Quantizer network.
+        frame_rate (int): Frame rate for the latent representation.
+        sample_rate (int): Audio sample rate.
+        channels (int): Number of audio channels.
+        causal (bool): Whether to use a causal version of the model.
+        renormalize (bool): Whether to renormalize the audio before running the model.
+    """
+    # we need assignement to override the property in the abstract class,
+    # I couldn't find a better way...
+    frame_rate: int = 0
+    sample_rate: int = 0
+    channels: int = 0
+
+    def __init__(self,
+                 encoder: nn.Module,
+                 decoder: nn.Module,
+                 quantizer: qt.BaseQuantizer,
+                 frame_rate: int,
+                 sample_rate: int,
+                 channels: int,
+                 causal: bool = False,
+                 renormalize: bool = False):
+        super().__init__()
+        self.encoder = encoder
+        self.decoder = decoder
+        self.quantizer = quantizer
+        self.frame_rate = frame_rate
+        self.sample_rate = sample_rate
+        self.channels = channels
+        self.renormalize = renormalize
+        self.causal = causal
+        if self.causal:
+            # we force disabling here to avoid handling linear overlap of segments
+            # as supported in original EnCodec codebase.
+            assert not self.renormalize, 'Causal model does not support renormalize'
+
+    @property
+    def total_codebooks(self):
+        """Total number of quantizer codebooks available.
+        """
+        return self.quantizer.total_codebooks
+
+    @property
+    def num_codebooks(self):
+        """Active number of codebooks used by the quantizer.
+        """
+        return self.quantizer.num_codebooks
+
+    def set_num_codebooks(self, n: int):
+        """Set the active number of codebooks used by the quantizer.
+        """
+        self.quantizer.set_num_codebooks(n)
+
+    @property
+    def cardinality(self):
+        """Cardinality of each codebook.
+        """
+        return self.quantizer.bins
+
+    def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
+        scale: tp.Optional[torch.Tensor]
+        if self.renormalize:
+            mono = x.mean(dim=1, keepdim=True)
+            volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt()
+            scale = 1e-8 + volume
+            x = x / scale
+            scale = scale.view(-1, 1)
+        else:
+            scale = None
+        return x, scale
+
+    def postprocess(self,
+                    x: torch.Tensor,
+                    scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
+        if scale is not None:
+            assert self.renormalize
+            x = x * scale.view(-1, 1, 1)
+        return x
+
+    def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
+        assert x.dim() == 3
+        length = x.shape[-1]
+        x, scale = self.preprocess(x)
+
+        emb = self.encoder(x)
+        q_res = self.quantizer(emb, self.frame_rate)
+        out = self.decoder(q_res.x)
+
+        # remove extra padding added by the encoder and decoder
+        assert out.shape[-1] >= length, (out.shape[-1], length)
+        out = out[..., :length]
+
+        q_res.x = self.postprocess(out, scale)
+
+        return q_res
+
+    def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
+        """Encode the given input tensor to quantized representation along with scale parameter.
+
+        Args:
+            x (torch.Tensor): Float tensor of shape [B, C, T]
+
+        Returns:
+            codes, scale (tp.Tuple[torch.Tensor, torch.Tensor]): Tuple composed of:
+                codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep.
+                scale a float tensor containing the scale for audio renormalizealization.
+        """
+        assert x.dim() == 3
+        x, scale = self.preprocess(x)
+        emb = self.encoder(x)
+        codes = self.quantizer.encode(emb)
+        return codes, scale
+
+    def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
+        """Decode the given codes to a reconstructed representation, using the scale to perform
+        audio denormalization if needed.
+
+        Args:
+            codes (torch.Tensor): Int tensor of shape [B, K, T]
+            scale (tp.Optional[torch.Tensor]): Float tensor containing the scale value.
+
+        Returns:
+            out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio.
+        """
+        emb = self.quantizer.decode(codes)
+        out = self.decoder(emb)
+        out = self.postprocess(out, scale)
+        # out contains extra padding added by the encoder and decoder
+        return out
+
+

Ancestors

+ +

Class variables

+
+
var channels : int
+
+
+
+
var frame_rate : int
+
+
+
+
var sample_rate : int
+
+
+
+
+

Instance variables

+
+
var cardinality
+
+

Cardinality of each codebook.

+
+ +Expand source code + +
@property
+def cardinality(self):
+    """Cardinality of each codebook.
+    """
+    return self.quantizer.bins
+
+
+
var num_codebooks
+
+

Active number of codebooks used by the quantizer.

+
+ +Expand source code + +
@property
+def num_codebooks(self):
+    """Active number of codebooks used by the quantizer.
+    """
+    return self.quantizer.num_codebooks
+
+
+
var total_codebooks
+
+

Total number of quantizer codebooks available.

+
+ +Expand source code + +
@property
+def total_codebooks(self):
+    """Total number of quantizer codebooks available.
+    """
+    return self.quantizer.total_codebooks
+
+
+
+

Methods

+
+
+def decode(self, codes: torch.Tensor, scale: Optional[torch.Tensor] = None) +
+
+

Decode the given codes to a reconstructed representation, using the scale to perform +audio denormalization if needed.

+

Args

+
+
codes : torch.Tensor
+
Int tensor of shape [B, K, T]
+
scale : tp.Optional[torch.Tensor]
+
Float tensor containing the scale value.
+
+

Returns

+

out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio.

+
+ +Expand source code + +
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
+    """Decode the given codes to a reconstructed representation, using the scale to perform
+    audio denormalization if needed.
+
+    Args:
+        codes (torch.Tensor): Int tensor of shape [B, K, T]
+        scale (tp.Optional[torch.Tensor]): Float tensor containing the scale value.
+
+    Returns:
+        out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio.
+    """
+    emb = self.quantizer.decode(codes)
+    out = self.decoder(emb)
+    out = self.postprocess(out, scale)
+    # out contains extra padding added by the encoder and decoder
+    return out
+
+
+
+def encode(self, x: torch.Tensor) ‑> Tuple[torch.Tensor, Optional[torch.Tensor]] +
+
+

Encode the given input tensor to quantized representation along with scale parameter.

+

Args

+
+
x : torch.Tensor
+
Float tensor of shape [B, C, T]
+
+

Returns

+

codes, scale (tp.Tuple[torch.Tensor, torch.Tensor]): Tuple composed of: +codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep. +scale a float tensor containing the scale for audio renormalizealization.

+
+ +Expand source code + +
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
+    """Encode the given input tensor to quantized representation along with scale parameter.
+
+    Args:
+        x (torch.Tensor): Float tensor of shape [B, C, T]
+
+    Returns:
+        codes, scale (tp.Tuple[torch.Tensor, torch.Tensor]): Tuple composed of:
+            codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep.
+            scale a float tensor containing the scale for audio renormalizealization.
+    """
+    assert x.dim() == 3
+    x, scale = self.preprocess(x)
+    emb = self.encoder(x)
+    codes = self.quantizer.encode(emb)
+    return codes, scale
+
+
+
+def postprocess(self, x: torch.Tensor, scale: Optional[torch.Tensor] = None) ‑> torch.Tensor +
+
+
+
+ +Expand source code + +
def postprocess(self,
+                x: torch.Tensor,
+                scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
+    if scale is not None:
+        assert self.renormalize
+        x = x * scale.view(-1, 1, 1)
+    return x
+
+
+
+def preprocess(self, x: torch.Tensor) ‑> Tuple[torch.Tensor, Optional[torch.Tensor]] +
+
+
+
+ +Expand source code + +
def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
+    scale: tp.Optional[torch.Tensor]
+    if self.renormalize:
+        mono = x.mean(dim=1, keepdim=True)
+        volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt()
+        scale = 1e-8 + volume
+        x = x / scale
+        scale = scale.view(-1, 1)
+    else:
+        scale = None
+    return x, scale
+
+
+
+

Inherited members

+ +
+
+class FlattenedCompressionModel +(model: CompressionModel, codebooks_per_step: int = 1, extend_cardinality: bool = True) +
+
+

Wraps a CompressionModel and flatten its codebooks, e.g. +instead of returning [B, K, T], return [B, S, T * (K // S)] with +S the number of codebooks per step, and K // S the number of 'virtual steps' +for each real time step.

+

Args

+
+
model : CompressionModel
+
compression model to wrap.
+
codebooks_per_step : int
+
number of codebooks to keep per step, +this must divide the number of codebooks provided by the wrapped model.
+
extend_cardinality : bool
+
if True, and for instance if codebooks_per_step = 1, +if each codebook has a cardinality N, then the first codebook will +use the range [0, N - 1], and the second [N, 2 N - 1] etc. +On decoding, this can lead to potentially invalid sequences. +Any invalid entry will be silently remapped to the proper range +with a modulo.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class FlattenedCompressionModel(CompressionModel):
+    """Wraps a CompressionModel and flatten its codebooks, e.g.
+    instead of returning [B, K, T], return [B, S, T * (K // S)] with
+    S the number of codebooks per step, and `K // S` the number of 'virtual steps'
+    for each real time step.
+
+    Args:
+        model (CompressionModel): compression model to wrap.
+        codebooks_per_step (int): number of codebooks to keep per step,
+            this must divide the number of codebooks provided by the wrapped model.
+        extend_cardinality (bool): if True, and for instance if codebooks_per_step = 1,
+            if each codebook has a cardinality N, then the first codebook will
+            use the range [0, N - 1], and the second [N, 2 N - 1] etc.
+            On decoding, this can lead to potentially invalid sequences.
+            Any invalid entry will be silently remapped to the proper range
+            with a modulo.
+    """
+    def __init__(self, model: CompressionModel, codebooks_per_step: int = 1,
+                 extend_cardinality: bool = True):
+        super().__init__()
+        self.model = model
+        self.codebooks_per_step = codebooks_per_step
+        self.extend_cardinality = extend_cardinality
+
+    @property
+    def total_codebooks(self):
+        return self.model.total_codebooks
+
+    @property
+    def num_codebooks(self):
+        """Active number of codebooks used by the quantizer.
+
+        ..Warning:: this reports the number of codebooks after the flattening
+        of the codebooks!
+        """
+        assert self.model.num_codebooks % self.codebooks_per_step == 0
+        return self.codebooks_per_step
+
+    def set_num_codebooks(self, n: int):
+        """Set the active number of codebooks used by the quantizer.
+
+        ..Warning:: this sets the number of codebooks **before** the flattening
+        of the codebooks.
+        """
+        assert n % self.codebooks_per_step == 0
+        self.model.set_num_codebooks(n)
+
+    @property
+    def num_virtual_steps(self) -> int:
+        """Return the number of virtual steps, e.g. one real step
+        will be split into that many steps.
+        """
+        return self.model.num_codebooks // self.codebooks_per_step
+
+    @property
+    def frame_rate(self) -> int:
+        return self.model.frame_rate * self.num_virtual_steps
+
+    @property
+    def sample_rate(self) -> int:
+        return self.model.sample_rate
+
+    @property
+    def channels(self) -> int:
+        return self.model.channels
+
+    @property
+    def cardinality(self):
+        """Cardinality of each codebook.
+        """
+        if self.extend_cardinality:
+            return self.model.cardinality * self.num_virtual_steps
+        else:
+            return self.model.cardinality
+
+    def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
+        raise NotImplementedError("Not supported, use encode and decode.")
+
+    def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
+        indices, scales = self.model.encode(x)
+        B, K, T = indices.shape
+        indices = rearrange(indices, 'b (k v) t -> b k t v', k=self.codebooks_per_step)
+        if self.extend_cardinality:
+            for virtual_step in range(1, self.num_virtual_steps):
+                indices[..., virtual_step] += self.model.cardinality * virtual_step
+        indices = rearrange(indices, 'b k t v -> b k (t v)')
+        return (indices, scales)
+
+    def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
+        B, K, T = codes.shape
+        assert T % self.num_virtual_steps == 0
+        codes = rearrange(codes, 'b k (t v) -> b (k v) t', v=self.num_virtual_steps)
+        # We silently ignore potential errors from the LM when
+        # using extend_cardinality.
+        codes = codes % self.model.cardinality
+        return self.model.decode(codes, scale)
+
+

Ancestors

+ +

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Instance variables

+
+
var cardinality
+
+

Cardinality of each codebook.

+
+ +Expand source code + +
@property
+def cardinality(self):
+    """Cardinality of each codebook.
+    """
+    if self.extend_cardinality:
+        return self.model.cardinality * self.num_virtual_steps
+    else:
+        return self.model.cardinality
+
+
+
var channels : int
+
+
+
+ +Expand source code + +
@property
+def channels(self) -> int:
+    return self.model.channels
+
+
+
var frame_rate : int
+
+
+
+ +Expand source code + +
@property
+def frame_rate(self) -> int:
+    return self.model.frame_rate * self.num_virtual_steps
+
+
+
var num_codebooks
+
+

Active number of codebooks used by the quantizer.

+
+

Warning: this reports the number of codebooks after the flattening

+
+

of the codebooks!

+
+ +Expand source code + +
@property
+def num_codebooks(self):
+    """Active number of codebooks used by the quantizer.
+
+    ..Warning:: this reports the number of codebooks after the flattening
+    of the codebooks!
+    """
+    assert self.model.num_codebooks % self.codebooks_per_step == 0
+    return self.codebooks_per_step
+
+
+
var num_virtual_steps : int
+
+

Return the number of virtual steps, e.g. one real step +will be split into that many steps.

+
+ +Expand source code + +
@property
+def num_virtual_steps(self) -> int:
+    """Return the number of virtual steps, e.g. one real step
+    will be split into that many steps.
+    """
+    return self.model.num_codebooks // self.codebooks_per_step
+
+
+
var sample_rate : int
+
+
+
+ +Expand source code + +
@property
+def sample_rate(self) -> int:
+    return self.model.sample_rate
+
+
+
var total_codebooks
+
+
+
+ +Expand source code + +
@property
+def total_codebooks(self):
+    return self.model.total_codebooks
+
+
+
+

Methods

+
+
+def set_num_codebooks(self, n: int) +
+
+

Set the active number of codebooks used by the quantizer.

+
+

Warning: this sets the number of codebooks before the flattening

+
+

of the codebooks.

+
+ +Expand source code + +
def set_num_codebooks(self, n: int):
+    """Set the active number of codebooks used by the quantizer.
+
+    ..Warning:: this sets the number of codebooks **before** the flattening
+    of the codebooks.
+    """
+    assert n % self.codebooks_per_step == 0
+    self.model.set_num_codebooks(n)
+
+
+
+

Inherited members

+ +
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/audiocraft/models/index.html b/docs/audiocraft/models/index.html new file mode 100644 index 00000000..f088aa04 --- /dev/null +++ b/docs/audiocraft/models/index.html @@ -0,0 +1,104 @@ + + + + + + +audiocraft.models API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.models

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# flake8: noqa
+from .musicgen import MusicGen
+from .lm import LMModel
+from .encodec import CompressionModel, EncodecModel
+
+
+
+

Sub-modules

+
+
audiocraft.models.builders
+
+

All the functions to build the relevant models and modules +from the Hydra config.

+
+
audiocraft.models.encodec
+
+
+
+
audiocraft.models.lm
+
+
+
+
audiocraft.models.loaders
+
+

Utility functions to load from the checkpoints. +Each checkpoint is a torch.saved dict with the following keys: +- 'xp.cfg': the hydra config as dumped …

+
+
audiocraft.models.musicgen
+
+

Main model for using MusicGen. This will combine all the required components +and provide easy access to the generation API.

+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/audiocraft/models/lm.html b/docs/audiocraft/models/lm.html new file mode 100644 index 00000000..0f6a515b --- /dev/null +++ b/docs/audiocraft/models/lm.html @@ -0,0 +1,1721 @@ + + + + + + +audiocraft.models.lm API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.models.lm

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from dataclasses import dataclass
+from functools import partial
+import logging
+import math
+import typing as tp
+
+import torch
+from torch import nn
+
+from ..utils import utils
+from ..modules.streaming import StreamingModule, State
+from ..modules.transformer import StreamingTransformer, create_norm_fn
+from ..modules.conditioners import (
+    ConditionFuser,
+    ClassifierFreeGuidanceDropout,
+    AttributeDropout,
+    ConditioningProvider,
+    ConditioningAttributes,
+    ConditionType,
+)
+from ..modules.codebooks_patterns import CodebooksPatternProvider
+from ..modules.activations import get_activation_fn
+
+
+logger = logging.getLogger(__name__)
+ConditionTensors = tp.Dict[str, ConditionType]
+CFGConditions = tp.Union[ConditionTensors, tp.Tuple[ConditionTensors, ConditionTensors]]
+
+
+def get_init_fn(method: str, input_dim: int, init_depth: tp.Optional[int] = None):
+    """LM layer initialization.
+    Inspired from xlformers: https://github.com/fairinternal/xlformers
+
+    Args:
+        method (str): Method name for init function. Valid options are:
+            'gaussian', 'uniform'.
+        input_dim (int): Input dimension of the initialized module.
+        init_depth (Optional[int]): Optional init depth value used to rescale
+            the standard deviation if defined.
+    """
+    # Compute std
+    std = 1 / math.sqrt(input_dim)
+    # Rescale with depth
+    if init_depth is not None:
+        std = std / math.sqrt(2 * init_depth)
+
+    if method == 'gaussian':
+        return partial(
+            torch.nn.init.trunc_normal_, mean=0.0, std=std, a=-3 * std, b=3 * std
+        )
+    elif method == 'uniform':
+        bound = math.sqrt(3) * std  # ensure the standard deviation is `std`
+        return partial(torch.nn.init.uniform_, a=-bound, b=bound)
+    else:
+        raise ValueError("Unsupported layer initialization method")
+
+
+def init_layer(m: nn.Module,
+               method: str,
+               init_depth: tp.Optional[int] = None,
+               zero_bias_init: bool = False):
+    """Wrapper around ``get_init_fn`` for proper initialization of LM modules.
+
+    Args:
+        m (nn.Module): Module to initialize.
+        method (str): Method name for the init function.
+        init_depth (Optional[int]): Optional init depth value used to rescale
+            the standard deviation if defined.
+        zero_bias_init (bool): Whether to initialize the bias to 0 or not.
+    """
+    if isinstance(m, nn.Linear):
+        init_fn = get_init_fn(method, m.in_features, init_depth=init_depth)
+        if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16:
+            weight = m.weight.float()
+            init_fn(weight)
+            m.weight.data[:] = weight.half()
+        else:
+            init_fn(m.weight)
+        if zero_bias_init and m.bias is not None:
+            nn.init.constant_(m.bias, 0)
+    elif isinstance(m, nn.Embedding):
+        init_fn = get_init_fn(method, m.embedding_dim, init_depth=None)
+        if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16:
+            weight = m.weight.float()
+            init_fn(weight)
+            m.weight.data[:] = weight.half()
+        else:
+            init_fn(m.weight)
+
+
+class ScaledEmbedding(nn.Embedding):
+    """Boost learning rate for embeddings (with `scale`).
+    """
+    def __init__(self, *args, lr=None, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.lr = lr
+
+    def make_optim_group(self):
+        group = {"params": list(self.parameters())}
+        if self.lr is not None:
+            group["lr"] = self.lr
+        return group
+
+
+@dataclass
+class LMOutput:
+    # The logits are already re-aligned with the input codes
+    # hence no extra shift is required, e.g. when computing CE
+    logits: torch.Tensor  # [B, K, T, card]
+    mask: torch.Tensor  # [B, K, T]
+
+
+class LMModel(StreamingModule):
+    """Transformer-based language model on multiple streams of codes.
+
+    Args:
+        pattern_provider (CodebooksPatternProvider): Pattern provider for codebook interleaving.
+        condition_provider (MusicConditioningProvider): Conditioning provider from metadata.
+        fuser (ConditionFuser): Fuser handling the fusing of conditions with language model input.
+        n_q (int): Number of parallel streams to model.
+        card (int): Cardinality, vocabulary size.
+        dim (int): Dimension of the transformer encoder.
+        num_heads (int): Number of heads for the transformer encoder.
+        hidden_scale (int): Scale for hidden feed forward dimension of the transformer encoder.
+        norm (str): Normalization method.
+        norm_first (bool): Use pre-norm instead of post-norm.
+        emb_lr (Optional[float]): Embedding-specific learning rate.
+        bias_proj (bool): Use bias for output projections.
+        weight_init (Optional[str]): Method for weight initialization.
+        depthwise_init (Optional[str]): Method for depthwise weight initialization.
+        zero_bias_init (bool): If true and bias in Linears, initialize bias to zeros.
+        cfg_dropout (float): Classifier-free guidance dropout.
+        cfg_coef (float): Classifier-free guidance coefficient.
+        attribute_dropout (dict): Attribute dropout probabilities.
+        two_step_cfg (bool): Whether to run classifier free-guidance with 2 distinct steps.
+        **kwargs: Additional parameters for the transformer encoder.
+    """
+    def __init__(self, pattern_provider: CodebooksPatternProvider, condition_provider: ConditioningProvider,
+                 fuser: ConditionFuser, n_q: int = 8, card: int = 1024, dim: int = 128, num_heads: int = 8,
+                 hidden_scale: int = 4, norm: str = 'layer_norm', norm_first: bool = False,
+                 emb_lr: tp.Optional[float] = None, bias_proj: bool = True,
+                 weight_init: tp.Optional[str] = None, depthwise_init: tp.Optional[str] = None,
+                 zero_bias_init: bool = False, cfg_dropout: float = 0, cfg_coef: float = 1.0,
+                 attribute_dropout: tp.Dict[str, tp.Dict[str, float]] = {}, two_step_cfg: bool = False,
+                 **kwargs):
+        super().__init__()
+        self.cfg_coef = cfg_coef
+        self.cfg_dropout = ClassifierFreeGuidanceDropout(p=cfg_dropout)
+        self.att_dropout = AttributeDropout(p=attribute_dropout)
+        self.condition_provider = condition_provider
+        self.fuser = fuser
+        self.card = card
+        embed_dim = self.card + 1
+        self.n_q = n_q
+        self.dim = dim
+        self.pattern_provider = pattern_provider
+        self.two_step_cfg = two_step_cfg
+        self.emb = nn.ModuleList([ScaledEmbedding(embed_dim, dim, lr=emb_lr) for _ in range(n_q)])
+        if 'activation' in kwargs:
+            kwargs['activation'] = get_activation_fn(kwargs['activation'])
+        self.transformer = StreamingTransformer(
+            d_model=dim, num_heads=num_heads, dim_feedforward=int(hidden_scale * dim),
+            norm=norm, norm_first=norm_first, **kwargs)
+        self.out_norm: tp.Optional[nn.Module] = None
+        if norm_first:
+            self.out_norm = create_norm_fn(norm, dim)
+        self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=bias_proj) for _ in range(n_q)])
+        self._init_weights(weight_init, depthwise_init, zero_bias_init)
+        self._fsdp: tp.Optional[nn.Module]
+        self.__dict__['_fsdp'] = None
+
+    def _init_weights(self, weight_init: tp.Optional[str], depthwise_init: tp.Optional[str], zero_bias_init: bool):
+        """Initialization of the transformer module weights.
+
+        Args:
+            weight_init (Optional[str]): Weight initialization strategy. See ``get_init_fn`` for valid options.
+            depthwise_init (Optional[str]): Depwthwise initialization strategy. The following options are valid:
+                'current' where the depth corresponds to the current layer index or 'global' where the total number
+                of layer is used as depth. If not set, no depthwise initialization strategy is used.
+            zero_bias_init (bool): Whether to initalize bias to zero or not.
+        """
+        assert depthwise_init is None or depthwise_init in ['current', 'global']
+        assert depthwise_init is None or weight_init is not None, \
+            "If 'depthwise_init' is defined, a 'weight_init' method should be provided."
+        assert not zero_bias_init or weight_init is not None, \
+            "If 'zero_bias_init', a 'weight_init' method should be provided"
+
+        if weight_init is None:
+            return
+
+        for emb_layer in self.emb:
+            init_layer(emb_layer, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
+
+        for layer_idx, tr_layer in enumerate(self.transformer.layers):
+            depth = None
+            if depthwise_init == 'current':
+                depth = layer_idx + 1
+            elif depthwise_init == 'global':
+                depth = len(self.transformer.layers)
+            init_fn = partial(init_layer, method=weight_init, init_depth=depth, zero_bias_init=zero_bias_init)
+            tr_layer.apply(init_fn)
+
+        for linear in self.linears:
+            init_layer(linear, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
+
+    @property
+    def special_token_id(self) -> int:
+        return self.card
+
+    @property
+    def num_codebooks(self) -> int:
+        return self.n_q
+
+    def forward(self, sequence: torch.Tensor,
+                conditions: tp.List[ConditioningAttributes],
+                condition_tensors: tp.Optional[ConditionTensors] = None) -> torch.Tensor:
+        """Apply language model on sequence and conditions.
+        Given a tensor of sequence of shape [B, K, S] with K the number of codebooks and
+        S the sequence steps, return the logits with shape [B, card, K, S].
+
+        Args:
+            indices (torch.Tensor): indices of the codes to model.
+            conditions (list[ConditioningAttributes]): conditionings to use when modeling
+                the given codes. Note that when evaluating multiple time with the same conditioning
+                you should pre-compute those and pass them as `condition_tensors`.
+            condition_tensors (dict[str, ConditionType] or None): pre-computed conditioning
+                tensors, see `conditions`.
+        Returns:
+            torch.Tensor: Logits.
+        """
+        B, K, S = sequence.shape
+        assert K == self.num_codebooks, 'Sequence shape must match the specified number of codebooks'
+        input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)])
+        if condition_tensors is None:
+            assert not self._is_streaming, "Conditions tensors should be precomputed when streaming."
+            # apply dropout modules
+            conditions = self.cfg_dropout(conditions)
+            conditions = self.att_dropout(conditions)
+            tokenized = self.condition_provider.tokenize(conditions)
+            # encode conditions and fuse, both have a streaming cache to not recompute when generating.
+            condition_tensors = self.condition_provider(tokenized)
+        else:
+            assert not conditions, "Shouldn't pass both conditions and condition_tensors."
+
+        input_, cross_attention_input = self.fuser(input_, condition_tensors)
+
+        out = self.transformer(input_, cross_attention_src=cross_attention_input)
+        if self.out_norm:
+            out = self.out_norm(out)
+        logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1)  # [B, K, S, card]
+
+        # remove the prefix from the model outputs
+        if len(self.fuser.fuse2cond['prepend']) > 0:
+            logits = logits[:, :, -S:]
+
+        return logits  # [B, K, S, card]
+
+    def compute_predictions(
+            self, codes: torch.Tensor,
+            conditions: tp.List[ConditioningAttributes],
+            condition_tensors: tp.Optional[ConditionTensors] = None) -> LMOutput:
+        """Given an input tensor of codes [B, K, T] and list of conditions, runs the model
+        forward using the specified codes interleaving pattern.
+
+        Args:
+            codes (torch.Tensor): Input codes of shape [B, K, T] with B the batch size,
+                K the number of codebooks and T the number of timesteps.
+            conditions (list[ConditioningAttributes]): conditionings to use when modeling
+                the given codes. Note that when evaluating multiple time with the same conditioning
+                you should pre-compute those and pass them as `condition_tensors`.
+            condition_tensors (dict[str, ConditionType] or None): pre-computed conditioning
+                tensors, see `conditions`.
+        Returns:
+            LMOutput: Language model outputs
+                logits (torch.Tensor) of shape [B, K, T, card] corresponding to the provided codes,
+                    i.e. the first item corresponds to logits to predict the first code, meaning that
+                    no additional shifting of codes and logits is required.
+                mask (torch.Tensor) of shape [B, K, T], mask over valid and invalid positions.
+                    Given the specified interleaving strategies, parts of the logits and codes should
+                    not be considered as valid predictions because of invalid context.
+        """
+        B, K, T = codes.shape
+        codes = codes.contiguous()
+        # map codes [B, K, T] into pattern sequence [B, K, S] using special_token_id for masked tokens
+        pattern = self.pattern_provider.get_pattern(T)
+        sequence_codes, sequence_indexes, sequence_mask = pattern.build_pattern_sequence(
+            codes, self.special_token_id, keep_only_valid_steps=True
+        )
+        # apply model on pattern sequence
+        model = self if self._fsdp is None else self._fsdp
+        logits = model(sequence_codes, conditions, condition_tensors)  # [B, K, S, card]
+        # map back the logits on pattern sequence to logits on original codes: [B, K, S, card] -> [B, K, T, card]
+        # and provide the corresponding mask over invalid positions of tokens
+        logits = logits.permute(0, 3, 1, 2)  # [B, card, K, S]
+        # note: we use nans as special token to make it obvious if we feed unexpected logits
+        logits, logits_indexes, logits_mask = pattern.revert_pattern_logits(
+            logits, float('nan'), keep_only_valid_steps=True
+        )
+        logits = logits.permute(0, 2, 3, 1)  # [B, K, T, card]
+        logits_mask = logits_mask[None, :, :].expand(B, -1, -1)  # [K, T] -> [B, K, T]
+        return LMOutput(logits, logits_mask)
+
+    def _sample_next_token(self,
+                           sequence: torch.Tensor,
+                           cfg_conditions: CFGConditions,
+                           unconditional_state: State,
+                           use_sampling: bool = False,
+                           temp: float = 1.0,
+                           top_k: int = 0,
+                           top_p: float = 0.0,
+                           cfg_coef: tp.Optional[float] = None) -> torch.Tensor:
+        """Sample next token from the model given a sequence and a set of conditions. The model supports
+        multiple sampling strategies (greedy sampling, softmax, top-k, top-p...).
+
+        Args:
+            sequence (torch.Tensor): Current sequence of shape [B, K, S]
+                with K corresponding to the number of codebooks and S the number of sequence steps.
+                S = 1 in streaming mode, except for the first step that contains a bigger prompt.
+            condition_tensors (Dict[str, ConditionType): Set of conditions. If CFG is used,
+                should be twice the batch size, being the concatenation of the conditions + null conditions.
+            use_sampling (bool): Whether to use a sampling strategy or not.
+            temp (float): Sampling temperature.
+            top_k (int): K for "top-k" sampling.
+            top_p (float): P for "top-p" sampling.
+            cfg_coef (float): classifier free guidance coefficient
+        Returns:
+            next_token (torch.Tensor): Next token tensor of shape [B, K, 1].
+        """
+        B = sequence.shape[0]
+        cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef
+        model = self if self._fsdp is None else self._fsdp
+        if self.two_step_cfg and cfg_conditions != {}:
+            assert isinstance(cfg_conditions, tuple)
+            condition_tensors, null_condition_tensors = cfg_conditions
+            cond_logits = model(sequence, conditions=[], condition_tensors=condition_tensors)
+            state = self.get_streaming_state()
+            self.set_streaming_state(unconditional_state)
+            uncond_logits = model(sequence, conditions=[], condition_tensors=null_condition_tensors)
+            unconditional_state.update(self.get_streaming_state())
+            self.set_streaming_state(state)
+            logits = uncond_logits + (cond_logits - uncond_logits) * self.cfg_coef
+        else:
+            assert isinstance(cfg_conditions, dict)
+            condition_tensors = cfg_conditions
+            if condition_tensors:
+                # Preparing for CFG, predicting both conditional and unconditional logits.
+                sequence = torch.cat([sequence, sequence], dim=0)
+            all_logits = model(
+                sequence,
+                conditions=[], condition_tensors=condition_tensors)
+            if condition_tensors:
+                cond_logits, uncond_logits = all_logits.split(B, dim=0)  # [B, K, T, card]
+                logits = uncond_logits + (cond_logits - uncond_logits) * cfg_coef
+            else:
+                logits = all_logits
+
+        logits = logits.permute(0, 1, 3, 2)  # [B, K, card, T]
+        logits = logits[..., -1]  # [B x K x card]
+
+        # Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error.
+        if use_sampling and temp > 0.0:
+            probs = torch.softmax(logits / temp, dim=-1)
+            if top_p > 0.0:
+                next_token = utils.sample_top_p(probs, p=top_p)
+            elif top_k > 0:
+                next_token = utils.sample_top_k(probs, k=top_k)
+            else:
+                next_token = utils.multinomial(probs, num_samples=1)
+        else:
+            next_token = torch.argmax(logits, dim=-1, keepdim=True)
+
+        return next_token
+
+    @torch.no_grad()
+    def generate(self,
+                 prompt: tp.Optional[torch.Tensor] = None,
+                 conditions: tp.List[ConditioningAttributes] = [],
+                 num_samples: tp.Optional[int] = None,
+                 max_gen_len: int = 256,
+                 use_sampling: bool = True,
+                 temp: float = 1.0,
+                 top_k: int = 250,
+                 top_p: float = 0.0,
+                 cfg_coef: tp.Optional[float] = None,
+                 two_step_cfg: bool = False,
+                 remove_prompts: bool = False,
+                 check: bool = False,
+                 callback: tp.Optional[tp.Callable[[int, int], None]] = None) -> torch.Tensor:
+        """Generate tokens sampling from the model given a prompt or unconditionally. Generation can
+        be perform in a greedy fashion or using sampling with top K and top P strategies.
+
+        Args:
+            prompt (Optional[torch.Tensor]): Prompt tokens of shape [B, K, T].
+            conditions_tensors (Dict[str, torch.Tensor]): Set of conditions or None.
+            num_samples (int or None): Number of samples to generate when no prompt and no conditions are given.
+            max_gen_len (int): Maximum generation length.
+            use_sampling (bool): Whether to use a sampling strategy or not.
+            temp (float): Sampling temperature.
+            top_k (int): K for "top-k" sampling.
+            top_p (float): P for "top-p" sampling.
+            remove_prompts (bool): Whether to remove prompts from generation or not.
+        Returns:
+            torch.Tensor: Generated tokens.
+        """
+        assert not self.training, "generation shouldn't be used in training mode."
+        first_param = next(iter(self.parameters()))
+        device = first_param.device
+
+        # Checking all input shapes are consistents.
+        possible_num_samples = []
+        if num_samples is not None:
+            possible_num_samples.append(num_samples)
+        elif prompt is not None:
+            possible_num_samples.append(prompt.shape[0])
+        elif conditions:
+            possible_num_samples.append(len(conditions))
+        else:
+            possible_num_samples.append(1)
+        assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsitent inputs shapes"
+        num_samples = possible_num_samples[0]
+
+        # below we create set of conditions: one conditional and one unconditional
+        # to do that we merge the regular condition together with the null condition
+        # we then do 1 forward pass instead of 2.
+        # the reason for that is two-fold:
+        # 1. it is about x2 faster than doing 2 forward passes
+        # 2. avoid the streaming API treating the 2 passes as part of different time steps
+        # We also support doing two different passes, in particular to ensure that
+        # the padding structure is exactly the same between train anf test.
+        # With a batch size of 1, this can be slower though.
+        cfg_conditions: CFGConditions
+        two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
+        if conditions:
+            null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions)
+            if two_step_cfg:
+                cfg_conditions = (
+                    self.condition_provider(self.condition_provider.tokenize(conditions)),
+                    self.condition_provider(self.condition_provider.tokenize(null_conditions)),
+                )
+            else:
+                conditions = conditions + null_conditions
+                tokenized = self.condition_provider.tokenize(conditions)
+                cfg_conditions = self.condition_provider(tokenized)
+        else:
+            cfg_conditions = {}
+
+        if prompt is None:
+            assert num_samples > 0
+            prompt = torch.zeros((num_samples, self.num_codebooks, 0), dtype=torch.long, device=device)
+
+        B, K, T = prompt.shape
+        start_offset = T
+        assert start_offset < max_gen_len
+
+        pattern = self.pattern_provider.get_pattern(max_gen_len)
+        # this token is used as default value for codes that are not generated yet
+        unknown_token = -1
+
+        # we generate codes up to the max_gen_len that will be mapped to the pattern sequence
+        gen_codes = torch.full((B, K, max_gen_len), unknown_token, dtype=torch.long, device=device)
+        # filling the gen_codes with the prompt if needed
+        gen_codes[..., :start_offset] = prompt
+        # create the gen_sequence with proper interleaving from the pattern: [B, K, S]
+        gen_sequence, indexes, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id)
+        # retrieve the start_offset in the sequence:
+        # it is the first sequence step that contains the `start_offset` timestep
+        start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset)
+        assert start_offset_sequence is not None
+
+        with self.streaming():
+            unconditional_state = self.get_streaming_state()
+            prev_offset = 0
+            gen_sequence_len = gen_sequence.shape[-1]  # gen_sequence shape is [B, K, S]
+            for offset in range(start_offset_sequence, gen_sequence_len):
+                # get current sequence (note that the streaming API is providing the caching over previous offsets)
+                curr_sequence = gen_sequence[..., prev_offset:offset]
+                curr_mask = mask[None, ..., prev_offset:offset].expand(B, -1, -1)
+                if check:
+                    # check coherence between mask and sequence
+                    assert (curr_sequence == torch.where(curr_mask, curr_sequence, self.special_token_id)).all()
+                    # should never happen as gen_sequence is filled progressively
+                    assert not (curr_sequence == unknown_token).any()
+                # sample next token from the model, next token shape is [B, K, 1]
+                next_token = self._sample_next_token(
+                    curr_sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p,
+                    cfg_coef=cfg_coef)
+                # ensure the tokens that should be masked are properly set to special_token_id
+                # as the model never output special_token_id
+                valid_mask = mask[..., offset:offset+1].expand(B, -1, -1)
+                next_token[~valid_mask] = self.special_token_id
+                # ensure we don't overwrite prompt tokens, we only write over unknown tokens
+                # (then mask tokens should be left as is as well, which is correct)
+                gen_sequence[..., offset:offset+1] = torch.where(
+                    gen_sequence[..., offset:offset+1] == unknown_token,
+                    next_token, gen_sequence[..., offset:offset+1]
+                )
+                prev_offset = offset
+                if callback is not None:
+                    callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence)
+        unconditional_state.clear()
+
+        # ensure sequence has been entirely filled
+        assert not (gen_sequence == unknown_token).any()
+        # ensure gen_sequence pattern and mask are matching
+        # which means the gen_sequence is valid according to the pattern
+        assert (
+            gen_sequence == torch.where(mask[None, ...].expand(B, -1, -1), gen_sequence, self.special_token_id)
+        ).all()
+        # get back the codes, trimming the prompt if needed and cutting potentially incomplete timesteps
+        out_codes, out_indexes, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token)
+
+        # sanity checks over the returned codes and corresponding masks
+        assert (out_codes[..., :max_gen_len] != unknown_token).all()
+        assert (out_mask[..., :max_gen_len] == 1).all()
+
+        out_start_offset = start_offset if remove_prompts else 0
+        out_codes = out_codes[..., out_start_offset:max_gen_len]
+
+        # ensure the returned codes are all valid
+        assert (out_codes >= 0).all() and (out_codes <= self.card).all()
+        return out_codes
+
+
+
+
+
+
+
+

Functions

+
+
+def get_init_fn(method: str, input_dim: int, init_depth: Optional[int] = None) +
+
+

LM layer initialization. +Inspired from xlformers: https://github.com/fairinternal/xlformers

+

Args

+
+
method : str
+
Method name for init function. Valid options are: +'gaussian', 'uniform'.
+
input_dim : int
+
Input dimension of the initialized module.
+
init_depth : Optional[int]
+
Optional init depth value used to rescale +the standard deviation if defined.
+
+
+ +Expand source code + +
def get_init_fn(method: str, input_dim: int, init_depth: tp.Optional[int] = None):
+    """LM layer initialization.
+    Inspired from xlformers: https://github.com/fairinternal/xlformers
+
+    Args:
+        method (str): Method name for init function. Valid options are:
+            'gaussian', 'uniform'.
+        input_dim (int): Input dimension of the initialized module.
+        init_depth (Optional[int]): Optional init depth value used to rescale
+            the standard deviation if defined.
+    """
+    # Compute std
+    std = 1 / math.sqrt(input_dim)
+    # Rescale with depth
+    if init_depth is not None:
+        std = std / math.sqrt(2 * init_depth)
+
+    if method == 'gaussian':
+        return partial(
+            torch.nn.init.trunc_normal_, mean=0.0, std=std, a=-3 * std, b=3 * std
+        )
+    elif method == 'uniform':
+        bound = math.sqrt(3) * std  # ensure the standard deviation is `std`
+        return partial(torch.nn.init.uniform_, a=-bound, b=bound)
+    else:
+        raise ValueError("Unsupported layer initialization method")
+
+
+
+def init_layer(m: torch.nn.modules.module.Module, method: str, init_depth: Optional[int] = None, zero_bias_init: bool = False) +
+
+

Wrapper around get_init_fn() for proper initialization of LM modules.

+

Args

+
+
m : nn.Module
+
Module to initialize.
+
method : str
+
Method name for the init function.
+
init_depth : Optional[int]
+
Optional init depth value used to rescale +the standard deviation if defined.
+
zero_bias_init : bool
+
Whether to initialize the bias to 0 or not.
+
+
+ +Expand source code + +
def init_layer(m: nn.Module,
+               method: str,
+               init_depth: tp.Optional[int] = None,
+               zero_bias_init: bool = False):
+    """Wrapper around ``get_init_fn`` for proper initialization of LM modules.
+
+    Args:
+        m (nn.Module): Module to initialize.
+        method (str): Method name for the init function.
+        init_depth (Optional[int]): Optional init depth value used to rescale
+            the standard deviation if defined.
+        zero_bias_init (bool): Whether to initialize the bias to 0 or not.
+    """
+    if isinstance(m, nn.Linear):
+        init_fn = get_init_fn(method, m.in_features, init_depth=init_depth)
+        if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16:
+            weight = m.weight.float()
+            init_fn(weight)
+            m.weight.data[:] = weight.half()
+        else:
+            init_fn(m.weight)
+        if zero_bias_init and m.bias is not None:
+            nn.init.constant_(m.bias, 0)
+    elif isinstance(m, nn.Embedding):
+        init_fn = get_init_fn(method, m.embedding_dim, init_depth=None)
+        if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16:
+            weight = m.weight.float()
+            init_fn(weight)
+            m.weight.data[:] = weight.half()
+        else:
+            init_fn(m.weight)
+
+
+
+
+
+

Classes

+
+
+class LMModel +(pattern_provider: CodebooksPatternProvider, condition_provider: ConditioningProvider, fuser: ConditionFuser, n_q: int = 8, card: int = 1024, dim: int = 128, num_heads: int = 8, hidden_scale: int = 4, norm: str = 'layer_norm', norm_first: bool = False, emb_lr: Optional[float] = None, bias_proj: bool = True, weight_init: Optional[str] = None, depthwise_init: Optional[str] = None, zero_bias_init: bool = False, cfg_dropout: float = 0, cfg_coef: float = 1.0, attribute_dropout: Dict[str, Dict[str, float]] = {}, two_step_cfg: bool = False, **kwargs) +
+
+

Transformer-based language model on multiple streams of codes.

+

Args

+
+
pattern_provider : CodebooksPatternProvider
+
Pattern provider for codebook interleaving.
+
condition_provider : MusicConditioningProvider
+
Conditioning provider from metadata.
+
fuser : ConditionFuser
+
Fuser handling the fusing of conditions with language model input.
+
n_q : int
+
Number of parallel streams to model.
+
card : int
+
Cardinality, vocabulary size.
+
dim : int
+
Dimension of the transformer encoder.
+
num_heads : int
+
Number of heads for the transformer encoder.
+
hidden_scale : int
+
Scale for hidden feed forward dimension of the transformer encoder.
+
norm : str
+
Normalization method.
+
norm_first : bool
+
Use pre-norm instead of post-norm.
+
emb_lr : Optional[float]
+
Embedding-specific learning rate.
+
bias_proj : bool
+
Use bias for output projections.
+
weight_init : Optional[str]
+
Method for weight initialization.
+
depthwise_init : Optional[str]
+
Method for depthwise weight initialization.
+
zero_bias_init : bool
+
If true and bias in Linears, initialize bias to zeros.
+
cfg_dropout : float
+
Classifier-free guidance dropout.
+
cfg_coef : float
+
Classifier-free guidance coefficient.
+
attribute_dropout : dict
+
Attribute dropout probabilities.
+
two_step_cfg : bool
+
Whether to run classifier free-guidance with 2 distinct steps.
+
**kwargs
+
Additional parameters for the transformer encoder.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class LMModel(StreamingModule):
+    """Transformer-based language model on multiple streams of codes.
+
+    Args:
+        pattern_provider (CodebooksPatternProvider): Pattern provider for codebook interleaving.
+        condition_provider (MusicConditioningProvider): Conditioning provider from metadata.
+        fuser (ConditionFuser): Fuser handling the fusing of conditions with language model input.
+        n_q (int): Number of parallel streams to model.
+        card (int): Cardinality, vocabulary size.
+        dim (int): Dimension of the transformer encoder.
+        num_heads (int): Number of heads for the transformer encoder.
+        hidden_scale (int): Scale for hidden feed forward dimension of the transformer encoder.
+        norm (str): Normalization method.
+        norm_first (bool): Use pre-norm instead of post-norm.
+        emb_lr (Optional[float]): Embedding-specific learning rate.
+        bias_proj (bool): Use bias for output projections.
+        weight_init (Optional[str]): Method for weight initialization.
+        depthwise_init (Optional[str]): Method for depthwise weight initialization.
+        zero_bias_init (bool): If true and bias in Linears, initialize bias to zeros.
+        cfg_dropout (float): Classifier-free guidance dropout.
+        cfg_coef (float): Classifier-free guidance coefficient.
+        attribute_dropout (dict): Attribute dropout probabilities.
+        two_step_cfg (bool): Whether to run classifier free-guidance with 2 distinct steps.
+        **kwargs: Additional parameters for the transformer encoder.
+    """
+    def __init__(self, pattern_provider: CodebooksPatternProvider, condition_provider: ConditioningProvider,
+                 fuser: ConditionFuser, n_q: int = 8, card: int = 1024, dim: int = 128, num_heads: int = 8,
+                 hidden_scale: int = 4, norm: str = 'layer_norm', norm_first: bool = False,
+                 emb_lr: tp.Optional[float] = None, bias_proj: bool = True,
+                 weight_init: tp.Optional[str] = None, depthwise_init: tp.Optional[str] = None,
+                 zero_bias_init: bool = False, cfg_dropout: float = 0, cfg_coef: float = 1.0,
+                 attribute_dropout: tp.Dict[str, tp.Dict[str, float]] = {}, two_step_cfg: bool = False,
+                 **kwargs):
+        super().__init__()
+        self.cfg_coef = cfg_coef
+        self.cfg_dropout = ClassifierFreeGuidanceDropout(p=cfg_dropout)
+        self.att_dropout = AttributeDropout(p=attribute_dropout)
+        self.condition_provider = condition_provider
+        self.fuser = fuser
+        self.card = card
+        embed_dim = self.card + 1
+        self.n_q = n_q
+        self.dim = dim
+        self.pattern_provider = pattern_provider
+        self.two_step_cfg = two_step_cfg
+        self.emb = nn.ModuleList([ScaledEmbedding(embed_dim, dim, lr=emb_lr) for _ in range(n_q)])
+        if 'activation' in kwargs:
+            kwargs['activation'] = get_activation_fn(kwargs['activation'])
+        self.transformer = StreamingTransformer(
+            d_model=dim, num_heads=num_heads, dim_feedforward=int(hidden_scale * dim),
+            norm=norm, norm_first=norm_first, **kwargs)
+        self.out_norm: tp.Optional[nn.Module] = None
+        if norm_first:
+            self.out_norm = create_norm_fn(norm, dim)
+        self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=bias_proj) for _ in range(n_q)])
+        self._init_weights(weight_init, depthwise_init, zero_bias_init)
+        self._fsdp: tp.Optional[nn.Module]
+        self.__dict__['_fsdp'] = None
+
+    def _init_weights(self, weight_init: tp.Optional[str], depthwise_init: tp.Optional[str], zero_bias_init: bool):
+        """Initialization of the transformer module weights.
+
+        Args:
+            weight_init (Optional[str]): Weight initialization strategy. See ``get_init_fn`` for valid options.
+            depthwise_init (Optional[str]): Depwthwise initialization strategy. The following options are valid:
+                'current' where the depth corresponds to the current layer index or 'global' where the total number
+                of layer is used as depth. If not set, no depthwise initialization strategy is used.
+            zero_bias_init (bool): Whether to initalize bias to zero or not.
+        """
+        assert depthwise_init is None or depthwise_init in ['current', 'global']
+        assert depthwise_init is None or weight_init is not None, \
+            "If 'depthwise_init' is defined, a 'weight_init' method should be provided."
+        assert not zero_bias_init or weight_init is not None, \
+            "If 'zero_bias_init', a 'weight_init' method should be provided"
+
+        if weight_init is None:
+            return
+
+        for emb_layer in self.emb:
+            init_layer(emb_layer, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
+
+        for layer_idx, tr_layer in enumerate(self.transformer.layers):
+            depth = None
+            if depthwise_init == 'current':
+                depth = layer_idx + 1
+            elif depthwise_init == 'global':
+                depth = len(self.transformer.layers)
+            init_fn = partial(init_layer, method=weight_init, init_depth=depth, zero_bias_init=zero_bias_init)
+            tr_layer.apply(init_fn)
+
+        for linear in self.linears:
+            init_layer(linear, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
+
+    @property
+    def special_token_id(self) -> int:
+        return self.card
+
+    @property
+    def num_codebooks(self) -> int:
+        return self.n_q
+
+    def forward(self, sequence: torch.Tensor,
+                conditions: tp.List[ConditioningAttributes],
+                condition_tensors: tp.Optional[ConditionTensors] = None) -> torch.Tensor:
+        """Apply language model on sequence and conditions.
+        Given a tensor of sequence of shape [B, K, S] with K the number of codebooks and
+        S the sequence steps, return the logits with shape [B, card, K, S].
+
+        Args:
+            indices (torch.Tensor): indices of the codes to model.
+            conditions (list[ConditioningAttributes]): conditionings to use when modeling
+                the given codes. Note that when evaluating multiple time with the same conditioning
+                you should pre-compute those and pass them as `condition_tensors`.
+            condition_tensors (dict[str, ConditionType] or None): pre-computed conditioning
+                tensors, see `conditions`.
+        Returns:
+            torch.Tensor: Logits.
+        """
+        B, K, S = sequence.shape
+        assert K == self.num_codebooks, 'Sequence shape must match the specified number of codebooks'
+        input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)])
+        if condition_tensors is None:
+            assert not self._is_streaming, "Conditions tensors should be precomputed when streaming."
+            # apply dropout modules
+            conditions = self.cfg_dropout(conditions)
+            conditions = self.att_dropout(conditions)
+            tokenized = self.condition_provider.tokenize(conditions)
+            # encode conditions and fuse, both have a streaming cache to not recompute when generating.
+            condition_tensors = self.condition_provider(tokenized)
+        else:
+            assert not conditions, "Shouldn't pass both conditions and condition_tensors."
+
+        input_, cross_attention_input = self.fuser(input_, condition_tensors)
+
+        out = self.transformer(input_, cross_attention_src=cross_attention_input)
+        if self.out_norm:
+            out = self.out_norm(out)
+        logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1)  # [B, K, S, card]
+
+        # remove the prefix from the model outputs
+        if len(self.fuser.fuse2cond['prepend']) > 0:
+            logits = logits[:, :, -S:]
+
+        return logits  # [B, K, S, card]
+
+    def compute_predictions(
+            self, codes: torch.Tensor,
+            conditions: tp.List[ConditioningAttributes],
+            condition_tensors: tp.Optional[ConditionTensors] = None) -> LMOutput:
+        """Given an input tensor of codes [B, K, T] and list of conditions, runs the model
+        forward using the specified codes interleaving pattern.
+
+        Args:
+            codes (torch.Tensor): Input codes of shape [B, K, T] with B the batch size,
+                K the number of codebooks and T the number of timesteps.
+            conditions (list[ConditioningAttributes]): conditionings to use when modeling
+                the given codes. Note that when evaluating multiple time with the same conditioning
+                you should pre-compute those and pass them as `condition_tensors`.
+            condition_tensors (dict[str, ConditionType] or None): pre-computed conditioning
+                tensors, see `conditions`.
+        Returns:
+            LMOutput: Language model outputs
+                logits (torch.Tensor) of shape [B, K, T, card] corresponding to the provided codes,
+                    i.e. the first item corresponds to logits to predict the first code, meaning that
+                    no additional shifting of codes and logits is required.
+                mask (torch.Tensor) of shape [B, K, T], mask over valid and invalid positions.
+                    Given the specified interleaving strategies, parts of the logits and codes should
+                    not be considered as valid predictions because of invalid context.
+        """
+        B, K, T = codes.shape
+        codes = codes.contiguous()
+        # map codes [B, K, T] into pattern sequence [B, K, S] using special_token_id for masked tokens
+        pattern = self.pattern_provider.get_pattern(T)
+        sequence_codes, sequence_indexes, sequence_mask = pattern.build_pattern_sequence(
+            codes, self.special_token_id, keep_only_valid_steps=True
+        )
+        # apply model on pattern sequence
+        model = self if self._fsdp is None else self._fsdp
+        logits = model(sequence_codes, conditions, condition_tensors)  # [B, K, S, card]
+        # map back the logits on pattern sequence to logits on original codes: [B, K, S, card] -> [B, K, T, card]
+        # and provide the corresponding mask over invalid positions of tokens
+        logits = logits.permute(0, 3, 1, 2)  # [B, card, K, S]
+        # note: we use nans as special token to make it obvious if we feed unexpected logits
+        logits, logits_indexes, logits_mask = pattern.revert_pattern_logits(
+            logits, float('nan'), keep_only_valid_steps=True
+        )
+        logits = logits.permute(0, 2, 3, 1)  # [B, K, T, card]
+        logits_mask = logits_mask[None, :, :].expand(B, -1, -1)  # [K, T] -> [B, K, T]
+        return LMOutput(logits, logits_mask)
+
+    def _sample_next_token(self,
+                           sequence: torch.Tensor,
+                           cfg_conditions: CFGConditions,
+                           unconditional_state: State,
+                           use_sampling: bool = False,
+                           temp: float = 1.0,
+                           top_k: int = 0,
+                           top_p: float = 0.0,
+                           cfg_coef: tp.Optional[float] = None) -> torch.Tensor:
+        """Sample next token from the model given a sequence and a set of conditions. The model supports
+        multiple sampling strategies (greedy sampling, softmax, top-k, top-p...).
+
+        Args:
+            sequence (torch.Tensor): Current sequence of shape [B, K, S]
+                with K corresponding to the number of codebooks and S the number of sequence steps.
+                S = 1 in streaming mode, except for the first step that contains a bigger prompt.
+            condition_tensors (Dict[str, ConditionType): Set of conditions. If CFG is used,
+                should be twice the batch size, being the concatenation of the conditions + null conditions.
+            use_sampling (bool): Whether to use a sampling strategy or not.
+            temp (float): Sampling temperature.
+            top_k (int): K for "top-k" sampling.
+            top_p (float): P for "top-p" sampling.
+            cfg_coef (float): classifier free guidance coefficient
+        Returns:
+            next_token (torch.Tensor): Next token tensor of shape [B, K, 1].
+        """
+        B = sequence.shape[0]
+        cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef
+        model = self if self._fsdp is None else self._fsdp
+        if self.two_step_cfg and cfg_conditions != {}:
+            assert isinstance(cfg_conditions, tuple)
+            condition_tensors, null_condition_tensors = cfg_conditions
+            cond_logits = model(sequence, conditions=[], condition_tensors=condition_tensors)
+            state = self.get_streaming_state()
+            self.set_streaming_state(unconditional_state)
+            uncond_logits = model(sequence, conditions=[], condition_tensors=null_condition_tensors)
+            unconditional_state.update(self.get_streaming_state())
+            self.set_streaming_state(state)
+            logits = uncond_logits + (cond_logits - uncond_logits) * self.cfg_coef
+        else:
+            assert isinstance(cfg_conditions, dict)
+            condition_tensors = cfg_conditions
+            if condition_tensors:
+                # Preparing for CFG, predicting both conditional and unconditional logits.
+                sequence = torch.cat([sequence, sequence], dim=0)
+            all_logits = model(
+                sequence,
+                conditions=[], condition_tensors=condition_tensors)
+            if condition_tensors:
+                cond_logits, uncond_logits = all_logits.split(B, dim=0)  # [B, K, T, card]
+                logits = uncond_logits + (cond_logits - uncond_logits) * cfg_coef
+            else:
+                logits = all_logits
+
+        logits = logits.permute(0, 1, 3, 2)  # [B, K, card, T]
+        logits = logits[..., -1]  # [B x K x card]
+
+        # Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error.
+        if use_sampling and temp > 0.0:
+            probs = torch.softmax(logits / temp, dim=-1)
+            if top_p > 0.0:
+                next_token = utils.sample_top_p(probs, p=top_p)
+            elif top_k > 0:
+                next_token = utils.sample_top_k(probs, k=top_k)
+            else:
+                next_token = utils.multinomial(probs, num_samples=1)
+        else:
+            next_token = torch.argmax(logits, dim=-1, keepdim=True)
+
+        return next_token
+
+    @torch.no_grad()
+    def generate(self,
+                 prompt: tp.Optional[torch.Tensor] = None,
+                 conditions: tp.List[ConditioningAttributes] = [],
+                 num_samples: tp.Optional[int] = None,
+                 max_gen_len: int = 256,
+                 use_sampling: bool = True,
+                 temp: float = 1.0,
+                 top_k: int = 250,
+                 top_p: float = 0.0,
+                 cfg_coef: tp.Optional[float] = None,
+                 two_step_cfg: bool = False,
+                 remove_prompts: bool = False,
+                 check: bool = False,
+                 callback: tp.Optional[tp.Callable[[int, int], None]] = None) -> torch.Tensor:
+        """Generate tokens sampling from the model given a prompt or unconditionally. Generation can
+        be perform in a greedy fashion or using sampling with top K and top P strategies.
+
+        Args:
+            prompt (Optional[torch.Tensor]): Prompt tokens of shape [B, K, T].
+            conditions_tensors (Dict[str, torch.Tensor]): Set of conditions or None.
+            num_samples (int or None): Number of samples to generate when no prompt and no conditions are given.
+            max_gen_len (int): Maximum generation length.
+            use_sampling (bool): Whether to use a sampling strategy or not.
+            temp (float): Sampling temperature.
+            top_k (int): K for "top-k" sampling.
+            top_p (float): P for "top-p" sampling.
+            remove_prompts (bool): Whether to remove prompts from generation or not.
+        Returns:
+            torch.Tensor: Generated tokens.
+        """
+        assert not self.training, "generation shouldn't be used in training mode."
+        first_param = next(iter(self.parameters()))
+        device = first_param.device
+
+        # Checking all input shapes are consistents.
+        possible_num_samples = []
+        if num_samples is not None:
+            possible_num_samples.append(num_samples)
+        elif prompt is not None:
+            possible_num_samples.append(prompt.shape[0])
+        elif conditions:
+            possible_num_samples.append(len(conditions))
+        else:
+            possible_num_samples.append(1)
+        assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsitent inputs shapes"
+        num_samples = possible_num_samples[0]
+
+        # below we create set of conditions: one conditional and one unconditional
+        # to do that we merge the regular condition together with the null condition
+        # we then do 1 forward pass instead of 2.
+        # the reason for that is two-fold:
+        # 1. it is about x2 faster than doing 2 forward passes
+        # 2. avoid the streaming API treating the 2 passes as part of different time steps
+        # We also support doing two different passes, in particular to ensure that
+        # the padding structure is exactly the same between train anf test.
+        # With a batch size of 1, this can be slower though.
+        cfg_conditions: CFGConditions
+        two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
+        if conditions:
+            null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions)
+            if two_step_cfg:
+                cfg_conditions = (
+                    self.condition_provider(self.condition_provider.tokenize(conditions)),
+                    self.condition_provider(self.condition_provider.tokenize(null_conditions)),
+                )
+            else:
+                conditions = conditions + null_conditions
+                tokenized = self.condition_provider.tokenize(conditions)
+                cfg_conditions = self.condition_provider(tokenized)
+        else:
+            cfg_conditions = {}
+
+        if prompt is None:
+            assert num_samples > 0
+            prompt = torch.zeros((num_samples, self.num_codebooks, 0), dtype=torch.long, device=device)
+
+        B, K, T = prompt.shape
+        start_offset = T
+        assert start_offset < max_gen_len
+
+        pattern = self.pattern_provider.get_pattern(max_gen_len)
+        # this token is used as default value for codes that are not generated yet
+        unknown_token = -1
+
+        # we generate codes up to the max_gen_len that will be mapped to the pattern sequence
+        gen_codes = torch.full((B, K, max_gen_len), unknown_token, dtype=torch.long, device=device)
+        # filling the gen_codes with the prompt if needed
+        gen_codes[..., :start_offset] = prompt
+        # create the gen_sequence with proper interleaving from the pattern: [B, K, S]
+        gen_sequence, indexes, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id)
+        # retrieve the start_offset in the sequence:
+        # it is the first sequence step that contains the `start_offset` timestep
+        start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset)
+        assert start_offset_sequence is not None
+
+        with self.streaming():
+            unconditional_state = self.get_streaming_state()
+            prev_offset = 0
+            gen_sequence_len = gen_sequence.shape[-1]  # gen_sequence shape is [B, K, S]
+            for offset in range(start_offset_sequence, gen_sequence_len):
+                # get current sequence (note that the streaming API is providing the caching over previous offsets)
+                curr_sequence = gen_sequence[..., prev_offset:offset]
+                curr_mask = mask[None, ..., prev_offset:offset].expand(B, -1, -1)
+                if check:
+                    # check coherence between mask and sequence
+                    assert (curr_sequence == torch.where(curr_mask, curr_sequence, self.special_token_id)).all()
+                    # should never happen as gen_sequence is filled progressively
+                    assert not (curr_sequence == unknown_token).any()
+                # sample next token from the model, next token shape is [B, K, 1]
+                next_token = self._sample_next_token(
+                    curr_sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p,
+                    cfg_coef=cfg_coef)
+                # ensure the tokens that should be masked are properly set to special_token_id
+                # as the model never output special_token_id
+                valid_mask = mask[..., offset:offset+1].expand(B, -1, -1)
+                next_token[~valid_mask] = self.special_token_id
+                # ensure we don't overwrite prompt tokens, we only write over unknown tokens
+                # (then mask tokens should be left as is as well, which is correct)
+                gen_sequence[..., offset:offset+1] = torch.where(
+                    gen_sequence[..., offset:offset+1] == unknown_token,
+                    next_token, gen_sequence[..., offset:offset+1]
+                )
+                prev_offset = offset
+                if callback is not None:
+                    callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence)
+        unconditional_state.clear()
+
+        # ensure sequence has been entirely filled
+        assert not (gen_sequence == unknown_token).any()
+        # ensure gen_sequence pattern and mask are matching
+        # which means the gen_sequence is valid according to the pattern
+        assert (
+            gen_sequence == torch.where(mask[None, ...].expand(B, -1, -1), gen_sequence, self.special_token_id)
+        ).all()
+        # get back the codes, trimming the prompt if needed and cutting potentially incomplete timesteps
+        out_codes, out_indexes, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token)
+
+        # sanity checks over the returned codes and corresponding masks
+        assert (out_codes[..., :max_gen_len] != unknown_token).all()
+        assert (out_mask[..., :max_gen_len] == 1).all()
+
+        out_start_offset = start_offset if remove_prompts else 0
+        out_codes = out_codes[..., out_start_offset:max_gen_len]
+
+        # ensure the returned codes are all valid
+        assert (out_codes >= 0).all() and (out_codes <= self.card).all()
+        return out_codes
+
+

Ancestors

+ +

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Instance variables

+
+
var num_codebooks : int
+
+
+
+ +Expand source code + +
@property
+def num_codebooks(self) -> int:
+    return self.n_q
+
+
+
var special_token_id : int
+
+
+
+ +Expand source code + +
@property
+def special_token_id(self) -> int:
+    return self.card
+
+
+
+

Methods

+
+
+def compute_predictions(self, codes: torch.Tensor, conditions: List[ConditioningAttributes], condition_tensors: Optional[Dict[str, Tuple[torch.Tensor, torch.Tensor]]] = None) ‑> LMOutput +
+
+

Given an input tensor of codes [B, K, T] and list of conditions, runs the model +forward using the specified codes interleaving pattern.

+

Args

+
+
codes : torch.Tensor
+
Input codes of shape [B, K, T] with B the batch size, +K the number of codebooks and T the number of timesteps.
+
conditions : list[ConditioningAttributes]
+
conditionings to use when modeling +the given codes. Note that when evaluating multiple time with the same conditioning +you should pre-compute those and pass them as condition_tensors.
+
condition_tensors : dict[str, ConditionType] or None
+
pre-computed conditioning +tensors, see conditions.
+
+

Returns

+
+
LMOutput
+
Language model outputs +logits (torch.Tensor) of shape [B, K, T, card] corresponding to the provided codes, +i.e. the first item corresponds to logits to predict the first code, meaning that +no additional shifting of codes and logits is required. +mask (torch.Tensor) of shape [B, K, T], mask over valid and invalid positions. +Given the specified interleaving strategies, parts of the logits and codes should +not be considered as valid predictions because of invalid context.
+
+
+ +Expand source code + +
def compute_predictions(
+        self, codes: torch.Tensor,
+        conditions: tp.List[ConditioningAttributes],
+        condition_tensors: tp.Optional[ConditionTensors] = None) -> LMOutput:
+    """Given an input tensor of codes [B, K, T] and list of conditions, runs the model
+    forward using the specified codes interleaving pattern.
+
+    Args:
+        codes (torch.Tensor): Input codes of shape [B, K, T] with B the batch size,
+            K the number of codebooks and T the number of timesteps.
+        conditions (list[ConditioningAttributes]): conditionings to use when modeling
+            the given codes. Note that when evaluating multiple time with the same conditioning
+            you should pre-compute those and pass them as `condition_tensors`.
+        condition_tensors (dict[str, ConditionType] or None): pre-computed conditioning
+            tensors, see `conditions`.
+    Returns:
+        LMOutput: Language model outputs
+            logits (torch.Tensor) of shape [B, K, T, card] corresponding to the provided codes,
+                i.e. the first item corresponds to logits to predict the first code, meaning that
+                no additional shifting of codes and logits is required.
+            mask (torch.Tensor) of shape [B, K, T], mask over valid and invalid positions.
+                Given the specified interleaving strategies, parts of the logits and codes should
+                not be considered as valid predictions because of invalid context.
+    """
+    B, K, T = codes.shape
+    codes = codes.contiguous()
+    # map codes [B, K, T] into pattern sequence [B, K, S] using special_token_id for masked tokens
+    pattern = self.pattern_provider.get_pattern(T)
+    sequence_codes, sequence_indexes, sequence_mask = pattern.build_pattern_sequence(
+        codes, self.special_token_id, keep_only_valid_steps=True
+    )
+    # apply model on pattern sequence
+    model = self if self._fsdp is None else self._fsdp
+    logits = model(sequence_codes, conditions, condition_tensors)  # [B, K, S, card]
+    # map back the logits on pattern sequence to logits on original codes: [B, K, S, card] -> [B, K, T, card]
+    # and provide the corresponding mask over invalid positions of tokens
+    logits = logits.permute(0, 3, 1, 2)  # [B, card, K, S]
+    # note: we use nans as special token to make it obvious if we feed unexpected logits
+    logits, logits_indexes, logits_mask = pattern.revert_pattern_logits(
+        logits, float('nan'), keep_only_valid_steps=True
+    )
+    logits = logits.permute(0, 2, 3, 1)  # [B, K, T, card]
+    logits_mask = logits_mask[None, :, :].expand(B, -1, -1)  # [K, T] -> [B, K, T]
+    return LMOutput(logits, logits_mask)
+
+
+
+def forward(self, sequence: torch.Tensor, conditions: List[ConditioningAttributes], condition_tensors: Optional[Dict[str, Tuple[torch.Tensor, torch.Tensor]]] = None) ‑> torch.Tensor +
+
+

Apply language model on sequence and conditions. +Given a tensor of sequence of shape [B, K, S] with K the number of codebooks and +S the sequence steps, return the logits with shape [B, card, K, S].

+

Args

+
+
indices : torch.Tensor
+
indices of the codes to model.
+
conditions : list[ConditioningAttributes]
+
conditionings to use when modeling +the given codes. Note that when evaluating multiple time with the same conditioning +you should pre-compute those and pass them as condition_tensors.
+
condition_tensors : dict[str, ConditionType] or None
+
pre-computed conditioning +tensors, see conditions.
+
+

Returns

+
+
torch.Tensor
+
Logits.
+
+
+ +Expand source code + +
def forward(self, sequence: torch.Tensor,
+            conditions: tp.List[ConditioningAttributes],
+            condition_tensors: tp.Optional[ConditionTensors] = None) -> torch.Tensor:
+    """Apply language model on sequence and conditions.
+    Given a tensor of sequence of shape [B, K, S] with K the number of codebooks and
+    S the sequence steps, return the logits with shape [B, card, K, S].
+
+    Args:
+        indices (torch.Tensor): indices of the codes to model.
+        conditions (list[ConditioningAttributes]): conditionings to use when modeling
+            the given codes. Note that when evaluating multiple time with the same conditioning
+            you should pre-compute those and pass them as `condition_tensors`.
+        condition_tensors (dict[str, ConditionType] or None): pre-computed conditioning
+            tensors, see `conditions`.
+    Returns:
+        torch.Tensor: Logits.
+    """
+    B, K, S = sequence.shape
+    assert K == self.num_codebooks, 'Sequence shape must match the specified number of codebooks'
+    input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)])
+    if condition_tensors is None:
+        assert not self._is_streaming, "Conditions tensors should be precomputed when streaming."
+        # apply dropout modules
+        conditions = self.cfg_dropout(conditions)
+        conditions = self.att_dropout(conditions)
+        tokenized = self.condition_provider.tokenize(conditions)
+        # encode conditions and fuse, both have a streaming cache to not recompute when generating.
+        condition_tensors = self.condition_provider(tokenized)
+    else:
+        assert not conditions, "Shouldn't pass both conditions and condition_tensors."
+
+    input_, cross_attention_input = self.fuser(input_, condition_tensors)
+
+    out = self.transformer(input_, cross_attention_src=cross_attention_input)
+    if self.out_norm:
+        out = self.out_norm(out)
+    logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1)  # [B, K, S, card]
+
+    # remove the prefix from the model outputs
+    if len(self.fuser.fuse2cond['prepend']) > 0:
+        logits = logits[:, :, -S:]
+
+    return logits  # [B, K, S, card]
+
+
+
+def generate(self, prompt: Optional[torch.Tensor] = None, conditions: List[ConditioningAttributes] = [], num_samples: Optional[int] = None, max_gen_len: int = 256, use_sampling: bool = True, temp: float = 1.0, top_k: int = 250, top_p: float = 0.0, cfg_coef: Optional[float] = None, two_step_cfg: bool = False, remove_prompts: bool = False, check: bool = False, callback: Optional[Callable[[int, int], None]] = None) ‑> torch.Tensor +
+
+

Generate tokens sampling from the model given a prompt or unconditionally. Generation can +be perform in a greedy fashion or using sampling with top K and top P strategies.

+

Args

+
+
prompt : Optional[torch.Tensor]
+
Prompt tokens of shape [B, K, T].
+
conditions_tensors : Dict[str, torch.Tensor]
+
Set of conditions or None.
+
num_samples : int or None
+
Number of samples to generate when no prompt and no conditions are given.
+
max_gen_len : int
+
Maximum generation length.
+
use_sampling : bool
+
Whether to use a sampling strategy or not.
+
temp : float
+
Sampling temperature.
+
top_k : int
+
K for "top-k" sampling.
+
top_p : float
+
P for "top-p" sampling.
+
remove_prompts : bool
+
Whether to remove prompts from generation or not.
+
+

Returns

+
+
torch.Tensor
+
Generated tokens.
+
+
+ +Expand source code + +
@torch.no_grad()
+def generate(self,
+             prompt: tp.Optional[torch.Tensor] = None,
+             conditions: tp.List[ConditioningAttributes] = [],
+             num_samples: tp.Optional[int] = None,
+             max_gen_len: int = 256,
+             use_sampling: bool = True,
+             temp: float = 1.0,
+             top_k: int = 250,
+             top_p: float = 0.0,
+             cfg_coef: tp.Optional[float] = None,
+             two_step_cfg: bool = False,
+             remove_prompts: bool = False,
+             check: bool = False,
+             callback: tp.Optional[tp.Callable[[int, int], None]] = None) -> torch.Tensor:
+    """Generate tokens sampling from the model given a prompt or unconditionally. Generation can
+    be perform in a greedy fashion or using sampling with top K and top P strategies.
+
+    Args:
+        prompt (Optional[torch.Tensor]): Prompt tokens of shape [B, K, T].
+        conditions_tensors (Dict[str, torch.Tensor]): Set of conditions or None.
+        num_samples (int or None): Number of samples to generate when no prompt and no conditions are given.
+        max_gen_len (int): Maximum generation length.
+        use_sampling (bool): Whether to use a sampling strategy or not.
+        temp (float): Sampling temperature.
+        top_k (int): K for "top-k" sampling.
+        top_p (float): P for "top-p" sampling.
+        remove_prompts (bool): Whether to remove prompts from generation or not.
+    Returns:
+        torch.Tensor: Generated tokens.
+    """
+    assert not self.training, "generation shouldn't be used in training mode."
+    first_param = next(iter(self.parameters()))
+    device = first_param.device
+
+    # Checking all input shapes are consistents.
+    possible_num_samples = []
+    if num_samples is not None:
+        possible_num_samples.append(num_samples)
+    elif prompt is not None:
+        possible_num_samples.append(prompt.shape[0])
+    elif conditions:
+        possible_num_samples.append(len(conditions))
+    else:
+        possible_num_samples.append(1)
+    assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsitent inputs shapes"
+    num_samples = possible_num_samples[0]
+
+    # below we create set of conditions: one conditional and one unconditional
+    # to do that we merge the regular condition together with the null condition
+    # we then do 1 forward pass instead of 2.
+    # the reason for that is two-fold:
+    # 1. it is about x2 faster than doing 2 forward passes
+    # 2. avoid the streaming API treating the 2 passes as part of different time steps
+    # We also support doing two different passes, in particular to ensure that
+    # the padding structure is exactly the same between train anf test.
+    # With a batch size of 1, this can be slower though.
+    cfg_conditions: CFGConditions
+    two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
+    if conditions:
+        null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions)
+        if two_step_cfg:
+            cfg_conditions = (
+                self.condition_provider(self.condition_provider.tokenize(conditions)),
+                self.condition_provider(self.condition_provider.tokenize(null_conditions)),
+            )
+        else:
+            conditions = conditions + null_conditions
+            tokenized = self.condition_provider.tokenize(conditions)
+            cfg_conditions = self.condition_provider(tokenized)
+    else:
+        cfg_conditions = {}
+
+    if prompt is None:
+        assert num_samples > 0
+        prompt = torch.zeros((num_samples, self.num_codebooks, 0), dtype=torch.long, device=device)
+
+    B, K, T = prompt.shape
+    start_offset = T
+    assert start_offset < max_gen_len
+
+    pattern = self.pattern_provider.get_pattern(max_gen_len)
+    # this token is used as default value for codes that are not generated yet
+    unknown_token = -1
+
+    # we generate codes up to the max_gen_len that will be mapped to the pattern sequence
+    gen_codes = torch.full((B, K, max_gen_len), unknown_token, dtype=torch.long, device=device)
+    # filling the gen_codes with the prompt if needed
+    gen_codes[..., :start_offset] = prompt
+    # create the gen_sequence with proper interleaving from the pattern: [B, K, S]
+    gen_sequence, indexes, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id)
+    # retrieve the start_offset in the sequence:
+    # it is the first sequence step that contains the `start_offset` timestep
+    start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset)
+    assert start_offset_sequence is not None
+
+    with self.streaming():
+        unconditional_state = self.get_streaming_state()
+        prev_offset = 0
+        gen_sequence_len = gen_sequence.shape[-1]  # gen_sequence shape is [B, K, S]
+        for offset in range(start_offset_sequence, gen_sequence_len):
+            # get current sequence (note that the streaming API is providing the caching over previous offsets)
+            curr_sequence = gen_sequence[..., prev_offset:offset]
+            curr_mask = mask[None, ..., prev_offset:offset].expand(B, -1, -1)
+            if check:
+                # check coherence between mask and sequence
+                assert (curr_sequence == torch.where(curr_mask, curr_sequence, self.special_token_id)).all()
+                # should never happen as gen_sequence is filled progressively
+                assert not (curr_sequence == unknown_token).any()
+            # sample next token from the model, next token shape is [B, K, 1]
+            next_token = self._sample_next_token(
+                curr_sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p,
+                cfg_coef=cfg_coef)
+            # ensure the tokens that should be masked are properly set to special_token_id
+            # as the model never output special_token_id
+            valid_mask = mask[..., offset:offset+1].expand(B, -1, -1)
+            next_token[~valid_mask] = self.special_token_id
+            # ensure we don't overwrite prompt tokens, we only write over unknown tokens
+            # (then mask tokens should be left as is as well, which is correct)
+            gen_sequence[..., offset:offset+1] = torch.where(
+                gen_sequence[..., offset:offset+1] == unknown_token,
+                next_token, gen_sequence[..., offset:offset+1]
+            )
+            prev_offset = offset
+            if callback is not None:
+                callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence)
+    unconditional_state.clear()
+
+    # ensure sequence has been entirely filled
+    assert not (gen_sequence == unknown_token).any()
+    # ensure gen_sequence pattern and mask are matching
+    # which means the gen_sequence is valid according to the pattern
+    assert (
+        gen_sequence == torch.where(mask[None, ...].expand(B, -1, -1), gen_sequence, self.special_token_id)
+    ).all()
+    # get back the codes, trimming the prompt if needed and cutting potentially incomplete timesteps
+    out_codes, out_indexes, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token)
+
+    # sanity checks over the returned codes and corresponding masks
+    assert (out_codes[..., :max_gen_len] != unknown_token).all()
+    assert (out_mask[..., :max_gen_len] == 1).all()
+
+    out_start_offset = start_offset if remove_prompts else 0
+    out_codes = out_codes[..., out_start_offset:max_gen_len]
+
+    # ensure the returned codes are all valid
+    assert (out_codes >= 0).all() and (out_codes <= self.card).all()
+    return out_codes
+
+
+
+

Inherited members

+ +
+
+class LMOutput +(logits: torch.Tensor, mask: torch.Tensor) +
+
+

LMOutput(logits: torch.Tensor, mask: torch.Tensor)

+
+ +Expand source code + +
class LMOutput:
+    # The logits are already re-aligned with the input codes
+    # hence no extra shift is required, e.g. when computing CE
+    logits: torch.Tensor  # [B, K, T, card]
+    mask: torch.Tensor  # [B, K, T]
+
+

Class variables

+
+
var logits : torch.Tensor
+
+
+
+
var mask : torch.Tensor
+
+
+
+
+
+
+class ScaledEmbedding +(*args, lr=None, **kwargs) +
+
+

Boost learning rate for embeddings (with scale).

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class ScaledEmbedding(nn.Embedding):
+    """Boost learning rate for embeddings (with `scale`).
+    """
+    def __init__(self, *args, lr=None, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.lr = lr
+
+    def make_optim_group(self):
+        group = {"params": list(self.parameters())}
+        if self.lr is not None:
+            group["lr"] = self.lr
+        return group
+
+

Ancestors

+
    +
  • torch.nn.modules.sparse.Embedding
  • +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var embedding_dim : int
+
+
+
+
var freeze : bool
+
+
+
+
var max_norm : Optional[float]
+
+
+
+
var norm_type : float
+
+
+
+
var num_embeddings : int
+
+
+
+
var padding_idx : Optional[int]
+
+
+
+
var scale_grad_by_freq : bool
+
+
+
+
var sparse : bool
+
+
+
+
var weight : torch.Tensor
+
+
+
+
+

Methods

+
+
+def make_optim_group(self) +
+
+
+
+ +Expand source code + +
def make_optim_group(self):
+    group = {"params": list(self.parameters())}
+    if self.lr is not None:
+        group["lr"] = self.lr
+    return group
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/audiocraft/models/loaders.html b/docs/audiocraft/models/loaders.html new file mode 100644 index 00000000..64f4bf06 --- /dev/null +++ b/docs/audiocraft/models/loaders.html @@ -0,0 +1,217 @@ + + + + + + +audiocraft.models.loaders API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.models.loaders

+
+
+

Utility functions to load from the checkpoints. +Each checkpoint is a torch.saved dict with the following keys: +- 'xp.cfg': the hydra config as dumped during training. This should be used +to rebuild the object using the audiocraft.models.builders functions, +- 'model_best_state': a readily loadable best state for the model, including +the conditioner. The model obtained from xp.cfg should be compatible +with this state dict. In the case of a LM, the encodec model would not be +bundled along but instead provided separately.

+

Those functions also support loading from a remote location with the Torch Hub API. +They also support overriding some parameters, in particular the device and dtype +of the returned model.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Utility functions to load from the checkpoints.
+Each checkpoint is a torch.saved dict with the following keys:
+- 'xp.cfg': the hydra config as dumped during training. This should be used
+    to rebuild the object using the audiocraft.models.builders functions,
+- 'model_best_state': a readily loadable best state for the model, including
+    the conditioner. The model obtained from `xp.cfg` should be compatible
+    with this state dict. In the case of a LM, the encodec model would not be
+    bundled along but instead provided separately.
+
+Those functions also support loading from a remote location with the Torch Hub API.
+They also support overriding some parameters, in particular the device and dtype
+of the returned model.
+"""
+
+from pathlib import Path
+from huggingface_hub import hf_hub_download
+import typing as tp
+import os
+
+from omegaconf import OmegaConf
+import torch
+
+from . import builders
+
+
+HF_MODEL_CHECKPOINTS_MAP = {
+    "small": "facebook/musicgen-small",
+    "medium": "facebook/musicgen-medium",
+    "large": "facebook/musicgen-large",
+    "melody": "facebook/musicgen-melody",
+}
+
+
+def _get_state_dict(
+    file_or_url_or_id: tp.Union[Path, str],
+    filename: tp.Optional[str] = None,
+    device='cpu',
+    cache_dir: tp.Optional[str] = None,
+):
+    # Return the state dict either from a file or url
+    file_or_url_or_id = str(file_or_url_or_id)
+    assert isinstance(file_or_url_or_id, str)
+
+    if os.path.isfile(file_or_url_or_id):
+        return torch.load(file_or_url_or_id, map_location=device)
+
+    if os.path.isdir(file_or_url_or_id):
+        file = f"{file_or_url_or_id}/{filename}"
+        return torch.load(file, map_location=device)
+
+    elif file_or_url_or_id.startswith('https://'):
+        return torch.hub.load_state_dict_from_url(file_or_url_or_id, map_location=device, check_hash=True)
+
+    elif file_or_url_or_id in HF_MODEL_CHECKPOINTS_MAP:
+        assert filename is not None, "filename needs to be defined if using HF checkpoints"
+
+        repo_id = HF_MODEL_CHECKPOINTS_MAP[file_or_url_or_id]
+        file = hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=cache_dir)
+        return torch.load(file, map_location=device)
+
+    else:
+        raise ValueError(f"{file_or_url_or_id} is not a valid name, path or link that can be loaded.")
+
+
+def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
+    pkg = _get_state_dict(file_or_url_or_id, filename="compression_state_dict.bin", cache_dir=cache_dir)
+    cfg = OmegaConf.create(pkg['xp.cfg'])
+    cfg.device = str(device)
+    model = builders.get_compression_model(cfg)
+    model.load_state_dict(pkg['best_state'])
+    model.eval()
+    return model
+
+
+def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
+    pkg = _get_state_dict(file_or_url_or_id, filename="state_dict.bin", cache_dir=cache_dir)
+    cfg = OmegaConf.create(pkg['xp.cfg'])
+    cfg.device = str(device)
+    if cfg.device == 'cpu':
+        cfg.dtype = 'float32'
+    else:
+        cfg.dtype = 'float16'
+    model = builders.get_lm_model(cfg)
+    model.load_state_dict(pkg['best_state'])
+    model.eval()
+    model.cfg = cfg
+    return model
+
+
+
+
+
+
+
+

Functions

+
+
+def load_compression_model(file_or_url_or_id: Union[str, pathlib.Path], device='cpu', cache_dir: Optional[str] = None) +
+
+
+
+ +Expand source code + +
def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
+    pkg = _get_state_dict(file_or_url_or_id, filename="compression_state_dict.bin", cache_dir=cache_dir)
+    cfg = OmegaConf.create(pkg['xp.cfg'])
+    cfg.device = str(device)
+    model = builders.get_compression_model(cfg)
+    model.load_state_dict(pkg['best_state'])
+    model.eval()
+    return model
+
+
+
+def load_lm_model(file_or_url_or_id: Union[str, pathlib.Path], device='cpu', cache_dir: Optional[str] = None) +
+
+
+
+ +Expand source code + +
def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
+    pkg = _get_state_dict(file_or_url_or_id, filename="state_dict.bin", cache_dir=cache_dir)
+    cfg = OmegaConf.create(pkg['xp.cfg'])
+    cfg.device = str(device)
+    if cfg.device == 'cpu':
+        cfg.dtype = 'float32'
+    else:
+        cfg.dtype = 'float16'
+    model = builders.get_lm_model(cfg)
+    model.load_state_dict(pkg['best_state'])
+    model.eval()
+    model.cfg = cfg
+    return model
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/audiocraft/models/musicgen.html b/docs/audiocraft/models/musicgen.html new file mode 100644 index 00000000..e6a3a2c8 --- /dev/null +++ b/docs/audiocraft/models/musicgen.html @@ -0,0 +1,1135 @@ + + + + + + +audiocraft.models.musicgen API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.models.musicgen

+
+
+

Main model for using MusicGen. This will combine all the required components +and provide easy access to the generation API.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Main model for using MusicGen. This will combine all the required components
+and provide easy access to the generation API.
+"""
+
+import os
+import typing as tp
+
+import torch
+
+from .encodec import CompressionModel
+from .lm import LMModel
+from .builders import get_debug_compression_model, get_debug_lm_model
+from .loaders import load_compression_model, load_lm_model, HF_MODEL_CHECKPOINTS_MAP
+from ..data.audio_utils import convert_audio
+from ..modules.conditioners import ConditioningAttributes, WavCondition
+from ..utils.autocast import TorchAutocast
+
+
+MelodyList = tp.List[tp.Optional[torch.Tensor]]
+MelodyType = tp.Union[torch.Tensor, MelodyList]
+
+
+class MusicGen:
+    """MusicGen main model with convenient generation API.
+
+    Args:
+        name (str): name of the model.
+        compression_model (CompressionModel): Compression model
+            used to map audio to invertible discrete representations.
+        lm (LMModel): Language model over discrete representations.
+    """
+    def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
+                 max_duration: float = 30):
+        self.name = name
+        self.compression_model = compression_model
+        self.lm = lm
+        self.max_duration = max_duration
+        self.device = next(iter(lm.parameters())).device
+        self.generation_params: dict = {}
+        self.set_generation_params(duration=15)  # 15 seconds by default
+        self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
+        if self.device.type == 'cpu':
+            self.autocast = TorchAutocast(enabled=False)
+        else:
+            self.autocast = TorchAutocast(
+                enabled=True, device_type=self.device.type, dtype=torch.float16)
+
+    @property
+    def frame_rate(self) -> int:
+        """Roughly the number of AR steps per seconds."""
+        return self.compression_model.frame_rate
+
+    @property
+    def sample_rate(self) -> int:
+        """Sample rate of the generated audio."""
+        return self.compression_model.sample_rate
+
+    @property
+    def audio_channels(self) -> int:
+        """Audio channels of the generated audio."""
+        return self.compression_model.channels
+
+    @staticmethod
+    def get_pretrained(name: str = 'melody', device=None):
+        """Return pretrained model, we provide four models:
+        - small (300M), text to music, # see: https://huggingface.co/facebook/musicgen-small
+        - medium (1.5B), text to music, # see: https://huggingface.co/facebook/musicgen-medium
+        - melody (1.5B) text to music and text+melody to music, # see: https://huggingface.co/facebook/musicgen-melody
+        - large (3.3B), text to music, # see: https://huggingface.co/facebook/musicgen-large
+        """
+
+        if device is None:
+            if torch.cuda.device_count():
+                device = 'cuda'
+            else:
+                device = 'cpu'
+
+        if name == 'debug':
+            # used only for unit tests
+            compression_model = get_debug_compression_model(device)
+            lm = get_debug_lm_model(device)
+            return MusicGen(name, compression_model, lm)
+
+        if name not in HF_MODEL_CHECKPOINTS_MAP:
+            if not os.path.isfile(name) and not os.path.isdir(name):
+                raise ValueError(
+                    f"{name} is not a valid checkpoint name. "
+                    f"Choose one of {', '.join(HF_MODEL_CHECKPOINTS_MAP.keys())}"
+                )
+
+        cache_dir = os.environ.get('MUSICGEN_ROOT', None)
+        compression_model = load_compression_model(name, device=device, cache_dir=cache_dir)
+        lm = load_lm_model(name, device=device, cache_dir=cache_dir)
+        if name == 'melody':
+            lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True
+
+        return MusicGen(name, compression_model, lm)
+
+    def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
+                              top_p: float = 0.0, temperature: float = 1.0,
+                              duration: float = 30.0, cfg_coef: float = 3.0,
+                              two_step_cfg: bool = False, extend_stride: float = 18):
+        """Set the generation parameters for MusicGen.
+
+        Args:
+            use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True.
+            top_k (int, optional): top_k used for sampling. Defaults to 250.
+            top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.
+            temperature (float, optional): Softmax temperature parameter. Defaults to 1.0.
+            duration (float, optional): Duration of the generated waveform. Defaults to 30.0.
+            cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0.
+            two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
+                instead of batching together the two. This has some impact on how things
+                are padded but seems to have little impact in practice.
+            extend_stride: when doing extended generation (i.e. more than 30 seconds), by how much
+                should we extend the audio each time. Larger values will mean less context is
+                preserved, and shorter value will require extra computations.
+        """
+        assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
+        self.extend_stride = extend_stride
+        self.duration = duration
+        self.generation_params = {
+            'use_sampling': use_sampling,
+            'temp': temperature,
+            'top_k': top_k,
+            'top_p': top_p,
+            'cfg_coef': cfg_coef,
+            'two_step_cfg': two_step_cfg,
+        }
+
+    def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None):
+        """Override the default progress callback."""
+        self._progress_callback = progress_callback
+
+    def generate_unconditional(self, num_samples: int, progress: bool = False) -> torch.Tensor:
+        """Generate samples in an unconditional manner.
+
+        Args:
+            num_samples (int): Number of samples to be generated.
+            progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+        """
+        descriptions: tp.List[tp.Optional[str]] = [None] * num_samples
+        attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
+        return self._generate_tokens(attributes, prompt_tokens, progress)
+
+    def generate(self, descriptions: tp.List[str], progress: bool = False) -> torch.Tensor:
+        """Generate samples conditioned on text.
+
+        Args:
+            descriptions (tp.List[str]): A list of strings used as text conditioning.
+            progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+        """
+        attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
+        assert prompt_tokens is None
+        return self._generate_tokens(attributes, prompt_tokens, progress)
+
+    def generate_with_chroma(self, descriptions: tp.List[str], melody_wavs: MelodyType,
+                             melody_sample_rate: int, progress: bool = False) -> torch.Tensor:
+        """Generate samples conditioned on text and melody.
+
+        Args:
+            descriptions (tp.List[str]): A list of strings used as text conditioning.
+            melody_wavs: (torch.Tensor or list of Tensor): A batch of waveforms used as
+                melody conditioning. Should have shape [B, C, T] with B matching the description length,
+                C=1 or 2. It can be [C, T] if there is a single description. It can also be
+                a list of [C, T] tensors.
+            melody_sample_rate: (int): Sample rate of the melody waveforms.
+            progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+        """
+        if isinstance(melody_wavs, torch.Tensor):
+            if melody_wavs.dim() == 2:
+                melody_wavs = melody_wavs[None]
+            if melody_wavs.dim() != 3:
+                raise ValueError("Melody wavs should have a shape [B, C, T].")
+            melody_wavs = list(melody_wavs)
+        else:
+            for melody in melody_wavs:
+                if melody is not None:
+                    assert melody.dim() == 2, "One melody in the list has the wrong number of dims."
+
+        melody_wavs = [
+            convert_audio(wav, melody_sample_rate, self.sample_rate, self.audio_channels)
+            if wav is not None else None
+            for wav in melody_wavs]
+        attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions=descriptions, prompt=None,
+                                                                        melody_wavs=melody_wavs)
+        assert prompt_tokens is None
+        return self._generate_tokens(attributes, prompt_tokens, progress)
+
+    def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
+                              descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None,
+                              progress: bool = False) -> torch.Tensor:
+        """Generate samples conditioned on audio prompts.
+
+        Args:
+            prompt (torch.Tensor): A batch of waveforms used for continuation.
+                Prompt should be [B, C, T], or [C, T] if only one sample is generated.
+            prompt_sample_rate (int): Sampling rate of the given audio waveforms.
+            descriptions (tp.List[str], optional): A list of strings used as text conditioning. Defaults to None.
+            progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+        """
+        if prompt.dim() == 2:
+            prompt = prompt[None]
+        if prompt.dim() != 3:
+            raise ValueError("prompt should have 3 dimensions: [B, C, T] (C = 1).")
+        prompt = convert_audio(prompt, prompt_sample_rate, self.sample_rate, self.audio_channels)
+        if descriptions is None:
+            descriptions = [None] * len(prompt)
+        attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt)
+        assert prompt_tokens is not None
+        return self._generate_tokens(attributes, prompt_tokens, progress)
+
+    @torch.no_grad()
+    def _prepare_tokens_and_attributes(
+            self,
+            descriptions: tp.Sequence[tp.Optional[str]],
+            prompt: tp.Optional[torch.Tensor],
+            melody_wavs: tp.Optional[MelodyList] = None,
+    ) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]:
+        """Prepare model inputs.
+
+        Args:
+            descriptions (tp.List[str]): A list of strings used as text conditioning.
+            prompt (torch.Tensor): A batch of waveforms used for continuation.
+            melody_wavs (tp.Optional[torch.Tensor], optional): A batch of waveforms
+                used as melody conditioning. Defaults to None.
+        """
+        attributes = [
+            ConditioningAttributes(text={'description': description})
+            for description in descriptions]
+
+        if melody_wavs is None:
+            for attr in attributes:
+                attr.wav['self_wav'] = WavCondition(
+                    torch.zeros((1, 1), device=self.device),
+                    torch.tensor([0], device=self.device),
+                    path='null_wav')  # type: ignore
+        else:
+            if self.name != "melody":
+                raise RuntimeError("This model doesn't support melody conditioning. "
+                                   "Use the `melody` model.")
+            assert len(melody_wavs) == len(descriptions), \
+                f"number of melody wavs must match number of descriptions! " \
+                f"got melody len={len(melody_wavs)}, and descriptions len={len(descriptions)}"
+            for attr, melody in zip(attributes, melody_wavs):
+                if melody is None:
+                    attr.wav['self_wav'] = WavCondition(
+                        torch.zeros((1, 1), device=self.device),
+                        torch.tensor([0], device=self.device),
+                        path='null_wav')  # type: ignore
+                else:
+                    attr.wav['self_wav'] = WavCondition(
+                        melody.to(device=self.device),
+                        torch.tensor([melody.shape[-1]], device=self.device))
+
+        if prompt is not None:
+            if descriptions is not None:
+                assert len(descriptions) == len(prompt), "Prompt and nb. descriptions doesn't match"
+            prompt = prompt.to(self.device)
+            prompt_tokens, scale = self.compression_model.encode(prompt)
+            assert scale is None
+        else:
+            prompt_tokens = None
+        return attributes, prompt_tokens
+
+    def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
+                         prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor:
+        """Generate discrete audio tokens given audio prompt and/or conditions.
+
+        Args:
+            attributes (tp.List[ConditioningAttributes]): Conditions used for generation (text/melody).
+            prompt_tokens (tp.Optional[torch.Tensor]): Audio prompt used for continuation.
+            progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+        Returns:
+            torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
+        """
+        total_gen_len = int(self.duration * self.frame_rate)
+        max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate)
+        current_gen_offset: int = 0
+
+        def _progress_callback(generated_tokens: int, tokens_to_generate: int):
+            generated_tokens += current_gen_offset
+            if self._progress_callback is not None:
+                # Note that total_gen_len might be quite wrong depending on the
+                # codebook pattern used, but with delay it is almost accurate.
+                self._progress_callback(generated_tokens, total_gen_len)
+            else:
+                print(f'{generated_tokens: 6d} / {total_gen_len: 6d}', end='\r')
+
+        if prompt_tokens is not None:
+            assert max_prompt_len >= prompt_tokens.shape[-1], \
+                "Prompt is longer than audio to generate"
+
+        callback = None
+        if progress:
+            callback = _progress_callback
+
+        if self.duration <= self.max_duration:
+            # generate by sampling from LM, simple case.
+            with self.autocast:
+                gen_tokens = self.lm.generate(
+                    prompt_tokens, attributes,
+                    callback=callback, max_gen_len=total_gen_len, **self.generation_params)
+
+        else:
+            # now this gets a bit messier, we need to handle prompts,
+            # melody conditioning etc.
+            ref_wavs = [attr.wav['self_wav'] for attr in attributes]
+            all_tokens = []
+            if prompt_tokens is None:
+                prompt_length = 0
+            else:
+                all_tokens.append(prompt_tokens)
+                prompt_length = prompt_tokens.shape[-1]
+
+            stride_tokens = int(self.frame_rate * self.extend_stride)
+
+            while current_gen_offset + prompt_length < total_gen_len:
+                time_offset = current_gen_offset / self.frame_rate
+                chunk_duration = min(self.duration - time_offset, self.max_duration)
+                max_gen_len = int(chunk_duration * self.frame_rate)
+                for attr, ref_wav in zip(attributes, ref_wavs):
+                    wav_length = ref_wav.length.item()
+                    if wav_length == 0:
+                        continue
+                    # We will extend the wav periodically if it not long enough.
+                    # we have to do it here rather than in conditioners.py as otherwise
+                    # we wouldn't have the full wav.
+                    initial_position = int(time_offset * self.sample_rate)
+                    wav_target_length = int(self.max_duration * self.sample_rate)
+                    print(initial_position / self.sample_rate, wav_target_length / self.sample_rate)
+                    positions = torch.arange(initial_position,
+                                             initial_position + wav_target_length, device=self.device)
+                    attr.wav['self_wav'] = WavCondition(
+                        ref_wav[0][:, positions % wav_length],
+                        torch.full_like(ref_wav[1], wav_target_length))
+                with self.autocast:
+                    gen_tokens = self.lm.generate(
+                        prompt_tokens, attributes,
+                        callback=callback, max_gen_len=max_gen_len, **self.generation_params)
+                if prompt_tokens is None:
+                    all_tokens.append(gen_tokens)
+                else:
+                    all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:])
+                prompt_tokens = gen_tokens[:, :, stride_tokens:]
+                prompt_length = prompt_tokens.shape[-1]
+                current_gen_offset += stride_tokens
+
+            gen_tokens = torch.cat(all_tokens, dim=-1)
+
+        # generate audio
+        assert gen_tokens.dim() == 3
+        with torch.no_grad():
+            gen_audio = self.compression_model.decode(gen_tokens, None)
+        return gen_audio
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class MusicGen +(name: str, compression_model: CompressionModel, lm: LMModel, max_duration: float = 30) +
+
+

MusicGen main model with convenient generation API.

+

Args

+
+
name : str
+
name of the model.
+
compression_model : CompressionModel
+
Compression model +used to map audio to invertible discrete representations.
+
lm : LMModel
+
Language model over discrete representations.
+
+
+ +Expand source code + +
class MusicGen:
+    """MusicGen main model with convenient generation API.
+
+    Args:
+        name (str): name of the model.
+        compression_model (CompressionModel): Compression model
+            used to map audio to invertible discrete representations.
+        lm (LMModel): Language model over discrete representations.
+    """
+    def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
+                 max_duration: float = 30):
+        self.name = name
+        self.compression_model = compression_model
+        self.lm = lm
+        self.max_duration = max_duration
+        self.device = next(iter(lm.parameters())).device
+        self.generation_params: dict = {}
+        self.set_generation_params(duration=15)  # 15 seconds by default
+        self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
+        if self.device.type == 'cpu':
+            self.autocast = TorchAutocast(enabled=False)
+        else:
+            self.autocast = TorchAutocast(
+                enabled=True, device_type=self.device.type, dtype=torch.float16)
+
+    @property
+    def frame_rate(self) -> int:
+        """Roughly the number of AR steps per seconds."""
+        return self.compression_model.frame_rate
+
+    @property
+    def sample_rate(self) -> int:
+        """Sample rate of the generated audio."""
+        return self.compression_model.sample_rate
+
+    @property
+    def audio_channels(self) -> int:
+        """Audio channels of the generated audio."""
+        return self.compression_model.channels
+
+    @staticmethod
+    def get_pretrained(name: str = 'melody', device=None):
+        """Return pretrained model, we provide four models:
+        - small (300M), text to music, # see: https://huggingface.co/facebook/musicgen-small
+        - medium (1.5B), text to music, # see: https://huggingface.co/facebook/musicgen-medium
+        - melody (1.5B) text to music and text+melody to music, # see: https://huggingface.co/facebook/musicgen-melody
+        - large (3.3B), text to music, # see: https://huggingface.co/facebook/musicgen-large
+        """
+
+        if device is None:
+            if torch.cuda.device_count():
+                device = 'cuda'
+            else:
+                device = 'cpu'
+
+        if name == 'debug':
+            # used only for unit tests
+            compression_model = get_debug_compression_model(device)
+            lm = get_debug_lm_model(device)
+            return MusicGen(name, compression_model, lm)
+
+        if name not in HF_MODEL_CHECKPOINTS_MAP:
+            if not os.path.isfile(name) and not os.path.isdir(name):
+                raise ValueError(
+                    f"{name} is not a valid checkpoint name. "
+                    f"Choose one of {', '.join(HF_MODEL_CHECKPOINTS_MAP.keys())}"
+                )
+
+        cache_dir = os.environ.get('MUSICGEN_ROOT', None)
+        compression_model = load_compression_model(name, device=device, cache_dir=cache_dir)
+        lm = load_lm_model(name, device=device, cache_dir=cache_dir)
+        if name == 'melody':
+            lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True
+
+        return MusicGen(name, compression_model, lm)
+
+    def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
+                              top_p: float = 0.0, temperature: float = 1.0,
+                              duration: float = 30.0, cfg_coef: float = 3.0,
+                              two_step_cfg: bool = False, extend_stride: float = 18):
+        """Set the generation parameters for MusicGen.
+
+        Args:
+            use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True.
+            top_k (int, optional): top_k used for sampling. Defaults to 250.
+            top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.
+            temperature (float, optional): Softmax temperature parameter. Defaults to 1.0.
+            duration (float, optional): Duration of the generated waveform. Defaults to 30.0.
+            cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0.
+            two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
+                instead of batching together the two. This has some impact on how things
+                are padded but seems to have little impact in practice.
+            extend_stride: when doing extended generation (i.e. more than 30 seconds), by how much
+                should we extend the audio each time. Larger values will mean less context is
+                preserved, and shorter value will require extra computations.
+        """
+        assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
+        self.extend_stride = extend_stride
+        self.duration = duration
+        self.generation_params = {
+            'use_sampling': use_sampling,
+            'temp': temperature,
+            'top_k': top_k,
+            'top_p': top_p,
+            'cfg_coef': cfg_coef,
+            'two_step_cfg': two_step_cfg,
+        }
+
+    def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None):
+        """Override the default progress callback."""
+        self._progress_callback = progress_callback
+
+    def generate_unconditional(self, num_samples: int, progress: bool = False) -> torch.Tensor:
+        """Generate samples in an unconditional manner.
+
+        Args:
+            num_samples (int): Number of samples to be generated.
+            progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+        """
+        descriptions: tp.List[tp.Optional[str]] = [None] * num_samples
+        attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
+        return self._generate_tokens(attributes, prompt_tokens, progress)
+
+    def generate(self, descriptions: tp.List[str], progress: bool = False) -> torch.Tensor:
+        """Generate samples conditioned on text.
+
+        Args:
+            descriptions (tp.List[str]): A list of strings used as text conditioning.
+            progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+        """
+        attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
+        assert prompt_tokens is None
+        return self._generate_tokens(attributes, prompt_tokens, progress)
+
+    def generate_with_chroma(self, descriptions: tp.List[str], melody_wavs: MelodyType,
+                             melody_sample_rate: int, progress: bool = False) -> torch.Tensor:
+        """Generate samples conditioned on text and melody.
+
+        Args:
+            descriptions (tp.List[str]): A list of strings used as text conditioning.
+            melody_wavs: (torch.Tensor or list of Tensor): A batch of waveforms used as
+                melody conditioning. Should have shape [B, C, T] with B matching the description length,
+                C=1 or 2. It can be [C, T] if there is a single description. It can also be
+                a list of [C, T] tensors.
+            melody_sample_rate: (int): Sample rate of the melody waveforms.
+            progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+        """
+        if isinstance(melody_wavs, torch.Tensor):
+            if melody_wavs.dim() == 2:
+                melody_wavs = melody_wavs[None]
+            if melody_wavs.dim() != 3:
+                raise ValueError("Melody wavs should have a shape [B, C, T].")
+            melody_wavs = list(melody_wavs)
+        else:
+            for melody in melody_wavs:
+                if melody is not None:
+                    assert melody.dim() == 2, "One melody in the list has the wrong number of dims."
+
+        melody_wavs = [
+            convert_audio(wav, melody_sample_rate, self.sample_rate, self.audio_channels)
+            if wav is not None else None
+            for wav in melody_wavs]
+        attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions=descriptions, prompt=None,
+                                                                        melody_wavs=melody_wavs)
+        assert prompt_tokens is None
+        return self._generate_tokens(attributes, prompt_tokens, progress)
+
+    def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
+                              descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None,
+                              progress: bool = False) -> torch.Tensor:
+        """Generate samples conditioned on audio prompts.
+
+        Args:
+            prompt (torch.Tensor): A batch of waveforms used for continuation.
+                Prompt should be [B, C, T], or [C, T] if only one sample is generated.
+            prompt_sample_rate (int): Sampling rate of the given audio waveforms.
+            descriptions (tp.List[str], optional): A list of strings used as text conditioning. Defaults to None.
+            progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+        """
+        if prompt.dim() == 2:
+            prompt = prompt[None]
+        if prompt.dim() != 3:
+            raise ValueError("prompt should have 3 dimensions: [B, C, T] (C = 1).")
+        prompt = convert_audio(prompt, prompt_sample_rate, self.sample_rate, self.audio_channels)
+        if descriptions is None:
+            descriptions = [None] * len(prompt)
+        attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt)
+        assert prompt_tokens is not None
+        return self._generate_tokens(attributes, prompt_tokens, progress)
+
+    @torch.no_grad()
+    def _prepare_tokens_and_attributes(
+            self,
+            descriptions: tp.Sequence[tp.Optional[str]],
+            prompt: tp.Optional[torch.Tensor],
+            melody_wavs: tp.Optional[MelodyList] = None,
+    ) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]:
+        """Prepare model inputs.
+
+        Args:
+            descriptions (tp.List[str]): A list of strings used as text conditioning.
+            prompt (torch.Tensor): A batch of waveforms used for continuation.
+            melody_wavs (tp.Optional[torch.Tensor], optional): A batch of waveforms
+                used as melody conditioning. Defaults to None.
+        """
+        attributes = [
+            ConditioningAttributes(text={'description': description})
+            for description in descriptions]
+
+        if melody_wavs is None:
+            for attr in attributes:
+                attr.wav['self_wav'] = WavCondition(
+                    torch.zeros((1, 1), device=self.device),
+                    torch.tensor([0], device=self.device),
+                    path='null_wav')  # type: ignore
+        else:
+            if self.name != "melody":
+                raise RuntimeError("This model doesn't support melody conditioning. "
+                                   "Use the `melody` model.")
+            assert len(melody_wavs) == len(descriptions), \
+                f"number of melody wavs must match number of descriptions! " \
+                f"got melody len={len(melody_wavs)}, and descriptions len={len(descriptions)}"
+            for attr, melody in zip(attributes, melody_wavs):
+                if melody is None:
+                    attr.wav['self_wav'] = WavCondition(
+                        torch.zeros((1, 1), device=self.device),
+                        torch.tensor([0], device=self.device),
+                        path='null_wav')  # type: ignore
+                else:
+                    attr.wav['self_wav'] = WavCondition(
+                        melody.to(device=self.device),
+                        torch.tensor([melody.shape[-1]], device=self.device))
+
+        if prompt is not None:
+            if descriptions is not None:
+                assert len(descriptions) == len(prompt), "Prompt and nb. descriptions doesn't match"
+            prompt = prompt.to(self.device)
+            prompt_tokens, scale = self.compression_model.encode(prompt)
+            assert scale is None
+        else:
+            prompt_tokens = None
+        return attributes, prompt_tokens
+
+    def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
+                         prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor:
+        """Generate discrete audio tokens given audio prompt and/or conditions.
+
+        Args:
+            attributes (tp.List[ConditioningAttributes]): Conditions used for generation (text/melody).
+            prompt_tokens (tp.Optional[torch.Tensor]): Audio prompt used for continuation.
+            progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+        Returns:
+            torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
+        """
+        total_gen_len = int(self.duration * self.frame_rate)
+        max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate)
+        current_gen_offset: int = 0
+
+        def _progress_callback(generated_tokens: int, tokens_to_generate: int):
+            generated_tokens += current_gen_offset
+            if self._progress_callback is not None:
+                # Note that total_gen_len might be quite wrong depending on the
+                # codebook pattern used, but with delay it is almost accurate.
+                self._progress_callback(generated_tokens, total_gen_len)
+            else:
+                print(f'{generated_tokens: 6d} / {total_gen_len: 6d}', end='\r')
+
+        if prompt_tokens is not None:
+            assert max_prompt_len >= prompt_tokens.shape[-1], \
+                "Prompt is longer than audio to generate"
+
+        callback = None
+        if progress:
+            callback = _progress_callback
+
+        if self.duration <= self.max_duration:
+            # generate by sampling from LM, simple case.
+            with self.autocast:
+                gen_tokens = self.lm.generate(
+                    prompt_tokens, attributes,
+                    callback=callback, max_gen_len=total_gen_len, **self.generation_params)
+
+        else:
+            # now this gets a bit messier, we need to handle prompts,
+            # melody conditioning etc.
+            ref_wavs = [attr.wav['self_wav'] for attr in attributes]
+            all_tokens = []
+            if prompt_tokens is None:
+                prompt_length = 0
+            else:
+                all_tokens.append(prompt_tokens)
+                prompt_length = prompt_tokens.shape[-1]
+
+            stride_tokens = int(self.frame_rate * self.extend_stride)
+
+            while current_gen_offset + prompt_length < total_gen_len:
+                time_offset = current_gen_offset / self.frame_rate
+                chunk_duration = min(self.duration - time_offset, self.max_duration)
+                max_gen_len = int(chunk_duration * self.frame_rate)
+                for attr, ref_wav in zip(attributes, ref_wavs):
+                    wav_length = ref_wav.length.item()
+                    if wav_length == 0:
+                        continue
+                    # We will extend the wav periodically if it not long enough.
+                    # we have to do it here rather than in conditioners.py as otherwise
+                    # we wouldn't have the full wav.
+                    initial_position = int(time_offset * self.sample_rate)
+                    wav_target_length = int(self.max_duration * self.sample_rate)
+                    print(initial_position / self.sample_rate, wav_target_length / self.sample_rate)
+                    positions = torch.arange(initial_position,
+                                             initial_position + wav_target_length, device=self.device)
+                    attr.wav['self_wav'] = WavCondition(
+                        ref_wav[0][:, positions % wav_length],
+                        torch.full_like(ref_wav[1], wav_target_length))
+                with self.autocast:
+                    gen_tokens = self.lm.generate(
+                        prompt_tokens, attributes,
+                        callback=callback, max_gen_len=max_gen_len, **self.generation_params)
+                if prompt_tokens is None:
+                    all_tokens.append(gen_tokens)
+                else:
+                    all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:])
+                prompt_tokens = gen_tokens[:, :, stride_tokens:]
+                prompt_length = prompt_tokens.shape[-1]
+                current_gen_offset += stride_tokens
+
+            gen_tokens = torch.cat(all_tokens, dim=-1)
+
+        # generate audio
+        assert gen_tokens.dim() == 3
+        with torch.no_grad():
+            gen_audio = self.compression_model.decode(gen_tokens, None)
+        return gen_audio
+
+

Static methods

+
+
+def get_pretrained(name: str = 'melody', device=None) +
+
+

Return pretrained model, we provide four models: +- small (300M), text to music, # see: https://huggingface.co/facebook/musicgen-small +- medium (1.5B), text to music, # see: https://huggingface.co/facebook/musicgen-medium +- melody (1.5B) text to music and text+melody to music, # see: https://huggingface.co/facebook/musicgen-melody +- large (3.3B), text to music, # see: https://huggingface.co/facebook/musicgen-large

+
+ +Expand source code + +
@staticmethod
+def get_pretrained(name: str = 'melody', device=None):
+    """Return pretrained model, we provide four models:
+    - small (300M), text to music, # see: https://huggingface.co/facebook/musicgen-small
+    - medium (1.5B), text to music, # see: https://huggingface.co/facebook/musicgen-medium
+    - melody (1.5B) text to music and text+melody to music, # see: https://huggingface.co/facebook/musicgen-melody
+    - large (3.3B), text to music, # see: https://huggingface.co/facebook/musicgen-large
+    """
+
+    if device is None:
+        if torch.cuda.device_count():
+            device = 'cuda'
+        else:
+            device = 'cpu'
+
+    if name == 'debug':
+        # used only for unit tests
+        compression_model = get_debug_compression_model(device)
+        lm = get_debug_lm_model(device)
+        return MusicGen(name, compression_model, lm)
+
+    if name not in HF_MODEL_CHECKPOINTS_MAP:
+        if not os.path.isfile(name) and not os.path.isdir(name):
+            raise ValueError(
+                f"{name} is not a valid checkpoint name. "
+                f"Choose one of {', '.join(HF_MODEL_CHECKPOINTS_MAP.keys())}"
+            )
+
+    cache_dir = os.environ.get('MUSICGEN_ROOT', None)
+    compression_model = load_compression_model(name, device=device, cache_dir=cache_dir)
+    lm = load_lm_model(name, device=device, cache_dir=cache_dir)
+    if name == 'melody':
+        lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True
+
+    return MusicGen(name, compression_model, lm)
+
+
+
+

Instance variables

+
+
var audio_channels : int
+
+

Audio channels of the generated audio.

+
+ +Expand source code + +
@property
+def audio_channels(self) -> int:
+    """Audio channels of the generated audio."""
+    return self.compression_model.channels
+
+
+
var frame_rate : int
+
+

Roughly the number of AR steps per seconds.

+
+ +Expand source code + +
@property
+def frame_rate(self) -> int:
+    """Roughly the number of AR steps per seconds."""
+    return self.compression_model.frame_rate
+
+
+
var sample_rate : int
+
+

Sample rate of the generated audio.

+
+ +Expand source code + +
@property
+def sample_rate(self) -> int:
+    """Sample rate of the generated audio."""
+    return self.compression_model.sample_rate
+
+
+
+

Methods

+
+
+def generate(self, descriptions: List[str], progress: bool = False) ‑> torch.Tensor +
+
+

Generate samples conditioned on text.

+

Args

+
+
descriptions : tp.List[str]
+
A list of strings used as text conditioning.
+
progress : bool, optional
+
Flag to display progress of the generation process. Defaults to False.
+
+
+ +Expand source code + +
def generate(self, descriptions: tp.List[str], progress: bool = False) -> torch.Tensor:
+    """Generate samples conditioned on text.
+
+    Args:
+        descriptions (tp.List[str]): A list of strings used as text conditioning.
+        progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+    """
+    attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
+    assert prompt_tokens is None
+    return self._generate_tokens(attributes, prompt_tokens, progress)
+
+
+
+def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int, descriptions: Optional[List[Optional[str]]] = None, progress: bool = False) ‑> torch.Tensor +
+
+

Generate samples conditioned on audio prompts.

+

Args

+
+
prompt : torch.Tensor
+
A batch of waveforms used for continuation. +Prompt should be [B, C, T], or [C, T] if only one sample is generated.
+
prompt_sample_rate : int
+
Sampling rate of the given audio waveforms.
+
descriptions : tp.List[str], optional
+
A list of strings used as text conditioning. Defaults to None.
+
progress : bool, optional
+
Flag to display progress of the generation process. Defaults to False.
+
+
+ +Expand source code + +
def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
+                          descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None,
+                          progress: bool = False) -> torch.Tensor:
+    """Generate samples conditioned on audio prompts.
+
+    Args:
+        prompt (torch.Tensor): A batch of waveforms used for continuation.
+            Prompt should be [B, C, T], or [C, T] if only one sample is generated.
+        prompt_sample_rate (int): Sampling rate of the given audio waveforms.
+        descriptions (tp.List[str], optional): A list of strings used as text conditioning. Defaults to None.
+        progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+    """
+    if prompt.dim() == 2:
+        prompt = prompt[None]
+    if prompt.dim() != 3:
+        raise ValueError("prompt should have 3 dimensions: [B, C, T] (C = 1).")
+    prompt = convert_audio(prompt, prompt_sample_rate, self.sample_rate, self.audio_channels)
+    if descriptions is None:
+        descriptions = [None] * len(prompt)
+    attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt)
+    assert prompt_tokens is not None
+    return self._generate_tokens(attributes, prompt_tokens, progress)
+
+
+
+def generate_unconditional(self, num_samples: int, progress: bool = False) ‑> torch.Tensor +
+
+

Generate samples in an unconditional manner.

+

Args

+
+
num_samples : int
+
Number of samples to be generated.
+
progress : bool, optional
+
Flag to display progress of the generation process. Defaults to False.
+
+
+ +Expand source code + +
def generate_unconditional(self, num_samples: int, progress: bool = False) -> torch.Tensor:
+    """Generate samples in an unconditional manner.
+
+    Args:
+        num_samples (int): Number of samples to be generated.
+        progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+    """
+    descriptions: tp.List[tp.Optional[str]] = [None] * num_samples
+    attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
+    return self._generate_tokens(attributes, prompt_tokens, progress)
+
+
+
+def generate_with_chroma(self, descriptions: List[str], melody_wavs: Union[torch.Tensor, List[Optional[torch.Tensor]]], melody_sample_rate: int, progress: bool = False) ‑> torch.Tensor +
+
+

Generate samples conditioned on text and melody.

+

Args

+
+
descriptions : tp.List[str]
+
A list of strings used as text conditioning.
+
melody_wavs
+
(torch.Tensor or list of Tensor): A batch of waveforms used as +melody conditioning. Should have shape [B, C, T] with B matching the description length, +C=1 or 2. It can be [C, T] if there is a single description. It can also be +a list of [C, T] tensors.
+
melody_sample_rate
+
(int): Sample rate of the melody waveforms.
+
progress : bool, optional
+
Flag to display progress of the generation process. Defaults to False.
+
+
+ +Expand source code + +
def generate_with_chroma(self, descriptions: tp.List[str], melody_wavs: MelodyType,
+                         melody_sample_rate: int, progress: bool = False) -> torch.Tensor:
+    """Generate samples conditioned on text and melody.
+
+    Args:
+        descriptions (tp.List[str]): A list of strings used as text conditioning.
+        melody_wavs: (torch.Tensor or list of Tensor): A batch of waveforms used as
+            melody conditioning. Should have shape [B, C, T] with B matching the description length,
+            C=1 or 2. It can be [C, T] if there is a single description. It can also be
+            a list of [C, T] tensors.
+        melody_sample_rate: (int): Sample rate of the melody waveforms.
+        progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+    """
+    if isinstance(melody_wavs, torch.Tensor):
+        if melody_wavs.dim() == 2:
+            melody_wavs = melody_wavs[None]
+        if melody_wavs.dim() != 3:
+            raise ValueError("Melody wavs should have a shape [B, C, T].")
+        melody_wavs = list(melody_wavs)
+    else:
+        for melody in melody_wavs:
+            if melody is not None:
+                assert melody.dim() == 2, "One melody in the list has the wrong number of dims."
+
+    melody_wavs = [
+        convert_audio(wav, melody_sample_rate, self.sample_rate, self.audio_channels)
+        if wav is not None else None
+        for wav in melody_wavs]
+    attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions=descriptions, prompt=None,
+                                                                    melody_wavs=melody_wavs)
+    assert prompt_tokens is None
+    return self._generate_tokens(attributes, prompt_tokens, progress)
+
+
+
+def set_custom_progress_callback(self, progress_callback: Optional[Callable[[int, int], None]] = None) +
+
+

Override the default progress callback.

+
+ +Expand source code + +
def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None):
+    """Override the default progress callback."""
+    self._progress_callback = progress_callback
+
+
+
+def set_generation_params(self, use_sampling: bool = True, top_k: int = 250, top_p: float = 0.0, temperature: float = 1.0, duration: float = 30.0, cfg_coef: float = 3.0, two_step_cfg: bool = False, extend_stride: float = 18) +
+
+

Set the generation parameters for MusicGen.

+

Args

+
+
use_sampling : bool, optional
+
Use sampling if True, else do argmax decoding. Defaults to True.
+
top_k : int, optional
+
top_k used for sampling. Defaults to 250.
+
top_p : float, optional
+
top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.
+
temperature : float, optional
+
Softmax temperature parameter. Defaults to 1.0.
+
duration : float, optional
+
Duration of the generated waveform. Defaults to 30.0.
+
cfg_coef : float, optional
+
Coefficient used for classifier free guidance. Defaults to 3.0.
+
two_step_cfg : bool, optional
+
If True, performs 2 forward for Classifier Free Guidance, +instead of batching together the two. This has some impact on how things +are padded but seems to have little impact in practice.
+
extend_stride
+
when doing extended generation (i.e. more than 30 seconds), by how much +should we extend the audio each time. Larger values will mean less context is +preserved, and shorter value will require extra computations.
+
+
+ +Expand source code + +
def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
+                          top_p: float = 0.0, temperature: float = 1.0,
+                          duration: float = 30.0, cfg_coef: float = 3.0,
+                          two_step_cfg: bool = False, extend_stride: float = 18):
+    """Set the generation parameters for MusicGen.
+
+    Args:
+        use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True.
+        top_k (int, optional): top_k used for sampling. Defaults to 250.
+        top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.
+        temperature (float, optional): Softmax temperature parameter. Defaults to 1.0.
+        duration (float, optional): Duration of the generated waveform. Defaults to 30.0.
+        cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0.
+        two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
+            instead of batching together the two. This has some impact on how things
+            are padded but seems to have little impact in practice.
+        extend_stride: when doing extended generation (i.e. more than 30 seconds), by how much
+            should we extend the audio each time. Larger values will mean less context is
+            preserved, and shorter value will require extra computations.
+    """
+    assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
+    self.extend_stride = extend_stride
+    self.duration = duration
+    self.generation_params = {
+        'use_sampling': use_sampling,
+        'temp': temperature,
+        'top_k': top_k,
+        'top_p': top_p,
+        'cfg_coef': cfg_coef,
+        'two_step_cfg': two_step_cfg,
+    }
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/audiocraft/modules/activations.html b/docs/audiocraft/modules/activations.html new file mode 100644 index 00000000..08efaf8a --- /dev/null +++ b/docs/audiocraft/modules/activations.html @@ -0,0 +1,523 @@ + + + + + + +audiocraft.modules.activations API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.modules.activations

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from torch import Tensor
+from typing import Union, Callable
+
+
+class CustomGLU(nn.Module):
+    """Custom Gated Linear Unit activation.
+    Applies a modified gated linear unit :math:`a * f(b)` where :math:`a` is the first half
+    of the input matrices, :math:`b` is the second half, and :math:`f` is a provided activation
+    function (i.e. sigmoid, swish, etc.).
+
+    Args:
+        activation (nn.Module): The custom activation to apply in the Gated Linear Unit
+        dim (int): the dimension on which to split the input. Default: -1
+
+    Shape:
+        - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
+          dimensions
+        - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
+
+    Examples::
+        >>> m = CustomGLU(nn.Sigmoid())
+        >>> input = torch.randn(4, 2)
+        >>> output = m(input)
+    """
+    def __init__(self, activation: nn.Module, dim: int = -1):
+        super(CustomGLU, self).__init__()
+        self.dim = dim
+        self.activation = activation
+
+    def forward(self, x: Tensor):
+        assert x.shape[self.dim] % 2 == 0  # M = N / 2
+        a, b = torch.chunk(x, 2, dim=self.dim)
+        return a * self.activation(b)
+
+
+class SwiGLU(CustomGLU):
+    """SiLU Gated Linear Unit activation.
+    Applies SiLU Gated Linear Unit :math:`a * SiLU(b)` where :math:`a` is
+    the first half of the input matrices, :math:`b` is the second half.
+
+    Args:
+        dim (int): the dimension on which to split the input. Default: -1
+    """
+    def __init__(self, dim: int = -1):
+        super(SwiGLU, self).__init__(nn.SiLU(), dim)
+
+
+class GeGLU(CustomGLU):
+    """GeLU Gated Linear Unit activation.
+    Applies GeLU Gated Linear Unit :math:`a * GELU(b)` where :math:`a` is
+    the first half of the input matrices, :math:`b` is the second half.
+
+    Args:
+        dim (int): the dimension on which to split the input. Default: -1
+    """
+    def __init__(self, dim: int = -1):
+        super(GeGLU, self).__init__(nn.GELU(), dim)
+
+
+class ReGLU(CustomGLU):
+    """ReLU Gated Linear Unit activation.
+    Applies ReLU Gated Linear Unit :math:`a * ReLU(b)` where :math:`a` is
+    the first half of the input matrices, :math:`b` is the second half.
+
+    Args:
+        dim (int): the dimension on which to split the input. Default: -1
+    """
+    def __init__(self, dim: int = -1):
+        super(ReGLU, self).__init__(nn.ReLU(), dim)
+
+
+def get_activation_fn(
+    activation: Union[str, Callable[[Tensor], Tensor]]
+) -> Union[str, Callable[[Tensor], Tensor]]:
+    """Helper function to map an activation string to the activation class.
+    If the supplied activation is not a string that is recognized, the activation is passed back.
+
+    Args:
+        activation (Union[str, Callable[[Tensor], Tensor]]): Activation to check
+    """
+    if isinstance(activation, str):
+        if activation == "reglu":
+            return ReGLU()
+        elif activation == "geglu":
+            return GeGLU()
+        elif activation == "swiglu":
+            return SwiGLU()
+    return activation
+
+
+
+
+
+
+
+

Functions

+
+
+def get_activation_fn(activation: Union[str, Callable[[torch.Tensor], torch.Tensor]]) ‑> Union[str, Callable[[torch.Tensor], torch.Tensor]] +
+
+

Helper function to map an activation string to the activation class. +If the supplied activation is not a string that is recognized, the activation is passed back.

+

Args

+
+
activation : Union[str, Callable[[Tensor], Tensor]]
+
Activation to check
+
+
+ +Expand source code + +
def get_activation_fn(
+    activation: Union[str, Callable[[Tensor], Tensor]]
+) -> Union[str, Callable[[Tensor], Tensor]]:
+    """Helper function to map an activation string to the activation class.
+    If the supplied activation is not a string that is recognized, the activation is passed back.
+
+    Args:
+        activation (Union[str, Callable[[Tensor], Tensor]]): Activation to check
+    """
+    if isinstance(activation, str):
+        if activation == "reglu":
+            return ReGLU()
+        elif activation == "geglu":
+            return GeGLU()
+        elif activation == "swiglu":
+            return SwiGLU()
+    return activation
+
+
+
+
+
+

Classes

+
+
+class CustomGLU +(activation: torch.nn.modules.module.Module, dim: int = -1) +
+
+

Custom Gated Linear Unit activation. +Applies a modified gated linear unit :math:a * f(b) where :math:a is the first half +of the input matrices, :math:b is the second half, and :math:f is a provided activation +function (i.e. sigmoid, swish, etc.).

+

Args

+
+
activation : nn.Module
+
The custom activation to apply in the Gated Linear Unit
+
dim : int
+
the dimension on which to split the input. Default: -1
+
+

Shape

+
    +
  • Input: :math:(st_1, N, st_2) where * means, any number of additional +dimensions
  • +
  • Output: :math:(st_1, M, st_2) where :math:M=N/2
  • +
+

Examples:: +>>> m = CustomGLU(nn.Sigmoid()) +>>> input = torch.randn(4, 2) +>>> output = m(input)

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class CustomGLU(nn.Module):
+    """Custom Gated Linear Unit activation.
+    Applies a modified gated linear unit :math:`a * f(b)` where :math:`a` is the first half
+    of the input matrices, :math:`b` is the second half, and :math:`f` is a provided activation
+    function (i.e. sigmoid, swish, etc.).
+
+    Args:
+        activation (nn.Module): The custom activation to apply in the Gated Linear Unit
+        dim (int): the dimension on which to split the input. Default: -1
+
+    Shape:
+        - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
+          dimensions
+        - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
+
+    Examples::
+        >>> m = CustomGLU(nn.Sigmoid())
+        >>> input = torch.randn(4, 2)
+        >>> output = m(input)
+    """
+    def __init__(self, activation: nn.Module, dim: int = -1):
+        super(CustomGLU, self).__init__()
+        self.dim = dim
+        self.activation = activation
+
+    def forward(self, x: Tensor):
+        assert x.shape[self.dim] % 2 == 0  # M = N / 2
+        a, b = torch.chunk(x, 2, dim=self.dim)
+        return a * self.activation(b)
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Subclasses

+ +

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, x: torch.Tensor) ‑> Callable[..., Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x: Tensor):
+    assert x.shape[self.dim] % 2 == 0  # M = N / 2
+    a, b = torch.chunk(x, 2, dim=self.dim)
+    return a * self.activation(b)
+
+
+
+
+
+class GeGLU +(dim: int = -1) +
+
+

GeLU Gated Linear Unit activation. +Applies GeLU Gated Linear Unit :math:a * GELU(b) where :math:a is +the first half of the input matrices, :math:b is the second half.

+

Args

+
+
dim : int
+
the dimension on which to split the input. Default: -1
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class GeGLU(CustomGLU):
+    """GeLU Gated Linear Unit activation.
+    Applies GeLU Gated Linear Unit :math:`a * GELU(b)` where :math:`a` is
+    the first half of the input matrices, :math:`b` is the second half.
+
+    Args:
+        dim (int): the dimension on which to split the input. Default: -1
+    """
+    def __init__(self, dim: int = -1):
+        super(GeGLU, self).__init__(nn.GELU(), dim)
+
+

Ancestors

+
    +
  • CustomGLU
  • +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Inherited members

+ +
+
+class ReGLU +(dim: int = -1) +
+
+

ReLU Gated Linear Unit activation. +Applies ReLU Gated Linear Unit :math:a * ReLU(b) where :math:a is +the first half of the input matrices, :math:b is the second half.

+

Args

+
+
dim : int
+
the dimension on which to split the input. Default: -1
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class ReGLU(CustomGLU):
+    """ReLU Gated Linear Unit activation.
+    Applies ReLU Gated Linear Unit :math:`a * ReLU(b)` where :math:`a` is
+    the first half of the input matrices, :math:`b` is the second half.
+
+    Args:
+        dim (int): the dimension on which to split the input. Default: -1
+    """
+    def __init__(self, dim: int = -1):
+        super(ReGLU, self).__init__(nn.ReLU(), dim)
+
+

Ancestors

+
    +
  • CustomGLU
  • +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Inherited members

+ +
+
+class SwiGLU +(dim: int = -1) +
+
+

SiLU Gated Linear Unit activation. +Applies SiLU Gated Linear Unit :math:a * SiLU(b) where :math:a is +the first half of the input matrices, :math:b is the second half.

+

Args

+
+
dim : int
+
the dimension on which to split the input. Default: -1
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class SwiGLU(CustomGLU):
+    """SiLU Gated Linear Unit activation.
+    Applies SiLU Gated Linear Unit :math:`a * SiLU(b)` where :math:`a` is
+    the first half of the input matrices, :math:`b` is the second half.
+
+    Args:
+        dim (int): the dimension on which to split the input. Default: -1
+    """
+    def __init__(self, dim: int = -1):
+        super(SwiGLU, self).__init__(nn.SiLU(), dim)
+
+

Ancestors

+
    +
  • CustomGLU
  • +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Inherited members

+ +
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/audiocraft/modules/codebooks_patterns.html b/docs/audiocraft/modules/codebooks_patterns.html new file mode 100644 index 00000000..d3339e1b --- /dev/null +++ b/docs/audiocraft/modules/codebooks_patterns.html @@ -0,0 +1,1818 @@ + + + + + + +audiocraft.modules.codebooks_patterns API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.modules.codebooks_patterns

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from collections import namedtuple
+from dataclasses import dataclass
+from functools import lru_cache
+import logging
+import typing as tp
+
+from abc import ABC, abstractmethod
+import torch
+
+LayoutCoord = namedtuple('LayoutCoord', ['t', 'q'])  # (timestep, codebook index)
+PatternLayout = tp.List[tp.List[LayoutCoord]]  # Sequence of coordinates
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class Pattern:
+    """Base implementation of a pattern over a sequence with multiple codebooks.
+
+    The codebook pattern consists in a layout, defining for each sequence step
+    the list of coordinates of each codebook timestep in the resulting interleaved sequence.
+    The first item of the pattern is always an empty list in order to properly insert a special token
+    to start with. For convenience, we also keep track of ``n_q`` the number of codebooks used for the pattern
+    and ``timesteps`` the number of timesteps corresponding to the original sequence.
+
+    The pattern provides convenient methods to build and revert interleaved sequences from it:
+    ``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T]
+        to the interleaved sequence of shape [B, K, S] applying the pattern, with S being the batch size,
+        K being the number of codebooks, T the number of original timesteps and S the number of sequence steps
+        for the output sequence. The unfilled positions are replaced with a special token and the built sequence
+        is returned along with a mask indicating valid tokens.
+    ``revert_pattern_sequence`` maps back an interleaved sequence of shape [B, K, S] to the original alignment
+        of codebooks across timesteps to an output tensor of shape [B, K, T], using again a special token and a mask
+        to fill and specify invalid positions if needed.
+    See the dedicated methods for more details.
+    """
+    # Pattern layout, for each sequence step, we have a list of coordinates
+    # corresponding to the original codebook timestep and position.
+    # The first list is always an empty list in order to properly insert
+    # a special token to start with.
+    layout: PatternLayout
+    timesteps: int
+    n_q: int
+
+    def __post_init__(self):
+        assert len(self.layout) > 0
+        assert self.layout[0] == []
+        self._validate_layout()
+        self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes)
+        self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes)
+        logger.info("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout))
+
+    def _validate_layout(self):
+        """Runs checks on the layout to ensure a valid pattern is defined.
+        A pattern is considered invalid if:
+            - Multiple timesteps for a same codebook are defined in the same sequence step
+            - The timesteps for a given codebook are not in ascending order as we advance in the sequence
+              (this would mean that we have future timesteps before past timesteps).
+        """
+        q_timesteps = {q: 0 for q in range(self.n_q)}
+        for s, seq_coords in enumerate(self.layout):
+            if len(seq_coords) > 0:
+                qs = set()
+                for coord in seq_coords:
+                    qs.add(coord.q)
+                    last_q_timestep = q_timesteps[coord.q]
+                    assert coord.t >= last_q_timestep, \
+                        f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}"
+                    q_timesteps[coord.q] = coord.t
+                # each sequence step contains at max 1 coordinate per codebook
+                assert len(qs) == len(seq_coords), \
+                    f"Multiple entries for a same codebook are found at step {s}"
+
+    @property
+    def num_sequence_steps(self):
+        return len(self.layout) - 1
+
+    @property
+    def max_delay(self):
+        max_t_in_seq_coords = 0
+        for seq_coords in self.layout[1:]:
+            for coords in seq_coords:
+                max_t_in_seq_coords = max(max_t_in_seq_coords, coords.t + 1)
+        return max_t_in_seq_coords - self.timesteps
+
+    @property
+    def valid_layout(self):
+        valid_step = len(self.layout) - self.max_delay
+        return self.layout[:valid_step]
+
+    def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None):
+        """Get codebook coordinates in the layout that corresponds to the specified timestep t
+        and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step
+        and the actual codebook coordinates.
+        """
+        assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps"
+        if q is not None:
+            assert q <= self.n_q, "provided number of codebooks is greater than the pattern's number of codebooks"
+        coords = []
+        for s, seq_codes in enumerate(self.layout):
+            for code in seq_codes:
+                if code.t == t and (q is None or code.q == q):
+                    coords.append((s, code))
+        return coords
+
+    def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) -> tp.List[int]:
+        return [step for step, coords in self.get_sequence_coords_with_timestep(t, q)]
+
+    def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = None) -> tp.Optional[int]:
+        steps_with_timesteps = self.get_steps_with_timestep(t, q)
+        return steps_with_timesteps[0] if len(steps_with_timesteps) > 0 else None
+
+    def _build_pattern_sequence_scatter_indexes(self, timesteps: int, n_q: int, keep_only_valid_steps: bool,
+                                                device: tp.Union[torch.device, str] = 'cpu'):
+        """Build scatter indexes corresponding to the pattern, up to the provided sequence_steps.
+
+        Args:
+            timesteps (int): Maximum number of timesteps steps to consider.
+            keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps.
+            device (Union[torch.device, str]): Device for created tensors.
+        Returns:
+            indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S].
+            mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S].
+        """
+        assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
+        assert timesteps <= self.timesteps, "invalid number of timesteps used to build the sequence from the pattern"
+        # use the proper layout based on whether we limit ourselves to valid steps only or not,
+        # note that using the valid_layout will result in a truncated sequence up to the valid steps
+        ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
+        # single item indexing being super slow with pytorch vs. numpy, so we use numpy here
+        indexes = torch.zeros(n_q, len(ref_layout), dtype=torch.long).numpy()
+        mask = torch.zeros(n_q, len(ref_layout), dtype=torch.bool).numpy()
+        # fill indexes with last sequence step value that will correspond to our special token
+        # the last value is n_q * timesteps as we have flattened z and append special token as the last token
+        # which will correspond to the index: n_q * timesteps
+        indexes[:] = n_q * timesteps
+        # iterate over the pattern and fill scattered indexes and mask
+        for s, sequence_coords in enumerate(ref_layout):
+            for coords in sequence_coords:
+                if coords.t < timesteps:
+                    indexes[coords.q, s] = coords.t + coords.q * timesteps
+                    mask[coords.q, s] = 1
+        indexes = torch.from_numpy(indexes).to(device)
+        mask = torch.from_numpy(mask).to(device)
+        return indexes, mask
+
+    def build_pattern_sequence(self, z: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
+        """Build sequence corresponding to the pattern from the input tensor z.
+        The sequence is built using up to sequence_steps if specified, and non-pattern
+        coordinates are filled with the special token.
+
+        Args:
+            z (torch.Tensor): Input tensor of multi-codebooks sequence, of shape [B, K, T].
+            special_token (int): Special token used to fill non-pattern coordinates in the new sequence.
+            keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
+                Steps that are beyond valid steps will be replaced by the special_token in that case.
+        Returns:
+            values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, S] with S
+                corresponding either to the sequence_steps if provided, otherwise to the length of the pattern.
+            indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, S].
+            mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, S].
+        """
+        B, K, T = z.shape
+        indexes, mask = self._build_pattern_sequence_scatter_indexes(
+            T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device)
+        )
+        z = z.view(B, -1)
+        # we append the special token as the last index of our flattened z tensor
+        z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1)
+        values = z[:, indexes.view(-1)]
+        values = values.view(B, K, indexes.shape[-1])
+        return values, indexes, mask
+
+    def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int, n_q: int,
+                                                 keep_only_valid_steps: bool = False,
+                                                 is_model_output: bool = False,
+                                                 device: tp.Union[torch.device, str] = 'cpu'):
+        """Builds scatter indexes required to retrieve the original multi-codebook sequence
+        from interleaving pattern.
+
+        Args:
+            sequence_steps (int): Sequence steps.
+            n_q (int): Number of codebooks.
+            keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
+                Steps that are beyond valid steps will be replaced by the special_token in that case.
+            is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not.
+            device (Union[torch.device, str]): Device for created tensors.
+        Returns:
+            torch.Tensor: Indexes for reconstructing the output, of shape [K, T].
+            mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
+        """
+        ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
+        # TODO(jade): Do we want to further truncate to only valid timesteps here as well?
+        timesteps = self.timesteps
+        assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
+        assert sequence_steps <= len(ref_layout), \
+            f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}"
+
+        # ensure we take the appropriate indexes to keep the model output from the first special token as well
+        if is_model_output:
+            ref_layout = ref_layout[1:]
+
+        # single item indexing being super slow with pytorch vs. numpy, so we use numpy here
+        indexes = torch.zeros(n_q, timesteps, dtype=torch.long).numpy()
+        mask = torch.zeros(n_q, timesteps, dtype=torch.bool).numpy()
+        # fill indexes with last sequence step value that will correspond to our special token
+        indexes[:] = n_q * sequence_steps
+        for s, sequence_codes in enumerate(ref_layout):
+            if s < sequence_steps:
+                for code in sequence_codes:
+                    if code.t < timesteps:
+                        indexes[code.q, code.t] = s + code.q * sequence_steps
+                        mask[code.q, code.t] = 1
+        indexes = torch.from_numpy(indexes).to(device)
+        mask = torch.from_numpy(mask).to(device)
+        return indexes, mask
+
+    def revert_pattern_sequence(self, s: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
+        """Revert a sequence built from the pattern back to the original multi-codebook sequence without interleaving.
+        The sequence is reverted using up to timesteps if specified, and non-pattern coordinates
+        are filled with the special token.
+
+        Args:
+            s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S].
+            special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence.
+        Returns:
+            values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, T] with T
+                corresponding either to the timesteps if provided, or the total timesteps in pattern otherwise.
+            indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T].
+            mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
+        """
+        B, K, S = s.shape
+        indexes, mask = self._build_reverted_sequence_scatter_indexes(
+            S, K, keep_only_valid_steps, is_model_output=False, device=str(s.device)
+        )
+        s = s.view(B, -1)
+        # we append the special token as the last index of our flattened z tensor
+        s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1)
+        values = s[:, indexes.view(-1)]
+        values = values.view(B, K, indexes.shape[-1])
+        return values, indexes, mask
+
+    def revert_pattern_logits(self, logits: torch.Tensor, special_token: float, keep_only_valid_steps: bool = False):
+        """Revert model logits obtained on a sequence built from the pattern
+        back to a tensor matching the original sequence.
+
+        This method is similar to ``revert_pattern_sequence`` with the following specificities:
+        1. It is designed to work with the extra cardinality dimension
+        2. We return the logits for the first sequence item that matches the special_token and
+        which matching target in the original sequence is the first item of the sequence,
+        while we skip the last logits as there is no matching target
+        """
+        B, card, K, S = logits.shape
+        indexes, mask = self._build_reverted_sequence_scatter_indexes(
+            S, K, keep_only_valid_steps, is_model_output=True, device=logits.device
+        )
+        logits = logits.reshape(B, card, -1)
+        # we append the special token as the last index of our flattened z tensor
+        logits = torch.cat([logits, torch.zeros_like(logits[:, :, :1]) + special_token], dim=-1)  # [B, card, K x S]
+        values = logits[:, :, indexes.view(-1)]
+        values = values.view(B, card, K, indexes.shape[-1])
+        return values, indexes, mask
+
+
+class CodebooksPatternProvider(ABC):
+    """Abstraction around providing pattern for interleaving codebooks.
+
+    The CodebooksPatternProvider abstraction allows to implement various strategies to
+    define interleaving pattern of sequences composed of multiple codebooks. For a given
+    number of codebooks `n_q`, the pattern provider can generate a specified pattern
+    corresponding to a sequence of `T` timesteps with `n_q` parallel codebooks. This pattern
+    can be used to construct a new sequence from the original codes respecting the specified
+    pattern. The pattern is defined as a list of list of code coordinates, code coordinate
+    being a tuple with the original timestep and codebook to build the new sequence.
+    Note that all patterns must start with an empty list that is then used to insert a first
+    sequence step of special tokens in the newly generated sequence.
+
+    Args:
+        n_q (int): number of codebooks.
+        cached (bool): if True, patterns for a given length are cached. In general
+            that should be true for efficiency reason to avoid synchronization points.
+    """
+    def __init__(self, n_q: int, cached: bool = True):
+        assert n_q > 0
+        self.n_q = n_q
+        self.get_pattern = lru_cache(100)(self.get_pattern)  # type: ignore
+
+    @abstractmethod
+    def get_pattern(self, timesteps: int) -> Pattern:
+        """Builds pattern with specific interleaving between codebooks.
+
+        Args:
+            timesteps (int): Total numer of timesteps.
+        """
+        raise NotImplementedError()
+
+
+class DelayedPatternProvider(CodebooksPatternProvider):
+    """Provider for delayed pattern across delayed codebooks.
+    Codebooks are delayed in the sequence and sequence steps will contain codebooks
+    from different timesteps.
+
+    Example:
+        Taking timesteps=4 and n_q=3, delays=None, the multi-codebook sequence:
+        [[1, 2, 3, 4],
+        [1, 2, 3, 4],
+        [1, 2, 3, 4]]
+        The resulting sequence obtained from the returned pattern is:
+        [[S, 1, 2, 3, 4],
+        [S, S, 1, 2, 3],
+        [S, S, S, 1, 2]]
+        (with S being a special token)
+
+    Args:
+        n_q (int): Number of codebooks.
+        delays (Optional[List[int]]): Delay for each of the codebooks.
+            If delays not defined, each codebook is delayed by 1 compared to the previous one.
+        flatten_first (int): Flatten the first N timesteps.
+        empty_initial (int): Prepend with N empty list of coordinates.
+    """
+    def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None,
+                 flatten_first: int = 0, empty_initial: int = 0):
+        super().__init__(n_q)
+        if delays is None:
+            delays = list(range(n_q))
+        self.delays = delays
+        self.flatten_first = flatten_first
+        self.empty_initial = empty_initial
+        assert len(self.delays) == self.n_q
+        assert sorted(self.delays) == self.delays
+
+    def get_pattern(self, timesteps: int) -> Pattern:
+        out: PatternLayout = [[]]
+        max_delay = max(self.delays)
+        if self.empty_initial:
+            out += [[] for _ in range(self.empty_initial)]
+        if self.flatten_first:
+            for t in range(min(timesteps, self.flatten_first)):
+                for q in range(self.n_q):
+                    out.append([LayoutCoord(t, q)])
+        for t in range(self.flatten_first, timesteps + max_delay):
+            v = []
+            for q, delay in enumerate(self.delays):
+                t_for_q = t - delay
+                if t_for_q >= self.flatten_first:
+                    v.append(LayoutCoord(t_for_q, q))
+            out.append(v)
+        return Pattern(out, n_q=self.n_q, timesteps=timesteps)
+
+
+class ParallelPatternProvider(DelayedPatternProvider):
+    """Provider for parallel pattern across codebooks.
+    This pattern provider is a special case of the delayed pattern with actually no delay,
+    hence delays=repeat(0, n_q).
+
+    Args:
+        n_q (int): Number of codebooks.
+    """
+    def __init__(self, n_q: int):
+        super().__init__(n_q, [0] * n_q)
+
+
+class UnrolledPatternProvider(CodebooksPatternProvider):
+    """Provider for unrolling codebooks pattern.
+    This pattern provider enables to represent the codebook flattened completely or only to some extend
+    while also specifying a given delay between the flattened codebooks representation, allowing to
+    unroll the codebooks in the sequence.
+
+    Example:
+        1. Flattening of the codebooks.
+        By default, the pattern provider will fully flatten the codebooks such as flattening=range(n_q),
+        taking n_q = 3 and timesteps = 4:
+        [[1, 2, 3, 4],
+         [1, 2, 3, 4],
+         [1, 2, 3, 4]]
+        will result into:
+        [[S, S, 1, S, S, 2, S, S, 3, S, S, 4],
+         [S, 1, S, S, 2, S, S, 3, S, S, 4, S],
+         [1, S, S, 2, S, S, 3, S, S, 4, S, S]]
+        2. Partial flattening of the codebooks. The ``flattening`` parameter allows to specify the inner step
+        for each of the codebook, allowing to define which codebook to flatten (or keep in parallel), for example
+        taking n_q = 3, timesteps = 4 and flattening = [0, 1, 1]:
+        [[1, 2, 3, 4],
+         [1, 2, 3, 4],
+         [1, 2, 3, 4]]
+        will result into:
+        [[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
+         [S, 1, S, S, 2, S, S, 3, S, S, 4, S],
+         [1, S, S, 2, S, S, 3, S, S, 4, S, S]]
+        3. Flattening with delay. The ``delay`` parameter allows to further unroll the sequence of codebooks
+        allowing to specify the delay per codebook. Note that the delay between codebooks flattened to the
+        same inner timestep should be coherent. For example, taking n_q = 3, timesteps = 4, flattening = [0, 1, 1]
+        and delays = [0, 3, 3]:
+        [[1, 2, 3, 4],
+         [1, 2, 3, 4],
+         [1, 2, 3, 4]]
+        will result into:
+        [[S, S, S, 1, S, 2, S, 3, S, 4],
+         [S, S, S, 1, S, 2, S, 3, S, 4],
+         [1, 2, 3, S, 4, S, 5, S, 6, S]]
+
+    Args:
+        n_q (int): Number of codebooks.
+        flattening (Optional[List[int]]): Flattening schema over the codebooks. If not defined,
+            the codebooks will be flattened to 1 codebook per step, meaning that the sequence will
+            have n_q extra steps for each timestep.
+        delays (Optional[List[int]]): Delay for each of the codebooks. If not defined,
+            no delay is added and therefore will default to [0] * ``n_q``.
+            Note that two codebooks that will be flattened to the same inner step
+            should have the same delay, otherwise the pattern is considered as invalid.
+    """
+    FlattenedCodebook = namedtuple('FlattenedCodebook', ['codebooks', 'delay'])
+
+    def __init__(self, n_q: int, flattening: tp.Optional[tp.List[int]] = None,
+                 delays: tp.Optional[tp.List[int]] = None):
+        super().__init__(n_q)
+        if flattening is None:
+            flattening = list(range(n_q))
+        if delays is None:
+            delays = [0] * n_q
+        assert len(flattening) == n_q
+        assert len(delays) == n_q
+        assert sorted(flattening) == flattening
+        assert sorted(delays) == delays
+        self._flattened_codebooks = self._build_flattened_codebooks(delays, flattening)
+        self.max_delay = max(delays)
+
+    def _build_flattened_codebooks(self, delays: tp.List[int], flattening: tp.List[int]):
+        """Build a flattened codebooks representation as a dictionary of inner step
+        and the actual codebook indices corresponding to the flattened codebook. For convenience, we
+        also store the delay associated to the flattened codebook to avoid maintaining an extra mapping.
+        """
+        flattened_codebooks: dict = {}
+        for q, (inner_step, delay) in enumerate(zip(flattening, delays)):
+            if inner_step not in flattened_codebooks:
+                flat_codebook = UnrolledPatternProvider.FlattenedCodebook(codebooks=[q], delay=delay)
+            else:
+                flat_codebook = flattened_codebooks[inner_step]
+                assert flat_codebook.delay == delay, (
+                    "Delay and flattening between codebooks is inconsistent: ",
+                    "two codebooks flattened to the same position should have the same delay."
+                )
+                flat_codebook.codebooks.append(q)
+            flattened_codebooks[inner_step] = flat_codebook
+        return flattened_codebooks
+
+    @property
+    def _num_inner_steps(self):
+        """Number of inner steps to unroll between timesteps in order to flatten the codebooks.
+        """
+        return max([inner_step for inner_step in self._flattened_codebooks.keys()]) + 1
+
+    def num_virtual_steps(self, timesteps: int) -> int:
+        return timesteps * self._num_inner_steps + 1
+
+    def get_pattern(self, timesteps: int) -> Pattern:
+        """Builds pattern for delay across codebooks.
+
+        Args:
+            timesteps (int): Total numer of timesteps.
+        """
+        # the PatternLayout is built as a tuple of sequence position and list of coordinates
+        # so that it can be reordered properly given the required delay between codebooks of given timesteps
+        indexed_out: list = [(-1, [])]
+        max_timesteps = timesteps + self.max_delay
+        for t in range(max_timesteps):
+            # for each timestep, we unroll the flattened codebooks,
+            # emitting the sequence step with the corresponding delay
+            for step in range(self._num_inner_steps):
+                if step in self._flattened_codebooks:
+                    # we have codebooks at this virtual step to emit
+                    step_codebooks = self._flattened_codebooks[step]
+                    t_for_q = t + step_codebooks.delay
+                    coords = [LayoutCoord(t, q) for q in step_codebooks.codebooks]
+                    if t_for_q < max_timesteps and t < max_timesteps:
+                        indexed_out.append((t_for_q, coords))
+                else:
+                    # there is no codebook in this virtual step so we emit an empty list
+                    indexed_out.append((t, []))
+        out = [coords for _, coords in sorted(indexed_out)]
+        return Pattern(out, n_q=self.n_q, timesteps=timesteps)
+
+
+class VALLEPattern(CodebooksPatternProvider):
+    """Almost VALL-E style pattern. We futher allow some delays for the
+    codebooks other than the first one.
+
+    Args:
+        n_q (int): Number of codebooks.
+        delays (Optional[List[int]]): Delay for each of the codebooks.
+            If delays not defined, each codebook is delayed by 1 compared to the previous one.
+    """
+    def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None):
+        super().__init__(n_q)
+        if delays is None:
+            delays = [0] * (n_q - 1)
+        self.delays = delays
+        assert len(self.delays) == self.n_q - 1
+        assert sorted(self.delays) == self.delays
+
+    def get_pattern(self, timesteps: int) -> Pattern:
+        out: PatternLayout = [[]]
+        for t in range(timesteps):
+            out.append([LayoutCoord(t, 0)])
+        max_delay = max(self.delays)
+        for t in range(timesteps + max_delay):
+            v = []
+            for q, delay in enumerate(self.delays):
+                t_for_q = t - delay
+                if t_for_q >= 0:
+                    v.append(LayoutCoord(t_for_q, q + 1))
+            out.append(v)
+        return Pattern(out, n_q=self.n_q, timesteps=timesteps)
+
+
+class MusicLMPattern(CodebooksPatternProvider):
+    """Almost MusicLM style pattern. This is equivalent to full flattening
+    but in a different order.
+
+    Args:
+        n_q (int): Number of codebooks.
+        group_by (int): Number of codebooks to group together.
+    """
+    def __init__(self, n_q: int, group_by: int = 2):
+        super().__init__(n_q)
+        self.group_by = group_by
+
+    def get_pattern(self, timesteps: int) -> Pattern:
+        out: PatternLayout = [[]]
+        for offset in range(0, self.n_q, self.group_by):
+            for t in range(timesteps):
+                for q in range(offset, offset + self.group_by):
+                    out.append([LayoutCoord(t, q)])
+        return Pattern(out, n_q=self.n_q, timesteps=timesteps)
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class CodebooksPatternProvider +(n_q: int, cached: bool = True) +
+
+

Abstraction around providing pattern for interleaving codebooks.

+

The CodebooksPatternProvider abstraction allows to implement various strategies to +define interleaving pattern of sequences composed of multiple codebooks. For a given +number of codebooks n_q, the pattern provider can generate a specified pattern +corresponding to a sequence of T timesteps with n_q parallel codebooks. This pattern +can be used to construct a new sequence from the original codes respecting the specified +pattern. The pattern is defined as a list of list of code coordinates, code coordinate +being a tuple with the original timestep and codebook to build the new sequence. +Note that all patterns must start with an empty list that is then used to insert a first +sequence step of special tokens in the newly generated sequence.

+

Args

+
+
n_q : int
+
number of codebooks.
+
cached : bool
+
if True, patterns for a given length are cached. In general +that should be true for efficiency reason to avoid synchronization points.
+
+
+ +Expand source code + +
class CodebooksPatternProvider(ABC):
+    """Abstraction around providing pattern for interleaving codebooks.
+
+    The CodebooksPatternProvider abstraction allows to implement various strategies to
+    define interleaving pattern of sequences composed of multiple codebooks. For a given
+    number of codebooks `n_q`, the pattern provider can generate a specified pattern
+    corresponding to a sequence of `T` timesteps with `n_q` parallel codebooks. This pattern
+    can be used to construct a new sequence from the original codes respecting the specified
+    pattern. The pattern is defined as a list of list of code coordinates, code coordinate
+    being a tuple with the original timestep and codebook to build the new sequence.
+    Note that all patterns must start with an empty list that is then used to insert a first
+    sequence step of special tokens in the newly generated sequence.
+
+    Args:
+        n_q (int): number of codebooks.
+        cached (bool): if True, patterns for a given length are cached. In general
+            that should be true for efficiency reason to avoid synchronization points.
+    """
+    def __init__(self, n_q: int, cached: bool = True):
+        assert n_q > 0
+        self.n_q = n_q
+        self.get_pattern = lru_cache(100)(self.get_pattern)  # type: ignore
+
+    @abstractmethod
+    def get_pattern(self, timesteps: int) -> Pattern:
+        """Builds pattern with specific interleaving between codebooks.
+
+        Args:
+            timesteps (int): Total numer of timesteps.
+        """
+        raise NotImplementedError()
+
+

Ancestors

+
    +
  • abc.ABC
  • +
+

Subclasses

+ +

Methods

+
+
+def get_pattern(self, timesteps: int) ‑> Pattern +
+
+

Builds pattern with specific interleaving between codebooks.

+

Args

+
+
timesteps : int
+
Total numer of timesteps.
+
+
+ +Expand source code + +
@abstractmethod
+def get_pattern(self, timesteps: int) -> Pattern:
+    """Builds pattern with specific interleaving between codebooks.
+
+    Args:
+        timesteps (int): Total numer of timesteps.
+    """
+    raise NotImplementedError()
+
+
+
+
+
+class DelayedPatternProvider +(n_q: int, delays: Optional[List[int]] = None, flatten_first: int = 0, empty_initial: int = 0) +
+
+

Provider for delayed pattern across delayed codebooks. +Codebooks are delayed in the sequence and sequence steps will contain codebooks +from different timesteps.

+

Example

+

Taking timesteps=4 and n_q=3, delays=None, the multi-codebook sequence: +[[1, 2, 3, 4], +[1, 2, 3, 4], +[1, 2, 3, 4]] +The resulting sequence obtained from the returned pattern is: +[[S, 1, 2, 3, 4], +[S, S, 1, 2, 3], +[S, S, S, 1, 2]] +(with S being a special token)

+

Args

+
+
n_q : int
+
Number of codebooks.
+
delays : Optional[List[int]]
+
Delay for each of the codebooks. +If delays not defined, each codebook is delayed by 1 compared to the previous one.
+
flatten_first : int
+
Flatten the first N timesteps.
+
empty_initial : int
+
Prepend with N empty list of coordinates.
+
+
+ +Expand source code + +
class DelayedPatternProvider(CodebooksPatternProvider):
+    """Provider for delayed pattern across delayed codebooks.
+    Codebooks are delayed in the sequence and sequence steps will contain codebooks
+    from different timesteps.
+
+    Example:
+        Taking timesteps=4 and n_q=3, delays=None, the multi-codebook sequence:
+        [[1, 2, 3, 4],
+        [1, 2, 3, 4],
+        [1, 2, 3, 4]]
+        The resulting sequence obtained from the returned pattern is:
+        [[S, 1, 2, 3, 4],
+        [S, S, 1, 2, 3],
+        [S, S, S, 1, 2]]
+        (with S being a special token)
+
+    Args:
+        n_q (int): Number of codebooks.
+        delays (Optional[List[int]]): Delay for each of the codebooks.
+            If delays not defined, each codebook is delayed by 1 compared to the previous one.
+        flatten_first (int): Flatten the first N timesteps.
+        empty_initial (int): Prepend with N empty list of coordinates.
+    """
+    def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None,
+                 flatten_first: int = 0, empty_initial: int = 0):
+        super().__init__(n_q)
+        if delays is None:
+            delays = list(range(n_q))
+        self.delays = delays
+        self.flatten_first = flatten_first
+        self.empty_initial = empty_initial
+        assert len(self.delays) == self.n_q
+        assert sorted(self.delays) == self.delays
+
+    def get_pattern(self, timesteps: int) -> Pattern:
+        out: PatternLayout = [[]]
+        max_delay = max(self.delays)
+        if self.empty_initial:
+            out += [[] for _ in range(self.empty_initial)]
+        if self.flatten_first:
+            for t in range(min(timesteps, self.flatten_first)):
+                for q in range(self.n_q):
+                    out.append([LayoutCoord(t, q)])
+        for t in range(self.flatten_first, timesteps + max_delay):
+            v = []
+            for q, delay in enumerate(self.delays):
+                t_for_q = t - delay
+                if t_for_q >= self.flatten_first:
+                    v.append(LayoutCoord(t_for_q, q))
+            out.append(v)
+        return Pattern(out, n_q=self.n_q, timesteps=timesteps)
+
+

Ancestors

+ +

Subclasses

+ +

Inherited members

+ +
+
+class LayoutCoord +(t, q) +
+
+

LayoutCoord(t, q)

+

Ancestors

+
    +
  • builtins.tuple
  • +
+

Instance variables

+
+
var q
+
+

Alias for field number 1

+
+
var t
+
+

Alias for field number 0

+
+
+
+
+class MusicLMPattern +(n_q: int, group_by: int = 2) +
+
+

Almost MusicLM style pattern. This is equivalent to full flattening +but in a different order.

+

Args

+
+
n_q : int
+
Number of codebooks.
+
group_by : int
+
Number of codebooks to group together.
+
+
+ +Expand source code + +
class MusicLMPattern(CodebooksPatternProvider):
+    """Almost MusicLM style pattern. This is equivalent to full flattening
+    but in a different order.
+
+    Args:
+        n_q (int): Number of codebooks.
+        group_by (int): Number of codebooks to group together.
+    """
+    def __init__(self, n_q: int, group_by: int = 2):
+        super().__init__(n_q)
+        self.group_by = group_by
+
+    def get_pattern(self, timesteps: int) -> Pattern:
+        out: PatternLayout = [[]]
+        for offset in range(0, self.n_q, self.group_by):
+            for t in range(timesteps):
+                for q in range(offset, offset + self.group_by):
+                    out.append([LayoutCoord(t, q)])
+        return Pattern(out, n_q=self.n_q, timesteps=timesteps)
+
+

Ancestors

+ +

Inherited members

+ +
+
+class ParallelPatternProvider +(n_q: int) +
+
+

Provider for parallel pattern across codebooks. +This pattern provider is a special case of the delayed pattern with actually no delay, +hence delays=repeat(0, n_q).

+

Args

+
+
n_q : int
+
Number of codebooks.
+
+
+ +Expand source code + +
class ParallelPatternProvider(DelayedPatternProvider):
+    """Provider for parallel pattern across codebooks.
+    This pattern provider is a special case of the delayed pattern with actually no delay,
+    hence delays=repeat(0, n_q).
+
+    Args:
+        n_q (int): Number of codebooks.
+    """
+    def __init__(self, n_q: int):
+        super().__init__(n_q, [0] * n_q)
+
+

Ancestors

+ +

Inherited members

+ +
+
+class Pattern +(layout: List[List[LayoutCoord]], timesteps: int, n_q: int) +
+
+

Base implementation of a pattern over a sequence with multiple codebooks.

+

The codebook pattern consists in a layout, defining for each sequence step +the list of coordinates of each codebook timestep in the resulting interleaved sequence. +The first item of the pattern is always an empty list in order to properly insert a special token +to start with. For convenience, we also keep track of n_q the number of codebooks used for the pattern +and timesteps the number of timesteps corresponding to the original sequence.

+

The pattern provides convenient methods to build and revert interleaved sequences from it: +build_pattern_sequence maps a given a dense input tensor of multi-codebook sequence from [B, K, T] +to the interleaved sequence of shape [B, K, S] applying the pattern, with S being the batch size, +K being the number of codebooks, T the number of original timesteps and S the number of sequence steps +for the output sequence. The unfilled positions are replaced with a special token and the built sequence +is returned along with a mask indicating valid tokens. +revert_pattern_sequence maps back an interleaved sequence of shape [B, K, S] to the original alignment +of codebooks across timesteps to an output tensor of shape [B, K, T], using again a special token and a mask +to fill and specify invalid positions if needed. +See the dedicated methods for more details.

+
+ +Expand source code + +
class Pattern:
+    """Base implementation of a pattern over a sequence with multiple codebooks.
+
+    The codebook pattern consists in a layout, defining for each sequence step
+    the list of coordinates of each codebook timestep in the resulting interleaved sequence.
+    The first item of the pattern is always an empty list in order to properly insert a special token
+    to start with. For convenience, we also keep track of ``n_q`` the number of codebooks used for the pattern
+    and ``timesteps`` the number of timesteps corresponding to the original sequence.
+
+    The pattern provides convenient methods to build and revert interleaved sequences from it:
+    ``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T]
+        to the interleaved sequence of shape [B, K, S] applying the pattern, with S being the batch size,
+        K being the number of codebooks, T the number of original timesteps and S the number of sequence steps
+        for the output sequence. The unfilled positions are replaced with a special token and the built sequence
+        is returned along with a mask indicating valid tokens.
+    ``revert_pattern_sequence`` maps back an interleaved sequence of shape [B, K, S] to the original alignment
+        of codebooks across timesteps to an output tensor of shape [B, K, T], using again a special token and a mask
+        to fill and specify invalid positions if needed.
+    See the dedicated methods for more details.
+    """
+    # Pattern layout, for each sequence step, we have a list of coordinates
+    # corresponding to the original codebook timestep and position.
+    # The first list is always an empty list in order to properly insert
+    # a special token to start with.
+    layout: PatternLayout
+    timesteps: int
+    n_q: int
+
+    def __post_init__(self):
+        assert len(self.layout) > 0
+        assert self.layout[0] == []
+        self._validate_layout()
+        self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes)
+        self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes)
+        logger.info("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout))
+
+    def _validate_layout(self):
+        """Runs checks on the layout to ensure a valid pattern is defined.
+        A pattern is considered invalid if:
+            - Multiple timesteps for a same codebook are defined in the same sequence step
+            - The timesteps for a given codebook are not in ascending order as we advance in the sequence
+              (this would mean that we have future timesteps before past timesteps).
+        """
+        q_timesteps = {q: 0 for q in range(self.n_q)}
+        for s, seq_coords in enumerate(self.layout):
+            if len(seq_coords) > 0:
+                qs = set()
+                for coord in seq_coords:
+                    qs.add(coord.q)
+                    last_q_timestep = q_timesteps[coord.q]
+                    assert coord.t >= last_q_timestep, \
+                        f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}"
+                    q_timesteps[coord.q] = coord.t
+                # each sequence step contains at max 1 coordinate per codebook
+                assert len(qs) == len(seq_coords), \
+                    f"Multiple entries for a same codebook are found at step {s}"
+
+    @property
+    def num_sequence_steps(self):
+        return len(self.layout) - 1
+
+    @property
+    def max_delay(self):
+        max_t_in_seq_coords = 0
+        for seq_coords in self.layout[1:]:
+            for coords in seq_coords:
+                max_t_in_seq_coords = max(max_t_in_seq_coords, coords.t + 1)
+        return max_t_in_seq_coords - self.timesteps
+
+    @property
+    def valid_layout(self):
+        valid_step = len(self.layout) - self.max_delay
+        return self.layout[:valid_step]
+
+    def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None):
+        """Get codebook coordinates in the layout that corresponds to the specified timestep t
+        and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step
+        and the actual codebook coordinates.
+        """
+        assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps"
+        if q is not None:
+            assert q <= self.n_q, "provided number of codebooks is greater than the pattern's number of codebooks"
+        coords = []
+        for s, seq_codes in enumerate(self.layout):
+            for code in seq_codes:
+                if code.t == t and (q is None or code.q == q):
+                    coords.append((s, code))
+        return coords
+
+    def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) -> tp.List[int]:
+        return [step for step, coords in self.get_sequence_coords_with_timestep(t, q)]
+
+    def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = None) -> tp.Optional[int]:
+        steps_with_timesteps = self.get_steps_with_timestep(t, q)
+        return steps_with_timesteps[0] if len(steps_with_timesteps) > 0 else None
+
+    def _build_pattern_sequence_scatter_indexes(self, timesteps: int, n_q: int, keep_only_valid_steps: bool,
+                                                device: tp.Union[torch.device, str] = 'cpu'):
+        """Build scatter indexes corresponding to the pattern, up to the provided sequence_steps.
+
+        Args:
+            timesteps (int): Maximum number of timesteps steps to consider.
+            keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps.
+            device (Union[torch.device, str]): Device for created tensors.
+        Returns:
+            indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S].
+            mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S].
+        """
+        assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
+        assert timesteps <= self.timesteps, "invalid number of timesteps used to build the sequence from the pattern"
+        # use the proper layout based on whether we limit ourselves to valid steps only or not,
+        # note that using the valid_layout will result in a truncated sequence up to the valid steps
+        ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
+        # single item indexing being super slow with pytorch vs. numpy, so we use numpy here
+        indexes = torch.zeros(n_q, len(ref_layout), dtype=torch.long).numpy()
+        mask = torch.zeros(n_q, len(ref_layout), dtype=torch.bool).numpy()
+        # fill indexes with last sequence step value that will correspond to our special token
+        # the last value is n_q * timesteps as we have flattened z and append special token as the last token
+        # which will correspond to the index: n_q * timesteps
+        indexes[:] = n_q * timesteps
+        # iterate over the pattern and fill scattered indexes and mask
+        for s, sequence_coords in enumerate(ref_layout):
+            for coords in sequence_coords:
+                if coords.t < timesteps:
+                    indexes[coords.q, s] = coords.t + coords.q * timesteps
+                    mask[coords.q, s] = 1
+        indexes = torch.from_numpy(indexes).to(device)
+        mask = torch.from_numpy(mask).to(device)
+        return indexes, mask
+
+    def build_pattern_sequence(self, z: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
+        """Build sequence corresponding to the pattern from the input tensor z.
+        The sequence is built using up to sequence_steps if specified, and non-pattern
+        coordinates are filled with the special token.
+
+        Args:
+            z (torch.Tensor): Input tensor of multi-codebooks sequence, of shape [B, K, T].
+            special_token (int): Special token used to fill non-pattern coordinates in the new sequence.
+            keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
+                Steps that are beyond valid steps will be replaced by the special_token in that case.
+        Returns:
+            values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, S] with S
+                corresponding either to the sequence_steps if provided, otherwise to the length of the pattern.
+            indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, S].
+            mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, S].
+        """
+        B, K, T = z.shape
+        indexes, mask = self._build_pattern_sequence_scatter_indexes(
+            T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device)
+        )
+        z = z.view(B, -1)
+        # we append the special token as the last index of our flattened z tensor
+        z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1)
+        values = z[:, indexes.view(-1)]
+        values = values.view(B, K, indexes.shape[-1])
+        return values, indexes, mask
+
+    def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int, n_q: int,
+                                                 keep_only_valid_steps: bool = False,
+                                                 is_model_output: bool = False,
+                                                 device: tp.Union[torch.device, str] = 'cpu'):
+        """Builds scatter indexes required to retrieve the original multi-codebook sequence
+        from interleaving pattern.
+
+        Args:
+            sequence_steps (int): Sequence steps.
+            n_q (int): Number of codebooks.
+            keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
+                Steps that are beyond valid steps will be replaced by the special_token in that case.
+            is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not.
+            device (Union[torch.device, str]): Device for created tensors.
+        Returns:
+            torch.Tensor: Indexes for reconstructing the output, of shape [K, T].
+            mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
+        """
+        ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
+        # TODO(jade): Do we want to further truncate to only valid timesteps here as well?
+        timesteps = self.timesteps
+        assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
+        assert sequence_steps <= len(ref_layout), \
+            f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}"
+
+        # ensure we take the appropriate indexes to keep the model output from the first special token as well
+        if is_model_output:
+            ref_layout = ref_layout[1:]
+
+        # single item indexing being super slow with pytorch vs. numpy, so we use numpy here
+        indexes = torch.zeros(n_q, timesteps, dtype=torch.long).numpy()
+        mask = torch.zeros(n_q, timesteps, dtype=torch.bool).numpy()
+        # fill indexes with last sequence step value that will correspond to our special token
+        indexes[:] = n_q * sequence_steps
+        for s, sequence_codes in enumerate(ref_layout):
+            if s < sequence_steps:
+                for code in sequence_codes:
+                    if code.t < timesteps:
+                        indexes[code.q, code.t] = s + code.q * sequence_steps
+                        mask[code.q, code.t] = 1
+        indexes = torch.from_numpy(indexes).to(device)
+        mask = torch.from_numpy(mask).to(device)
+        return indexes, mask
+
+    def revert_pattern_sequence(self, s: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
+        """Revert a sequence built from the pattern back to the original multi-codebook sequence without interleaving.
+        The sequence is reverted using up to timesteps if specified, and non-pattern coordinates
+        are filled with the special token.
+
+        Args:
+            s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S].
+            special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence.
+        Returns:
+            values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, T] with T
+                corresponding either to the timesteps if provided, or the total timesteps in pattern otherwise.
+            indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T].
+            mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
+        """
+        B, K, S = s.shape
+        indexes, mask = self._build_reverted_sequence_scatter_indexes(
+            S, K, keep_only_valid_steps, is_model_output=False, device=str(s.device)
+        )
+        s = s.view(B, -1)
+        # we append the special token as the last index of our flattened z tensor
+        s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1)
+        values = s[:, indexes.view(-1)]
+        values = values.view(B, K, indexes.shape[-1])
+        return values, indexes, mask
+
+    def revert_pattern_logits(self, logits: torch.Tensor, special_token: float, keep_only_valid_steps: bool = False):
+        """Revert model logits obtained on a sequence built from the pattern
+        back to a tensor matching the original sequence.
+
+        This method is similar to ``revert_pattern_sequence`` with the following specificities:
+        1. It is designed to work with the extra cardinality dimension
+        2. We return the logits for the first sequence item that matches the special_token and
+        which matching target in the original sequence is the first item of the sequence,
+        while we skip the last logits as there is no matching target
+        """
+        B, card, K, S = logits.shape
+        indexes, mask = self._build_reverted_sequence_scatter_indexes(
+            S, K, keep_only_valid_steps, is_model_output=True, device=logits.device
+        )
+        logits = logits.reshape(B, card, -1)
+        # we append the special token as the last index of our flattened z tensor
+        logits = torch.cat([logits, torch.zeros_like(logits[:, :, :1]) + special_token], dim=-1)  # [B, card, K x S]
+        values = logits[:, :, indexes.view(-1)]
+        values = values.view(B, card, K, indexes.shape[-1])
+        return values, indexes, mask
+
+

Class variables

+
+
var layout : List[List[LayoutCoord]]
+
+
+
+
var n_q : int
+
+
+
+
var timesteps : int
+
+
+
+
+

Instance variables

+
+
var max_delay
+
+
+
+ +Expand source code + +
@property
+def max_delay(self):
+    max_t_in_seq_coords = 0
+    for seq_coords in self.layout[1:]:
+        for coords in seq_coords:
+            max_t_in_seq_coords = max(max_t_in_seq_coords, coords.t + 1)
+    return max_t_in_seq_coords - self.timesteps
+
+
+
var num_sequence_steps
+
+
+
+ +Expand source code + +
@property
+def num_sequence_steps(self):
+    return len(self.layout) - 1
+
+
+
var valid_layout
+
+
+
+ +Expand source code + +
@property
+def valid_layout(self):
+    valid_step = len(self.layout) - self.max_delay
+    return self.layout[:valid_step]
+
+
+
+

Methods

+
+
+def build_pattern_sequence(self, z: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False) +
+
+

Build sequence corresponding to the pattern from the input tensor z. +The sequence is built using up to sequence_steps if specified, and non-pattern +coordinates are filled with the special token.

+

Args

+
+
z : torch.Tensor
+
Input tensor of multi-codebooks sequence, of shape [B, K, T].
+
special_token : int
+
Special token used to fill non-pattern coordinates in the new sequence.
+
keep_only_valid_steps : bool
+
Build a sequence from the pattern up to valid (= fully defined) steps. +Steps that are beyond valid steps will be replaced by the special_token in that case.
+
+

Returns

+

values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, S] with S +corresponding either to the sequence_steps if provided, otherwise to the length of the pattern. +indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, S]. +mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, S].

+
+ +Expand source code + +
def build_pattern_sequence(self, z: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
+    """Build sequence corresponding to the pattern from the input tensor z.
+    The sequence is built using up to sequence_steps if specified, and non-pattern
+    coordinates are filled with the special token.
+
+    Args:
+        z (torch.Tensor): Input tensor of multi-codebooks sequence, of shape [B, K, T].
+        special_token (int): Special token used to fill non-pattern coordinates in the new sequence.
+        keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
+            Steps that are beyond valid steps will be replaced by the special_token in that case.
+    Returns:
+        values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, S] with S
+            corresponding either to the sequence_steps if provided, otherwise to the length of the pattern.
+        indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, S].
+        mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, S].
+    """
+    B, K, T = z.shape
+    indexes, mask = self._build_pattern_sequence_scatter_indexes(
+        T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device)
+    )
+    z = z.view(B, -1)
+    # we append the special token as the last index of our flattened z tensor
+    z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1)
+    values = z[:, indexes.view(-1)]
+    values = values.view(B, K, indexes.shape[-1])
+    return values, indexes, mask
+
+
+
+def get_first_step_with_timesteps(self, t: int, q: Optional[int] = None) ‑> Optional[int] +
+
+
+
+ +Expand source code + +
def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = None) -> tp.Optional[int]:
+    steps_with_timesteps = self.get_steps_with_timestep(t, q)
+    return steps_with_timesteps[0] if len(steps_with_timesteps) > 0 else None
+
+
+
+def get_sequence_coords_with_timestep(self, t: int, q: Optional[int] = None) +
+
+

Get codebook coordinates in the layout that corresponds to the specified timestep t +and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step +and the actual codebook coordinates.

+
+ +Expand source code + +
def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None):
+    """Get codebook coordinates in the layout that corresponds to the specified timestep t
+    and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step
+    and the actual codebook coordinates.
+    """
+    assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps"
+    if q is not None:
+        assert q <= self.n_q, "provided number of codebooks is greater than the pattern's number of codebooks"
+    coords = []
+    for s, seq_codes in enumerate(self.layout):
+        for code in seq_codes:
+            if code.t == t and (q is None or code.q == q):
+                coords.append((s, code))
+    return coords
+
+
+
+def get_steps_with_timestep(self, t: int, q: Optional[int] = None) ‑> List[int] +
+
+
+
+ +Expand source code + +
def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) -> tp.List[int]:
+    return [step for step, coords in self.get_sequence_coords_with_timestep(t, q)]
+
+
+
+def revert_pattern_logits(self, logits: torch.Tensor, special_token: float, keep_only_valid_steps: bool = False) +
+
+

Revert model logits obtained on a sequence built from the pattern +back to a tensor matching the original sequence.

+

This method is similar to revert_pattern_sequence with the following specificities: +1. It is designed to work with the extra cardinality dimension +2. We return the logits for the first sequence item that matches the special_token and +which matching target in the original sequence is the first item of the sequence, +while we skip the last logits as there is no matching target

+
+ +Expand source code + +
def revert_pattern_logits(self, logits: torch.Tensor, special_token: float, keep_only_valid_steps: bool = False):
+    """Revert model logits obtained on a sequence built from the pattern
+    back to a tensor matching the original sequence.
+
+    This method is similar to ``revert_pattern_sequence`` with the following specificities:
+    1. It is designed to work with the extra cardinality dimension
+    2. We return the logits for the first sequence item that matches the special_token and
+    which matching target in the original sequence is the first item of the sequence,
+    while we skip the last logits as there is no matching target
+    """
+    B, card, K, S = logits.shape
+    indexes, mask = self._build_reverted_sequence_scatter_indexes(
+        S, K, keep_only_valid_steps, is_model_output=True, device=logits.device
+    )
+    logits = logits.reshape(B, card, -1)
+    # we append the special token as the last index of our flattened z tensor
+    logits = torch.cat([logits, torch.zeros_like(logits[:, :, :1]) + special_token], dim=-1)  # [B, card, K x S]
+    values = logits[:, :, indexes.view(-1)]
+    values = values.view(B, card, K, indexes.shape[-1])
+    return values, indexes, mask
+
+
+
+def revert_pattern_sequence(self, s: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False) +
+
+

Revert a sequence built from the pattern back to the original multi-codebook sequence without interleaving. +The sequence is reverted using up to timesteps if specified, and non-pattern coordinates +are filled with the special token.

+

Args

+
+
s : torch.Tensor
+
Interleaved sequence tensor obtained from the pattern, of shape [B, K, S].
+
special_token : int or float
+
Special token used to fill non-pattern coordinates in the new sequence.
+
+

Returns

+

values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, T] with T +corresponding either to the timesteps if provided, or the total timesteps in pattern otherwise. +indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T]. +mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].

+
+ +Expand source code + +
def revert_pattern_sequence(self, s: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
+    """Revert a sequence built from the pattern back to the original multi-codebook sequence without interleaving.
+    The sequence is reverted using up to timesteps if specified, and non-pattern coordinates
+    are filled with the special token.
+
+    Args:
+        s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S].
+        special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence.
+    Returns:
+        values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, T] with T
+            corresponding either to the timesteps if provided, or the total timesteps in pattern otherwise.
+        indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T].
+        mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
+    """
+    B, K, S = s.shape
+    indexes, mask = self._build_reverted_sequence_scatter_indexes(
+        S, K, keep_only_valid_steps, is_model_output=False, device=str(s.device)
+    )
+    s = s.view(B, -1)
+    # we append the special token as the last index of our flattened z tensor
+    s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1)
+    values = s[:, indexes.view(-1)]
+    values = values.view(B, K, indexes.shape[-1])
+    return values, indexes, mask
+
+
+
+
+
+class UnrolledPatternProvider +(n_q: int, flattening: Optional[List[int]] = None, delays: Optional[List[int]] = None) +
+
+

Provider for unrolling codebooks pattern. +This pattern provider enables to represent the codebook flattened completely or only to some extend +while also specifying a given delay between the flattened codebooks representation, allowing to +unroll the codebooks in the sequence.

+

Example

+
    +
  1. Flattening of the codebooks. +By default, the pattern provider will fully flatten the codebooks such as flattening=range(n_q), +taking n_q = 3 and timesteps = 4: +[[1, 2, 3, 4], +[1, 2, 3, 4], +[1, 2, 3, 4]] +will result into: +[[S, S, 1, S, S, 2, S, S, 3, S, S, 4], +[S, 1, S, S, 2, S, S, 3, S, S, 4, S], +[1, S, S, 2, S, S, 3, S, S, 4, S, S]]
  2. +
  3. Partial flattening of the codebooks. The flattening parameter allows to specify the inner step +for each of the codebook, allowing to define which codebook to flatten (or keep in parallel), for example +taking n_q = 3, timesteps = 4 and flattening = [0, 1, 1]: +[[1, 2, 3, 4], +[1, 2, 3, 4], +[1, 2, 3, 4]] +will result into: +[[S, 1, S, S, 2, S, S, 3, S, S, 4, S], +[S, 1, S, S, 2, S, S, 3, S, S, 4, S], +[1, S, S, 2, S, S, 3, S, S, 4, S, S]]
  4. +
  5. Flattening with delay. The delay parameter allows to further unroll the sequence of codebooks +allowing to specify the delay per codebook. Note that the delay between codebooks flattened to the +same inner timestep should be coherent. For example, taking n_q = 3, timesteps = 4, flattening = [0, 1, 1] +and delays = [0, 3, 3]: +[[1, 2, 3, 4], +[1, 2, 3, 4], +[1, 2, 3, 4]] +will result into: +[[S, S, S, 1, S, 2, S, 3, S, 4], +[S, S, S, 1, S, 2, S, 3, S, 4], +[1, 2, 3, S, 4, S, 5, S, 6, S]]
  6. +
+

Args

+
+
n_q : int
+
Number of codebooks.
+
flattening : Optional[List[int]]
+
Flattening schema over the codebooks. If not defined, +the codebooks will be flattened to 1 codebook per step, meaning that the sequence will +have n_q extra steps for each timestep.
+
delays : Optional[List[int]]
+
Delay for each of the codebooks. If not defined, +no delay is added and therefore will default to [0] * n_q. +Note that two codebooks that will be flattened to the same inner step +should have the same delay, otherwise the pattern is considered as invalid.
+
+
+ +Expand source code + +
class UnrolledPatternProvider(CodebooksPatternProvider):
+    """Provider for unrolling codebooks pattern.
+    This pattern provider enables to represent the codebook flattened completely or only to some extend
+    while also specifying a given delay between the flattened codebooks representation, allowing to
+    unroll the codebooks in the sequence.
+
+    Example:
+        1. Flattening of the codebooks.
+        By default, the pattern provider will fully flatten the codebooks such as flattening=range(n_q),
+        taking n_q = 3 and timesteps = 4:
+        [[1, 2, 3, 4],
+         [1, 2, 3, 4],
+         [1, 2, 3, 4]]
+        will result into:
+        [[S, S, 1, S, S, 2, S, S, 3, S, S, 4],
+         [S, 1, S, S, 2, S, S, 3, S, S, 4, S],
+         [1, S, S, 2, S, S, 3, S, S, 4, S, S]]
+        2. Partial flattening of the codebooks. The ``flattening`` parameter allows to specify the inner step
+        for each of the codebook, allowing to define which codebook to flatten (or keep in parallel), for example
+        taking n_q = 3, timesteps = 4 and flattening = [0, 1, 1]:
+        [[1, 2, 3, 4],
+         [1, 2, 3, 4],
+         [1, 2, 3, 4]]
+        will result into:
+        [[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
+         [S, 1, S, S, 2, S, S, 3, S, S, 4, S],
+         [1, S, S, 2, S, S, 3, S, S, 4, S, S]]
+        3. Flattening with delay. The ``delay`` parameter allows to further unroll the sequence of codebooks
+        allowing to specify the delay per codebook. Note that the delay between codebooks flattened to the
+        same inner timestep should be coherent. For example, taking n_q = 3, timesteps = 4, flattening = [0, 1, 1]
+        and delays = [0, 3, 3]:
+        [[1, 2, 3, 4],
+         [1, 2, 3, 4],
+         [1, 2, 3, 4]]
+        will result into:
+        [[S, S, S, 1, S, 2, S, 3, S, 4],
+         [S, S, S, 1, S, 2, S, 3, S, 4],
+         [1, 2, 3, S, 4, S, 5, S, 6, S]]
+
+    Args:
+        n_q (int): Number of codebooks.
+        flattening (Optional[List[int]]): Flattening schema over the codebooks. If not defined,
+            the codebooks will be flattened to 1 codebook per step, meaning that the sequence will
+            have n_q extra steps for each timestep.
+        delays (Optional[List[int]]): Delay for each of the codebooks. If not defined,
+            no delay is added and therefore will default to [0] * ``n_q``.
+            Note that two codebooks that will be flattened to the same inner step
+            should have the same delay, otherwise the pattern is considered as invalid.
+    """
+    FlattenedCodebook = namedtuple('FlattenedCodebook', ['codebooks', 'delay'])
+
+    def __init__(self, n_q: int, flattening: tp.Optional[tp.List[int]] = None,
+                 delays: tp.Optional[tp.List[int]] = None):
+        super().__init__(n_q)
+        if flattening is None:
+            flattening = list(range(n_q))
+        if delays is None:
+            delays = [0] * n_q
+        assert len(flattening) == n_q
+        assert len(delays) == n_q
+        assert sorted(flattening) == flattening
+        assert sorted(delays) == delays
+        self._flattened_codebooks = self._build_flattened_codebooks(delays, flattening)
+        self.max_delay = max(delays)
+
+    def _build_flattened_codebooks(self, delays: tp.List[int], flattening: tp.List[int]):
+        """Build a flattened codebooks representation as a dictionary of inner step
+        and the actual codebook indices corresponding to the flattened codebook. For convenience, we
+        also store the delay associated to the flattened codebook to avoid maintaining an extra mapping.
+        """
+        flattened_codebooks: dict = {}
+        for q, (inner_step, delay) in enumerate(zip(flattening, delays)):
+            if inner_step not in flattened_codebooks:
+                flat_codebook = UnrolledPatternProvider.FlattenedCodebook(codebooks=[q], delay=delay)
+            else:
+                flat_codebook = flattened_codebooks[inner_step]
+                assert flat_codebook.delay == delay, (
+                    "Delay and flattening between codebooks is inconsistent: ",
+                    "two codebooks flattened to the same position should have the same delay."
+                )
+                flat_codebook.codebooks.append(q)
+            flattened_codebooks[inner_step] = flat_codebook
+        return flattened_codebooks
+
+    @property
+    def _num_inner_steps(self):
+        """Number of inner steps to unroll between timesteps in order to flatten the codebooks.
+        """
+        return max([inner_step for inner_step in self._flattened_codebooks.keys()]) + 1
+
+    def num_virtual_steps(self, timesteps: int) -> int:
+        return timesteps * self._num_inner_steps + 1
+
+    def get_pattern(self, timesteps: int) -> Pattern:
+        """Builds pattern for delay across codebooks.
+
+        Args:
+            timesteps (int): Total numer of timesteps.
+        """
+        # the PatternLayout is built as a tuple of sequence position and list of coordinates
+        # so that it can be reordered properly given the required delay between codebooks of given timesteps
+        indexed_out: list = [(-1, [])]
+        max_timesteps = timesteps + self.max_delay
+        for t in range(max_timesteps):
+            # for each timestep, we unroll the flattened codebooks,
+            # emitting the sequence step with the corresponding delay
+            for step in range(self._num_inner_steps):
+                if step in self._flattened_codebooks:
+                    # we have codebooks at this virtual step to emit
+                    step_codebooks = self._flattened_codebooks[step]
+                    t_for_q = t + step_codebooks.delay
+                    coords = [LayoutCoord(t, q) for q in step_codebooks.codebooks]
+                    if t_for_q < max_timesteps and t < max_timesteps:
+                        indexed_out.append((t_for_q, coords))
+                else:
+                    # there is no codebook in this virtual step so we emit an empty list
+                    indexed_out.append((t, []))
+        out = [coords for _, coords in sorted(indexed_out)]
+        return Pattern(out, n_q=self.n_q, timesteps=timesteps)
+
+

Ancestors

+ +

Class variables

+
+
var FlattenedCodebook
+
+
+
+
+

Methods

+
+
+def get_pattern(self, timesteps: int) ‑> Pattern +
+
+

Builds pattern for delay across codebooks.

+

Args

+
+
timesteps : int
+
Total numer of timesteps.
+
+
+ +Expand source code + +
def get_pattern(self, timesteps: int) -> Pattern:
+    """Builds pattern for delay across codebooks.
+
+    Args:
+        timesteps (int): Total numer of timesteps.
+    """
+    # the PatternLayout is built as a tuple of sequence position and list of coordinates
+    # so that it can be reordered properly given the required delay between codebooks of given timesteps
+    indexed_out: list = [(-1, [])]
+    max_timesteps = timesteps + self.max_delay
+    for t in range(max_timesteps):
+        # for each timestep, we unroll the flattened codebooks,
+        # emitting the sequence step with the corresponding delay
+        for step in range(self._num_inner_steps):
+            if step in self._flattened_codebooks:
+                # we have codebooks at this virtual step to emit
+                step_codebooks = self._flattened_codebooks[step]
+                t_for_q = t + step_codebooks.delay
+                coords = [LayoutCoord(t, q) for q in step_codebooks.codebooks]
+                if t_for_q < max_timesteps and t < max_timesteps:
+                    indexed_out.append((t_for_q, coords))
+            else:
+                # there is no codebook in this virtual step so we emit an empty list
+                indexed_out.append((t, []))
+    out = [coords for _, coords in sorted(indexed_out)]
+    return Pattern(out, n_q=self.n_q, timesteps=timesteps)
+
+
+
+def num_virtual_steps(self, timesteps: int) ‑> int +
+
+
+
+ +Expand source code + +
def num_virtual_steps(self, timesteps: int) -> int:
+    return timesteps * self._num_inner_steps + 1
+
+
+
+
+
+class VALLEPattern +(n_q: int, delays: Optional[List[int]] = None) +
+
+

Almost VALL-E style pattern. We futher allow some delays for the +codebooks other than the first one.

+

Args

+
+
n_q : int
+
Number of codebooks.
+
delays : Optional[List[int]]
+
Delay for each of the codebooks. +If delays not defined, each codebook is delayed by 1 compared to the previous one.
+
+
+ +Expand source code + +
class VALLEPattern(CodebooksPatternProvider):
+    """Almost VALL-E style pattern. We futher allow some delays for the
+    codebooks other than the first one.
+
+    Args:
+        n_q (int): Number of codebooks.
+        delays (Optional[List[int]]): Delay for each of the codebooks.
+            If delays not defined, each codebook is delayed by 1 compared to the previous one.
+    """
+    def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None):
+        super().__init__(n_q)
+        if delays is None:
+            delays = [0] * (n_q - 1)
+        self.delays = delays
+        assert len(self.delays) == self.n_q - 1
+        assert sorted(self.delays) == self.delays
+
+    def get_pattern(self, timesteps: int) -> Pattern:
+        out: PatternLayout = [[]]
+        for t in range(timesteps):
+            out.append([LayoutCoord(t, 0)])
+        max_delay = max(self.delays)
+        for t in range(timesteps + max_delay):
+            v = []
+            for q, delay in enumerate(self.delays):
+                t_for_q = t - delay
+                if t_for_q >= 0:
+                    v.append(LayoutCoord(t_for_q, q + 1))
+            out.append(v)
+        return Pattern(out, n_q=self.n_q, timesteps=timesteps)
+
+

Ancestors

+ +

Inherited members

+ +
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/audiocraft/modules/conditioners.html b/docs/audiocraft/modules/conditioners.html new file mode 100644 index 00000000..049caf6c --- /dev/null +++ b/docs/audiocraft/modules/conditioners.html @@ -0,0 +1,3573 @@ + + + + + + +audiocraft.modules.conditioners API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.modules.conditioners

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from collections import defaultdict
+from copy import deepcopy
+from dataclasses import dataclass, field
+from itertools import chain
+import logging
+import math
+import random
+import re
+import typing as tp
+import warnings
+
+from einops import rearrange
+from num2words import num2words
+import spacy
+from transformers import T5EncoderModel, T5Tokenizer  # type: ignore
+import torchaudio
+import torch
+from torch import nn
+from torch import Tensor
+import torch.nn.functional as F
+from torch.nn.utils.rnn import pad_sequence
+
+from .streaming import StreamingModule
+from .transformer import create_sin_embedding
+from ..data.audio_dataset import SegmentInfo
+from ..utils.autocast import TorchAutocast
+from ..utils.utils import hash_trick, length_to_mask, collate
+
+
+logger = logging.getLogger(__name__)
+TextCondition = tp.Optional[str]  # a text condition can be a string or None (if doesn't exist)
+ConditionType = tp.Tuple[Tensor, Tensor]  # condition, mask
+
+
+class WavCondition(tp.NamedTuple):
+    wav: Tensor
+    length: Tensor
+    path: tp.List[tp.Optional[str]] = []
+
+
+def nullify_condition(condition: ConditionType, dim: int = 1):
+    """This function transforms an input condition to a null condition.
+    The way it is done by converting it to a single zero vector similarly
+    to how it is done inside WhiteSpaceTokenizer and NoopTokenizer.
+
+    Args:
+        condition (ConditionType): a tuple of condition and mask (tp.Tuple[Tensor, Tensor])
+        dim (int): the dimension that will be truncated (should be the time dimension)
+        WARNING!: dim should not be the batch dimension!
+    Returns:
+        ConditionType: a tuple of null condition and mask
+    """
+    assert dim != 0, "dim cannot be the batch dimension!"
+    assert type(condition) == tuple and \
+        type(condition[0]) == Tensor and \
+        type(condition[1]) == Tensor, "'nullify_condition' got an unexpected input type!"
+    cond, mask = condition
+    B = cond.shape[0]
+    last_dim = cond.dim() - 1
+    out = cond.transpose(dim, last_dim)
+    out = 0. * out[..., :1]
+    out = out.transpose(dim, last_dim)
+    mask = torch.zeros((B, 1), device=out.device).int()
+    assert cond.dim() == out.dim()
+    return out, mask
+
+
+def nullify_wav(wav: Tensor) -> WavCondition:
+    """Create a nullified WavCondition from a wav tensor with appropriate shape.
+
+    Args:
+        wav (Tensor): tensor of shape [B, T]
+    Returns:
+        WavCondition: wav condition with nullified wav.
+    """
+    null_wav, _ = nullify_condition((wav, torch.zeros_like(wav)), dim=wav.dim() - 1)
+    return WavCondition(
+        wav=null_wav,
+        length=torch.tensor([0] * wav.shape[0], device=wav.device),
+        path=['null_wav'] * wav.shape[0]
+    )
+
+
+@dataclass
+class ConditioningAttributes:
+    text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict)
+    wav: tp.Dict[str, WavCondition] = field(default_factory=dict)
+
+    def __getitem__(self, item):
+        return getattr(self, item)
+
+    @property
+    def text_attributes(self):
+        return self.text.keys()
+
+    @property
+    def wav_attributes(self):
+        return self.wav.keys()
+
+    @property
+    def attributes(self):
+        return {"text": self.text_attributes, "wav": self.wav_attributes}
+
+    def to_flat_dict(self):
+        return {
+            **{f"text.{k}": v for k, v in self.text.items()},
+            **{f"wav.{k}": v for k, v in self.wav.items()},
+        }
+
+    @classmethod
+    def from_flat_dict(cls, x):
+        out = cls()
+        for k, v in x.items():
+            kind, att = k.split(".")
+            out[kind][att] = v
+        return out
+
+
+class SegmentWithAttributes(SegmentInfo):
+    """Base class for all dataclasses that are used for conditioning.
+    All child classes should implement `to_condition_attributes` that converts
+    the existing attributes to a dataclass of type ConditioningAttributes.
+    """
+    def to_condition_attributes(self) -> ConditioningAttributes:
+        raise NotImplementedError()
+
+
+class Tokenizer:
+    """Base class for all tokenizers
+    (in case we want to introduce more advances tokenizers in the future).
+    """
+    def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[Tensor, Tensor]:
+        raise NotImplementedError()
+
+
+class WhiteSpaceTokenizer(Tokenizer):
+    """This tokenizer should be used for natural language descriptions.
+    For example:
+    ["he didn't, know he's going home.", 'shorter sentence'] =>
+    [[78, 62, 31,  4, 78, 25, 19, 34],
+    [59, 77,  0,  0,  0,  0,  0,  0]]
+    """
+    PUNCTUATIONS = "?:!.,;"
+
+    def __init__(self, n_bins: int, pad_idx: int = 0, language: str = "en_core_web_sm",
+                 lemma: bool = True, stopwords: bool = True) -> None:
+        self.n_bins = n_bins
+        self.pad_idx = pad_idx
+        self.lemma = lemma
+        self.stopwords = stopwords
+        try:
+            self.nlp = spacy.load(language)
+        except IOError:
+            spacy.cli.download(language)  # type: ignore
+            self.nlp = spacy.load(language)
+
+    @tp.no_type_check
+    def __call__(
+        self,
+        texts: tp.List[tp.Optional[str]],
+        return_text: bool = False
+    ) -> tp.Tuple[Tensor, Tensor]:
+        """Take a list of strings and convert them to a tensor of indices.
+
+        Args:
+            texts (tp.List[str]): List of strings.
+            return_text (bool, optional): Whether to return text as additional tuple item. Defaults to False.
+        Returns:
+            tp.Tuple[Tensor, Tensor]:
+                - Indices of words in the LUT.
+                - And a mask indicating where the padding tokens are
+        """
+        output, lengths = [], []
+        texts = deepcopy(texts)
+        for i, text in enumerate(texts):
+            # if current sample doesn't have a certain attribute, replace with pad token
+            if text is None:
+                output.append(Tensor([self.pad_idx]))
+                lengths.append(0)
+                continue
+
+            # convert numbers to words
+            text = re.sub(r"(\d+)", lambda x: num2words(int(x.group(0))), text)  # type: ignore
+            # normalize text
+            text = self.nlp(text)  # type: ignore
+            # remove stopwords
+            if self.stopwords:
+                text = [w for w in text if not w.is_stop]  # type: ignore
+            # remove punctuations
+            text = [w for w in text if w.text not in self.PUNCTUATIONS]  # type: ignore
+            # lemmatize if needed
+            text = [getattr(t, "lemma_" if self.lemma else "text") for t in text]  # type: ignore
+
+            texts[i] = " ".join(text)
+            lengths.append(len(text))
+            # convert to tensor
+            tokens = Tensor([hash_trick(w, self.n_bins) for w in text])
+            output.append(tokens)
+
+        mask = length_to_mask(torch.IntTensor(lengths)).int()
+        padded_output = pad_sequence(output, padding_value=self.pad_idx).int().t()
+        if return_text:
+            return padded_output, mask, texts  # type: ignore
+        return padded_output, mask
+
+
+class NoopTokenizer(Tokenizer):
+    """This tokenizer should be used for global conditioners such as: artist, genre, key, etc.
+    The difference between this and WhiteSpaceTokenizer is that NoopTokenizer does not split
+    strings, so "Jeff Buckley" will get it's own index. Whereas WhiteSpaceTokenizer will
+    split it to ["Jeff", "Buckley"] and return an index per word.
+
+    For example:
+    ["Queen", "ABBA", "Jeff Buckley"] => [43, 55, 101]
+    ["Metal", "Rock", "Classical"] => [0, 223, 51]
+    """
+    def __init__(self, n_bins: int, pad_idx: int = 0):
+        self.n_bins = n_bins
+        self.pad_idx = pad_idx
+
+    def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[Tensor, Tensor]:
+        output, lengths = [], []
+        for text in texts:
+            # if current sample doesn't have a certain attribute, replace with pad token
+            if text is None:
+                output.append(self.pad_idx)
+                lengths.append(0)
+            else:
+                output.append(hash_trick(text, self.n_bins))
+                lengths.append(1)
+
+        tokens = torch.LongTensor(output).unsqueeze(1)
+        mask = length_to_mask(torch.IntTensor(lengths)).int()
+        return tokens, mask
+
+
+class BaseConditioner(nn.Module):
+    """Base model for all conditioner modules. We allow the output dim to be different
+    than the hidden dim for two reasons: 1) keep our LUTs small when the vocab is large;
+    2) make all condition dims consistent.
+
+    Args:
+        dim (int): Hidden dim of the model (text-encoder/LUT).
+        output_dim (int): Output dim of the conditioner.
+    """
+    def __init__(self, dim, output_dim):
+        super().__init__()
+        self.dim = dim
+        self.output_dim = output_dim
+        self.output_proj = nn.Linear(dim, output_dim)
+
+    def tokenize(self, *args, **kwargs) -> tp.Any:
+        """Should be any part of the processing that will lead to a synchronization
+        point, e.g. BPE tokenization with transfer to the GPU.
+
+        The returned value will be saved and return later when calling forward().
+        """
+        raise NotImplementedError()
+
+    def forward(self, inputs: tp.Any) -> ConditionType:
+        """Gets input that should be used as conditioning (e.g, genre, description or a waveform).
+        Outputs a ConditionType, after the input data was embedded as a dense vector.
+
+        Returns:
+            ConditionType:
+                - A tensor of size [B, T, D] where B is the batch size, T is the length of the
+                  output embedding and D is the dimension of the embedding.
+                - And a mask indicating where the padding tokens.
+        """
+        raise NotImplementedError()
+
+
+class TextConditioner(BaseConditioner):
+    ...
+
+
+class LUTConditioner(TextConditioner):
+    """Lookup table TextConditioner.
+
+    Args:
+        n_bins (int): Number of bins.
+        dim (int): Hidden dim of the model (text-encoder/LUT).
+        output_dim (int): Output dim of the conditioner.
+        tokenizer (str): Name of the tokenizer.
+        pad_idx (int, optional): Index for padding token. Defaults to 0.
+    """
+    def __init__(self, n_bins: int, dim: int, output_dim: int, tokenizer: str, pad_idx: int = 0):
+        super().__init__(dim, output_dim)
+        self.embed = nn.Embedding(n_bins, dim)
+        self.tokenizer: Tokenizer
+        if tokenizer == "whitespace":
+            self.tokenizer = WhiteSpaceTokenizer(n_bins, pad_idx=pad_idx)
+        elif tokenizer == "noop":
+            self.tokenizer = NoopTokenizer(n_bins, pad_idx=pad_idx)
+        else:
+            raise ValueError(f"unrecognized tokenizer `{tokenizer}`.")
+
+    def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+        device = self.embed.weight.device
+        tokens, mask = self.tokenizer(x)
+        tokens, mask = tokens.to(device), mask.to(device)
+        return tokens, mask
+
+    def forward(self, inputs: tp.Tuple[torch.Tensor, torch.Tensor]) -> ConditionType:
+        tokens, mask = inputs
+        embeds = self.embed(tokens)
+        embeds = self.output_proj(embeds)
+        embeds = (embeds * mask.unsqueeze(-1))
+        return embeds, mask
+
+
+class T5Conditioner(TextConditioner):
+    """T5-based TextConditioner.
+
+    Args:
+        name (str): Name of the T5 model.
+        output_dim (int): Output dim of the conditioner.
+        finetune (bool): Whether to fine-tune T5 at train time.
+        device (str): Device for T5 Conditioner.
+        autocast_dtype (tp.Optional[str], optional): Autocast dtype.
+        word_dropout (float, optional): Word dropout probability.
+        normalize_text (bool, optional): Whether to apply text normalization.
+    """
+    MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b",
+              "google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large",
+              "google/flan-t5-xl", "google/flan-t5-xxl"]
+    MODELS_DIMS = {
+        "t5-small": 512,
+        "t5-base": 768,
+        "t5-large": 1024,
+        "t5-3b": 1024,
+        "t5-11b": 1024,
+        "google/flan-t5-small": 512,
+        "google/flan-t5-base": 768,
+        "google/flan-t5-large": 1024,
+        "google/flan-t5-3b": 1024,
+        "google/flan-t5-11b": 1024,
+    }
+
+    def __init__(self, name: str, output_dim: int, finetune: bool, device: str,
+                 autocast_dtype: tp.Optional[str] = 'float32', word_dropout: float = 0.,
+                 normalize_text: bool = False):
+        assert name in self.MODELS, f"unrecognized t5 model name (should in {self.MODELS})"
+        super().__init__(self.MODELS_DIMS[name], output_dim)
+        self.device = device
+        self.name = name
+        self.finetune = finetune
+        self.word_dropout = word_dropout
+
+        if autocast_dtype is None or self.device == 'cpu':
+            self.autocast = TorchAutocast(enabled=False)
+            if self.device != 'cpu':
+                logger.warning("T5 has no autocast, this might lead to NaN")
+        else:
+            dtype = getattr(torch, autocast_dtype)
+            assert isinstance(dtype, torch.dtype)
+            logger.info(f"T5 will be evaluated with autocast as {autocast_dtype}")
+            self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype)
+        # Let's disable logging temporarily because T5 will vomit some errors otherwise.
+        # thanks https://gist.github.com/simon-weber/7853144
+        previous_level = logging.root.manager.disable
+        logging.disable(logging.ERROR)
+        with warnings.catch_warnings():
+            warnings.simplefilter("ignore")
+            try:
+                self.t5_tokenizer = T5Tokenizer.from_pretrained(name)
+                t5 = T5EncoderModel.from_pretrained(name).train(mode=finetune)
+            finally:
+                logging.disable(previous_level)
+        if finetune:
+            self.t5 = t5
+        else:
+            # this makes sure that the t5 models is not part
+            # of the saved checkpoint
+            self.__dict__["t5"] = t5.to(device)
+
+        self.normalize_text = normalize_text
+        if normalize_text:
+            self.text_normalizer = WhiteSpaceTokenizer(1, lemma=True, stopwords=True)
+
+    def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch.Tensor]:
+        # if current sample doesn't have a certain attribute, replace with empty string
+        entries: tp.List[str] = [xi if xi is not None else "" for xi in x]
+        if self.normalize_text:
+            _, _, entries = self.text_normalizer(entries, return_text=True)
+        if self.word_dropout > 0. and self.training:
+            new_entries = []
+            for entry in entries:
+                words = [word for word in entry.split(" ") if random.random() >= self.word_dropout]
+                new_entries.append(" ".join(words))
+            entries = new_entries
+
+        empty_idx = torch.LongTensor([i for i, xi in enumerate(entries) if xi == ""])
+
+        inputs = self.t5_tokenizer(entries, return_tensors="pt", padding=True).to(self.device)
+        mask = inputs["attention_mask"]
+        mask[empty_idx, :] = 0  # zero-out index where the input is non-existant
+        return inputs
+
+    def forward(self, inputs: tp.Dict[str, torch.Tensor]) -> ConditionType:
+        mask = inputs["attention_mask"]
+        with torch.set_grad_enabled(self.finetune), self.autocast:
+            embeds = self.t5(**inputs).last_hidden_state
+        embeds = self.output_proj(embeds.to(self.output_proj.weight))
+        embeds = (embeds * mask.unsqueeze(-1))
+        return embeds, mask
+
+
+class WaveformConditioner(BaseConditioner):
+    """Base class for all conditioners that take a waveform as input.
+    Classes that inherit must implement `_get_wav_embedding` that outputs
+    a continuous tensor, and `_downsampling_factor` that returns the down-sampling
+    factor of the embedding model.
+
+    Args:
+        dim (int): The internal representation dimension.
+        output_dim (int): Output dimension.
+        device (tp.Union[torch.device, str]): Device.
+    """
+    def __init__(self, dim: int, output_dim: int, device: tp.Union[torch.device, str]):
+        super().__init__(dim, output_dim)
+        self.device = device
+
+    def tokenize(self, wav_length: WavCondition) -> WavCondition:
+        wav, length, path = wav_length
+        assert length is not None
+        return WavCondition(wav.to(self.device), length.to(self.device), path)
+
+    def _get_wav_embedding(self, wav: Tensor) -> Tensor:
+        """Gets as input a wav and returns a dense vector of conditions."""
+        raise NotImplementedError()
+
+    def _downsampling_factor(self):
+        """Returns the downsampling factor of the embedding model."""
+        raise NotImplementedError()
+
+    def forward(self, inputs: WavCondition) -> ConditionType:
+        """
+        Args:
+            input (WavCondition): Tuple of (waveform, lengths).
+        Returns:
+            ConditionType: Dense vector representing the conditioning along with its' mask.
+        """
+        wav, lengths, path = inputs
+        with torch.no_grad():
+            embeds = self._get_wav_embedding(wav)
+        embeds = embeds.to(self.output_proj.weight)
+        embeds = self.output_proj(embeds)
+
+        if lengths is not None:
+            lengths = lengths / self._downsampling_factor()
+            mask = length_to_mask(lengths, max_len=embeds.shape[1]).int()  # type: ignore
+        else:
+            mask = torch.ones_like(embeds)
+        embeds = (embeds * mask.unsqueeze(2).to(self.device))
+
+        return embeds, mask
+
+
+class ChromaStemConditioner(WaveformConditioner):
+    """Chroma conditioner that uses DEMUCS to first filter out drums and bass. The is followed by
+    the insight the drums and bass often dominate the chroma, leading to the chroma not containing the
+    information about melody.
+
+    Args:
+        output_dim (int): Output dimension for the conditioner.
+        sample_rate (int): Sample rate for the chroma extractor.
+        n_chroma (int): Number of chroma for the chroma extractor.
+        radix2_exp (int): Radix2 exponent for the chroma extractor.
+        duration (float): Duration used during training. This is later used for correct padding
+            in case we are using chroma as prefix.
+        match_len_on_eval (bool, optional): If True then all chromas are padded to the training
+            duration. Defaults to False.
+        eval_wavs (str, optional): Path to a json egg with waveform, this waveforms are used as
+            conditions during eval (for cases where we don't want to leak test conditions like MusicCaps).
+            Defaults to None.
+        n_eval_wavs (int, optional): Limits the number of waveforms used for conditioning. Defaults to 0.
+        device (tp.Union[torch.device, str], optional): Device for the conditioner.
+        **kwargs: Additional parameters for the chroma extractor.
+    """
+    def __init__(self, output_dim: int, sample_rate: int, n_chroma: int, radix2_exp: int,
+                 duration: float, match_len_on_eval: bool = True, eval_wavs: tp.Optional[str] = None,
+                 n_eval_wavs: int = 0, device: tp.Union[torch.device, str] = "cpu", **kwargs):
+        from demucs import pretrained
+        super().__init__(dim=n_chroma, output_dim=output_dim, device=device)
+        self.autocast = TorchAutocast(enabled=device != "cpu", device_type=self.device, dtype=torch.float32)
+        self.sample_rate = sample_rate
+        self.match_len_on_eval = match_len_on_eval
+        self.duration = duration
+        self.__dict__["demucs"] = pretrained.get_model('htdemucs').to(device)
+        self.stem2idx = {'drums': 0, 'bass': 1, 'other': 2, 'vocal': 3}
+        self.stem_idx = torch.LongTensor([self.stem2idx['vocal'], self.stem2idx['other']]).to(device)
+        self.chroma = ChromaExtractor(sample_rate=sample_rate, n_chroma=n_chroma, radix2_exp=radix2_exp,
+                                      device=device, **kwargs)
+        self.chroma_len = self._get_chroma_len()
+
+    def _downsampling_factor(self):
+        return self.chroma.winhop
+
+    def _get_chroma_len(self):
+        """Get length of chroma during training"""
+        dummy_wav = torch.zeros((1, self.sample_rate * self.duration), device=self.device)
+        dummy_chr = self.chroma(dummy_wav)
+        return dummy_chr.shape[1]
+
+    @torch.no_grad()
+    def _get_filtered_wav(self, wav):
+        from demucs.apply import apply_model
+        from demucs.audio import convert_audio
+        with self.autocast:
+            wav = convert_audio(wav, self.sample_rate, self.demucs.samplerate, self.demucs.audio_channels)
+            stems = apply_model(self.demucs, wav, device=self.device)
+            stems = stems[:, self.stem_idx]  # extract stem
+            stems = stems.sum(1)  # merge extracted stems
+            stems = stems.mean(1, keepdim=True)  # mono
+            stems = convert_audio(stems, self.demucs.samplerate, self.sample_rate, 1)
+            return stems
+
+    @torch.no_grad()
+    def _get_wav_embedding(self, wav):
+        # avoid 0-size tensors when we are working with null conds
+        if wav.shape[-1] == 1:
+            return self.chroma(wav)
+        stems = self._get_filtered_wav(wav)
+        chroma = self.chroma(stems)
+
+        if self.match_len_on_eval:
+            b, t, c = chroma.shape
+            if t > self.chroma_len:
+                chroma = chroma[:, :self.chroma_len]
+                logger.debug(f'chroma was truncated! ({t} -> {chroma.shape[1]})')
+            elif t < self.chroma_len:
+                # chroma = F.pad(chroma, (0, 0, 0, self.chroma_len - t))
+                n_repeat = int(math.ceil(self.chroma_len / t))
+                chroma = chroma.repeat(1, n_repeat, 1)
+                chroma = chroma[:, :self.chroma_len]
+                logger.debug(f'chroma was zero-padded! ({t} -> {chroma.shape[1]})')
+        return chroma
+
+
+class ChromaExtractor(nn.Module):
+    """Chroma extraction class, handles chroma extraction and quantization.
+
+    Args:
+        sample_rate (int): Sample rate.
+        n_chroma (int): Number of chroma to consider.
+        radix2_exp (int): Radix2 exponent.
+        nfft (tp.Optional[int], optional): Number of FFT.
+        winlen (tp.Optional[int], optional): Window length.
+        winhop (tp.Optional[int], optional): Window hop size.
+        argmax (bool, optional): Whether to use argmax. Defaults to False.
+        norm (float, optional): Norm for chroma normalization. Defaults to inf.
+        device (tp.Union[torch.device, str], optional): Device to use. Defaults to cpu.
+    """
+    def __init__(self, sample_rate: int, n_chroma: int = 12, radix2_exp: int = 12,
+                 nfft: tp.Optional[int] = None, winlen: tp.Optional[int] = None, winhop: tp.Optional[int] = None,
+                 argmax: bool = False, norm: float = torch.inf, device: tp.Union[torch.device, str] = "cpu"):
+        super().__init__()
+        from librosa import filters
+        self.device = device
+        self.autocast = TorchAutocast(enabled=device != "cpu", device_type=self.device, dtype=torch.float32)
+        self.winlen = winlen or 2 ** radix2_exp
+        self.nfft = nfft or self.winlen
+        self.winhop = winhop or (self.winlen // 4)
+        self.sr = sample_rate
+        self.n_chroma = n_chroma
+        self.norm = norm
+        self.argmax = argmax
+        self.window = torch.hann_window(self.winlen).to(device)
+        self.fbanks = torch.from_numpy(filters.chroma(sr=sample_rate, n_fft=self.nfft, tuning=0,
+                                                      n_chroma=self.n_chroma)).to(device)
+        self.spec = torchaudio.transforms.Spectrogram(n_fft=self.nfft, win_length=self.winlen,
+                                                      hop_length=self.winhop, power=2, center=True,
+                                                      pad=0, normalized=True).to(device)
+
+    def forward(self, wav):
+        with self.autocast:
+            T = wav.shape[-1]
+            # in case we are getting a wav that was dropped out (nullified)
+            # make sure wav length is no less that nfft
+            if T < self.nfft:
+                pad = self.nfft - T
+                r = 0 if pad % 2 == 0 else 1
+                wav = F.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0)
+                assert wav.shape[-1] == self.nfft, f'expected len {self.nfft} but got {wav.shape[-1]}'
+            spec = self.spec(wav).squeeze(1)
+            raw_chroma = torch.einsum("cf,...ft->...ct", self.fbanks, spec)
+            norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6)
+            norm_chroma = rearrange(norm_chroma, "b d t -> b t d")
+
+            if self.argmax:
+                idx = norm_chroma.argmax(-1, keepdims=True)
+                norm_chroma[:] = 0
+                norm_chroma.scatter_(dim=-1, index=idx, value=1)
+
+            return norm_chroma
+
+
+def dropout_condition(sample: ConditioningAttributes, condition_type: str, condition: str):
+    """Utility function for nullifying an attribute inside an ConditioningAttributes object.
+    If the condition is of type "wav", then nullify it using "nullify_condition".
+    If the condition is of any other type, set its' value to None.
+    Works in-place.
+    """
+    if condition_type not in ["text", "wav"]:
+        raise ValueError(
+            "dropout_condition got an unexpected condition type!"
+            f" expected 'wav' or 'text' but got '{condition_type}'"
+        )
+
+    if condition not in getattr(sample, condition_type):
+        raise ValueError(
+            "dropout_condition received an unexpected condition!"
+            f" expected wav={sample.wav.keys()} and text={sample.text.keys()}"
+            f"but got '{condition}' of type '{condition_type}'!"
+        )
+
+    if condition_type == "wav":
+        wav, length, path = sample.wav[condition]
+        sample.wav[condition] = nullify_wav(wav)
+    else:
+        sample.text[condition] = None
+
+    return sample
+
+
+class DropoutModule(nn.Module):
+    """Base class for all dropout modules."""
+    def __init__(self, seed: int = 1234):
+        super().__init__()
+        self.rng = torch.Generator()
+        self.rng.manual_seed(seed)
+
+
+class AttributeDropout(DropoutModule):
+    """Applies dropout with a given probability per attribute. This is different from the behavior of
+    ClassifierFreeGuidanceDropout as this allows for attributes to be dropped out separately. For example,
+    "artist" can be dropped while "genre" remains. This is in contrast to ClassifierFreeGuidanceDropout
+    where if "artist" is dropped "genre" must also be dropped.
+
+    Args:
+        p (tp.Dict[str, float]): A dict mapping between attributes and dropout probability. For example:
+            ...
+            "genre": 0.1,
+            "artist": 0.5,
+            "wav": 0.25,
+            ...
+        active_on_eval (bool, optional): Whether the dropout is active at eval. Default to False.
+        seed (int, optional): Random seed.
+    """
+    def __init__(self, p: tp.Dict[str, tp.Dict[str, float]], active_on_eval: bool = False, seed: int = 1234):
+        super().__init__(seed=seed)
+        self.active_on_eval = active_on_eval
+        # construct dict that return the values from p otherwise 0
+        self.p = {}
+        for condition_type, probs in p.items():
+            self.p[condition_type] = defaultdict(lambda: 0, probs)
+
+    def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
+        """
+        Args:
+            samples (tp.List[ConditioningAttributes]): List of conditions.
+        Returns:
+            tp.List[ConditioningAttributes]: List of conditions after certain attributes were set to None.
+        """
+        if not self.training and not self.active_on_eval:
+            return samples
+
+        samples = deepcopy(samples)
+
+        for condition_type, ps in self.p.items():  # for condition types [text, wav]
+            for condition, p in ps.items():  # for attributes of each type (e.g., [artist, genre])
+                if torch.rand(1, generator=self.rng).item() < p:
+                    for sample in samples:
+                        dropout_condition(sample, condition_type, condition)
+
+        return samples
+
+    def __repr__(self):
+        return f"AttributeDropout({dict(self.p)})"
+
+
+class ClassifierFreeGuidanceDropout(DropoutModule):
+    """Applies Classifier Free Guidance dropout, meaning all attributes
+    are dropped with the same probability.
+
+    Args:
+        p (float): Probability to apply condition dropout during training.
+        seed (int): Random seed.
+    """
+    def __init__(self, p: float, seed: int = 1234):
+        super().__init__(seed=seed)
+        self.p = p
+
+    def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
+        """
+        Args:
+            samples (tp.List[ConditioningAttributes]): List of conditions.
+        Returns:
+            tp.List[ConditioningAttributes]: List of conditions after all attributes were set to None.
+        """
+        if not self.training:
+            return samples
+
+        # decide on which attributes to drop in a batched fashion
+        drop = torch.rand(1, generator=self.rng).item() < self.p
+        if not drop:
+            return samples
+
+        # nullify conditions of all attributes
+        samples = deepcopy(samples)
+
+        for condition_type in ["wav", "text"]:
+            for sample in samples:
+                for condition in sample.attributes[condition_type]:
+                    dropout_condition(sample, condition_type, condition)
+
+        return samples
+
+    def __repr__(self):
+        return f"ClassifierFreeGuidanceDropout(p={self.p})"
+
+
+class ConditioningProvider(nn.Module):
+    """Main class to provide conditions given all the supported conditioners.
+
+    Args:
+        conditioners (dict): Dictionary of conditioners.
+        merge_text_conditions_p (float, optional): Probability to merge all text sources
+            into a single text condition. Defaults to 0.
+        drop_desc_p (float, optional): Probability to drop the original description
+            when merging all text sources into a single text condition. Defaults to 0.
+        device (tp.Union[torch.device, str], optional): Device for conditioners and output condition types.
+    """
+    def __init__(
+        self,
+        conditioners: tp.Dict[str, BaseConditioner],
+        merge_text_conditions_p: float = 0,
+        drop_desc_p: float = 0,
+        device: tp.Union[torch.device, str] = "cpu",
+    ):
+        super().__init__()
+        self.device = device
+        self.merge_text_conditions_p = merge_text_conditions_p
+        self.drop_desc_p = drop_desc_p
+        self.conditioners = nn.ModuleDict(conditioners)
+
+    @property
+    def text_conditions(self):
+        return [k for k, v in self.conditioners.items() if isinstance(v, TextConditioner)]
+
+    @property
+    def wav_conditions(self):
+        return [k for k, v in self.conditioners.items() if isinstance(v, WaveformConditioner)]
+
+    @property
+    def has_wav_condition(self):
+        return len(self.wav_conditions) > 0
+
+    def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]:
+        """Match attributes/wavs with existing conditioners in self, and compute tokenize them accordingly.
+        This should be called before starting any real GPU work to avoid synchronization points.
+        This will return a dict matching conditioner names to their arbitrary tokenized representations.
+
+        Args:
+            inputs (list[ConditioningAttribres]): List of ConditioningAttributes objects containing
+                text and wav conditions.
+        """
+        assert all([type(x) == ConditioningAttributes for x in inputs]), \
+            "got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]" \
+            f" but types were {set([type(x) for x in inputs])}"
+
+        output = {}
+        text = self._collate_text(inputs)
+        wavs = self._collate_wavs(inputs)
+
+        assert set(text.keys() | wavs.keys()).issubset(set(self.conditioners.keys())), \
+            f"got an unexpected attribute! Expected {self.conditioners.keys()}, got {text.keys(), wavs.keys()}"
+
+        for attribute, batch in chain(text.items(), wavs.items()):
+            output[attribute] = self.conditioners[attribute].tokenize(batch)
+        return output
+
+    def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]:
+        """Compute pairs of `(embedding, mask)` using the configured conditioners
+        and the tokenized representations. The output is for example:
+
+            {
+                "genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])),
+                "description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])),
+                ...
+            }
+
+        Args:
+            tokenized (dict): Dict of tokenized representations as returned by `tokenize()`.
+        """
+        output = {}
+        for attribute, inputs in tokenized.items():
+            condition, mask = self.conditioners[attribute](inputs)
+            output[attribute] = (condition, mask)
+        return output
+
+    def _collate_text(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.List[tp.Optional[str]]]:
+        """Given a list of ConditioningAttributes objects, compile a dictionary where the keys
+        are the attributes and the values are the aggregated input per attribute.
+        For example:
+        Input:
+        [
+            ConditioningAttributes(text={"genre": "Rock", "description": "A rock song with a guitar solo"}, wav=...),
+            ConditioningAttributes(text={"genre": "Hip-hop", "description": "A hip-hop verse"}, wav=...),
+        ]
+        Output:
+        {
+            "genre": ["Rock", "Hip-hop"],
+            "description": ["A rock song with a guitar solo", "A hip-hop verse"]
+        }
+        """
+        batch_per_attribute: tp.Dict[str, tp.List[tp.Optional[str]]] = defaultdict(list)
+
+        def _merge_conds(cond, merge_text_conditions_p=0, drop_desc_p=0):
+            def is_valid(k, v):
+                k_valid = k in ['key', 'bpm', 'genre', 'moods', 'instrument']
+                v_valid = v is not None and isinstance(v, (int, float, str, list))
+                return k_valid and v_valid
+
+            def process_value(v):
+                if isinstance(v, (int, float, str)):
+                    return v
+                if isinstance(v, list):
+                    return ", ".join(v)
+                else:
+                    RuntimeError(f"unknown type for text value! ({type(v), v})")
+
+            desc = cond.text['description']
+            meta_data = ""
+            if random.uniform(0, 1) < merge_text_conditions_p:
+                meta_pairs = [f'{k}: {process_value(v)}' for k, v in cond.text.items() if is_valid(k, v)]
+                random.shuffle(meta_pairs)
+                meta_data = ". ".join(meta_pairs)
+                desc = desc if not random.uniform(0, 1) < drop_desc_p else None
+
+            if desc is None:
+                desc = meta_data if len(meta_data) > 1 else None
+            else:
+                desc = desc.rstrip('.') + ". " + meta_data
+            cond.text['description'] = desc.strip() if desc else None
+
+        if self.training and self.merge_text_conditions_p:
+            for sample in samples:
+                _merge_conds(sample, self.merge_text_conditions_p, self.drop_desc_p)
+
+        texts = [x.text for x in samples]
+        for text in texts:
+            for condition in self.text_conditions:
+                batch_per_attribute[condition].append(text[condition])
+
+        return batch_per_attribute
+
+    def _collate_wavs(self, samples: tp.List[ConditioningAttributes]):
+        """Generate a dict where the keys are attributes by which we fetch similar wavs,
+        and the values are Tensors of wavs according to said attribtues.
+
+        *Note*: by the time the samples reach this function, each sample should have some waveform
+        inside the "wav" attribute. It should be either:
+        1. A real waveform
+        2. A null waveform due to the sample having no similar waveforms (nullified by the dataset)
+        3. A null waveform due to it being dropped in a dropout module (nullified by dropout)
+
+        Args:
+            samples (tp.List[ConditioningAttributes]): List of ConditioningAttributes samples.
+        Returns:
+            dict: A dicionary mapping an attribute name to wavs.
+        """
+        wavs = defaultdict(list)
+        lens = defaultdict(list)
+        paths = defaultdict(list)
+        out = {}
+
+        for sample in samples:
+            for attribute in self.wav_conditions:
+                wav, length, path = sample.wav[attribute]
+                wavs[attribute].append(wav.flatten())
+                lens[attribute].append(length)
+                paths[attribute].append(path)
+
+        # stack all wavs to a single tensor
+        for attribute in self.wav_conditions:
+            stacked_wav, _ = collate(wavs[attribute], dim=0)
+            out[attribute] = WavCondition(stacked_wav.unsqueeze(1),
+                                          torch.cat(lens['self_wav']), paths[attribute])  # type: ignore
+
+        return out
+
+
+class ConditionFuser(StreamingModule):
+    """Condition fuser handles the logic to combine the different conditions
+    to the actual model input.
+
+    Args:
+        fuse2cond (tp.Dict[str, str]): A dictionary that says how to fuse
+            each condition. For example:
+            {
+                "prepend": ["description"],
+                "sum": ["genre", "bpm"],
+                "cross": ["description"],
+            }
+        cross_attention_pos_emb (bool, optional): Use positional embeddings in cross attention.
+        cross_attention_pos_emb_scale (int): Scale for positional embeddings in cross attention if used.
+    """
+    FUSING_METHODS = ["sum", "prepend", "cross", "input_interpolate"]
+
+    def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]], cross_attention_pos_emb: bool = False,
+                 cross_attention_pos_emb_scale: float = 1.0):
+        super().__init__()
+        assert all(
+            [k in self.FUSING_METHODS for k in fuse2cond.keys()]
+        ), f"got invalid fuse method, allowed methods: {self.FUSING_MEHTODS}"
+        self.cross_attention_pos_emb = cross_attention_pos_emb
+        self.cross_attention_pos_emb_scale = cross_attention_pos_emb_scale
+        self.fuse2cond: tp.Dict[str, tp.List[str]] = fuse2cond
+        self.cond2fuse: tp.Dict[str, str] = {}
+        for fuse_method, conditions in fuse2cond.items():
+            for condition in conditions:
+                self.cond2fuse[condition] = fuse_method
+
+    def forward(
+        self,
+        input: Tensor,
+        conditions: tp.Dict[str, ConditionType]
+    ) -> tp.Tuple[Tensor, tp.Optional[Tensor]]:
+        """Fuse the conditions to the provided model input.
+
+        Args:
+            input (Tensor): Transformer input.
+            conditions (tp.Dict[str, ConditionType]): Dict of conditions.
+        Returns:
+            tp.Tuple[Tensor, Tensor]: The first tensor is the transformer input
+                after the conditions have been fused. The second output tensor is the tensor
+                used for cross-attention or None if no cross attention inputs exist.
+        """
+        B, T, _ = input.shape
+
+        if 'offsets' in self._streaming_state:
+            first_step = False
+            offsets = self._streaming_state['offsets']
+        else:
+            first_step = True
+            offsets = torch.zeros(input.shape[0], dtype=torch.long, device=input.device)
+
+        assert set(conditions.keys()).issubset(set(self.cond2fuse.keys())), \
+            f"given conditions contain unknown attributes for fuser, " \
+            f"expected {self.cond2fuse.keys()}, got {conditions.keys()}"
+        cross_attention_output = None
+        for cond_type, (cond, cond_mask) in conditions.items():
+            op = self.cond2fuse[cond_type]
+            if op == "sum":
+                input += cond
+            elif op == "input_interpolate":
+                cond = rearrange(cond, "b t d -> b d t")
+                cond = F.interpolate(cond, size=input.shape[1])
+                input += rearrange(cond, "b d t -> b t d")
+            elif op == "prepend":
+                if first_step:
+                    input = torch.cat([cond, input], dim=1)
+            elif op == "cross":
+                if cross_attention_output is not None:
+                    cross_attention_output = torch.cat([cross_attention_output, cond], dim=1)
+                else:
+                    cross_attention_output = cond
+            else:
+                raise ValueError(f"unknown op ({op})")
+
+        if self.cross_attention_pos_emb and cross_attention_output is not None:
+            positions = torch.arange(
+                cross_attention_output.shape[1],
+                device=cross_attention_output.device
+            ).view(1, -1, 1)
+            pos_emb = create_sin_embedding(positions, cross_attention_output.shape[-1])
+            cross_attention_output = cross_attention_output + self.cross_attention_pos_emb_scale * pos_emb
+
+        if self._is_streaming:
+            self._streaming_state['offsets'] = offsets + T
+
+        return input, cross_attention_output
+
+
+
+
+
+
+
+

Functions

+
+
+def dropout_condition(sample: ConditioningAttributes, condition_type: str, condition: str) +
+
+

Utility function for nullifying an attribute inside an ConditioningAttributes object. +If the condition is of type "wav", then nullify it using "nullify_condition". +If the condition is of any other type, set its' value to None. +Works in-place.

+
+ +Expand source code + +
def dropout_condition(sample: ConditioningAttributes, condition_type: str, condition: str):
+    """Utility function for nullifying an attribute inside an ConditioningAttributes object.
+    If the condition is of type "wav", then nullify it using "nullify_condition".
+    If the condition is of any other type, set its' value to None.
+    Works in-place.
+    """
+    if condition_type not in ["text", "wav"]:
+        raise ValueError(
+            "dropout_condition got an unexpected condition type!"
+            f" expected 'wav' or 'text' but got '{condition_type}'"
+        )
+
+    if condition not in getattr(sample, condition_type):
+        raise ValueError(
+            "dropout_condition received an unexpected condition!"
+            f" expected wav={sample.wav.keys()} and text={sample.text.keys()}"
+            f"but got '{condition}' of type '{condition_type}'!"
+        )
+
+    if condition_type == "wav":
+        wav, length, path = sample.wav[condition]
+        sample.wav[condition] = nullify_wav(wav)
+    else:
+        sample.text[condition] = None
+
+    return sample
+
+
+
+def nullify_condition(condition: Tuple[torch.Tensor, torch.Tensor], dim: int = 1) +
+
+

This function transforms an input condition to a null condition. +The way it is done by converting it to a single zero vector similarly +to how it is done inside WhiteSpaceTokenizer and NoopTokenizer.

+

Args

+
+
condition : ConditionType
+
a tuple of condition and mask (tp.Tuple[Tensor, Tensor])
+
dim : int
+
the dimension that will be truncated (should be the time dimension)
+
+

WARNING!: dim should not be the batch dimension!

+

Returns

+
+
ConditionType
+
a tuple of null condition and mask
+
+
+ +Expand source code + +
def nullify_condition(condition: ConditionType, dim: int = 1):
+    """This function transforms an input condition to a null condition.
+    The way it is done by converting it to a single zero vector similarly
+    to how it is done inside WhiteSpaceTokenizer and NoopTokenizer.
+
+    Args:
+        condition (ConditionType): a tuple of condition and mask (tp.Tuple[Tensor, Tensor])
+        dim (int): the dimension that will be truncated (should be the time dimension)
+        WARNING!: dim should not be the batch dimension!
+    Returns:
+        ConditionType: a tuple of null condition and mask
+    """
+    assert dim != 0, "dim cannot be the batch dimension!"
+    assert type(condition) == tuple and \
+        type(condition[0]) == Tensor and \
+        type(condition[1]) == Tensor, "'nullify_condition' got an unexpected input type!"
+    cond, mask = condition
+    B = cond.shape[0]
+    last_dim = cond.dim() - 1
+    out = cond.transpose(dim, last_dim)
+    out = 0. * out[..., :1]
+    out = out.transpose(dim, last_dim)
+    mask = torch.zeros((B, 1), device=out.device).int()
+    assert cond.dim() == out.dim()
+    return out, mask
+
+
+
+def nullify_wav(wav: torch.Tensor) ‑> WavCondition +
+
+

Create a nullified WavCondition from a wav tensor with appropriate shape.

+

Args

+
+
wav : Tensor
+
tensor of shape [B, T]
+
+

Returns

+
+
WavCondition
+
wav condition with nullified wav.
+
+
+ +Expand source code + +
def nullify_wav(wav: Tensor) -> WavCondition:
+    """Create a nullified WavCondition from a wav tensor with appropriate shape.
+
+    Args:
+        wav (Tensor): tensor of shape [B, T]
+    Returns:
+        WavCondition: wav condition with nullified wav.
+    """
+    null_wav, _ = nullify_condition((wav, torch.zeros_like(wav)), dim=wav.dim() - 1)
+    return WavCondition(
+        wav=null_wav,
+        length=torch.tensor([0] * wav.shape[0], device=wav.device),
+        path=['null_wav'] * wav.shape[0]
+    )
+
+
+
+
+
+

Classes

+
+
+class AttributeDropout +(p: Dict[str, Dict[str, float]], active_on_eval: bool = False, seed: int = 1234) +
+
+

Applies dropout with a given probability per attribute. This is different from the behavior of +ClassifierFreeGuidanceDropout as this allows for attributes to be dropped out separately. For example, +"artist" can be dropped while "genre" remains. This is in contrast to ClassifierFreeGuidanceDropout +where if "artist" is dropped "genre" must also be dropped.

+

Args

+
+
p : tp.Dict[str, float]
+
A dict mapping between attributes and dropout probability. For example: +… +"genre": 0.1, +"artist": 0.5, +"wav": 0.25, +…
+
active_on_eval : bool, optional
+
Whether the dropout is active at eval. Default to False.
+
seed : int, optional
+
Random seed.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class AttributeDropout(DropoutModule):
+    """Applies dropout with a given probability per attribute. This is different from the behavior of
+    ClassifierFreeGuidanceDropout as this allows for attributes to be dropped out separately. For example,
+    "artist" can be dropped while "genre" remains. This is in contrast to ClassifierFreeGuidanceDropout
+    where if "artist" is dropped "genre" must also be dropped.
+
+    Args:
+        p (tp.Dict[str, float]): A dict mapping between attributes and dropout probability. For example:
+            ...
+            "genre": 0.1,
+            "artist": 0.5,
+            "wav": 0.25,
+            ...
+        active_on_eval (bool, optional): Whether the dropout is active at eval. Default to False.
+        seed (int, optional): Random seed.
+    """
+    def __init__(self, p: tp.Dict[str, tp.Dict[str, float]], active_on_eval: bool = False, seed: int = 1234):
+        super().__init__(seed=seed)
+        self.active_on_eval = active_on_eval
+        # construct dict that return the values from p otherwise 0
+        self.p = {}
+        for condition_type, probs in p.items():
+            self.p[condition_type] = defaultdict(lambda: 0, probs)
+
+    def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
+        """
+        Args:
+            samples (tp.List[ConditioningAttributes]): List of conditions.
+        Returns:
+            tp.List[ConditioningAttributes]: List of conditions after certain attributes were set to None.
+        """
+        if not self.training and not self.active_on_eval:
+            return samples
+
+        samples = deepcopy(samples)
+
+        for condition_type, ps in self.p.items():  # for condition types [text, wav]
+            for condition, p in ps.items():  # for attributes of each type (e.g., [artist, genre])
+                if torch.rand(1, generator=self.rng).item() < p:
+                    for sample in samples:
+                        dropout_condition(sample, condition_type, condition)
+
+        return samples
+
+    def __repr__(self):
+        return f"AttributeDropout({dict(self.p)})"
+
+

Ancestors

+ +

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, samples: List[ConditioningAttributes]) ‑> List[ConditioningAttributes] +
+
+

Args

+
+
samples : tp.List[ConditioningAttributes]
+
List of conditions.
+
+

Returns

+
+
tp.List[ConditioningAttributes]
+
List of conditions after certain attributes were set to None.
+
+
+ +Expand source code + +
def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
+    """
+    Args:
+        samples (tp.List[ConditioningAttributes]): List of conditions.
+    Returns:
+        tp.List[ConditioningAttributes]: List of conditions after certain attributes were set to None.
+    """
+    if not self.training and not self.active_on_eval:
+        return samples
+
+    samples = deepcopy(samples)
+
+    for condition_type, ps in self.p.items():  # for condition types [text, wav]
+        for condition, p in ps.items():  # for attributes of each type (e.g., [artist, genre])
+            if torch.rand(1, generator=self.rng).item() < p:
+                for sample in samples:
+                    dropout_condition(sample, condition_type, condition)
+
+    return samples
+
+
+
+
+
+class BaseConditioner +(dim, output_dim) +
+
+

Base model for all conditioner modules. We allow the output dim to be different +than the hidden dim for two reasons: 1) keep our LUTs small when the vocab is large; +2) make all condition dims consistent.

+

Args

+
+
dim : int
+
Hidden dim of the model (text-encoder/LUT).
+
output_dim : int
+
Output dim of the conditioner.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class BaseConditioner(nn.Module):
+    """Base model for all conditioner modules. We allow the output dim to be different
+    than the hidden dim for two reasons: 1) keep our LUTs small when the vocab is large;
+    2) make all condition dims consistent.
+
+    Args:
+        dim (int): Hidden dim of the model (text-encoder/LUT).
+        output_dim (int): Output dim of the conditioner.
+    """
+    def __init__(self, dim, output_dim):
+        super().__init__()
+        self.dim = dim
+        self.output_dim = output_dim
+        self.output_proj = nn.Linear(dim, output_dim)
+
+    def tokenize(self, *args, **kwargs) -> tp.Any:
+        """Should be any part of the processing that will lead to a synchronization
+        point, e.g. BPE tokenization with transfer to the GPU.
+
+        The returned value will be saved and return later when calling forward().
+        """
+        raise NotImplementedError()
+
+    def forward(self, inputs: tp.Any) -> ConditionType:
+        """Gets input that should be used as conditioning (e.g, genre, description or a waveform).
+        Outputs a ConditionType, after the input data was embedded as a dense vector.
+
+        Returns:
+            ConditionType:
+                - A tensor of size [B, T, D] where B is the batch size, T is the length of the
+                  output embedding and D is the dimension of the embedding.
+                - And a mask indicating where the padding tokens.
+        """
+        raise NotImplementedError()
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Subclasses

+ +

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, inputs: Any) ‑> Tuple[torch.Tensor, torch.Tensor] +
+
+

Gets input that should be used as conditioning (e.g, genre, description or a waveform). +Outputs a ConditionType, after the input data was embedded as a dense vector.

+

Returns

+

ConditionType: +- A tensor of size [B, T, D] where B is the batch size, T is the length of the +output embedding and D is the dimension of the embedding. +- And a mask indicating where the padding tokens.

+
+ +Expand source code + +
def forward(self, inputs: tp.Any) -> ConditionType:
+    """Gets input that should be used as conditioning (e.g, genre, description or a waveform).
+    Outputs a ConditionType, after the input data was embedded as a dense vector.
+
+    Returns:
+        ConditionType:
+            - A tensor of size [B, T, D] where B is the batch size, T is the length of the
+              output embedding and D is the dimension of the embedding.
+            - And a mask indicating where the padding tokens.
+    """
+    raise NotImplementedError()
+
+
+
+def tokenize(self, *args, **kwargs) ‑> Any +
+
+

Should be any part of the processing that will lead to a synchronization +point, e.g. BPE tokenization with transfer to the GPU.

+

The returned value will be saved and return later when calling forward().

+
+ +Expand source code + +
def tokenize(self, *args, **kwargs) -> tp.Any:
+    """Should be any part of the processing that will lead to a synchronization
+    point, e.g. BPE tokenization with transfer to the GPU.
+
+    The returned value will be saved and return later when calling forward().
+    """
+    raise NotImplementedError()
+
+
+
+
+
+class ChromaExtractor +(sample_rate: int, n_chroma: int = 12, radix2_exp: int = 12, nfft: Optional[int] = None, winlen: Optional[int] = None, winhop: Optional[int] = None, argmax: bool = False, norm: float = inf, device: Union[torch.device, str] = 'cpu') +
+
+

Chroma extraction class, handles chroma extraction and quantization.

+

Args

+
+
sample_rate : int
+
Sample rate.
+
n_chroma : int
+
Number of chroma to consider.
+
radix2_exp : int
+
Radix2 exponent.
+
nfft : tp.Optional[int], optional
+
Number of FFT.
+
winlen : tp.Optional[int], optional
+
Window length.
+
winhop : tp.Optional[int], optional
+
Window hop size.
+
argmax : bool, optional
+
Whether to use argmax. Defaults to False.
+
norm : float, optional
+
Norm for chroma normalization. Defaults to inf.
+
device : tp.Union[torch.device, str], optional
+
Device to use. Defaults to cpu.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class ChromaExtractor(nn.Module):
+    """Chroma extraction class, handles chroma extraction and quantization.
+
+    Args:
+        sample_rate (int): Sample rate.
+        n_chroma (int): Number of chroma to consider.
+        radix2_exp (int): Radix2 exponent.
+        nfft (tp.Optional[int], optional): Number of FFT.
+        winlen (tp.Optional[int], optional): Window length.
+        winhop (tp.Optional[int], optional): Window hop size.
+        argmax (bool, optional): Whether to use argmax. Defaults to False.
+        norm (float, optional): Norm for chroma normalization. Defaults to inf.
+        device (tp.Union[torch.device, str], optional): Device to use. Defaults to cpu.
+    """
+    def __init__(self, sample_rate: int, n_chroma: int = 12, radix2_exp: int = 12,
+                 nfft: tp.Optional[int] = None, winlen: tp.Optional[int] = None, winhop: tp.Optional[int] = None,
+                 argmax: bool = False, norm: float = torch.inf, device: tp.Union[torch.device, str] = "cpu"):
+        super().__init__()
+        from librosa import filters
+        self.device = device
+        self.autocast = TorchAutocast(enabled=device != "cpu", device_type=self.device, dtype=torch.float32)
+        self.winlen = winlen or 2 ** radix2_exp
+        self.nfft = nfft or self.winlen
+        self.winhop = winhop or (self.winlen // 4)
+        self.sr = sample_rate
+        self.n_chroma = n_chroma
+        self.norm = norm
+        self.argmax = argmax
+        self.window = torch.hann_window(self.winlen).to(device)
+        self.fbanks = torch.from_numpy(filters.chroma(sr=sample_rate, n_fft=self.nfft, tuning=0,
+                                                      n_chroma=self.n_chroma)).to(device)
+        self.spec = torchaudio.transforms.Spectrogram(n_fft=self.nfft, win_length=self.winlen,
+                                                      hop_length=self.winhop, power=2, center=True,
+                                                      pad=0, normalized=True).to(device)
+
+    def forward(self, wav):
+        with self.autocast:
+            T = wav.shape[-1]
+            # in case we are getting a wav that was dropped out (nullified)
+            # make sure wav length is no less that nfft
+            if T < self.nfft:
+                pad = self.nfft - T
+                r = 0 if pad % 2 == 0 else 1
+                wav = F.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0)
+                assert wav.shape[-1] == self.nfft, f'expected len {self.nfft} but got {wav.shape[-1]}'
+            spec = self.spec(wav).squeeze(1)
+            raw_chroma = torch.einsum("cf,...ft->...ct", self.fbanks, spec)
+            norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6)
+            norm_chroma = rearrange(norm_chroma, "b d t -> b t d")
+
+            if self.argmax:
+                idx = norm_chroma.argmax(-1, keepdims=True)
+                norm_chroma[:] = 0
+                norm_chroma.scatter_(dim=-1, index=idx, value=1)
+
+            return norm_chroma
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, wav) ‑> Callable[..., Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, wav):
+    with self.autocast:
+        T = wav.shape[-1]
+        # in case we are getting a wav that was dropped out (nullified)
+        # make sure wav length is no less that nfft
+        if T < self.nfft:
+            pad = self.nfft - T
+            r = 0 if pad % 2 == 0 else 1
+            wav = F.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0)
+            assert wav.shape[-1] == self.nfft, f'expected len {self.nfft} but got {wav.shape[-1]}'
+        spec = self.spec(wav).squeeze(1)
+        raw_chroma = torch.einsum("cf,...ft->...ct", self.fbanks, spec)
+        norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6)
+        norm_chroma = rearrange(norm_chroma, "b d t -> b t d")
+
+        if self.argmax:
+            idx = norm_chroma.argmax(-1, keepdims=True)
+            norm_chroma[:] = 0
+            norm_chroma.scatter_(dim=-1, index=idx, value=1)
+
+        return norm_chroma
+
+
+
+
+
+class ChromaStemConditioner +(output_dim: int, sample_rate: int, n_chroma: int, radix2_exp: int, duration: float, match_len_on_eval: bool = True, eval_wavs: Optional[str] = None, n_eval_wavs: int = 0, device: Union[torch.device, str] = 'cpu', **kwargs) +
+
+

Chroma conditioner that uses DEMUCS to first filter out drums and bass. The is followed by +the insight the drums and bass often dominate the chroma, leading to the chroma not containing the +information about melody.

+

Args

+
+
output_dim : int
+
Output dimension for the conditioner.
+
sample_rate : int
+
Sample rate for the chroma extractor.
+
n_chroma : int
+
Number of chroma for the chroma extractor.
+
radix2_exp : int
+
Radix2 exponent for the chroma extractor.
+
duration : float
+
Duration used during training. This is later used for correct padding +in case we are using chroma as prefix.
+
match_len_on_eval : bool, optional
+
If True then all chromas are padded to the training +duration. Defaults to False.
+
eval_wavs : str, optional
+
Path to a json egg with waveform, this waveforms are used as +conditions during eval (for cases where we don't want to leak test conditions like MusicCaps). +Defaults to None.
+
n_eval_wavs : int, optional
+
Limits the number of waveforms used for conditioning. Defaults to 0.
+
device : tp.Union[torch.device, str], optional
+
Device for the conditioner.
+
**kwargs
+
Additional parameters for the chroma extractor.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class ChromaStemConditioner(WaveformConditioner):
+    """Chroma conditioner that uses DEMUCS to first filter out drums and bass. The is followed by
+    the insight the drums and bass often dominate the chroma, leading to the chroma not containing the
+    information about melody.
+
+    Args:
+        output_dim (int): Output dimension for the conditioner.
+        sample_rate (int): Sample rate for the chroma extractor.
+        n_chroma (int): Number of chroma for the chroma extractor.
+        radix2_exp (int): Radix2 exponent for the chroma extractor.
+        duration (float): Duration used during training. This is later used for correct padding
+            in case we are using chroma as prefix.
+        match_len_on_eval (bool, optional): If True then all chromas are padded to the training
+            duration. Defaults to False.
+        eval_wavs (str, optional): Path to a json egg with waveform, this waveforms are used as
+            conditions during eval (for cases where we don't want to leak test conditions like MusicCaps).
+            Defaults to None.
+        n_eval_wavs (int, optional): Limits the number of waveforms used for conditioning. Defaults to 0.
+        device (tp.Union[torch.device, str], optional): Device for the conditioner.
+        **kwargs: Additional parameters for the chroma extractor.
+    """
+    def __init__(self, output_dim: int, sample_rate: int, n_chroma: int, radix2_exp: int,
+                 duration: float, match_len_on_eval: bool = True, eval_wavs: tp.Optional[str] = None,
+                 n_eval_wavs: int = 0, device: tp.Union[torch.device, str] = "cpu", **kwargs):
+        from demucs import pretrained
+        super().__init__(dim=n_chroma, output_dim=output_dim, device=device)
+        self.autocast = TorchAutocast(enabled=device != "cpu", device_type=self.device, dtype=torch.float32)
+        self.sample_rate = sample_rate
+        self.match_len_on_eval = match_len_on_eval
+        self.duration = duration
+        self.__dict__["demucs"] = pretrained.get_model('htdemucs').to(device)
+        self.stem2idx = {'drums': 0, 'bass': 1, 'other': 2, 'vocal': 3}
+        self.stem_idx = torch.LongTensor([self.stem2idx['vocal'], self.stem2idx['other']]).to(device)
+        self.chroma = ChromaExtractor(sample_rate=sample_rate, n_chroma=n_chroma, radix2_exp=radix2_exp,
+                                      device=device, **kwargs)
+        self.chroma_len = self._get_chroma_len()
+
+    def _downsampling_factor(self):
+        return self.chroma.winhop
+
+    def _get_chroma_len(self):
+        """Get length of chroma during training"""
+        dummy_wav = torch.zeros((1, self.sample_rate * self.duration), device=self.device)
+        dummy_chr = self.chroma(dummy_wav)
+        return dummy_chr.shape[1]
+
+    @torch.no_grad()
+    def _get_filtered_wav(self, wav):
+        from demucs.apply import apply_model
+        from demucs.audio import convert_audio
+        with self.autocast:
+            wav = convert_audio(wav, self.sample_rate, self.demucs.samplerate, self.demucs.audio_channels)
+            stems = apply_model(self.demucs, wav, device=self.device)
+            stems = stems[:, self.stem_idx]  # extract stem
+            stems = stems.sum(1)  # merge extracted stems
+            stems = stems.mean(1, keepdim=True)  # mono
+            stems = convert_audio(stems, self.demucs.samplerate, self.sample_rate, 1)
+            return stems
+
+    @torch.no_grad()
+    def _get_wav_embedding(self, wav):
+        # avoid 0-size tensors when we are working with null conds
+        if wav.shape[-1] == 1:
+            return self.chroma(wav)
+        stems = self._get_filtered_wav(wav)
+        chroma = self.chroma(stems)
+
+        if self.match_len_on_eval:
+            b, t, c = chroma.shape
+            if t > self.chroma_len:
+                chroma = chroma[:, :self.chroma_len]
+                logger.debug(f'chroma was truncated! ({t} -> {chroma.shape[1]})')
+            elif t < self.chroma_len:
+                # chroma = F.pad(chroma, (0, 0, 0, self.chroma_len - t))
+                n_repeat = int(math.ceil(self.chroma_len / t))
+                chroma = chroma.repeat(1, n_repeat, 1)
+                chroma = chroma[:, :self.chroma_len]
+                logger.debug(f'chroma was zero-padded! ({t} -> {chroma.shape[1]})')
+        return chroma
+
+

Ancestors

+ +

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Inherited members

+ +
+
+class ClassifierFreeGuidanceDropout +(p: float, seed: int = 1234) +
+
+

Applies Classifier Free Guidance dropout, meaning all attributes +are dropped with the same probability.

+

Args

+
+
p : float
+
Probability to apply condition dropout during training.
+
seed : int
+
Random seed.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class ClassifierFreeGuidanceDropout(DropoutModule):
+    """Applies Classifier Free Guidance dropout, meaning all attributes
+    are dropped with the same probability.
+
+    Args:
+        p (float): Probability to apply condition dropout during training.
+        seed (int): Random seed.
+    """
+    def __init__(self, p: float, seed: int = 1234):
+        super().__init__(seed=seed)
+        self.p = p
+
+    def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
+        """
+        Args:
+            samples (tp.List[ConditioningAttributes]): List of conditions.
+        Returns:
+            tp.List[ConditioningAttributes]: List of conditions after all attributes were set to None.
+        """
+        if not self.training:
+            return samples
+
+        # decide on which attributes to drop in a batched fashion
+        drop = torch.rand(1, generator=self.rng).item() < self.p
+        if not drop:
+            return samples
+
+        # nullify conditions of all attributes
+        samples = deepcopy(samples)
+
+        for condition_type in ["wav", "text"]:
+            for sample in samples:
+                for condition in sample.attributes[condition_type]:
+                    dropout_condition(sample, condition_type, condition)
+
+        return samples
+
+    def __repr__(self):
+        return f"ClassifierFreeGuidanceDropout(p={self.p})"
+
+

Ancestors

+ +

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, samples: List[ConditioningAttributes]) ‑> List[ConditioningAttributes] +
+
+

Args

+
+
samples : tp.List[ConditioningAttributes]
+
List of conditions.
+
+

Returns

+
+
tp.List[ConditioningAttributes]
+
List of conditions after all attributes were set to None.
+
+
+ +Expand source code + +
def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
+    """
+    Args:
+        samples (tp.List[ConditioningAttributes]): List of conditions.
+    Returns:
+        tp.List[ConditioningAttributes]: List of conditions after all attributes were set to None.
+    """
+    if not self.training:
+        return samples
+
+    # decide on which attributes to drop in a batched fashion
+    drop = torch.rand(1, generator=self.rng).item() < self.p
+    if not drop:
+        return samples
+
+    # nullify conditions of all attributes
+    samples = deepcopy(samples)
+
+    for condition_type in ["wav", "text"]:
+        for sample in samples:
+            for condition in sample.attributes[condition_type]:
+                dropout_condition(sample, condition_type, condition)
+
+    return samples
+
+
+
+
+
+class ConditionFuser +(fuse2cond: Dict[str, List[str]], cross_attention_pos_emb: bool = False, cross_attention_pos_emb_scale: float = 1.0) +
+
+

Condition fuser handles the logic to combine the different conditions +to the actual model input.

+

Args

+
+
fuse2cond : tp.Dict[str, str]
+
A dictionary that says how to fuse +each condition. For example: +{ +"prepend": ["description"], +"sum": ["genre", "bpm"], +"cross": ["description"], +}
+
cross_attention_pos_emb : bool, optional
+
Use positional embeddings in cross attention.
+
cross_attention_pos_emb_scale : int
+
Scale for positional embeddings in cross attention if used.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class ConditionFuser(StreamingModule):
+    """Condition fuser handles the logic to combine the different conditions
+    to the actual model input.
+
+    Args:
+        fuse2cond (tp.Dict[str, str]): A dictionary that says how to fuse
+            each condition. For example:
+            {
+                "prepend": ["description"],
+                "sum": ["genre", "bpm"],
+                "cross": ["description"],
+            }
+        cross_attention_pos_emb (bool, optional): Use positional embeddings in cross attention.
+        cross_attention_pos_emb_scale (int): Scale for positional embeddings in cross attention if used.
+    """
+    FUSING_METHODS = ["sum", "prepend", "cross", "input_interpolate"]
+
+    def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]], cross_attention_pos_emb: bool = False,
+                 cross_attention_pos_emb_scale: float = 1.0):
+        super().__init__()
+        assert all(
+            [k in self.FUSING_METHODS for k in fuse2cond.keys()]
+        ), f"got invalid fuse method, allowed methods: {self.FUSING_MEHTODS}"
+        self.cross_attention_pos_emb = cross_attention_pos_emb
+        self.cross_attention_pos_emb_scale = cross_attention_pos_emb_scale
+        self.fuse2cond: tp.Dict[str, tp.List[str]] = fuse2cond
+        self.cond2fuse: tp.Dict[str, str] = {}
+        for fuse_method, conditions in fuse2cond.items():
+            for condition in conditions:
+                self.cond2fuse[condition] = fuse_method
+
+    def forward(
+        self,
+        input: Tensor,
+        conditions: tp.Dict[str, ConditionType]
+    ) -> tp.Tuple[Tensor, tp.Optional[Tensor]]:
+        """Fuse the conditions to the provided model input.
+
+        Args:
+            input (Tensor): Transformer input.
+            conditions (tp.Dict[str, ConditionType]): Dict of conditions.
+        Returns:
+            tp.Tuple[Tensor, Tensor]: The first tensor is the transformer input
+                after the conditions have been fused. The second output tensor is the tensor
+                used for cross-attention or None if no cross attention inputs exist.
+        """
+        B, T, _ = input.shape
+
+        if 'offsets' in self._streaming_state:
+            first_step = False
+            offsets = self._streaming_state['offsets']
+        else:
+            first_step = True
+            offsets = torch.zeros(input.shape[0], dtype=torch.long, device=input.device)
+
+        assert set(conditions.keys()).issubset(set(self.cond2fuse.keys())), \
+            f"given conditions contain unknown attributes for fuser, " \
+            f"expected {self.cond2fuse.keys()}, got {conditions.keys()}"
+        cross_attention_output = None
+        for cond_type, (cond, cond_mask) in conditions.items():
+            op = self.cond2fuse[cond_type]
+            if op == "sum":
+                input += cond
+            elif op == "input_interpolate":
+                cond = rearrange(cond, "b t d -> b d t")
+                cond = F.interpolate(cond, size=input.shape[1])
+                input += rearrange(cond, "b d t -> b t d")
+            elif op == "prepend":
+                if first_step:
+                    input = torch.cat([cond, input], dim=1)
+            elif op == "cross":
+                if cross_attention_output is not None:
+                    cross_attention_output = torch.cat([cross_attention_output, cond], dim=1)
+                else:
+                    cross_attention_output = cond
+            else:
+                raise ValueError(f"unknown op ({op})")
+
+        if self.cross_attention_pos_emb and cross_attention_output is not None:
+            positions = torch.arange(
+                cross_attention_output.shape[1],
+                device=cross_attention_output.device
+            ).view(1, -1, 1)
+            pos_emb = create_sin_embedding(positions, cross_attention_output.shape[-1])
+            cross_attention_output = cross_attention_output + self.cross_attention_pos_emb_scale * pos_emb
+
+        if self._is_streaming:
+            self._streaming_state['offsets'] = offsets + T
+
+        return input, cross_attention_output
+
+

Ancestors

+ +

Class variables

+
+
var FUSING_METHODS
+
+
+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, input: torch.Tensor, conditions: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) ‑> Tuple[torch.Tensor, Optional[torch.Tensor]] +
+
+

Fuse the conditions to the provided model input.

+

Args

+
+
input : Tensor
+
Transformer input.
+
conditions : tp.Dict[str, ConditionType]
+
Dict of conditions.
+
+

Returns

+
+
tp.Tuple[Tensor, Tensor]
+
The first tensor is the transformer input +after the conditions have been fused. The second output tensor is the tensor +used for cross-attention or None if no cross attention inputs exist.
+
+
+ +Expand source code + +
def forward(
+    self,
+    input: Tensor,
+    conditions: tp.Dict[str, ConditionType]
+) -> tp.Tuple[Tensor, tp.Optional[Tensor]]:
+    """Fuse the conditions to the provided model input.
+
+    Args:
+        input (Tensor): Transformer input.
+        conditions (tp.Dict[str, ConditionType]): Dict of conditions.
+    Returns:
+        tp.Tuple[Tensor, Tensor]: The first tensor is the transformer input
+            after the conditions have been fused. The second output tensor is the tensor
+            used for cross-attention or None if no cross attention inputs exist.
+    """
+    B, T, _ = input.shape
+
+    if 'offsets' in self._streaming_state:
+        first_step = False
+        offsets = self._streaming_state['offsets']
+    else:
+        first_step = True
+        offsets = torch.zeros(input.shape[0], dtype=torch.long, device=input.device)
+
+    assert set(conditions.keys()).issubset(set(self.cond2fuse.keys())), \
+        f"given conditions contain unknown attributes for fuser, " \
+        f"expected {self.cond2fuse.keys()}, got {conditions.keys()}"
+    cross_attention_output = None
+    for cond_type, (cond, cond_mask) in conditions.items():
+        op = self.cond2fuse[cond_type]
+        if op == "sum":
+            input += cond
+        elif op == "input_interpolate":
+            cond = rearrange(cond, "b t d -> b d t")
+            cond = F.interpolate(cond, size=input.shape[1])
+            input += rearrange(cond, "b d t -> b t d")
+        elif op == "prepend":
+            if first_step:
+                input = torch.cat([cond, input], dim=1)
+        elif op == "cross":
+            if cross_attention_output is not None:
+                cross_attention_output = torch.cat([cross_attention_output, cond], dim=1)
+            else:
+                cross_attention_output = cond
+        else:
+            raise ValueError(f"unknown op ({op})")
+
+    if self.cross_attention_pos_emb and cross_attention_output is not None:
+        positions = torch.arange(
+            cross_attention_output.shape[1],
+            device=cross_attention_output.device
+        ).view(1, -1, 1)
+        pos_emb = create_sin_embedding(positions, cross_attention_output.shape[-1])
+        cross_attention_output = cross_attention_output + self.cross_attention_pos_emb_scale * pos_emb
+
+    if self._is_streaming:
+        self._streaming_state['offsets'] = offsets + T
+
+    return input, cross_attention_output
+
+
+
+

Inherited members

+ +
+
+class ConditioningAttributes +(text: Dict[str, Optional[str]] = <factory>, wav: Dict[str, WavCondition] = <factory>) +
+
+

ConditioningAttributes(text: Dict[str, Union[str, NoneType]] = , wav: Dict[str, audiocraft.modules.conditioners.WavCondition] = )

+
+ +Expand source code + +
class ConditioningAttributes:
+    text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict)
+    wav: tp.Dict[str, WavCondition] = field(default_factory=dict)
+
+    def __getitem__(self, item):
+        return getattr(self, item)
+
+    @property
+    def text_attributes(self):
+        return self.text.keys()
+
+    @property
+    def wav_attributes(self):
+        return self.wav.keys()
+
+    @property
+    def attributes(self):
+        return {"text": self.text_attributes, "wav": self.wav_attributes}
+
+    def to_flat_dict(self):
+        return {
+            **{f"text.{k}": v for k, v in self.text.items()},
+            **{f"wav.{k}": v for k, v in self.wav.items()},
+        }
+
+    @classmethod
+    def from_flat_dict(cls, x):
+        out = cls()
+        for k, v in x.items():
+            kind, att = k.split(".")
+            out[kind][att] = v
+        return out
+
+

Class variables

+
+
var text : Dict[str, Optional[str]]
+
+
+
+
var wav : Dict[str, WavCondition]
+
+
+
+
+

Static methods

+
+
+def from_flat_dict(x) +
+
+
+
+ +Expand source code + +
@classmethod
+def from_flat_dict(cls, x):
+    out = cls()
+    for k, v in x.items():
+        kind, att = k.split(".")
+        out[kind][att] = v
+    return out
+
+
+
+

Instance variables

+
+
var attributes
+
+
+
+ +Expand source code + +
@property
+def attributes(self):
+    return {"text": self.text_attributes, "wav": self.wav_attributes}
+
+
+
var text_attributes
+
+
+
+ +Expand source code + +
@property
+def text_attributes(self):
+    return self.text.keys()
+
+
+
var wav_attributes
+
+
+
+ +Expand source code + +
@property
+def wav_attributes(self):
+    return self.wav.keys()
+
+
+
+

Methods

+
+
+def to_flat_dict(self) +
+
+
+
+ +Expand source code + +
def to_flat_dict(self):
+    return {
+        **{f"text.{k}": v for k, v in self.text.items()},
+        **{f"wav.{k}": v for k, v in self.wav.items()},
+    }
+
+
+
+
+
+class ConditioningProvider +(conditioners: Dict[str, BaseConditioner], merge_text_conditions_p: float = 0, drop_desc_p: float = 0, device: Union[torch.device, str] = 'cpu') +
+
+

Main class to provide conditions given all the supported conditioners.

+

Args

+
+
conditioners : dict
+
Dictionary of conditioners.
+
merge_text_conditions_p : float, optional
+
Probability to merge all text sources +into a single text condition. Defaults to 0.
+
drop_desc_p : float, optional
+
Probability to drop the original description +when merging all text sources into a single text condition. Defaults to 0.
+
device : tp.Union[torch.device, str], optional
+
Device for conditioners and output condition types.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class ConditioningProvider(nn.Module):
+    """Main class to provide conditions given all the supported conditioners.
+
+    Args:
+        conditioners (dict): Dictionary of conditioners.
+        merge_text_conditions_p (float, optional): Probability to merge all text sources
+            into a single text condition. Defaults to 0.
+        drop_desc_p (float, optional): Probability to drop the original description
+            when merging all text sources into a single text condition. Defaults to 0.
+        device (tp.Union[torch.device, str], optional): Device for conditioners and output condition types.
+    """
+    def __init__(
+        self,
+        conditioners: tp.Dict[str, BaseConditioner],
+        merge_text_conditions_p: float = 0,
+        drop_desc_p: float = 0,
+        device: tp.Union[torch.device, str] = "cpu",
+    ):
+        super().__init__()
+        self.device = device
+        self.merge_text_conditions_p = merge_text_conditions_p
+        self.drop_desc_p = drop_desc_p
+        self.conditioners = nn.ModuleDict(conditioners)
+
+    @property
+    def text_conditions(self):
+        return [k for k, v in self.conditioners.items() if isinstance(v, TextConditioner)]
+
+    @property
+    def wav_conditions(self):
+        return [k for k, v in self.conditioners.items() if isinstance(v, WaveformConditioner)]
+
+    @property
+    def has_wav_condition(self):
+        return len(self.wav_conditions) > 0
+
+    def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]:
+        """Match attributes/wavs with existing conditioners in self, and compute tokenize them accordingly.
+        This should be called before starting any real GPU work to avoid synchronization points.
+        This will return a dict matching conditioner names to their arbitrary tokenized representations.
+
+        Args:
+            inputs (list[ConditioningAttribres]): List of ConditioningAttributes objects containing
+                text and wav conditions.
+        """
+        assert all([type(x) == ConditioningAttributes for x in inputs]), \
+            "got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]" \
+            f" but types were {set([type(x) for x in inputs])}"
+
+        output = {}
+        text = self._collate_text(inputs)
+        wavs = self._collate_wavs(inputs)
+
+        assert set(text.keys() | wavs.keys()).issubset(set(self.conditioners.keys())), \
+            f"got an unexpected attribute! Expected {self.conditioners.keys()}, got {text.keys(), wavs.keys()}"
+
+        for attribute, batch in chain(text.items(), wavs.items()):
+            output[attribute] = self.conditioners[attribute].tokenize(batch)
+        return output
+
+    def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]:
+        """Compute pairs of `(embedding, mask)` using the configured conditioners
+        and the tokenized representations. The output is for example:
+
+            {
+                "genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])),
+                "description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])),
+                ...
+            }
+
+        Args:
+            tokenized (dict): Dict of tokenized representations as returned by `tokenize()`.
+        """
+        output = {}
+        for attribute, inputs in tokenized.items():
+            condition, mask = self.conditioners[attribute](inputs)
+            output[attribute] = (condition, mask)
+        return output
+
+    def _collate_text(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.List[tp.Optional[str]]]:
+        """Given a list of ConditioningAttributes objects, compile a dictionary where the keys
+        are the attributes and the values are the aggregated input per attribute.
+        For example:
+        Input:
+        [
+            ConditioningAttributes(text={"genre": "Rock", "description": "A rock song with a guitar solo"}, wav=...),
+            ConditioningAttributes(text={"genre": "Hip-hop", "description": "A hip-hop verse"}, wav=...),
+        ]
+        Output:
+        {
+            "genre": ["Rock", "Hip-hop"],
+            "description": ["A rock song with a guitar solo", "A hip-hop verse"]
+        }
+        """
+        batch_per_attribute: tp.Dict[str, tp.List[tp.Optional[str]]] = defaultdict(list)
+
+        def _merge_conds(cond, merge_text_conditions_p=0, drop_desc_p=0):
+            def is_valid(k, v):
+                k_valid = k in ['key', 'bpm', 'genre', 'moods', 'instrument']
+                v_valid = v is not None and isinstance(v, (int, float, str, list))
+                return k_valid and v_valid
+
+            def process_value(v):
+                if isinstance(v, (int, float, str)):
+                    return v
+                if isinstance(v, list):
+                    return ", ".join(v)
+                else:
+                    RuntimeError(f"unknown type for text value! ({type(v), v})")
+
+            desc = cond.text['description']
+            meta_data = ""
+            if random.uniform(0, 1) < merge_text_conditions_p:
+                meta_pairs = [f'{k}: {process_value(v)}' for k, v in cond.text.items() if is_valid(k, v)]
+                random.shuffle(meta_pairs)
+                meta_data = ". ".join(meta_pairs)
+                desc = desc if not random.uniform(0, 1) < drop_desc_p else None
+
+            if desc is None:
+                desc = meta_data if len(meta_data) > 1 else None
+            else:
+                desc = desc.rstrip('.') + ". " + meta_data
+            cond.text['description'] = desc.strip() if desc else None
+
+        if self.training and self.merge_text_conditions_p:
+            for sample in samples:
+                _merge_conds(sample, self.merge_text_conditions_p, self.drop_desc_p)
+
+        texts = [x.text for x in samples]
+        for text in texts:
+            for condition in self.text_conditions:
+                batch_per_attribute[condition].append(text[condition])
+
+        return batch_per_attribute
+
+    def _collate_wavs(self, samples: tp.List[ConditioningAttributes]):
+        """Generate a dict where the keys are attributes by which we fetch similar wavs,
+        and the values are Tensors of wavs according to said attribtues.
+
+        *Note*: by the time the samples reach this function, each sample should have some waveform
+        inside the "wav" attribute. It should be either:
+        1. A real waveform
+        2. A null waveform due to the sample having no similar waveforms (nullified by the dataset)
+        3. A null waveform due to it being dropped in a dropout module (nullified by dropout)
+
+        Args:
+            samples (tp.List[ConditioningAttributes]): List of ConditioningAttributes samples.
+        Returns:
+            dict: A dicionary mapping an attribute name to wavs.
+        """
+        wavs = defaultdict(list)
+        lens = defaultdict(list)
+        paths = defaultdict(list)
+        out = {}
+
+        for sample in samples:
+            for attribute in self.wav_conditions:
+                wav, length, path = sample.wav[attribute]
+                wavs[attribute].append(wav.flatten())
+                lens[attribute].append(length)
+                paths[attribute].append(path)
+
+        # stack all wavs to a single tensor
+        for attribute in self.wav_conditions:
+            stacked_wav, _ = collate(wavs[attribute], dim=0)
+            out[attribute] = WavCondition(stacked_wav.unsqueeze(1),
+                                          torch.cat(lens['self_wav']), paths[attribute])  # type: ignore
+
+        return out
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Instance variables

+
+
var has_wav_condition
+
+
+
+ +Expand source code + +
@property
+def has_wav_condition(self):
+    return len(self.wav_conditions) > 0
+
+
+
var text_conditions
+
+
+
+ +Expand source code + +
@property
+def text_conditions(self):
+    return [k for k, v in self.conditioners.items() if isinstance(v, TextConditioner)]
+
+
+
var wav_conditions
+
+
+
+ +Expand source code + +
@property
+def wav_conditions(self):
+    return [k for k, v in self.conditioners.items() if isinstance(v, WaveformConditioner)]
+
+
+
+

Methods

+
+
+def forward(self, tokenized: Dict[str, Any]) ‑> Dict[str, Tuple[torch.Tensor, torch.Tensor]] +
+
+

Compute pairs of (embedding, mask) using the configured conditioners +and the tokenized representations. The output is for example:

+
{
+    "genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])),
+    "description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])),
+    ...
+}
+
+

Args

+
+
tokenized : dict
+
Dict of tokenized representations as returned by tokenize().
+
+
+ +Expand source code + +
def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]:
+    """Compute pairs of `(embedding, mask)` using the configured conditioners
+    and the tokenized representations. The output is for example:
+
+        {
+            "genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])),
+            "description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])),
+            ...
+        }
+
+    Args:
+        tokenized (dict): Dict of tokenized representations as returned by `tokenize()`.
+    """
+    output = {}
+    for attribute, inputs in tokenized.items():
+        condition, mask = self.conditioners[attribute](inputs)
+        output[attribute] = (condition, mask)
+    return output
+
+
+
+def tokenize(self, inputs: List[ConditioningAttributes]) ‑> Dict[str, Any] +
+
+

Match attributes/wavs with existing conditioners in self, and compute tokenize them accordingly. +This should be called before starting any real GPU work to avoid synchronization points. +This will return a dict matching conditioner names to their arbitrary tokenized representations.

+

Args

+
+
inputs : list[ConditioningAttribres]
+
List of ConditioningAttributes objects containing +text and wav conditions.
+
+
+ +Expand source code + +
def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]:
+    """Match attributes/wavs with existing conditioners in self, and compute tokenize them accordingly.
+    This should be called before starting any real GPU work to avoid synchronization points.
+    This will return a dict matching conditioner names to their arbitrary tokenized representations.
+
+    Args:
+        inputs (list[ConditioningAttribres]): List of ConditioningAttributes objects containing
+            text and wav conditions.
+    """
+    assert all([type(x) == ConditioningAttributes for x in inputs]), \
+        "got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]" \
+        f" but types were {set([type(x) for x in inputs])}"
+
+    output = {}
+    text = self._collate_text(inputs)
+    wavs = self._collate_wavs(inputs)
+
+    assert set(text.keys() | wavs.keys()).issubset(set(self.conditioners.keys())), \
+        f"got an unexpected attribute! Expected {self.conditioners.keys()}, got {text.keys(), wavs.keys()}"
+
+    for attribute, batch in chain(text.items(), wavs.items()):
+        output[attribute] = self.conditioners[attribute].tokenize(batch)
+    return output
+
+
+
+
+
+class DropoutModule +(seed: int = 1234) +
+
+

Base class for all dropout modules.

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class DropoutModule(nn.Module):
+    """Base class for all dropout modules."""
+    def __init__(self, seed: int = 1234):
+        super().__init__()
+        self.rng = torch.Generator()
+        self.rng.manual_seed(seed)
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Subclasses

+ +

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, *input: Any) ‑> None +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def _forward_unimplemented(self, *input: Any) -> None:
+    r"""Defines the computation performed at every call.
+
+    Should be overridden by all subclasses.
+
+    .. note::
+        Although the recipe for forward pass needs to be defined within
+        this function, one should call the :class:`Module` instance afterwards
+        instead of this since the former takes care of running the
+        registered hooks while the latter silently ignores them.
+    """
+    raise NotImplementedError(f"Module [{type(self).__name__}] is missing the required \"forward\" function")
+
+
+
+
+
+class LUTConditioner +(n_bins: int, dim: int, output_dim: int, tokenizer: str, pad_idx: int = 0) +
+
+

Lookup table TextConditioner.

+

Args

+
+
n_bins : int
+
Number of bins.
+
dim : int
+
Hidden dim of the model (text-encoder/LUT).
+
output_dim : int
+
Output dim of the conditioner.
+
tokenizer : str
+
Name of the tokenizer.
+
pad_idx : int, optional
+
Index for padding token. Defaults to 0.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class LUTConditioner(TextConditioner):
+    """Lookup table TextConditioner.
+
+    Args:
+        n_bins (int): Number of bins.
+        dim (int): Hidden dim of the model (text-encoder/LUT).
+        output_dim (int): Output dim of the conditioner.
+        tokenizer (str): Name of the tokenizer.
+        pad_idx (int, optional): Index for padding token. Defaults to 0.
+    """
+    def __init__(self, n_bins: int, dim: int, output_dim: int, tokenizer: str, pad_idx: int = 0):
+        super().__init__(dim, output_dim)
+        self.embed = nn.Embedding(n_bins, dim)
+        self.tokenizer: Tokenizer
+        if tokenizer == "whitespace":
+            self.tokenizer = WhiteSpaceTokenizer(n_bins, pad_idx=pad_idx)
+        elif tokenizer == "noop":
+            self.tokenizer = NoopTokenizer(n_bins, pad_idx=pad_idx)
+        else:
+            raise ValueError(f"unrecognized tokenizer `{tokenizer}`.")
+
+    def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+        device = self.embed.weight.device
+        tokens, mask = self.tokenizer(x)
+        tokens, mask = tokens.to(device), mask.to(device)
+        return tokens, mask
+
+    def forward(self, inputs: tp.Tuple[torch.Tensor, torch.Tensor]) -> ConditionType:
+        tokens, mask = inputs
+        embeds = self.embed(tokens)
+        embeds = self.output_proj(embeds)
+        embeds = (embeds * mask.unsqueeze(-1))
+        return embeds, mask
+
+

Ancestors

+ +

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Inherited members

+ +
+
+class NoopTokenizer +(n_bins: int, pad_idx: int = 0) +
+
+

This tokenizer should be used for global conditioners such as: artist, genre, key, etc. +The difference between this and WhiteSpaceTokenizer is that NoopTokenizer does not split +strings, so "Jeff Buckley" will get it's own index. Whereas WhiteSpaceTokenizer will +split it to ["Jeff", "Buckley"] and return an index per word.

+

For example: +["Queen", "ABBA", "Jeff Buckley"] => [43, 55, 101] +["Metal", "Rock", "Classical"] => [0, 223, 51]

+
+ +Expand source code + +
class NoopTokenizer(Tokenizer):
+    """This tokenizer should be used for global conditioners such as: artist, genre, key, etc.
+    The difference between this and WhiteSpaceTokenizer is that NoopTokenizer does not split
+    strings, so "Jeff Buckley" will get it's own index. Whereas WhiteSpaceTokenizer will
+    split it to ["Jeff", "Buckley"] and return an index per word.
+
+    For example:
+    ["Queen", "ABBA", "Jeff Buckley"] => [43, 55, 101]
+    ["Metal", "Rock", "Classical"] => [0, 223, 51]
+    """
+    def __init__(self, n_bins: int, pad_idx: int = 0):
+        self.n_bins = n_bins
+        self.pad_idx = pad_idx
+
+    def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[Tensor, Tensor]:
+        output, lengths = [], []
+        for text in texts:
+            # if current sample doesn't have a certain attribute, replace with pad token
+            if text is None:
+                output.append(self.pad_idx)
+                lengths.append(0)
+            else:
+                output.append(hash_trick(text, self.n_bins))
+                lengths.append(1)
+
+        tokens = torch.LongTensor(output).unsqueeze(1)
+        mask = length_to_mask(torch.IntTensor(lengths)).int()
+        return tokens, mask
+
+

Ancestors

+ +
+
+class SegmentWithAttributes +(meta: AudioMeta, seek_time: float, n_frames: int, total_frames: int, sample_rate: int) +
+
+

Base class for all dataclasses that are used for conditioning. +All child classes should implement to_condition_attributes that converts +the existing attributes to a dataclass of type ConditioningAttributes.

+
+ +Expand source code + +
class SegmentWithAttributes(SegmentInfo):
+    """Base class for all dataclasses that are used for conditioning.
+    All child classes should implement `to_condition_attributes` that converts
+    the existing attributes to a dataclass of type ConditioningAttributes.
+    """
+    def to_condition_attributes(self) -> ConditioningAttributes:
+        raise NotImplementedError()
+
+

Ancestors

+ +

Class variables

+
+
var metaAudioMeta
+
+
+
+
var n_frames : int
+
+
+
+
var sample_rate : int
+
+
+
+
var seek_time : float
+
+
+
+
var total_frames : int
+
+
+
+
+

Methods

+
+
+def to_condition_attributes(self) ‑> ConditioningAttributes +
+
+
+
+ +Expand source code + +
def to_condition_attributes(self) -> ConditioningAttributes:
+    raise NotImplementedError()
+
+
+
+
+
+class T5Conditioner +(name: str, output_dim: int, finetune: bool, device: str, autocast_dtype: Optional[str] = 'float32', word_dropout: float = 0.0, normalize_text: bool = False) +
+
+

T5-based TextConditioner.

+

Args

+
+
name : str
+
Name of the T5 model.
+
output_dim : int
+
Output dim of the conditioner.
+
finetune : bool
+
Whether to fine-tune T5 at train time.
+
device : str
+
Device for T5 Conditioner.
+
autocast_dtype : tp.Optional[str], optional
+
Autocast dtype.
+
word_dropout : float, optional
+
Word dropout probability.
+
normalize_text : bool, optional
+
Whether to apply text normalization.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class T5Conditioner(TextConditioner):
+    """T5-based TextConditioner.
+
+    Args:
+        name (str): Name of the T5 model.
+        output_dim (int): Output dim of the conditioner.
+        finetune (bool): Whether to fine-tune T5 at train time.
+        device (str): Device for T5 Conditioner.
+        autocast_dtype (tp.Optional[str], optional): Autocast dtype.
+        word_dropout (float, optional): Word dropout probability.
+        normalize_text (bool, optional): Whether to apply text normalization.
+    """
+    MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b",
+              "google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large",
+              "google/flan-t5-xl", "google/flan-t5-xxl"]
+    MODELS_DIMS = {
+        "t5-small": 512,
+        "t5-base": 768,
+        "t5-large": 1024,
+        "t5-3b": 1024,
+        "t5-11b": 1024,
+        "google/flan-t5-small": 512,
+        "google/flan-t5-base": 768,
+        "google/flan-t5-large": 1024,
+        "google/flan-t5-3b": 1024,
+        "google/flan-t5-11b": 1024,
+    }
+
+    def __init__(self, name: str, output_dim: int, finetune: bool, device: str,
+                 autocast_dtype: tp.Optional[str] = 'float32', word_dropout: float = 0.,
+                 normalize_text: bool = False):
+        assert name in self.MODELS, f"unrecognized t5 model name (should in {self.MODELS})"
+        super().__init__(self.MODELS_DIMS[name], output_dim)
+        self.device = device
+        self.name = name
+        self.finetune = finetune
+        self.word_dropout = word_dropout
+
+        if autocast_dtype is None or self.device == 'cpu':
+            self.autocast = TorchAutocast(enabled=False)
+            if self.device != 'cpu':
+                logger.warning("T5 has no autocast, this might lead to NaN")
+        else:
+            dtype = getattr(torch, autocast_dtype)
+            assert isinstance(dtype, torch.dtype)
+            logger.info(f"T5 will be evaluated with autocast as {autocast_dtype}")
+            self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype)
+        # Let's disable logging temporarily because T5 will vomit some errors otherwise.
+        # thanks https://gist.github.com/simon-weber/7853144
+        previous_level = logging.root.manager.disable
+        logging.disable(logging.ERROR)
+        with warnings.catch_warnings():
+            warnings.simplefilter("ignore")
+            try:
+                self.t5_tokenizer = T5Tokenizer.from_pretrained(name)
+                t5 = T5EncoderModel.from_pretrained(name).train(mode=finetune)
+            finally:
+                logging.disable(previous_level)
+        if finetune:
+            self.t5 = t5
+        else:
+            # this makes sure that the t5 models is not part
+            # of the saved checkpoint
+            self.__dict__["t5"] = t5.to(device)
+
+        self.normalize_text = normalize_text
+        if normalize_text:
+            self.text_normalizer = WhiteSpaceTokenizer(1, lemma=True, stopwords=True)
+
+    def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch.Tensor]:
+        # if current sample doesn't have a certain attribute, replace with empty string
+        entries: tp.List[str] = [xi if xi is not None else "" for xi in x]
+        if self.normalize_text:
+            _, _, entries = self.text_normalizer(entries, return_text=True)
+        if self.word_dropout > 0. and self.training:
+            new_entries = []
+            for entry in entries:
+                words = [word for word in entry.split(" ") if random.random() >= self.word_dropout]
+                new_entries.append(" ".join(words))
+            entries = new_entries
+
+        empty_idx = torch.LongTensor([i for i, xi in enumerate(entries) if xi == ""])
+
+        inputs = self.t5_tokenizer(entries, return_tensors="pt", padding=True).to(self.device)
+        mask = inputs["attention_mask"]
+        mask[empty_idx, :] = 0  # zero-out index where the input is non-existant
+        return inputs
+
+    def forward(self, inputs: tp.Dict[str, torch.Tensor]) -> ConditionType:
+        mask = inputs["attention_mask"]
+        with torch.set_grad_enabled(self.finetune), self.autocast:
+            embeds = self.t5(**inputs).last_hidden_state
+        embeds = self.output_proj(embeds.to(self.output_proj.weight))
+        embeds = (embeds * mask.unsqueeze(-1))
+        return embeds, mask
+
+

Ancestors

+ +

Class variables

+
+
var MODELS
+
+
+
+
var MODELS_DIMS
+
+
+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Inherited members

+ +
+
+class TextConditioner +(dim, output_dim) +
+
+

Base model for all conditioner modules. We allow the output dim to be different +than the hidden dim for two reasons: 1) keep our LUTs small when the vocab is large; +2) make all condition dims consistent.

+

Args

+
+
dim : int
+
Hidden dim of the model (text-encoder/LUT).
+
output_dim : int
+
Output dim of the conditioner.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class TextConditioner(BaseConditioner):
+    ...
+
+

Ancestors

+ +

Subclasses

+ +

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Inherited members

+ +
+
+class Tokenizer +
+
+

Base class for all tokenizers +(in case we want to introduce more advances tokenizers in the future).

+
+ +Expand source code + +
class Tokenizer:
+    """Base class for all tokenizers
+    (in case we want to introduce more advances tokenizers in the future).
+    """
+    def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[Tensor, Tensor]:
+        raise NotImplementedError()
+
+

Subclasses

+ +
+
+class WavCondition +(wav: torch.Tensor, length: torch.Tensor, path: List[Optional[str]] = []) +
+
+

WavCondition(wav, length, path)

+
+ +Expand source code + +
class WavCondition(tp.NamedTuple):
+    wav: Tensor
+    length: Tensor
+    path: tp.List[tp.Optional[str]] = []
+
+

Ancestors

+
    +
  • builtins.tuple
  • +
+

Instance variables

+
+
var length : torch.Tensor
+
+

Alias for field number 1

+
+
var path : List[Optional[str]]
+
+

Alias for field number 2

+
+
var wav : torch.Tensor
+
+

Alias for field number 0

+
+
+
+
+class WaveformConditioner +(dim: int, output_dim: int, device: Union[torch.device, str]) +
+
+

Base class for all conditioners that take a waveform as input. +Classes that inherit must implement _get_wav_embedding that outputs +a continuous tensor, and _downsampling_factor that returns the down-sampling +factor of the embedding model.

+

Args

+
+
dim : int
+
The internal representation dimension.
+
output_dim : int
+
Output dimension.
+
device : tp.Union[torch.device, str]
+
Device.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class WaveformConditioner(BaseConditioner):
+    """Base class for all conditioners that take a waveform as input.
+    Classes that inherit must implement `_get_wav_embedding` that outputs
+    a continuous tensor, and `_downsampling_factor` that returns the down-sampling
+    factor of the embedding model.
+
+    Args:
+        dim (int): The internal representation dimension.
+        output_dim (int): Output dimension.
+        device (tp.Union[torch.device, str]): Device.
+    """
+    def __init__(self, dim: int, output_dim: int, device: tp.Union[torch.device, str]):
+        super().__init__(dim, output_dim)
+        self.device = device
+
+    def tokenize(self, wav_length: WavCondition) -> WavCondition:
+        wav, length, path = wav_length
+        assert length is not None
+        return WavCondition(wav.to(self.device), length.to(self.device), path)
+
+    def _get_wav_embedding(self, wav: Tensor) -> Tensor:
+        """Gets as input a wav and returns a dense vector of conditions."""
+        raise NotImplementedError()
+
+    def _downsampling_factor(self):
+        """Returns the downsampling factor of the embedding model."""
+        raise NotImplementedError()
+
+    def forward(self, inputs: WavCondition) -> ConditionType:
+        """
+        Args:
+            input (WavCondition): Tuple of (waveform, lengths).
+        Returns:
+            ConditionType: Dense vector representing the conditioning along with its' mask.
+        """
+        wav, lengths, path = inputs
+        with torch.no_grad():
+            embeds = self._get_wav_embedding(wav)
+        embeds = embeds.to(self.output_proj.weight)
+        embeds = self.output_proj(embeds)
+
+        if lengths is not None:
+            lengths = lengths / self._downsampling_factor()
+            mask = length_to_mask(lengths, max_len=embeds.shape[1]).int()  # type: ignore
+        else:
+            mask = torch.ones_like(embeds)
+        embeds = (embeds * mask.unsqueeze(2).to(self.device))
+
+        return embeds, mask
+
+

Ancestors

+ +

Subclasses

+ +

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, inputs: WavCondition) ‑> Tuple[torch.Tensor, torch.Tensor] +
+
+

Args

+
+
input : WavCondition
+
Tuple of (waveform, lengths).
+
+

Returns

+
+
ConditionType
+
Dense vector representing the conditioning along with its' mask.
+
+
+ +Expand source code + +
def forward(self, inputs: WavCondition) -> ConditionType:
+    """
+    Args:
+        input (WavCondition): Tuple of (waveform, lengths).
+    Returns:
+        ConditionType: Dense vector representing the conditioning along with its' mask.
+    """
+    wav, lengths, path = inputs
+    with torch.no_grad():
+        embeds = self._get_wav_embedding(wav)
+    embeds = embeds.to(self.output_proj.weight)
+    embeds = self.output_proj(embeds)
+
+    if lengths is not None:
+        lengths = lengths / self._downsampling_factor()
+        mask = length_to_mask(lengths, max_len=embeds.shape[1]).int()  # type: ignore
+    else:
+        mask = torch.ones_like(embeds)
+    embeds = (embeds * mask.unsqueeze(2).to(self.device))
+
+    return embeds, mask
+
+
+
+

Inherited members

+ +
+
+class WhiteSpaceTokenizer +(n_bins: int, pad_idx: int = 0, language: str = 'en_core_web_sm', lemma: bool = True, stopwords: bool = True) +
+
+

This tokenizer should be used for natural language descriptions. +For example: +["he didn't, know he's going home.", 'shorter sentence'] => +[[78, 62, 31, +4, 78, 25, 19, 34], +[59, 77, +0, +0, +0, +0, +0, +0]]

+
+ +Expand source code + +
class WhiteSpaceTokenizer(Tokenizer):
+    """This tokenizer should be used for natural language descriptions.
+    For example:
+    ["he didn't, know he's going home.", 'shorter sentence'] =>
+    [[78, 62, 31,  4, 78, 25, 19, 34],
+    [59, 77,  0,  0,  0,  0,  0,  0]]
+    """
+    PUNCTUATIONS = "?:!.,;"
+
+    def __init__(self, n_bins: int, pad_idx: int = 0, language: str = "en_core_web_sm",
+                 lemma: bool = True, stopwords: bool = True) -> None:
+        self.n_bins = n_bins
+        self.pad_idx = pad_idx
+        self.lemma = lemma
+        self.stopwords = stopwords
+        try:
+            self.nlp = spacy.load(language)
+        except IOError:
+            spacy.cli.download(language)  # type: ignore
+            self.nlp = spacy.load(language)
+
+    @tp.no_type_check
+    def __call__(
+        self,
+        texts: tp.List[tp.Optional[str]],
+        return_text: bool = False
+    ) -> tp.Tuple[Tensor, Tensor]:
+        """Take a list of strings and convert them to a tensor of indices.
+
+        Args:
+            texts (tp.List[str]): List of strings.
+            return_text (bool, optional): Whether to return text as additional tuple item. Defaults to False.
+        Returns:
+            tp.Tuple[Tensor, Tensor]:
+                - Indices of words in the LUT.
+                - And a mask indicating where the padding tokens are
+        """
+        output, lengths = [], []
+        texts = deepcopy(texts)
+        for i, text in enumerate(texts):
+            # if current sample doesn't have a certain attribute, replace with pad token
+            if text is None:
+                output.append(Tensor([self.pad_idx]))
+                lengths.append(0)
+                continue
+
+            # convert numbers to words
+            text = re.sub(r"(\d+)", lambda x: num2words(int(x.group(0))), text)  # type: ignore
+            # normalize text
+            text = self.nlp(text)  # type: ignore
+            # remove stopwords
+            if self.stopwords:
+                text = [w for w in text if not w.is_stop]  # type: ignore
+            # remove punctuations
+            text = [w for w in text if w.text not in self.PUNCTUATIONS]  # type: ignore
+            # lemmatize if needed
+            text = [getattr(t, "lemma_" if self.lemma else "text") for t in text]  # type: ignore
+
+            texts[i] = " ".join(text)
+            lengths.append(len(text))
+            # convert to tensor
+            tokens = Tensor([hash_trick(w, self.n_bins) for w in text])
+            output.append(tokens)
+
+        mask = length_to_mask(torch.IntTensor(lengths)).int()
+        padded_output = pad_sequence(output, padding_value=self.pad_idx).int().t()
+        if return_text:
+            return padded_output, mask, texts  # type: ignore
+        return padded_output, mask
+
+

Ancestors

+ +

Class variables

+
+
var PUNCTUATIONS
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/audiocraft/modules/conv.html b/docs/audiocraft/modules/conv.html new file mode 100644 index 00000000..0c6281f0 --- /dev/null +++ b/docs/audiocraft/modules/conv.html @@ -0,0 +1,1048 @@ + + + + + + +audiocraft.modules.conv API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.modules.conv

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+import typing as tp
+import warnings
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torch.nn.utils import spectral_norm, weight_norm
+
+
+CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
+                                 'time_group_norm'])
+
+
+def apply_parametrization_norm(module: nn.Module, norm: str = 'none'):
+    assert norm in CONV_NORMALIZATIONS
+    if norm == 'weight_norm':
+        return weight_norm(module)
+    elif norm == 'spectral_norm':
+        return spectral_norm(module)
+    else:
+        # We already check was in CONV_NORMALIZATION, so any other choice
+        # doesn't need reparametrization.
+        return module
+
+
+def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs):
+    """Return the proper normalization module. If causal is True, this will ensure the returned
+    module is causal, or return an error if the normalization doesn't support causal evaluation.
+    """
+    assert norm in CONV_NORMALIZATIONS
+    if norm == 'time_group_norm':
+        if causal:
+            raise ValueError("GroupNorm doesn't support causal evaluation.")
+        assert isinstance(module, nn.modules.conv._ConvNd)
+        return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
+    else:
+        return nn.Identity()
+
+
+def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
+                                 padding_total: int = 0) -> int:
+    """See `pad_for_conv1d`.
+    """
+    length = x.shape[-1]
+    n_frames = (length - kernel_size + padding_total) / stride + 1
+    ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
+    return ideal_length - length
+
+
+def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
+    """Pad for a convolution to make sure that the last window is full.
+    Extra padding is added at the end. This is required to ensure that we can rebuild
+    an output of the same length, as otherwise, even with padding, some time steps
+    might get removed.
+    For instance, with total padding = 4, kernel size = 4, stride = 2:
+        0 0 1 2 3 4 5 0 0   # (0s are padding)
+        1   2   3           # (output frames of a convolution, last 0 is never used)
+        0 0 1 2 3 4 5 0     # (output of tr. conv., but pos. 5 is going to get removed as padding)
+            1 2 3 4         # once you removed padding, we are missing one time step !
+    """
+    extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
+    return F.pad(x, (0, extra_padding))
+
+
+def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
+    """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
+    If this is the case, we insert extra 0 padding to the right before the reflection happen.
+    """
+    length = x.shape[-1]
+    padding_left, padding_right = paddings
+    assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
+    if mode == 'reflect':
+        max_pad = max(padding_left, padding_right)
+        extra_pad = 0
+        if length <= max_pad:
+            extra_pad = max_pad - length + 1
+            x = F.pad(x, (0, extra_pad))
+        padded = F.pad(x, paddings, mode, value)
+        end = padded.shape[-1] - extra_pad
+        return padded[..., :end]
+    else:
+        return F.pad(x, paddings, mode, value)
+
+
+def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
+    """Remove padding from x, handling properly zero padding. Only for 1d!
+    """
+    padding_left, padding_right = paddings
+    assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
+    assert (padding_left + padding_right) <= x.shape[-1]
+    end = x.shape[-1] - padding_right
+    return x[..., padding_left: end]
+
+
+class NormConv1d(nn.Module):
+    """Wrapper around Conv1d and normalization applied to this conv
+    to provide a uniform interface across normalization approaches.
+    """
+    def __init__(self, *args, causal: bool = False, norm: str = 'none',
+                 norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
+        super().__init__()
+        self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
+        self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
+        self.norm_type = norm
+
+    def forward(self, x):
+        x = self.conv(x)
+        x = self.norm(x)
+        return x
+
+
+class NormConv2d(nn.Module):
+    """Wrapper around Conv2d and normalization applied to this conv
+    to provide a uniform interface across normalization approaches.
+    """
+    def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
+        super().__init__()
+        self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
+        self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
+        self.norm_type = norm
+
+    def forward(self, x):
+        x = self.conv(x)
+        x = self.norm(x)
+        return x
+
+
+class NormConvTranspose1d(nn.Module):
+    """Wrapper around ConvTranspose1d and normalization applied to this conv
+    to provide a uniform interface across normalization approaches.
+    """
+    def __init__(self, *args, causal: bool = False, norm: str = 'none',
+                 norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
+        super().__init__()
+        self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
+        self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
+        self.norm_type = norm
+
+    def forward(self, x):
+        x = self.convtr(x)
+        x = self.norm(x)
+        return x
+
+
+class NormConvTranspose2d(nn.Module):
+    """Wrapper around ConvTranspose2d and normalization applied to this conv
+    to provide a uniform interface across normalization approaches.
+    """
+    def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
+        super().__init__()
+        self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
+        self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
+
+    def forward(self, x):
+        x = self.convtr(x)
+        x = self.norm(x)
+        return x
+
+
+class StreamableConv1d(nn.Module):
+    """Conv1d with some builtin handling of asymmetric or causal padding
+    and normalization.
+    """
+    def __init__(self, in_channels: int, out_channels: int,
+                 kernel_size: int, stride: int = 1, dilation: int = 1,
+                 groups: int = 1, bias: bool = True, causal: bool = False,
+                 norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
+                 pad_mode: str = 'reflect'):
+        super().__init__()
+        # warn user on unusual setup between dilation and stride
+        if stride > 1 and dilation > 1:
+            warnings.warn('StreamableConv1d has been initialized with stride > 1 and dilation > 1'
+                          f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).')
+        self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
+                               dilation=dilation, groups=groups, bias=bias, causal=causal,
+                               norm=norm, norm_kwargs=norm_kwargs)
+        self.causal = causal
+        self.pad_mode = pad_mode
+
+    def forward(self, x):
+        B, C, T = x.shape
+        kernel_size = self.conv.conv.kernel_size[0]
+        stride = self.conv.conv.stride[0]
+        dilation = self.conv.conv.dilation[0]
+        kernel_size = (kernel_size - 1) * dilation + 1  # effective kernel size with dilations
+        padding_total = kernel_size - stride
+        extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
+        if self.causal:
+            # Left padding for causal
+            x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
+        else:
+            # Asymmetric padding required for odd strides
+            padding_right = padding_total // 2
+            padding_left = padding_total - padding_right
+            x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
+        return self.conv(x)
+
+
+class StreamableConvTranspose1d(nn.Module):
+    """ConvTranspose1d with some builtin handling of asymmetric or causal padding
+    and normalization.
+    """
+    def __init__(self, in_channels: int, out_channels: int,
+                 kernel_size: int, stride: int = 1, causal: bool = False,
+                 norm: str = 'none', trim_right_ratio: float = 1.,
+                 norm_kwargs: tp.Dict[str, tp.Any] = {}):
+        super().__init__()
+        self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
+                                          causal=causal, norm=norm, norm_kwargs=norm_kwargs)
+        self.causal = causal
+        self.trim_right_ratio = trim_right_ratio
+        assert self.causal or self.trim_right_ratio == 1., \
+            "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
+        assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
+
+    def forward(self, x):
+        kernel_size = self.convtr.convtr.kernel_size[0]
+        stride = self.convtr.convtr.stride[0]
+        padding_total = kernel_size - stride
+
+        y = self.convtr(x)
+
+        # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
+        # removed at the very end, when keeping only the right length for the output,
+        # as removing it here would require also passing the length at the matching layer
+        # in the encoder.
+        if self.causal:
+            # Trim the padding on the right according to the specified ratio
+            # if trim_right_ratio = 1.0, trim everything from right
+            padding_right = math.ceil(padding_total * self.trim_right_ratio)
+            padding_left = padding_total - padding_right
+            y = unpad1d(y, (padding_left, padding_right))
+        else:
+            # Asymmetric padding required for odd strides
+            padding_right = padding_total // 2
+            padding_left = padding_total - padding_right
+            y = unpad1d(y, (padding_left, padding_right))
+        return y
+
+
+
+
+
+
+
+

Functions

+
+
+def apply_parametrization_norm(module: torch.nn.modules.module.Module, norm: str = 'none') +
+
+
+
+ +Expand source code + +
def apply_parametrization_norm(module: nn.Module, norm: str = 'none'):
+    assert norm in CONV_NORMALIZATIONS
+    if norm == 'weight_norm':
+        return weight_norm(module)
+    elif norm == 'spectral_norm':
+        return spectral_norm(module)
+    else:
+        # We already check was in CONV_NORMALIZATION, so any other choice
+        # doesn't need reparametrization.
+        return module
+
+
+
+def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0) ‑> int +
+
+ +
+ +Expand source code + +
def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
+                                 padding_total: int = 0) -> int:
+    """See `pad_for_conv1d`.
+    """
+    length = x.shape[-1]
+    n_frames = (length - kernel_size + padding_total) / stride + 1
+    ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
+    return ideal_length - length
+
+
+
+def get_norm_module(module: torch.nn.modules.module.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) +
+
+

Return the proper normalization module. If causal is True, this will ensure the returned +module is causal, or return an error if the normalization doesn't support causal evaluation.

+
+ +Expand source code + +
def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs):
+    """Return the proper normalization module. If causal is True, this will ensure the returned
+    module is causal, or return an error if the normalization doesn't support causal evaluation.
+    """
+    assert norm in CONV_NORMALIZATIONS
+    if norm == 'time_group_norm':
+        if causal:
+            raise ValueError("GroupNorm doesn't support causal evaluation.")
+        assert isinstance(module, nn.modules.conv._ConvNd)
+        return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
+    else:
+        return nn.Identity()
+
+
+
+def pad1d(x: torch.Tensor, paddings: Tuple[int, int], mode: str = 'constant', value: float = 0.0) +
+
+

Tiny wrapper around F.pad, just to allow for reflect padding on small input. +If this is the case, we insert extra 0 padding to the right before the reflection happen.

+
+ +Expand source code + +
def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
+    """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
+    If this is the case, we insert extra 0 padding to the right before the reflection happen.
+    """
+    length = x.shape[-1]
+    padding_left, padding_right = paddings
+    assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
+    if mode == 'reflect':
+        max_pad = max(padding_left, padding_right)
+        extra_pad = 0
+        if length <= max_pad:
+            extra_pad = max_pad - length + 1
+            x = F.pad(x, (0, extra_pad))
+        padded = F.pad(x, paddings, mode, value)
+        end = padded.shape[-1] - extra_pad
+        return padded[..., :end]
+    else:
+        return F.pad(x, paddings, mode, value)
+
+
+
+def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0) +
+
+

Pad for a convolution to make sure that the last window is full. +Extra padding is added at the end. This is required to ensure that we can rebuild +an output of the same length, as otherwise, even with padding, some time steps +might get removed. +For instance, with total padding = 4, kernel size = 4, stride = 2: +0 0 1 2 3 4 5 0 0 +# (0s are padding) +1 +2 +3 +# (output frames of a convolution, last 0 is never used) +0 0 1 2 3 4 5 0 +# (output of tr. conv., but pos. 5 is going to get removed as padding) +1 2 3 4 +# once you removed padding, we are missing one time step !

+
+ +Expand source code + +
def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
+    """Pad for a convolution to make sure that the last window is full.
+    Extra padding is added at the end. This is required to ensure that we can rebuild
+    an output of the same length, as otherwise, even with padding, some time steps
+    might get removed.
+    For instance, with total padding = 4, kernel size = 4, stride = 2:
+        0 0 1 2 3 4 5 0 0   # (0s are padding)
+        1   2   3           # (output frames of a convolution, last 0 is never used)
+        0 0 1 2 3 4 5 0     # (output of tr. conv., but pos. 5 is going to get removed as padding)
+            1 2 3 4         # once you removed padding, we are missing one time step !
+    """
+    extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
+    return F.pad(x, (0, extra_padding))
+
+
+
+def unpad1d(x: torch.Tensor, paddings: Tuple[int, int]) +
+
+

Remove padding from x, handling properly zero padding. Only for 1d!

+
+ +Expand source code + +
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
+    """Remove padding from x, handling properly zero padding. Only for 1d!
+    """
+    padding_left, padding_right = paddings
+    assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
+    assert (padding_left + padding_right) <= x.shape[-1]
+    end = x.shape[-1] - padding_right
+    return x[..., padding_left: end]
+
+
+
+
+
+

Classes

+
+
+class NormConv1d +(*args, causal: bool = False, norm: str = 'none', norm_kwargs: Dict[str, Any] = {}, **kwargs) +
+
+

Wrapper around Conv1d and normalization applied to this conv +to provide a uniform interface across normalization approaches.

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class NormConv1d(nn.Module):
+    """Wrapper around Conv1d and normalization applied to this conv
+    to provide a uniform interface across normalization approaches.
+    """
+    def __init__(self, *args, causal: bool = False, norm: str = 'none',
+                 norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
+        super().__init__()
+        self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
+        self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
+        self.norm_type = norm
+
+    def forward(self, x):
+        x = self.conv(x)
+        x = self.norm(x)
+        return x
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, x) ‑> Callable[..., Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    x = self.conv(x)
+    x = self.norm(x)
+    return x
+
+
+
+
+
+class NormConv2d +(*args, norm: str = 'none', norm_kwargs: Dict[str, Any] = {}, **kwargs) +
+
+

Wrapper around Conv2d and normalization applied to this conv +to provide a uniform interface across normalization approaches.

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class NormConv2d(nn.Module):
+    """Wrapper around Conv2d and normalization applied to this conv
+    to provide a uniform interface across normalization approaches.
+    """
+    def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
+        super().__init__()
+        self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
+        self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
+        self.norm_type = norm
+
+    def forward(self, x):
+        x = self.conv(x)
+        x = self.norm(x)
+        return x
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, x) ‑> Callable[..., Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    x = self.conv(x)
+    x = self.norm(x)
+    return x
+
+
+
+
+
+class NormConvTranspose1d +(*args, causal: bool = False, norm: str = 'none', norm_kwargs: Dict[str, Any] = {}, **kwargs) +
+
+

Wrapper around ConvTranspose1d and normalization applied to this conv +to provide a uniform interface across normalization approaches.

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class NormConvTranspose1d(nn.Module):
+    """Wrapper around ConvTranspose1d and normalization applied to this conv
+    to provide a uniform interface across normalization approaches.
+    """
+    def __init__(self, *args, causal: bool = False, norm: str = 'none',
+                 norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
+        super().__init__()
+        self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
+        self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
+        self.norm_type = norm
+
+    def forward(self, x):
+        x = self.convtr(x)
+        x = self.norm(x)
+        return x
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, x) ‑> Callable[..., Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    x = self.convtr(x)
+    x = self.norm(x)
+    return x
+
+
+
+
+
+class NormConvTranspose2d +(*args, norm: str = 'none', norm_kwargs: Dict[str, Any] = {}, **kwargs) +
+
+

Wrapper around ConvTranspose2d and normalization applied to this conv +to provide a uniform interface across normalization approaches.

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class NormConvTranspose2d(nn.Module):
+    """Wrapper around ConvTranspose2d and normalization applied to this conv
+    to provide a uniform interface across normalization approaches.
+    """
+    def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
+        super().__init__()
+        self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
+        self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
+
+    def forward(self, x):
+        x = self.convtr(x)
+        x = self.norm(x)
+        return x
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, x) ‑> Callable[..., Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    x = self.convtr(x)
+    x = self.norm(x)
+    return x
+
+
+
+
+
+class StreamableConv1d +(in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, dilation: int = 1, groups: int = 1, bias: bool = True, causal: bool = False, norm: str = 'none', norm_kwargs: Dict[str, Any] = {}, pad_mode: str = 'reflect') +
+
+

Conv1d with some builtin handling of asymmetric or causal padding +and normalization.

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class StreamableConv1d(nn.Module):
+    """Conv1d with some builtin handling of asymmetric or causal padding
+    and normalization.
+    """
+    def __init__(self, in_channels: int, out_channels: int,
+                 kernel_size: int, stride: int = 1, dilation: int = 1,
+                 groups: int = 1, bias: bool = True, causal: bool = False,
+                 norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
+                 pad_mode: str = 'reflect'):
+        super().__init__()
+        # warn user on unusual setup between dilation and stride
+        if stride > 1 and dilation > 1:
+            warnings.warn('StreamableConv1d has been initialized with stride > 1 and dilation > 1'
+                          f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).')
+        self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
+                               dilation=dilation, groups=groups, bias=bias, causal=causal,
+                               norm=norm, norm_kwargs=norm_kwargs)
+        self.causal = causal
+        self.pad_mode = pad_mode
+
+    def forward(self, x):
+        B, C, T = x.shape
+        kernel_size = self.conv.conv.kernel_size[0]
+        stride = self.conv.conv.stride[0]
+        dilation = self.conv.conv.dilation[0]
+        kernel_size = (kernel_size - 1) * dilation + 1  # effective kernel size with dilations
+        padding_total = kernel_size - stride
+        extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
+        if self.causal:
+            # Left padding for causal
+            x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
+        else:
+            # Asymmetric padding required for odd strides
+            padding_right = padding_total // 2
+            padding_left = padding_total - padding_right
+            x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
+        return self.conv(x)
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, x) ‑> Callable[..., Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    B, C, T = x.shape
+    kernel_size = self.conv.conv.kernel_size[0]
+    stride = self.conv.conv.stride[0]
+    dilation = self.conv.conv.dilation[0]
+    kernel_size = (kernel_size - 1) * dilation + 1  # effective kernel size with dilations
+    padding_total = kernel_size - stride
+    extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
+    if self.causal:
+        # Left padding for causal
+        x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
+    else:
+        # Asymmetric padding required for odd strides
+        padding_right = padding_total // 2
+        padding_left = padding_total - padding_right
+        x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
+    return self.conv(x)
+
+
+
+
+
+class StreamableConvTranspose1d +(in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, causal: bool = False, norm: str = 'none', trim_right_ratio: float = 1.0, norm_kwargs: Dict[str, Any] = {}) +
+
+

ConvTranspose1d with some builtin handling of asymmetric or causal padding +and normalization.

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class StreamableConvTranspose1d(nn.Module):
+    """ConvTranspose1d with some builtin handling of asymmetric or causal padding
+    and normalization.
+    """
+    def __init__(self, in_channels: int, out_channels: int,
+                 kernel_size: int, stride: int = 1, causal: bool = False,
+                 norm: str = 'none', trim_right_ratio: float = 1.,
+                 norm_kwargs: tp.Dict[str, tp.Any] = {}):
+        super().__init__()
+        self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
+                                          causal=causal, norm=norm, norm_kwargs=norm_kwargs)
+        self.causal = causal
+        self.trim_right_ratio = trim_right_ratio
+        assert self.causal or self.trim_right_ratio == 1., \
+            "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
+        assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
+
+    def forward(self, x):
+        kernel_size = self.convtr.convtr.kernel_size[0]
+        stride = self.convtr.convtr.stride[0]
+        padding_total = kernel_size - stride
+
+        y = self.convtr(x)
+
+        # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
+        # removed at the very end, when keeping only the right length for the output,
+        # as removing it here would require also passing the length at the matching layer
+        # in the encoder.
+        if self.causal:
+            # Trim the padding on the right according to the specified ratio
+            # if trim_right_ratio = 1.0, trim everything from right
+            padding_right = math.ceil(padding_total * self.trim_right_ratio)
+            padding_left = padding_total - padding_right
+            y = unpad1d(y, (padding_left, padding_right))
+        else:
+            # Asymmetric padding required for odd strides
+            padding_right = padding_total // 2
+            padding_left = padding_total - padding_right
+            y = unpad1d(y, (padding_left, padding_right))
+        return y
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, x) ‑> Callable[..., Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    kernel_size = self.convtr.convtr.kernel_size[0]
+    stride = self.convtr.convtr.stride[0]
+    padding_total = kernel_size - stride
+
+    y = self.convtr(x)
+
+    # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
+    # removed at the very end, when keeping only the right length for the output,
+    # as removing it here would require also passing the length at the matching layer
+    # in the encoder.
+    if self.causal:
+        # Trim the padding on the right according to the specified ratio
+        # if trim_right_ratio = 1.0, trim everything from right
+        padding_right = math.ceil(padding_total * self.trim_right_ratio)
+        padding_left = padding_total - padding_right
+        y = unpad1d(y, (padding_left, padding_right))
+    else:
+        # Asymmetric padding required for odd strides
+        padding_right = padding_total // 2
+        padding_left = padding_total - padding_right
+        y = unpad1d(y, (padding_left, padding_right))
+    return y
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/audiocraft/modules/index.html b/docs/audiocraft/modules/index.html new file mode 100644 index 00000000..f012a824 --- /dev/null +++ b/docs/audiocraft/modules/index.html @@ -0,0 +1,131 @@ + + + + + + +audiocraft.modules API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.modules

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# flake8: noqa
+from .conv import (
+    NormConv1d,
+    NormConv2d,
+    NormConvTranspose1d,
+    NormConvTranspose2d,
+    StreamableConv1d,
+    StreamableConvTranspose1d,
+    pad_for_conv1d,
+    pad1d,
+    unpad1d,
+)
+from .lstm import StreamableLSTM
+from .seanet import SEANetEncoder, SEANetDecoder
+
+
+
+

Sub-modules

+
+
audiocraft.modules.activations
+
+
+
+
audiocraft.modules.codebooks_patterns
+
+
+
+
audiocraft.modules.conditioners
+
+
+
+
audiocraft.modules.conv
+
+
+
+
audiocraft.modules.lstm
+
+
+
+
audiocraft.modules.rope
+
+
+
+
audiocraft.modules.seanet
+
+
+
+
audiocraft.modules.streaming
+
+

Streaming module API that should be implemented by all Streaming components,

+
+
audiocraft.modules.transformer
+
+

Transformer model, with streaming support, xformer attention support +and easy causal attention with a potentially finite receptive field …

+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/audiocraft/modules/lstm.html b/docs/audiocraft/modules/lstm.html new file mode 100644 index 00000000..ad20d54e --- /dev/null +++ b/docs/audiocraft/modules/lstm.html @@ -0,0 +1,177 @@ + + + + + + +audiocraft.modules.lstm API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.modules.lstm

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from torch import nn
+
+
+class StreamableLSTM(nn.Module):
+    """LSTM without worrying about the hidden state, nor the layout of the data.
+    Expects input as convolutional layout.
+    """
+    def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
+        super().__init__()
+        self.skip = skip
+        self.lstm = nn.LSTM(dimension, dimension, num_layers)
+
+    def forward(self, x):
+        x = x.permute(2, 0, 1)
+        y, _ = self.lstm(x)
+        if self.skip:
+            y = y + x
+        y = y.permute(1, 2, 0)
+        return y
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class StreamableLSTM +(dimension: int, num_layers: int = 2, skip: bool = True) +
+
+

LSTM without worrying about the hidden state, nor the layout of the data. +Expects input as convolutional layout.

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class StreamableLSTM(nn.Module):
+    """LSTM without worrying about the hidden state, nor the layout of the data.
+    Expects input as convolutional layout.
+    """
+    def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
+        super().__init__()
+        self.skip = skip
+        self.lstm = nn.LSTM(dimension, dimension, num_layers)
+
+    def forward(self, x):
+        x = x.permute(2, 0, 1)
+        y, _ = self.lstm(x)
+        if self.skip:
+            y = y + x
+        y = y.permute(1, 2, 0)
+        return y
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, x) ‑> Callable[..., Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    x = x.permute(2, 0, 1)
+    y, _ = self.lstm(x)
+    if self.skip:
+        y = y + x
+    y = y.permute(1, 2, 0)
+    return y
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/audiocraft/modules/rope.html b/docs/audiocraft/modules/rope.html new file mode 100644 index 00000000..57e7eb5e --- /dev/null +++ b/docs/audiocraft/modules/rope.html @@ -0,0 +1,595 @@ + + + + + + +audiocraft.modules.rope API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.modules.rope

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import typing as tp
+
+from torch import nn
+import torch
+
+
+class XPos(nn.Module):
+    """Length-extrapolatable positional embedding (xPos) from [Sun et al 2022](https://arxiv.org/abs/2212.10554v1).
+    This applies an exponential decay to the RoPE rotation matrix.
+
+    Args:
+        dim (int): Embedding dimension.
+        smoothing (float): Smoothing factor applied to the decay rates.
+        base_scale (int): Base decay rate, given in terms of scaling time.
+        device (torch.device or None): Device on which to initialize the module.
+        dtype (torch.dtype): dtype to use to generate the embedding.
+    """
+    def __init__(self, dim: int, smoothing: float = 0.4, base_scale: int = 512,
+                 device=None, dtype: torch.dtype = torch.float32):
+        super().__init__()
+        assert dim % 2 == 0
+        assert dtype in [torch.float64, torch.float32]
+        self.dtype = dtype
+        self.base_scale = base_scale
+
+        half_dim = dim // 2
+        adim = torch.arange(half_dim, device=device, dtype=dtype)
+        decay_rates = (adim / half_dim + smoothing) / (1.0 + smoothing)
+        self.register_buffer("decay_rates", decay_rates)
+        self.decay: tp.Optional[torch.Tensor] = None
+
+    def get_decay(self, start: int, end: int):
+        """Create complex decay tensor, cache values for fast computation.
+        """
+        if self.decay is None or end > self.decay.shape[0]:
+            assert isinstance(self.decay_rates, torch.Tensor)  # Satisfy type checker.
+            idx = torch.arange(end, device=self.decay_rates.device, dtype=self.dtype)
+            power = idx / self.base_scale
+            scale = self.decay_rates ** power.unsqueeze(-1)
+            self.decay = torch.polar(scale, torch.zeros_like(scale))
+        return self.decay[start:end]  # [T, C/2]
+
+
+class RotaryEmbedding(nn.Module):
+    """Rotary positional embedding (RoPE) from [Su et al 2022](https://arxiv.org/abs/2104.09864).
+
+    Args:
+        dim (int): Embedding dimension (twice the number of frequencies).
+        max_period (float): Maximum period of the rotation frequencies.
+        xpos (bool): Use xPos, applies an exponential decay to rotation matrix.
+        scale (float): Scale of positional embedding, set to 0 to deactivate.
+        device (torch.device or None): Device on which to initialize the module.
+        dtype (torch.dtype): dtype to use to generate the embedding.
+    """
+    def __init__(self, dim: int, max_period: float = 10000.0, xpos: bool = False,
+                 scale: float = 1.0, device=None, dtype: torch.dtype = torch.float32):
+        super().__init__()
+        assert dim % 2 == 0
+        self.scale = scale
+        assert dtype in [torch.float64, torch.float32]
+        self.dtype = dtype
+
+        adim = torch.arange(0, dim, 2, device=device, dtype=dtype)[: (dim // 2)]
+        frequencies = 1.0 / (max_period ** (adim / dim))
+        self.register_buffer("frequencies", frequencies)
+        self.rotation: tp.Optional[torch.Tensor] = None
+
+        self.xpos = XPos(dim, device=device, dtype=dtype) if xpos else None
+
+    def get_rotation(self, start: int, end: int):
+        """Create complex rotation tensor, cache values for fast computation.
+        """
+        if self.rotation is None or end > self.rotation.shape[0]:
+            assert isinstance(self.frequencies, torch.Tensor)  # Satisfy type checker.
+            idx = torch.arange(end, device=self.frequencies.device, dtype=self.dtype)
+            angles = torch.outer(idx, self.frequencies)
+            self.rotation = torch.polar(torch.ones_like(angles), angles)
+        return self.rotation[start:end]
+
+    def rotate(self, x: torch.Tensor, start: int = 0, invert_decay: bool = False):
+        """Apply rope rotation to query or key tensor.
+        """
+        T = x.shape[1]
+        rotation = self.get_rotation(start, start + T).unsqueeze(0).unsqueeze(2)
+
+        if self.xpos:
+            decay = self.xpos.get_decay(start, start + T).unsqueeze(0).unsqueeze(2)
+        else:
+            decay = 1.0
+
+        if invert_decay:
+            decay = decay ** -1
+
+        x_complex = torch.view_as_complex(x.to(self.dtype).reshape(*x.shape[:-1], -1, 2))
+        scaled_rotation = (rotation * decay) * self.scale + (1.0 - self.scale)
+        x_out = torch.view_as_real(x_complex * scaled_rotation).flatten(-2)
+
+        return x_out.type_as(x)
+
+    def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0):
+        """ Apply rope rotation to both query and key tensors.
+        Supports streaming mode, in which query and key are not expected to have the same shape.
+        In streaming mode, key will be of legnth [P + C] with P the cached past timesteps, but
+        query will be [C] (typically C == 1).
+
+        Args:
+            query (torch.Tensor): Query to rotate.
+            key (torch.Tensor): Key to rotate.
+            start (int): Start index of the sequence for time offset.
+        """
+        query_timesteps = query.shape[1]
+        key_timesteps = key.shape[1]
+        streaming_offset = key_timesteps - query_timesteps
+
+        query_out = self.rotate(query, start + streaming_offset)
+        key_out = self.rotate(key, start, invert_decay=True)
+
+        return query_out, key_out
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class RotaryEmbedding +(dim: int, max_period: float = 10000.0, xpos: bool = False, scale: float = 1.0, device=None, dtype: torch.dtype = torch.float32) +
+
+

Rotary positional embedding (RoPE) from Su et al 2022.

+

Args

+
+
dim : int
+
Embedding dimension (twice the number of frequencies).
+
max_period : float
+
Maximum period of the rotation frequencies.
+
xpos : bool
+
Use xPos, applies an exponential decay to rotation matrix.
+
scale : float
+
Scale of positional embedding, set to 0 to deactivate.
+
device : torch.device or None
+
Device on which to initialize the module.
+
dtype : torch.dtype
+
dtype to use to generate the embedding.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class RotaryEmbedding(nn.Module):
+    """Rotary positional embedding (RoPE) from [Su et al 2022](https://arxiv.org/abs/2104.09864).
+
+    Args:
+        dim (int): Embedding dimension (twice the number of frequencies).
+        max_period (float): Maximum period of the rotation frequencies.
+        xpos (bool): Use xPos, applies an exponential decay to rotation matrix.
+        scale (float): Scale of positional embedding, set to 0 to deactivate.
+        device (torch.device or None): Device on which to initialize the module.
+        dtype (torch.dtype): dtype to use to generate the embedding.
+    """
+    def __init__(self, dim: int, max_period: float = 10000.0, xpos: bool = False,
+                 scale: float = 1.0, device=None, dtype: torch.dtype = torch.float32):
+        super().__init__()
+        assert dim % 2 == 0
+        self.scale = scale
+        assert dtype in [torch.float64, torch.float32]
+        self.dtype = dtype
+
+        adim = torch.arange(0, dim, 2, device=device, dtype=dtype)[: (dim // 2)]
+        frequencies = 1.0 / (max_period ** (adim / dim))
+        self.register_buffer("frequencies", frequencies)
+        self.rotation: tp.Optional[torch.Tensor] = None
+
+        self.xpos = XPos(dim, device=device, dtype=dtype) if xpos else None
+
+    def get_rotation(self, start: int, end: int):
+        """Create complex rotation tensor, cache values for fast computation.
+        """
+        if self.rotation is None or end > self.rotation.shape[0]:
+            assert isinstance(self.frequencies, torch.Tensor)  # Satisfy type checker.
+            idx = torch.arange(end, device=self.frequencies.device, dtype=self.dtype)
+            angles = torch.outer(idx, self.frequencies)
+            self.rotation = torch.polar(torch.ones_like(angles), angles)
+        return self.rotation[start:end]
+
+    def rotate(self, x: torch.Tensor, start: int = 0, invert_decay: bool = False):
+        """Apply rope rotation to query or key tensor.
+        """
+        T = x.shape[1]
+        rotation = self.get_rotation(start, start + T).unsqueeze(0).unsqueeze(2)
+
+        if self.xpos:
+            decay = self.xpos.get_decay(start, start + T).unsqueeze(0).unsqueeze(2)
+        else:
+            decay = 1.0
+
+        if invert_decay:
+            decay = decay ** -1
+
+        x_complex = torch.view_as_complex(x.to(self.dtype).reshape(*x.shape[:-1], -1, 2))
+        scaled_rotation = (rotation * decay) * self.scale + (1.0 - self.scale)
+        x_out = torch.view_as_real(x_complex * scaled_rotation).flatten(-2)
+
+        return x_out.type_as(x)
+
+    def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0):
+        """ Apply rope rotation to both query and key tensors.
+        Supports streaming mode, in which query and key are not expected to have the same shape.
+        In streaming mode, key will be of legnth [P + C] with P the cached past timesteps, but
+        query will be [C] (typically C == 1).
+
+        Args:
+            query (torch.Tensor): Query to rotate.
+            key (torch.Tensor): Key to rotate.
+            start (int): Start index of the sequence for time offset.
+        """
+        query_timesteps = query.shape[1]
+        key_timesteps = key.shape[1]
+        streaming_offset = key_timesteps - query_timesteps
+
+        query_out = self.rotate(query, start + streaming_offset)
+        key_out = self.rotate(key, start, invert_decay=True)
+
+        return query_out, key_out
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, *input: Any) ‑> None +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def _forward_unimplemented(self, *input: Any) -> None:
+    r"""Defines the computation performed at every call.
+
+    Should be overridden by all subclasses.
+
+    .. note::
+        Although the recipe for forward pass needs to be defined within
+        this function, one should call the :class:`Module` instance afterwards
+        instead of this since the former takes care of running the
+        registered hooks while the latter silently ignores them.
+    """
+    raise NotImplementedError(f"Module [{type(self).__name__}] is missing the required \"forward\" function")
+
+
+
+def get_rotation(self, start: int, end: int) +
+
+

Create complex rotation tensor, cache values for fast computation.

+
+ +Expand source code + +
def get_rotation(self, start: int, end: int):
+    """Create complex rotation tensor, cache values for fast computation.
+    """
+    if self.rotation is None or end > self.rotation.shape[0]:
+        assert isinstance(self.frequencies, torch.Tensor)  # Satisfy type checker.
+        idx = torch.arange(end, device=self.frequencies.device, dtype=self.dtype)
+        angles = torch.outer(idx, self.frequencies)
+        self.rotation = torch.polar(torch.ones_like(angles), angles)
+    return self.rotation[start:end]
+
+
+
+def rotate(self, x: torch.Tensor, start: int = 0, invert_decay: bool = False) +
+
+

Apply rope rotation to query or key tensor.

+
+ +Expand source code + +
def rotate(self, x: torch.Tensor, start: int = 0, invert_decay: bool = False):
+    """Apply rope rotation to query or key tensor.
+    """
+    T = x.shape[1]
+    rotation = self.get_rotation(start, start + T).unsqueeze(0).unsqueeze(2)
+
+    if self.xpos:
+        decay = self.xpos.get_decay(start, start + T).unsqueeze(0).unsqueeze(2)
+    else:
+        decay = 1.0
+
+    if invert_decay:
+        decay = decay ** -1
+
+    x_complex = torch.view_as_complex(x.to(self.dtype).reshape(*x.shape[:-1], -1, 2))
+    scaled_rotation = (rotation * decay) * self.scale + (1.0 - self.scale)
+    x_out = torch.view_as_real(x_complex * scaled_rotation).flatten(-2)
+
+    return x_out.type_as(x)
+
+
+
+def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0) +
+
+

Apply rope rotation to both query and key tensors. +Supports streaming mode, in which query and key are not expected to have the same shape. +In streaming mode, key will be of legnth [P + C] with P the cached past timesteps, but +query will be [C] (typically C == 1).

+

Args

+
+
query : torch.Tensor
+
Query to rotate.
+
key : torch.Tensor
+
Key to rotate.
+
start : int
+
Start index of the sequence for time offset.
+
+
+ +Expand source code + +
def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0):
+    """ Apply rope rotation to both query and key tensors.
+    Supports streaming mode, in which query and key are not expected to have the same shape.
+    In streaming mode, key will be of legnth [P + C] with P the cached past timesteps, but
+    query will be [C] (typically C == 1).
+
+    Args:
+        query (torch.Tensor): Query to rotate.
+        key (torch.Tensor): Key to rotate.
+        start (int): Start index of the sequence for time offset.
+    """
+    query_timesteps = query.shape[1]
+    key_timesteps = key.shape[1]
+    streaming_offset = key_timesteps - query_timesteps
+
+    query_out = self.rotate(query, start + streaming_offset)
+    key_out = self.rotate(key, start, invert_decay=True)
+
+    return query_out, key_out
+
+
+
+
+
+class XPos +(dim: int, smoothing: float = 0.4, base_scale: int = 512, device=None, dtype: torch.dtype = torch.float32) +
+
+

Length-extrapolatable positional embedding (xPos) from Sun et al 2022. +This applies an exponential decay to the RoPE rotation matrix.

+

Args

+
+
dim : int
+
Embedding dimension.
+
smoothing : float
+
Smoothing factor applied to the decay rates.
+
base_scale : int
+
Base decay rate, given in terms of scaling time.
+
device : torch.device or None
+
Device on which to initialize the module.
+
dtype : torch.dtype
+
dtype to use to generate the embedding.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class XPos(nn.Module):
+    """Length-extrapolatable positional embedding (xPos) from [Sun et al 2022](https://arxiv.org/abs/2212.10554v1).
+    This applies an exponential decay to the RoPE rotation matrix.
+
+    Args:
+        dim (int): Embedding dimension.
+        smoothing (float): Smoothing factor applied to the decay rates.
+        base_scale (int): Base decay rate, given in terms of scaling time.
+        device (torch.device or None): Device on which to initialize the module.
+        dtype (torch.dtype): dtype to use to generate the embedding.
+    """
+    def __init__(self, dim: int, smoothing: float = 0.4, base_scale: int = 512,
+                 device=None, dtype: torch.dtype = torch.float32):
+        super().__init__()
+        assert dim % 2 == 0
+        assert dtype in [torch.float64, torch.float32]
+        self.dtype = dtype
+        self.base_scale = base_scale
+
+        half_dim = dim // 2
+        adim = torch.arange(half_dim, device=device, dtype=dtype)
+        decay_rates = (adim / half_dim + smoothing) / (1.0 + smoothing)
+        self.register_buffer("decay_rates", decay_rates)
+        self.decay: tp.Optional[torch.Tensor] = None
+
+    def get_decay(self, start: int, end: int):
+        """Create complex decay tensor, cache values for fast computation.
+        """
+        if self.decay is None or end > self.decay.shape[0]:
+            assert isinstance(self.decay_rates, torch.Tensor)  # Satisfy type checker.
+            idx = torch.arange(end, device=self.decay_rates.device, dtype=self.dtype)
+            power = idx / self.base_scale
+            scale = self.decay_rates ** power.unsqueeze(-1)
+            self.decay = torch.polar(scale, torch.zeros_like(scale))
+        return self.decay[start:end]  # [T, C/2]
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, *input: Any) ‑> None +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def _forward_unimplemented(self, *input: Any) -> None:
+    r"""Defines the computation performed at every call.
+
+    Should be overridden by all subclasses.
+
+    .. note::
+        Although the recipe for forward pass needs to be defined within
+        this function, one should call the :class:`Module` instance afterwards
+        instead of this since the former takes care of running the
+        registered hooks while the latter silently ignores them.
+    """
+    raise NotImplementedError(f"Module [{type(self).__name__}] is missing the required \"forward\" function")
+
+
+
+def get_decay(self, start: int, end: int) +
+
+

Create complex decay tensor, cache values for fast computation.

+
+ +Expand source code + +
def get_decay(self, start: int, end: int):
+    """Create complex decay tensor, cache values for fast computation.
+    """
+    if self.decay is None or end > self.decay.shape[0]:
+        assert isinstance(self.decay_rates, torch.Tensor)  # Satisfy type checker.
+        idx = torch.arange(end, device=self.decay_rates.device, dtype=self.dtype)
+        power = idx / self.base_scale
+        scale = self.decay_rates ** power.unsqueeze(-1)
+        self.decay = torch.polar(scale, torch.zeros_like(scale))
+    return self.decay[start:end]  # [T, C/2]
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/audiocraft/modules/seanet.html b/docs/audiocraft/modules/seanet.html new file mode 100644 index 00000000..831a462b --- /dev/null +++ b/docs/audiocraft/modules/seanet.html @@ -0,0 +1,879 @@ + + + + + + +audiocraft.modules.seanet API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.modules.seanet

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import typing as tp
+
+import numpy as np
+import torch.nn as nn
+
+from .conv import StreamableConv1d, StreamableConvTranspose1d
+from .lstm import StreamableLSTM
+
+
+class SEANetResnetBlock(nn.Module):
+    """Residual block from SEANet model.
+
+    Args:
+        dim (int): Dimension of the input/output.
+        kernel_sizes (list): List of kernel sizes for the convolutions.
+        dilations (list): List of dilations for the convolutions.
+        activation (str): Activation function.
+        activation_params (dict): Parameters to provide to the activation function.
+        norm (str): Normalization method.
+        norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
+        causal (bool): Whether to use fully causal convolution.
+        pad_mode (str): Padding mode for the convolutions.
+        compress (int): Reduced dimensionality in residual branches (from Demucs v3).
+        true_skip (bool): Whether to use true skip connection or a simple
+            (streamable) convolution as the skip connection.
+    """
+    def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1],
+                 activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
+                 norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, causal: bool = False,
+                 pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True):
+        super().__init__()
+        assert len(kernel_sizes) == len(dilations), 'Number of kernel sizes should match number of dilations'
+        act = getattr(nn, activation)
+        hidden = dim // compress
+        block = []
+        for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
+            in_chs = dim if i == 0 else hidden
+            out_chs = dim if i == len(kernel_sizes) - 1 else hidden
+            block += [
+                act(**activation_params),
+                StreamableConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation,
+                                 norm=norm, norm_kwargs=norm_params,
+                                 causal=causal, pad_mode=pad_mode),
+            ]
+        self.block = nn.Sequential(*block)
+        self.shortcut: nn.Module
+        if true_skip:
+            self.shortcut = nn.Identity()
+        else:
+            self.shortcut = StreamableConv1d(dim, dim, kernel_size=1, norm=norm, norm_kwargs=norm_params,
+                                             causal=causal, pad_mode=pad_mode)
+
+    def forward(self, x):
+        return self.shortcut(x) + self.block(x)
+
+
+class SEANetEncoder(nn.Module):
+    """SEANet encoder.
+
+    Args:
+        channels (int): Audio channels.
+        dimension (int): Intermediate representation dimension.
+        n_filters (int): Base width for the model.
+        n_residual_layers (int): nb of residual layers.
+        ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
+            upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
+            that must match the decoder order. We use the decoder order as some models may only employ the decoder.
+        activation (str): Activation function.
+        activation_params (dict): Parameters to provide to the activation function.
+        norm (str): Normalization method.
+        norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
+        kernel_size (int): Kernel size for the initial convolution.
+        last_kernel_size (int): Kernel size for the initial convolution.
+        residual_kernel_size (int): Kernel size for the residual layers.
+        dilation_base (int): How much to increase the dilation with each layer.
+        causal (bool): Whether to use fully causal convolution.
+        pad_mode (str): Padding mode for the convolutions.
+        true_skip (bool): Whether to use true skip connection or a simple
+            (streamable) convolution as the skip connection in the residual network blocks.
+        compress (int): Reduced dimensionality in residual branches (from Demucs v3).
+        lstm (int): Number of LSTM layers at the end of the encoder.
+        disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
+            For the encoder, it corresponds to the N first blocks.
+    """
+    def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
+                 ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
+                 norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
+                 last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
+                 pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0,
+                 disable_norm_outer_blocks: int = 0):
+        super().__init__()
+        self.channels = channels
+        self.dimension = dimension
+        self.n_filters = n_filters
+        self.ratios = list(reversed(ratios))
+        del ratios
+        self.n_residual_layers = n_residual_layers
+        self.hop_length = np.prod(self.ratios)
+        self.n_blocks = len(self.ratios) + 2  # first and last conv + residual blocks
+        self.disable_norm_outer_blocks = disable_norm_outer_blocks
+        assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \
+            "Number of blocks for which to disable norm is invalid." \
+            "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
+
+        act = getattr(nn, activation)
+        mult = 1
+        model: tp.List[nn.Module] = [
+            StreamableConv1d(channels, mult * n_filters, kernel_size,
+                             norm='none' if self.disable_norm_outer_blocks >= 1 else norm,
+                             norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
+        ]
+        # Downsample to raw audio scale
+        for i, ratio in enumerate(self.ratios):
+            block_norm = 'none' if self.disable_norm_outer_blocks >= i + 2 else norm
+            # Add residual layers
+            for j in range(n_residual_layers):
+                model += [
+                    SEANetResnetBlock(mult * n_filters, kernel_sizes=[residual_kernel_size, 1],
+                                      dilations=[dilation_base ** j, 1],
+                                      norm=block_norm, norm_params=norm_params,
+                                      activation=activation, activation_params=activation_params,
+                                      causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
+
+            # Add downsampling layers
+            model += [
+                act(**activation_params),
+                StreamableConv1d(mult * n_filters, mult * n_filters * 2,
+                                 kernel_size=ratio * 2, stride=ratio,
+                                 norm=block_norm, norm_kwargs=norm_params,
+                                 causal=causal, pad_mode=pad_mode),
+            ]
+            mult *= 2
+
+        if lstm:
+            model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
+
+        model += [
+            act(**activation_params),
+            StreamableConv1d(mult * n_filters, dimension, last_kernel_size,
+                             norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm,
+                             norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
+        ]
+
+        self.model = nn.Sequential(*model)
+
+    def forward(self, x):
+        return self.model(x)
+
+
+class SEANetDecoder(nn.Module):
+    """SEANet decoder.
+
+    Args:
+        channels (int): Audio channels.
+        dimension (int): Intermediate representation dimension.
+        n_filters (int): Base width for the model.
+        n_residual_layers (int): nb of residual layers.
+        ratios (Sequence[int]): kernel size and stride ratios.
+        activation (str): Activation function.
+        activation_params (dict): Parameters to provide to the activation function.
+        final_activation (str): Final activation function after all convolutions.
+        final_activation_params (dict): Parameters to provide to the activation function.
+        norm (str): Normalization method.
+        norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
+        kernel_size (int): Kernel size for the initial convolution.
+        last_kernel_size (int): Kernel size for the initial convolution.
+        residual_kernel_size (int): Kernel size for the residual layers.
+        dilation_base (int): How much to increase the dilation with each layer.
+        causal (bool): Whether to use fully causal convolution.
+        pad_mode (str): Padding mode for the convolutions.
+        true_skip (bool): Whether to use true skip connection or a simple.
+            (streamable) convolution as the skip connection in the residual network blocks.
+        compress (int): Reduced dimensionality in residual branches (from Demucs v3).
+        lstm (int): Number of LSTM layers at the end of the encoder.
+        disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
+            For the decoder, it corresponds to the N last blocks.
+        trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
+            If equal to 1.0, it means that all the trimming is done at the right.
+    """
+    def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
+                 ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
+                 final_activation: tp.Optional[str] = None, final_activation_params: tp.Optional[dict] = None,
+                 norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
+                 last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
+                 pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0,
+                 disable_norm_outer_blocks: int = 0, trim_right_ratio: float = 1.0):
+        super().__init__()
+        self.dimension = dimension
+        self.channels = channels
+        self.n_filters = n_filters
+        self.ratios = ratios
+        del ratios
+        self.n_residual_layers = n_residual_layers
+        self.hop_length = np.prod(self.ratios)
+        self.n_blocks = len(self.ratios) + 2  # first and last conv + residual blocks
+        self.disable_norm_outer_blocks = disable_norm_outer_blocks
+        assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \
+            "Number of blocks for which to disable norm is invalid." \
+            "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
+
+        act = getattr(nn, activation)
+        mult = int(2 ** len(self.ratios))
+        model: tp.List[nn.Module] = [
+            StreamableConv1d(dimension, mult * n_filters, kernel_size,
+                             norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm,
+                             norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
+        ]
+
+        if lstm:
+            model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
+
+        # Upsample to raw audio scale
+        for i, ratio in enumerate(self.ratios):
+            block_norm = 'none' if self.disable_norm_outer_blocks >= self.n_blocks - (i + 1) else norm
+            # Add upsampling layers
+            model += [
+                act(**activation_params),
+                StreamableConvTranspose1d(mult * n_filters, mult * n_filters // 2,
+                                          kernel_size=ratio * 2, stride=ratio,
+                                          norm=block_norm, norm_kwargs=norm_params,
+                                          causal=causal, trim_right_ratio=trim_right_ratio),
+            ]
+            # Add residual layers
+            for j in range(n_residual_layers):
+                model += [
+                    SEANetResnetBlock(mult * n_filters // 2, kernel_sizes=[residual_kernel_size, 1],
+                                      dilations=[dilation_base ** j, 1],
+                                      activation=activation, activation_params=activation_params,
+                                      norm=block_norm, norm_params=norm_params, causal=causal,
+                                      pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
+
+            mult //= 2
+
+        # Add final layers
+        model += [
+            act(**activation_params),
+            StreamableConv1d(n_filters, channels, last_kernel_size,
+                             norm='none' if self.disable_norm_outer_blocks >= 1 else norm,
+                             norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
+        ]
+        # Add optional final activation to decoder (eg. tanh)
+        if final_activation is not None:
+            final_act = getattr(nn, final_activation)
+            final_activation_params = final_activation_params or {}
+            model += [
+                final_act(**final_activation_params)
+            ]
+        self.model = nn.Sequential(*model)
+
+    def forward(self, z):
+        y = self.model(z)
+        return y
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class SEANetDecoder +(channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3, ratios: List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, final_activation: Optional[str] = None, final_activation_params: Optional[dict] = None, norm: str = 'none', norm_params: Dict[str, Any] = {}, kernel_size: int = 7, last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False, pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0, disable_norm_outer_blocks: int = 0, trim_right_ratio: float = 1.0) +
+
+

SEANet decoder.

+

Args

+
+
channels : int
+
Audio channels.
+
dimension : int
+
Intermediate representation dimension.
+
n_filters : int
+
Base width for the model.
+
n_residual_layers : int
+
nb of residual layers.
+
ratios : Sequence[int]
+
kernel size and stride ratios.
+
activation : str
+
Activation function.
+
activation_params : dict
+
Parameters to provide to the activation function.
+
final_activation : str
+
Final activation function after all convolutions.
+
final_activation_params : dict
+
Parameters to provide to the activation function.
+
norm : str
+
Normalization method.
+
norm_params : dict
+
Parameters to provide to the underlying normalization used along with the convolution.
+
kernel_size : int
+
Kernel size for the initial convolution.
+
last_kernel_size : int
+
Kernel size for the initial convolution.
+
residual_kernel_size : int
+
Kernel size for the residual layers.
+
dilation_base : int
+
How much to increase the dilation with each layer.
+
causal : bool
+
Whether to use fully causal convolution.
+
pad_mode : str
+
Padding mode for the convolutions.
+
true_skip : bool
+
Whether to use true skip connection or a simple. +(streamable) convolution as the skip connection in the residual network blocks.
+
compress : int
+
Reduced dimensionality in residual branches (from Demucs v3).
+
lstm : int
+
Number of LSTM layers at the end of the encoder.
+
disable_norm_outer_blocks : int
+
Number of blocks for which we don't apply norm. +For the decoder, it corresponds to the N last blocks.
+
trim_right_ratio : float
+
Ratio for trimming at the right of the transposed convolution under the causal setup. +If equal to 1.0, it means that all the trimming is done at the right.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class SEANetDecoder(nn.Module):
+    """SEANet decoder.
+
+    Args:
+        channels (int): Audio channels.
+        dimension (int): Intermediate representation dimension.
+        n_filters (int): Base width for the model.
+        n_residual_layers (int): nb of residual layers.
+        ratios (Sequence[int]): kernel size and stride ratios.
+        activation (str): Activation function.
+        activation_params (dict): Parameters to provide to the activation function.
+        final_activation (str): Final activation function after all convolutions.
+        final_activation_params (dict): Parameters to provide to the activation function.
+        norm (str): Normalization method.
+        norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
+        kernel_size (int): Kernel size for the initial convolution.
+        last_kernel_size (int): Kernel size for the initial convolution.
+        residual_kernel_size (int): Kernel size for the residual layers.
+        dilation_base (int): How much to increase the dilation with each layer.
+        causal (bool): Whether to use fully causal convolution.
+        pad_mode (str): Padding mode for the convolutions.
+        true_skip (bool): Whether to use true skip connection or a simple.
+            (streamable) convolution as the skip connection in the residual network blocks.
+        compress (int): Reduced dimensionality in residual branches (from Demucs v3).
+        lstm (int): Number of LSTM layers at the end of the encoder.
+        disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
+            For the decoder, it corresponds to the N last blocks.
+        trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
+            If equal to 1.0, it means that all the trimming is done at the right.
+    """
+    def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
+                 ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
+                 final_activation: tp.Optional[str] = None, final_activation_params: tp.Optional[dict] = None,
+                 norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
+                 last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
+                 pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0,
+                 disable_norm_outer_blocks: int = 0, trim_right_ratio: float = 1.0):
+        super().__init__()
+        self.dimension = dimension
+        self.channels = channels
+        self.n_filters = n_filters
+        self.ratios = ratios
+        del ratios
+        self.n_residual_layers = n_residual_layers
+        self.hop_length = np.prod(self.ratios)
+        self.n_blocks = len(self.ratios) + 2  # first and last conv + residual blocks
+        self.disable_norm_outer_blocks = disable_norm_outer_blocks
+        assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \
+            "Number of blocks for which to disable norm is invalid." \
+            "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
+
+        act = getattr(nn, activation)
+        mult = int(2 ** len(self.ratios))
+        model: tp.List[nn.Module] = [
+            StreamableConv1d(dimension, mult * n_filters, kernel_size,
+                             norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm,
+                             norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
+        ]
+
+        if lstm:
+            model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
+
+        # Upsample to raw audio scale
+        for i, ratio in enumerate(self.ratios):
+            block_norm = 'none' if self.disable_norm_outer_blocks >= self.n_blocks - (i + 1) else norm
+            # Add upsampling layers
+            model += [
+                act(**activation_params),
+                StreamableConvTranspose1d(mult * n_filters, mult * n_filters // 2,
+                                          kernel_size=ratio * 2, stride=ratio,
+                                          norm=block_norm, norm_kwargs=norm_params,
+                                          causal=causal, trim_right_ratio=trim_right_ratio),
+            ]
+            # Add residual layers
+            for j in range(n_residual_layers):
+                model += [
+                    SEANetResnetBlock(mult * n_filters // 2, kernel_sizes=[residual_kernel_size, 1],
+                                      dilations=[dilation_base ** j, 1],
+                                      activation=activation, activation_params=activation_params,
+                                      norm=block_norm, norm_params=norm_params, causal=causal,
+                                      pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
+
+            mult //= 2
+
+        # Add final layers
+        model += [
+            act(**activation_params),
+            StreamableConv1d(n_filters, channels, last_kernel_size,
+                             norm='none' if self.disable_norm_outer_blocks >= 1 else norm,
+                             norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
+        ]
+        # Add optional final activation to decoder (eg. tanh)
+        if final_activation is not None:
+            final_act = getattr(nn, final_activation)
+            final_activation_params = final_activation_params or {}
+            model += [
+                final_act(**final_activation_params)
+            ]
+        self.model = nn.Sequential(*model)
+
+    def forward(self, z):
+        y = self.model(z)
+        return y
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, z) ‑> Callable[..., Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, z):
+    y = self.model(z)
+    return y
+
+
+
+
+
+class SEANetEncoder +(channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3, ratios: List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, norm: str = 'none', norm_params: Dict[str, Any] = {}, kernel_size: int = 7, last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False, pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0, disable_norm_outer_blocks: int = 0) +
+
+

SEANet encoder.

+

Args

+
+
channels : int
+
Audio channels.
+
dimension : int
+
Intermediate representation dimension.
+
n_filters : int
+
Base width for the model.
+
n_residual_layers : int
+
nb of residual layers.
+
ratios : Sequence[int]
+
kernel size and stride ratios. The encoder uses downsampling ratios instead of +upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here +that must match the decoder order. We use the decoder order as some models may only employ the decoder.
+
activation : str
+
Activation function.
+
activation_params : dict
+
Parameters to provide to the activation function.
+
norm : str
+
Normalization method.
+
norm_params : dict
+
Parameters to provide to the underlying normalization used along with the convolution.
+
kernel_size : int
+
Kernel size for the initial convolution.
+
last_kernel_size : int
+
Kernel size for the initial convolution.
+
residual_kernel_size : int
+
Kernel size for the residual layers.
+
dilation_base : int
+
How much to increase the dilation with each layer.
+
causal : bool
+
Whether to use fully causal convolution.
+
pad_mode : str
+
Padding mode for the convolutions.
+
true_skip : bool
+
Whether to use true skip connection or a simple +(streamable) convolution as the skip connection in the residual network blocks.
+
compress : int
+
Reduced dimensionality in residual branches (from Demucs v3).
+
lstm : int
+
Number of LSTM layers at the end of the encoder.
+
disable_norm_outer_blocks : int
+
Number of blocks for which we don't apply norm. +For the encoder, it corresponds to the N first blocks.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class SEANetEncoder(nn.Module):
+    """SEANet encoder.
+
+    Args:
+        channels (int): Audio channels.
+        dimension (int): Intermediate representation dimension.
+        n_filters (int): Base width for the model.
+        n_residual_layers (int): nb of residual layers.
+        ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
+            upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
+            that must match the decoder order. We use the decoder order as some models may only employ the decoder.
+        activation (str): Activation function.
+        activation_params (dict): Parameters to provide to the activation function.
+        norm (str): Normalization method.
+        norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
+        kernel_size (int): Kernel size for the initial convolution.
+        last_kernel_size (int): Kernel size for the initial convolution.
+        residual_kernel_size (int): Kernel size for the residual layers.
+        dilation_base (int): How much to increase the dilation with each layer.
+        causal (bool): Whether to use fully causal convolution.
+        pad_mode (str): Padding mode for the convolutions.
+        true_skip (bool): Whether to use true skip connection or a simple
+            (streamable) convolution as the skip connection in the residual network blocks.
+        compress (int): Reduced dimensionality in residual branches (from Demucs v3).
+        lstm (int): Number of LSTM layers at the end of the encoder.
+        disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
+            For the encoder, it corresponds to the N first blocks.
+    """
+    def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
+                 ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
+                 norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
+                 last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
+                 pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0,
+                 disable_norm_outer_blocks: int = 0):
+        super().__init__()
+        self.channels = channels
+        self.dimension = dimension
+        self.n_filters = n_filters
+        self.ratios = list(reversed(ratios))
+        del ratios
+        self.n_residual_layers = n_residual_layers
+        self.hop_length = np.prod(self.ratios)
+        self.n_blocks = len(self.ratios) + 2  # first and last conv + residual blocks
+        self.disable_norm_outer_blocks = disable_norm_outer_blocks
+        assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \
+            "Number of blocks for which to disable norm is invalid." \
+            "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
+
+        act = getattr(nn, activation)
+        mult = 1
+        model: tp.List[nn.Module] = [
+            StreamableConv1d(channels, mult * n_filters, kernel_size,
+                             norm='none' if self.disable_norm_outer_blocks >= 1 else norm,
+                             norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
+        ]
+        # Downsample to raw audio scale
+        for i, ratio in enumerate(self.ratios):
+            block_norm = 'none' if self.disable_norm_outer_blocks >= i + 2 else norm
+            # Add residual layers
+            for j in range(n_residual_layers):
+                model += [
+                    SEANetResnetBlock(mult * n_filters, kernel_sizes=[residual_kernel_size, 1],
+                                      dilations=[dilation_base ** j, 1],
+                                      norm=block_norm, norm_params=norm_params,
+                                      activation=activation, activation_params=activation_params,
+                                      causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
+
+            # Add downsampling layers
+            model += [
+                act(**activation_params),
+                StreamableConv1d(mult * n_filters, mult * n_filters * 2,
+                                 kernel_size=ratio * 2, stride=ratio,
+                                 norm=block_norm, norm_kwargs=norm_params,
+                                 causal=causal, pad_mode=pad_mode),
+            ]
+            mult *= 2
+
+        if lstm:
+            model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
+
+        model += [
+            act(**activation_params),
+            StreamableConv1d(mult * n_filters, dimension, last_kernel_size,
+                             norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm,
+                             norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
+        ]
+
+        self.model = nn.Sequential(*model)
+
+    def forward(self, x):
+        return self.model(x)
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, x) ‑> Callable[..., Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    return self.model(x)
+
+
+
+
+
+class SEANetResnetBlock +(dim: int, kernel_sizes: List[int] = [3, 1], dilations: List[int] = [1, 1], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, norm: str = 'none', norm_params: Dict[str, Any] = {}, causal: bool = False, pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True) +
+
+

Residual block from SEANet model.

+

Args

+
+
dim : int
+
Dimension of the input/output.
+
kernel_sizes : list
+
List of kernel sizes for the convolutions.
+
dilations : list
+
List of dilations for the convolutions.
+
activation : str
+
Activation function.
+
activation_params : dict
+
Parameters to provide to the activation function.
+
norm : str
+
Normalization method.
+
norm_params : dict
+
Parameters to provide to the underlying normalization used along with the convolution.
+
causal : bool
+
Whether to use fully causal convolution.
+
pad_mode : str
+
Padding mode for the convolutions.
+
compress : int
+
Reduced dimensionality in residual branches (from Demucs v3).
+
true_skip : bool
+
Whether to use true skip connection or a simple +(streamable) convolution as the skip connection.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class SEANetResnetBlock(nn.Module):
+    """Residual block from SEANet model.
+
+    Args:
+        dim (int): Dimension of the input/output.
+        kernel_sizes (list): List of kernel sizes for the convolutions.
+        dilations (list): List of dilations for the convolutions.
+        activation (str): Activation function.
+        activation_params (dict): Parameters to provide to the activation function.
+        norm (str): Normalization method.
+        norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
+        causal (bool): Whether to use fully causal convolution.
+        pad_mode (str): Padding mode for the convolutions.
+        compress (int): Reduced dimensionality in residual branches (from Demucs v3).
+        true_skip (bool): Whether to use true skip connection or a simple
+            (streamable) convolution as the skip connection.
+    """
+    def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1],
+                 activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
+                 norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, causal: bool = False,
+                 pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True):
+        super().__init__()
+        assert len(kernel_sizes) == len(dilations), 'Number of kernel sizes should match number of dilations'
+        act = getattr(nn, activation)
+        hidden = dim // compress
+        block = []
+        for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
+            in_chs = dim if i == 0 else hidden
+            out_chs = dim if i == len(kernel_sizes) - 1 else hidden
+            block += [
+                act(**activation_params),
+                StreamableConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation,
+                                 norm=norm, norm_kwargs=norm_params,
+                                 causal=causal, pad_mode=pad_mode),
+            ]
+        self.block = nn.Sequential(*block)
+        self.shortcut: nn.Module
+        if true_skip:
+            self.shortcut = nn.Identity()
+        else:
+            self.shortcut = StreamableConv1d(dim, dim, kernel_size=1, norm=norm, norm_kwargs=norm_params,
+                                             causal=causal, pad_mode=pad_mode)
+
+    def forward(self, x):
+        return self.shortcut(x) + self.block(x)
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, x) ‑> Callable[..., Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    return self.shortcut(x) + self.block(x)
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/audiocraft/modules/streaming.html b/docs/audiocraft/modules/streaming.html new file mode 100644 index 00000000..9e2df1f9 --- /dev/null +++ b/docs/audiocraft/modules/streaming.html @@ -0,0 +1,573 @@ + + + + + + +audiocraft.modules.streaming API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.modules.streaming

+
+
+

Streaming module API that should be implemented by all Streaming components,

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Streaming module API that should be implemented by all Streaming components,
+"""
+
+from contextlib import contextmanager
+import typing as tp
+from torch import nn
+import torch
+
+
+State = tp.Dict[str, torch.Tensor]
+
+
+class StreamingModule(nn.Module):
+    """Common API for streaming components.
+
+    Each streaming component has a streaming state, which is just a dict[str, Tensor].
+    By convention, the first dim of each tensor must be the batch size.
+    Don't use dots in the key names, as this would clash with submodules
+    (like in state_dict).
+
+    If `self._is_streaming` is True, the component should use and remember
+    the proper state inside `self._streaming_state`.
+
+    To set a streaming component in streaming state, use
+
+        with module.streaming():
+            ...
+
+    This will automatically reset the streaming state when exiting the context manager.
+    This also automatically propagates to all streaming children module.
+
+    Some module might also implement the `StreamingModule.flush` method, although
+    this one is trickier, as all parents module must be StreamingModule and implement
+    it as well for it to work properly. See `StreamingSequential` after.
+    """
+    def __init__(self) -> None:
+        super().__init__()
+        self._streaming_state: State = {}
+        self._is_streaming = False
+
+    def _apply_named_streaming(self, fn: tp.Any):
+        for name, module in self.named_modules():
+            if isinstance(module, StreamingModule):
+                fn(name, module)
+
+    def _set_streaming(self, streaming: bool):
+        def _set_streaming(name, module):
+            module._is_streaming = streaming
+        self._apply_named_streaming(_set_streaming)
+
+    @contextmanager
+    def streaming(self):
+        """Context manager to enter streaming mode. Reset streaming state on exit.
+        """
+        self._set_streaming(True)
+        try:
+            yield
+        finally:
+            self._set_streaming(False)
+            self.reset_streaming()
+
+    def reset_streaming(self):
+        """Reset the streaming state.
+        """
+        def _reset(name: str, module: StreamingModule):
+            module._streaming_state.clear()
+
+        self._apply_named_streaming(_reset)
+
+    def get_streaming_state(self) -> State:
+        """Return the streaming state, including that of sub-modules.
+        """
+        state: State = {}
+
+        def _add(name: str, module: StreamingModule):
+            if name:
+                name += "."
+            for key, value in module._streaming_state.items():
+                state[name + key] = value
+
+        self._apply_named_streaming(_add)
+        return state
+
+    def set_streaming_state(self, state: State):
+        """Set the streaming state, including that of sub-modules.
+        """
+        state = dict(state)
+
+        def _set(name: str, module: StreamingModule):
+            if name:
+                name += "."
+            module._streaming_state.clear()
+            for key, value in list(state.items()):
+                # complexity is not ideal here, but probably fine.
+                if key.startswith(name):
+                    local_key = key[len(name):]
+                    if '.' not in local_key:
+                        module._streaming_state[local_key] = value
+                        del state[key]
+
+        self._apply_named_streaming(_set)
+        assert len(state) == 0, list(state.keys())
+
+    def flush(self, x: tp.Optional[torch.Tensor] = None):
+        """Flush any remaining outputs that were waiting for completion.
+        Typically, for convolutions, this will add the final padding
+        and process the last buffer.
+
+        This should take an optional argument `x`, which will be provided
+        if a module before this one in the streaming pipeline has already
+        spitted out a flushed out buffer.
+        """
+        if x is None:
+            return None
+        else:
+            return self(x)
+
+
+class StreamingSequential(StreamingModule, nn.Sequential):
+    """A streaming compatible alternative of `nn.Sequential`.
+    """
+    def flush(self, x: tp.Optional[torch.Tensor] = None):
+        for module in self:
+            if isinstance(module, StreamingModule):
+                x = module.flush(x)
+            elif x is not None:
+                x = module(x)
+        return x
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class StreamingModule +
+
+

Common API for streaming components.

+

Each streaming component has a streaming state, which is just a dict[str, Tensor]. +By convention, the first dim of each tensor must be the batch size. +Don't use dots in the key names, as this would clash with submodules +(like in state_dict).

+

If self._is_streaming is True, the component should use and remember +the proper state inside self._streaming_state.

+

To set a streaming component in streaming state, use

+
with module.streaming():
+    ...
+
+

This will automatically reset the streaming state when exiting the context manager. +This also automatically propagates to all streaming children module.

+

Some module might also implement the StreamingModule.flush() method, although +this one is trickier, as all parents module must be StreamingModule and implement +it as well for it to work properly. See StreamingSequential after.

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class StreamingModule(nn.Module):
+    """Common API for streaming components.
+
+    Each streaming component has a streaming state, which is just a dict[str, Tensor].
+    By convention, the first dim of each tensor must be the batch size.
+    Don't use dots in the key names, as this would clash with submodules
+    (like in state_dict).
+
+    If `self._is_streaming` is True, the component should use and remember
+    the proper state inside `self._streaming_state`.
+
+    To set a streaming component in streaming state, use
+
+        with module.streaming():
+            ...
+
+    This will automatically reset the streaming state when exiting the context manager.
+    This also automatically propagates to all streaming children module.
+
+    Some module might also implement the `StreamingModule.flush` method, although
+    this one is trickier, as all parents module must be StreamingModule and implement
+    it as well for it to work properly. See `StreamingSequential` after.
+    """
+    def __init__(self) -> None:
+        super().__init__()
+        self._streaming_state: State = {}
+        self._is_streaming = False
+
+    def _apply_named_streaming(self, fn: tp.Any):
+        for name, module in self.named_modules():
+            if isinstance(module, StreamingModule):
+                fn(name, module)
+
+    def _set_streaming(self, streaming: bool):
+        def _set_streaming(name, module):
+            module._is_streaming = streaming
+        self._apply_named_streaming(_set_streaming)
+
+    @contextmanager
+    def streaming(self):
+        """Context manager to enter streaming mode. Reset streaming state on exit.
+        """
+        self._set_streaming(True)
+        try:
+            yield
+        finally:
+            self._set_streaming(False)
+            self.reset_streaming()
+
+    def reset_streaming(self):
+        """Reset the streaming state.
+        """
+        def _reset(name: str, module: StreamingModule):
+            module._streaming_state.clear()
+
+        self._apply_named_streaming(_reset)
+
+    def get_streaming_state(self) -> State:
+        """Return the streaming state, including that of sub-modules.
+        """
+        state: State = {}
+
+        def _add(name: str, module: StreamingModule):
+            if name:
+                name += "."
+            for key, value in module._streaming_state.items():
+                state[name + key] = value
+
+        self._apply_named_streaming(_add)
+        return state
+
+    def set_streaming_state(self, state: State):
+        """Set the streaming state, including that of sub-modules.
+        """
+        state = dict(state)
+
+        def _set(name: str, module: StreamingModule):
+            if name:
+                name += "."
+            module._streaming_state.clear()
+            for key, value in list(state.items()):
+                # complexity is not ideal here, but probably fine.
+                if key.startswith(name):
+                    local_key = key[len(name):]
+                    if '.' not in local_key:
+                        module._streaming_state[local_key] = value
+                        del state[key]
+
+        self._apply_named_streaming(_set)
+        assert len(state) == 0, list(state.keys())
+
+    def flush(self, x: tp.Optional[torch.Tensor] = None):
+        """Flush any remaining outputs that were waiting for completion.
+        Typically, for convolutions, this will add the final padding
+        and process the last buffer.
+
+        This should take an optional argument `x`, which will be provided
+        if a module before this one in the streaming pipeline has already
+        spitted out a flushed out buffer.
+        """
+        if x is None:
+            return None
+        else:
+            return self(x)
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Subclasses

+ +

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def flush(self, x: Optional[torch.Tensor] = None) +
+
+

Flush any remaining outputs that were waiting for completion. +Typically, for convolutions, this will add the final padding +and process the last buffer.

+

This should take an optional argument x, which will be provided +if a module before this one in the streaming pipeline has already +spitted out a flushed out buffer.

+
+ +Expand source code + +
def flush(self, x: tp.Optional[torch.Tensor] = None):
+    """Flush any remaining outputs that were waiting for completion.
+    Typically, for convolutions, this will add the final padding
+    and process the last buffer.
+
+    This should take an optional argument `x`, which will be provided
+    if a module before this one in the streaming pipeline has already
+    spitted out a flushed out buffer.
+    """
+    if x is None:
+        return None
+    else:
+        return self(x)
+
+
+
+def forward(self, *input: Any) ‑> None +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def _forward_unimplemented(self, *input: Any) -> None:
+    r"""Defines the computation performed at every call.
+
+    Should be overridden by all subclasses.
+
+    .. note::
+        Although the recipe for forward pass needs to be defined within
+        this function, one should call the :class:`Module` instance afterwards
+        instead of this since the former takes care of running the
+        registered hooks while the latter silently ignores them.
+    """
+    raise NotImplementedError(f"Module [{type(self).__name__}] is missing the required \"forward\" function")
+
+
+
+def get_streaming_state(self) ‑> Dict[str, torch.Tensor] +
+
+

Return the streaming state, including that of sub-modules.

+
+ +Expand source code + +
def get_streaming_state(self) -> State:
+    """Return the streaming state, including that of sub-modules.
+    """
+    state: State = {}
+
+    def _add(name: str, module: StreamingModule):
+        if name:
+            name += "."
+        for key, value in module._streaming_state.items():
+            state[name + key] = value
+
+    self._apply_named_streaming(_add)
+    return state
+
+
+
+def reset_streaming(self) +
+
+

Reset the streaming state.

+
+ +Expand source code + +
def reset_streaming(self):
+    """Reset the streaming state.
+    """
+    def _reset(name: str, module: StreamingModule):
+        module._streaming_state.clear()
+
+    self._apply_named_streaming(_reset)
+
+
+
+def set_streaming_state(self, state: Dict[str, torch.Tensor]) +
+
+

Set the streaming state, including that of sub-modules.

+
+ +Expand source code + +
def set_streaming_state(self, state: State):
+    """Set the streaming state, including that of sub-modules.
+    """
+    state = dict(state)
+
+    def _set(name: str, module: StreamingModule):
+        if name:
+            name += "."
+        module._streaming_state.clear()
+        for key, value in list(state.items()):
+            # complexity is not ideal here, but probably fine.
+            if key.startswith(name):
+                local_key = key[len(name):]
+                if '.' not in local_key:
+                    module._streaming_state[local_key] = value
+                    del state[key]
+
+    self._apply_named_streaming(_set)
+    assert len(state) == 0, list(state.keys())
+
+
+
+def streaming(self) +
+
+

Context manager to enter streaming mode. Reset streaming state on exit.

+
+ +Expand source code + +
@contextmanager
+def streaming(self):
+    """Context manager to enter streaming mode. Reset streaming state on exit.
+    """
+    self._set_streaming(True)
+    try:
+        yield
+    finally:
+        self._set_streaming(False)
+        self.reset_streaming()
+
+
+
+
+
+class StreamingSequential +
+
+

A streaming compatible alternative of nn.Sequential.

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class StreamingSequential(StreamingModule, nn.Sequential):
+    """A streaming compatible alternative of `nn.Sequential`.
+    """
+    def flush(self, x: tp.Optional[torch.Tensor] = None):
+        for module in self:
+            if isinstance(module, StreamingModule):
+                x = module.flush(x)
+            elif x is not None:
+                x = module(x)
+        return x
+
+

Ancestors

+
    +
  • StreamingModule
  • +
  • torch.nn.modules.container.Sequential
  • +
  • torch.nn.modules.module.Module
  • +
+

Inherited members

+ +
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/audiocraft/modules/transformer.html b/docs/audiocraft/modules/transformer.html new file mode 100644 index 00000000..d6a79793 --- /dev/null +++ b/docs/audiocraft/modules/transformer.html @@ -0,0 +1,2012 @@ + + + + + + +audiocraft.modules.transformer API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.modules.transformer

+
+
+

Transformer model, with streaming support, xformer attention support +and easy causal attention with a potentially finite receptive field.

+

See StreamingTransformer for more information.

+

Unlike regular PyTorch Transformer, we make the hard choice that batches are first.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Transformer model, with streaming support, xformer attention support
+and easy causal attention with a potentially finite receptive field.
+
+See `StreamingTransformer` for more information.
+
+Unlike regular PyTorch Transformer, we make the hard choice that batches are first.
+"""
+
+import typing as tp
+
+from einops import rearrange
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+from torch.utils.checkpoint import checkpoint as torch_checkpoint
+from xformers import ops
+
+from .rope import RotaryEmbedding
+from .streaming import StreamingModule
+
+_efficient_attention_backend: str = 'torch'
+
+
+def set_efficient_attention_backend(backend: str = 'torch'):
+    # Using torch by default, it seems a bit faster on older P100 GPUs (~20% faster).
+    global _efficient_attention_backend
+    assert _efficient_attention_backend in ['xformers', 'torch']
+    _efficient_attention_backend = backend
+
+
+def _get_attention_time_dimension() -> int:
+    if _efficient_attention_backend == 'torch':
+        return 2
+    else:
+        return 1
+
+
+def _is_profiled() -> bool:
+    # Return true if we are currently running with a xformers profiler activated.
+    try:
+        from xformers.profiler import profiler
+    except ImportError:
+        return False
+    return profiler._Profiler._CURRENT_PROFILER is not None
+
+
+def create_norm_fn(norm_type: str, dim: int, **kwargs) -> nn.Module:
+    """Create normalization module for transformer encoder layer.
+
+    Args:
+        norm_type (str): Normalization method.
+        dim (int): Dimension of the normalized layer.
+        **kwargs (dict): Additional parameters for normalization layer.
+    Returns:
+        nn.Module: Normalization module.
+    """
+    if norm_type == 'layer_norm':
+        return nn.LayerNorm(dim, eps=1e-5, **kwargs)
+    else:
+        raise ValueError(f"Unknown norm type: {norm_type}")
+
+
+def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000,
+                         dtype: torch.dtype = torch.float32) -> torch.Tensor:
+    """Create sinusoidal positional embedding, with shape `[B, T, C]`.
+
+    Args:
+        positions (torch.Tensor): LongTensor of positions.
+        dim (int): Dimension of the embedding.
+        max_period (float): Maximum period of the cosine/sine functions.
+        dtype (torch.dtype or str): dtype to use to generate the embedding.
+    Returns:
+        torch.Tensor: Sinusoidal positional embedding.
+    """
+    # We aim for BTC format
+    assert dim % 2 == 0
+    half_dim = dim // 2
+    positions = positions.to(dtype)
+    adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1)
+    max_period_tensor = torch.full([], max_period, device=positions.device, dtype=dtype)  # avoid sync point
+    phase = positions / (max_period_tensor ** (adim / (half_dim - 1)))
+    return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
+
+
+def expand_repeated_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
+    """torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers"""
+    if n_rep == 1:
+        return x
+    if _efficient_attention_backend == 'torch':
+        bs, n_kv_heads, slen, head_dim = x.shape
+        return (
+            x[:, :, None, :, :]
+            .expand(bs, n_kv_heads, n_rep, slen, head_dim)
+            .reshape(bs, n_kv_heads * n_rep, slen, head_dim)
+        )
+    else:
+        bs, slen, n_kv_heads, head_dim = x.shape
+        return (
+            x[:, :, :, None, :]
+            .expand(bs, slen, n_kv_heads, n_rep, head_dim)
+            .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
+        )
+
+
+class LayerScale(nn.Module):
+    """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
+    This rescales diagonaly the residual outputs close to 0, with a learnt scale.
+
+    Args:
+        channels (int): Number of channels.
+        init (float): Initial scale.
+        channel_last (bool): If True, expect `[*, C]` shaped tensors, otherwise, `[*, C, T]`.
+        device (torch.device or None): Device on which to initialize the module.
+        dtype (torch.dtype or None): dtype to use to initialize the module.
+    """
+    def __init__(self, channels: int, init: float = 1e-4, channel_last: bool = True,
+                 device=None, dtype=None):
+        super().__init__()
+        self.channel_last = channel_last
+        self.scale = nn.Parameter(
+            torch.full((channels,), init,
+                       requires_grad=True, device=device, dtype=dtype))
+
+    def forward(self, x: torch.Tensor):
+        if self.channel_last:
+            return self.scale * x
+        else:
+            return self.scale[:, None] * x
+
+
+class StreamingMultiheadAttention(StreamingModule):
+    """Similar to `nn.MultiheadAttention` but with support for streaming, causal evaluation.
+
+    Args:
+        embed_dim (int): Dimension to project to.
+        num_heads (int): Number of heads.
+        dropout (float): Dropout level.
+        bias (bool): Use bias in projections.
+        causal (bool): Causal mask applied automatically.
+        past_context (int or None): Receptive field for the causal mask, infinite if None.
+        custom (bool): Use custom MHA implementation, for testing / benchmarking.
+        memory_efficient (bool): Use xformers based memory efficient attention.
+        attention_as_float32 (bool): Perform the attention as float32
+            (especially important with memory_efficient as autocast won't do this automatically).
+        rope (`RotaryEmbedding` or None): Rope embedding to use.
+        cross_attention: Should be true when used as a cross attention.
+            All keys and values must be available at once, streaming is only for the queries.
+            Cannot be used with `causal` or `rope` (as it wouldn't make sens to
+            intepret the time steps in the keys relative to those in the queries).
+        safe_streaming (bool): Bug fix, will go away with xformers update.
+        qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product.
+        kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads).
+            This will lead to faster decoding time on A100 or other GPUs with tensorcore.
+        device (torch.device or None): Sevice on which to initialize.
+        dtype (torch.dtype or None): dtype to use.
+    """
+    def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True,
+                 causal: bool = False, past_context: tp.Optional[int] = None, custom: bool = False,
+                 memory_efficient: bool = False, attention_as_float32: bool = False,
+                 rope: tp.Optional[RotaryEmbedding] = None, cross_attention: bool = False,
+                 safe_streaming: bool = True, qk_layer_norm: bool = False, kv_repeat: int = 1,
+                 device=None, dtype=None):
+        super().__init__()
+        factory_kwargs = {'device': device, 'dtype': dtype}
+        if past_context is not None:
+            assert causal
+
+        self.embed_dim = embed_dim
+        self.causal = causal
+        self.past_context = past_context
+        self.memory_efficient = memory_efficient
+        self.attention_as_float32 = attention_as_float32
+        self.rope = rope
+        self.cross_attention = cross_attention
+        self.safe_streaming = safe_streaming
+        self.num_heads = num_heads
+        self.dropout = dropout
+        self.kv_repeat = kv_repeat
+        if cross_attention:
+            assert not causal, "Causal cannot work with cross attention."
+            assert rope is None, "Rope cannot work with cross attention."
+
+        if memory_efficient:
+            _verify_xformers_memory_efficient_compat()
+
+        self.custom = _is_custom(custom, memory_efficient)
+        if self.custom:
+            out_dim = embed_dim
+            assert num_heads % kv_repeat == 0
+            assert not cross_attention or kv_repeat == 1
+            num_kv = num_heads // kv_repeat
+            kv_dim = (embed_dim // num_heads) * num_kv
+            out_dim += 2 * kv_dim
+            in_proj = nn.Linear(embed_dim, out_dim, bias=bias, **factory_kwargs)
+            # We try to follow the default PyTorch MHA convention, to easily compare results.
+            self.in_proj_weight = in_proj.weight
+            self.in_proj_bias = in_proj.bias
+            if bias:
+                self.in_proj_bias.data.zero_()  # Following Pytorch convention
+            self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
+            if bias:
+                self.out_proj.bias.data.zero_()
+        else:
+            assert not qk_layer_norm
+            assert kv_repeat == 1
+            self.mha = nn.MultiheadAttention(
+                embed_dim, num_heads, dropout=dropout, bias=bias, batch_first=True,
+                **factory_kwargs)
+        self.qk_layer_norm = qk_layer_norm
+        if qk_layer_norm:
+            assert self.custom
+            assert kv_repeat == 1
+            ln_dim = embed_dim
+            self.q_layer_norm = nn.LayerNorm(ln_dim)
+            self.k_layer_norm = nn.LayerNorm(ln_dim)
+
+    def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
+        if not self.custom:
+            # Support compat with regular MHA
+            keys = [n for n, _ in self.mha.named_parameters()]
+            for key in keys:
+                if prefix + key in state_dict:
+                    state_dict[prefix + "mha." + key] = state_dict.pop(prefix + key)
+        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
+
+    def _get_mask(self, current_steps: int, device: torch.device, dtype: torch.dtype):
+        # Return a causal mask, accounting for potentially stored past keys/values
+        # We actually return a bias for the attention score, as this has the same
+        # convention both in the builtin MHA in Pytorch, and Xformers functions.
+        time_dim = _get_attention_time_dimension()
+        if self.memory_efficient:
+            from xformers.ops import LowerTriangularMask
+            if current_steps == 1:
+                # If we only have one step, then we do not need a mask.
+                return None
+            elif 'past_keys' in self._streaming_state:
+                raise RuntimeError('Not supported at the moment')
+            else:
+                # Then we can safely use a lower triangular mask
+                return LowerTriangularMask()
+        if self._streaming_state:
+            past_keys = self._streaming_state['past_keys']
+            past_steps = past_keys.shape[time_dim]
+        else:
+            past_steps = 0
+
+        queries_pos = torch.arange(
+            past_steps, current_steps + past_steps, device=device).view(-1, 1)
+        keys_pos = torch.arange(past_steps + current_steps, device=device).view(1, -1)
+        delta = queries_pos - keys_pos
+        valid = delta >= 0
+        if self.past_context is not None:
+            valid &= (delta <= self.past_context)
+        return torch.where(
+            valid,
+            torch.zeros([], device=device, dtype=dtype),
+            torch.full([], float('-inf'), device=device, dtype=dtype))
+
+    def _complete_kv(self, k, v):
+        time_dim = _get_attention_time_dimension()
+        if self.cross_attention:
+            # With cross attention we assume all keys and values
+            # are already available, and streaming is with respect
+            # to the queries only.
+            return k, v
+        # Complete the key/value pair using the streaming state.
+        if self._streaming_state:
+            pk = self._streaming_state['past_keys']
+            nk = torch.cat([pk, k], dim=time_dim)
+            if v is k:
+                nv = nk
+            else:
+                pv = self._streaming_state['past_values']
+                nv = torch.cat([pv, v], dim=time_dim)
+        else:
+            nk = k
+            nv = v
+
+        assert nk.shape[time_dim] == nv.shape[time_dim]
+        offset = 0
+        if self.past_context is not None:
+            offset = max(0, nk.shape[time_dim] - self.past_context)
+        if self._is_streaming:
+            self._streaming_state['past_keys'] = nk[:, offset:]
+            if v is not k:
+                self._streaming_state['past_values'] = nv[:, offset:]
+            if 'offset' in self._streaming_state:
+                self._streaming_state['offset'] += offset
+            else:
+                self._streaming_state['offset'] = torch.tensor(0)
+        return nk, nv
+
+    def _apply_rope(self, query: torch.Tensor, key: torch.Tensor):
+        # TODO: fix and verify layout.
+        assert _efficient_attention_backend == 'xformers', 'Rope not supported with torch attn.'
+        # Apply rope embeddings to query and key tensors.
+        assert self.rope is not None
+        if 'past_keys' in self._streaming_state:
+            past_keys_offset = self._streaming_state['past_keys'].shape[1]
+        else:
+            past_keys_offset = 0
+        if 'offset' in self._streaming_state:
+            past_context_offset = int(self._streaming_state['offset'].item())
+        else:
+            past_context_offset = 0
+        streaming_offset = past_context_offset + past_keys_offset
+        return self.rope.rotate_qk(query, key, start=streaming_offset)
+
+    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
+                key_padding_mask=None, need_weights=False, attn_mask=None,
+                average_attn_weights=True, is_causal=False):
+        assert attn_mask is None
+        assert not is_causal, ("new param added in torch 2.0.1 not supported, "
+                               "use the causal args in the constructor.")
+
+        time_dim = _get_attention_time_dimension()
+        if time_dim == 2:
+            layout = "b h t d"
+        else:
+            layout = "b t h d"
+        dtype = query.dtype
+        if self._is_streaming:
+            assert self.causal or self.cross_attention, \
+                "Streaming only available for causal or cross attention"
+
+        if self.causal:
+            # At the moment we specialize only for the self-attention case.
+            assert query.shape[1] == key.shape[1], "Causal only for same length query / key / value"
+            assert value.shape[1] == key.shape[1], "Causal only for same length query / key / value"
+            attn_mask = self._get_mask(query.shape[1], query.device, query.dtype)
+
+        if self.custom:
+            # custom implementation
+            assert need_weights is False
+            assert key_padding_mask is None
+            if self.cross_attention:
+                # Different queries, keys, values, we have to spit manually the weights
+                # before applying the linear.
+                dim = self.in_proj_weight.shape[0] // 3
+                if self.in_proj_bias is None:
+                    bias_q, bias_k, bias_v = None, None, None
+                else:
+                    bias_q = self.in_proj_bias[:dim]
+                    bias_k = self.in_proj_bias[dim: 2 * dim]
+                    bias_v = self.in_proj_bias[2 * dim:]
+                q = nn.functional.linear(query, self.in_proj_weight[:dim], bias_q)
+                # todo: when streaming, we could actually save k, v and check the shape actually match.
+                k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim], bias_k)
+                v = nn.functional.linear(value, self.in_proj_weight[2 * dim:], bias_v)
+                if self.qk_layer_norm is True:
+                    q = self.q_layer_norm(q)
+                    k = self.k_layer_norm(k)
+                q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
+            else:
+                if not _is_profiled():
+                    # profiling breaks that propertysomehow.
+                    assert query is key, "specialized implementation"
+                    assert value is key, "specialized implementation"
+                projected = nn.functional.linear(query, self.in_proj_weight, self.in_proj_bias)
+                if self.kv_repeat == 1:
+                    if time_dim == 2:
+                        bound_layout = "b h p t d"
+                    else:
+                        bound_layout = "b t p h d"
+                    packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
+                    q, k, v = ops.unbind(packed, dim=2)
+                else:
+                    embed_dim = self.embed_dim
+                    per_head_dim = (embed_dim // self.num_heads)
+                    kv_heads = self.num_heads // self.kv_repeat
+                    q = projected[:, :, :embed_dim]
+                    start = embed_dim
+                    end = start + per_head_dim * kv_heads
+                    k = projected[:, :, start: end]
+                    v = projected[:, :, end:]
+                    q = rearrange(q, f"b t (h d) -> {layout}", h=self.num_heads)
+                    k = rearrange(k, f"b t (h d) -> {layout}", h=kv_heads)
+                    v = rearrange(v, f"b t (h d) -> {layout}", h=kv_heads)
+
+                if self.qk_layer_norm is True:
+                    assert self.kv_repeat == 1
+                    q, k = [rearrange(x, f"{layout} -> b t (h d)") for x in [q, k]]
+                    q = self.q_layer_norm(q)
+                    k = self.k_layer_norm(k)
+                    q, k = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k]]
+                if self.rope:
+                    q, k = self._apply_rope(q, k)
+                k, v = self._complete_kv(k, v)
+                if self.kv_repeat > 1:
+                    k = expand_repeated_kv(k, self.kv_repeat)
+                    v = expand_repeated_kv(v, self.kv_repeat)
+            if self.attention_as_float32:
+                q, k, v = [x.float() for x in [q, k, v]]
+            if self.memory_efficient:
+                p = self.dropout if self.training else 0
+                if _efficient_attention_backend == 'torch':
+                    x = torch.nn.functional.scaled_dot_product_attention(
+                        q, k, v, is_causal=attn_mask is not None, dropout_p=p)
+                else:
+                    x = ops.memory_efficient_attention(q, k, v, attn_mask, p=p)
+            else:
+                # We include the dot product as float32, for consistency
+                # with the other implementations that include that step
+                # as part of the attention. Note that when using `autocast`,
+                # the einsums would be done as bfloat16, but the softmax
+                # would be done as bfloat16, so `attention_as_float32` will
+                # extend a bit the range of operations done in float32,
+                # although this should make no difference.
+                q = q / q.shape[-1] ** 0.5
+                key_layout = layout.replace('t', 'k')
+                query_layout = layout
+                if self._is_streaming and self.safe_streaming and q.device.type == 'cuda':
+                    with torch.autocast(device_type=q.device.type, dtype=torch.float32):
+                        pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k)
+                else:
+                    pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k)
+                if attn_mask is not None:
+                    pre_w = pre_w + attn_mask
+                w = torch.softmax(pre_w, dim=-1)
+                w = F.dropout(w, self.dropout, training=self.training).to(v)
+                # Key and value have the same format.
+                x = torch.einsum(f"b h t k, {key_layout} -> {layout}", w, v)
+            x = x.to(dtype)
+            x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
+            x = self.out_proj(x)
+        else:
+            key, value = self._complete_kv(key, value)
+            if self.attention_as_float32:
+                query, key, value = [x.float() for x in [query, key, value]]
+            x, _ = self.mha(
+                query, key, value, key_padding_mask,
+                need_weights, attn_mask, average_attn_weights)
+            x = x.to(dtype)
+
+        return x, None
+
+
+class StreamingTransformerLayer(nn.TransformerEncoderLayer):
+    """TransformerLayer with Streaming / Causal support.
+    This also integrates cross_attention, when passing `cross_attention=True`,
+    rather than having two separate classes like in PyTorch.
+
+    Args:
+        d_model (int): Dimension of the data.
+        num_heads (int): Number of heads.
+        dim_feedforward (int): Intermediate dimension of FF module.
+        dropout (float): Dropout both for MHA and FF.
+        bias_ff (bool): Use bias for FF.
+        bias_attn (bool): Use bias for MHA.
+        causal (bool): Causal mask applied automatically.
+        past_context (int or None): Receptive field for the causal mask, infinite if None.
+        custom (bool): Use custom MHA implementation, for testing / benchmarking.
+        memory_efficient (bool): Use xformers based memory efficient attention.
+        attention_as_float32 (bool): Perform the attention as float32
+            (especially important with memory_efficient as autocast won't do this automatically).
+        qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product in attention.
+        qk_layer_norm_cross (bool): Same for the cross attention.
+        cross_attention (bool): If True, expect to get secondary input for cross-attention.
+            Cross attention will use the default MHA, as it typically won't require
+            special treatment.
+        layer_scale (float or None): If not None, LayerScale will be used with
+            the given value as initial scale.
+        rope (`RotaryEmbedding` or None): Rope embedding to use.
+        attention_dropout (float or None): If not None, separate the value of the dimension dropout
+            in FFN and of the attention dropout.
+        kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads).
+            This will lead to faster decoding time on A100 or other GPUs with tensorcore.
+        device (torch.device or None): Device on which to initialize.
+        dtype (torch.dtype or None): dtype to use.
+        **kwargs: See `nn.TransformerEncoderLayer`.
+    """
+    def __init__(self, d_model: int, num_heads: int, dim_feedforward: int = 2048, dropout: float = 0.1,
+                 bias_ff: bool = True, bias_attn: bool = True, causal: bool = False,
+                 past_context: tp.Optional[int] = None, custom: bool = False,
+                 memory_efficient: bool = False, attention_as_float32: bool = False,
+                 qk_layer_norm: bool = False, qk_layer_norm_cross: bool = False,
+                 cross_attention: bool = False, layer_scale: tp.Optional[float] = None,
+                 rope: tp.Optional[RotaryEmbedding] = None, attention_dropout: tp.Optional[float] = None,
+                 kv_repeat: int = 1, norm: str = 'layer_norm', device=None, dtype=None, **kwargs):
+        super().__init__(d_model, num_heads, dim_feedforward, dropout,
+                         device=device, dtype=dtype, batch_first=True, **kwargs)
+        factory_kwargs = {'device': device, 'dtype': dtype}
+        # Redefine self_attn to our streaming multi-head attention
+        attn_kwargs: tp.Dict[str, tp.Any] = {
+            'embed_dim': d_model,
+            'num_heads': num_heads,
+            'dropout': dropout if attention_dropout is None else attention_dropout,
+            'bias': bias_attn,
+            'custom': custom,
+            'memory_efficient': memory_efficient,
+            'attention_as_float32': attention_as_float32,
+        }
+        self.self_attn: StreamingMultiheadAttention = StreamingMultiheadAttention(
+            causal=causal, past_context=past_context, rope=rope, qk_layer_norm=qk_layer_norm,
+            kv_repeat=kv_repeat, **attn_kwargs, **factory_kwargs)  # type: ignore
+        # Redefine feedforward layers to expose bias parameter
+        self.linear1 = nn.Linear(d_model, dim_feedforward, bias=bias_ff, **factory_kwargs)
+        self.linear2 = nn.Linear(dim_feedforward, d_model, bias=bias_ff, **factory_kwargs)
+
+        self.layer_scale_1: nn.Module
+        self.layer_scale_2: nn.Module
+        if layer_scale is None:
+            self.layer_scale_1 = nn.Identity()
+            self.layer_scale_2 = nn.Identity()
+        else:
+            self.layer_scale_1 = LayerScale(d_model, layer_scale, **factory_kwargs)
+            self.layer_scale_2 = LayerScale(d_model, layer_scale, **factory_kwargs)
+
+        self.cross_attention: tp.Optional[nn.Module] = None
+        if cross_attention:
+            self.cross_attention = StreamingMultiheadAttention(
+                cross_attention=True, qk_layer_norm=qk_layer_norm_cross,
+                **attn_kwargs, **factory_kwargs)
+            # Norm and dropout
+            self.dropout_cross = nn.Dropout(dropout)
+            # eps value matching that used in PyTorch reference implementation.
+            self.norm_cross = nn.LayerNorm(d_model, eps=1e-5, **factory_kwargs)
+            self.layer_scale_cross: nn.Module
+            if layer_scale is None:
+                self.layer_scale_cross = nn.Identity()
+            else:
+                self.layer_scale_cross = LayerScale(d_model, layer_scale, **factory_kwargs)
+        self.norm1 = create_norm_fn(norm, d_model, **factory_kwargs)  # type: ignore
+        self.norm2 = create_norm_fn(norm, d_model, **factory_kwargs)  # type: ignore
+
+    def _cross_attention_block(self, src: torch.Tensor,
+                               cross_attention_src: torch.Tensor) -> torch.Tensor:
+        assert self.cross_attention is not None
+        # queries are from src, keys and values from cross_attention_src.
+        x = self.cross_attention(
+            src, cross_attention_src, cross_attention_src, need_weights=False)[0]
+        return self.dropout_cross(x)  # type: ignore
+
+    def forward(self, src: torch.Tensor, src_mask: tp.Optional[torch.Tensor] = None,  # type: ignore
+                src_key_padding_mask: tp.Optional[torch.Tensor] = None,
+                cross_attention_src: tp.Optional[torch.Tensor] = None):
+        if self.cross_attention is None:
+            assert cross_attention_src is None
+        else:
+            assert cross_attention_src is not None
+        x = src
+        if self.norm_first:
+            x = x + self.layer_scale_1(
+                self._sa_block(self.norm1(x), src_mask, src_key_padding_mask))
+            if cross_attention_src is not None:
+                x = x + self.layer_scale_cross(
+                    self._cross_attention_block(
+                        self.norm_cross(x), cross_attention_src))
+            x = x + self.layer_scale_2(self._ff_block(self.norm2(x)))
+        else:
+            x = self.norm1(x + self.layer_scale_1(
+                self._sa_block(x, src_mask, src_key_padding_mask)))
+            if cross_attention_src is not None:
+                x = self.norm_cross(
+                    x + self.layer_scale_cross(
+                        self._cross_attention_block(src, cross_attention_src)))
+            x = self.norm2(x + self.layer_scale_2(self._ff_block(x)))
+        return x
+
+
+class StreamingTransformer(StreamingModule):
+    """Transformer with Streaming / Causal support.
+
+    Args:
+        d_model (int): Dimension of the data.
+        num_heads (int): Number of heads.
+        dim_feedforward (int): Intermediate dimension of FF module.
+        dropout (float): Dropout both for MHA and FF.
+        bias_ff (bool): Use bias for FF.
+        bias_attn (bool): Use bias for MHA.
+        causal (bool): Causal mask applied automatically.
+        past_context (int or None): Receptive field for the causal mask, infinite if None.
+        custom (bool): Use custom MHA implementation, for testing / benchmarking.
+        memory_efficient (bool): Use xformers based memory efficient attention.
+        attention_as_float32 (bool): Perform the attention as float32
+            (especially important with memory_efficient as autocast won't do this automatically).
+        cross_attention (bool): If True, expect to get secondary input for cross-attention.
+        layer_scale (float or None): If not None, LayerScale will be used
+            with the given value as initial scale.
+        positional_embedding (str): Positional embedding strategy (sin, rope, or sin_rope).
+        max_period (float): Maximum period of the time embedding.
+        positional_scale (float): Scale of positional embedding, set to 0 to deactivate.
+        xpos (bool): Apply xpos exponential decay to positional embedding (rope only).
+        lr (float or None): learning rate override through the `make_optim_group` API.
+        weight_decay (float or None): Weight_decay override through the `make_optim_group` API.
+        layer_class: (subclass of `StreamingTransformerLayer): class to use
+            to initialize the layers, allowing further customization outside of Audiocraft.
+        checkpointing (str): Checkpointing strategy to reduce memory usage.
+            No checkpointing if set to 'none'. Per layer checkpointing using PyTorch
+            if set to 'torch' (entire layer checkpointed, i.e. linears are evaluated twice,
+            minimal memory usage, but maximal runtime). Finally, `xformers_default` provide
+            a policy for opting-out some operations of the checkpointing like
+            linear layers and attention, providing a middle ground between speed and memory.
+        device (torch.device or None): Device on which to initialize.
+        dtype (torch.dtype or None): dtype to use.
+        **kwargs: See `nn.TransformerEncoderLayer`.
+    """
+    def __init__(self, d_model: int, num_heads: int, num_layers: int, dim_feedforward: int = 2048,
+                 dropout: float = 0.1, bias_ff: bool = True, bias_attn: bool = True,
+                 causal: bool = False, past_context: tp.Optional[int] = None,
+                 custom: bool = False, memory_efficient: bool = False, attention_as_float32: bool = False,
+                 cross_attention: bool = False, layer_scale: tp.Optional[float] = None,
+                 positional_embedding: str = 'sin', max_period: float = 10_000, positional_scale: float = 1.,
+                 xpos: bool = False, lr: tp.Optional[float] = None, weight_decay: tp.Optional[float] = None,
+                 layer_class: tp.Type[StreamingTransformerLayer] = StreamingTransformerLayer,
+                 checkpointing: str = 'none', device=None, dtype=None, **kwargs):
+        super().__init__()
+        assert d_model % num_heads == 0
+
+        self.positional_embedding = positional_embedding
+        self.max_period = max_period
+        self.positional_scale = positional_scale
+        self.weight_decay = weight_decay
+        self.lr = lr
+
+        assert positional_embedding in ['sin', 'rope', 'sin_rope']
+        self.rope: tp.Optional[RotaryEmbedding] = None
+        if self.positional_embedding in ['rope', 'sin_rope']:
+            assert _is_custom(custom, memory_efficient)
+            self.rope = RotaryEmbedding(d_model // num_heads, max_period=max_period,
+                                        xpos=xpos, scale=positional_scale, device=device)
+
+        self.checkpointing = checkpointing
+
+        assert checkpointing in ['none', 'torch', 'xformers_default', 'xformers_mm']
+        if self.checkpointing.startswith('xformers'):
+            _verify_xformers_internal_compat()
+
+        self.layers = nn.ModuleList()
+        for idx in range(num_layers):
+            self.layers.append(
+                layer_class(
+                    d_model=d_model, num_heads=num_heads, dim_feedforward=dim_feedforward,
+                    dropout=dropout, bias_ff=bias_ff, bias_attn=bias_attn,
+                    causal=causal, past_context=past_context, custom=custom,
+                    memory_efficient=memory_efficient, attention_as_float32=attention_as_float32,
+                    cross_attention=cross_attention, layer_scale=layer_scale, rope=self.rope,
+                    device=device, dtype=dtype, **kwargs))
+
+        if self.checkpointing != 'none':
+            for layer in self.layers:
+                # see audiocraft/optim/fsdp.py, magic signal to indicate this requires fixing the
+                # backward hook inside of FSDP...
+                layer._magma_checkpointed = True  # type: ignore
+                assert layer.layer_drop == 0., "Need further checking"  # type: ignore
+
+    def _apply_layer(self, layer, *args, **kwargs):
+        method = self.checkpointing
+        if method == 'none':
+            return layer(*args, **kwargs)
+        elif method == 'torch':
+            return torch_checkpoint(layer, *args, use_reentrant=False, **kwargs)
+        elif method.startswith('xformers'):
+            from xformers.checkpoint_fairinternal import checkpoint, _get_default_policy
+            if method == 'xformers_default':
+                # those operations will be saved, and not recomputed.
+                # According to Francisco we can get smarter policies but this is a good start.
+                allow_list = [
+                    "xformers.efficient_attention_forward_cutlass.default",
+                    "xformers_flash.flash_fwd.default",
+                    "aten.addmm.default",
+                    "aten.mm.default",
+                ]
+            elif method == 'xformers_mm':
+                # those operations will be saved, and not recomputed.
+                # According to Francisco we can get smarter policies but this is a good start.
+                allow_list = [
+                    "aten.addmm.default",
+                    "aten.mm.default",
+                ]
+            else:
+                raise ValueError(f"xformers checkpointing xformers policy {method} is not known.")
+            policy_fn = _get_default_policy(allow_list)
+            return checkpoint(layer, *args, policy_fn=policy_fn, **kwargs)
+        else:
+            raise ValueError(f"Checkpointing method {method} is unknown.")
+
+    def forward(self, x: torch.Tensor, *args, **kwargs):
+        B, T, C = x.shape
+
+        if 'offsets' in self._streaming_state:
+            offsets = self._streaming_state['offsets']
+        else:
+            offsets = torch.zeros(B, dtype=torch.long, device=x.device)
+
+        if self.positional_embedding in ['sin', 'sin_rope']:
+            positions = torch.arange(T, device=x.device).view(1, -1, 1)
+            positions = positions + offsets.view(-1, 1, 1)
+            pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype)
+            x = x + self.positional_scale * pos_emb
+
+        for layer in self.layers:
+            x = self._apply_layer(layer, x, *args, **kwargs)
+
+        if self._is_streaming:
+            self._streaming_state['offsets'] = offsets + T
+
+        return x
+
+    def make_optim_group(self):
+        group = {"params": list(self.parameters())}
+        if self.lr is not None:
+            group["lr"] = self.lr
+        if self.weight_decay is not None:
+            group["weight_decay"] = self.weight_decay
+        return group
+
+
+# special attention attention related function
+
+def _verify_xformers_memory_efficient_compat():
+    try:
+        from xformers.ops import memory_efficient_attention, LowerTriangularMask  # noqa
+    except ImportError:
+        raise ImportError(
+            "xformers is not installed. Please install it and try again.\n"
+            "To install on AWS and Azure, run \n"
+            "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='8.0'\\\n"
+            "pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n"
+            "To install on FAIR Cluster, run \n"
+            "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='6.0;7.0'\\\n"
+            "pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n")
+
+
+def _verify_xformers_internal_compat():
+    try:
+        from xformers.checkpoint_fairinternal import checkpoint, _get_default_policy  # noqa
+    except ImportError:
+        raise ImportError(
+            "Francisco's fairinternal xformers is not installed. Please install it and try again.\n"
+            "To install on AWS and Azure, run \n"
+            "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='8.0'\\\n"
+            "pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n"
+            "To install on FAIR Cluster, run \n"
+            "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='6.0;7.0'\\\n"
+            "pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n")
+
+
+def _is_custom(custom: bool, memory_efficient: bool):
+    return custom or memory_efficient
+
+
+
+
+
+
+
+

Functions

+
+
+def create_norm_fn(norm_type: str, dim: int, **kwargs) ‑> torch.nn.modules.module.Module +
+
+

Create normalization module for transformer encoder layer.

+

Args

+
+
norm_type : str
+
Normalization method.
+
dim : int
+
Dimension of the normalized layer.
+
**kwargs : dict
+
Additional parameters for normalization layer.
+
+

Returns

+
+
nn.Module
+
Normalization module.
+
+
+ +Expand source code + +
def create_norm_fn(norm_type: str, dim: int, **kwargs) -> nn.Module:
+    """Create normalization module for transformer encoder layer.
+
+    Args:
+        norm_type (str): Normalization method.
+        dim (int): Dimension of the normalized layer.
+        **kwargs (dict): Additional parameters for normalization layer.
+    Returns:
+        nn.Module: Normalization module.
+    """
+    if norm_type == 'layer_norm':
+        return nn.LayerNorm(dim, eps=1e-5, **kwargs)
+    else:
+        raise ValueError(f"Unknown norm type: {norm_type}")
+
+
+
+def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000, dtype: torch.dtype = torch.float32) ‑> torch.Tensor +
+
+

Create sinusoidal positional embedding, with shape [B, T, C].

+

Args

+
+
positions : torch.Tensor
+
LongTensor of positions.
+
dim : int
+
Dimension of the embedding.
+
max_period : float
+
Maximum period of the cosine/sine functions.
+
dtype : torch.dtype or str
+
dtype to use to generate the embedding.
+
+

Returns

+
+
torch.Tensor
+
Sinusoidal positional embedding.
+
+
+ +Expand source code + +
def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000,
+                         dtype: torch.dtype = torch.float32) -> torch.Tensor:
+    """Create sinusoidal positional embedding, with shape `[B, T, C]`.
+
+    Args:
+        positions (torch.Tensor): LongTensor of positions.
+        dim (int): Dimension of the embedding.
+        max_period (float): Maximum period of the cosine/sine functions.
+        dtype (torch.dtype or str): dtype to use to generate the embedding.
+    Returns:
+        torch.Tensor: Sinusoidal positional embedding.
+    """
+    # We aim for BTC format
+    assert dim % 2 == 0
+    half_dim = dim // 2
+    positions = positions.to(dtype)
+    adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1)
+    max_period_tensor = torch.full([], max_period, device=positions.device, dtype=dtype)  # avoid sync point
+    phase = positions / (max_period_tensor ** (adim / (half_dim - 1)))
+    return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
+
+
+
+def expand_repeated_kv(x: torch.Tensor, n_rep: int) ‑> torch.Tensor +
+
+

torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers

+
+ +Expand source code + +
def expand_repeated_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
+    """torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers"""
+    if n_rep == 1:
+        return x
+    if _efficient_attention_backend == 'torch':
+        bs, n_kv_heads, slen, head_dim = x.shape
+        return (
+            x[:, :, None, :, :]
+            .expand(bs, n_kv_heads, n_rep, slen, head_dim)
+            .reshape(bs, n_kv_heads * n_rep, slen, head_dim)
+        )
+    else:
+        bs, slen, n_kv_heads, head_dim = x.shape
+        return (
+            x[:, :, :, None, :]
+            .expand(bs, slen, n_kv_heads, n_rep, head_dim)
+            .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
+        )
+
+
+
+def set_efficient_attention_backend(backend: str = 'torch') +
+
+
+
+ +Expand source code + +
def set_efficient_attention_backend(backend: str = 'torch'):
+    # Using torch by default, it seems a bit faster on older P100 GPUs (~20% faster).
+    global _efficient_attention_backend
+    assert _efficient_attention_backend in ['xformers', 'torch']
+    _efficient_attention_backend = backend
+
+
+
+
+
+

Classes

+
+
+class LayerScale +(channels: int, init: float = 0.0001, channel_last: bool = True, device=None, dtype=None) +
+
+

Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf). +This rescales diagonaly the residual outputs close to 0, with a learnt scale.

+

Args

+
+
channels : int
+
Number of channels.
+
init : float
+
Initial scale.
+
channel_last : bool
+
If True, expect [*, C] shaped tensors, otherwise, [*, C, T].
+
device : torch.device or None
+
Device on which to initialize the module.
+
dtype : torch.dtype or None
+
dtype to use to initialize the module.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class LayerScale(nn.Module):
+    """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
+    This rescales diagonaly the residual outputs close to 0, with a learnt scale.
+
+    Args:
+        channels (int): Number of channels.
+        init (float): Initial scale.
+        channel_last (bool): If True, expect `[*, C]` shaped tensors, otherwise, `[*, C, T]`.
+        device (torch.device or None): Device on which to initialize the module.
+        dtype (torch.dtype or None): dtype to use to initialize the module.
+    """
+    def __init__(self, channels: int, init: float = 1e-4, channel_last: bool = True,
+                 device=None, dtype=None):
+        super().__init__()
+        self.channel_last = channel_last
+        self.scale = nn.Parameter(
+            torch.full((channels,), init,
+                       requires_grad=True, device=device, dtype=dtype))
+
+    def forward(self, x: torch.Tensor):
+        if self.channel_last:
+            return self.scale * x
+        else:
+            return self.scale[:, None] * x
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, x: torch.Tensor) ‑> Callable[..., Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x: torch.Tensor):
+    if self.channel_last:
+        return self.scale * x
+    else:
+        return self.scale[:, None] * x
+
+
+
+
+
+class StreamingMultiheadAttention +(embed_dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True, causal: bool = False, past_context: Optional[int] = None, custom: bool = False, memory_efficient: bool = False, attention_as_float32: bool = False, rope: Optional[RotaryEmbedding] = None, cross_attention: bool = False, safe_streaming: bool = True, qk_layer_norm: bool = False, kv_repeat: int = 1, device=None, dtype=None) +
+
+

Similar to nn.MultiheadAttention but with support for streaming, causal evaluation.

+

Args

+
+
embed_dim : int
+
Dimension to project to.
+
num_heads : int
+
Number of heads.
+
dropout : float
+
Dropout level.
+
bias : bool
+
Use bias in projections.
+
causal : bool
+
Causal mask applied automatically.
+
past_context : int or None
+
Receptive field for the causal mask, infinite if None.
+
custom : bool
+
Use custom MHA implementation, for testing / benchmarking.
+
memory_efficient : bool
+
Use xformers based memory efficient attention.
+
attention_as_float32 : bool
+
Perform the attention as float32 +(especially important with memory_efficient as autocast won't do this automatically).
+
rope (RotaryEmbedding or None): Rope embedding to use.
+
cross_attention
+
Should be true when used as a cross attention. +All keys and values must be available at once, streaming is only for the queries. +Cannot be used with causal or rope (as it wouldn't make sens to +intepret the time steps in the keys relative to those in the queries).
+
safe_streaming : bool
+
Bug fix, will go away with xformers update.
+
qk_layer_norm : bool
+
Layer normalization applied to queries and keys before dot product.
+
kv_repeat : int
+
If > 1, will repeat keys and queries multiple times (need to divide num_heads). +This will lead to faster decoding time on A100 or other GPUs with tensorcore.
+
device : torch.device or None
+
Sevice on which to initialize.
+
dtype : torch.dtype or None
+
dtype to use.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class StreamingMultiheadAttention(StreamingModule):
+    """Similar to `nn.MultiheadAttention` but with support for streaming, causal evaluation.
+
+    Args:
+        embed_dim (int): Dimension to project to.
+        num_heads (int): Number of heads.
+        dropout (float): Dropout level.
+        bias (bool): Use bias in projections.
+        causal (bool): Causal mask applied automatically.
+        past_context (int or None): Receptive field for the causal mask, infinite if None.
+        custom (bool): Use custom MHA implementation, for testing / benchmarking.
+        memory_efficient (bool): Use xformers based memory efficient attention.
+        attention_as_float32 (bool): Perform the attention as float32
+            (especially important with memory_efficient as autocast won't do this automatically).
+        rope (`RotaryEmbedding` or None): Rope embedding to use.
+        cross_attention: Should be true when used as a cross attention.
+            All keys and values must be available at once, streaming is only for the queries.
+            Cannot be used with `causal` or `rope` (as it wouldn't make sens to
+            intepret the time steps in the keys relative to those in the queries).
+        safe_streaming (bool): Bug fix, will go away with xformers update.
+        qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product.
+        kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads).
+            This will lead to faster decoding time on A100 or other GPUs with tensorcore.
+        device (torch.device or None): Sevice on which to initialize.
+        dtype (torch.dtype or None): dtype to use.
+    """
+    def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True,
+                 causal: bool = False, past_context: tp.Optional[int] = None, custom: bool = False,
+                 memory_efficient: bool = False, attention_as_float32: bool = False,
+                 rope: tp.Optional[RotaryEmbedding] = None, cross_attention: bool = False,
+                 safe_streaming: bool = True, qk_layer_norm: bool = False, kv_repeat: int = 1,
+                 device=None, dtype=None):
+        super().__init__()
+        factory_kwargs = {'device': device, 'dtype': dtype}
+        if past_context is not None:
+            assert causal
+
+        self.embed_dim = embed_dim
+        self.causal = causal
+        self.past_context = past_context
+        self.memory_efficient = memory_efficient
+        self.attention_as_float32 = attention_as_float32
+        self.rope = rope
+        self.cross_attention = cross_attention
+        self.safe_streaming = safe_streaming
+        self.num_heads = num_heads
+        self.dropout = dropout
+        self.kv_repeat = kv_repeat
+        if cross_attention:
+            assert not causal, "Causal cannot work with cross attention."
+            assert rope is None, "Rope cannot work with cross attention."
+
+        if memory_efficient:
+            _verify_xformers_memory_efficient_compat()
+
+        self.custom = _is_custom(custom, memory_efficient)
+        if self.custom:
+            out_dim = embed_dim
+            assert num_heads % kv_repeat == 0
+            assert not cross_attention or kv_repeat == 1
+            num_kv = num_heads // kv_repeat
+            kv_dim = (embed_dim // num_heads) * num_kv
+            out_dim += 2 * kv_dim
+            in_proj = nn.Linear(embed_dim, out_dim, bias=bias, **factory_kwargs)
+            # We try to follow the default PyTorch MHA convention, to easily compare results.
+            self.in_proj_weight = in_proj.weight
+            self.in_proj_bias = in_proj.bias
+            if bias:
+                self.in_proj_bias.data.zero_()  # Following Pytorch convention
+            self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
+            if bias:
+                self.out_proj.bias.data.zero_()
+        else:
+            assert not qk_layer_norm
+            assert kv_repeat == 1
+            self.mha = nn.MultiheadAttention(
+                embed_dim, num_heads, dropout=dropout, bias=bias, batch_first=True,
+                **factory_kwargs)
+        self.qk_layer_norm = qk_layer_norm
+        if qk_layer_norm:
+            assert self.custom
+            assert kv_repeat == 1
+            ln_dim = embed_dim
+            self.q_layer_norm = nn.LayerNorm(ln_dim)
+            self.k_layer_norm = nn.LayerNorm(ln_dim)
+
+    def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
+        if not self.custom:
+            # Support compat with regular MHA
+            keys = [n for n, _ in self.mha.named_parameters()]
+            for key in keys:
+                if prefix + key in state_dict:
+                    state_dict[prefix + "mha." + key] = state_dict.pop(prefix + key)
+        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
+
+    def _get_mask(self, current_steps: int, device: torch.device, dtype: torch.dtype):
+        # Return a causal mask, accounting for potentially stored past keys/values
+        # We actually return a bias for the attention score, as this has the same
+        # convention both in the builtin MHA in Pytorch, and Xformers functions.
+        time_dim = _get_attention_time_dimension()
+        if self.memory_efficient:
+            from xformers.ops import LowerTriangularMask
+            if current_steps == 1:
+                # If we only have one step, then we do not need a mask.
+                return None
+            elif 'past_keys' in self._streaming_state:
+                raise RuntimeError('Not supported at the moment')
+            else:
+                # Then we can safely use a lower triangular mask
+                return LowerTriangularMask()
+        if self._streaming_state:
+            past_keys = self._streaming_state['past_keys']
+            past_steps = past_keys.shape[time_dim]
+        else:
+            past_steps = 0
+
+        queries_pos = torch.arange(
+            past_steps, current_steps + past_steps, device=device).view(-1, 1)
+        keys_pos = torch.arange(past_steps + current_steps, device=device).view(1, -1)
+        delta = queries_pos - keys_pos
+        valid = delta >= 0
+        if self.past_context is not None:
+            valid &= (delta <= self.past_context)
+        return torch.where(
+            valid,
+            torch.zeros([], device=device, dtype=dtype),
+            torch.full([], float('-inf'), device=device, dtype=dtype))
+
+    def _complete_kv(self, k, v):
+        time_dim = _get_attention_time_dimension()
+        if self.cross_attention:
+            # With cross attention we assume all keys and values
+            # are already available, and streaming is with respect
+            # to the queries only.
+            return k, v
+        # Complete the key/value pair using the streaming state.
+        if self._streaming_state:
+            pk = self._streaming_state['past_keys']
+            nk = torch.cat([pk, k], dim=time_dim)
+            if v is k:
+                nv = nk
+            else:
+                pv = self._streaming_state['past_values']
+                nv = torch.cat([pv, v], dim=time_dim)
+        else:
+            nk = k
+            nv = v
+
+        assert nk.shape[time_dim] == nv.shape[time_dim]
+        offset = 0
+        if self.past_context is not None:
+            offset = max(0, nk.shape[time_dim] - self.past_context)
+        if self._is_streaming:
+            self._streaming_state['past_keys'] = nk[:, offset:]
+            if v is not k:
+                self._streaming_state['past_values'] = nv[:, offset:]
+            if 'offset' in self._streaming_state:
+                self._streaming_state['offset'] += offset
+            else:
+                self._streaming_state['offset'] = torch.tensor(0)
+        return nk, nv
+
+    def _apply_rope(self, query: torch.Tensor, key: torch.Tensor):
+        # TODO: fix and verify layout.
+        assert _efficient_attention_backend == 'xformers', 'Rope not supported with torch attn.'
+        # Apply rope embeddings to query and key tensors.
+        assert self.rope is not None
+        if 'past_keys' in self._streaming_state:
+            past_keys_offset = self._streaming_state['past_keys'].shape[1]
+        else:
+            past_keys_offset = 0
+        if 'offset' in self._streaming_state:
+            past_context_offset = int(self._streaming_state['offset'].item())
+        else:
+            past_context_offset = 0
+        streaming_offset = past_context_offset + past_keys_offset
+        return self.rope.rotate_qk(query, key, start=streaming_offset)
+
+    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
+                key_padding_mask=None, need_weights=False, attn_mask=None,
+                average_attn_weights=True, is_causal=False):
+        assert attn_mask is None
+        assert not is_causal, ("new param added in torch 2.0.1 not supported, "
+                               "use the causal args in the constructor.")
+
+        time_dim = _get_attention_time_dimension()
+        if time_dim == 2:
+            layout = "b h t d"
+        else:
+            layout = "b t h d"
+        dtype = query.dtype
+        if self._is_streaming:
+            assert self.causal or self.cross_attention, \
+                "Streaming only available for causal or cross attention"
+
+        if self.causal:
+            # At the moment we specialize only for the self-attention case.
+            assert query.shape[1] == key.shape[1], "Causal only for same length query / key / value"
+            assert value.shape[1] == key.shape[1], "Causal only for same length query / key / value"
+            attn_mask = self._get_mask(query.shape[1], query.device, query.dtype)
+
+        if self.custom:
+            # custom implementation
+            assert need_weights is False
+            assert key_padding_mask is None
+            if self.cross_attention:
+                # Different queries, keys, values, we have to spit manually the weights
+                # before applying the linear.
+                dim = self.in_proj_weight.shape[0] // 3
+                if self.in_proj_bias is None:
+                    bias_q, bias_k, bias_v = None, None, None
+                else:
+                    bias_q = self.in_proj_bias[:dim]
+                    bias_k = self.in_proj_bias[dim: 2 * dim]
+                    bias_v = self.in_proj_bias[2 * dim:]
+                q = nn.functional.linear(query, self.in_proj_weight[:dim], bias_q)
+                # todo: when streaming, we could actually save k, v and check the shape actually match.
+                k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim], bias_k)
+                v = nn.functional.linear(value, self.in_proj_weight[2 * dim:], bias_v)
+                if self.qk_layer_norm is True:
+                    q = self.q_layer_norm(q)
+                    k = self.k_layer_norm(k)
+                q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
+            else:
+                if not _is_profiled():
+                    # profiling breaks that propertysomehow.
+                    assert query is key, "specialized implementation"
+                    assert value is key, "specialized implementation"
+                projected = nn.functional.linear(query, self.in_proj_weight, self.in_proj_bias)
+                if self.kv_repeat == 1:
+                    if time_dim == 2:
+                        bound_layout = "b h p t d"
+                    else:
+                        bound_layout = "b t p h d"
+                    packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
+                    q, k, v = ops.unbind(packed, dim=2)
+                else:
+                    embed_dim = self.embed_dim
+                    per_head_dim = (embed_dim // self.num_heads)
+                    kv_heads = self.num_heads // self.kv_repeat
+                    q = projected[:, :, :embed_dim]
+                    start = embed_dim
+                    end = start + per_head_dim * kv_heads
+                    k = projected[:, :, start: end]
+                    v = projected[:, :, end:]
+                    q = rearrange(q, f"b t (h d) -> {layout}", h=self.num_heads)
+                    k = rearrange(k, f"b t (h d) -> {layout}", h=kv_heads)
+                    v = rearrange(v, f"b t (h d) -> {layout}", h=kv_heads)
+
+                if self.qk_layer_norm is True:
+                    assert self.kv_repeat == 1
+                    q, k = [rearrange(x, f"{layout} -> b t (h d)") for x in [q, k]]
+                    q = self.q_layer_norm(q)
+                    k = self.k_layer_norm(k)
+                    q, k = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k]]
+                if self.rope:
+                    q, k = self._apply_rope(q, k)
+                k, v = self._complete_kv(k, v)
+                if self.kv_repeat > 1:
+                    k = expand_repeated_kv(k, self.kv_repeat)
+                    v = expand_repeated_kv(v, self.kv_repeat)
+            if self.attention_as_float32:
+                q, k, v = [x.float() for x in [q, k, v]]
+            if self.memory_efficient:
+                p = self.dropout if self.training else 0
+                if _efficient_attention_backend == 'torch':
+                    x = torch.nn.functional.scaled_dot_product_attention(
+                        q, k, v, is_causal=attn_mask is not None, dropout_p=p)
+                else:
+                    x = ops.memory_efficient_attention(q, k, v, attn_mask, p=p)
+            else:
+                # We include the dot product as float32, for consistency
+                # with the other implementations that include that step
+                # as part of the attention. Note that when using `autocast`,
+                # the einsums would be done as bfloat16, but the softmax
+                # would be done as bfloat16, so `attention_as_float32` will
+                # extend a bit the range of operations done in float32,
+                # although this should make no difference.
+                q = q / q.shape[-1] ** 0.5
+                key_layout = layout.replace('t', 'k')
+                query_layout = layout
+                if self._is_streaming and self.safe_streaming and q.device.type == 'cuda':
+                    with torch.autocast(device_type=q.device.type, dtype=torch.float32):
+                        pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k)
+                else:
+                    pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k)
+                if attn_mask is not None:
+                    pre_w = pre_w + attn_mask
+                w = torch.softmax(pre_w, dim=-1)
+                w = F.dropout(w, self.dropout, training=self.training).to(v)
+                # Key and value have the same format.
+                x = torch.einsum(f"b h t k, {key_layout} -> {layout}", w, v)
+            x = x.to(dtype)
+            x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
+            x = self.out_proj(x)
+        else:
+            key, value = self._complete_kv(key, value)
+            if self.attention_as_float32:
+                query, key, value = [x.float() for x in [query, key, value]]
+            x, _ = self.mha(
+                query, key, value, key_padding_mask,
+                need_weights, attn_mask, average_attn_weights)
+            x = x.to(dtype)
+
+        return x, None
+
+

Ancestors

+ +

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Inherited members

+ +
+
+class StreamingTransformer +(d_model: int, num_heads: int, num_layers: int, dim_feedforward: int = 2048, dropout: float = 0.1, bias_ff: bool = True, bias_attn: bool = True, causal: bool = False, past_context: Optional[int] = None, custom: bool = False, memory_efficient: bool = False, attention_as_float32: bool = False, cross_attention: bool = False, layer_scale: Optional[float] = None, positional_embedding: str = 'sin', max_period: float = 10000, positional_scale: float = 1.0, xpos: bool = False, lr: Optional[float] = None, weight_decay: Optional[float] = None, layer_class: Type[StreamingTransformerLayer] = audiocraft.modules.transformer.StreamingTransformerLayer, checkpointing: str = 'none', device=None, dtype=None, **kwargs) +
+
+

Transformer with Streaming / Causal support.

+

Args

+
+
d_model : int
+
Dimension of the data.
+
num_heads : int
+
Number of heads.
+
dim_feedforward : int
+
Intermediate dimension of FF module.
+
dropout : float
+
Dropout both for MHA and FF.
+
bias_ff : bool
+
Use bias for FF.
+
bias_attn : bool
+
Use bias for MHA.
+
causal : bool
+
Causal mask applied automatically.
+
past_context : int or None
+
Receptive field for the causal mask, infinite if None.
+
custom : bool
+
Use custom MHA implementation, for testing / benchmarking.
+
memory_efficient : bool
+
Use xformers based memory efficient attention.
+
attention_as_float32 : bool
+
Perform the attention as float32 +(especially important with memory_efficient as autocast won't do this automatically).
+
cross_attention : bool
+
If True, expect to get secondary input for cross-attention.
+
layer_scale : float or None
+
If not None, LayerScale will be used +with the given value as initial scale.
+
positional_embedding : str
+
Positional embedding strategy (sin, rope, or sin_rope).
+
max_period : float
+
Maximum period of the time embedding.
+
positional_scale : float
+
Scale of positional embedding, set to 0 to deactivate.
+
xpos : bool
+
Apply xpos exponential decay to positional embedding (rope only).
+
lr : float or None
+
learning rate override through the make_optim_group API.
+
weight_decay : float or None
+
Weight_decay override through the make_optim_group API.
+
layer_class
+
(subclass of `StreamingTransformerLayer): class to use +to initialize the layers, allowing further customization outside of Audiocraft.
+
checkpointing : str
+
Checkpointing strategy to reduce memory usage. +No checkpointing if set to 'none'. Per layer checkpointing using PyTorch +if set to 'torch' (entire layer checkpointed, i.e. linears are evaluated twice, +minimal memory usage, but maximal runtime). Finally, xformers_default provide +a policy for opting-out some operations of the checkpointing like +linear layers and attention, providing a middle ground between speed and memory.
+
device : torch.device or None
+
Device on which to initialize.
+
dtype : torch.dtype or None
+
dtype to use.
+
**kwargs
+
See nn.TransformerEncoderLayer.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class StreamingTransformer(StreamingModule):
+    """Transformer with Streaming / Causal support.
+
+    Args:
+        d_model (int): Dimension of the data.
+        num_heads (int): Number of heads.
+        dim_feedforward (int): Intermediate dimension of FF module.
+        dropout (float): Dropout both for MHA and FF.
+        bias_ff (bool): Use bias for FF.
+        bias_attn (bool): Use bias for MHA.
+        causal (bool): Causal mask applied automatically.
+        past_context (int or None): Receptive field for the causal mask, infinite if None.
+        custom (bool): Use custom MHA implementation, for testing / benchmarking.
+        memory_efficient (bool): Use xformers based memory efficient attention.
+        attention_as_float32 (bool): Perform the attention as float32
+            (especially important with memory_efficient as autocast won't do this automatically).
+        cross_attention (bool): If True, expect to get secondary input for cross-attention.
+        layer_scale (float or None): If not None, LayerScale will be used
+            with the given value as initial scale.
+        positional_embedding (str): Positional embedding strategy (sin, rope, or sin_rope).
+        max_period (float): Maximum period of the time embedding.
+        positional_scale (float): Scale of positional embedding, set to 0 to deactivate.
+        xpos (bool): Apply xpos exponential decay to positional embedding (rope only).
+        lr (float or None): learning rate override through the `make_optim_group` API.
+        weight_decay (float or None): Weight_decay override through the `make_optim_group` API.
+        layer_class: (subclass of `StreamingTransformerLayer): class to use
+            to initialize the layers, allowing further customization outside of Audiocraft.
+        checkpointing (str): Checkpointing strategy to reduce memory usage.
+            No checkpointing if set to 'none'. Per layer checkpointing using PyTorch
+            if set to 'torch' (entire layer checkpointed, i.e. linears are evaluated twice,
+            minimal memory usage, but maximal runtime). Finally, `xformers_default` provide
+            a policy for opting-out some operations of the checkpointing like
+            linear layers and attention, providing a middle ground between speed and memory.
+        device (torch.device or None): Device on which to initialize.
+        dtype (torch.dtype or None): dtype to use.
+        **kwargs: See `nn.TransformerEncoderLayer`.
+    """
+    def __init__(self, d_model: int, num_heads: int, num_layers: int, dim_feedforward: int = 2048,
+                 dropout: float = 0.1, bias_ff: bool = True, bias_attn: bool = True,
+                 causal: bool = False, past_context: tp.Optional[int] = None,
+                 custom: bool = False, memory_efficient: bool = False, attention_as_float32: bool = False,
+                 cross_attention: bool = False, layer_scale: tp.Optional[float] = None,
+                 positional_embedding: str = 'sin', max_period: float = 10_000, positional_scale: float = 1.,
+                 xpos: bool = False, lr: tp.Optional[float] = None, weight_decay: tp.Optional[float] = None,
+                 layer_class: tp.Type[StreamingTransformerLayer] = StreamingTransformerLayer,
+                 checkpointing: str = 'none', device=None, dtype=None, **kwargs):
+        super().__init__()
+        assert d_model % num_heads == 0
+
+        self.positional_embedding = positional_embedding
+        self.max_period = max_period
+        self.positional_scale = positional_scale
+        self.weight_decay = weight_decay
+        self.lr = lr
+
+        assert positional_embedding in ['sin', 'rope', 'sin_rope']
+        self.rope: tp.Optional[RotaryEmbedding] = None
+        if self.positional_embedding in ['rope', 'sin_rope']:
+            assert _is_custom(custom, memory_efficient)
+            self.rope = RotaryEmbedding(d_model // num_heads, max_period=max_period,
+                                        xpos=xpos, scale=positional_scale, device=device)
+
+        self.checkpointing = checkpointing
+
+        assert checkpointing in ['none', 'torch', 'xformers_default', 'xformers_mm']
+        if self.checkpointing.startswith('xformers'):
+            _verify_xformers_internal_compat()
+
+        self.layers = nn.ModuleList()
+        for idx in range(num_layers):
+            self.layers.append(
+                layer_class(
+                    d_model=d_model, num_heads=num_heads, dim_feedforward=dim_feedforward,
+                    dropout=dropout, bias_ff=bias_ff, bias_attn=bias_attn,
+                    causal=causal, past_context=past_context, custom=custom,
+                    memory_efficient=memory_efficient, attention_as_float32=attention_as_float32,
+                    cross_attention=cross_attention, layer_scale=layer_scale, rope=self.rope,
+                    device=device, dtype=dtype, **kwargs))
+
+        if self.checkpointing != 'none':
+            for layer in self.layers:
+                # see audiocraft/optim/fsdp.py, magic signal to indicate this requires fixing the
+                # backward hook inside of FSDP...
+                layer._magma_checkpointed = True  # type: ignore
+                assert layer.layer_drop == 0., "Need further checking"  # type: ignore
+
+    def _apply_layer(self, layer, *args, **kwargs):
+        method = self.checkpointing
+        if method == 'none':
+            return layer(*args, **kwargs)
+        elif method == 'torch':
+            return torch_checkpoint(layer, *args, use_reentrant=False, **kwargs)
+        elif method.startswith('xformers'):
+            from xformers.checkpoint_fairinternal import checkpoint, _get_default_policy
+            if method == 'xformers_default':
+                # those operations will be saved, and not recomputed.
+                # According to Francisco we can get smarter policies but this is a good start.
+                allow_list = [
+                    "xformers.efficient_attention_forward_cutlass.default",
+                    "xformers_flash.flash_fwd.default",
+                    "aten.addmm.default",
+                    "aten.mm.default",
+                ]
+            elif method == 'xformers_mm':
+                # those operations will be saved, and not recomputed.
+                # According to Francisco we can get smarter policies but this is a good start.
+                allow_list = [
+                    "aten.addmm.default",
+                    "aten.mm.default",
+                ]
+            else:
+                raise ValueError(f"xformers checkpointing xformers policy {method} is not known.")
+            policy_fn = _get_default_policy(allow_list)
+            return checkpoint(layer, *args, policy_fn=policy_fn, **kwargs)
+        else:
+            raise ValueError(f"Checkpointing method {method} is unknown.")
+
+    def forward(self, x: torch.Tensor, *args, **kwargs):
+        B, T, C = x.shape
+
+        if 'offsets' in self._streaming_state:
+            offsets = self._streaming_state['offsets']
+        else:
+            offsets = torch.zeros(B, dtype=torch.long, device=x.device)
+
+        if self.positional_embedding in ['sin', 'sin_rope']:
+            positions = torch.arange(T, device=x.device).view(1, -1, 1)
+            positions = positions + offsets.view(-1, 1, 1)
+            pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype)
+            x = x + self.positional_scale * pos_emb
+
+        for layer in self.layers:
+            x = self._apply_layer(layer, x, *args, **kwargs)
+
+        if self._is_streaming:
+            self._streaming_state['offsets'] = offsets + T
+
+        return x
+
+    def make_optim_group(self):
+        group = {"params": list(self.parameters())}
+        if self.lr is not None:
+            group["lr"] = self.lr
+        if self.weight_decay is not None:
+            group["weight_decay"] = self.weight_decay
+        return group
+
+

Ancestors

+ +

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def make_optim_group(self) +
+
+
+
+ +Expand source code + +
def make_optim_group(self):
+    group = {"params": list(self.parameters())}
+    if self.lr is not None:
+        group["lr"] = self.lr
+    if self.weight_decay is not None:
+        group["weight_decay"] = self.weight_decay
+    return group
+
+
+
+

Inherited members

+ +
+
+class StreamingTransformerLayer +(d_model: int, num_heads: int, dim_feedforward: int = 2048, dropout: float = 0.1, bias_ff: bool = True, bias_attn: bool = True, causal: bool = False, past_context: Optional[int] = None, custom: bool = False, memory_efficient: bool = False, attention_as_float32: bool = False, qk_layer_norm: bool = False, qk_layer_norm_cross: bool = False, cross_attention: bool = False, layer_scale: Optional[float] = None, rope: Optional[RotaryEmbedding] = None, attention_dropout: Optional[float] = None, kv_repeat: int = 1, norm: str = 'layer_norm', device=None, dtype=None, **kwargs) +
+
+

TransformerLayer with Streaming / Causal support. +This also integrates cross_attention, when passing cross_attention=True, +rather than having two separate classes like in PyTorch.

+

Args

+
+
d_model : int
+
Dimension of the data.
+
num_heads : int
+
Number of heads.
+
dim_feedforward : int
+
Intermediate dimension of FF module.
+
dropout : float
+
Dropout both for MHA and FF.
+
bias_ff : bool
+
Use bias for FF.
+
bias_attn : bool
+
Use bias for MHA.
+
causal : bool
+
Causal mask applied automatically.
+
past_context : int or None
+
Receptive field for the causal mask, infinite if None.
+
custom : bool
+
Use custom MHA implementation, for testing / benchmarking.
+
memory_efficient : bool
+
Use xformers based memory efficient attention.
+
attention_as_float32 : bool
+
Perform the attention as float32 +(especially important with memory_efficient as autocast won't do this automatically).
+
qk_layer_norm : bool
+
Layer normalization applied to queries and keys before dot product in attention.
+
qk_layer_norm_cross : bool
+
Same for the cross attention.
+
cross_attention : bool
+
If True, expect to get secondary input for cross-attention. +Cross attention will use the default MHA, as it typically won't require +special treatment.
+
layer_scale : float or None
+
If not None, LayerScale will be used with +the given value as initial scale.
+
rope (RotaryEmbedding or None): Rope embedding to use.
+
attention_dropout : float or None
+
If not None, separate the value of the dimension dropout +in FFN and of the attention dropout.
+
kv_repeat : int
+
If > 1, will repeat keys and queries multiple times (need to divide num_heads). +This will lead to faster decoding time on A100 or other GPUs with tensorcore.
+
device : torch.device or None
+
Device on which to initialize.
+
dtype : torch.dtype or None
+
dtype to use.
+
**kwargs
+
See nn.TransformerEncoderLayer.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class StreamingTransformerLayer(nn.TransformerEncoderLayer):
+    """TransformerLayer with Streaming / Causal support.
+    This also integrates cross_attention, when passing `cross_attention=True`,
+    rather than having two separate classes like in PyTorch.
+
+    Args:
+        d_model (int): Dimension of the data.
+        num_heads (int): Number of heads.
+        dim_feedforward (int): Intermediate dimension of FF module.
+        dropout (float): Dropout both for MHA and FF.
+        bias_ff (bool): Use bias for FF.
+        bias_attn (bool): Use bias for MHA.
+        causal (bool): Causal mask applied automatically.
+        past_context (int or None): Receptive field for the causal mask, infinite if None.
+        custom (bool): Use custom MHA implementation, for testing / benchmarking.
+        memory_efficient (bool): Use xformers based memory efficient attention.
+        attention_as_float32 (bool): Perform the attention as float32
+            (especially important with memory_efficient as autocast won't do this automatically).
+        qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product in attention.
+        qk_layer_norm_cross (bool): Same for the cross attention.
+        cross_attention (bool): If True, expect to get secondary input for cross-attention.
+            Cross attention will use the default MHA, as it typically won't require
+            special treatment.
+        layer_scale (float or None): If not None, LayerScale will be used with
+            the given value as initial scale.
+        rope (`RotaryEmbedding` or None): Rope embedding to use.
+        attention_dropout (float or None): If not None, separate the value of the dimension dropout
+            in FFN and of the attention dropout.
+        kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads).
+            This will lead to faster decoding time on A100 or other GPUs with tensorcore.
+        device (torch.device or None): Device on which to initialize.
+        dtype (torch.dtype or None): dtype to use.
+        **kwargs: See `nn.TransformerEncoderLayer`.
+    """
+    def __init__(self, d_model: int, num_heads: int, dim_feedforward: int = 2048, dropout: float = 0.1,
+                 bias_ff: bool = True, bias_attn: bool = True, causal: bool = False,
+                 past_context: tp.Optional[int] = None, custom: bool = False,
+                 memory_efficient: bool = False, attention_as_float32: bool = False,
+                 qk_layer_norm: bool = False, qk_layer_norm_cross: bool = False,
+                 cross_attention: bool = False, layer_scale: tp.Optional[float] = None,
+                 rope: tp.Optional[RotaryEmbedding] = None, attention_dropout: tp.Optional[float] = None,
+                 kv_repeat: int = 1, norm: str = 'layer_norm', device=None, dtype=None, **kwargs):
+        super().__init__(d_model, num_heads, dim_feedforward, dropout,
+                         device=device, dtype=dtype, batch_first=True, **kwargs)
+        factory_kwargs = {'device': device, 'dtype': dtype}
+        # Redefine self_attn to our streaming multi-head attention
+        attn_kwargs: tp.Dict[str, tp.Any] = {
+            'embed_dim': d_model,
+            'num_heads': num_heads,
+            'dropout': dropout if attention_dropout is None else attention_dropout,
+            'bias': bias_attn,
+            'custom': custom,
+            'memory_efficient': memory_efficient,
+            'attention_as_float32': attention_as_float32,
+        }
+        self.self_attn: StreamingMultiheadAttention = StreamingMultiheadAttention(
+            causal=causal, past_context=past_context, rope=rope, qk_layer_norm=qk_layer_norm,
+            kv_repeat=kv_repeat, **attn_kwargs, **factory_kwargs)  # type: ignore
+        # Redefine feedforward layers to expose bias parameter
+        self.linear1 = nn.Linear(d_model, dim_feedforward, bias=bias_ff, **factory_kwargs)
+        self.linear2 = nn.Linear(dim_feedforward, d_model, bias=bias_ff, **factory_kwargs)
+
+        self.layer_scale_1: nn.Module
+        self.layer_scale_2: nn.Module
+        if layer_scale is None:
+            self.layer_scale_1 = nn.Identity()
+            self.layer_scale_2 = nn.Identity()
+        else:
+            self.layer_scale_1 = LayerScale(d_model, layer_scale, **factory_kwargs)
+            self.layer_scale_2 = LayerScale(d_model, layer_scale, **factory_kwargs)
+
+        self.cross_attention: tp.Optional[nn.Module] = None
+        if cross_attention:
+            self.cross_attention = StreamingMultiheadAttention(
+                cross_attention=True, qk_layer_norm=qk_layer_norm_cross,
+                **attn_kwargs, **factory_kwargs)
+            # Norm and dropout
+            self.dropout_cross = nn.Dropout(dropout)
+            # eps value matching that used in PyTorch reference implementation.
+            self.norm_cross = nn.LayerNorm(d_model, eps=1e-5, **factory_kwargs)
+            self.layer_scale_cross: nn.Module
+            if layer_scale is None:
+                self.layer_scale_cross = nn.Identity()
+            else:
+                self.layer_scale_cross = LayerScale(d_model, layer_scale, **factory_kwargs)
+        self.norm1 = create_norm_fn(norm, d_model, **factory_kwargs)  # type: ignore
+        self.norm2 = create_norm_fn(norm, d_model, **factory_kwargs)  # type: ignore
+
+    def _cross_attention_block(self, src: torch.Tensor,
+                               cross_attention_src: torch.Tensor) -> torch.Tensor:
+        assert self.cross_attention is not None
+        # queries are from src, keys and values from cross_attention_src.
+        x = self.cross_attention(
+            src, cross_attention_src, cross_attention_src, need_weights=False)[0]
+        return self.dropout_cross(x)  # type: ignore
+
+    def forward(self, src: torch.Tensor, src_mask: tp.Optional[torch.Tensor] = None,  # type: ignore
+                src_key_padding_mask: tp.Optional[torch.Tensor] = None,
+                cross_attention_src: tp.Optional[torch.Tensor] = None):
+        if self.cross_attention is None:
+            assert cross_attention_src is None
+        else:
+            assert cross_attention_src is not None
+        x = src
+        if self.norm_first:
+            x = x + self.layer_scale_1(
+                self._sa_block(self.norm1(x), src_mask, src_key_padding_mask))
+            if cross_attention_src is not None:
+                x = x + self.layer_scale_cross(
+                    self._cross_attention_block(
+                        self.norm_cross(x), cross_attention_src))
+            x = x + self.layer_scale_2(self._ff_block(self.norm2(x)))
+        else:
+            x = self.norm1(x + self.layer_scale_1(
+                self._sa_block(x, src_mask, src_key_padding_mask)))
+            if cross_attention_src is not None:
+                x = self.norm_cross(
+                    x + self.layer_scale_cross(
+                        self._cross_attention_block(src, cross_attention_src)))
+            x = self.norm2(x + self.layer_scale_2(self._ff_block(x)))
+        return x
+
+

Ancestors

+
    +
  • torch.nn.modules.transformer.TransformerEncoderLayer
  • +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def forward(self, src: torch.Tensor, src_mask: Optional[torch.Tensor] = None, src_key_padding_mask: Optional[torch.Tensor] = None, cross_attention_src: Optional[torch.Tensor] = None) ‑> Callable[..., Any] +
+
+

Pass the input through the encoder layer.

+

Args

+
+
src
+
the sequence to the encoder layer (required).
+
src_mask
+
the mask for the src sequence (optional).
+
is_causal
+
If specified, applies a causal mask as src_mask. +Default: False.
+
src_key_padding_mask
+
the mask for the src keys per batch (optional).
+
+

Shape

+

see the docs in Transformer class.

+
+ +Expand source code + +
def forward(self, src: torch.Tensor, src_mask: tp.Optional[torch.Tensor] = None,  # type: ignore
+            src_key_padding_mask: tp.Optional[torch.Tensor] = None,
+            cross_attention_src: tp.Optional[torch.Tensor] = None):
+    if self.cross_attention is None:
+        assert cross_attention_src is None
+    else:
+        assert cross_attention_src is not None
+    x = src
+    if self.norm_first:
+        x = x + self.layer_scale_1(
+            self._sa_block(self.norm1(x), src_mask, src_key_padding_mask))
+        if cross_attention_src is not None:
+            x = x + self.layer_scale_cross(
+                self._cross_attention_block(
+                    self.norm_cross(x), cross_attention_src))
+        x = x + self.layer_scale_2(self._ff_block(self.norm2(x)))
+    else:
+        x = self.norm1(x + self.layer_scale_1(
+            self._sa_block(x, src_mask, src_key_padding_mask)))
+        if cross_attention_src is not None:
+            x = self.norm_cross(
+                x + self.layer_scale_cross(
+                    self._cross_attention_block(src, cross_attention_src)))
+        x = self.norm2(x + self.layer_scale_2(self._ff_block(x)))
+    return x
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/audiocraft/quantization/base.html b/docs/audiocraft/quantization/base.html new file mode 100644 index 00000000..efe5b397 --- /dev/null +++ b/docs/audiocraft/quantization/base.html @@ -0,0 +1,566 @@ + + + + + + +audiocraft.quantization.base API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.quantization.base

+
+
+

Base class for all quantizers.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Base class for all quantizers.
+"""
+
+from dataclasses import dataclass, field
+import typing as tp
+
+import torch
+from torch import nn
+
+
+@dataclass
+class QuantizedResult:
+    x: torch.Tensor
+    codes: torch.Tensor
+    bandwidth: torch.Tensor  # bandwidth in kb/s used, per batch item.
+    penalty: tp.Optional[torch.Tensor] = None
+    metrics: dict = field(default_factory=dict)
+
+
+class BaseQuantizer(nn.Module):
+    """Base class for quantizers.
+    """
+
+    def forward(self, x: torch.Tensor, frame_rate: int) -> QuantizedResult:
+        """
+        Given input tensor x, returns first the quantized (or approximately quantized)
+        representation along with quantized codes, bandwidth, and any penalty term for the loss.
+        Finally, this returns a dict of metrics to update logging etc.
+        Frame rate must be passed so that the bandwidth is properly computed.
+        """
+        raise NotImplementedError()
+
+    def encode(self, x: torch.Tensor) -> torch.Tensor:
+        """Encode a given input tensor with the specified sample rate at the given bandwidth.
+        """
+        raise NotImplementedError()
+
+    def decode(self, codes: torch.Tensor) -> torch.Tensor:
+        """Decode the given codes to the quantized representation.
+        """
+        raise NotImplementedError()
+
+    @property
+    def total_codebooks(self):
+        """Total number of codebooks.
+        """
+        raise NotImplementedError()
+
+    @property
+    def num_codebooks(self):
+        """Number of active codebooks.
+        """
+        raise NotImplementedError()
+
+    def set_num_codebooks(self, n: int):
+        """Set the number of active codebooks.
+        """
+        raise NotImplementedError()
+
+
+class DummyQuantizer(BaseQuantizer):
+    """Fake quantizer that actually does not perform any quantization.
+    """
+    def __init__(self):
+        super().__init__()
+
+    def forward(self, x: torch.Tensor, frame_rate: int):
+        q = x.unsqueeze(1)
+        return QuantizedResult(x, q, torch.tensor(q.numel() * 32 * frame_rate / 1000 / len(x)).to(x))
+
+    def encode(self, x: torch.Tensor) -> torch.Tensor:
+        """Encode a given input tensor with the specified sample rate at the given bandwidth.
+        In the case of the DummyQuantizer, the codes are actually identical
+        to the input and resulting quantized representation as no quantization is done.
+        """
+        return x.unsqueeze(1)
+
+    def decode(self, codes: torch.Tensor) -> torch.Tensor:
+        """Decode the given codes to the quantized representation.
+        In the case of the DummyQuantizer, the codes are actually identical
+        to the input and resulting quantized representation as no quantization is done.
+        """
+        return codes.squeeze(1)
+
+    @property
+    def total_codebooks(self):
+        """Total number of codebooks.
+        """
+        return 1
+
+    @property
+    def num_codebooks(self):
+        """Total number of codebooks.
+        """
+        return self.total_codebooks
+
+    def set_num_codebooks(self, n: int):
+        """Set the number of active codebooks.
+        """
+        raise AttributeError("Cannot override the number of codebooks for the dummy quantizer")
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class BaseQuantizer +(*args, **kwargs) +
+
+

Base class for quantizers.

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class BaseQuantizer(nn.Module):
+    """Base class for quantizers.
+    """
+
+    def forward(self, x: torch.Tensor, frame_rate: int) -> QuantizedResult:
+        """
+        Given input tensor x, returns first the quantized (or approximately quantized)
+        representation along with quantized codes, bandwidth, and any penalty term for the loss.
+        Finally, this returns a dict of metrics to update logging etc.
+        Frame rate must be passed so that the bandwidth is properly computed.
+        """
+        raise NotImplementedError()
+
+    def encode(self, x: torch.Tensor) -> torch.Tensor:
+        """Encode a given input tensor with the specified sample rate at the given bandwidth.
+        """
+        raise NotImplementedError()
+
+    def decode(self, codes: torch.Tensor) -> torch.Tensor:
+        """Decode the given codes to the quantized representation.
+        """
+        raise NotImplementedError()
+
+    @property
+    def total_codebooks(self):
+        """Total number of codebooks.
+        """
+        raise NotImplementedError()
+
+    @property
+    def num_codebooks(self):
+        """Number of active codebooks.
+        """
+        raise NotImplementedError()
+
+    def set_num_codebooks(self, n: int):
+        """Set the number of active codebooks.
+        """
+        raise NotImplementedError()
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Subclasses

+ +

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Instance variables

+
+
var num_codebooks
+
+

Number of active codebooks.

+
+ +Expand source code + +
@property
+def num_codebooks(self):
+    """Number of active codebooks.
+    """
+    raise NotImplementedError()
+
+
+
var total_codebooks
+
+

Total number of codebooks.

+
+ +Expand source code + +
@property
+def total_codebooks(self):
+    """Total number of codebooks.
+    """
+    raise NotImplementedError()
+
+
+
+

Methods

+
+
+def decode(self, codes: torch.Tensor) ‑> torch.Tensor +
+
+

Decode the given codes to the quantized representation.

+
+ +Expand source code + +
def decode(self, codes: torch.Tensor) -> torch.Tensor:
+    """Decode the given codes to the quantized representation.
+    """
+    raise NotImplementedError()
+
+
+
+def encode(self, x: torch.Tensor) ‑> torch.Tensor +
+
+

Encode a given input tensor with the specified sample rate at the given bandwidth.

+
+ +Expand source code + +
def encode(self, x: torch.Tensor) -> torch.Tensor:
+    """Encode a given input tensor with the specified sample rate at the given bandwidth.
+    """
+    raise NotImplementedError()
+
+
+
+def forward(self, x: torch.Tensor, frame_rate: int) ‑> QuantizedResult +
+
+

Given input tensor x, returns first the quantized (or approximately quantized) +representation along with quantized codes, bandwidth, and any penalty term for the loss. +Finally, this returns a dict of metrics to update logging etc. +Frame rate must be passed so that the bandwidth is properly computed.

+
+ +Expand source code + +
def forward(self, x: torch.Tensor, frame_rate: int) -> QuantizedResult:
+    """
+    Given input tensor x, returns first the quantized (or approximately quantized)
+    representation along with quantized codes, bandwidth, and any penalty term for the loss.
+    Finally, this returns a dict of metrics to update logging etc.
+    Frame rate must be passed so that the bandwidth is properly computed.
+    """
+    raise NotImplementedError()
+
+
+
+def set_num_codebooks(self, n: int) +
+
+

Set the number of active codebooks.

+
+ +Expand source code + +
def set_num_codebooks(self, n: int):
+    """Set the number of active codebooks.
+    """
+    raise NotImplementedError()
+
+
+
+
+
+class DummyQuantizer +
+
+

Fake quantizer that actually does not perform any quantization.

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class DummyQuantizer(BaseQuantizer):
+    """Fake quantizer that actually does not perform any quantization.
+    """
+    def __init__(self):
+        super().__init__()
+
+    def forward(self, x: torch.Tensor, frame_rate: int):
+        q = x.unsqueeze(1)
+        return QuantizedResult(x, q, torch.tensor(q.numel() * 32 * frame_rate / 1000 / len(x)).to(x))
+
+    def encode(self, x: torch.Tensor) -> torch.Tensor:
+        """Encode a given input tensor with the specified sample rate at the given bandwidth.
+        In the case of the DummyQuantizer, the codes are actually identical
+        to the input and resulting quantized representation as no quantization is done.
+        """
+        return x.unsqueeze(1)
+
+    def decode(self, codes: torch.Tensor) -> torch.Tensor:
+        """Decode the given codes to the quantized representation.
+        In the case of the DummyQuantizer, the codes are actually identical
+        to the input and resulting quantized representation as no quantization is done.
+        """
+        return codes.squeeze(1)
+
+    @property
+    def total_codebooks(self):
+        """Total number of codebooks.
+        """
+        return 1
+
+    @property
+    def num_codebooks(self):
+        """Total number of codebooks.
+        """
+        return self.total_codebooks
+
+    def set_num_codebooks(self, n: int):
+        """Set the number of active codebooks.
+        """
+        raise AttributeError("Cannot override the number of codebooks for the dummy quantizer")
+
+

Ancestors

+ +

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Instance variables

+
+
var num_codebooks
+
+

Total number of codebooks.

+
+ +Expand source code + +
@property
+def num_codebooks(self):
+    """Total number of codebooks.
+    """
+    return self.total_codebooks
+
+
+
+

Methods

+
+
+def decode(self, codes: torch.Tensor) ‑> torch.Tensor +
+
+

Decode the given codes to the quantized representation. +In the case of the DummyQuantizer, the codes are actually identical +to the input and resulting quantized representation as no quantization is done.

+
+ +Expand source code + +
def decode(self, codes: torch.Tensor) -> torch.Tensor:
+    """Decode the given codes to the quantized representation.
+    In the case of the DummyQuantizer, the codes are actually identical
+    to the input and resulting quantized representation as no quantization is done.
+    """
+    return codes.squeeze(1)
+
+
+
+def encode(self, x: torch.Tensor) ‑> torch.Tensor +
+
+

Encode a given input tensor with the specified sample rate at the given bandwidth. +In the case of the DummyQuantizer, the codes are actually identical +to the input and resulting quantized representation as no quantization is done.

+
+ +Expand source code + +
def encode(self, x: torch.Tensor) -> torch.Tensor:
+    """Encode a given input tensor with the specified sample rate at the given bandwidth.
+    In the case of the DummyQuantizer, the codes are actually identical
+    to the input and resulting quantized representation as no quantization is done.
+    """
+    return x.unsqueeze(1)
+
+
+
+

Inherited members

+ +
+
+class QuantizedResult +(x: torch.Tensor, codes: torch.Tensor, bandwidth: torch.Tensor, penalty: Optional[torch.Tensor] = None, metrics: dict = <factory>) +
+
+

QuantizedResult(x: torch.Tensor, codes: torch.Tensor, bandwidth: torch.Tensor, penalty: Union[torch.Tensor, NoneType] = None, metrics: dict = )

+
+ +Expand source code + +
class QuantizedResult:
+    x: torch.Tensor
+    codes: torch.Tensor
+    bandwidth: torch.Tensor  # bandwidth in kb/s used, per batch item.
+    penalty: tp.Optional[torch.Tensor] = None
+    metrics: dict = field(default_factory=dict)
+
+

Class variables

+
+
var bandwidth : torch.Tensor
+
+
+
+
var codes : torch.Tensor
+
+
+
+
var metrics : dict
+
+
+
+
var penalty : Optional[torch.Tensor]
+
+
+
+
var x : torch.Tensor
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/audiocraft/quantization/core_vq.html b/docs/audiocraft/quantization/core_vq.html new file mode 100644 index 00000000..99610654 --- /dev/null +++ b/docs/audiocraft/quantization/core_vq.html @@ -0,0 +1,1538 @@ + + + + + + +audiocraft.quantization.core_vq API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.quantization.core_vq

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import typing as tp
+
+from einops import rearrange, repeat
+import flashy
+import torch
+from torch import nn, einsum
+import torch.nn.functional as F
+
+
+def exists(val: tp.Optional[tp.Any]) -> bool:
+    return val is not None
+
+
+def default(val: tp.Any, d: tp.Any) -> tp.Any:
+    return val if exists(val) else d
+
+
+def l2norm(t):
+    return F.normalize(t, p=2, dim=-1)
+
+
+def ema_inplace(moving_avg, new, decay: float):
+    moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
+
+
+def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
+    return (x + epsilon) / (x.sum() + n_categories * epsilon)
+
+
+def uniform_init(*shape: int):
+    t = torch.empty(shape)
+    nn.init.kaiming_uniform_(t)
+    return t
+
+
+def sample_vectors(samples, num: int):
+    num_samples, device = samples.shape[0], samples.device
+
+    if num_samples >= num:
+        indices = torch.randperm(num_samples, device=device)[:num]
+    else:
+        indices = torch.randint(0, num_samples, (num,), device=device)
+
+    return samples[indices]
+
+
+def kmeans(samples, num_clusters: int, num_iters: int = 10):
+    dim, dtype = samples.shape[-1], samples.dtype
+
+    means = sample_vectors(samples, num_clusters)
+
+    for _ in range(num_iters):
+        diffs = rearrange(samples, "n d -> n () d") - rearrange(
+            means, "c d -> () c d"
+        )
+        dists = -(diffs ** 2).sum(dim=-1)
+
+        buckets = dists.max(dim=-1).indices
+        bins = torch.bincount(buckets, minlength=num_clusters)
+        zero_mask = bins == 0
+        bins_min_clamped = bins.masked_fill(zero_mask, 1)
+
+        new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
+        new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
+        new_means = new_means / bins_min_clamped[..., None]
+
+        means = torch.where(zero_mask[..., None], means, new_means)
+
+    return means, bins
+
+
+def orthgonal_loss_fn(t):
+    # eq (2) from https://arxiv.org/abs/2112.00384
+    n = t.shape[0]
+    normed_codes = l2norm(t)
+    identity = torch.eye(n, device=t.device)
+    cosine_sim = einsum("i d, j d -> i j", normed_codes, normed_codes)
+    return ((cosine_sim - identity) ** 2).sum() / (n ** 2)
+
+
+class EuclideanCodebook(nn.Module):
+    """Codebook with Euclidean distance.
+
+    Args:
+        dim (int): Dimension.
+        codebook_size (int): Codebook size.
+        kmeans_init (bool): Whether to use k-means to initialize the codebooks.
+            If set to true, run the k-means algorithm on the first training batch and use
+            the learned centroids as initialization.
+        kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
+        decay (float): Decay for exponential moving average over the codebooks.
+        epsilon (float): Epsilon value for numerical stability.
+        threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
+            that have an exponential moving average cluster size less than the specified threshold with
+            randomly selected vector from the current batch.
+    """
+    def __init__(
+        self,
+        dim: int,
+        codebook_size: int,
+        kmeans_init: int = False,
+        kmeans_iters: int = 10,
+        decay: float = 0.8,
+        epsilon: float = 1e-5,
+        threshold_ema_dead_code: int = 2,
+    ):
+        super().__init__()
+        self.decay = decay
+        init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
+        embed = init_fn(codebook_size, dim)
+
+        self.codebook_size = codebook_size
+
+        self.kmeans_iters = kmeans_iters
+        self.epsilon = epsilon
+        self.threshold_ema_dead_code = threshold_ema_dead_code
+
+        self.register_buffer("inited", torch.Tensor([not kmeans_init]))
+        self.register_buffer("cluster_size", torch.zeros(codebook_size))
+        self.register_buffer("embed", embed)
+        self.register_buffer("embed_avg", embed.clone())
+
+    @torch.jit.ignore
+    def init_embed_(self, data):
+        if self.inited:
+            return
+
+        embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
+        self.embed.data.copy_(embed)
+        self.embed_avg.data.copy_(embed.clone())
+        self.cluster_size.data.copy_(cluster_size)
+        self.inited.data.copy_(torch.Tensor([True]))
+        # Make sure all buffers across workers are in sync after initialization
+        flashy.distrib.broadcast_tensors(self.buffers())
+
+    def replace_(self, samples, mask):
+        modified_codebook = torch.where(
+            mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
+        )
+        self.embed.data.copy_(modified_codebook)
+
+    def expire_codes_(self, batch_samples):
+        if self.threshold_ema_dead_code == 0:
+            return
+
+        expired_codes = self.cluster_size < self.threshold_ema_dead_code
+        if not torch.any(expired_codes):
+            return
+
+        batch_samples = rearrange(batch_samples, "... d -> (...) d")
+        self.replace_(batch_samples, mask=expired_codes)
+        flashy.distrib.broadcast_tensors(self.buffers())
+
+    def preprocess(self, x):
+        x = rearrange(x, "... d -> (...) d")
+        return x
+
+    def quantize(self, x):
+        embed = self.embed.t()
+        dist = -(
+            x.pow(2).sum(1, keepdim=True)
+            - 2 * x @ embed
+            + embed.pow(2).sum(0, keepdim=True)
+        )
+        embed_ind = dist.max(dim=-1).indices
+        return embed_ind
+
+    def postprocess_emb(self, embed_ind, shape):
+        return embed_ind.view(*shape[:-1])
+
+    def dequantize(self, embed_ind):
+        quantize = F.embedding(embed_ind, self.embed)
+        return quantize
+
+    def encode(self, x):
+        shape = x.shape
+        # pre-process
+        x = self.preprocess(x)
+        # quantize
+        embed_ind = self.quantize(x)
+        # post-process
+        embed_ind = self.postprocess_emb(embed_ind, shape)
+        return embed_ind
+
+    def decode(self, embed_ind):
+        quantize = self.dequantize(embed_ind)
+        return quantize
+
+    def forward(self, x):
+        shape, dtype = x.shape, x.dtype
+        x = self.preprocess(x)
+        self.init_embed_(x)
+
+        embed_ind = self.quantize(x)
+        embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
+        embed_ind = self.postprocess_emb(embed_ind, shape)
+        quantize = self.dequantize(embed_ind)
+
+        if self.training:
+            # We do the expiry of code at that point as buffers are in sync
+            # and all the workers will take the same decision.
+            self.expire_codes_(x)
+            ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
+            embed_sum = x.t() @ embed_onehot
+            ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
+            cluster_size = (
+                laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
+                * self.cluster_size.sum()
+            )
+            embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
+            self.embed.data.copy_(embed_normalized)
+
+        return quantize, embed_ind
+
+
+class VectorQuantization(nn.Module):
+    """Vector quantization implementation.
+    Currently supports only euclidean distance.
+
+    Args:
+        dim (int): Dimension
+        codebook_size (int): Codebook size
+        codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
+        decay (float): Decay for exponential moving average over the codebooks.
+        epsilon (float): Epsilon value for numerical stability.
+        kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
+        kmeans_iters (int): Number of iterations used for kmeans initialization.
+        threshold_ema_dead_code (int):
+        channels_last (bool): Channels are the last dimension in the input tensors.
+        commitment_weight (float): Weight for commitment loss.
+        orthogonal_reg_weight (float): Orthogonal regularization weights.
+        orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes.
+        orthogonal_reg_max_codes (optional int): Maximum number of codes to consider
+            for orthogonal regulariation.
+        threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
+            that have an exponential moving average cluster size less than the specified threshold with
+            randomly selected vector from the current batch.
+    """
+    def __init__(
+        self,
+        dim: int,
+        codebook_size: int,
+        codebook_dim: tp.Optional[int] = None,
+        decay: float = 0.8,
+        epsilon: float = 1e-5,
+        kmeans_init: bool = False,
+        kmeans_iters: int = 10,
+        threshold_ema_dead_code: int = 2,
+        channels_last: bool = False,
+        commitment_weight: float = 1.,
+        orthogonal_reg_weight: float = 0.0,
+        orthogonal_reg_active_codes_only: bool = False,
+        orthogonal_reg_max_codes: tp.Optional[int] = None,
+    ):
+        super().__init__()
+        _codebook_dim: int = default(codebook_dim, dim)
+
+        requires_projection = _codebook_dim != dim
+        self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity())
+        self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity())
+
+        self.epsilon = epsilon
+        self.commitment_weight = commitment_weight
+
+        self.orthogonal_reg_weight = orthogonal_reg_weight
+        self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
+        self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
+
+        self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
+                                           kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
+                                           decay=decay, epsilon=epsilon,
+                                           threshold_ema_dead_code=threshold_ema_dead_code)
+        self.codebook_size = codebook_size
+
+        self.channels_last = channels_last
+
+    @property
+    def codebook(self):
+        return self._codebook.embed
+
+    @property
+    def inited(self):
+        return self._codebook.inited
+
+    def _preprocess(self, x):
+        if not self.channels_last:
+            x = rearrange(x, "b d n -> b n d")
+        return x
+
+    def _postprocess(self, quantize):
+        if not self.channels_last:
+            quantize = rearrange(quantize, "b n d -> b d n")
+        return quantize
+
+    def encode(self, x):
+        x = self._preprocess(x)
+        x = self.project_in(x)
+        embed_in = self._codebook.encode(x)
+        return embed_in
+
+    def decode(self, embed_ind):
+        quantize = self._codebook.decode(embed_ind)
+        quantize = self.project_out(quantize)
+        quantize = self._postprocess(quantize)
+        return quantize
+
+    def forward(self, x):
+        device = x.device
+        x = self._preprocess(x)
+
+        x = self.project_in(x)
+        quantize, embed_ind = self._codebook(x)
+
+        if self.training:
+            quantize = x + (quantize - x).detach()
+
+        loss = torch.tensor([0.0], device=device, requires_grad=self.training)
+
+        if self.training:
+            if self.commitment_weight > 0:
+                commit_loss = F.mse_loss(quantize.detach(), x)
+                loss = loss + commit_loss * self.commitment_weight
+
+            if self.orthogonal_reg_weight > 0:
+                codebook = self.codebook
+
+                if self.orthogonal_reg_active_codes_only:
+                    # only calculate orthogonal loss for the activated codes for this batch
+                    unique_code_ids = torch.unique(embed_ind)
+                    codebook = codebook[unique_code_ids]
+
+                num_codes = codebook.shape[0]
+                if exists(self.orthogonal_reg_max_codes) and num_codes > self.orthogonal_reg_max_codes:
+                    rand_ids = torch.randperm(num_codes, device=device)[:self.orthogonal_reg_max_codes]
+                    codebook = codebook[rand_ids]
+
+                orthogonal_reg_loss = orthgonal_loss_fn(codebook)
+                loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight
+
+        quantize = self.project_out(quantize)
+        quantize = self._postprocess(quantize)
+
+        return quantize, embed_ind, loss
+
+
+class ResidualVectorQuantization(nn.Module):
+    """Residual vector quantization implementation.
+
+    Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
+    """
+    def __init__(self, *, num_quantizers, **kwargs):
+        super().__init__()
+        self.layers = nn.ModuleList(
+            [VectorQuantization(**kwargs) for _ in range(num_quantizers)]
+        )
+
+    def forward(self, x, n_q: tp.Optional[int] = None):
+        quantized_out = 0.0
+        residual = x
+
+        all_losses = []
+        all_indices = []
+
+        n_q = n_q or len(self.layers)
+
+        for i, layer in enumerate(self.layers[:n_q]):
+            quantized, indices, loss = layer(residual)
+            residual = residual - quantized
+            quantized_out = quantized_out + quantized
+            all_indices.append(indices)
+            all_losses.append(loss)
+
+        out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
+        return quantized_out, out_indices, out_losses
+
+    def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
+        residual = x
+        all_indices = []
+        n_q = n_q or len(self.layers)
+        for layer in self.layers[:n_q]:
+            indices = layer.encode(residual)
+            quantized = layer.decode(indices)
+            residual = residual - quantized
+            all_indices.append(indices)
+        out_indices = torch.stack(all_indices)
+        return out_indices
+
+    def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
+        quantized_out = torch.tensor(0.0, device=q_indices.device)
+        for i, indices in enumerate(q_indices):
+            layer = self.layers[i]
+            quantized = layer.decode(indices)
+            quantized_out = quantized_out + quantized
+        return quantized_out
+
+
+
+
+
+
+
+

Functions

+
+
+def default(val: Any, d: Any) ‑> Any +
+
+
+
+ +Expand source code + +
def default(val: tp.Any, d: tp.Any) -> tp.Any:
+    return val if exists(val) else d
+
+
+
+def ema_inplace(moving_avg, new, decay: float) +
+
+
+
+ +Expand source code + +
def ema_inplace(moving_avg, new, decay: float):
+    moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
+
+
+
+def exists(val: Optional[Any]) ‑> bool +
+
+
+
+ +Expand source code + +
def exists(val: tp.Optional[tp.Any]) -> bool:
+    return val is not None
+
+
+
+def kmeans(samples, num_clusters: int, num_iters: int = 10) +
+
+
+
+ +Expand source code + +
def kmeans(samples, num_clusters: int, num_iters: int = 10):
+    dim, dtype = samples.shape[-1], samples.dtype
+
+    means = sample_vectors(samples, num_clusters)
+
+    for _ in range(num_iters):
+        diffs = rearrange(samples, "n d -> n () d") - rearrange(
+            means, "c d -> () c d"
+        )
+        dists = -(diffs ** 2).sum(dim=-1)
+
+        buckets = dists.max(dim=-1).indices
+        bins = torch.bincount(buckets, minlength=num_clusters)
+        zero_mask = bins == 0
+        bins_min_clamped = bins.masked_fill(zero_mask, 1)
+
+        new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
+        new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
+        new_means = new_means / bins_min_clamped[..., None]
+
+        means = torch.where(zero_mask[..., None], means, new_means)
+
+    return means, bins
+
+
+
+def l2norm(t) +
+
+
+
+ +Expand source code + +
def l2norm(t):
+    return F.normalize(t, p=2, dim=-1)
+
+
+
+def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-05) +
+
+
+
+ +Expand source code + +
def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
+    return (x + epsilon) / (x.sum() + n_categories * epsilon)
+
+
+
+def orthgonal_loss_fn(t) +
+
+
+
+ +Expand source code + +
def orthgonal_loss_fn(t):
+    # eq (2) from https://arxiv.org/abs/2112.00384
+    n = t.shape[0]
+    normed_codes = l2norm(t)
+    identity = torch.eye(n, device=t.device)
+    cosine_sim = einsum("i d, j d -> i j", normed_codes, normed_codes)
+    return ((cosine_sim - identity) ** 2).sum() / (n ** 2)
+
+
+
+def sample_vectors(samples, num: int) +
+
+
+
+ +Expand source code + +
def sample_vectors(samples, num: int):
+    num_samples, device = samples.shape[0], samples.device
+
+    if num_samples >= num:
+        indices = torch.randperm(num_samples, device=device)[:num]
+    else:
+        indices = torch.randint(0, num_samples, (num,), device=device)
+
+    return samples[indices]
+
+
+
+def uniform_init(*shape: int) +
+
+
+
+ +Expand source code + +
def uniform_init(*shape: int):
+    t = torch.empty(shape)
+    nn.init.kaiming_uniform_(t)
+    return t
+
+
+
+
+
+

Classes

+
+
+class EuclideanCodebook +(dim: int, codebook_size: int, kmeans_init: int = False, kmeans_iters: int = 10, decay: float = 0.8, epsilon: float = 1e-05, threshold_ema_dead_code: int = 2) +
+
+

Codebook with Euclidean distance.

+

Args

+
+
dim : int
+
Dimension.
+
codebook_size : int
+
Codebook size.
+
kmeans_init : bool
+
Whether to use k-means to initialize the codebooks. +If set to true, run the k-means algorithm on the first training batch and use +the learned centroids as initialization.
+
kmeans_iters : int
+
Number of iterations used for k-means algorithm at initialization.
+
decay : float
+
Decay for exponential moving average over the codebooks.
+
epsilon : float
+
Epsilon value for numerical stability.
+
threshold_ema_dead_code : int
+
Threshold for dead code expiration. Replace any codes +that have an exponential moving average cluster size less than the specified threshold with +randomly selected vector from the current batch.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class EuclideanCodebook(nn.Module):
+    """Codebook with Euclidean distance.
+
+    Args:
+        dim (int): Dimension.
+        codebook_size (int): Codebook size.
+        kmeans_init (bool): Whether to use k-means to initialize the codebooks.
+            If set to true, run the k-means algorithm on the first training batch and use
+            the learned centroids as initialization.
+        kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
+        decay (float): Decay for exponential moving average over the codebooks.
+        epsilon (float): Epsilon value for numerical stability.
+        threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
+            that have an exponential moving average cluster size less than the specified threshold with
+            randomly selected vector from the current batch.
+    """
+    def __init__(
+        self,
+        dim: int,
+        codebook_size: int,
+        kmeans_init: int = False,
+        kmeans_iters: int = 10,
+        decay: float = 0.8,
+        epsilon: float = 1e-5,
+        threshold_ema_dead_code: int = 2,
+    ):
+        super().__init__()
+        self.decay = decay
+        init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
+        embed = init_fn(codebook_size, dim)
+
+        self.codebook_size = codebook_size
+
+        self.kmeans_iters = kmeans_iters
+        self.epsilon = epsilon
+        self.threshold_ema_dead_code = threshold_ema_dead_code
+
+        self.register_buffer("inited", torch.Tensor([not kmeans_init]))
+        self.register_buffer("cluster_size", torch.zeros(codebook_size))
+        self.register_buffer("embed", embed)
+        self.register_buffer("embed_avg", embed.clone())
+
+    @torch.jit.ignore
+    def init_embed_(self, data):
+        if self.inited:
+            return
+
+        embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
+        self.embed.data.copy_(embed)
+        self.embed_avg.data.copy_(embed.clone())
+        self.cluster_size.data.copy_(cluster_size)
+        self.inited.data.copy_(torch.Tensor([True]))
+        # Make sure all buffers across workers are in sync after initialization
+        flashy.distrib.broadcast_tensors(self.buffers())
+
+    def replace_(self, samples, mask):
+        modified_codebook = torch.where(
+            mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
+        )
+        self.embed.data.copy_(modified_codebook)
+
+    def expire_codes_(self, batch_samples):
+        if self.threshold_ema_dead_code == 0:
+            return
+
+        expired_codes = self.cluster_size < self.threshold_ema_dead_code
+        if not torch.any(expired_codes):
+            return
+
+        batch_samples = rearrange(batch_samples, "... d -> (...) d")
+        self.replace_(batch_samples, mask=expired_codes)
+        flashy.distrib.broadcast_tensors(self.buffers())
+
+    def preprocess(self, x):
+        x = rearrange(x, "... d -> (...) d")
+        return x
+
+    def quantize(self, x):
+        embed = self.embed.t()
+        dist = -(
+            x.pow(2).sum(1, keepdim=True)
+            - 2 * x @ embed
+            + embed.pow(2).sum(0, keepdim=True)
+        )
+        embed_ind = dist.max(dim=-1).indices
+        return embed_ind
+
+    def postprocess_emb(self, embed_ind, shape):
+        return embed_ind.view(*shape[:-1])
+
+    def dequantize(self, embed_ind):
+        quantize = F.embedding(embed_ind, self.embed)
+        return quantize
+
+    def encode(self, x):
+        shape = x.shape
+        # pre-process
+        x = self.preprocess(x)
+        # quantize
+        embed_ind = self.quantize(x)
+        # post-process
+        embed_ind = self.postprocess_emb(embed_ind, shape)
+        return embed_ind
+
+    def decode(self, embed_ind):
+        quantize = self.dequantize(embed_ind)
+        return quantize
+
+    def forward(self, x):
+        shape, dtype = x.shape, x.dtype
+        x = self.preprocess(x)
+        self.init_embed_(x)
+
+        embed_ind = self.quantize(x)
+        embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
+        embed_ind = self.postprocess_emb(embed_ind, shape)
+        quantize = self.dequantize(embed_ind)
+
+        if self.training:
+            # We do the expiry of code at that point as buffers are in sync
+            # and all the workers will take the same decision.
+            self.expire_codes_(x)
+            ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
+            embed_sum = x.t() @ embed_onehot
+            ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
+            cluster_size = (
+                laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
+                * self.cluster_size.sum()
+            )
+            embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
+            self.embed.data.copy_(embed_normalized)
+
+        return quantize, embed_ind
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def decode(self, embed_ind) +
+
+
+
+ +Expand source code + +
def decode(self, embed_ind):
+    quantize = self.dequantize(embed_ind)
+    return quantize
+
+
+
+def dequantize(self, embed_ind) +
+
+
+
+ +Expand source code + +
def dequantize(self, embed_ind):
+    quantize = F.embedding(embed_ind, self.embed)
+    return quantize
+
+
+
+def encode(self, x) +
+
+
+
+ +Expand source code + +
def encode(self, x):
+    shape = x.shape
+    # pre-process
+    x = self.preprocess(x)
+    # quantize
+    embed_ind = self.quantize(x)
+    # post-process
+    embed_ind = self.postprocess_emb(embed_ind, shape)
+    return embed_ind
+
+
+
+def expire_codes_(self, batch_samples) +
+
+
+
+ +Expand source code + +
def expire_codes_(self, batch_samples):
+    if self.threshold_ema_dead_code == 0:
+        return
+
+    expired_codes = self.cluster_size < self.threshold_ema_dead_code
+    if not torch.any(expired_codes):
+        return
+
+    batch_samples = rearrange(batch_samples, "... d -> (...) d")
+    self.replace_(batch_samples, mask=expired_codes)
+    flashy.distrib.broadcast_tensors(self.buffers())
+
+
+
+def forward(self, x) ‑> Callable[..., Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    shape, dtype = x.shape, x.dtype
+    x = self.preprocess(x)
+    self.init_embed_(x)
+
+    embed_ind = self.quantize(x)
+    embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
+    embed_ind = self.postprocess_emb(embed_ind, shape)
+    quantize = self.dequantize(embed_ind)
+
+    if self.training:
+        # We do the expiry of code at that point as buffers are in sync
+        # and all the workers will take the same decision.
+        self.expire_codes_(x)
+        ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
+        embed_sum = x.t() @ embed_onehot
+        ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
+        cluster_size = (
+            laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
+            * self.cluster_size.sum()
+        )
+        embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
+        self.embed.data.copy_(embed_normalized)
+
+    return quantize, embed_ind
+
+
+
+def init_embed_(self, data) +
+
+
+
+ +Expand source code + +
@torch.jit.ignore
+def init_embed_(self, data):
+    if self.inited:
+        return
+
+    embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
+    self.embed.data.copy_(embed)
+    self.embed_avg.data.copy_(embed.clone())
+    self.cluster_size.data.copy_(cluster_size)
+    self.inited.data.copy_(torch.Tensor([True]))
+    # Make sure all buffers across workers are in sync after initialization
+    flashy.distrib.broadcast_tensors(self.buffers())
+
+
+
+def postprocess_emb(self, embed_ind, shape) +
+
+
+
+ +Expand source code + +
def postprocess_emb(self, embed_ind, shape):
+    return embed_ind.view(*shape[:-1])
+
+
+
+def preprocess(self, x) +
+
+
+
+ +Expand source code + +
def preprocess(self, x):
+    x = rearrange(x, "... d -> (...) d")
+    return x
+
+
+
+def quantize(self, x) +
+
+
+
+ +Expand source code + +
def quantize(self, x):
+    embed = self.embed.t()
+    dist = -(
+        x.pow(2).sum(1, keepdim=True)
+        - 2 * x @ embed
+        + embed.pow(2).sum(0, keepdim=True)
+    )
+    embed_ind = dist.max(dim=-1).indices
+    return embed_ind
+
+
+
+def replace_(self, samples, mask) +
+
+
+
+ +Expand source code + +
def replace_(self, samples, mask):
+    modified_codebook = torch.where(
+        mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
+    )
+    self.embed.data.copy_(modified_codebook)
+
+
+
+
+
+class ResidualVectorQuantization +(*, num_quantizers, **kwargs) +
+
+

Residual vector quantization implementation.

+

Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf

+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class ResidualVectorQuantization(nn.Module):
+    """Residual vector quantization implementation.
+
+    Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
+    """
+    def __init__(self, *, num_quantizers, **kwargs):
+        super().__init__()
+        self.layers = nn.ModuleList(
+            [VectorQuantization(**kwargs) for _ in range(num_quantizers)]
+        )
+
+    def forward(self, x, n_q: tp.Optional[int] = None):
+        quantized_out = 0.0
+        residual = x
+
+        all_losses = []
+        all_indices = []
+
+        n_q = n_q or len(self.layers)
+
+        for i, layer in enumerate(self.layers[:n_q]):
+            quantized, indices, loss = layer(residual)
+            residual = residual - quantized
+            quantized_out = quantized_out + quantized
+            all_indices.append(indices)
+            all_losses.append(loss)
+
+        out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
+        return quantized_out, out_indices, out_losses
+
+    def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
+        residual = x
+        all_indices = []
+        n_q = n_q or len(self.layers)
+        for layer in self.layers[:n_q]:
+            indices = layer.encode(residual)
+            quantized = layer.decode(indices)
+            residual = residual - quantized
+            all_indices.append(indices)
+        out_indices = torch.stack(all_indices)
+        return out_indices
+
+    def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
+        quantized_out = torch.tensor(0.0, device=q_indices.device)
+        for i, indices in enumerate(q_indices):
+            layer = self.layers[i]
+            quantized = layer.decode(indices)
+            quantized_out = quantized_out + quantized
+        return quantized_out
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def decode(self, q_indices: torch.Tensor) ‑> torch.Tensor +
+
+
+
+ +Expand source code + +
def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
+    quantized_out = torch.tensor(0.0, device=q_indices.device)
+    for i, indices in enumerate(q_indices):
+        layer = self.layers[i]
+        quantized = layer.decode(indices)
+        quantized_out = quantized_out + quantized
+    return quantized_out
+
+
+
+def encode(self, x: torch.Tensor, n_q: Optional[int] = None) ‑> torch.Tensor +
+
+
+
+ +Expand source code + +
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
+    residual = x
+    all_indices = []
+    n_q = n_q or len(self.layers)
+    for layer in self.layers[:n_q]:
+        indices = layer.encode(residual)
+        quantized = layer.decode(indices)
+        residual = residual - quantized
+        all_indices.append(indices)
+    out_indices = torch.stack(all_indices)
+    return out_indices
+
+
+
+def forward(self, x, n_q: Optional[int] = None) ‑> Callable[..., Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x, n_q: tp.Optional[int] = None):
+    quantized_out = 0.0
+    residual = x
+
+    all_losses = []
+    all_indices = []
+
+    n_q = n_q or len(self.layers)
+
+    for i, layer in enumerate(self.layers[:n_q]):
+        quantized, indices, loss = layer(residual)
+        residual = residual - quantized
+        quantized_out = quantized_out + quantized
+        all_indices.append(indices)
+        all_losses.append(loss)
+
+    out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
+    return quantized_out, out_indices, out_losses
+
+
+
+
+
+class VectorQuantization +(dim: int, codebook_size: int, codebook_dim: Optional[int] = None, decay: float = 0.8, epsilon: float = 1e-05, kmeans_init: bool = False, kmeans_iters: int = 10, threshold_ema_dead_code: int = 2, channels_last: bool = False, commitment_weight: float = 1.0, orthogonal_reg_weight: float = 0.0, orthogonal_reg_active_codes_only: bool = False, orthogonal_reg_max_codes: Optional[int] = None) +
+
+

Vector quantization implementation. +Currently supports only euclidean distance.

+

Args

+
+
dim : int
+
Dimension
+
codebook_size : int
+
Codebook size
+
codebook_dim : int
+
Codebook dimension. If not defined, uses the specified dimension in dim.
+
decay : float
+
Decay for exponential moving average over the codebooks.
+
epsilon : float
+
Epsilon value for numerical stability.
+
kmeans_init : bool
+
Whether to use kmeans to initialize the codebooks.
+
kmeans_iters : int
+
Number of iterations used for kmeans initialization.
+
threshold_ema_dead_code (int):
+
channels_last : bool
+
Channels are the last dimension in the input tensors.
+
commitment_weight : float
+
Weight for commitment loss.
+
orthogonal_reg_weight : float
+
Orthogonal regularization weights.
+
orthogonal_reg_active_codes_only : bool
+
Apply orthogonal regularization only on active codes.
+
orthogonal_reg_max_codes : optional int
+
Maximum number of codes to consider +for orthogonal regulariation.
+
threshold_ema_dead_code : int
+
Threshold for dead code expiration. Replace any codes +that have an exponential moving average cluster size less than the specified threshold with +randomly selected vector from the current batch.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class VectorQuantization(nn.Module):
+    """Vector quantization implementation.
+    Currently supports only euclidean distance.
+
+    Args:
+        dim (int): Dimension
+        codebook_size (int): Codebook size
+        codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
+        decay (float): Decay for exponential moving average over the codebooks.
+        epsilon (float): Epsilon value for numerical stability.
+        kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
+        kmeans_iters (int): Number of iterations used for kmeans initialization.
+        threshold_ema_dead_code (int):
+        channels_last (bool): Channels are the last dimension in the input tensors.
+        commitment_weight (float): Weight for commitment loss.
+        orthogonal_reg_weight (float): Orthogonal regularization weights.
+        orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes.
+        orthogonal_reg_max_codes (optional int): Maximum number of codes to consider
+            for orthogonal regulariation.
+        threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
+            that have an exponential moving average cluster size less than the specified threshold with
+            randomly selected vector from the current batch.
+    """
+    def __init__(
+        self,
+        dim: int,
+        codebook_size: int,
+        codebook_dim: tp.Optional[int] = None,
+        decay: float = 0.8,
+        epsilon: float = 1e-5,
+        kmeans_init: bool = False,
+        kmeans_iters: int = 10,
+        threshold_ema_dead_code: int = 2,
+        channels_last: bool = False,
+        commitment_weight: float = 1.,
+        orthogonal_reg_weight: float = 0.0,
+        orthogonal_reg_active_codes_only: bool = False,
+        orthogonal_reg_max_codes: tp.Optional[int] = None,
+    ):
+        super().__init__()
+        _codebook_dim: int = default(codebook_dim, dim)
+
+        requires_projection = _codebook_dim != dim
+        self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity())
+        self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity())
+
+        self.epsilon = epsilon
+        self.commitment_weight = commitment_weight
+
+        self.orthogonal_reg_weight = orthogonal_reg_weight
+        self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
+        self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
+
+        self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
+                                           kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
+                                           decay=decay, epsilon=epsilon,
+                                           threshold_ema_dead_code=threshold_ema_dead_code)
+        self.codebook_size = codebook_size
+
+        self.channels_last = channels_last
+
+    @property
+    def codebook(self):
+        return self._codebook.embed
+
+    @property
+    def inited(self):
+        return self._codebook.inited
+
+    def _preprocess(self, x):
+        if not self.channels_last:
+            x = rearrange(x, "b d n -> b n d")
+        return x
+
+    def _postprocess(self, quantize):
+        if not self.channels_last:
+            quantize = rearrange(quantize, "b n d -> b d n")
+        return quantize
+
+    def encode(self, x):
+        x = self._preprocess(x)
+        x = self.project_in(x)
+        embed_in = self._codebook.encode(x)
+        return embed_in
+
+    def decode(self, embed_ind):
+        quantize = self._codebook.decode(embed_ind)
+        quantize = self.project_out(quantize)
+        quantize = self._postprocess(quantize)
+        return quantize
+
+    def forward(self, x):
+        device = x.device
+        x = self._preprocess(x)
+
+        x = self.project_in(x)
+        quantize, embed_ind = self._codebook(x)
+
+        if self.training:
+            quantize = x + (quantize - x).detach()
+
+        loss = torch.tensor([0.0], device=device, requires_grad=self.training)
+
+        if self.training:
+            if self.commitment_weight > 0:
+                commit_loss = F.mse_loss(quantize.detach(), x)
+                loss = loss + commit_loss * self.commitment_weight
+
+            if self.orthogonal_reg_weight > 0:
+                codebook = self.codebook
+
+                if self.orthogonal_reg_active_codes_only:
+                    # only calculate orthogonal loss for the activated codes for this batch
+                    unique_code_ids = torch.unique(embed_ind)
+                    codebook = codebook[unique_code_ids]
+
+                num_codes = codebook.shape[0]
+                if exists(self.orthogonal_reg_max_codes) and num_codes > self.orthogonal_reg_max_codes:
+                    rand_ids = torch.randperm(num_codes, device=device)[:self.orthogonal_reg_max_codes]
+                    codebook = codebook[rand_ids]
+
+                orthogonal_reg_loss = orthgonal_loss_fn(codebook)
+                loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight
+
+        quantize = self.project_out(quantize)
+        quantize = self._postprocess(quantize)
+
+        return quantize, embed_ind, loss
+
+

Ancestors

+
    +
  • torch.nn.modules.module.Module
  • +
+

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Instance variables

+
+
var codebook
+
+
+
+ +Expand source code + +
@property
+def codebook(self):
+    return self._codebook.embed
+
+
+
var inited
+
+
+
+ +Expand source code + +
@property
+def inited(self):
+    return self._codebook.inited
+
+
+
+

Methods

+
+
+def decode(self, embed_ind) +
+
+
+
+ +Expand source code + +
def decode(self, embed_ind):
+    quantize = self._codebook.decode(embed_ind)
+    quantize = self.project_out(quantize)
+    quantize = self._postprocess(quantize)
+    return quantize
+
+
+
+def encode(self, x) +
+
+
+
+ +Expand source code + +
def encode(self, x):
+    x = self._preprocess(x)
+    x = self.project_in(x)
+    embed_in = self._codebook.encode(x)
+    return embed_in
+
+
+
+def forward(self, x) ‑> Callable[..., Any] +
+
+

Defines the computation performed at every call.

+

Should be overridden by all subclasses.

+
+

Note

+

Although the recipe for forward pass needs to be defined within +this function, one should call the :class:Module instance afterwards +instead of this since the former takes care of running the +registered hooks while the latter silently ignores them.

+
+
+ +Expand source code + +
def forward(self, x):
+    device = x.device
+    x = self._preprocess(x)
+
+    x = self.project_in(x)
+    quantize, embed_ind = self._codebook(x)
+
+    if self.training:
+        quantize = x + (quantize - x).detach()
+
+    loss = torch.tensor([0.0], device=device, requires_grad=self.training)
+
+    if self.training:
+        if self.commitment_weight > 0:
+            commit_loss = F.mse_loss(quantize.detach(), x)
+            loss = loss + commit_loss * self.commitment_weight
+
+        if self.orthogonal_reg_weight > 0:
+            codebook = self.codebook
+
+            if self.orthogonal_reg_active_codes_only:
+                # only calculate orthogonal loss for the activated codes for this batch
+                unique_code_ids = torch.unique(embed_ind)
+                codebook = codebook[unique_code_ids]
+
+            num_codes = codebook.shape[0]
+            if exists(self.orthogonal_reg_max_codes) and num_codes > self.orthogonal_reg_max_codes:
+                rand_ids = torch.randperm(num_codes, device=device)[:self.orthogonal_reg_max_codes]
+                codebook = codebook[rand_ids]
+
+            orthogonal_reg_loss = orthgonal_loss_fn(codebook)
+            loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight
+
+    quantize = self.project_out(quantize)
+    quantize = self._postprocess(quantize)
+
+    return quantize, embed_ind, loss
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/audiocraft/quantization/index.html b/docs/audiocraft/quantization/index.html new file mode 100644 index 00000000..3224d9aa --- /dev/null +++ b/docs/audiocraft/quantization/index.html @@ -0,0 +1,89 @@ + + + + + + +audiocraft.quantization API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.quantization

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# flake8: noqa
+from .vq import ResidualVectorQuantizer
+from .base import BaseQuantizer, DummyQuantizer, QuantizedResult
+
+
+
+

Sub-modules

+
+
audiocraft.quantization.base
+
+

Base class for all quantizers.

+
+
audiocraft.quantization.core_vq
+
+
+
+
audiocraft.quantization.vq
+
+
+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/audiocraft/quantization/vq.html b/docs/audiocraft/quantization/vq.html new file mode 100644 index 00000000..9bd7694e --- /dev/null +++ b/docs/audiocraft/quantization/vq.html @@ -0,0 +1,390 @@ + + + + + + +audiocraft.quantization.vq API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.quantization.vq

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+import typing as tp
+
+import torch
+
+from .base import BaseQuantizer, QuantizedResult
+from .core_vq import ResidualVectorQuantization
+
+
+class ResidualVectorQuantizer(BaseQuantizer):
+    """Residual Vector Quantizer.
+
+    Args:
+        dimension (int): Dimension of the codebooks.
+        n_q (int): Number of residual vector quantizers used.
+        q_dropout (bool): Random quantizer drop out at train time.
+        bins (int): Codebook size.
+        decay (float): Decay for exponential moving average over the codebooks.
+        kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
+        kmeans_iters (int): Number of iterations used for kmeans initialization.
+        threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
+            that have an exponential moving average cluster size less than the specified threshold with
+            randomly selected vector from the current batch.
+        orthogonal_reg_weight (float): Orthogonal regularization weights.
+        orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes.
+        orthogonal_reg_max_codes (optional int): Maximum number of codes to consider.
+            for orthogonal regulariation.
+    """
+    def __init__(
+        self,
+        dimension: int = 256,
+        n_q: int = 8,
+        q_dropout: bool = False,
+        bins: int = 1024,
+        decay: float = 0.99,
+        kmeans_init: bool = True,
+        kmeans_iters: int = 10,
+        threshold_ema_dead_code: int = 2,
+        orthogonal_reg_weight: float = 0.0,
+        orthogonal_reg_active_codes_only: bool = False,
+        orthogonal_reg_max_codes: tp.Optional[int] = None,
+    ):
+        super().__init__()
+        self.max_n_q = n_q
+        self.n_q = n_q
+        self.q_dropout = q_dropout
+        self.dimension = dimension
+        self.bins = bins
+        self.decay = decay
+        self.kmeans_init = kmeans_init
+        self.kmeans_iters = kmeans_iters
+        self.threshold_ema_dead_code = threshold_ema_dead_code
+        self.orthogonal_reg_weight = orthogonal_reg_weight
+        self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
+        self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
+        self.vq = ResidualVectorQuantization(
+            dim=self.dimension,
+            codebook_size=self.bins,
+            num_quantizers=self.n_q,
+            decay=self.decay,
+            kmeans_init=self.kmeans_init,
+            kmeans_iters=self.kmeans_iters,
+            threshold_ema_dead_code=self.threshold_ema_dead_code,
+            orthogonal_reg_weight=self.orthogonal_reg_weight,
+            orthogonal_reg_active_codes_only=self.orthogonal_reg_active_codes_only,
+            orthogonal_reg_max_codes=self.orthogonal_reg_max_codes,
+            channels_last=False
+        )
+
+    def forward(self, x: torch.Tensor, frame_rate: int):
+        n_q = self.n_q
+        if self.training and self.q_dropout:
+            n_q = int(torch.randint(1, self.n_q + 1, (1,)).item())
+        bw_per_q = math.log2(self.bins) * frame_rate / 1000
+        quantized, codes, commit_loss = self.vq(x, n_q=n_q)
+        codes = codes.transpose(0, 1)
+        # codes is [B, K, T], with T frames, K nb of codebooks.
+        bw = torch.tensor(n_q * bw_per_q).to(x)
+        return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss))
+
+    def encode(self, x: torch.Tensor) -> torch.Tensor:
+        """Encode a given input tensor with the specified frame rate at the given bandwidth.
+        The RVQ encode method sets the appropriate number of quantizer to use
+        and returns indices for each quantizer.
+        """
+        n_q = self.n_q
+        codes = self.vq.encode(x, n_q=n_q)
+        codes = codes.transpose(0, 1)
+        # codes is [B, K, T], with T frames, K nb of codebooks.
+        return codes
+
+    def decode(self, codes: torch.Tensor) -> torch.Tensor:
+        """Decode the given codes to the quantized representation.
+        """
+        # codes is [B, K, T], with T frames, K nb of codebooks, vq.decode expects [K, B, T].
+        codes = codes.transpose(0, 1)
+        quantized = self.vq.decode(codes)
+        return quantized
+
+    @property
+    def total_codebooks(self):
+        return self.max_n_q
+
+    @property
+    def num_codebooks(self):
+        return self.n_q
+
+    def set_num_codebooks(self, n: int):
+        assert n > 0 and n <= self.max_n_q
+        self.n_q = n
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class ResidualVectorQuantizer +(dimension: int = 256, n_q: int = 8, q_dropout: bool = False, bins: int = 1024, decay: float = 0.99, kmeans_init: bool = True, kmeans_iters: int = 10, threshold_ema_dead_code: int = 2, orthogonal_reg_weight: float = 0.0, orthogonal_reg_active_codes_only: bool = False, orthogonal_reg_max_codes: Optional[int] = None) +
+
+

Residual Vector Quantizer.

+

Args

+
+
dimension : int
+
Dimension of the codebooks.
+
n_q : int
+
Number of residual vector quantizers used.
+
q_dropout : bool
+
Random quantizer drop out at train time.
+
bins : int
+
Codebook size.
+
decay : float
+
Decay for exponential moving average over the codebooks.
+
kmeans_init : bool
+
Whether to use kmeans to initialize the codebooks.
+
kmeans_iters : int
+
Number of iterations used for kmeans initialization.
+
threshold_ema_dead_code : int
+
Threshold for dead code expiration. Replace any codes +that have an exponential moving average cluster size less than the specified threshold with +randomly selected vector from the current batch.
+
orthogonal_reg_weight : float
+
Orthogonal regularization weights.
+
orthogonal_reg_active_codes_only : bool
+
Apply orthogonal regularization only on active codes.
+
orthogonal_reg_max_codes : optional int
+
Maximum number of codes to consider. +for orthogonal regulariation.
+
+

Initializes internal Module state, shared by both nn.Module and ScriptModule.

+
+ +Expand source code + +
class ResidualVectorQuantizer(BaseQuantizer):
+    """Residual Vector Quantizer.
+
+    Args:
+        dimension (int): Dimension of the codebooks.
+        n_q (int): Number of residual vector quantizers used.
+        q_dropout (bool): Random quantizer drop out at train time.
+        bins (int): Codebook size.
+        decay (float): Decay for exponential moving average over the codebooks.
+        kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
+        kmeans_iters (int): Number of iterations used for kmeans initialization.
+        threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
+            that have an exponential moving average cluster size less than the specified threshold with
+            randomly selected vector from the current batch.
+        orthogonal_reg_weight (float): Orthogonal regularization weights.
+        orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes.
+        orthogonal_reg_max_codes (optional int): Maximum number of codes to consider.
+            for orthogonal regulariation.
+    """
+    def __init__(
+        self,
+        dimension: int = 256,
+        n_q: int = 8,
+        q_dropout: bool = False,
+        bins: int = 1024,
+        decay: float = 0.99,
+        kmeans_init: bool = True,
+        kmeans_iters: int = 10,
+        threshold_ema_dead_code: int = 2,
+        orthogonal_reg_weight: float = 0.0,
+        orthogonal_reg_active_codes_only: bool = False,
+        orthogonal_reg_max_codes: tp.Optional[int] = None,
+    ):
+        super().__init__()
+        self.max_n_q = n_q
+        self.n_q = n_q
+        self.q_dropout = q_dropout
+        self.dimension = dimension
+        self.bins = bins
+        self.decay = decay
+        self.kmeans_init = kmeans_init
+        self.kmeans_iters = kmeans_iters
+        self.threshold_ema_dead_code = threshold_ema_dead_code
+        self.orthogonal_reg_weight = orthogonal_reg_weight
+        self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
+        self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
+        self.vq = ResidualVectorQuantization(
+            dim=self.dimension,
+            codebook_size=self.bins,
+            num_quantizers=self.n_q,
+            decay=self.decay,
+            kmeans_init=self.kmeans_init,
+            kmeans_iters=self.kmeans_iters,
+            threshold_ema_dead_code=self.threshold_ema_dead_code,
+            orthogonal_reg_weight=self.orthogonal_reg_weight,
+            orthogonal_reg_active_codes_only=self.orthogonal_reg_active_codes_only,
+            orthogonal_reg_max_codes=self.orthogonal_reg_max_codes,
+            channels_last=False
+        )
+
+    def forward(self, x: torch.Tensor, frame_rate: int):
+        n_q = self.n_q
+        if self.training and self.q_dropout:
+            n_q = int(torch.randint(1, self.n_q + 1, (1,)).item())
+        bw_per_q = math.log2(self.bins) * frame_rate / 1000
+        quantized, codes, commit_loss = self.vq(x, n_q=n_q)
+        codes = codes.transpose(0, 1)
+        # codes is [B, K, T], with T frames, K nb of codebooks.
+        bw = torch.tensor(n_q * bw_per_q).to(x)
+        return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss))
+
+    def encode(self, x: torch.Tensor) -> torch.Tensor:
+        """Encode a given input tensor with the specified frame rate at the given bandwidth.
+        The RVQ encode method sets the appropriate number of quantizer to use
+        and returns indices for each quantizer.
+        """
+        n_q = self.n_q
+        codes = self.vq.encode(x, n_q=n_q)
+        codes = codes.transpose(0, 1)
+        # codes is [B, K, T], with T frames, K nb of codebooks.
+        return codes
+
+    def decode(self, codes: torch.Tensor) -> torch.Tensor:
+        """Decode the given codes to the quantized representation.
+        """
+        # codes is [B, K, T], with T frames, K nb of codebooks, vq.decode expects [K, B, T].
+        codes = codes.transpose(0, 1)
+        quantized = self.vq.decode(codes)
+        return quantized
+
+    @property
+    def total_codebooks(self):
+        return self.max_n_q
+
+    @property
+    def num_codebooks(self):
+        return self.n_q
+
+    def set_num_codebooks(self, n: int):
+        assert n > 0 and n <= self.max_n_q
+        self.n_q = n
+
+

Ancestors

+ +

Class variables

+
+
var call_super_init : bool
+
+
+
+
var dump_patches : bool
+
+
+
+
var training : bool
+
+
+
+
+

Methods

+
+
+def encode(self, x: torch.Tensor) ‑> torch.Tensor +
+
+

Encode a given input tensor with the specified frame rate at the given bandwidth. +The RVQ encode method sets the appropriate number of quantizer to use +and returns indices for each quantizer.

+
+ +Expand source code + +
def encode(self, x: torch.Tensor) -> torch.Tensor:
+    """Encode a given input tensor with the specified frame rate at the given bandwidth.
+    The RVQ encode method sets the appropriate number of quantizer to use
+    and returns indices for each quantizer.
+    """
+    n_q = self.n_q
+    codes = self.vq.encode(x, n_q=n_q)
+    codes = codes.transpose(0, 1)
+    # codes is [B, K, T], with T frames, K nb of codebooks.
+    return codes
+
+
+
+

Inherited members

+ +
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/audiocraft/utils/autocast.html b/docs/audiocraft/utils/autocast.html new file mode 100644 index 00000000..bbf4554e --- /dev/null +++ b/docs/audiocraft/utils/autocast.html @@ -0,0 +1,163 @@ + + + + + + +audiocraft.utils.autocast API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.utils.autocast

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+
+class TorchAutocast:
+    """TorchAutocast utility class.
+    Allows you to enable and disable autocast. This is specially useful
+    when dealing with different architectures and clusters with different
+    levels of support.
+
+    Args:
+        enabled (bool): Whether to enable torch.autocast or not.
+        args: Additional args for torch.autocast.
+        kwargs: Additional kwargs for torch.autocast
+    """
+    def __init__(self, enabled: bool, *args, **kwargs):
+        self.autocast = torch.autocast(*args, **kwargs) if enabled else None
+
+    def __enter__(self):
+        if self.autocast is None:
+            return
+        try:
+            self.autocast.__enter__()
+        except RuntimeError:
+            device = self.autocast.device
+            dtype = self.autocast.fast_dtype
+            raise RuntimeError(
+                f"There was an error autocasting with dtype={dtype} device={device}\n"
+                "If you are on the FAIR Cluster, you might need to use autocast_dtype=float16"
+            )
+
+    def __exit__(self, *args, **kwargs):
+        if self.autocast is None:
+            return
+        self.autocast.__exit__(*args, **kwargs)
+
+
+
+
+
+
+
+
+
+

Classes

+
+
+class TorchAutocast +(enabled: bool, *args, **kwargs) +
+
+

TorchAutocast utility class. +Allows you to enable and disable autocast. This is specially useful +when dealing with different architectures and clusters with different +levels of support.

+

Args

+
+
enabled : bool
+
Whether to enable torch.autocast or not.
+
args
+
Additional args for torch.autocast.
+
kwargs
+
Additional kwargs for torch.autocast
+
+
+ +Expand source code + +
class TorchAutocast:
+    """TorchAutocast utility class.
+    Allows you to enable and disable autocast. This is specially useful
+    when dealing with different architectures and clusters with different
+    levels of support.
+
+    Args:
+        enabled (bool): Whether to enable torch.autocast or not.
+        args: Additional args for torch.autocast.
+        kwargs: Additional kwargs for torch.autocast
+    """
+    def __init__(self, enabled: bool, *args, **kwargs):
+        self.autocast = torch.autocast(*args, **kwargs) if enabled else None
+
+    def __enter__(self):
+        if self.autocast is None:
+            return
+        try:
+            self.autocast.__enter__()
+        except RuntimeError:
+            device = self.autocast.device
+            dtype = self.autocast.fast_dtype
+            raise RuntimeError(
+                f"There was an error autocasting with dtype={dtype} device={device}\n"
+                "If you are on the FAIR Cluster, you might need to use autocast_dtype=float16"
+            )
+
+    def __exit__(self, *args, **kwargs):
+        if self.autocast is None:
+            return
+        self.autocast.__exit__(*args, **kwargs)
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/audiocraft/utils/export.html b/docs/audiocraft/utils/export.html new file mode 100644 index 00000000..70e932e5 --- /dev/null +++ b/docs/audiocraft/utils/export.html @@ -0,0 +1,168 @@ + + + + + + +audiocraft.utils.export API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.utils.export

+
+
+

Utility to export a training checkpoint to a lightweight release checkpoint.

+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Utility to export a training checkpoint to a lightweight release checkpoint.
+"""
+
+from pathlib import Path
+import typing as tp
+
+from omegaconf import OmegaConf, DictConfig
+import torch
+
+
+def _clean_lm_cfg(cfg: DictConfig):
+    OmegaConf.set_struct(cfg, False)
+    # This used to be set automatically in the LM solver, need a more robust solution
+    # for the future.
+    cfg['transformer_lm']['card'] = 2048
+    cfg['transformer_lm']['n_q'] = 4
+    # Experimental params no longer supported.
+    bad_params = ['spectral_norm_attn_iters', 'spectral_norm_ff_iters',
+                  'residual_balancer_attn', 'residual_balancer_ff', 'layer_drop']
+    for name in bad_params:
+        del cfg['transformer_lm'][name]
+    OmegaConf.set_struct(cfg, True)
+    return cfg
+
+
+def export_encodec(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]):
+    sig = Path(checkpoint_path).parent.name
+    assert len(sig) == 8, "Not a valid Dora signature"
+    pkg = torch.load(checkpoint_path, 'cpu')
+    new_pkg = {
+        'best_state': pkg['ema']['state']['model'],
+        'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
+    }
+    out_file = Path(out_folder) / f'{sig}.th'
+    torch.save(new_pkg, out_file)
+    return out_file
+
+
+def export_lm(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]):
+    sig = Path(checkpoint_path).parent.name
+    assert len(sig) == 8, "Not a valid Dora signature"
+    pkg = torch.load(checkpoint_path, 'cpu')
+    new_pkg = {
+        'best_state': pkg['fsdp_best_state']['model'],
+        'xp.cfg': OmegaConf.to_yaml(_clean_lm_cfg(pkg['xp.cfg']))
+    }
+    out_file = Path(out_folder) / f'{sig}.th'
+    torch.save(new_pkg, out_file)
+    return out_file
+
+
+
+
+
+
+
+

Functions

+
+
+def export_encodec(checkpoint_path: Union[str, pathlib.Path], out_folder: Union[str, pathlib.Path]) +
+
+
+
+ +Expand source code + +
def export_encodec(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]):
+    sig = Path(checkpoint_path).parent.name
+    assert len(sig) == 8, "Not a valid Dora signature"
+    pkg = torch.load(checkpoint_path, 'cpu')
+    new_pkg = {
+        'best_state': pkg['ema']['state']['model'],
+        'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
+    }
+    out_file = Path(out_folder) / f'{sig}.th'
+    torch.save(new_pkg, out_file)
+    return out_file
+
+
+
+def export_lm(checkpoint_path: Union[str, pathlib.Path], out_folder: Union[str, pathlib.Path]) +
+
+
+
+ +Expand source code + +
def export_lm(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]):
+    sig = Path(checkpoint_path).parent.name
+    assert len(sig) == 8, "Not a valid Dora signature"
+    pkg = torch.load(checkpoint_path, 'cpu')
+    new_pkg = {
+        'best_state': pkg['fsdp_best_state']['model'],
+        'xp.cfg': OmegaConf.to_yaml(_clean_lm_cfg(pkg['xp.cfg']))
+    }
+    out_file = Path(out_folder) / f'{sig}.th'
+    torch.save(new_pkg, out_file)
+    return out_file
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/audiocraft/utils/index.html b/docs/audiocraft/utils/index.html new file mode 100644 index 00000000..f6515be6 --- /dev/null +++ b/docs/audiocraft/utils/index.html @@ -0,0 +1,90 @@ + + + + + + +audiocraft.utils API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.utils

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+
+

Sub-modules

+
+
audiocraft.utils.autocast
+
+
+
+
audiocraft.utils.export
+
+

Utility to export a training checkpoint to a lightweight release checkpoint.

+
+
audiocraft.utils.notebook
+
+
+
+
audiocraft.utils.utils
+
+
+
+
+
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/audiocraft/utils/notebook.html b/docs/audiocraft/utils/notebook.html new file mode 100644 index 00000000..075a78d7 --- /dev/null +++ b/docs/audiocraft/utils/notebook.html @@ -0,0 +1,133 @@ + + + + + + +audiocraft.utils.notebook API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.utils.notebook

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+try:
+    import IPython.display as ipd  # type: ignore
+except ImportError:
+    # Note in a notebook...
+    pass
+
+
+import torch
+
+
+def display_audio(samples: torch.Tensor, sample_rate: int):
+    """Renders an audio player for the given audio samples.
+
+    Args:
+        samples (torch.Tensor): a Tensor of decoded audio samples
+            with shapes [B, C, T] or [C, T]
+        sample_rate (int): sample rate audio should be displayed with.
+    """
+    assert samples.dim() == 2 or samples.dim() == 3
+
+    samples = samples.detach().cpu()
+    if samples.dim() == 2:
+        samples = samples[None, ...]
+
+    for audio in samples:
+        ipd.display(ipd.Audio(audio, rate=sample_rate))
+
+
+
+
+
+
+
+

Functions

+
+
+def display_audio(samples: torch.Tensor, sample_rate: int) +
+
+

Renders an audio player for the given audio samples.

+

Args

+
+
samples : torch.Tensor
+
a Tensor of decoded audio samples +with shapes [B, C, T] or [C, T]
+
sample_rate : int
+
sample rate audio should be displayed with.
+
+
+ +Expand source code + +
def display_audio(samples: torch.Tensor, sample_rate: int):
+    """Renders an audio player for the given audio samples.
+
+    Args:
+        samples (torch.Tensor): a Tensor of decoded audio samples
+            with shapes [B, C, T] or [C, T]
+        sample_rate (int): sample rate audio should be displayed with.
+    """
+    assert samples.dim() == 2 or samples.dim() == 3
+
+    samples = samples.detach().cpu()
+    if samples.dim() == 2:
+        samples = samples[None, ...]
+
+    for audio in samples:
+        ipd.display(ipd.Audio(audio, rate=sample_rate))
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file diff --git a/docs/audiocraft/utils/utils.html b/docs/audiocraft/utils/utils.html new file mode 100644 index 00000000..837913bc --- /dev/null +++ b/docs/audiocraft/utils/utils.html @@ -0,0 +1,796 @@ + + + + + + +audiocraft.utils.utils API documentation + + + + + + + + + + + +
+
+
+

Module audiocraft.utils.utils

+
+
+
+ +Expand source code + +
# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from concurrent.futures import ProcessPoolExecutor
+from functools import wraps
+import hashlib
+import logging
+import typing as tp
+
+import flashy
+import flashy.distrib
+import omegaconf
+import torch
+from torch.nn.utils.rnn import pad_sequence
+
+
+logger = logging.getLogger(__name__)
+
+
+def dict_from_config(cfg: omegaconf.DictConfig) -> dict:
+    """Convenience function to map an omegaconf configuration to a dictionary.
+
+    Args:
+        cfg (omegaconf.DictConfig): Original configuration to map to dict.
+    Returns:
+        dict: Config as dictionary object.
+    """
+    dct = omegaconf.OmegaConf.to_container(cfg, resolve=True)
+    assert isinstance(dct, dict)
+    return dct
+
+
+def random_subset(dataset, max_samples: int, seed: int = 42) -> torch.utils.data.Subset:
+    if max_samples >= len(dataset):
+        return dataset
+
+    generator = torch.Generator().manual_seed(seed)
+    perm = torch.randperm(len(dataset), generator=generator)
+    return torch.utils.data.Subset(dataset, perm[:max_samples].tolist())
+
+
+def get_loader(dataset, num_samples: tp.Optional[int], batch_size: int,
+               num_workers: int, seed: int, **kwargs) -> torch.utils.data.DataLoader:
+    """Convenience function to load dataset into a dataloader with optional subset sampling.
+
+    Args:
+        dataset: Dataset to load.
+        num_samples (Optional[int]): Number of samples to limit subset size.
+        batch_size (int): Batch size.
+        num_workers (int): Number of workers for data loading.
+        seed (int): Random seed.
+    """
+    if num_samples is not None:
+        dataset = random_subset(dataset, num_samples, seed)
+
+    dataloader = flashy.distrib.loader(
+        dataset,
+        batch_size=batch_size,
+        num_workers=num_workers,
+        **kwargs
+    )
+    return dataloader
+
+
+def get_dataset_from_loader(dataloader):
+    dataset = dataloader.dataset
+    if isinstance(dataset, torch.utils.data.Subset):
+        return dataset.dataset
+    else:
+        return dataset
+
+
+def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None):
+    """torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.
+
+    Args:
+        input (torch.Tensor): The input tensor containing probabilities.
+        num_samples (int): Number of samples to draw.
+        replacement (bool): Whether to draw with replacement or not.
+    Keywords args:
+        generator (torch.Generator): A pseudorandom number generator for sampling.
+    Returns:
+        torch.Tensor: Last dimension contains num_samples indices
+            sampled from the multinomial probability distribution
+            located in the last dimension of tensor input.
+    """
+    input_ = input.reshape(-1, input.shape[-1])
+    output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
+    output = output_.reshape(*list(input.shape[:-1]), -1)
+    return output
+
+
+def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor:
+    """Sample next token from top K values along the last dimension of the input probs tensor.
+
+    Args:
+        probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
+        k (int): The k in “top-k”.
+    Returns:
+        torch.Tensor: Sampled tokens.
+    """
+    top_k_value, _ = torch.topk(probs, k, dim=-1)
+    min_value_top_k = top_k_value[..., [-1]]
+    probs *= (probs >= min_value_top_k).float()
+    probs.div_(probs.sum(dim=-1, keepdim=True))
+    next_token = multinomial(probs, num_samples=1)
+    return next_token
+
+
+def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
+    """Sample next token from top P probabilities along the last dimension of the input probs tensor.
+
+    Args:
+        probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
+        p (int): The p in “top-p”.
+    Returns:
+        torch.Tensor: Sampled tokens.
+    """
+    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
+    probs_sum = torch.cumsum(probs_sort, dim=-1)
+    mask = probs_sum - probs_sort > p
+    probs_sort *= (~mask).float()
+    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
+    next_token = multinomial(probs_sort, num_samples=1)
+    next_token = torch.gather(probs_idx, -1, next_token)
+    return next_token
+
+
+class DummyPoolExecutor:
+    """Dummy pool executor to use when we actually have only 1 worker.
+    (e.g. instead of ProcessPoolExecutor).
+    """
+    class DummyResult:
+        def __init__(self, func, *args, **kwargs):
+            self.func = func
+            self.args = args
+            self.kwargs = kwargs
+
+        def result(self):
+            return self.func(*self.args, **self.kwargs)
+
+    def __init__(self, workers, mp_context=None):
+        pass
+
+    def submit(self, func, *args, **kwargs):
+        return DummyPoolExecutor.DummyResult(func, *args, **kwargs)
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_value, exc_tb):
+        return
+
+
+def get_pool_executor(num_workers: int, mp_context=None):
+    return ProcessPoolExecutor(num_workers, mp_context) if num_workers > 1 else DummyPoolExecutor(1)
+
+
+def length_to_mask(lengths: torch.Tensor, max_len: tp.Optional[int] = None) -> torch.Tensor:
+    """Utility function to convert a tensor of sequence lengths to a mask (useful when working on padded sequences).
+    For example: [3, 5] => [[1, 1, 1, 0, 0], [1, 1, 1, 1, 1]]
+
+    Args:
+        lengths (torch.Tensor): tensor with lengths
+        max_len (int): can set the max length manually. Defaults to None.
+    Returns:
+        torch.Tensor: mask with 0s where there is pad tokens else 1s
+    """
+    assert len(lengths.shape) == 1, "Length shape should be 1 dimensional."
+    final_length = lengths.max().item() if not max_len else max_len
+    final_length = max(final_length, 1)  # if all seqs are of len zero we don't want a zero-size tensor
+    return torch.arange(final_length)[None, :].to(lengths.device) < lengths[:, None]
+
+
+def hash_trick(word: str, vocab_size: int) -> int:
+    """Hash trick to pair each word with an index
+
+    Args:
+        word (str): word we wish to convert to an index
+        vocab_size (int): size of the vocabulary
+    Returns:
+        int: index of the word in the embedding LUT
+    """
+    hash = int(hashlib.sha256(word.encode("utf-8")).hexdigest(), 16)
+    return hash % vocab_size
+
+
+def with_rank_rng(base_seed: int = 1234):
+    """Decorator for a function so that the function will use a Random Number Generator
+    whose state depend on the GPU rank. The original RNG state is restored upon returning.
+
+    Args:
+        base_seed (int): Random seed.
+    """
+    def _decorator(fun: tp.Callable):
+        @wraps(fun)
+        def _decorated(*args, **kwargs):
+            state = torch.get_rng_state()
+            seed = base_seed ^ flashy.distrib.rank()
+            torch.manual_seed(seed)
+            logger.debug('Rank dependent seed set to %d', seed)
+            try:
+                return fun(*args, **kwargs)
+            finally:
+                torch.set_rng_state(state)
+                logger.debug('RNG state restored.')
+        return _decorated
+    return _decorator
+
+
+def collate(tensors: tp.List[torch.Tensor], dim: int = 0) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+    """Get a list of tensors and collate them to a single tensor. according to the following logic:
+    - `dim` specifies the time dimension which will be stacked and padded.
+    - The output will contain 1 new dimension (dimension index 0) which will be the size of
+    of the original list.
+
+    Args:
+        tensors (tp.List[torch.Tensor]): List of tensors to collate.
+        dim (int): Dimension which will be stacked and padded.
+    Returns:
+        tp.Tuple[torch.Tensor, torch.Tensor]:
+            torch.Tensor: Stacked and padded tensor. The output will contain 1 new dimension
+                (dimension index 0) which will be the size of the original list.
+            torch.Tensor: Tensor containing length of original tensor sizes (without padding).
+    """
+    tensors = [x.transpose(0, dim) for x in tensors]
+    lens = torch.LongTensor([len(x) for x in tensors])
+    padded_tensors = pad_sequence(tensors)
+    padded_tensors = padded_tensors.transpose(0, 1)
+    padded_tensors = padded_tensors.transpose(1, dim + 1)
+    return padded_tensors, lens
+
+
+
+
+
+
+
+

Functions

+
+
+def collate(tensors: List[torch.Tensor], dim: int = 0) ‑> Tuple[torch.Tensor, torch.Tensor] +
+
+

Get a list of tensors and collate them to a single tensor. according to the following logic: +- dim specifies the time dimension which will be stacked and padded. +- The output will contain 1 new dimension (dimension index 0) which will be the size of +of the original list.

+

Args

+
+
tensors : tp.List[torch.Tensor]
+
List of tensors to collate.
+
dim : int
+
Dimension which will be stacked and padded.
+
+

Returns

+
+
tp.Tuple[torch.Tensor, torch.Tensor]:
+
+torch.Tensor
+
Stacked and padded tensor. The output will contain 1 new dimension +(dimension index 0) which will be the size of the original list. +torch.Tensor: Tensor containing length of original tensor sizes (without padding).
+
+
+ +Expand source code + +
def collate(tensors: tp.List[torch.Tensor], dim: int = 0) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+    """Get a list of tensors and collate them to a single tensor. according to the following logic:
+    - `dim` specifies the time dimension which will be stacked and padded.
+    - The output will contain 1 new dimension (dimension index 0) which will be the size of
+    of the original list.
+
+    Args:
+        tensors (tp.List[torch.Tensor]): List of tensors to collate.
+        dim (int): Dimension which will be stacked and padded.
+    Returns:
+        tp.Tuple[torch.Tensor, torch.Tensor]:
+            torch.Tensor: Stacked and padded tensor. The output will contain 1 new dimension
+                (dimension index 0) which will be the size of the original list.
+            torch.Tensor: Tensor containing length of original tensor sizes (without padding).
+    """
+    tensors = [x.transpose(0, dim) for x in tensors]
+    lens = torch.LongTensor([len(x) for x in tensors])
+    padded_tensors = pad_sequence(tensors)
+    padded_tensors = padded_tensors.transpose(0, 1)
+    padded_tensors = padded_tensors.transpose(1, dim + 1)
+    return padded_tensors, lens
+
+
+
+def dict_from_config(cfg: omegaconf.dictconfig.DictConfig) ‑> dict +
+
+

Convenience function to map an omegaconf configuration to a dictionary.

+

Args

+
+
cfg : omegaconf.DictConfig
+
Original configuration to map to dict.
+
+

Returns

+
+
dict
+
Config as dictionary object.
+
+
+ +Expand source code + +
def dict_from_config(cfg: omegaconf.DictConfig) -> dict:
+    """Convenience function to map an omegaconf configuration to a dictionary.
+
+    Args:
+        cfg (omegaconf.DictConfig): Original configuration to map to dict.
+    Returns:
+        dict: Config as dictionary object.
+    """
+    dct = omegaconf.OmegaConf.to_container(cfg, resolve=True)
+    assert isinstance(dct, dict)
+    return dct
+
+
+
+def get_dataset_from_loader(dataloader) +
+
+
+
+ +Expand source code + +
def get_dataset_from_loader(dataloader):
+    dataset = dataloader.dataset
+    if isinstance(dataset, torch.utils.data.Subset):
+        return dataset.dataset
+    else:
+        return dataset
+
+
+
+def get_loader(dataset, num_samples: Optional[int], batch_size: int, num_workers: int, seed: int, **kwargs) ‑> torch.utils.data.dataloader.DataLoader +
+
+

Convenience function to load dataset into a dataloader with optional subset sampling.

+

Args

+
+
dataset
+
Dataset to load.
+
num_samples : Optional[int]
+
Number of samples to limit subset size.
+
batch_size : int
+
Batch size.
+
num_workers : int
+
Number of workers for data loading.
+
seed : int
+
Random seed.
+
+
+ +Expand source code + +
def get_loader(dataset, num_samples: tp.Optional[int], batch_size: int,
+               num_workers: int, seed: int, **kwargs) -> torch.utils.data.DataLoader:
+    """Convenience function to load dataset into a dataloader with optional subset sampling.
+
+    Args:
+        dataset: Dataset to load.
+        num_samples (Optional[int]): Number of samples to limit subset size.
+        batch_size (int): Batch size.
+        num_workers (int): Number of workers for data loading.
+        seed (int): Random seed.
+    """
+    if num_samples is not None:
+        dataset = random_subset(dataset, num_samples, seed)
+
+    dataloader = flashy.distrib.loader(
+        dataset,
+        batch_size=batch_size,
+        num_workers=num_workers,
+        **kwargs
+    )
+    return dataloader
+
+
+
+def get_pool_executor(num_workers: int, mp_context=None) +
+
+
+
+ +Expand source code + +
def get_pool_executor(num_workers: int, mp_context=None):
+    return ProcessPoolExecutor(num_workers, mp_context) if num_workers > 1 else DummyPoolExecutor(1)
+
+
+
+def hash_trick(word: str, vocab_size: int) ‑> int +
+
+

Hash trick to pair each word with an index

+

Args

+
+
word : str
+
word we wish to convert to an index
+
vocab_size : int
+
size of the vocabulary
+
+

Returns

+
+
int
+
index of the word in the embedding LUT
+
+
+ +Expand source code + +
def hash_trick(word: str, vocab_size: int) -> int:
+    """Hash trick to pair each word with an index
+
+    Args:
+        word (str): word we wish to convert to an index
+        vocab_size (int): size of the vocabulary
+    Returns:
+        int: index of the word in the embedding LUT
+    """
+    hash = int(hashlib.sha256(word.encode("utf-8")).hexdigest(), 16)
+    return hash % vocab_size
+
+
+
+def length_to_mask(lengths: torch.Tensor, max_len: Optional[int] = None) ‑> torch.Tensor +
+
+

Utility function to convert a tensor of sequence lengths to a mask (useful when working on padded sequences). +For example: [3, 5] => [[1, 1, 1, 0, 0], [1, 1, 1, 1, 1]]

+

Args

+
+
lengths : torch.Tensor
+
tensor with lengths
+
max_len : int
+
can set the max length manually. Defaults to None.
+
+

Returns

+
+
torch.Tensor
+
mask with 0s where there is pad tokens else 1s
+
+
+ +Expand source code + +
def length_to_mask(lengths: torch.Tensor, max_len: tp.Optional[int] = None) -> torch.Tensor:
+    """Utility function to convert a tensor of sequence lengths to a mask (useful when working on padded sequences).
+    For example: [3, 5] => [[1, 1, 1, 0, 0], [1, 1, 1, 1, 1]]
+
+    Args:
+        lengths (torch.Tensor): tensor with lengths
+        max_len (int): can set the max length manually. Defaults to None.
+    Returns:
+        torch.Tensor: mask with 0s where there is pad tokens else 1s
+    """
+    assert len(lengths.shape) == 1, "Length shape should be 1 dimensional."
+    final_length = lengths.max().item() if not max_len else max_len
+    final_length = max(final_length, 1)  # if all seqs are of len zero we don't want a zero-size tensor
+    return torch.arange(final_length)[None, :].to(lengths.device) < lengths[:, None]
+
+
+
+def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None) +
+
+

torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.

+

Args

+
+
input : torch.Tensor
+
The input tensor containing probabilities.
+
num_samples : int
+
Number of samples to draw.
+
replacement : bool
+
Whether to draw with replacement or not.
+
+

Keywords args: +generator (torch.Generator): A pseudorandom number generator for sampling.

+

Returns

+
+
torch.Tensor
+
Last dimension contains num_samples indices +sampled from the multinomial probability distribution +located in the last dimension of tensor input.
+
+
+ +Expand source code + +
def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None):
+    """torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.
+
+    Args:
+        input (torch.Tensor): The input tensor containing probabilities.
+        num_samples (int): Number of samples to draw.
+        replacement (bool): Whether to draw with replacement or not.
+    Keywords args:
+        generator (torch.Generator): A pseudorandom number generator for sampling.
+    Returns:
+        torch.Tensor: Last dimension contains num_samples indices
+            sampled from the multinomial probability distribution
+            located in the last dimension of tensor input.
+    """
+    input_ = input.reshape(-1, input.shape[-1])
+    output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
+    output = output_.reshape(*list(input.shape[:-1]), -1)
+    return output
+
+
+
+def random_subset(dataset, max_samples: int, seed: int = 42) ‑> torch.utils.data.dataset.Subset +
+
+
+
+ +Expand source code + +
def random_subset(dataset, max_samples: int, seed: int = 42) -> torch.utils.data.Subset:
+    if max_samples >= len(dataset):
+        return dataset
+
+    generator = torch.Generator().manual_seed(seed)
+    perm = torch.randperm(len(dataset), generator=generator)
+    return torch.utils.data.Subset(dataset, perm[:max_samples].tolist())
+
+
+
+def sample_top_k(probs: torch.Tensor, k: int) ‑> torch.Tensor +
+
+

Sample next token from top K values along the last dimension of the input probs tensor.

+

Args

+
+
probs : torch.Tensor
+
Input probabilities with token candidates on the last dimension.
+
k : int
+
The k in “top-k”.
+
+

Returns

+
+
torch.Tensor
+
Sampled tokens.
+
+
+ +Expand source code + +
def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor:
+    """Sample next token from top K values along the last dimension of the input probs tensor.
+
+    Args:
+        probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
+        k (int): The k in “top-k”.
+    Returns:
+        torch.Tensor: Sampled tokens.
+    """
+    top_k_value, _ = torch.topk(probs, k, dim=-1)
+    min_value_top_k = top_k_value[..., [-1]]
+    probs *= (probs >= min_value_top_k).float()
+    probs.div_(probs.sum(dim=-1, keepdim=True))
+    next_token = multinomial(probs, num_samples=1)
+    return next_token
+
+
+
+def sample_top_p(probs: torch.Tensor, p: float) ‑> torch.Tensor +
+
+

Sample next token from top P probabilities along the last dimension of the input probs tensor.

+

Args

+
+
probs : torch.Tensor
+
Input probabilities with token candidates on the last dimension.
+
p : int
+
The p in “top-p”.
+
+

Returns

+
+
torch.Tensor
+
Sampled tokens.
+
+
+ +Expand source code + +
def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
+    """Sample next token from top P probabilities along the last dimension of the input probs tensor.
+
+    Args:
+        probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
+        p (int): The p in “top-p”.
+    Returns:
+        torch.Tensor: Sampled tokens.
+    """
+    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
+    probs_sum = torch.cumsum(probs_sort, dim=-1)
+    mask = probs_sum - probs_sort > p
+    probs_sort *= (~mask).float()
+    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
+    next_token = multinomial(probs_sort, num_samples=1)
+    next_token = torch.gather(probs_idx, -1, next_token)
+    return next_token
+
+
+
+def with_rank_rng(base_seed: int = 1234) +
+
+

Decorator for a function so that the function will use a Random Number Generator +whose state depend on the GPU rank. The original RNG state is restored upon returning.

+

Args

+
+
base_seed : int
+
Random seed.
+
+
+ +Expand source code + +
def with_rank_rng(base_seed: int = 1234):
+    """Decorator for a function so that the function will use a Random Number Generator
+    whose state depend on the GPU rank. The original RNG state is restored upon returning.
+
+    Args:
+        base_seed (int): Random seed.
+    """
+    def _decorator(fun: tp.Callable):
+        @wraps(fun)
+        def _decorated(*args, **kwargs):
+            state = torch.get_rng_state()
+            seed = base_seed ^ flashy.distrib.rank()
+            torch.manual_seed(seed)
+            logger.debug('Rank dependent seed set to %d', seed)
+            try:
+                return fun(*args, **kwargs)
+            finally:
+                torch.set_rng_state(state)
+                logger.debug('RNG state restored.')
+        return _decorated
+    return _decorator
+
+
+
+
+
+

Classes

+
+
+class DummyPoolExecutor +(workers, mp_context=None) +
+
+

Dummy pool executor to use when we actually have only 1 worker. +(e.g. instead of ProcessPoolExecutor).

+
+ +Expand source code + +
class DummyPoolExecutor:
+    """Dummy pool executor to use when we actually have only 1 worker.
+    (e.g. instead of ProcessPoolExecutor).
+    """
+    class DummyResult:
+        def __init__(self, func, *args, **kwargs):
+            self.func = func
+            self.args = args
+            self.kwargs = kwargs
+
+        def result(self):
+            return self.func(*self.args, **self.kwargs)
+
+    def __init__(self, workers, mp_context=None):
+        pass
+
+    def submit(self, func, *args, **kwargs):
+        return DummyPoolExecutor.DummyResult(func, *args, **kwargs)
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_value, exc_tb):
+        return
+
+

Class variables

+
+
var DummyResult
+
+
+
+
+

Methods

+
+
+def submit(self, func, *args, **kwargs) +
+
+
+
+ +Expand source code + +
def submit(self, func, *args, **kwargs):
+    return DummyPoolExecutor.DummyResult(func, *args, **kwargs)
+
+
+
+
+
+
+
+ +
+ + + \ No newline at end of file