chromium/third_party/eigen3/src/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h

// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2014 Benoit Steiner <[email protected]>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.

#ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H
#define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H

// IWYU pragma: private
#include "./InternalHeaderCheck.h"

namespace Eigen {

namespace internal {

enum {};

/*
 * Implementation of the Eigen blas_data_mapper class for tensors.
 */
/// The make pointer class is used by sycl in order to build the mapper class on the device. For other platform the
/// default make pointer is used which is scalar * for CoeffLoader.
template <typename Tensor, bool HasRawAccess, template <class> class MakePointer_ = MakePointer>
struct CoeffLoader;

template <typename Scalar, typename Index, int side, typename Tensor, typename nocontract_t, typename contract_t,
          int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
          template <class> class MakePointer_ = MakePointer>
class BaseTensorContractionMapper;

template <typename Tensor, bool HasRawAccess, template <class> class MakePointer_>
struct CoeffLoader {};

CoeffLoader<Tensor, true, MakePointer_>;

template <typename Scalar, typename Index, int side, typename Tensor, typename nocontract_t, typename contract_t,
          int packet_size, bool inner_dim_contiguous, int Alignment, template <class> class MakePointer_ = MakePointer>
class SimpleTensorContractionMapper {};

template <typename Scalar, typename Index, int side, typename Tensor, typename nocontract_t, typename contract_t,
          int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
          template <class> class MakePointer_>
class BaseTensorContractionMapper
    : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size,
                                           inner_dim_contiguous, Alignment, MakePointer_> {};

BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_>;

template <typename Scalar, typename Index, int side, typename Tensor, typename nocontract_t, typename contract_t,
          int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
          template <class> class MakePointer_ = MakePointer>
class TensorContractionSubMapper {};

template <typename Scalar_, typename Index, int side, typename Tensor, typename nocontract_t, typename contract_t,
          int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
          template <class> class MakePointer_ = MakePointer>
class TensorContractionInputMapper
    : public BaseTensorContractionMapper<Scalar_, Index, side, Tensor, nocontract_t, contract_t, packet_size,
                                         inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_> {};

template <typename T>
struct TensorContractionInputMapperTrait;

TensorContractionInputMapperTrait<TensorContractionInputMapper<Scalar_, Index_, side_, Tensor_, nocontract_t_, contract_t_, packet_size_, inner_dim_contiguous_, inner_dim_reordered_, Alignment_, MakePointer_>>;

}  // end namespace internal
}  // end namespace Eigen

#endif  // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H