diff --git a/asteroid/dsp/beamforming.py b/asteroid/dsp/beamforming.py index e58add446..374d43289 100644 --- a/asteroid/dsp/beamforming.py +++ b/asteroid/dsp/beamforming.py @@ -473,14 +473,6 @@ def _generalized_eigenvalue_decomposition(a, b): return e_val, e_vec -_to_double_map = { - torch.float16: torch.float64, - torch.float32: torch.float64, - torch.complex32: torch.complex128, - torch.complex64: torch.complex128, -} - - def _common_dtype(*args): all_dtypes = [a.dtype for a in args] if len(set(all_dtypes)) > 1: @@ -502,20 +494,24 @@ def force_double_linalg(): def _precision_mapping(): + has_complex32 = hasattr(torch, "complex32") if USE_DOUBLE: - return { + precision_map = { torch.float16: torch.float64, torch.float32: torch.float64, - torch.complex32: torch.complex128, torch.complex64: torch.complex128, } + if has_complex32: + precision_map[torch.complex32] = torch.complex128 else: - return { + precision_map = { torch.float16: torch.float16, torch.float32: torch.float32, - torch.complex32: torch.complex32, torch.complex64: torch.complex64, } + if has_complex32: + precision_map[torch.complex32] = torch.complex32 + return precision_map # Legacy