diff --git a/testing/sequence.cu b/testing/sequence.cu index 57285a404..9f2bff6ed 100644 --- a/testing/sequence.cu +++ b/testing/sequence.cu @@ -124,3 +124,36 @@ void TestSequenceComplex() thrust::sequence(m.begin(), m.end()); } DECLARE_UNITTEST(TestSequenceComplex); + +// A class that doesnt accept conversion from size_t but can be multiplied by a scalar +struct Vector +{ + Vector() = default; + // Explicitly disable construction from size_t + Vector(std::size_t) = delete; + __host__ __device__ Vector(int x_, int y_) : x{x_}, y{y_} {} + Vector(const Vector&) = default; + Vector &operator=(const Vector&) = default; + + int x, y; +}; + +// Vector-Vector addition +__host__ __device__ Vector operator+(const Vector a, const Vector b) { return Vector{a.x + b.x, a.y + b.y}; } +// Vector-Scalar Multiplication +__host__ __device__ Vector operator*(const int a, const Vector b) { return Vector{a * b.x, a * b.y}; } +__host__ __device__ Vector operator*(const Vector b, const int a) { return Vector{a * b.x, a * b.y}; } + +void TestSequenceNoSizeTConversion() +{ + thrust::device_vector m(64); + thrust::sequence(m.begin(), m.end(), ::Vector{0, 0}, ::Vector{1, 2}); + + for (std::size_t i = 0; i < m.size(); ++i) + { + const ::Vector v = m[i]; + ASSERT_EQUAL(static_cast(v.x), i); + ASSERT_EQUAL(static_cast(v.y), 2 * i); + } +} +DECLARE_UNITTEST(TestSequenceNoSizeTConversion); diff --git a/thrust/system/detail/generic/sequence.inl b/thrust/system/detail/generic/sequence.inl index 711fb5c7e..0fe372931 100644 --- a/thrust/system/detail/generic/sequence.inl +++ b/thrust/system/detail/generic/sequence.inl @@ -52,12 +52,25 @@ __host__ __device__ namespace detail { -template +template struct compute_sequence_value { T init; T step; + __thrust_exec_check_disable__ + __host__ __device__ + T operator()(std::size_t i) const + { + return init + step * i; + } +}; +template +struct compute_sequence_value::value>::type> +{ + T init; + T step; + __thrust_exec_check_disable__ __host__ __device__ T operator()(std::size_t i) const