Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade dependencies to torch 2.x and lightning 2.x #682

Merged
merged 8 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions asteroid/dsp/beamforming.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def forward(
"""
# TODO: Implement several RTF estimation strategies, and choose one here, or expose all.
# Get relative transfer function (1st PCA of Σss)
e_val, e_vec = torch.symeig(target_scm.permute(0, 3, 1, 2), eigenvectors=True)
e_val, e_vec = torch.linalg.eigh(target_scm.permute(0, 3, 1, 2))
rtf_vect = e_vec[..., -1] # bfm
return self.from_rtf_vect(mix=mix, rtf_vec=rtf_vect.transpose(-1, -2), noise_scm=noise_scm)

Expand Down Expand Up @@ -471,7 +471,7 @@ def _generalized_eigenvalue_decomposition(a, b):
# Compute C matrix L⁻1 A L^-T
cmat = inv_cholesky @ a @ inv_cholesky.conj().transpose(-1, -2)
# Performing the eigenvalue decomposition
e_val, e_vec = torch.symeig(cmat, eigenvectors=True)
e_val, e_vec = torch.linalg.eigh(cmat)
# Collecting the eigenvectors
e_vec = torch.matmul(inv_cholesky.conj().transpose(-1, -2), e_vec)
return e_val, e_vec
Expand Down
2 changes: 1 addition & 1 deletion asteroid/engine/schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ class SinkPITBetaScheduler(pl.callbacks.Callback):
def __init__(self, cooling_schedule=sinkpit_default_beta_schedule):
self.cooling_schedule = cooling_schedule

def on_epoch_start(self, trainer, pl_module):
def on_train_epoch_start(self, trainer, pl_module):
assert isinstance(pl_module.loss_func, SinkPITLossWrapper)
assert trainer.current_epoch == pl_module.current_epoch # same
epoch = pl_module.current_epoch
Expand Down
2 changes: 1 addition & 1 deletion asteroid/engine/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def configure_optimizers(self):
epoch_schedulers.append(sched)
return [self.optimizer], epoch_schedulers

def lr_scheduler_step(self, scheduler, optimizer_idx, metric):
def lr_scheduler_step(self, scheduler, metric):
if metric is None:
scheduler.step()
else:
Expand Down
7 changes: 7 additions & 0 deletions asteroid/models/x_umx.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@


class XUMX(BaseModel):
def __init__(self, *args, **kwargs):
raise RuntimeError(
"XUMX is broken in torch 2.0, use torch<2.0 with asteroid<0.7 to use it until it's fixed."
)


class BrokenXUMX(BaseModel):
r"""CrossNet-Open-Unmix (X-UMX) for Music Source Separation introduced in [1].
There are two notable contributions with no effect on inference:
a) Multi Domain Losses
Expand Down
2 changes: 1 addition & 1 deletion notebooks/03_PITLossWrapper.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@
" return pw_loss.mean(dim=mean_over)\n",
"# Compute pairwise losses using broadcasting (+ unit test equality)\n",
"direct_pairwise_losses = pairwise_mse(estimate_sources, sources)\n",
"torch.testing.assert_allclose(pairwise_losses, direct_pairwise_losses)\n",
"torch.testing.assert_close(pairwise_losses, direct_pairwise_losses)\n",
"# Plot the pairwise losses\n",
"ax = plt.imshow(direct_pairwise_losses[0].data.numpy())"
]
Expand Down
2 changes: 1 addition & 1 deletion requirements/install.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
-r ./torchhub.txt
PyYAML>=5.0
pandas>=0.23.4
pytorch-lightning>=1.5.0,<=1.7.7
pytorch-lightning>=2.0.0
torchmetrics<=0.11.4
torchaudio>=0.8.0
pb_bss_eval>=0.0.2
Expand Down
2 changes: 1 addition & 1 deletion requirements/torchhub.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Note that Asteroid itself is not required to be installed.
numpy>=1.16.4
scipy>=1.10.1
torch>=1.8.0,<2.0.0
torch>=2.0.0
asteroid-filterbanks>=0.4.0
requests
filelock
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ def find_version(*file_paths):
# From requirements/torchhub.txt
"numpy>=1.16.4",
"scipy>=1.10.1",
"torch>=1.8.0,<2.0.0",
"torch>=2.0.0",
"asteroid-filterbanks>=0.4.0",
"SoundFile>=0.10.2",
"huggingface_hub>=0.0.2",
# From requirements/install.txt
"PyYAML>=5.0",
"pandas>=0.23.4",
"pytorch-lightning>=1.5.0,<=1.7.7",
"pytorch-lightning>=2.0.0",
"torchmetrics<=0.11.4",
"torchaudio>=0.5.0",
"pb_bss_eval>=0.0.2",
Expand Down
20 changes: 10 additions & 10 deletions tests/complex_nn_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from torch.testing import assert_allclose
from torch.testing import assert_close
import pytest
import math

