Skip to content

Commit

Permalink
Fix normalization of impulse response in ImpulsePerturbation (NVIDIA#…
Browse files Browse the repository at this point in the history
…6505)

Signed-off-by: Ante Jukić <[email protected]>
Signed-off-by: hsiehjackson <[email protected]>
  • Loading branch information
anteju authored and hsiehjackson committed Jun 2, 2023
1 parent 319b191 commit 2dd91fa
Showing 1 changed file with 32 additions and 15 deletions.
47 changes: 32 additions & 15 deletions nemo/collections/asr/parts/preprocessing/perturb.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,16 +344,24 @@ class ImpulsePerturbation(Perturbation):
manifest_path (list): Manifest file for RIRs
audio_tar_filepaths (list): Tar files, if RIR audio files are tarred
shuffle_n (int): Shuffle parameter for shuffling buffered files from the tar files
normalize_impulse (bool): Normalize impulse response to zero mean and amplitude 1
shift_impulse (bool): Shift impulse response to adjust for delay at the beginning
rng (int): Random seed. Default is None
"""

def __init__(
self, manifest_path=None, audio_tar_filepaths=None, shuffle_n=128, shift_impulse=False, rng=None,
self,
manifest_path=None,
audio_tar_filepaths=None,
shuffle_n=128,
normalize_impulse=False,
shift_impulse=False,
rng=None,
):
self._manifest = collections.ASRAudioText(manifest_path, parser=parsers.make_parser([]), index_by_file_id=True)
self._audiodataset = None
self._tarred_audio = False
self._normalize_impulse = normalize_impulse
self._shift_impulse = shift_impulse
self._data_iterator = None

Expand All @@ -373,23 +381,32 @@ def perturb(self, data):
tarred_audio=self._tarred_audio,
audio_dataset=self._data_iterator,
)
if not self._shift_impulse:
impulse_norm = (impulse.samples - min(impulse.samples)) / (max(impulse.samples) - min(impulse.samples))
data._samples = signal.fftconvolve(data._samples, impulse_norm, "same")
data._samples = data._samples / max(
abs(data._samples)
) # normalize data samples to [-1,1] after rir convolution to avoid nans with fp16 training

# normalize if necessary
if self._normalize_impulse:
# normalize the impulse response to zero mean and amplitude 1
impulse_norm = impulse.samples - np.mean(impulse.samples)
impulse_norm /= max(abs(impulse_norm))
else:
# Find peak and shift peak to left
impulse_norm = (impulse.samples - min(impulse.samples)) / (max(impulse.samples) - min(impulse.samples))
impulse_norm = impulse.samples

# len of input data samples
len_data = len(data._samples)

# convolve with the full impulse response
data._samples = signal.fftconvolve(data._samples, impulse_norm, "full")

# compensate the dominant path propagation delay
if self._shift_impulse:
# Find the peak of the IR and shift the output to the left
max_ind = np.argmax(np.abs(impulse_norm))
data._samples = data._samples[max_ind:]

# trim to match the input data length
data._samples = data._samples[:len_data]

impulse_resp = impulse_norm[max_ind:]
delay_after = len(impulse_resp)
data._samples = signal.fftconvolve(data._samples, impulse_resp, "full")[:-delay_after]
data._samples = data._samples / max(
abs(data._samples)
) # normalize data samples to [-1,1] after rir convolution to avoid nans with fp16 training
# normalize data samples to [-1,1] after rir convolution to avoid nans with fp16 training
data._samples = data._samples / max(abs(data._samples))


class ShiftPerturbation(Perturbation):
Expand Down

0 comments on commit 2dd91fa

Please sign in to comment.