diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 987aff50cf75fc..cd2735077be82e 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -2746,9 +2746,6 @@ def forward(self, x): skip('as_strided_scatter'), skip('as_strided', 'partial_views'), # flaky - # Too annoying to generate random inputs - xfail('cholesky'), - # Given input size: (s0xs1x2). Calculated output size: ... skip('max_pool2d_with_indices_backward'), diff --git a/test/test_meta.py b/test/test_meta.py index 09578b30332de6..6bf23d9ed8705b 100644 --- a/test/test_meta.py +++ b/test/test_meta.py @@ -633,7 +633,6 @@ def run_meta_crossref( torch.nn.functional.gaussian_nll_loss : {f16, f64, bf16, f32}, torch.nn.functional.one_hot : {i64}, torch._segment_reduce : {f64, f16, bf16, f32}, - torch.cholesky : {f64, f32, c128, c64}, torch.cholesky_inverse : {f64, f32, c128, c64}, torch.linalg.eig : {f64, f32, c128, c64}, torch.linalg.eigvals : {f64, f32, c128, c64}, @@ -804,8 +803,6 @@ def __torch_function__(self, func, types, args=(), kwargs=None): # these always fail meta_dispatch_expected_failures = { aten.allclose.default: {f16, bf16, f32, f64, c64, c128}, # NotImplementedError: 'aten::_local_scalar_dense' - aten.cholesky.default : {c64, c128, f64, f32}, - aten.cholesky.out : {c64, c128, f64, f32}, aten.cholesky_inverse.default : {c64, c128, f64, f32}, aten.cholesky_inverse.out : {c64, c128, f64, f32}, aten.geqrf.default : {c64, c128, f64, f32}, diff --git a/test/test_ops.py b/test/test_ops.py index ca93ab161a9c95..d11891d73802ea 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1935,7 +1935,6 @@ def test_refs_are_in_decomp_table(self, op): fake_skips = ( "aminmax", # failing input - "cholesky", # Could not run 'aten::cholesky' with arguments from the 'Meta' backend "cholesky_inverse", # Could not run 'aten::cholesky' with arguments from the 'Meta' backend "cov", # aweights cannot be negtaive "istft", # window overlap add min: 0 @@ -2014,7 +2013,6 @@ def test_refs_are_in_decomp_table(self, op): "roll", "svd_lowrank", "sgn", - "cholesky", } fake_backward_xfails = {skip(s) for s in fake_backward_skips} | { diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index c0ab88ba5cdc72..60e217471152c2 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1521,7 +1521,6 @@ def f(t): fake_tensor_failures = { # FakeTensor fallback doesn't work xfail('_segment_reduce', 'lengths'), - xfail('cholesky'), xfail('cholesky_inverse'), # cannot do these as they rely on tensor data xfail('repeat_interleave'), diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 10d18e7bc8866e..a947c288c10389 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -575,6 +575,15 @@ def cholesky_solve(self: Tensor, A: Tensor, upper: bool = False) -> Tensor: return _cholesky_solve_helper(self_broadcasted, A_broadcasted, upper) +@register_meta(aten.cholesky) +@out_wrapper() +def cholesky(self: Tensor, upper: bool = False) -> Tensor: + if self.numel() == 0: + return torch.empty_like(self, memory_format=torch.legacy_contiguous_format) + squareCheckInputs(self, "cholesky") + return cloneBatchedColumnMajor(self) + + # From aten/src/ATen/native/BatchLinearAlgebra.cpp @register_meta(aten.linalg_cholesky_ex.default) def linalg_cholesky_ex(A: Tensor, upper: bool = False, check_errors: bool = False):