Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup the semaphore headers #2441

Merged
merged 5 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
//
//===----------------------------------------------------------------------===//

#ifndef _LIBCUDACXX___CUDA_SEMAPHORE_H
#define _LIBCUDACXX___CUDA_SEMAPHORE_H
#ifndef _CUDA___SEMAPHORE_COUNTING_SEMAPHORE_H
#define _CUDA___SEMAPHORE_COUNTING_SEMAPHORE_H

#include <cuda/std/detail/__config>

Expand All @@ -21,16 +21,21 @@
# pragma system_header
#endif // no system header

#include <cuda/std/__semaphore/atomic_semaphore.h>
#include <cuda/std/cstdint>

_CCCL_PUSH_MACROS

_LIBCUDACXX_BEGIN_NAMESPACE_CUDA

template <thread_scope _Sco, ptrdiff_t __least_max_value = INT_MAX>
class counting_semaphore : public _CUDA_VSTD::__semaphore_base<__least_max_value, _Sco>
class counting_semaphore : public _CUDA_VSTD::__atomic_semaphore<_Sco, __least_max_value>
{
static_assert(__least_max_value <= _CUDA_VSTD::__semaphore_base<__least_max_value, _Sco>::max(), "");
static_assert(__least_max_value <= _CUDA_VSTD::__atomic_semaphore<_Sco, __least_max_value>::max(), "");

public:
_LIBCUDACXX_HIDE_FROM_ABI constexpr counting_semaphore(ptrdiff_t __count = 0)
: _CUDA_VSTD::__semaphore_base<__least_max_value, _Sco>(__count)
: _CUDA_VSTD::__atomic_semaphore<_Sco, __least_max_value>(__count)
{}
_CCCL_HIDE_FROM_ABI ~counting_semaphore() = default;

Expand All @@ -43,4 +48,6 @@ using binary_semaphore = counting_semaphore<_Sco, 1>;

_LIBCUDACXX_END_NAMESPACE_CUDA

#endif // _LIBCUDACXX___CUDA_SEMAPHORE_H
_CCCL_POP_MACROS

#endif // _CUDA___SEMAPHORE_COUNTING_SEMAPHORE_H
15 changes: 15 additions & 0 deletions libcudacxx/include/cuda/semaphore
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,21 @@
#ifndef _CUDA_SEMAPHORE
#define _CUDA_SEMAPHORE

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700
# error "CUDA synchronization primitives are only supported for sm_70 and up."
#endif

#include <cuda/std/detail/__config>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
# pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header

#include <cuda/__semaphore/counting_semaphore.h>
#include <cuda/std/semaphore>

#endif // _CUDA_SEMAPHORE
234 changes: 234 additions & 0 deletions libcudacxx/include/cuda/std/__semaphore/atomic_semaphore.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#ifndef _LIBCUDACXX___SEMAPHORE_ATOMIC_SEMAPHORE_H
#define _LIBCUDACXX___SEMAPHORE_ATOMIC_SEMAPHORE_H

#include <cuda/std/detail/__config>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
# pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header

#include <cuda/std/atomic>
#include <cuda/std/chrono>
#include <cuda/std/cstdint>

_CCCL_PUSH_MACROS

_LIBCUDACXX_BEGIN_NAMESPACE_STD

template <thread_scope _Sco, ptrdiff_t __least_max_value>
class __atomic_semaphore
{
__atomic_impl<ptrdiff_t, _Sco> __count;

_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI bool __fetch_sub_if_slow(ptrdiff_t __old)
{
while (__old != 0)
{
if (__count.compare_exchange_weak(__old, __old - 1, memory_order_acquire, memory_order_relaxed))
{
return true;
}
}
return false;
}

_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI bool __fetch_sub_if()
{
ptrdiff_t __old = __count.load(memory_order_acquire);
if (__old == 0)
{
return false;
}
if (__count.compare_exchange_weak(__old, __old - 1, memory_order_acquire, memory_order_relaxed))
{
return true;
}
return __fetch_sub_if_slow(__old); // fail only if not __available
}

_LIBCUDACXX_HIDE_FROM_ABI void __wait_slow()
{
while (1)
{
ptrdiff_t const __old = __count.load(memory_order_acquire);
if (__old != 0)
{
break;
}
__count.wait(__old, memory_order_relaxed);
}
}

_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI bool __acquire_slow_timed(chrono::nanoseconds const& __rel_time)
{
return __libcpp_thread_poll_with_backoff(
[this]() {
ptrdiff_t const __old = __count.load(memory_order_acquire);
return __old != 0 && __fetch_sub_if_slow(__old);
},
__rel_time);
}

public:
_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI static constexpr ptrdiff_t max() noexcept
{
return numeric_limits<ptrdiff_t>::max();
}

_LIBCUDACXX_HIDE_FROM_ABI constexpr __atomic_semaphore(ptrdiff_t __count) noexcept
: __count(__count)
{}

_CCCL_HIDE_FROM_ABI ~__atomic_semaphore() = default;

__atomic_semaphore(__atomic_semaphore const&) = delete;
__atomic_semaphore& operator=(__atomic_semaphore const&) = delete;

_LIBCUDACXX_HIDE_FROM_ABI void release(ptrdiff_t __update = 1)
{
__count.fetch_add(__update, memory_order_release);
if (__update > 1)
{
__count.notify_all();
}
else
{
__count.notify_one();
}
}

_LIBCUDACXX_HIDE_FROM_ABI void acquire()
{
while (!try_acquire())
{
__wait_slow();
}
}

_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI bool try_acquire() noexcept
{
return __fetch_sub_if();
}

template <class Clock, class Duration>
_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI bool try_acquire_until(chrono::time_point<Clock, Duration> const& __abs_time)
{
if (try_acquire())
{
return true;
}
else
{
return __acquire_slow_timed(__abs_time - Clock::now());
}
}

template <class Rep, class Period>
_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI bool try_acquire_for(chrono::duration<Rep, Period> const& __rel_time)
{
if (try_acquire())
{
return true;
}
else
{
return __acquire_slow_timed(__rel_time);
}
}
};

template <thread_scope _Sco>
class __atomic_semaphore<_Sco, 1>
{
__atomic_impl<int, _Sco> __available;

_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI bool __acquire_slow_timed(chrono::nanoseconds const& __rel_time)
{
return __libcpp_thread_poll_with_backoff(
[this]() {
return try_acquire();
},
__rel_time);
}

public:
_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI static constexpr ptrdiff_t max() noexcept
{
return 1;
}

_LIBCUDACXX_HIDE_FROM_ABI constexpr __atomic_semaphore(ptrdiff_t __available)
: __available(__available)
{}

_CCCL_HIDE_FROM_ABI ~__atomic_semaphore() = default;

__atomic_semaphore(__atomic_semaphore const&) = delete;
__atomic_semaphore& operator=(__atomic_semaphore const&) = delete;

_LIBCUDACXX_HIDE_FROM_ABI void release(ptrdiff_t __update = 1)
{
_CCCL_ASSERT(__update == 1, "");
__available.store(1, memory_order_release);
__available.notify_one();
(void) __update;
}

_LIBCUDACXX_HIDE_FROM_ABI void acquire()
{
while (!try_acquire())
{
__available.wait(0, memory_order_relaxed);
}
}

_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI bool try_acquire() noexcept
{
return 1 == __available.exchange(0, memory_order_acquire);
}

template <class Clock, class Duration>
_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI bool try_acquire_until(chrono::time_point<Clock, Duration> const& __abs_time)
{
if (try_acquire())
{
return true;
}
else
{
return __acquire_slow_timed(__abs_time - Clock::now());
}
}

template <class Rep, class Period>
_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI bool try_acquire_for(chrono::duration<Rep, Period> const& __rel_time)
{
if (try_acquire())
{
return true;
}
else
{
return __acquire_slow_timed(__rel_time);
}
}
};

_LIBCUDACXX_END_NAMESPACE_STD

_CCCL_POP_MACROS

#endif // _LIBCUDACXX___SEMAPHORE_ATOMIC_SEMAPHORE_H
51 changes: 51 additions & 0 deletions libcudacxx/include/cuda/std/__semaphore/counting_semaphore.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#ifndef _LIBCUDACXX___SEMAPHORE_COUNTING_SEMAPHORE_H
#define _LIBCUDACXX___SEMAPHORE_COUNTING_SEMAPHORE_H

#include <cuda/std/detail/__config>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
# pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header

#include <cuda/std/__semaphore/atomic_semaphore.h>
#include <cuda/std/cstdint>

_CCCL_PUSH_MACROS

_LIBCUDACXX_BEGIN_NAMESPACE_STD

template <ptrdiff_t __least_max_value = INT_MAX>
class counting_semaphore : public __atomic_semaphore<thread_scope_system, __least_max_value>
{
static_assert(__least_max_value <= __atomic_semaphore<thread_scope_system, __least_max_value>::max(), "");

public:
_LIBCUDACXX_HIDE_FROM_ABI constexpr counting_semaphore(ptrdiff_t __count = 0)
: __atomic_semaphore<thread_scope_system, __least_max_value>(__count)
{}
_CCCL_HIDE_FROM_ABI ~counting_semaphore() = default;

counting_semaphore(const counting_semaphore&) = delete;
counting_semaphore& operator=(const counting_semaphore&) = delete;
};

using binary_semaphore = counting_semaphore<1>;

_LIBCUDACXX_END_NAMESPACE_STD

_CCCL_POP_MACROS

#endif // _LIBCUDACXX___SEMAPHORE_COUNTING_SEMAPHORE_H
Loading
Loading