Skip to content

Commit

Permalink
[BACKPORT]: Add missing overloads for thrust::pow (#1223)
Browse files Browse the repository at this point in the history
* Add missing overloads for thrust::pow

Also add proper type checks for all of those overloads so that we can ensure that we are

* Properly constraint the pow overloads and stop pulling in `cuda::std::pow`
  • Loading branch information
miscco committed Dec 16, 2023
1 parent b28d445 commit c4eda1a
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 7 deletions.
34 changes: 32 additions & 2 deletions thrust/testing/complex.cu
Original file line number Diff line number Diff line change
Expand Up @@ -449,17 +449,18 @@ struct TestComplexBasicArithmetic
// Test the basic arithmetic functions against std

ASSERT_ALMOST_EQUAL(thrust::abs(a), std::abs(b));

ASSERT_ALMOST_EQUAL(thrust::arg(a), std::arg(b));

ASSERT_ALMOST_EQUAL(thrust::norm(a), std::norm(b));

ASSERT_EQUAL(thrust::conj(a), std::conj(b));
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::conj(a))>::value, "");

ASSERT_ALMOST_EQUAL(thrust::polar(data[0], data[1]), std::polar(data[0], data[1]));
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::polar(data[0], data[1]))>::value, "");

// random_samples does not seem to produce infinities so proj(z) == z
ASSERT_EQUAL(thrust::proj(a), a);
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::proj(a))>::value, "");
}
};
SimpleUnitTest<TestComplexBasicArithmetic, FloatingPointTypes> TestComplexBasicArithmeticInstance;
Expand Down Expand Up @@ -556,6 +557,9 @@ struct TestComplexExponentialFunctions
ASSERT_ALMOST_EQUAL(thrust::exp(a), std::exp(b));
ASSERT_ALMOST_EQUAL(thrust::log(a), std::log(b));
ASSERT_ALMOST_EQUAL(thrust::log10(a), std::log10(b));
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::exp(a))>::value, "");
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::log(a))>::value, "");
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::log10(a))>::value, "");
}
};
SimpleUnitTest<TestComplexExponentialFunctions, FloatingPointTypes>
Expand All @@ -575,16 +579,24 @@ struct TestComplexPowerFunctions
const std::complex<T> b_std(b_thrust);

ASSERT_ALMOST_EQUAL(thrust::pow(a_thrust, b_thrust), std::pow(a_std, b_std));
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::pow(a_thrust, b_thrust))>::value, "");
ASSERT_ALMOST_EQUAL(thrust::pow(a_thrust, b_thrust.real()), std::pow(a_std, b_std.real()));
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::pow(a_thrust, b_thrust.real()))>::value, "");
ASSERT_ALMOST_EQUAL(thrust::pow(a_thrust.real(), b_thrust), std::pow(a_std.real(), b_std));
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::pow(a_thrust.real(), b_thrust))>::value, "");

ASSERT_ALMOST_EQUAL(thrust::pow(a_thrust, 4), std::pow(a_std, 4));
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::pow(a_thrust, 4))>::value, "");

ASSERT_ALMOST_EQUAL(thrust::sqrt(a_thrust), std::sqrt(a_std));
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::sqrt(a_thrust))>::value, "");
}

// Test power functions with promoted types.
{
using T0 = T;
using T1 = other_floating_point_type_t<T0>;
using promoted = typename thrust::detail::promoted_numerical_type<T0, T1>::type;

thrust::host_vector<T0> data = unittest::random_samples<T0>(4);

Expand All @@ -594,11 +606,17 @@ struct TestComplexPowerFunctions
const std::complex<T0> b_std(data[2], data[3]);

ASSERT_ALMOST_EQUAL(thrust::pow(a_thrust, b_thrust), std::pow(a_std, b_std));
static_assert(cuda::std::is_same<thrust::complex<promoted>, decltype(thrust::pow(a_thrust, b_thrust))>::value, "");
ASSERT_ALMOST_EQUAL(thrust::pow(b_thrust, a_thrust), std::pow(b_std, a_std));
static_assert(cuda::std::is_same<thrust::complex<promoted>, decltype(thrust::pow(b_thrust, a_thrust))>::value, "");
ASSERT_ALMOST_EQUAL(thrust::pow(a_thrust, b_thrust.real()), std::pow(a_std, b_std.real()));
static_assert(cuda::std::is_same<thrust::complex<promoted>, decltype(thrust::pow(a_thrust, b_thrust.real()))>::value, "");
ASSERT_ALMOST_EQUAL(thrust::pow(b_thrust, a_thrust.real()), std::pow(b_std, a_std.real()));
static_assert(cuda::std::is_same<thrust::complex<promoted>, decltype(thrust::pow(b_thrust, a_thrust.real()))>::value, "");
ASSERT_ALMOST_EQUAL(thrust::pow(a_thrust.real(), b_thrust), std::pow(a_std.real(), b_std));
static_assert(cuda::std::is_same<thrust::complex<promoted>, decltype(thrust::pow(a_thrust.real(), b_thrust))>::value, "");
ASSERT_ALMOST_EQUAL(thrust::pow(b_thrust.real(), a_thrust), std::pow(b_std.real(), a_std));
static_assert(cuda::std::is_same<thrust::complex<promoted>, decltype(thrust::pow(b_thrust.real(), a_thrust))>::value, "");
}
}
};
Expand All @@ -617,20 +635,32 @@ struct TestComplexTrigonometricFunctions
ASSERT_ALMOST_EQUAL(thrust::cos(a), std::cos(c));
ASSERT_ALMOST_EQUAL(thrust::sin(a), std::sin(c));
ASSERT_ALMOST_EQUAL(thrust::tan(a), std::tan(c));
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::cos(a))>::value, "");
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::sin(a))>::value, "");
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::tan(a))>::value, "");