Expand All @@ -17,23 +17,23 @@ def test_torch_complex_from_magphase():
mag = torch.randn(shape).abs()
phase = torch.remainder(torch.randn(shape), math.pi)
out = cnn.torch_complex_from_magphase(mag, phase)
assert_allclose(torch.abs(out), mag)
assert_allclose(out.angle(), phase)
assert_close(torch.abs(out), mag)
assert_close(out.angle(), phase)


def test_torch_complex_from_reim():
comp = torch.randn(10, 12, dtype=torch.complex64)
assert_allclose(cnn.torch_complex_from_reim(comp.real, comp.imag), comp)
assert_close(cnn.torch_complex_from_reim(comp.real, comp.imag), comp)


def test_onreim():
inp = torch.randn(10, 10, dtype=torch.complex64)
# Identity
fn = cnn.on_reim(lambda x: x)
assert_allclose(fn(inp), inp)
assert_close(fn(inp), inp)
# Top right quadrant
fn = cnn.on_reim(lambda x: x.abs())
assert_allclose(fn(inp), cnn.torch_complex_from_reim(inp.real.abs(), inp.imag.abs()))
assert_close(fn(inp), cnn.torch_complex_from_reim(inp.real.abs(), inp.imag.abs()))


def test_on_reim_class():
Expand All @@ -48,16 +48,16 @@ def forward(self, x):
return x + self.a

fn = cnn.OnReIm(Identity, 0)
assert_allclose(fn(inp), inp)
assert_close(fn(inp), inp)
fn = cnn.OnReIm(Identity, 1)
assert_allclose(fn(inp), cnn.torch_complex_from_reim(inp.real + 1, inp.imag + 1))
assert_close(fn(inp), cnn.torch_complex_from_reim(inp.real + 1, inp.imag + 1))


def test_complex_mul_wrapper():
a = torch.randn(10, 10, dtype=torch.complex64)

