Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes integer overflows in index computation when indexes approach numeric_limits<OffsetT>::max() #1419

Merged
merged 20 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 57 additions & 35 deletions cub/cub/agent/agent_merge_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ struct AgentPartition
CompareOpT compare_op;
OffsetT target_merged_tiles_number;
int items_per_tile;
OffsetT num_partitions;

_CCCL_DEVICE _CCCL_FORCEINLINE AgentPartition(bool ping,
KeyIteratorT keys_ping,
Expand All @@ -322,7 +323,8 @@ struct AgentPartition
OffsetT *merge_partitions,
CompareOpT compare_op,
OffsetT target_merged_tiles_number,
int items_per_tile)
int items_per_tile,
OffsetT num_partitions)
: ping(ping)
, keys_ping(keys_ping)
, keys_pong(keys_pong)
Expand All @@ -332,6 +334,7 @@ struct AgentPartition
, compare_op(compare_op)
, target_merged_tiles_number(target_merged_tiles_number)
, items_per_tile(items_per_tile)
, num_partitions(num_partitions)
{}

_CCCL_DEVICE _CCCL_FORCEINLINE void Process()
Expand All @@ -352,27 +355,37 @@ struct AgentPartition
OffsetT local_tile_idx = mask & partition_idx;

OffsetT keys1_beg = (cub::min)(keys_count, start);
OffsetT keys1_end = (cub::min)(keys_count, start + size);
OffsetT keys1_end = (cub::min)(keys_count, detail::safe_add_bound_to_max(start, size));
OffsetT keys2_beg = keys1_end;
OffsetT keys2_end = (cub::min)(keys_count, keys2_beg + size);

OffsetT partition_at = (cub::min)(keys2_end - keys1_beg,
items_per_tile * local_tile_idx);

OffsetT partition_diag = ping ? MergePath<KeyT>(keys_ping + keys1_beg,
keys_ping + keys2_beg,
keys1_end - keys1_beg,
keys2_end - keys2_beg,
partition_at,
compare_op)
: MergePath<KeyT>(keys_pong + keys1_beg,
keys_pong + keys2_beg,
keys1_end - keys1_beg,
keys2_end - keys2_beg,
partition_at,
compare_op);

merge_partitions[partition_idx] = keys1_beg + partition_diag;
OffsetT keys2_end = (cub::min)(keys_count, detail::safe_add_bound_to_max(keys2_beg, size));

// The last partition (which is one-past-the-last-tile) is only to mark the end of keys1_end for the merge stage
if (partition_idx + 1 == num_partitions)
{
merge_partitions[partition_idx] = keys1_end;
}
else
{
OffsetT partition_at = (cub::min)(keys2_end - keys1_beg, items_per_tile * local_tile_idx);

OffsetT partition_diag =
ping ? MergePath<KeyT>(
keys_ping + keys1_beg,
keys_ping + keys2_beg,
keys1_end - keys1_beg,
keys2_end - keys2_beg,
partition_at,
compare_op)
: MergePath<KeyT>(
keys_pong + keys1_beg,
keys_pong + keys2_beg,
keys1_end - keys1_beg,
keys2_end - keys2_beg,
partition_at,
compare_op);

merge_partitions[partition_idx] = keys1_beg + partition_diag;
}
}
};

Expand Down Expand Up @@ -525,16 +538,25 @@ struct AgentMerge

OffsetT diag = ITEMS_PER_TILE * tile_idx - start;

OffsetT keys1_beg = partition_beg;
OffsetT keys1_end = partition_end;
OffsetT keys2_beg = (cub::min)(keys_count, 2 * start + size + diag - partition_beg);
OffsetT keys2_end = (cub::min)(keys_count, 2 * start + size + diag + ITEMS_PER_TILE - partition_end);
OffsetT keys1_beg = partition_beg - start;
OffsetT keys1_end = partition_end - start;

OffsetT keys_end_dist_from_start = keys_count - start;
OffsetT max_keys2 = (keys_end_dist_from_start > size) ? (keys_end_dist_from_start - size) : 0;

// We have the following invariants:
// diag >= keys1_beg, because diag is the distance of the total merge path so far (keys1 + keys2)
// diag+ITEMS_PER_TILE >= keys1_end, because diag+ITEMS_PER_TILE is the distance of the merge path for the next tile
// and keys1_end is key1's component of that path
OffsetT keys2_beg = (cub::min)(max_keys2, diag - keys1_beg);
OffsetT keys2_end =
(cub::min)(max_keys2, detail::safe_add_bound_to_max(diag, static_cast<OffsetT>(ITEMS_PER_TILE)) - keys1_end);

