diff --git a/asteroid/dsp/beamforming.py b/asteroid/dsp/beamforming.py index 0b6baec6b..6ab8fd30b 100644 --- a/asteroid/dsp/beamforming.py +++ b/asteroid/dsp/beamforming.py @@ -94,7 +94,7 @@ def forward( """ # TODO: Implement several RTF estimation strategies, and choose one here, or expose all. # Get relative transfer function (1st PCA of Σss) - e_val, e_vec = torch.symeig(target_scm.permute(0, 3, 1, 2), eigenvectors=True) + e_val, e_vec = torch.linalg.eigh(target_scm.permute(0, 3, 1, 2)) rtf_vect = e_vec[..., -1] # bfm return self.from_rtf_vect(mix=mix, rtf_vec=rtf_vect.transpose(-1, -2), noise_scm=noise_scm) @@ -471,7 +471,7 @@ def _generalized_eigenvalue_decomposition(a, b): # Compute C matrix L⁻1 A L^-T cmat = inv_cholesky @ a @ inv_cholesky.conj().transpose(-1, -2) # Performing the eigenvalue decomposition - e_val, e_vec = torch.symeig(cmat, eigenvectors=True) + e_val, e_vec = torch.linalg.eigh(cmat) # Collecting the eigenvectors e_vec = torch.matmul(inv_cholesky.conj().transpose(-1, -2), e_vec) return e_val, e_vec diff --git a/asteroid/engine/schedulers.py b/asteroid/engine/schedulers.py index 9c4b4298a..5472371fd 100644 --- a/asteroid/engine/schedulers.py +++ b/asteroid/engine/schedulers.py @@ -180,7 +180,7 @@ class SinkPITBetaScheduler(pl.callbacks.Callback): def __init__(self, cooling_schedule=sinkpit_default_beta_schedule): self.cooling_schedule = cooling_schedule - def on_epoch_start(self, trainer, pl_module): + def on_train_epoch_start(self, trainer, pl_module): assert isinstance(pl_module.loss_func, SinkPITLossWrapper) assert trainer.current_epoch == pl_module.current_epoch # same epoch = pl_module.current_epoch diff --git a/asteroid/engine/system.py b/asteroid/engine/system.py index f014ae3f8..23481aaff 100644 --- a/asteroid/engine/system.py +++ b/asteroid/engine/system.py @@ -163,7 +163,7 @@ def configure_optimizers(self): epoch_schedulers.append(sched) return [self.optimizer], epoch_schedulers - def lr_scheduler_step(self, scheduler, optimizer_idx, metric): + def lr_scheduler_step(self, scheduler, metric): if metric is None: scheduler.step() else: diff --git a/asteroid/models/x_umx.py b/asteroid/models/x_umx.py index 599a6e69d..1c29fd6b1 100755 --- a/asteroid/models/x_umx.py +++ b/asteroid/models/x_umx.py @@ -7,6 +7,13 @@ class XUMX(BaseModel): + def __init__(self, *args, **kwargs): + raise RuntimeError( + "XUMX is broken in torch 2.0, use torch<2.0 with asteroid<0.7 to use it until it's fixed." + ) + + +class BrokenXUMX(BaseModel): r"""CrossNet-Open-Unmix (X-UMX) for Music Source Separation introduced in [1]. There are two notable contributions with no effect on inference: a) Multi Domain Losses diff --git a/notebooks/03_PITLossWrapper.ipynb b/notebooks/03_PITLossWrapper.ipynb index 587cb4568..7672973fc 100644 --- a/notebooks/03_PITLossWrapper.ipynb +++ b/notebooks/03_PITLossWrapper.ipynb @@ -187,7 +187,7 @@ " return pw_loss.mean(dim=mean_over)\n", "# Compute pairwise losses using broadcasting (+ unit test equality)\n", "direct_pairwise_losses = pairwise_mse(estimate_sources, sources)\n", - "torch.testing.assert_allclose(pairwise_losses, direct_pairwise_losses)\n", + "torch.testing.assert_close(pairwise_losses, direct_pairwise_losses)\n", "# Plot the pairwise losses\n", "ax = plt.imshow(direct_pairwise_losses[0].data.numpy())" ] diff --git a/requirements/install.txt b/requirements/install.txt index ab521c439..6ff8e89c2 100644 --- a/requirements/install.txt +++ b/requirements/install.txt @@ -2,7 +2,7 @@ -r ./torchhub.txt PyYAML>=5.0 pandas>=0.23.4 -pytorch-lightning>=1.5.0,<=1.7.7 +pytorch-lightning>=2.0.0 torchmetrics<=0.11.4 torchaudio>=0.8.0 pb_bss_eval>=0.0.2 diff --git a/requirements/torchhub.txt b/requirements/torchhub.txt index c6cd82628..d9795a641 100644 --- a/requirements/torchhub.txt +++ b/requirements/torchhub.txt @@ -2,7 +2,7 @@ # Note that Asteroid itself is not required to be installed. numpy>=1.16.4 scipy>=1.10.1 -torch>=1.8.0,<2.0.0 +torch>=2.0.0 asteroid-filterbanks>=0.4.0 requests filelock diff --git a/setup.py b/setup.py index 0b287c21c..727582050 100644 --- a/setup.py +++ b/setup.py @@ -39,14 +39,14 @@ def find_version(*file_paths): # From requirements/torchhub.txt "numpy>=1.16.4", "scipy>=1.10.1", - "torch>=1.8.0,<2.0.0", + "torch>=2.0.0", "asteroid-filterbanks>=0.4.0", "SoundFile>=0.10.2", "huggingface_hub>=0.0.2", # From requirements/install.txt "PyYAML>=5.0", "pandas>=0.23.4", - "pytorch-lightning>=1.5.0,<=1.7.7", + "pytorch-lightning>=2.0.0", "torchmetrics<=0.11.4", "torchaudio>=0.5.0", "pb_bss_eval>=0.0.2", diff --git a/tests/complex_nn_test.py b/tests/complex_nn_test.py index a06491fd9..9ee5be21b 100644 --- a/tests/complex_nn_test.py +++ b/tests/complex_nn_test.py @@ -1,5 +1,5 @@ import torch -from torch.testing import assert_allclose +from torch.testing import assert_close import pytest import math @@ -17,23 +17,23 @@ def test_torch_complex_from_magphase(): mag = torch.randn(shape).abs() phase = torch.remainder(torch.randn(shape), math.pi) out = cnn.torch_complex_from_magphase(mag, phase) - assert_allclose(torch.abs(out), mag) - assert_allclose(out.angle(), phase) + assert_close(torch.abs(out), mag) + assert_close(out.angle(), phase) def test_torch_complex_from_reim(): comp = torch.randn(10, 12, dtype=torch.complex64) - assert_allclose(cnn.torch_complex_from_reim(comp.real, comp.imag), comp) + assert_close(cnn.torch_complex_from_reim(comp.real, comp.imag), comp) def test_onreim(): inp = torch.randn(10, 10, dtype=torch.complex64) # Identity fn = cnn.on_reim(lambda x: x) - assert_allclose(fn(inp), inp) + assert_close(fn(inp), inp) # Top right quadrant fn = cnn.on_reim(lambda x: x.abs()) - assert_allclose(fn(inp), cnn.torch_complex_from_reim(inp.real.abs(), inp.imag.abs())) + assert_close(fn(inp), cnn.torch_complex_from_reim(inp.real.abs(), inp.imag.abs())) def test_on_reim_class(): @@ -48,16 +48,16 @@ def forward(self, x): return x + self.a fn = cnn.OnReIm(Identity, 0) - assert_allclose(fn(inp), inp) + assert_close(fn(inp), inp) fn = cnn.OnReIm(Identity, 1) - assert_allclose(fn(inp), cnn.torch_complex_from_reim(inp.real + 1, inp.imag + 1)) + assert_close(fn(inp), cnn.torch_complex_from_reim(inp.real + 1, inp.imag + 1)) def test_complex_mul_wrapper(): a = torch.randn(10, 10, dtype=torch.complex64) fn = cnn.ComplexMultiplicationWrapper(torch.nn.ReLU) - assert_allclose( + assert_close( fn(a), cnn.torch_complex_from_reim( torch.relu(a.real) - torch.relu(a.imag), torch.relu(a.real) + torch.relu(a.imag) @@ -86,4 +86,4 @@ def test_complexsinglernn(n_layers): reim = layer.re_module(inp.imag) imre = layer.im_module(inp.real) inp = cnn.torch_complex_from_reim(rere - imim, reim + imre) - assert_allclose(out, inp) + assert_close(out, inp) diff --git a/tests/dsp/consistency_test.py b/tests/dsp/consistency_test.py index 2327e6eff..763fcdb30 100644 --- a/tests/dsp/consistency_test.py +++ b/tests/dsp/consistency_test.py @@ -1,5 +1,5 @@ import torch -from torch.testing import assert_allclose +from torch.testing import assert_close import pytest from asteroid.dsp.consistency import mixture_consistency @@ -13,7 +13,7 @@ def test_consistency_noweight(mix_shape, dim, n_src): est_shape = mix_shape[:dim] + [n_src] + mix_shape[dim:] est_sources = torch.randn(est_shape) consistent_est_sources = mixture_consistency(mix, est_sources, dim=dim) - assert_allclose(mix, consistent_est_sources.sum(dim)) + assert_close(mix, consistent_est_sources.sum(dim)) @pytest.mark.parametrize("mix_shape", [[2, 1600], [2, 130, 10]]) @@ -30,7 +30,7 @@ def test_consistency_withweight(mix_shape, dim, n_src): src_weights = torch.softmax(torch.randn(src_weights_shape), dim=dim) # Apply mixture consitency consistent_est_sources = mixture_consistency(mix, est_sources, src_weights=src_weights, dim=dim) - assert_allclose(mix, consistent_est_sources.sum(dim)) + assert_close(mix, consistent_est_sources.sum(dim)) def test_consistency_raise(): diff --git a/tests/dsp/overlap_add_test.py b/tests/dsp/overlap_add_test.py index 18572f78a..ff6df0dea 100644 --- a/tests/dsp/overlap_add_test.py +++ b/tests/dsp/overlap_add_test.py @@ -1,5 +1,5 @@ import torch -from torch.testing import assert_allclose +from torch.testing import assert_close import pytest from asteroid.dsp.overlap_add import LambdaOverlapAdd @@ -16,4 +16,4 @@ def test_overlap_add(length, batch_size, n_src, window, window_size, hop_size): nnet = lambda x: x.unsqueeze(1).repeat(1, n_src, 1) oladd = LambdaOverlapAdd(nnet, n_src, window_size, hop_size, window) oladded = oladd(mix) - assert_allclose(mix.repeat(1, n_src, 1), oladded) + assert_close(mix.repeat(1, n_src, 1), oladded) diff --git a/tests/jit/jit_filterbanks_test.py b/tests/jit/jit_filterbanks_test.py index 178328640..910a6e011 100644 --- a/tests/jit/jit_filterbanks_test.py +++ b/tests/jit/jit_filterbanks_test.py @@ -1,6 +1,6 @@ import torch import pytest -from torch.testing import assert_allclose +from torch.testing import assert_close from asteroid_filterbanks import make_enc_dec from asteroid.models.base_models import BaseEncoderMaskerDecoder @@ -30,7 +30,7 @@ def test_jit_filterbanks(filter_bank_name, inference_data): with torch.no_grad(): res = model(inference_data) out = traced(inference_data) - assert_allclose(res, out) + assert_close(res, out) class DummyModel(BaseEncoderMaskerDecoder): diff --git a/tests/jit/jit_masknn_test.py b/tests/jit/jit_masknn_test.py index 36176d34f..5c5f9251b 100644 --- a/tests/jit/jit_masknn_test.py +++ b/tests/jit/jit_masknn_test.py @@ -1,6 +1,6 @@ import pytest import torch -from torch.testing import assert_allclose +from torch.testing import assert_close from asteroid.masknn import norms @@ -13,10 +13,10 @@ def test_lns(cls): traced = torch.jit.trace(model, x) y = torch.randn(3, chan_size, 18, 12) - assert_allclose(traced(y), model(y)) + assert_close(traced(y), model(y)) y = torch.randn(2, chan_size, 10, 5, 4) - assert_allclose(traced(y), model(y)) + assert_close(traced(y), model(y)) def test_cumln(): @@ -27,4 +27,4 @@ def test_cumln(): traced = torch.jit.trace(model, x) y = torch.randn(3, chan_size, 100) - assert_allclose(traced(y), model(y)) + assert_close(traced(y), model(y)) diff --git a/tests/jit/jit_models_test.py b/tests/jit/jit_models_test.py index 7b17c40a7..362b7464d 100644 --- a/tests/jit/jit_models_test.py +++ b/tests/jit/jit_models_test.py @@ -2,7 +2,7 @@ import torch import pytest -from torch.testing import assert_allclose +from torch.testing import assert_close from asteroid.models import ( DCCRNet, DCUNet, @@ -20,7 +20,7 @@ def assert_consistency(model, traced, tensor): ref = model(tensor) out = traced(tensor) - assert_allclose(ref, out) + assert_close(ref, out) @pytest.fixture(scope="module") diff --git a/tests/losses/loss_functions_test.py b/tests/losses/loss_functions_test.py index e92f549cd..4e14500d4 100644 --- a/tests/losses/loss_functions_test.py +++ b/tests/losses/loss_functions_test.py @@ -1,6 +1,6 @@ import pytest import torch -from torch.testing import assert_allclose +from torch.testing import assert_close import warnings from asteroid_filterbanks import STFTFB, Encoder, transforms @@ -68,15 +68,15 @@ def test_sisdr_and_mse(n_src, loss): w_src_wrapper = PITLossWrapper(multisrc, pit_from="perm_avg") # Circular tests on value - assert_allclose(pw_wrapper(est_targets, targets), wo_src_wrapper(est_targets, targets)) - assert_allclose(wo_src_wrapper(est_targets, targets), w_src_wrapper(est_targets, targets)) + assert_close(pw_wrapper(est_targets, targets), wo_src_wrapper(est_targets, targets)) + assert_close(wo_src_wrapper(est_targets, targets), w_src_wrapper(est_targets, targets)) # Circular tests on returned estimates - assert_allclose( + assert_close( pw_wrapper(est_targets, targets, return_est=True)[1], wo_src_wrapper(est_targets, targets, return_est=True)[1], ) - assert_allclose( + assert_close( wo_src_wrapper(est_targets, targets, return_est=True)[1], w_src_wrapper(est_targets, targets, return_est=True)[1], ) @@ -151,7 +151,7 @@ def test_pmsqe(sample_rate): assert loss_value.shape[0] == ref.shape[0] # Assert support for transposed inputs. tr_loss_value = loss_func(est_spec.transpose(1, 2), ref_spec.transpose(1, 2)) - assert_allclose(loss_value, tr_loss_value) + assert_close(loss_value, tr_loss_value) @pytest.mark.parametrize("n_src", [2, 3]) diff --git a/tests/losses/pit_wrapper_test.py b/tests/losses/pit_wrapper_test.py index cb67fbf9f..2debdf87e 100644 --- a/tests/losses/pit_wrapper_test.py +++ b/tests/losses/pit_wrapper_test.py @@ -1,7 +1,7 @@ import pytest import itertools import torch -from torch.testing import assert_allclose +from torch.testing import assert_close from asteroid.losses import PITLossWrapper, pairwise_mse @@ -71,7 +71,7 @@ def test_permutation(perm): loss_value, reordered = loss_func(est_sources, sources, return_est=True) assert loss_value.item() == 0 - assert_allclose(sources, reordered) + assert_close(sources, reordered) def test_permreduce(): @@ -95,8 +95,8 @@ def test_permreduce(): w_mean = w_mean_reduce(est_sources, sources) w_sum = w_sum_reduce(est_sources, sources) - assert_allclose(wo, w_mean) - assert_allclose(wo, w_sum / n_src) + assert_close(wo, w_mean) + assert_close(wo, w_sum / n_src) def test_permreduce_args(): @@ -123,8 +123,8 @@ def test_best_perm_match(n_src): min_loss, min_idx = PITLossWrapper.find_best_perm_factorial(pwl) min_loss_hun, min_idx_hun = PITLossWrapper.find_best_perm_hungarian(pwl) - assert_allclose(min_loss, min_loss_hun) - assert_allclose(min_idx, min_idx_hun) + assert_close(min_loss, min_loss_hun) + assert_close(min_idx, min_idx_hun) def test_raises_wrong_pit_from(): diff --git a/tests/losses/sinkpit_wrapper_test.py b/tests/losses/sinkpit_wrapper_test.py index 49366b0f0..536232fcc 100644 --- a/tests/losses/sinkpit_wrapper_test.py +++ b/tests/losses/sinkpit_wrapper_test.py @@ -3,7 +3,7 @@ import torch from torch import nn, optim from torch.utils import data -from torch.testing import assert_allclose +from torch.testing import assert_close import pytorch_lightning as pl from pytorch_lightning import Trainer @@ -90,7 +90,7 @@ def test_proximity_sinkhorn_hungarian(batch_size, n_src, beta, n_iter, function_ mean_loss_hungarian = loss_hungarian(est_targets, targets, return_est=False) # compare - assert_allclose(mean_loss_sinkhorn, mean_loss_hungarian) + assert_close(mean_loss_sinkhorn, mean_loss_hungarian) class _TestCallback(pl.callbacks.Callback): @@ -99,7 +99,7 @@ def __init__(self, function, total, batch_size): self.epoch = 0 self.n_batch = total // batch_size - def on_batch_end(self, trainer, pl_module): + def on_train_batch_end(self, trainer, *args, **kwargs): step = trainer.global_step assert self.epoch * self.n_batch <= step assert step <= (self.epoch + 1) * self.n_batch diff --git a/tests/masknn/activations_test.py b/tests/masknn/activations_test.py index a209a0934..d734bbb8d 100644 --- a/tests/masknn/activations_test.py +++ b/tests/masknn/activations_test.py @@ -2,7 +2,7 @@ # list of strings / function pair in parametrize? import pytest import torch -from torch.testing import assert_allclose +from torch.testing import assert_close from asteroid.masknn import activations from torch import nn @@ -27,14 +27,14 @@ def test_activations(activation_tuple): asteroid_act = activations.get(asteroid_act)() inp = torch.randn(10, 11, 12) - assert_allclose(torch_act(inp), asteroid_act(inp)) + assert_close(torch_act(inp), asteroid_act(inp)) def test_softmax(): torch_softmax = nn.Softmax(dim=-1) asteroid_softmax = activations.get("softmax")(dim=-1) inp = torch.randn(10, 11, 12) - assert_allclose(torch_softmax(inp), asteroid_softmax(inp)) + assert_close(torch_softmax(inp), asteroid_softmax(inp)) assert torch_softmax == activations.get(torch_softmax) diff --git a/tests/models/demask_test.py b/tests/models/demask_test.py index 92b5189a4..5971625a3 100644 --- a/tests/models/demask_test.py +++ b/tests/models/demask_test.py @@ -4,7 +4,7 @@ import pytest import torch -from torch.testing import assert_allclose +from torch.testing import assert_close from asteroid.models import DeMask @@ -55,7 +55,7 @@ # expected_data = torch.load(ref_file) # with torch.no_grad(): # output = model(in_data) -# assert_allclose(output, expected_data) +# assert_close(output, expected_data) # # # def test_get_model_args(model): diff --git a/tests/models/models_test.py b/tests/models/models_test.py index dc9915dba..6ff4884e7 100644 --- a/tests/models/models_test.py +++ b/tests/models/models_test.py @@ -1,6 +1,6 @@ import torch import pytest -from torch.testing import assert_allclose +from torch.testing import assert_close import numpy as np import soundfile as sf import asteroid @@ -257,7 +257,7 @@ def _default_test_model(model, input_samples=801, test_input=None): model_conf = model.serialize() reconstructed_model = model.__class__.from_pretrained(model_conf) - assert_allclose(model(test_input), reconstructed_model(test_input)) + assert_close(model(test_input), reconstructed_model(test_input)) # Load with and without SR sr = model_conf["model_args"].pop("sample_rate") @@ -310,7 +310,7 @@ def test_demask(fb): model_conf = model.serialize() reconstructed_model = DeMask.from_pretrained(model_conf) - assert_allclose(model(test_input), reconstructed_model(test_input)) + assert_close(model(test_input), reconstructed_model(test_input)) def test_separate(): diff --git a/tests/models/xumx_test.py b/tests/models/xumx_test.py index a5f6a94cf..b4fb160c8 100755 --- a/tests/models/xumx_test.py +++ b/tests/models/xumx_test.py @@ -11,6 +11,7 @@ ] +@pytest.mark.skip(reason="XUMX is not broken in torch 2.x") @pytest.mark.parametrize("nb_channels", (1, 2)) @pytest.mark.parametrize("sources", sources) @pytest.mark.parametrize("bidirectional", (True, False)) @@ -45,6 +46,7 @@ def test_forward(nb_channels, sources, bidirectional, spec_power, return_time_si x_umx(data) +@pytest.mark.skip(reason="XUMX is not broken in torch 2.x") def test_get_model_args(): sources_tmp = ["vocals"] x_umx = XUMX(sources=sources_tmp, window_length=4096) @@ -67,6 +69,7 @@ def test_get_model_args(): assert x_umx.get_model_args() == expected +@pytest.mark.skip(reason="XUMX is not broken in torch 2.x") def test_model_loading(): sources_tmp = ["bass", "drums", "vocals", "other"] model = XUMX(sources=sources_tmp) diff --git a/tests/utils/utils_test.py b/tests/utils/utils_test.py index 02e455cd9..86ab80fbd 100644 --- a/tests/utils/utils_test.py +++ b/tests/utils/utils_test.py @@ -2,7 +2,7 @@ import argparse from collections.abc import MutableMapping import torch -from torch.testing import assert_allclose +from torch.testing import assert_close import pytest import numpy as np @@ -64,7 +64,7 @@ def test_boolean(parser): ) def test_transfer(tensors): if isinstance(tensors, torch.Tensor): - assert_allclose(utils.tensors_to_device(tensors, "cpu"), tensors) + assert_close(utils.tensors_to_device(tensors, "cpu"), tensors) if isinstance(tensors, list): assert list(utils.tensors_to_device(tensors, "cpu")) == list(tensors) if isinstance(tensors, dict):