fn = cnn.ComplexMultiplicationWrapper(torch.nn.ReLU)
assert_allclose(
assert_close(
fn(a),
cnn.torch_complex_from_reim(
torch.relu(a.real) - torch.relu(a.imag), torch.relu(a.real) + torch.relu(a.imag)
Expand Down Expand Up @@ -86,4 +86,4 @@ def test_complexsinglernn(n_layers):
reim = layer.re_module(inp.imag)
imre = layer.im_module(inp.real)
inp = cnn.torch_complex_from_reim(rere - imim, reim + imre)
assert_allclose(out, inp)
assert_close(out, inp)
6 changes: 3 additions & 3 deletions tests/dsp/consistency_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from torch.testing import assert_allclose
from torch.testing import assert_close
import pytest

from asteroid.dsp.consistency import mixture_consistency
Expand All @@ -13,7 +13,7 @@ def test_consistency_noweight(mix_shape, dim, n_src):
est_shape = mix_shape[:dim] + [n_src] + mix_shape[dim:]
est_sources = torch.randn(est_shape)
consistent_est_sources = mixture_consistency(mix, est_sources, dim=dim)
assert_allclose(mix, consistent_est_sources.sum(dim))
assert_close(mix, consistent_est_sources.sum(dim))


@pytest.mark.parametrize("mix_shape", [[2, 1600], [2, 130, 10]])
Expand All @@ -30,7 +30,7 @@ def test_consistency_withweight(mix_shape, dim, n_src):
src_weights = torch.softmax(torch.randn(src_weights_shape), dim=dim)
# Apply mixture consitency
consistent_est_sources = mixture_consistency(mix, est_sources, src_weights=src_weights, dim=dim)
assert_allclose(mix, consistent_est_sources.sum(dim))
assert_close(mix, consistent_est_sources.sum(dim))


def test_consistency_raise():
Expand Down
4 changes: 2 additions & 2 deletions tests/dsp/overlap_add_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from torch.testing import assert_allclose
from torch.testing import assert_close
import pytest

from asteroid.dsp.overlap_add import LambdaOverlapAdd
Expand All @@ -16,4 +16,4 @@ def test_overlap_add(length, batch_size, n_src, window, window_size, hop_size):
nnet = lambda x: x.unsqueeze(1).repeat(1, n_src, 1)
oladd = LambdaOverlapAdd(nnet, n_src, window_size, hop_size, window)
oladded = oladd(mix)
assert_allclose(mix.repeat(1, n_src, 1), oladded)
assert_close(mix.repeat(1, n_src, 1), oladded)
4 changes: 2 additions & 2 deletions tests/jit/jit_filterbanks_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import pytest
from torch.testing import assert_allclose
from torch.testing import assert_close
from asteroid_filterbanks import make_enc_dec
from asteroid.models.base_models import BaseEncoderMaskerDecoder

Expand Down Expand Up @@ -30,7 +30,7 @@ def test_jit_filterbanks(filter_bank_name, inference_data):
with torch.no_grad():
res = model(inference_data)
out = traced(inference_data)
assert_allclose(res, out)
assert_close(res, out)


class DummyModel(BaseEncoderMaskerDecoder):
Expand Down
8 changes: 4 additions & 4 deletions tests/jit/jit_masknn_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
import torch
from torch.testing import assert_allclose
from torch.testing import assert_close
from asteroid.masknn import norms


Expand All @@ -13,10 +13,10 @@ def test_lns(cls):
traced = torch.jit.trace(model, x)

y = torch.randn(3, chan_size, 18, 12)
assert_allclose(traced(y), model(y))
assert_close(traced(y), model(y))

y = torch.randn(2, chan_size, 10, 5, 4)
assert_allclose(traced(y), model(y))
assert_close(traced(y), model(y))


def test_cumln():
Expand All @@ -27,4 +27,4 @@ def test_cumln():
traced = torch.jit.trace(model, x)

y = torch.randn(3, chan_size, 100)
assert_allclose(traced(y), model(y))
assert_close(traced(y), model(y))
4 changes: 2 additions & 2 deletions tests/jit/jit_models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch
import pytest
from torch.testing import assert_allclose
from torch.testing import assert_close
from asteroid.models import (
DCCRNet,
DCUNet,
Expand All @@ -20,7 +20,7 @@
def assert_consistency(model, traced, tensor):
ref = model(tensor)
out = traced(tensor)
assert_allclose(ref, out)
assert_close(ref, out)


@pytest.fixture(scope="module")
Expand Down
12 changes: 6 additions & 6 deletions tests/losses/loss_functions_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
import torch
from torch.testing import assert_allclose
from torch.testing import assert_close
import warnings

from asteroid_filterbanks import STFTFB, Encoder, transforms
Expand Down Expand Up @@ -68,15 +68,15 @@ def test_sisdr_and_mse(n_src, loss):
w_src_wrapper = PITLossWrapper(multisrc, pit_from="perm_avg")

# Circular tests on value
assert_allclose(pw_wrapper(est_targets, targets), wo_src_wrapper(est_targets, targets))
assert_allclose(wo_src_wrapper(est_targets, targets), w_src_wrapper(est_targets, targets))
assert_close(pw_wrapper(est_targets, targets), wo_src_wrapper(est_targets, targets))
assert_close(wo_src_wrapper(est_targets, targets), w_src_wrapper(est_targets, targets))

# Circular tests on returned estimates
assert_allclose(
assert_close(
pw_wrapper(est_targets, targets, return_est=True)[1],
wo_src_wrapper(est_targets, targets, return_est=True)[1],
)
assert_allclose(
assert_close(
wo_src_wrapper(est_targets, targets, return_est=True)[1],
w_src_wrapper(est_targets, targets, return_est=True)[1],
)
Expand Down Expand Up @@ -151,7 +151,7 @@ def test_pmsqe(sample_rate):
assert loss_value.shape[0] == ref.shape[0]
# Assert support for transposed inputs.
tr_loss_value = loss_func(est_spec.transpose(1, 2), ref_spec.transpose(1, 2))
assert_allclose(loss_value, tr_loss_value)
assert_close(loss_value, tr_loss_value)


@pytest.mark.parametrize("n_src", [2, 3])
Expand Down
12 changes: 6 additions & 6 deletions tests/losses/pit_wrapper_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import itertools
import torch
from torch.testing import assert_allclose
from torch.testing import assert_close

from asteroid.losses import PITLossWrapper, pairwise_mse

Expand Down Expand Up @@ -71,7 +71,7 @@ def test_permutation(perm):
loss_value, reordered = loss_func(est_sources, sources, return_est=True)

assert loss_value.item() == 0
assert_allclose(sources, reordered)
assert_close(sources, reordered)


def test_permreduce():
Expand All @@ -95,8 +95,8 @@ def test_permreduce():
w_mean = w_mean_reduce(est_sources, sources)
w_sum = w_sum_reduce(est_sources, sources)

assert_allclose(wo, w_mean)
assert_allclose(wo, w_sum / n_src)
assert_close(wo, w_mean)
assert_close(wo, w_sum / n_src)


def test_permreduce_args():
Expand All @@ -123,8 +123,8 @@ def test_best_perm_match(n_src):
min_loss, min_idx = PITLossWrapper.find_best_perm_factorial(pwl)
min_loss_hun, min_idx_hun = PITLossWrapper.find_best_perm_hungarian(pwl)

assert_allclose(min_loss, min_loss_hun)
assert_allclose(min_idx, min_idx_hun)
assert_close(min_loss, min_loss_hun)
assert_close(min_idx, min_idx_hun)


def test_raises_wrong_pit_from():
Expand Down
6 changes: 3 additions & 3 deletions tests/losses/sinkpit_wrapper_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from torch import nn, optim
from torch.utils import data
from torch.testing import assert_allclose
from torch.testing import assert_close

import pytorch_lightning as pl
from pytorch_lightning import Trainer
Expand Down Expand Up @@ -90,7 +90,7 @@ def test_proximity_sinkhorn_hungarian(batch_size, n_src, beta, n_iter, function_
mean_loss_hungarian = loss_hungarian(est_targets, targets, return_est=False)

# compare
assert_allclose(mean_loss_sinkhorn, mean_loss_hungarian)
assert_close(mean_loss_sinkhorn, mean_loss_hungarian)


class _TestCallback(pl.callbacks.Callback):
Expand All @@ -99,7 +99,7 @@ def __init__(self, function, total, batch_size):
self.epoch = 0
self.n_batch = total // batch_size

def on_batch_end(self, trainer, pl_module):
def on_train_batch_end(self, trainer, *args, **kwargs):
step = trainer.global_step
assert self.epoch * self.n_batch <= step
assert step <= (self.epoch + 1) * self.n_batch
Expand Down
6 changes: 3 additions & 3 deletions tests/masknn/activations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# list of strings / function pair in parametrize?
import pytest
import torch
from torch.testing import assert_allclose
from torch.testing import assert_close

from asteroid.masknn import activations
from torch import nn
Expand All @@ -27,14 +27,14 @@ def test_activations(activation_tuple):
asteroid_act = activations.get(asteroid_act)()

inp = torch.randn(10, 11, 12)
assert_allclose(torch_act(inp), asteroid_act(inp))
assert_close(torch_act(inp), asteroid_act(inp))


def test_softmax():
torch_softmax = nn.Softmax(dim=-1)
asteroid_softmax = activations.get("softmax")(dim=-1)
inp = torch.randn(10, 11, 12)
assert_allclose(torch_softmax(inp), asteroid_softmax(inp))
assert_close(torch_softmax(inp), asteroid_softmax(inp))
assert torch_softmax == activations.get(torch_softmax)


Expand Down
Loading
Loading