Skip to content

Commit

Permalink
masked_scatter shoudl accept only bool masks
Browse files Browse the repository at this point in the history
Modify test_torch to check that assert is raised in this case

torch.uint8 usage has been deprecated for a few releases, this diff
finally removes it completely.

Fixes #94634
  • Loading branch information
malfet committed Mar 30, 2023
1 parent 4cce607 commit e76fda5
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 93 deletions.
17 changes: 5 additions & 12 deletions aten/src/ATen/native/cpu/IndexKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,9 +329,8 @@ void masked_fill_kernel(TensorIterator& iter, const Scalar& value) {
});
}

template <typename scalar_t, typename mask_t>
template <typename scalar_t>
void cpu_masked_scatter_kernel(TensorIterator& iter, const TensorBase& source) {
auto is_mask_bool = std::is_same<mask_t, bool>::value;
std::ptrdiff_t source_cntr = 0;
scalar_t* source_ptr = source.data_ptr<scalar_t>();
auto numel = source.numel();
Expand All @@ -342,10 +341,7 @@ void cpu_masked_scatter_kernel(TensorIterator& iter, const TensorBase& source) {
char* mask = data[1];
const int64_t mask_stride = strides[1];
for (const auto i : c10::irange(n)) {
mask_t mask_value = *(mask_t*)(mask + mask_stride * i);
if (!is_mask_bool) {
TORCH_CHECK(mask_value <= static_cast<mask_t>(1), "Mask tensor can take 0 and 1 values only");
}
auto mask_value = *reinterpret_cast<bool*>(mask + mask_stride * i);
if (mask_value) {
TORCH_CHECK(source_cntr < numel, "Number of elements of source < number of ones in mask");
*(scalar_t*)(dst + dst_stride * i) = *(source_ptr);
Expand All @@ -358,19 +354,16 @@ void cpu_masked_scatter_kernel(TensorIterator& iter, const TensorBase& source) {
}

void masked_scatter_kernel(TensorIterator& iter, const TensorBase& source) {
TORCH_CHECK(iter.input_dtype() == ScalarType::Bool, "masked_scatter_ only supports boolean masks, "
"but got mask with dtype ", iter.input_dtype());
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
ScalarType::Bool,
ScalarType::BFloat16,
ScalarType::Half,
iter.dtype(),
"masked_scatter",
[&] {
auto mask_dtype = iter.input_dtype(0);
if (mask_dtype == ScalarType::Bool) {
cpu_masked_scatter_kernel<scalar_t, bool>(iter, source);
} else {
cpu_masked_scatter_kernel<scalar_t, unsigned char>(iter, source);
}
cpu_masked_scatter_kernel<scalar_t>(iter, source);
});
}

Expand Down
9 changes: 3 additions & 6 deletions aten/src/ATen/native/cuda/IndexKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,18 +62,15 @@ Tensor & masked_scatter__cuda(Tensor& self, const Tensor& mask, const Tensor& so
at::assert_no_internal_overlap(self);
TORCH_CHECK(
self.scalar_type() == source.scalar_type(),
"masked_scatter: expected self and source to have same dtypes but got",
"masked_scatter_: expected self and source to have same dtypes but got",
self.scalar_type(),
" and ",
source.scalar_type());
TORCH_CHECK(mask.dtype() == ScalarType::Bool, "masked_scatter_ only supports boolean masks, "
"but got mask with dtype ", mask.dtype());

c10::MaybeOwned<Tensor> b_mask = expand_inplace(self, mask, "masked_scatter_");

if (b_mask->dtype() == ScalarType::Byte) {
TORCH_WARN("masked_scatter_ received a mask with dtype torch.uint8, this behavior is now deprecated," \
"please use a mask with dtype torch.bool instead.");
}

if (self.numel() == 0) {
return self;
}
Expand Down
24 changes: 6 additions & 18 deletions aten/src/ATen/native/cuda/IndexKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -344,15 +344,15 @@ void take_kernel(

namespace {

template <typename mask_t>
__global__ void masked_scatter_size_check(int64_t *mask_exclusive_sum, mask_t *mask, int64_t srcSize) {
__global__ void masked_scatter_size_check(int64_t *mask_exclusive_sum, bool *mask, int64_t srcSize) {
// Convert exclusive sum to inclusive sum
auto totalElements = *mask_exclusive_sum + *mask;
CUDA_KERNEL_ASSERT(totalElements <= srcSize);
}

template <typename mask_t>
void masked_scatter_cuda_impl(
} // anonymous namespace

void launch_masked_scatter_kernel(
const TensorBase &self, const TensorBase &mask,
const TensorBase &maskPrefixSum, const TensorBase &source) {
auto srcSize = source.numel();
Expand All @@ -361,7 +361,7 @@ void masked_scatter_cuda_impl(

// Use a prefix sum to determine the output locations of the masked elements
auto maskPrefixSum_data = maskPrefixSum.data_ptr<int64_t>();
auto mask_data = mask_cont.data_ptr<mask_t>();
auto mask_data = mask_cont.data_ptr<bool>();

at::cuda::cub::mask_exclusive_sum(
mask_data, maskPrefixSum_data, mask_numel);
Expand Down Expand Up @@ -395,7 +395,7 @@ void masked_scatter_cuda_impl(
[&]() {
auto source_ptr = source_contig.data_ptr<scalar_t>();
gpu_kernel(
iter, [=] GPU_LAMBDA(scalar_t a, mask_t mask, int64_t maskPrefixSum) -> scalar_t {
iter, [=] GPU_LAMBDA(scalar_t a, bool mask, int64_t maskPrefixSum) -> scalar_t {
if (mask) {
return source_ptr[maskPrefixSum];
}
Expand All @@ -405,18 +405,6 @@ void masked_scatter_cuda_impl(
});
}

} // anonymous namespace

void launch_masked_scatter_kernel(
const TensorBase &self, const TensorBase &mask,
const TensorBase &maskPrefixSum, const TensorBase &source) {
if (mask.scalar_type() == kBool) {
masked_scatter_cuda_impl<bool>(self, mask, maskPrefixSum, source);
} else {
masked_scatter_cuda_impl<uint8_t>(self, mask, maskPrefixSum, source);
}
}

template <typename scalar_t>
void flip_kernel_impl(TensorIterator& iter) {
if (!iter.can_use_32bit_indexing()) {
Expand Down
104 changes: 47 additions & 57 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3654,67 +3654,57 @@ def test_scatter_add_bool(self, device):

# FIXME: find a test suite for the masked scatter operator
@onlyNativeDeviceTypes
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
def test_masked_scatter(self, device, dtype):
dt = dtype
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
for maskType in [torch.uint8, torch.bool]:
num_copy, num_dest = 3, 10
dest = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dt, device=device)
dest2 = dest.clone()
dest_ones = dest.clone()
dest_ones_expected = dest.clone()
src = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=dt, device=device)
src_ones = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=dt, device=device)
mask = torch.tensor((0, 0, 0, 0, 1, 0, 1, 0, 1, 0), dtype=maskType, device=device)

if dt == torch.bool:
# torch.bool is a special case and is being tested
# in a separate test
return

dest.masked_scatter_(mask, src)
j = 0
for i in range(num_dest):
if mask[i]:
dest2[i] = src[j]
dest_ones_expected[i] = src_ones[j]
j += 1
self.assertEqual(dest, dest2, atol=0, rtol=0)

dest_ones.masked_scatter_(mask, src_ones)
self.assertEqual(dest_ones, dest_ones_expected, atol=0, rtol=0)

# Bound checking in CUDA is done inside a kernel
# in order to avoid synchronization, but this means
# we can not clear the failures. So there is no way
# to test it then recover.
if self.device_type != 'cuda':
# make src smaller. this should fail
src = torch.zeros(num_copy - 1, dtype=dt, device=device)
with self.assertRaises(RuntimeError):
dest.masked_scatter_(mask, src)

# empty tensor
dest = torch.empty((5, 0, 5), dtype=dt, device=device)
mask = torch.ones_like(dest, dtype=maskType, device=device)
src = torch.empty((0,), dtype=dt, device=device)
dest.masked_scatter_(mask, src)

dest = torch.empty((5, 0, 5), dtype=dt, device=device)
mask = torch.ones((5, 1, 5), dtype=maskType, device=device)
src = torch.empty((0,), dtype=dt, device=device)
dest.masked_scatter_(mask, src)

num_copy, num_dest = 3, 10
dest = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dt, device=device)
dest2 = dest.clone()
dest_ones = dest.clone()
dest_ones_expected = dest.clone()
src = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=dt, device=device)
src_ones = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=dt, device=device)
mask = torch.tensor((0, 0, 0, 0, 1, 0, 1, 0, 1, 0), dtype=torch.bool, device=device)

dest.masked_scatter_(mask, src)
j = 0
for i in range(num_dest):
if mask[i]:
dest2[i] = src[j]
dest_ones_expected[i] = src_ones[j]
j += 1
self.assertEqual(dest, dest2, atol=0, rtol=0)

dest_ones.masked_scatter_(mask, src_ones)
self.assertEqual(dest_ones, dest_ones_expected, atol=0, rtol=0)

# Bound checking in CUDA is done inside a kernel
# in order to avoid synchronization, but this means
# we can not clear the failures. So there is no way
# to test it then recover.
if self.device_type != 'cuda':
self.assertEqual(len(w), 5)
else:
self.assertEqual(len(w), 4)
# make src smaller. this should fail
src = torch.zeros(num_copy - 1, dtype=dt, device=device)
with self.assertRaises(RuntimeError):
dest.masked_scatter_(mask, src)

warn = 'masked_scatter_ received a mask with dtype torch.uint8,'
for wi in w:
self.assertEqual(str(wi.message)[0:55], str(warn))
# empty tensor
dest = torch.empty((5, 0, 5), dtype=dt, device=device)
mask = torch.ones_like(dest, dtype=torch.bool, device=device)
src = torch.empty((0,), dtype=dt, device=device)
dest.masked_scatter_(mask, src)

dest = torch.empty((5, 0, 5), dtype=dt, device=device)
mask = torch.ones((5, 1, 5), dtype=torch.bool, device=device)
src = torch.empty((0,), dtype=dt, device=device)
dest.masked_scatter_(mask, src)

# test invalidate mask types
for mask_dtype in [torch.float, torch.uint8]:
dest = torch.empty(1, 3, dtype=dt, device=device)
source = torch.ones(3, 4, dtype=dt, device=device)
with self.assertRaisesRegex(RuntimeError, "masked_scatter_ only supports boolean masks"):
dest.masked_scatter_(torch.ones(1, 3, dtype=mask_dtype, device=device), source)

# FIXME: find a test suite for the masked scatter operator
@skipIfMps
Expand Down

0 comments on commit e76fda5

Please sign in to comment.