Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Fix regression in transform_inclusive_scan.
Browse files Browse the repository at this point in the history
We were deducing a reference type when we needed a value type for the
transform iterator instantition.

Fixes #1332 and adds a regression test.
  • Loading branch information
alliepiper committed Nov 9, 2020
1 parent f8eadce commit 8d50a02
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 6 deletions.
52 changes: 52 additions & 0 deletions testing/transform_scan.cu
Original file line number Diff line number Diff line change
Expand Up @@ -347,3 +347,55 @@ struct TestTransformScanToDiscardIterator
};
VariableUnitTest<TestTransformScanToDiscardIterator, IntegralTypes> TestTransformScanToDiscardIteratorInstance;

// Regression test for https://github.com/NVIDIA/thrust/issues/1332
// The issue was the internal transform_input_iterator_t created by the
// transform_inclusive_scan implementation was instantiated using a reference
// type for the value_type.
template <typename T>
void TestValueCategoryDeduction()
{
thrust::device_vector<T> vec;

T a_h[10] = {5, 0, 5, 8, 6, 7, 5, 3, 0, 9};
vec.assign((T*)a_h, a_h + 10);


thrust::transform_inclusive_scan(thrust::device,
vec.cbegin(),
vec.cend(),
vec.begin(),
thrust::identity<>{},
thrust::maximum<>{});

ASSERT_EQUAL(T{5}, vec[0]);
ASSERT_EQUAL(T{5}, vec[1]);
ASSERT_EQUAL(T{5}, vec[2]);
ASSERT_EQUAL(T{8}, vec[3]);
ASSERT_EQUAL(T{8}, vec[4]);
ASSERT_EQUAL(T{8}, vec[5]);
ASSERT_EQUAL(T{8}, vec[6]);
ASSERT_EQUAL(T{8}, vec[7]);
ASSERT_EQUAL(T{8}, vec[8]);
ASSERT_EQUAL(T{9}, vec[9]);

vec.assign((T*)a_h, a_h + 10);
thrust::transform_exclusive_scan(thrust::device,
vec.cbegin(),
vec.cend(),
vec.begin(),
thrust::identity<>{},
T{},
thrust::maximum<>{});

ASSERT_EQUAL(T{0}, vec[0]);
ASSERT_EQUAL(T{5}, vec[1]);
ASSERT_EQUAL(T{5}, vec[2]);
ASSERT_EQUAL(T{5}, vec[3]);
ASSERT_EQUAL(T{8}, vec[4]);
ASSERT_EQUAL(T{8}, vec[5]);
ASSERT_EQUAL(T{8}, vec[6]);
ASSERT_EQUAL(T{8}, vec[7]);
ASSERT_EQUAL(T{8}, vec[8]);
ASSERT_EQUAL(T{8}, vec[9]);
}
DECLARE_GENERIC_UNITTEST(TestValueCategoryDeduction);
8 changes: 5 additions & 3 deletions thrust/system/cuda/detail/transform_scan.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,19 @@ transform_inclusive_scan(execution_policy<Derived> &policy,
TransformOp transform_op,
ScanOp scan_op)
{
// Use the input iterator's value type per https://wg21.link/P0571
// Use the transformed input iterator's value type per https://wg21.link/P0571
using input_type = typename thrust::iterator_value<InputIt>::type;
#if THRUST_CPP_DIALECT < 2017
using result_type = typename std::result_of<TransformOp(input_type)>::type;
#else
using result_type = std::invoke_result_t<TransformOp, input_type>;
#endif

using value_type = typename std::remove_reference<result_type>::type;

typedef typename iterator_traits<InputIt>::difference_type size_type;
size_type num_items = static_cast<size_type>(thrust::distance(first, last));
typedef transform_input_iterator_t<result_type,
typedef transform_input_iterator_t<value_type,
InputIt,
TransformOp>
transformed_iterator_t;
Expand Down Expand Up @@ -88,7 +90,7 @@ transform_exclusive_scan(execution_policy<Derived> &policy,
ScanOp scan_op)
{
// Use the initial value type per https://wg21.link/P0571
using result_type = InitialValueType;
using result_type = typename std::remove_reference<InitialValueType>::type;

typedef typename iterator_traits<InputIt>::difference_type size_type;
size_type num_items = static_cast<size_type>(thrust::distance(first, last));
Expand Down
7 changes: 4 additions & 3 deletions thrust/system/detail/generic/transform_scan.inl
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,11 @@ __host__ __device__
// Use the input iterator's value type per https://wg21.link/P0571
using InputType = typename thrust::iterator_value<InputIterator>::type;
#if THRUST_CPP_DIALECT < 2017
using ValueType = typename std::result_of<UnaryFunction(InputType)>::type;
using ResultType = typename std::result_of<UnaryFunction(InputType)>::type;
#else
using ValueType = std::invoke_result_t<UnaryFunction, InputType>;
using ResultType = std::invoke_result_t<UnaryFunction, InputType>;
#endif
using ValueType = typename std::remove_reference<ResultType>::type;

thrust::transform_iterator<UnaryFunction, InputIterator, ValueType> _first(first, unary_op);
thrust::transform_iterator<UnaryFunction, InputIterator, ValueType> _last(last, unary_op);
Expand All @@ -79,7 +80,7 @@ __host__ __device__
AssociativeOperator binary_op)
{
// Use the initial value type per https://wg21.link/P0571
using ValueType = InitialValueType;
using ValueType = typename std::remove_reference<InitialValueType>::type;

thrust::transform_iterator<UnaryFunction, InputIterator, ValueType> _first(first, unary_op);
thrust::transform_iterator<UnaryFunction, InputIterator, ValueType> _last(last, unary_op);
Expand Down

0 comments on commit 8d50a02

Please sign in to comment.