// -*- C++ -*-
//===--------------------------- barrier ----------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef _LIBCUDACXX_BARRIER
#define _LIBCUDACXX_BARRIER

/*
    barrier synopsis

namespace std
{

  template<class CompletionFunction = see below>
  class barrier
  {
  public:
    using arrival_token = see below;

    constexpr explicit barrier(ptrdiff_t phase_count,
                               CompletionFunction f = CompletionFunction());
    ~barrier();

    barrier(const barrier&) = delete;
    barrier& operator=(const barrier&) = delete;

    [[nodiscard]] arrival_token arrive(ptrdiff_t update = 1);
    void wait(arrival_token&& arrival) const;

    void arrive_and_wait();
    void arrive_and_drop();

  private:
    CompletionFunction __completion; // exposition only
  };

}

*/

#include <cuda/std/detail/__config>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
#  pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
#  pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
#  pragma system_header
#endif // no system header

#include <cuda/std/__new_>
#include <cuda/std/atomic>
#include <cuda/std/chrono>
#include <cuda/std/cstddef>
#include <cuda/std/detail/libcxx/include/__assert> // all public C++ headers provide the assertion handler
#include <cuda/std/detail/libcxx/include/__debug>

_CCCL_PUSH_MACROS

#ifdef _LIBCUDACXX_HAS_NO_THREADS
#  error <barrier> is not supported on this single threaded system
#endif

_LIBCUDACXX_BEGIN_NAMESPACE_STD

struct __empty_completion
{
  _LIBCUDACXX_HIDE_FROM_ABI void operator()() noexcept {}
};

#ifndef _LIBCUDACXX_HAS_NO_TREE_BARRIER

template <class _CompletionF = __empty_completion, int _Sco = 0>
class alignas(64) __barrier_base
{
  ptrdiff_t __expected;
  __atomic_impl<ptrdiff_t, _Sco> __expected_adjustment;
  _CompletionF __completion;

  using __phase_t = uint8_t;
  __atomic_impl<__phase_t, _Sco> __phase;

  struct alignas(64) __state_t
  {
    struct
    {
      __atomic_impl<__phase_t, _Sco> __phase = LIBCUDACXX_ATOMIC_VAR_INIT(0);
    } __tickets[64];
  };
  ::std::vector<__state_t> __state;

  _LIBCUDACXX_HIDE_FROM_ABI bool __arrive(__phase_t const __old_phase)
  {
    __phase_t const __half_step = __old_phase + 1, __full_step = __old_phase + 2;
#  ifndef _LIBCUDACXX_HAS_NO_THREAD_FAVORITE_BARRIER_INDEX
    ptrdiff_t __current = __libcpp_thread_favorite_barrier_index,
#  else
    ptrdiff_t __current = 0,
#  endif
              __current_expected = __expected, __last_node = (__current_expected >> 1);
    for (size_t __round = 0;; ++__round)
    {
      _LIBCUDACXX_ASSERT(__round <= 63, "");
      if (__current_expected == 1)
      {
        return true;
      }
      for (;; ++__current)
      {
#  ifndef _LIBCUDACXX_HAS_NO_THREAD_FAVORITE_BARRIER_INDEX
        if (0 == __round)
        {
          if (__current >= __current_expected)
          {
            __current = 0;
          }
          __libcpp_thread_favorite_barrier_index = __current;
        }
#  endif
        _LIBCUDACXX_ASSERT(__current <= __last_node, "");
        __phase_t expect = __old_phase;
        if (__current == __last_node && (__current_expected & 1))
        {
          if (__state[__current].__tickets[__round].__phase.compare_exchange_strong(
                expect, __full_step, memory_order_acq_rel))
          {
            break; // I'm 1 in 1, go to next __round
          }
          _LIBCUDACXX_ASSERT(expect == __full_step, "");
        }
        else if (__state[__current].__tickets[__round].__phase.compare_exchange_strong(
                   expect, __half_step, memory_order_acq_rel))
        {
          return false; // I'm 1 in 2, done with arrival
        }
        else if (expect == __half_step)
        {
          if (__state[__current].__tickets[__round].__phase.compare_exchange_strong(
                expect, __full_step, memory_order_acq_rel))
          {
            break; // I'm 2 in 2, go to next __round
          }
          _LIBCUDACXX_ASSERT(expect == __full_step, "");
        }
        _LIBCUDACXX_ASSERT(__round == 0 && expect == __full_step, "");
      }
      __current_expected = (__current_expected >> 1) + (__current_expected & 1);
      __current &= ~(1 << __round);
      __last_node &= ~(1 << __round);
    }
  }

public:
  using arrival_token = __phase_t;

