llvm/libcxx/test/std/containers/views/mdspan/CustomTestLayouts.h

// -*- C++ -*-
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//                        Kokkos v. 4.0
//       Copyright (2022) National Technology & Engineering
//               Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
//===---------------------------------------------------------------------===//

#ifndef TEST_STD_CONTAINERS_VIEWS_MDSPAN_CUSTOM_TEST_LAYOUTS_H
#define TEST_STD_CONTAINERS_VIEWS_MDSPAN_CUSTOM_TEST_LAYOUTS_H

#include <algorithm>
#include <array>
#include <cassert>
#include <cinttypes>
#include <concepts>
#include <cstddef>
#include <limits>
#include <mdspan>
#include <type_traits>
#include <utility>

// Layout that wraps indices to test some idiosyncratic behavior
// - basically it is a layout_left where indices are first wrapped i.e. i%Wrap
// - only accepts integers as indices
// - is_always_strided and is_always_unique are false
// - is_strided and is_unique are true if all extents are smaller than Wrap
// - not default constructible
// - not extents constructible
// - not trivially copyable
// - does not check dynamic to static extent conversion in converting ctor
// - check via side-effects that mdspan::swap calls mappings swap via ADL

struct not_extents_constructible_tag {};

template <size_t Wrap>
class layout_wrapping_integral {
public:
  template <class Extents>
  class mapping;
};

template <size_t WrapArg>
template <class Extents>
class layout_wrapping_integral<WrapArg>::mapping {
  static constexpr typename Extents::index_type Wrap = static_cast<typename Extents::index_type>(WrapArg);

public:
  using extents_type = Extents;
  using index_type   = typename extents_type::index_type;
  using size_type    = typename extents_type::size_type;
  using rank_type    = typename extents_type::rank_type;
  using layout_type  = layout_wrapping_integral<Wrap>;

private:
  static constexpr bool required_span_size_is_representable(const extents_type& ext) {
    if constexpr (extents_type::rank() == 0)
      return true;

    index_type prod = ext.extent(0);
    for (rank_type r = 1; r < extents_type::rank(); r++) {
      bool overflowed = __builtin_mul_overflow(prod, std::min(ext.extent(r), Wrap), &prod);
      if (overflowed)
        return false;
    }
    return true;
  }

public:
  constexpr mapping() noexcept = delete;
  constexpr mapping(const mapping& other) noexcept : extents_(other.extents()) {}
  constexpr mapping(extents_type&& ext) noexcept
    requires(Wrap == 8)
      : extents_(ext) {}
  constexpr mapping(const extents_type& ext, not_extents_constructible_tag) noexcept : extents_(ext) {}

  template <class OtherExtents>
    requires(std::is_constructible_v<extents_type, OtherExtents> && (Wrap != 8))
  constexpr explicit(!std::is_convertible_v<OtherExtents, extents_type>)
      mapping(const mapping<OtherExtents>& other) noexcept {
    std::array<index_type, extents_type::rank_dynamic()> dyn_extents;
    rank_type count = 0;
    for (rank_type r = 0; r < extents_type::rank(); r++) {
      if (extents_type::static_extent(r) == std::dynamic_extent) {
        dyn_extents[count++] = other.extents().extent(r);
      }
    }
    extents_ = extents_type(dyn_extents);
  }
  template <class OtherExtents>
    requires(std::is_constructible_v<extents_type, OtherExtents> && (Wrap == 8))
  constexpr explicit(!std::is_convertible_v<OtherExtents, extents_type>)
      mapping(mapping<OtherExtents>&& other) noexcept {
    std::array<index_type, extents_type::rank_dynamic()> dyn_extents;
    rank_type count = 0;
    for (rank_type r = 0; r < extents_type::rank(); r++) {
      if (extents_type::static_extent(r) == std::dynamic_extent) {
        dyn_extents[count++] = other.extents().extent(r);
      }
    }
    extents_ = extents_type(dyn_extents);
  }

  constexpr mapping& operator=(const mapping& other) noexcept {
    extents_ = other.extents_;
    return *this;
  };

