Skip to content

Commit

Permalink
[pt2] add meta for cholesky (pytorch#106115)
Browse files Browse the repository at this point in the history
  • Loading branch information
nkaretnikov authored and bobby-palmer committed Jul 29, 2023
1 parent 4739dc5 commit 3f1d0df
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 9 deletions.
3 changes: 0 additions & 3 deletions test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2746,9 +2746,6 @@ def forward(self, x):
skip('as_strided_scatter'),
skip('as_strided', 'partial_views'), # flaky

# Too annoying to generate random inputs
xfail('cholesky'),

# Given input size: (s0xs1x2). Calculated output size: ...
skip('max_pool2d_with_indices_backward'),

Expand Down
3 changes: 0 additions & 3 deletions test/test_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,6 @@ def run_meta_crossref(
torch.nn.functional.gaussian_nll_loss : {f16, f64, bf16, f32},
torch.nn.functional.one_hot : {i64},
torch._segment_reduce : {f64, f16, bf16, f32},
torch.cholesky : {f64, f32, c128, c64},
torch.cholesky_inverse : {f64, f32, c128, c64},
torch.linalg.eig : {f64, f32, c128, c64},
torch.linalg.eigvals : {f64, f32, c128, c64},
Expand Down Expand Up @@ -804,8 +803,6 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
# these always fail
meta_dispatch_expected_failures = {
aten.allclose.default: {f16, bf16, f32, f64, c64, c128}, # NotImplementedError: 'aten::_local_scalar_dense'
aten.cholesky.default : {c64, c128, f64, f32},
aten.cholesky.out : {c64, c128, f64, f32},
aten.cholesky_inverse.default : {c64, c128, f64, f32},
aten.cholesky_inverse.out : {c64, c128, f64, f32},
aten.geqrf.default : {c64, c128, f64, f32},
Expand Down
2 changes: 0 additions & 2 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1935,7 +1935,6 @@ def test_refs_are_in_decomp_table(self, op):

fake_skips = (
"aminmax", # failing input
"cholesky", # Could not run 'aten::cholesky' with arguments from the 'Meta' backend
"cholesky_inverse", # Could not run 'aten::cholesky' with arguments from the 'Meta' backend
"cov", # aweights cannot be negtaive
"istft", # window overlap add min: 0
Expand Down Expand Up @@ -2014,7 +2013,6 @@ def test_refs_are_in_decomp_table(self, op):
"roll",
"svd_lowrank",
"sgn",
"cholesky",
}

fake_backward_xfails = {skip(s) for s in fake_backward_skips} | {
Expand Down
1 change: 0 additions & 1 deletion test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1521,7 +1521,6 @@ def f(t):
fake_tensor_failures = {
# FakeTensor fallback doesn't work
xfail('_segment_reduce', 'lengths'),
xfail('cholesky'),
xfail('cholesky_inverse'),
# cannot do these as they rely on tensor data
xfail('repeat_interleave'),
Expand Down
9 changes: 9 additions & 0 deletions torch/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,15 @@ def cholesky_solve(self: Tensor, A: Tensor, upper: bool = False) -> Tensor:
return _cholesky_solve_helper(self_broadcasted, A_broadcasted, upper)


@register_meta(aten.cholesky)
@out_wrapper()
def cholesky(self: Tensor, upper: bool = False) -> Tensor:
if self.numel() == 0:
return torch.empty_like(self, memory_format=torch.legacy_contiguous_format)
squareCheckInputs(self, "cholesky")
return cloneBatchedColumnMajor(self)


# From aten/src/ATen/native/BatchLinearAlgebra.cpp
@register_meta(aten.linalg_cholesky_ex.default)
def linalg_cholesky_ex(A: Tensor, upper: bool = False, check_errors: bool = False):
Expand Down

0 comments on commit 3f1d0df

Please sign in to comment.