chromium/third_party/ruy/src/ruy/create_trmul_params.h

/* Copyright 2020 Google LLC. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required_capacity by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// Implementation of CreateTrMulParams, see function comment.

#ifndef RUY_RUY_CREATE_TRMUL_PARAMS_H_
#define RUY_RUY_CREATE_TRMUL_PARAMS_H_

#include <cstdint>
#include <cstring>
#include <type_traits>

#include "ruy/allocator.h"
#include "ruy/ctx.h"
#include "ruy/kernel.h"
#include "ruy/mat.h"
#include "ruy/mul_params.h"
#include "ruy/pack.h"
#include "ruy/path.h"
#include "ruy/performance_advisory.h"
#include "ruy/trace.h"
#include "ruy/trmul_params.h"

namespace ruy {
// While the only entry point to this file is CreateTrMulParams, its templatized
// nature requires putting more code in this header than we would like. This
// internal implementation code is enclosed in namespace 'detail'.
namespace detail {

inline void CreatePackedLayout(const MatLayout& src,
                               const KernelLayout& kernel_layout,
                               PMatLayout* packed_layout) {}

template <typename Scalar, typename PackedScalar>
void CreatePackedMatrix(Side side, const KernelLayout& kernel_layout,
                        TrMulParams* params) {}

template <typename KernelType>
struct CheckKernelPathImpl {};

#if RUY_DCHECK_IS_ENABLED
CheckKernelPathImpl<Kernel<ThePath, SrcScalar, SrcScalar, DstScalar, MulParams<AccumScalar, DstScalar>>>;
#endif

template <typename KernelType>
void CheckKernelPath(Path expected_path) {}

template <Path ThePath, typename LhsScalar, typename RhsScalar,
          typename AccumScalar, typename DstScalar>
void PopulateTrMulParams(TrMulParams* params) {}

// PopulateTrMulParamsAllCompiledPaths calls into one of multiple
// instantiations of PopulateTrMulParams. For each bit that is set in
// CompiledPaths, it statically instantiates PopulateTrMulParams with a Path
// corresponding to that single bit. The call to PopulateTrMulParams is
// guarded by a runtime check that it is in fact the dynamically selected path.
//
// PopulateTrMulParamsAllCompiledPaths is implemented with template
// metaprogramming by mutual recursion between PathSearchCountdown and
// PathSearchCompiledPaths.
//
// PopulateTrMulParamsAllCompiledPaths is logically implementing the following
// computation:
//
// template <Path CompiledPaths>
// void PopulateTrMulParamsAllCompiledPaths(Path the_path,
//                                            TrMulParams* params) {
//   for (int bit = 8 * sizeof(Path) - 1; bit != -1; bit--) { // [1]
//     Path current_path = static_cast<Path>(1 << bit);
//     if ((CompiledPaths & current_path) != Path::kNone) { // [2]
//       if (current_path == the_path) { // [3]
//         PopulateTrMulParams<current_path, ...>(the_path, params);
//         return;
//       }
//     }
//   }
// }
//
//
//
// [1] - Done by the main definition of PathSearchCountdown. The `bit--` is
// done in the recursion of PathSearchOnlyCompiledPaths.
// [2] - Done by PathSearchOnlyCompiledPaths's partial template
// specialization on InCompiledPaths. This is the check which necessitates
// doing the whole computation at C++ compile time.
// [3] - Done by the `if` in the main definition of
// PathSearchOnlyCompiledPaths.
//
// The template metaprogramming is necessary because:
// - In `PopulateTrMulParams<current_path, ...>`, current_path must be a C++
// compile-time constant.
// - PopulateTrMulParamsAllCompiledPaths must not instantiate
// inner loops for paths that are not in CompiledPaths, since that can result in
// bogus instantiations which cause a compile time failure.
template <Path CompiledPaths, int BitNumber, typename LhsScalar,
          typename RhsScalar, typename AccumScalar, typename DstScalar>
struct PathSearchCountdown;

template <Path CompiledPaths, bool InCompiledPaths, int BitNumber,
          typename LhsScalar, typename RhsScalar, typename AccumScalar,
          typename DstScalar>
struct PathSearchOnlyCompiledPaths {};

// Skip this iteration if CompiledPaths doesn't contain the specified path.
PathSearchOnlyCompiledPaths<CompiledPaths, false, BitNumber, LhsScalar, RhsScalar, AccumScalar, DstScalar>;

template <Path CompiledPaths, int BitNumber, typename LhsScalar,
          typename RhsScalar, typename AccumScalar, typename DstScalar>
struct PathSearchCountdown {};

// Termination of the countdown. If the counter reaches -1, then we haven't
// found the specified path.
PathSearchCountdown<CompiledPaths, -1, LhsScalar, RhsScalar, AccumScalar, DstScalar>;

template <Path CompiledPaths, typename LhsScalar, typename RhsScalar,
          typename AccumScalar, typename DstScalar>
void PopulateTrMulParamsAllCompiledPaths(Path the_path, TrMulParams* params) {}

template <typename AccumScalar, typename DstScalar>
void AssertThatExtraCapacityInPerChannelBuffersIsZeroInitialized(
    const MulParams<AccumScalar, DstScalar>& mul_params, int user_size,
    int user_capacity) {}

template <typename AccumScalar, typename DstScalar,
          bool HaveQuantizedMultipliers =
              std::is_same<AccumScalar, std::int32_t>::value &&
              !std::is_same<DstScalar, std::int32_t>::value>
struct EnsurePerChannelBuffersLargeEnoughImpl {};

EnsurePerChannelBuffersLargeEnoughImpl<AccumScalar, DstScalar, false>;

template <typename AccumScalar, typename DstScalar>
void EnsurePerChannelBuffersLargeEnough(
    const TrMulParams& params, Ctx* ctx,
    MulParams<AccumScalar, DstScalar>* mul_params) {}

// Ensures that `params->mul_params_bytes` contains MulParams data that's ready
// to be consumed by the kernel. As a first-order approximation, that is simply
// copying the user-provided `mul_params`, however there are a few changes.
//
//   1. The specified `channel_dimension` value overrides the channel_dimension
//      member in `mul_params`. The reason why `channel_dimension` is being
//      special-cased among MulParams members is that we will need to transpose
//      MulParams, and that consists just in toggling channel_dimension.
//   2. Per-channel buffers may be reallocated, see
//      EnsurePerChannelBuffersLargeEnough.
template <typename AccumScalar, typename DstScalar>
void FinalizeMulParams(const MulParams<AccumScalar, DstScalar>& mul_params,
                       ChannelDimension channel_dimension, Ctx* ctx,
                       TrMulParams* params) {}

// In this function, the `channel_dimension` parameter overrides the value
// of the channel_dimension member in the `mul_params` parameter. See the
// FinalizeMulParams comment.
template <Path CompiledPaths, typename LhsScalar, typename RhsScalar,
          typename AccumScalar, typename DstScalar>
void CreateTrMulParamsAssumingColMajorDst(
    const Mat<LhsScalar>& lhs, const Mat<RhsScalar>& rhs,
    const Mat<DstScalar>& dst,
    const MulParams<AccumScalar, DstScalar>& mul_params,
    ChannelDimension channel_dimension, Ctx* ctx, TrMulParams* params) {}

}  // namespace detail

inline ChannelDimension Transpose(ChannelDimension channel_dimension) {}

// CreateTrMulParams's output is a TrMulParams object that encodes
// all of the input information required_capacity by the middle-end, that is,
// the TrMul function.
//
// CreateTrMulParams performs the following tasks:
//   1. Reduce to the case of column-major destination, by transposing the
//      whole problem as needed.
//   2. Select the single code path to be taken, out of the set of paths
//      described by the `CompiledPaths` template parameter, based on the
//      runtime input parameter `the_path`.
//   3. Perform type-erasure, converting templatized typed input parameters
//      to the un-typed data stored in TrMulParams.
template <Path CompiledPaths, typename LhsScalar, typename RhsScalar,
          typename AccumScalar, typename DstScalar>
void CreateTrMulParams(const Mat<LhsScalar>& lhs, const Mat<RhsScalar>& rhs,
                       const Mat<DstScalar>& dst,
                       const MulParams<AccumScalar, DstScalar>& mul_params,
                       Ctx* ctx, TrMulParams* params) {}

}  // namespace ruy

#endif  // RUY_RUY_CREATE_TRMUL_PARAMS_H_