From 42d239d40a9d2b8a7f2db1f82907ec713191e326 Mon Sep 17 00:00:00 2001 From: Benjamin Freist Date: Wed, 16 Mar 2022 13:45:00 +0100 Subject: [PATCH] [install] torch.complex32 has been removed from 1.11.0 According to https://github.com/pytorch/pytorch/issues/72721 complex32 will be brought back in a later version, the check should be removed then. --- asteroid/dsp/beamforming.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/asteroid/dsp/beamforming.py b/asteroid/dsp/beamforming.py index e58add446..374d43289 100644 --- a/asteroid/dsp/beamforming.py +++ b/asteroid/dsp/beamforming.py @@ -473,14 +473,6 @@ def _generalized_eigenvalue_decomposition(a, b): return e_val, e_vec -_to_double_map = { - torch.float16: torch.float64, - torch.float32: torch.float64, - torch.complex32: torch.complex128, - torch.complex64: torch.complex128, -} - - def _common_dtype(*args): all_dtypes = [a.dtype for a in args] if len(set(all_dtypes)) > 1: @@ -502,20 +494,24 @@ def force_double_linalg(): def _precision_mapping(): + has_complex32 = hasattr(torch, "complex32") if USE_DOUBLE: - return { + precision_map = { torch.float16: torch.float64, torch.float32: torch.float64, - torch.complex32: torch.complex128, torch.complex64: torch.complex128, } + if has_complex32: + precision_map[torch.complex32] = torch.complex128 else: - return { + precision_map = { torch.float16: torch.float16, torch.float32: torch.float32, - torch.complex32: torch.complex32, torch.complex64: torch.complex64, } + if has_complex32: + precision_map[torch.complex32] = torch.complex32 + return precision_map # Legacy