// Check if it's the last tile in the tile group being merged
if (mask == (mask & tile_idx))
{
keys1_end = (cub::min)(keys_count, start + size);
keys2_end = (cub::min)(keys_count, start + size * 2);
keys1_end = (cub::min)(keys_count-start, size);
keys2_end = (cub::min)(max_keys2, size);
}

// number of keys per tile
Expand All @@ -547,16 +569,16 @@ struct AgentMerge
if (ping)
{
gmem_to_reg<IS_FULL_TILE>(keys_local,
keys_in_ping + keys1_beg,
keys_in_ping + keys2_beg,
keys_in_ping + start + keys1_beg,
keys_in_ping + start + size + keys2_beg,
num_keys1,
num_keys2);
}
else
{
gmem_to_reg<IS_FULL_TILE>(keys_local,
keys_in_pong + keys1_beg,
keys_in_pong + keys2_beg,
keys_in_pong + start + keys1_beg,
keys_in_pong + start + size + keys2_beg,
num_keys1,
num_keys2);
}
Expand All @@ -570,16 +592,16 @@ struct AgentMerge
if (ping)
{
gmem_to_reg<IS_FULL_TILE>(items_local,
items_in_ping + keys1_beg,
items_in_ping + keys2_beg,
items_in_ping + start + keys1_beg,
items_in_ping + start + size + keys2_beg,
num_keys1,
num_keys2);
}
else
{
gmem_to_reg<IS_FULL_TILE>(items_local,
items_in_pong + keys1_beg,
items_in_pong + keys2_beg,
items_in_pong + start + keys1_beg,
items_in_pong + start + size + keys2_beg,
num_keys1,
num_keys2);
}
Expand Down
23 changes: 23 additions & 0 deletions cub/cub/detail/choose_offset.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,29 @@ struct choose_offset
template <typename NumItemsT>
using choose_offset_t = typename choose_offset<NumItemsT>::type;

/**
* promote_small_offset checks NumItemsT, the type of the num_items parameter, and
* promotes any integral type smaller than 32 bits to a signed 32-bit integer type.
*/
template <typename NumItemsT>
struct promote_small_offset
{
// NumItemsT must be an integral type (but not bool).
static_assert(::cuda::std::is_integral<NumItemsT>::value
&& !::cuda::std::is_same<typename ::cuda::std::remove_cv<NumItemsT>::type, bool>::value,
"NumItemsT must be an integral type, but not bool");

// Unsigned integer type for global offsets.
using type = typename ::cuda::std::conditional<sizeof(NumItemsT) < 4, std::int32_t, NumItemsT>::type;
};

/**
* promote_small_offset_t is an alias template that checks NumItemsT, the type of the num_items parameter, and
* promotes any integral type smaller than 32 bits to a signed 32-bit integer type.
*/
template <typename NumItemsT>
using promote_small_offset_t = typename promote_small_offset<NumItemsT>::type;

