llvm/pstl/include/pstl/internal/omp/parallel_scan.h

// -*- C++ -*-
// -*-===----------------------------------------------------------------------===//
//
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
//
//===----------------------------------------------------------------------===//

#ifndef _PSTL_INTERNAL_OMP_PARALLEL_SCAN_H
#define _PSTL_INTERNAL_OMP_PARALLEL_SCAN_H

#include "parallel_invoke.h"

namespace __pstl
{
namespace __omp_backend
{

template <typename _Index>
_Index
__split(_Index __m)
{
    _Index __k = 1;
    while (2 * __k < __m)
        __k *= 2;
    return __k;
}

template <typename _Index, typename _Tp, typename _Rp, typename _Cp>
void
__upsweep(_Index __i, _Index __m, _Index __tilesize, _Tp* __r, _Index __lastsize, _Rp __reduce, _Cp __combine)
{
    if (__m == 1)
        __r[0] = __reduce(__i * __tilesize, __lastsize);
    else
    {
        _Index __k = __split(__m);
        __omp_backend::__parallel_invoke_body(
            [=] { __omp_backend::__upsweep(__i, __k, __tilesize, __r, __tilesize, __reduce, __combine); },
            [=] {
                __omp_backend::__upsweep(__i + __k, __m - __k, __tilesize, __r + __k, __lastsize, __reduce, __combine);
            });
        if (__m == 2 * __k)
            __r[__m - 1] = __combine(__r[__k - 1], __r[__m - 1]);
    }
}

template <typename _Index, typename _Tp, typename _Cp, typename _Sp>
void
__downsweep(_Index __i, _Index __m, _Index __tilesize, _Tp* __r, _Index __lastsize, _Tp __initial, _Cp __combine,
            _Sp __scan)
{
    if (__m == 1)
        __scan(__i * __tilesize, __lastsize, __initial);
    else
    {
        const _Index __k = __split(__m);
        __omp_backend::__parallel_invoke_body(
            [=] { __omp_backend::__downsweep(__i, __k, __tilesize, __r, __tilesize, __initial, __combine, __scan); },
            // Assumes that __combine never throws.
            // TODO: Consider adding a requirement for user functors to be constant.
            [=, &__combine]
            {
                __omp_backend::__downsweep(__i + __k, __m - __k, __tilesize, __r + __k, __lastsize,
                                           __combine(__initial, __r[__k - 1]), __combine, __scan);
            });
    }
}

template <typename _ExecutionPolicy, typename _Index, typename _Tp, typename _Rp, typename _Cp, typename _Sp,
          typename _Ap>
void
__parallel_strict_scan_body(_Index __n, _Tp __initial, _Rp __reduce, _Cp __combine, _Sp __scan, _Ap __apex)
{
    _Index __p = omp_get_num_threads();
    const _Index __slack = 4;
    _Index __tilesize = (__n - 1) / (__slack * __p) + 1;
    _Index __m = (__n - 1) / __tilesize;
    __buffer<_Tp> __buf(__m + 1);
    _Tp* __r = __buf.get();

    __omp_backend::__upsweep(_Index(0), _Index(__m + 1), __tilesize, __r, __n - __m * __tilesize, __reduce, __combine);

    std::size_t __k = __m + 1;
    _Tp __t = __r[__k - 1];
    while ((__k &= __k - 1))
    {
        __t = __combine(__r[__k - 1], __t);
    }

    __apex(__combine(__initial, __t));
    __omp_backend::__downsweep(_Index(0), _Index(__m + 1), __tilesize, __r, __n - __m * __tilesize, __initial,
                               __combine, __scan);
}

template <class _ExecutionPolicy, typename _Index, typename _Tp, typename _Rp, typename _Cp, typename _Sp, typename _Ap>
void
__parallel_strict_scan(__pstl::__internal::__openmp_backend_tag, _ExecutionPolicy&&, _Index __n, _Tp __initial,
                       _Rp __reduce, _Cp __combine, _Sp __scan, _Ap __apex)
{
    if (__n <= __default_chunk_size)
    {
        _Tp __sum = __initial;
        if (__n)
        {
            __sum = __combine(__sum, __reduce(_Index(0), __n));
        }
        __apex(__sum);
        if (__n)
        {
            __scan(_Index(0), __n, __initial);
        }
        return;
    }

    if (omp_in_parallel())
    {
        __pstl::__omp_backend::__parallel_strict_scan_body<_ExecutionPolicy>(__n, __initial, __reduce, __combine,
                                                                             __scan, __apex);
    }
    else
    {
        _PSTL_PRAGMA(omp parallel)
        _PSTL_PRAGMA(omp single nowait)
        {
            __pstl::__omp_backend::__parallel_strict_scan_body<_ExecutionPolicy>(__n, __initial, __reduce, __combine,
                                                                                 __scan, __apex);
        }
    }
}

} // namespace __omp_backend
} // namespace __pstl
#endif // _PSTL_INTERNAL_OMP_PARALLEL_SCAN_H