ASSERT_ALMOST_EQUAL(thrust::cosh(a), std::cosh(c));
ASSERT_ALMOST_EQUAL(thrust::sinh(a), std::sinh(c));
ASSERT_ALMOST_EQUAL(thrust::tanh(a), std::tanh(c));
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::cosh(a))>::value, "");
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::sinh(a))>::value, "");
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::tanh(a))>::value, "");

#if THRUST_CPP_DIALECT >= 2011

ASSERT_ALMOST_EQUAL(thrust::acos(a), std::acos(c));
ASSERT_ALMOST_EQUAL(thrust::asin(a), std::asin(c));
ASSERT_ALMOST_EQUAL(thrust::atan(a), std::atan(c));
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::acos(a))>::value, "");
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::asin(a))>::value, "");
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::atan(a))>::value, "");

ASSERT_ALMOST_EQUAL(thrust::acosh(a), std::acosh(c));
ASSERT_ALMOST_EQUAL(thrust::asinh(a), std::asinh(c));
ASSERT_ALMOST_EQUAL(thrust::atanh(a), std::atanh(c));
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::acosh(a))>::value, "");
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::asinh(a))>::value, "");
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::atanh(a))>::value, "");

#endif
}
Expand Down
21 changes: 16 additions & 5 deletions thrust/thrust/complex.h
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,8 @@ using ::cuda::std::proj;
using ::cuda::std::exp;
using ::cuda::std::log;
using ::cuda::std::log10;
using ::cuda::std::pow;
// pow always returns a complex.
// using ::cuda::std::pow;
using ::cuda::std::sqrt;

using ::cuda::std::acos;
Expand Down Expand Up @@ -516,15 +517,25 @@ template<class T>
__host__ __device__ complex<T> log10(const complex<T>& c) {
return static_cast<complex<T>>(::cuda::std::log10(c));
}
template<class T>
__host__ __device__ complex<T> pow(const complex<T>& c) {
return static_cast<complex<T>>(::cuda::std::pow(c));
template<class T0, class T1>
__host__ __device__ complex<typename detail::promoted_numerical_type<T0, T1>::type>
pow(const complex<T0>& x, const complex<T1>& y) {
return static_cast<complex<typename detail::promoted_numerical_type<T0, T1>::type>>(::cuda::std::pow(x, y));
}
template<class T0, class T1, ::cuda::std::__enable_if_t<::cuda::std::is_arithmetic<T1>::value, int> = 0>
__host__ __device__ complex<typename detail::promoted_numerical_type<T0, T1>::type>
pow(const complex<T0>& x, const T1& y) {
return static_cast<complex<typename detail::promoted_numerical_type<T0, T1>::type>>(::cuda::std::pow(x, y));
}
template<class T0, class T1, ::cuda::std::__enable_if_t<::cuda::std::is_arithmetic<T0>::value, int> = 0>
__host__ __device__ complex<typename detail::promoted_numerical_type<T0, T1>::type>
pow(const T0& x, const complex<T1>& y) {
return static_cast<complex<typename detail::promoted_numerical_type<T0, T1>::type>>(::cuda::std::pow(x, y));
}
template<class T>
__host__ __device__ complex<T> sqrt(const complex<T>& c) {
return static_cast<complex<T>>(::cuda::std::sqrt(c));
}

template<class T>
__host__ __device__ complex<T> acos(const complex<T>& c) {
return static_cast<complex<T>>(::cuda::std::acos(c));
Expand Down

0 comments on commit c4eda1a

Please sign in to comment.