Skip to content

Commit

Permalink
Add forward AD layout check for storage numel (pytorch#68631)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#68631

This PR:
- Adds the check that the storage numel of the base and tangent tensors are the same. This is to support the case when as_strided reveals elements that aren't indexable by the input tensor.
- Skips the check when batched tensors are involved, because using as_strided to reveal elements that not indexable by the input tensor is already not allowed vmap.
- Adds tests for the above two cases, as well as an edge case regarding conj bit (what about neg bit?)

For functorch:
- we need to copy the batching rule implemented here

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D32899678

Pulled By: soulitzer

fbshipit-source-id: 54db9550dd2c93bc66b8fb2d36ce40799ebba794
  • Loading branch information
soulitzer authored and facebook-github-bot committed Dec 14, 2021
1 parent 6078e12 commit 51033ec
Show file tree
Hide file tree
Showing 8 changed files with 104 additions and 2 deletions.
10 changes: 10 additions & 0 deletions aten/src/ATen/BatchingRegistrations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,15 @@ Tensor _new_zeros_with_same_feature_meta_batching_rule(
return self_physical_view.getPhysicalToLogicalMap().apply(result);
}

bool _has_same_storage_numel_batching_rule(const Tensor& self, const Tensor& other) {
TORCH_CHECK(isBatchedTensor(self) && !isBatchedTensor(other),
"Only the 'batched grad' use case is supported in PyTorch core.");
// The _has_same_storage_numel check is skipped if the tangent is a batched
// tensor because using as_strided to access storage locations not indexable
// by the input tensor is not supported in vmap
return true;
}

// What are the semantics of as_strided inside of vmap?
// y = vmap(lambda x: x.as_strided(sizes, strides, offset))(xs)
// This returns a view on `x`, `y`, such that each y[i] has:
Expand Down Expand Up @@ -1060,6 +1069,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
m.impl("_add_batch_dim", native::_add_batch_dim);
m.impl("_remove_batch_dim", native::_remove_batch_dim);
m.impl("_make_dual", _make_dual_batching_rule);
m.impl("_has_same_storage_numel", _has_same_storage_numel_batching_rule);
m.impl("is_same_size", native::is_same_size);
m.impl("_new_zeros_with_same_feature_meta", _new_zeros_with_same_feature_meta_batching_rule);

Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/ConjugateFallback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ TORCH_LIBRARY_IMPL(aten, Conjugate, m) {
m.impl("resolve_conj", torch::CppFunction::makeFallthrough());
m.impl("resolve_neg", torch::CppFunction::makeFallthrough());

// See test_metadata_check_when_primal_has_conj_bit in test_autograd.py
m.impl("_has_same_storage_numel", torch::CppFunction::makeFallthrough());

// linear algebra functions
m.impl("dot", torch::CppFunction::makeFallthrough());
m.impl("vdot", torch::CppFunction::makeFallthrough());
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/native/AutogradComposite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ Tensor _new_zeros_with_same_feature_meta(
return new_tensor.as_strided(out_sizes, out_strides, other_storage_offset);
}

bool _has_same_storage_numel(const at::Tensor& base, const at::Tensor& other) {
return base.storage().nbytes() / base.itemsize() == other.storage().nbytes() / other.itemsize();
}

} // namespace native

} // namespace at
3 changes: 3 additions & 0 deletions aten/src/ATen/native/NegateFallback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ TORCH_LIBRARY_IMPL(aten, Negative, m) {
m.impl("resolve_neg", torch::CppFunction::makeFallthrough());
m.impl("resolve_conj", torch::CppFunction::makeFallthrough());

// See test_metadata_check_when_primal_has_neg_bit in test_autograd.py
m.impl("_has_same_storage_numel", torch::CppFunction::makeFallthrough());

// linear algebra functions
m.impl("linalg_solve_triangular", torch::CppFunction::makeFallthrough());
m.impl("linalg_solve_triangular.out", torch::CppFunction::makeFallthrough());
Expand Down
11 changes: 11 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,17 @@
dispatch:
CompositeExplicitAutograd: _new_zeros_with_same_feature_meta

# This function compares the storage numel of self with that of other, where
# storage numel is cumputed as: `other.storage().nbytes() / other.itemsize()`.
# We create this function for composite compliance purposes. The batching rule
# always returns true because vmapped as_strided does not support accessing
# storage locations not indexable by the input tensor.
# See the note above for more information.
- func: _has_same_storage_numel(Tensor self, Tensor other) -> bool
variants: function
dispatch:
CompositeExplicitAutograd: _has_same_storage_numel

- func: rename_(Tensor(a!) self, Dimname[]? names) -> Tensor(a!)
variants: method

Expand Down
70 changes: 69 additions & 1 deletion test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7440,7 +7440,6 @@ def jvp(tangent):
self.assertIs(x_tangent, tangent)
self.assertIs(view_tangent, tangent)


def test_inplace_on_view_not_same_layout(self):
input = torch.zeros([2, 2])
tangent = torch.zeros([2, 2, 2])
Expand All @@ -7457,6 +7456,28 @@ def jvp(tangent):
self.assertIs(x_tangent, tangent)
self.assertIsNot(view_tangent, tangent)

def test_metadata_check_for_storage_numel_skipped(self):
# See: test_metadata_check_checks_storage_numel for the reverse of this test
primal = torch.randn(5)[:4].detach()
self.assertEqual(len(primal.storage()), 5)
tangent = torch.randn(10, 4)

def jvp(tangent):
with fwAD.dual_level():
dual = fwAD.make_dual(primal, tangent)
_, unpacked_tangent = fwAD.unpack_dual(dual)

# No copy is made
self.assertIs(tangent, unpacked_tangent)

# as_strided raises
with self.assertRaisesRegex(RuntimeError, "can access memory outside of `tensor`"):
dual.as_strided((5,), (1,), 0)
return unpacked_tangent

torch._vmap_internals._vmap(jvp, 0, 0)(tangent)


class TestAutogradForwardMode(TestCase):
def tearDown(self):
# Ensure that a failing test won't make others fail
Expand Down Expand Up @@ -7510,6 +7531,53 @@ def test_size_check(self):

dual = fwAD.make_dual(foo, tangent[1:])

def test_metadata_check_checks_storage_numel(self):
primal = torch.randn(5)[:4].detach()
self.assertEqual(len(primal.storage()), 5)
tangent = torch.randn(4)

with fwAD.dual_level():
dual = fwAD.make_dual(primal, tangent)
_, unpacked_tangent = fwAD.unpack_dual(dual)

# # Verify that mutating unpacked tangent does not affect the original tangent
tangent_clone = tangent.clone()
unpacked_tangent *= 2
self.assertTrue(torch.allclose(tangent_clone, tangent))

# as_strided runs without error
dual.as_strided((5,), (1,), 0)

def test_metadata_check_when_primal_has_conj_bit(self):
# Make sure the _has_same_storage_numel is a fallthrough, so that
# conj bit does not materialize. If it materializes it would
# cause the layout check to fail for views that do not index the
# the entire storage.
a = torch.randn(2, 2, dtype=torch.cdouble).conj()
b = torch.rand_like(a)

self.assertTrue(torch.is_conj(a))
self.assertEqual(len(a.storage()), len(b.storage()))

with fwAD.dual_level():
dual = fwAD.make_dual(a, b)
dual[1:]

def test_metadata_check_when_primal_has_neg_bit(self):
# Make sure the _has_same_storage_numel is a fallthrough, so that
# conj bit does not materialize. If it materializes it would
# cause the layout check to fail for views that do not index the
# the entire storage.
a = torch.randn(2, 2, dtype=torch.cdouble).conj().imag
b = torch.randn(2, 2, dtype=torch.cdouble).imag

self.assertTrue(torch.is_neg(a))
self.assertEqual(len(a.storage()), len(b.storage()))

with fwAD.dual_level():
dual = fwAD.make_dual(a, b)
dual[1:]

# The following test functions want to ensure all the following behaviors:
# - Ensure that default level system in the python binding works
# - Ensure that only level 0 exists and nesting is properly disabled
Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/gen_python_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
'data', 'is_leaf', 'output_nr', '_version', 'requires_grad_', 'retains_grad', 'set_',
'_fw_primal', 'fake_quantize_per_tensor_affine_cachemask',
'fake_quantize_per_channel_affine_cachemask',
'_new_zeros_with_same_feature_meta', # used for forward AD internals
'_new_zeros_with_same_feature_meta', '_has_same_storage_numel' # used for forward AD internals
'_reshape_alias',
'replace_', # only used by the functionalization pass, doesn't need to be exposed to python
]
Expand Down
3 changes: 3 additions & 0 deletions torch/csrc/autograd/autograd_meta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ namespace {
return false;
}
}
if (!at::_has_same_storage_numel(base, other)) {
return false;
}
return true;
}

Expand Down

0 comments on commit 51033ec

Please sign in to comment.