Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for stream operators for complex #1538

Merged
merged 1 commit into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ _CCCL_DIAG_POP
# include "../__type_traits/is_same.h"
# include "../cmath"

# if !defined(_CCCL_COMPILER_NVRTC)
# include <sstream> // for std::basic_ostringstream
# endif // !_CCCL_COMPILER_NVRTC

_LIBCUDACXX_BEGIN_NAMESPACE_STD

template <>
Expand Down Expand Up @@ -72,8 +76,8 @@ class _LIBCUDACXX_TEMPLATE_VIS _LIBCUDACXX_COMPLEX_ALIGNAS(alignof(__nv_bfloat16
: __repr(__re, __im)
{}

_CCCL_DIAG_PUSH
_CCCL_DIAG_SUPPRESS_MSVC(4244) // narrowing conversions
_CCCL_DIAG_PUSH
_CCCL_DIAG_SUPPRESS_MSVC(4244) // narrowing conversions

_LIBCUDACXX_INLINE_VISIBILITY explicit complex(const complex<float>& __c)
: __repr(__c.real(), __c.imag())
Expand All @@ -82,7 +86,7 @@ _CCCL_DIAG_SUPPRESS_MSVC(4244) // narrowing conversions
: __repr(__c.real(), __c.imag())
{}

_CCCL_DIAG_POP
_CCCL_DIAG_POP

# if !defined(_CCCL_COMPILER_NVRTC)
template <class _Up>
Expand Down Expand Up @@ -228,6 +232,25 @@ inline _LIBCUDACXX_INLINE_VISIBILITY complex<__nv_bfloat16> acos(const complex<_
return complex<__nv_bfloat16>{_CUDA_VSTD::acos(complex<float>{__x.real(), __x.imag()})};
}

# if !defined(_CCCL_COMPILER_NVRTC)
template <class _CharT, class _Traits>
::std::basic_istream<_CharT, _Traits>&
operator>>(::std::basic_istream<_CharT, _Traits>& __is, complex<__nv_bfloat16>& __x)
{
::std::complex<float> __temp;
__is >> __temp;
__x = __temp;
return __is;
}

template <class _CharT, class _Traits>
::std::basic_ostream<_CharT, _Traits>&
operator<<(::std::basic_ostream<_CharT, _Traits>& __os, const complex<__nv_bfloat16>& __x)
{
return __os << complex<float>{__x.real(), __x.imag()};
}
# endif // !_CCCL_COMPILER_NVRTC

_LIBCUDACXX_END_NAMESPACE_STD

#endif /// _LIBCUDACXX_HAS_NVBF16
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@
# include "../__type_traits/is_same.h"
# include "../cmath"

# if !defined(_CCCL_COMPILER_NVRTC)
# include <sstream> // for std::basic_ostringstream
# endif // !_CCCL_COMPILER_NVRTC

_LIBCUDACXX_BEGIN_NAMESPACE_STD

template <>
Expand Down Expand Up @@ -69,8 +73,8 @@ class _LIBCUDACXX_TEMPLATE_VIS _LIBCUDACXX_COMPLEX_ALIGNAS(alignof(__half2)) com
: __repr(__re, __im)
{}

_CCCL_DIAG_PUSH
_CCCL_DIAG_SUPPRESS_MSVC(4244) // narrowing conversions
_CCCL_DIAG_PUSH
_CCCL_DIAG_SUPPRESS_MSVC(4244) // narrowing conversions

_LIBCUDACXX_INLINE_VISIBILITY explicit complex(const complex<float>& __c)
: __repr(__c.real(), __c.imag())
Expand All @@ -79,7 +83,7 @@ _CCCL_DIAG_SUPPRESS_MSVC(4244) // narrowing conversions
: __repr(__c.real(), __c.imag())
{}

_CCCL_DIAG_POP
_CCCL_DIAG_POP

# if !defined(_CCCL_COMPILER_NVRTC)
template <class _Up>
Expand Down Expand Up @@ -225,6 +229,24 @@ inline _LIBCUDACXX_INLINE_VISIBILITY complex<__half> acos(const complex<__half>&
return complex<__half>{_CUDA_VSTD::acos(complex<float>{__x.real(), __x.imag()})};
}

# if !defined(_LIBCUDACXX_HAS_NO_LOCALIZATION) && !defined(_CCCL_COMPILER_NVRTC)
template <class _CharT, class _Traits>
::std::basic_istream<_CharT, _Traits>& operator>>(::std::basic_istream<_CharT, _Traits>& __is, complex<__half>& __x)
{
::std::complex<float> __temp;
__is >> __temp;
__x = __temp;
return __is;
}

template <class _CharT, class _Traits>
::std::basic_ostream<_CharT, _Traits>&
operator<<(::std::basic_ostream<_CharT, _Traits>& __os, const complex<__half>& __x)
{
return __os << complex<float>{__x.real(), __x.imag()};
}
# endif // !_LIBCUDACXX_HAS_NO_LOCALIZATION && !_CCCL_COMPILER_NVRTC

_LIBCUDACXX_END_NAMESPACE_STD

#endif /// _LIBCUDACXX_HAS_NVFP16
Expand Down
21 changes: 20 additions & 1 deletion libcudacxx/include/cuda/std/detail/libcxx/include/complex
Original file line number Diff line number Diff line change
Expand Up @@ -1829,7 +1829,26 @@ tan(const complex<_Tp>& __x)
return complex<_Tp>(__z.imag(), -__z.real());
}

#ifndef __cuda_std__
#ifdef __cuda_std__
# if !defined(_CCCL_COMPILER_NVRTC)
template <class _Tp, class _CharT, class _Traits>
::std::basic_istream<_CharT, _Traits>&
operator>>(::std::basic_istream<_CharT, _Traits>& __is, complex<_Tp>& __x)
{
::std::complex<_Tp> __temp;
__is >> __temp;
__x = __temp;
return __is;
}

template <class _Tp, class _CharT, class _Traits>
::std::basic_ostream<_CharT, _Traits>&
operator<<(::std::basic_ostream<_CharT, _Traits>& __os, const complex<_Tp>& __x)
{
return __os << static_cast<::std::complex<_Tp>>(__x);
}
# endif // !_CCCL_COMPILER_NVRTC
#else // ^^^ __cuda_std__ ^^^ / vvv !__cuda_std__ vvv
#if !defined(_LIBCUDACXX_HAS_NO_LOCALIZATION)
template<class _Tp, class _CharT, class _Traits>
basic_istream<_CharT, _Traits>&
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
//===----------------------------------------------------------------------===//
//
// Part of the libcu++ Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#include <cuda/std/complex>
#include <cuda/std/cassert>

#include "test_macros.h"

template <class T, class U>
__host__ __device__ void test_assignment() {
cuda::std::complex<T> from_only_real{static_cast<T>(-1.0),
static_cast<T>(1.0)};
cuda::std::complex<T> from_only_imag{static_cast<T>(-1.0),
static_cast<T>(1.0)};
cuda::std::complex<T> from_real_imag{static_cast<T>(-1.0),
static_cast<T>(1.0)};

const cuda::std::complex<U> only_real{static_cast<U>(42.0), static_cast<U>(0.0)};
const cuda::std::complex<U> only_imag{static_cast<U>(0.0), static_cast<U>(42.0)};
const cuda::std::complex<U> real_imag{static_cast<U>(42.0),
static_cast<U>(112.0)};

from_only_real = only_real;
from_only_imag = only_imag;
from_real_imag = real_imag;

assert(from_only_real.real() == static_cast<T>(42.0));
assert(from_only_real.imag() == static_cast<T>(0.0));
assert(from_only_imag.real() == static_cast<T>(0.0));
assert(from_only_imag.imag() == static_cast<T>(42.0));
assert(from_real_imag.real() == static_cast<T>(42.0));
assert(from_real_imag.imag() == static_cast<T>(112.0));
}

__host__ __device__ void test() {
#ifdef _LIBCUDACXX_HAS_NVFP16
test_assignment<__half, float>();
test_assignment<__half, double>();
test_assignment<float, __half>();
test_assignment<double, __half>();
#endif // _LIBCUDACXX_HAS_NVFP16
#ifdef _LIBCUDACXX_HAS_NVBF16
test_assignment<__nv_bfloat16, float>();
test_assignment<__nv_bfloat16, double>();
test_assignment<float, __nv_bfloat16>();
test_assignment<double, __nv_bfloat16>();
#endif // _LIBCUDACXX_HAS_NVBF16
}

int main(int arg, char** argv) {
test();
return 0;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
//===----------------------------------------------------------------------===//
//
// Part of the libcu++ Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#include <cuda/std/complex>
#include <cuda/std/cassert>

#include "test_macros.h"

template <class T, class U>
__host__ __device__ void test_construction() {
const cuda::std::complex<U> only_real{static_cast<U>(42.0), static_cast<U>(0.0)};
const cuda::std::complex<U> only_imag{static_cast<U>(0.0), static_cast<U>(42.0)};
const cuda::std::complex<U> real_imag{static_cast<U>(42.0),
static_cast<U>(112.0)};

const cuda::std::complex<T> from_only_real{only_real};
const cuda::std::complex<T> from_only_imag{only_imag};
const cuda::std::complex<T> from_real_imag{real_imag};

assert(from_only_real.real() == static_cast<T>(42.0));
assert(from_only_real.imag() == static_cast<T>(0.0));
assert(from_only_imag.real() == static_cast<T>(0.0));
assert(from_only_imag.imag() == static_cast<T>(42.0));
assert(from_real_imag.real() == static_cast<T>(42.0));
assert(from_real_imag.imag() == static_cast<T>(112.0));
}

__host__ __device__ void test() {
#ifdef _LIBCUDACXX_HAS_NVFP16
test_construction<__half, float>();
test_construction<__half, double>();
test_construction<float, __half>();
test_construction<double, __half>();
#endif // _LIBCUDACXX_HAS_NVFP16
#ifdef _LIBCUDACXX_HAS_NVBF16
test_construction<__nv_bfloat16, float>();
test_construction<__nv_bfloat16, double>();
test_construction<float, __nv_bfloat16>();
test_construction<double, __nv_bfloat16>();
#endif // _LIBCUDACXX_HAS_NVBF16
}

int main(int arg, char** argv) {
test();
return 0;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

// UNSUPPORTED: no-localization
// UNSUPPORTED: nvrtc

// <complex>

// template<class T, class charT, class traits>
// basic_istream<charT, traits>&
// operator>>(basic_istream<charT, traits>& is, complex<T>& x);

#include <cuda/std/complex>
#include <cuda/std/cassert>

#include <sstream>

#include "test_macros.h"

template <class T>
void test() {
{
std::istringstream is("5");
cuda::std::complex<T> c;
is >> c;
assert(c == cuda::std::complex<T>(5, 0));
assert(is.eof());
}
{
std::istringstream is(" 5 ");
cuda::std::complex<T> c;
is >> c;
assert(c == cuda::std::complex<T>(5, 0));
assert(is.good());
}
{
std::istringstream is(" 5, ");
cuda::std::complex<T> c;
is >> c;
assert(c == cuda::std::complex<T>(5, 0));
assert(is.good());
}
{
std::istringstream is(" , 5, ");
cuda::std::complex<T> c;
is >> c;
assert(c == cuda::std::complex<T>(0, 0));
assert(is.fail());
}
{
std::istringstream is("5.5 ");
cuda::std::complex<T> c;
is >> c;
assert(c == cuda::std::complex<T>(5.5, 0));
assert(is.good());
}
{
std::istringstream is(" ( 5.5 ) ");
cuda::std::complex<T> c;
is >> c;
assert(c == cuda::std::complex<T>(5.5, 0));
assert(is.good());
}
{
std::istringstream is(" 5.5)");
cuda::std::complex<T> c;
is >> c;
assert(c == cuda::std::complex<T>(5.5, 0));
assert(is.good());
}
{
std::istringstream is("(5.5 ");
cuda::std::complex<T> c;
is >> c;
assert(c == cuda::std::complex<T>(0, 0));
assert(is.fail());
}
{
std::istringstream is("(5.5,");
cuda::std::complex<T> c;
is >> c;
assert(c == cuda::std::complex<T>(0, 0));
assert(is.fail());
}
{
std::istringstream is("( -5.5 , -6.5 )");
cuda::std::complex<T> c;
is >> c;
assert(c == cuda::std::complex<T>(-5.5, -6.5));
assert(!is.eof());
}
{
std::istringstream is("(-5.5,-6.5)");
cuda::std::complex<T> c;
is >> c;
assert(c == cuda::std::complex<T>(-5.5, -6.5));
assert(!is.eof());
}
}

void test() {
test<float>();
test<double>();
#ifdef _LIBCUDACXX_HAS_NVFP16
test<__half>();
#endif // _LIBCUDACXX_HAS_NVFP16
#ifdef _LIBCUDACXX_HAS_NVBF16
test<__nv_bfloat16>();
#endif // _LIBCUDACXX_HAS_NVBF16
}

int main(int, char**) {
NV_IF_TARGET(NV_IS_HOST, test();)
return 0;
}
Loading
Loading