  _LIBCUDACXX_HIDE_FROM_ABI __barrier_base(ptrdiff_t __expected, _CompletionF __completion = _CompletionF())
      : __expected(__expected)
      , __expected_adjustment(0)
      , __completion(__completion)
      , __phase(0)
      , __state((__expected + 1) >> 1)
  {
    _LIBCUDACXX_ASSERT(__expected >= 0, "");
  }

  _CCCL_HIDE_FROM_ABI ~__barrier_base() = default;

  __barrier_base(__barrier_base const&)            = delete;
  __barrier_base& operator=(__barrier_base const&) = delete;

  _CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI arrival_token arrive(ptrdiff_t update = 1)
  {
    _LIBCUDACXX_ASSERT(update > 0, "");
    auto __old_phase = __phase.load(memory_order_relaxed);
    for (; update; --update)
    {
      if (__arrive(__old_phase))
      {
        __completion();
        __expected += __expected_adjustment.load(memory_order_relaxed);
        __expected_adjustment.store(0, memory_order_relaxed);
        __phase.store(__old_phase + 2, memory_order_release);
      }
    }
    return __old_phase;
  }
  _LIBCUDACXX_HIDE_FROM_ABI void wait(arrival_token&& __old_phase) const
  {
    __libcpp_thread_poll_with_backoff([=]() -> bool {
      return __phase.load(memory_order_acquire) != __old_phase;
    });
  }
  _LIBCUDACXX_HIDE_FROM_ABI void arrive_and_wait()
  {
    wait(arrive());
  }
  _LIBCUDACXX_HIDE_FROM_ABI void arrive_and_drop()
  {
    __expected_adjustment.fetch_sub(1, memory_order_relaxed);
    (void) arrive();
  }
};

#else

#  if _LIBCUDACXX_CUDA_ABI_VERSION < 3
#    define _LIBCUDACXX_BARRIER_ALIGNMENTS alignas(64)
#  else
#    define _LIBCUDACXX_BARRIER_ALIGNMENTS
#  endif

template <class _Barrier>
class __barrier_poll_tester_phase
{
  _Barrier const* __this;
  typename _Barrier::arrival_token __phase;

public:
  _LIBCUDACXX_HIDE_FROM_ABI
  __barrier_poll_tester_phase(_Barrier const* __this_, typename _Barrier::arrival_token&& __phase_)
      : __this(__this_)
      , __phase(_CUDA_VSTD::move(__phase_))
  {}

  _LIBCUDACXX_HIDE_FROM_ABI bool operator()() const
  {
    return __this->__try_wait(__phase);
  }
};

template <class _Barrier>
class __barrier_poll_tester_parity
{
  _Barrier const* __this;
  bool __parity;

public:
  _LIBCUDACXX_HIDE_FROM_ABI __barrier_poll_tester_parity(_Barrier const* __this_, bool __parity_)
      : __this(__this_)
      , __parity(__parity_)
  {}

  _LIBCUDACXX_HIDE_FROM_ABI bool operator()() const
  {
    return __this->__try_wait_parity(__parity);
  }
};

template <class _Barrier>
_LIBCUDACXX_HIDE_FROM_ABI bool __call_try_wait(const _Barrier& __b, typename _Barrier::arrival_token&& __phase)
{
  return __b.__try_wait(_CUDA_VSTD::move(__phase));
}

template <class _Barrier>
_LIBCUDACXX_HIDE_FROM_ABI bool __call_try_wait_parity(const _Barrier& __b, bool __parity)
{
  return __b.__try_wait_parity(__parity);
}

