chromium/third_party/eigen3/src/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h

// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2014 Benoit Steiner <[email protected]>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.

#ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H
#define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H

// IWYU pragma: private
#include "./InternalHeaderCheck.h"

namespace Eigen {

/** \class TensorContraction
 * \ingroup CXX11_Tensor_Module
 *
 * \brief Tensor contraction class.
 *
 *
 */
namespace internal {

traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType>>;

eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType>, Eigen::Dense>;

nested<TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType>, 1, typename eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType>>::type>;

traits<TensorEvaluator<const TensorContractionOp<Indices_, LeftArgType_, RightArgType_, OutputKernelType_>, Device_>>;

// Helper class to allocate and deallocate temporary memory for packed buffers.
template <typename LhsScalar, typename RhsScalar>
struct TensorContractionBlockMemAllocator {};

// WARNING: In this code we assume that Lhs and Rhs tensor expressions are in
// ColMajor storage order. This property is guaranteed by the
// TensorContractionOp evaluator. TensorContractionKernel specifies how we pack
// blocks of Lhs and Rhs tensor expressions, and how we invoke matrix
// multiplication for these blocks. Default tensor contraction uses
// gemm_pack_rhs, gemm_pack_lhs and gebp_kernel from Eigen Core (see
// GeneralBlocPanelKernel.h for details).
//
// By specializing contraction kernels we can use other low level libraries to
// perform matrix multiplication, and still rely on Eigen contraction evaluator.
// This also includes full support in TensorContractionThreadPool, assuming that
// underlying gemm do not use it's own threading.
//
// - ResScalar/LhsScalar/RhsScalar - scalar type for the result of
//   multiplication, lhs tensor and rhs tensor respectively.
//
// - StorageIndex - index type for the tensor expressions. In practice almost
//   always is Eigen::Index.
//
// - OutputMapper provides access to the memory of the output matrix. In
//   practice it's always column major blas_data_mapper (it must be of ResScalar
//   type).
//
// - LhsMapper/RhsMapper similarly to blas_data_mapper provide a two dimensional
//   view into the Lhs/Rhs tensor expressions. In practice it's
//   TensorContractionInputMapper, or some specialization of it based on the
//   type of tensor expression (e.g. TensorImagePatchOp has optimized input
//   mapper).
template <typename ResScalar, typename LhsScalar, typename RhsScalar, typename StorageIndex, typename OutputMapper,
          typename LhsMapper, typename RhsMapper>
struct TensorContractionKernel {};

}  // end namespace internal

// Tensor contraction params that should enable to get from output matrix
// 2-dimensional coordinates to the output tensor dimensions.
struct TensorContractionParams {};

// Output kernel allows to fuse operations into the tensor contraction.
//
// Examples:
//   1. Elementwise Relu transformation following Conv2D.
//   2. AddBias to the Conv2D output channels dimension.
//
// The NoOpOutputKernel implements an output kernel that does absolutely nothing.
struct NoOpOutputKernel {};

template <typename Indices, typename LhsXprType, typename RhsXprType,
          typename OutputKernelType = const NoOpOutputKernel>
class TensorContractionOp
    : public TensorBase<TensorContractionOp<Indices, LhsXprType, RhsXprType, OutputKernelType>, ReadOnlyAccessors> {};

template <typename Derived>
struct TensorContractionEvaluatorBase {};

// evaluator for default device
TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device>;

}  // end namespace Eigen

#endif  // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H