/**
* common_iterator_value sets member type to the common_type of
* value_type for all argument types. used to get OffsetT in
Expand Down
29 changes: 22 additions & 7 deletions cub/cub/device/device_merge_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
# pragma system_header
#endif // no system header

#include <cub/detail/choose_offset.cuh>
#include <cub/device/dispatch/dispatch_merge_sort.cuh>
#include <cub/util_deprecated.cuh>
#include <cub/util_namespace.cuh>
Expand Down Expand Up @@ -217,11 +218,13 @@ struct DeviceMergeSort
CompareOpT compare_op,
cudaStream_t stream = 0)
{
using PromotedOffsetT = detail::promote_small_offset_t<OffsetT>;

using DispatchMergeSortT = DispatchMergeSort<KeyIteratorT,
ValueIteratorT,
KeyIteratorT,
ValueIteratorT,
OffsetT,
PromotedOffsetT,
CompareOpT>;

return DispatchMergeSortT::Dispatch(d_temp_storage,
Expand Down Expand Up @@ -390,11 +393,13 @@ struct DeviceMergeSort
CompareOpT compare_op,
cudaStream_t stream = 0)
{
using PromotedOffsetT = detail::promote_small_offset_t<OffsetT>;

using DispatchMergeSortT = DispatchMergeSort<KeyInputIteratorT,
ValueInputIteratorT,
KeyIteratorT,
ValueIteratorT,
OffsetT,
PromotedOffsetT,
CompareOpT>;

return DispatchMergeSortT::Dispatch(d_temp_storage,
Expand Down Expand Up @@ -539,11 +544,13 @@ struct DeviceMergeSort
CompareOpT compare_op,
cudaStream_t stream = 0)
{
using PromotedOffsetT = detail::promote_small_offset_t<OffsetT>;

using DispatchMergeSortT = DispatchMergeSort<KeyIteratorT,
NullType *,
KeyIteratorT,
NullType *,
OffsetT,
PromotedOffsetT,
CompareOpT>;

return DispatchMergeSortT::Dispatch(d_temp_storage,
Expand Down Expand Up @@ -689,11 +696,13 @@ struct DeviceMergeSort
CompareOpT compare_op,
cudaStream_t stream = 0)
{
using PromotedOffsetT = detail::promote_small_offset_t<OffsetT>;

using DispatchMergeSortT = DispatchMergeSort<KeyInputIteratorT,
NullType *,
KeyIteratorT,
NullType *,
OffsetT,
PromotedOffsetT,
CompareOpT>;

return DispatchMergeSortT::Dispatch(d_temp_storage,
Expand Down Expand Up @@ -839,7 +848,9 @@ struct DeviceMergeSort
CompareOpT compare_op,
cudaStream_t stream = 0)
{
return SortPairs<KeyIteratorT, ValueIteratorT, OffsetT, CompareOpT>(
using PromotedOffsetT = detail::promote_small_offset_t<OffsetT>;

return SortPairs<KeyIteratorT, ValueIteratorT, PromotedOffsetT, CompareOpT>(
d_temp_storage,
temp_storage_bytes,
d_keys,
Expand Down Expand Up @@ -971,7 +982,9 @@ struct DeviceMergeSort
CompareOpT compare_op,
cudaStream_t stream = 0)
{
return SortKeys<KeyIteratorT, OffsetT, CompareOpT>(d_temp_storage,
using PromotedOffsetT = detail::promote_small_offset_t<OffsetT>;

return SortKeys<KeyIteratorT, PromotedOffsetT, CompareOpT>(d_temp_storage,
temp_storage_bytes,
d_keys,
num_items,
Expand Down Expand Up @@ -1112,7 +1125,9 @@ struct DeviceMergeSort
CompareOpT compare_op,
cudaStream_t stream = 0)
{
return SortKeysCopy<KeyInputIteratorT, KeyIteratorT, OffsetT, CompareOpT>(d_temp_storage,
using PromotedOffsetT = detail::promote_small_offset_t<OffsetT>;

return SortKeysCopy<KeyInputIteratorT, KeyIteratorT, PromotedOffsetT, CompareOpT>(d_temp_storage,
temp_storage_bytes,
d_input_keys,
d_output_keys,
Expand Down
3 changes: 2 additions & 1 deletion cub/cub/device/dispatch/dispatch_merge_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,8 @@ CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceMergeSortPartitionKernel(
merge_partitions,
compare_op,
target_merged_tiles_number,
items_per_tile);
items_per_tile,
num_partitions);

agent.Process();
}
Expand Down
20 changes: 14 additions & 6 deletions cub/cub/util_math.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,23 @@ using is_integral_or_enum =
::cuda::std::integral_constant<bool,
::cuda::std::is_integral<T>::value || ::cuda::std::is_enum<T>::value>;

_CCCL_HOST_DEVICE _CCCL_FORCEINLINE constexpr ::cuda::std::size_t
VshmemSize(::cuda::std::size_t max_shmem,
::cuda::std::size_t shmem_per_block,
::cuda::std::size_t num_blocks)
/**
* Computes lhs + rhs, but bounds the result to the maximum number representable by the given type, if the addition would
* overflow. Note, lhs must be non-negative.
*
* Effectively performs `min((lhs + rhs), ::cuda::std::numeric_limits<OffsetT>::max())`, but is robust against the case
* where `(lhs + rhs)` would overflow.
*/
template <typename OffsetT>
_CCCL_HOST_DEVICE _CCCL_FORCEINLINE OffsetT safe_add_bound_to_max(OffsetT lhs, OffsetT rhs)
{
return shmem_per_block > max_shmem ? shmem_per_block * num_blocks : 0;
static_assert(::cuda::std::is_integral<OffsetT>::value, "OffsetT must be an integral type");
static_assert(sizeof(OffsetT) >= 4, "OffsetT must be at least 32 bits in size");
auto const capped_operand_rhs = (cub::min)(rhs, ::cuda::std::numeric_limits<OffsetT>::max() - lhs);
return lhs + capped_operand_rhs;
}

}
} // namespace detail

/**
* Divide n by d, round up if any remainder, and return the result.
Expand Down
Loading
Loading