  constexpr const extents_type& extents() const noexcept { return extents_; }

  constexpr index_type required_span_size() const noexcept {
    index_type size = 1;
    for (size_t r = 0; r < extents_type::rank(); r++)
      size *= extents_.extent(r) < Wrap ? extents_.extent(r) : Wrap;
    return size;
  }

  template <std::integral... Indices>
    requires((sizeof...(Indices) == extents_type::rank()) && (std::is_convertible_v<Indices, index_type> && ...) &&
             (std::is_nothrow_constructible_v<index_type, Indices> && ...))
  constexpr index_type operator()(Indices... idx) const noexcept {
    std::array<index_type, extents_type::rank()> idx_a{static_cast<index_type>(static_cast<index_type>(idx) % Wrap)...};
    return [&]<size_t... Pos>(std::index_sequence<Pos...>) {
      index_type res = 0;
      ((res = idx_a[extents_type::rank() - 1 - Pos] +
              (extents_.extent(extents_type::rank() - 1 - Pos) < Wrap ? extents_.extent(extents_type::rank() - 1 - Pos)
                                                                      : Wrap) *
                  res),
       ...);
      return res;
    }(std::make_index_sequence<sizeof...(Indices)>());
  }

  static constexpr bool is_always_unique() noexcept { return false; }
  static constexpr bool is_always_exhaustive() noexcept { return true; }
  static constexpr bool is_always_strided() noexcept { return false; }

  constexpr bool is_unique() const noexcept {
    for (rank_type r = 0; r < extents_type::rank(); r++) {
      if (extents_.extent(r) > Wrap)
        return false;
    }
    return true;
  }
  static constexpr bool is_exhaustive() noexcept { return true; }
  constexpr bool is_strided() const noexcept {
    for (rank_type r = 0; r < extents_type::rank(); r++) {
      if (extents_.extent(r) > Wrap)
        return false;
    }
    return true;
  }

  constexpr index_type stride(rank_type r) const noexcept
    requires(extents_type::rank() > 0)
  {
    index_type s = 1;
    for (rank_type i = extents_type::rank() - 1; i > r; i--)
      s *= extents_.extent(i);
    return s;
  }

  template <class OtherExtents>
    requires(OtherExtents::rank() == extents_type::rank())
  friend constexpr bool operator==(const mapping& lhs, const mapping<OtherExtents>& rhs) noexcept {
    return lhs.extents() == rhs.extents();
  }

  friend constexpr void swap(mapping& x, mapping& y) noexcept {
    swap(x.extents_, y.extents_);
    if (!std::is_constant_evaluated()) {
      swap_counter()++;
    }
  }

  static int& swap_counter() {
    static int value = 0;
    return value;
  }

private:
  extents_type extents_{};
};

template <class Extents>
constexpr auto construct_mapping(std::layout_left, Extents exts) {
  return std::layout_left::mapping<Extents>(exts);
}

template <class Extents>
constexpr auto construct_mapping(std::layout_right, Extents exts) {
  return std::layout_right::mapping<Extents>(exts);
}

template <size_t Wraps, class Extents>
constexpr auto construct_mapping(layout_wrapping_integral<Wraps>, Extents exts) {
  return typename layout_wrapping_integral<Wraps>::template mapping<Extents>(exts, not_extents_constructible_tag{});
}

// This layout does not check convertibility of extents for its conversion ctor
// Allows triggering mdspan's ctor static assertion on convertibility of extents
// It also allows for negative strides and offsets via runtime arguments
class always_convertible_layout {
public:
  template <class Extents>
  class mapping;
};

