Skip to content

Commit

Permalink
Add multi gpu batched inference (#12)
Browse files Browse the repository at this point in the history
* add more aug

* add multi gpu inference
  • Loading branch information
loubbrad committed Mar 7, 2024
1 parent f6f5fbb commit e49951c
Show file tree
Hide file tree
Showing 11 changed files with 595 additions and 285 deletions.
30 changes: 23 additions & 7 deletions amt/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,14 +186,18 @@ class AudioTransform(torch.nn.Module):
def __init__(
self,
reverb_factor: int = 1,
min_snr: int = 10,
max_snr: int = 40,
min_snr: int = 20,
max_snr: int = 50,
max_dist_gain: int = 25,
min_dist_gain: int = 0,
):
super().__init__()
self.tokenizer = AmtTokenizer()
self.reverb_factor = reverb_factor
self.min_snr = min_snr
self.max_snr = max_snr
self.max_dist_gain = max_dist_gain
self.min_dist_gain = min_dist_gain

self.config = load_config()["audio"]
self.sample_rate = self.config["sample_rate"]
Expand Down Expand Up @@ -230,10 +234,10 @@ def __init__(
)
self.spec_aug = torch.nn.Sequential(
torchaudio.transforms.FrequencyMasking(
freq_mask_param=15, iid_masks=True
freq_mask_param=10, iid_masks=True
),
torchaudio.transforms.TimeMasking(
time_mask_param=500, iid_masks=True
time_mask_param=1000, iid_masks=True
),
)

Expand Down Expand Up @@ -309,6 +313,12 @@ def apply_noise(self, wav: torch.tensor):

return AF.add_noise(waveform=wav, noise=noise, snr=snr_dbs)

def apply_distortion(self, wav: torch.tensor):
gain = random.randint(self.min_dist_gain, self.max_dist_gain)
colour = random.randint(5, 95)

return AF.overdrive(wav, gain=gain, colour=colour)

def shift_spec(self, specs: torch.Tensor, shift: int):
if shift == 0:
return specs
Expand All @@ -335,7 +345,13 @@ def shift_spec(self, specs: torch.Tensor, shift: int):
return shifted_specs

def aug_wav(self, wav: torch.Tensor):
return self.apply_reverb(self.apply_noise(wav))
# Only apply distortion in 20% of cases
if random.random() > 0.20:
return self.apply_reverb(self.apply_noise(wav))
else:
return self.apply_reverb(
self.apply_distortion(self.apply_noise(wav))
)

def norm_mel(self, mel_spec: torch.Tensor):
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
Expand Down Expand Up @@ -364,8 +380,8 @@ def forward(self, wav: torch.Tensor, shift: int = 0):
# Spec & pitch shift
log_mel = self.log_mel(wav, shift)

# Spec aug
if random.random() > 0.2:
# Spec aug in 20% of cases
if random.random() > 0.20:
log_mel = self.spec_aug(log_mel)

return log_mel
9 changes: 7 additions & 2 deletions amt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,25 @@


def get_wav_mid_segments(
audio_path: str, mid_path: str = "", return_json: bool = False
audio_path: str,
mid_path: str = "",
return_json: bool = False,
stride_factor: int | None = None,
):
"""This function yields tuples of matched log mel spectrograms and
tokenized sequences (np.array, list). If it is given only an audio path
then it will return an empty list for the mid_feature
"""
tokenizer = AmtTokenizer()
config = load_config()
stride_factor = config["data"]["stride_factor"]
sample_rate = config["audio"]["sample_rate"]
chunk_len = config["audio"]["chunk_len"]
num_samples = sample_rate * chunk_len
samples_per_ms = sample_rate // 1000

if not stride_factor:
stride_factor = config["data"]["stride_factor"]

if not os.path.isfile(audio_path):
return None

Expand Down
Loading

0 comments on commit e49951c

Please sign in to comment.