Skip to content

Commit

Permalink
scatter/gather - check that inputs are of the same dimensionality (#4…
Browse files Browse the repository at this point in the history
…1890)

Co-authored-by: Nikita Vedeneev <[email protected]>
  • Loading branch information
gchanan and nikitaved authored Jul 23, 2020
1 parent a2922f5 commit 7c7c9c3
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 1 deletion.
19 changes: 18 additions & 1 deletion aten/src/ATen/native/ScatterGatherChecks.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,17 @@ static void scatter_gather_dtype_check(
// Test:
// 1. index.size(d) == self.size(d) for all d != dim
// 2. index.size(d) <= src.size(d) for all d != dim
// 3. index.dim() == self.dim() == src.dim()
static void gather_shape_check(const Tensor& self, int64_t dim,
const Tensor& index, const Tensor& src
) {
auto self_dims = ensure_nonempty_dim(self.dim());

TORCH_CHECK(self_dims == ensure_nonempty_dim(index.dim()),
"Index tensor must have the same number of dimensions as out tensor"
);

auto src_dims = ensure_nonempty_dim(src.dim());
TORCH_CHECK(src_dims == ensure_nonempty_dim(index.dim()),
"Index tensor must have the same number of dimensions as input tensor"
);

Expand All @@ -66,10 +71,16 @@ static void gather_shape_check(const Tensor& self, int64_t dim,
// Tests:
// 1. index.size(d) <= self.size(d) for all d != dim
// 2. index.size(d) <= src.size(d) for all d if src is a Tensor
// 3. index.dim() == self.dim() == src.dim()
static void scatter_shape_check(
const Tensor& self, int64_t dim, const Tensor& index,
const c10::optional<Tensor>& src_opt = c10::nullopt
) {
TORCH_CHECK(
ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()),
"Index tensor must have the same number of dimensions as self tensor"
);

bool is_wrong_shape = false;
int64_t self_dims = ensure_nonempty_dim(self.dim());

Expand Down Expand Up @@ -97,6 +108,12 @@ static void scatter_shape_check(

if (src_opt.has_value()) {
auto src = src_opt.value();

TORCH_CHECK(
ensure_nonempty_dim(src.dim()) == ensure_nonempty_dim(index.dim()),
"Index tensor must have the same number of dimensions as src tensor"
);

TORCH_CHECK(!is_wrong_shape,
"Expected index ", index.sizes(),
" to be smaller than self ", self.sizes(),
Expand Down
18 changes: 18 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2643,6 +2643,13 @@ def _test_gather(self, cast, test_bounds=True):
with self.assertRaisesRegex(RuntimeError, 'Expected self.dtype to be equal to src.dtype'):
torch.gather(src, dim, idx, out=expected.to(torch.int))

# checks for the same dimensionality
with self.assertRaisesRegex(RuntimeError, 'Index tensor must have the same number of dimensions as input tensor'):
torch.gather(src, dim, idx.unsqueeze(-1))

with self.assertRaisesRegex(RuntimeError, 'Index tensor must have the same number of dimensions as input tensor'):
torch.gather(src.unsqueeze(-1), dim, idx)

if test_bounds:
idx[0][0][0] = 23
self.assertRaises(RuntimeError, lambda: torch.gather(src, dim, idx))
Expand Down Expand Up @@ -2728,6 +2735,17 @@ def _test_scatter_base(self, cast, method, is_scalar=False, test_bounds=True, *,
with self.assertRaisesRegex(RuntimeError, 'Expected dtype int64 for index'):
getattr(base.clone(), method)(dim, idx.type(torch.int), src)

# check for the same dimensionality
with self.assertRaisesRegex(RuntimeError, 'Index tensor must have the same number of dimensions as self tensor'):
getattr(base.clone().unsqueeze(-1), method)(dim, idx, src)

with self.assertRaisesRegex(RuntimeError, 'Index tensor must have the same number of dimensions as self tensor'):
getattr(base.clone(), method)(dim, idx.unsqueeze(-1), src)

if not is_scalar:
with self.assertRaisesRegex(RuntimeError, 'Index tensor must have the same number of dimensions as src tensor'):
getattr(base.clone(), method)(dim, idx, src.unsqueeze(-1))

if test_bounds:
idx[0][0][0] = 34
with self.assertRaises(RuntimeError):
Expand Down

0 comments on commit 7c7c9c3

Please sign in to comment.