template <class Extents>
class always_convertible_layout::mapping {
public:
  using extents_type = Extents;
  using index_type   = typename extents_type::index_type;
  using size_type    = typename extents_type::size_type;
  using rank_type    = typename extents_type::rank_type;
  using layout_type  = always_convertible_layout;

private:
  static constexpr bool required_span_size_is_representable(const extents_type& ext) {
    if constexpr (extents_type::rank() == 0)
      return true;

    index_type prod = ext.extent(0);
    for (rank_type r = 1; r < extents_type::rank(); r++) {
      bool overflowed = __builtin_mul_overflow(prod, ext.extent(r), &prod);
      if (overflowed)
        return false;
    }
    return true;
  }

public:
  constexpr mapping() noexcept = delete;
  constexpr mapping(const mapping& other) noexcept
      : extents_(other.extents_), offset_(other.offset_), scaling_(other.scaling_) {}
  constexpr mapping(const extents_type& ext, index_type offset = 0, index_type scaling = 1) noexcept
      : extents_(ext), offset_(offset), scaling_(scaling) {}

  template <class OtherExtents>
  constexpr mapping(const mapping<OtherExtents>& other) noexcept {
    if constexpr (extents_type::rank() == OtherExtents::rank()) {
      std::array<index_type, extents_type::rank_dynamic()> dyn_extents;
      rank_type count = 0;
      for (rank_type r = 0; r < extents_type::rank(); r++) {
        if (extents_type::static_extent(r) == std::dynamic_extent) {
          dyn_extents[count++] = other.extents().extent(r);
        }
      }
      extents_ = extents_type(dyn_extents);
    } else {
      extents_ = extents_type();
    }
    offset_  = other.offset_;
    scaling_ = other.scaling_;
  }

  constexpr mapping& operator=(const mapping& other) noexcept {
    extents_ = other.extents_;
    offset_  = other.offset_;
    scaling_ = other.scaling_;
    return *this;
  };

  constexpr const extents_type& extents() const noexcept { return extents_; }

  constexpr index_type required_span_size() const noexcept {
    index_type size = 1;
    for (size_t r = 0; r < extents_type::rank(); r++)
      size *= extents_.extent(r);
    return std::max(size * scaling_ + offset_, offset_);
  }

  template <std::integral... Indices>
    requires((sizeof...(Indices) == extents_type::rank()) && (std::is_convertible_v<Indices, index_type> && ...) &&
             (std::is_nothrow_constructible_v<index_type, Indices> && ...))
  constexpr index_type operator()(Indices... idx) const noexcept {
    std::array<index_type, extents_type::rank()> idx_a{static_cast<index_type>(static_cast<index_type>(idx))...};
    return offset_ +
           scaling_ * ([&]<size_t... Pos>(std::index_sequence<Pos...>) {
             index_type res = 0;
             ((res = idx_a[extents_type::rank() - 1 - Pos] + extents_.extent(extents_type::rank() - 1 - Pos) * res),
              ...);
             return res;
           }(std::make_index_sequence<sizeof...(Indices)>()));
  }

  static constexpr bool is_always_unique() noexcept { return true; }
  static constexpr bool is_always_exhaustive() noexcept { return true; }
  static constexpr bool is_always_strided() noexcept { return true; }

  static constexpr bool is_unique() noexcept { return true; }
  static constexpr bool is_exhaustive() noexcept { return true; }
  static constexpr bool is_strided() noexcept { return true; }

  constexpr index_type stride(rank_type r) const noexcept
    requires(extents_type::rank() > 0)
  {
    index_type s = 1;
    for (rank_type i = 0; i < r; i++)
      s *= extents_.extent(i);
    return s * scaling_;
  }

  template <class OtherExtents>
    requires(OtherExtents::rank() == extents_type::rank())
  friend constexpr bool operator==(const mapping& lhs, const mapping<OtherExtents>& rhs) noexcept {
    return lhs.extents() == rhs.extents() && lhs.offset_ == rhs.offset && lhs.scaling_ == rhs.scaling_;
  }

  friend constexpr void swap(mapping& x, mapping& y) noexcept {
    swap(x.extents_, y.extents_);
    if (!std::is_constant_evaluated()) {
      swap_counter()++;
    }
  }

  static int& swap_counter() {
    static int value = 0;
    return value;
  }

private:
  template <class>
  friend class mapping;

  extents_type extents_{};
  index_type offset_{};
  index_type scaling_{};
};
#endif // TEST_STD_CONTAINERS_VIEWS_MDSPAN_CUSTOM_TEST_LAYOUTS_H