template <class _CompletionF, thread_scope _Sco = thread_scope_system>
class __barrier_base
{
  _LIBCUDACXX_BARRIER_ALIGNMENTS __atomic_impl<ptrdiff_t, _Sco> __expected, __arrived;
  _LIBCUDACXX_BARRIER_ALIGNMENTS _CompletionF __completion;
  _LIBCUDACXX_BARRIER_ALIGNMENTS __atomic_impl<bool, _Sco> __phase;

public:
  using arrival_token = bool;

private:
  template <typename _Barrier>
  friend class __barrier_poll_tester_phase;
  template <typename _Barrier>
  friend class __barrier_poll_tester_parity;
  template <typename _Barrier>
  _LIBCUDACXX_HIDE_FROM_ABI friend bool __call_try_wait(const _Barrier& __b, typename _Barrier::arrival_token&& __phase);
  template <typename _Barrier>
  _LIBCUDACXX_HIDE_FROM_ABI friend bool __call_try_wait_parity(const _Barrier& __b, bool __parity);

  _LIBCUDACXX_HIDE_FROM_ABI bool __try_wait(arrival_token __old) const
  {
    return __phase.load(memory_order_acquire) != __old;
  }
  _LIBCUDACXX_HIDE_FROM_ABI bool __try_wait_parity(bool __parity) const
  {
    return __try_wait(__parity);
  }

public:
  _CCCL_HIDE_FROM_ABI __barrier_base() = default;

  _LIBCUDACXX_HIDE_FROM_ABI __barrier_base(ptrdiff_t __expected, _CompletionF __completion = _CompletionF())
      : __expected(__expected)
      , __arrived(__expected)
      , __completion(__completion)
      , __phase(false)
  {}

  _CCCL_HIDE_FROM_ABI ~__barrier_base() = default;

  __barrier_base(__barrier_base const&)            = delete;
  __barrier_base& operator=(__barrier_base const&) = delete;

  _CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI arrival_token arrive(ptrdiff_t __update = 1)
  {
    auto const __old_phase    = __phase.load(memory_order_relaxed);
    auto const __result       = __arrived.fetch_sub(__update, memory_order_acq_rel) - __update;
    auto const __new_expected = __expected.load(memory_order_relaxed);

    _LIBCUDACXX_DEBUG_ASSERT(__result >= 0, "");

    if (0 == __result)
    {
      __completion();
      __arrived.store(__new_expected, memory_order_relaxed);
      __phase.store(!__old_phase, memory_order_release);
      __atomic_notify_all(&__phase.__a, __scope_to_tag<_Sco>{});
    }
    return __old_phase;
  }
  _LIBCUDACXX_HIDE_FROM_ABI void wait(arrival_token&& __old_phase) const
  {
    __phase.wait(__old_phase, memory_order_acquire);
  }
  _LIBCUDACXX_HIDE_FROM_ABI void arrive_and_wait()
  {
    wait(arrive());
  }
  _LIBCUDACXX_HIDE_FROM_ABI void arrive_and_drop()
  {
    __expected.fetch_sub(1, memory_order_relaxed);
    (void) arrive();
  }

  _LIBCUDACXX_HIDE_FROM_ABI static constexpr ptrdiff_t max() noexcept
  {
    return numeric_limits<ptrdiff_t>::max();
  }
};

template <thread_scope _Sco>
class __barrier_base<__empty_completion, _Sco>
{
  static constexpr uint64_t __expected_unit = 1ull;
  static constexpr uint64_t __arrived_unit  = 1ull << 32;
  static constexpr uint64_t __expected_mask = __arrived_unit - 1;
  static constexpr uint64_t __phase_bit     = 1ull << 63;
  static constexpr uint64_t __arrived_mask  = (__phase_bit - 1) & ~__expected_mask;

  _LIBCUDACXX_BARRIER_ALIGNMENTS __atomic_impl<uint64_t, _Sco> __phase_arrived_expected;

public:
  using arrival_token = uint64_t;

private:
  template <typename _Barrier>
  friend class __barrier_poll_tester_phase;
  template <typename _Barrier>
  friend class __barrier_poll_tester_parity;
  template <typename _Barrier>
  _LIBCUDACXX_HIDE_FROM_ABI friend bool __call_try_wait(const _Barrier& __b, typename _Barrier::arrival_token&& __phase);
  template <typename _Barrier>
  _LIBCUDACXX_HIDE_FROM_ABI friend bool __call_try_wait_parity(const _Barrier& __b, bool __parity);

