Skip to content

Commit

Permalink
[CUDA][Linalg} Patch crash of linalg.eigh when input matrix is ill-…
Browse files Browse the repository at this point in the history
…conditioned, in some cusolver version (pytorch#107082)

Related: pytorch#94772, pytorch#105359

I can locally reproduce this crash with pytorch 2.0.1 stable pip binary. The test already passes with the latest cuda 12.2 release.

Re: pytorch#94772 (comment)
> From discussion in triage review:

- [x] we should add a test to prevent regressions
- [x] properly document support wrt different CUDA versions
- [x] possibly add support using MAGMA
Pull Request resolved: pytorch#107082
Approved by: https://github.com/lezcano
  • Loading branch information
xwang233 authored and summerdo committed Aug 17, 2023
1 parent 3fa80b3 commit 6cd0970
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 2 deletions.
13 changes: 11 additions & 2 deletions aten/src/ATen/cuda/Exceptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ const char *cusparseGetErrorString(cusparseStatus_t status);

namespace at::cuda::solver {
C10_EXPORT const char* cusolverGetErrorMessage(cusolverStatus_t status);

constexpr const char* _cusolver_backend_suggestion = \
"If you keep seeing this error, you may use " \
"`torch.backends.cuda.preferred_linalg_library()` to try " \
"linear algebra operators with other supported backends. " \
"See https://pytorch.org/docs/stable/backends.html#torch.backends.cuda.preferred_linalg_library";

} // namespace at::cuda::solver

// When cuda < 11.5, cusolver raises CUSOLVER_STATUS_EXECUTION_FAILED when input contains nan.
Expand All @@ -85,13 +92,15 @@ C10_EXPORT const char* cusolverGetErrorMessage(cusolverStatus_t status);
"cusolver error: ", \
at::cuda::solver::cusolverGetErrorMessage(__err), \
", when calling `" #EXPR "`", \
". This error may appear if the input matrix contains NaN."); \
". This error may appear if the input matrix contains NaN. ", \
at::cuda::solver::_cusolver_backend_suggestion); \
} else { \
TORCH_CHECK( \
__err == CUSOLVER_STATUS_SUCCESS, \
"cusolver error: ", \
at::cuda::solver::cusolverGetErrorMessage(__err), \
", when calling `" #EXPR "`"); \
", when calling `" #EXPR "`. ", \
at::cuda::solver::_cusolver_backend_suggestion); \
} \
} while (0)

Expand Down
20 changes: 20 additions & 0 deletions test/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,6 +988,26 @@ def test_eigh_errors_and_warnings(self, device, dtype):
with self.assertRaisesRegex(RuntimeError, "tensors to be on the same device"):
torch.linalg.eigh(a, out=(out_w, out_v))

@skipCPUIfNoLapack
@dtypes(torch.float, torch.double)
@unittest.skipIf(_get_torch_cuda_version() < (12, 1), "Test is fixed on cuda 12.1 update 1.")
def test_eigh_svd_illcondition_matrix_input_should_not_crash(self, device, dtype):
# See https://github.com/pytorch/pytorch/issues/94772, https://github.com/pytorch/pytorch/issues/105359
# This test crashes with `cusolver error: CUSOLVER_STATUS_EXECUTION_FAILED` on cuda 11.8,
# but passes on cuda 12.1 update 1 or later.
a = torch.ones(512, 512, dtype=dtype, device=device)
a[0, 0] = 1.0e-5
a[-1, -1] = 1.0e5

eigh_out = torch.linalg.eigh(a)
svd_out = torch.linalg.svd(a)

# Matrix input a is too ill-conditioned.
# We'll just compare the first two singular values/eigenvalues. They are 1.0e5 and 511.0
# The precision override with tolerance of 1.0 makes sense since ill-conditioned inputs are hard to converge
# to exact values.
self.assertEqual(eigh_out.eigenvalues.sort(descending=True).values[:2], [1.0e5, 511.0], atol=1.0, rtol=1.0e-2)
self.assertEqual(svd_out.S[:2], [1.0e5, 511.0], atol=1.0, rtol=1.0e-2)

@skipCUDAIfNoMagma
@skipCPUIfNoLapack
Expand Down
7 changes: 7 additions & 0 deletions torch/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,13 @@
:math:`\lambda_i` through the computation of
:math:`\frac{1}{\min_{i \neq j} \lambda_i - \lambda_j}`.
.. warning:: User may see pytorch crashes if running `eigh` on CUDA devices with CUDA versions before 12.1 update 1
with large ill-conditioned matrices as inputs.
Refer to :ref:`Linear Algebra Numerical Stability<Linear Algebra Stability>` for more details.
If this is the case, user may (1) tune their matrix inputs to be less ill-conditioned,
or (2) use :func:`torch.backends.cuda.preferred_linalg_library` to
try other supported backends.
.. seealso::
:func:`torch.linalg.eigvalsh` computes only the eigenvalues of a Hermitian matrix.
Expand Down

0 comments on commit 6cd0970

Please sign in to comment.