  static _LIBCUDACXX_HIDE_FROM_ABI constexpr uint64_t __init(ptrdiff_t __count) noexcept
  {
#  if _CCCL_STD_VER > 2011
    // This debug assert is not supported in C++11 due to resulting in a
    // multi-statement constexpr function.
    _LIBCUDACXX_DEBUG_ASSERT(__count >= 0, "Count must be non-negative.");
#  endif // _CCCL_STD_VER > 2011
    return (((1u << 31) - __count) << 32) | ((1u << 31) - __count);
  }
  _LIBCUDACXX_HIDE_FROM_ABI bool __try_wait_phase(uint64_t __phase) const
  {
    uint64_t const __current = __phase_arrived_expected.load(memory_order_acquire);
    return ((__current & __phase_bit) != __phase);
  }
  _LIBCUDACXX_HIDE_FROM_ABI bool __try_wait(arrival_token __old) const
  {
    return __try_wait_phase(__old & __phase_bit);
  }
  _LIBCUDACXX_HIDE_FROM_ABI bool __try_wait_parity(bool __parity) const
  {
    return __try_wait_phase(__parity ? __phase_bit : 0);
  }

public:
  _CCCL_HIDE_FROM_ABI __barrier_base() = default;

  _LIBCUDACXX_HIDE_FROM_ABI _CCCL_CONSTEXPR_CXX14
  __barrier_base(ptrdiff_t __count, __empty_completion = __empty_completion())
      : __phase_arrived_expected(__init(__count))
  {
    _LIBCUDACXX_DEBUG_ASSERT(__count >= 0, "");
  }

  _CCCL_HIDE_FROM_ABI ~__barrier_base() = default;

  __barrier_base(__barrier_base const&)            = delete;
  __barrier_base& operator=(__barrier_base const&) = delete;

  _CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI arrival_token arrive(ptrdiff_t __update = 1)
  {
    auto const __inc = __arrived_unit * __update;
    auto const __old = __phase_arrived_expected.fetch_add(__inc, memory_order_acq_rel);
    if ((__old ^ (__old + __inc)) & __phase_bit)
    {
      __phase_arrived_expected.fetch_add((__old & __expected_mask) << 32, memory_order_relaxed);
      __phase_arrived_expected.notify_all();
    }
    return __old & __phase_bit;
  }
  _LIBCUDACXX_HIDE_FROM_ABI void wait(arrival_token&& __phase) const
  {
    __libcpp_thread_poll_with_backoff(__barrier_poll_tester_phase<__barrier_base>(this, _CUDA_VSTD::move(__phase)));
  }
  _LIBCUDACXX_HIDE_FROM_ABI void wait_parity(bool __parity) const
  {
    __libcpp_thread_poll_with_backoff(__barrier_poll_tester_parity<__barrier_base>(this, __parity));
  }
  _LIBCUDACXX_HIDE_FROM_ABI void arrive_and_wait()
  {
    wait(arrive());
  }
  _LIBCUDACXX_HIDE_FROM_ABI void arrive_and_drop()
  {
    __phase_arrived_expected.fetch_add(__expected_unit, memory_order_relaxed);
    (void) arrive();
  }

  _LIBCUDACXX_HIDE_FROM_ABI static constexpr ptrdiff_t max() noexcept
  {
    return numeric_limits<int32_t>::max();
  }
};

#endif //_LIBCUDACXX_HAS_NO_TREE_BARRIER

template <class _CompletionF = __empty_completion>
class barrier : public __barrier_base<_CompletionF>
{
public:
  _LIBCUDACXX_HIDE_FROM_ABI constexpr barrier(ptrdiff_t __count, _CompletionF __completion = _CompletionF())
      : __barrier_base<_CompletionF>(__count, __completion)
  {}
};

_LIBCUDACXX_END_NAMESPACE_STD

#include <cuda/std/__cuda/barrier.h>

_CCCL_POP_MACROS

#endif //_LIBCUDACXX_BARRIER
