chromium/third_party/tflite/src/tensorflow/lite/kernels/cpu_backend_gemm_params.h

/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_PARAMS_H_
#define TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_PARAMS_H_

#include <cstdint>
#include <limits>
#include <type_traits>

#include "tensorflow/lite/kernels/internal/compatibility.h"

namespace tflite {

namespace cpu_backend_gemm {

// Matrix storage order: column-major or row-major.
enum class Order {};

enum class CachePolicy : std::uint8_t {};

inline CachePolicy DefaultCachePolicy(bool is_constant_data) {}

// MatrixParams encapsulates the parameters that Gemm needs about each
// matrix, besides the buffer data pointer.
// Compare to ruy::Matrix, which also encapsulates the data pointer.
// Rationale for leaving the data pointer out of here: doing so
// requires complicated const-correctness mechanics. See
// ruy::ConstCheckingPtr.
template <typename Scalar>
struct MatrixParams {};

// Enumeration of broad categories of Gemm.
//
// The primary reason for this to exist is to allow Gemm to compile
// only uniform-quantized or only per-channel-quantized code paths.
// This is unneeded with ruy as the back-end, as this is only a runtime
// difference in ruy, but with gemmlowp these really are separate code
// paths and templatizing in a QuantizationFlavor is necessary to avoid
// compiling unused gemmlowp code. Indeed, TFLite currently uses
// uint8 with uniform quantization and int8 with per-channel quantization,
// and does not use uint8 with per-channel. We want to avoid compiling
// the gemmlowp uint8 per-channel path when gemmlowp is the back-end.
//
// It's possible to drop this in the future if gemmlowp goes away and no
// other then-relevant backend library handles quantized paths in a way that
// requires knowing this at compile-time.
enum class QuantizationFlavor {};

// Additional parameters that Gemm needs, beyond what falls into
// the MatrixParams that it takes. Compare to ruy::Spec.
//
// Decoupling AccumScalar from DstScalar (rather than deducing it from that)
// is useful future-proofing. Think of a float16 path using float32 accum.
//
// QuantizationFlavor is passed here even though it's technically not used
// in this class. This is so that we retain the ability in the future to
// specialize this class for quantization flavor, and this allows for
// Gemm to be templatized in quantization_flavor via the GemmParams that it
// takes, allowing for automatic template parameter deduction to take place,
// so that most call sites don't need to specify a QuantizationFlavor
// (only those that need perchannel quantization do).
template <typename AccumScalar, typename DstScalar,
          QuantizationFlavor quantization_flavor =
              std::is_floating_point<AccumScalar>::value
                  ? QuantizationFlavor::kFloatingPoint
                  : QuantizationFlavor::kIntegerWithUniformMultiplier>
struct GemmParams {};

/* Convenience typedefs */

QuantizedGemmParams;

FloatGemmParams;

/* Validation functions */

// Note that this uses TFLITE_DCHECK from kernels/internal/compatibility.h
// and not TF_LITE_ASSERT from op_macros.h. We want this to be explicitly
// debug-build-only assertions so that there's not reason not to
// generously validate, and TF_LITE_ASSERT is actually at the moment
// a release-build assertion. See b/131587258.

// Validates self-consistency of GemmParams.
template <typename AccumScalar, typename DstScalar,
          QuantizationFlavor quantization_flavor>
void ValidateGemmParams(
    const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params) {}

namespace detail {

template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
          typename DstScalar, QuantizationFlavor quantization_flavor>
struct ValidateTypes {};

ValidateTypes<LhsScalar, RhsScalar, AccumScalar, DstScalar, QuantizationFlavor::kFloatingPoint>;

}  // namespace detail

// Validates overall consistency of all the parameters taken by a Gemm call:
// the 3 MatrixParams and the GemmParams.
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
          typename DstScalar, QuantizationFlavor quantization_flavor>
void ValidateParams(
    const MatrixParams<LhsScalar>& lhs_params,
    const MatrixParams<RhsScalar>& rhs_params,
    const MatrixParams<DstScalar>& dst_params,
    const GemmParams<AccumScalar, DstScalar, quantization_flavor>& params) {}

// Test if the Gemm is degenerate in some way, e.g. nonsensical dimenions.
template <typename LhsScalar, typename RhsScalar, typename DstScalar>
bool IsValidGemm(const MatrixParams<LhsScalar>& lhs_params,
                 const MatrixParams<RhsScalar>& rhs_params,
                 const MatrixParams<DstScalar>& dst_params) {}

}  // namespace cpu_backend_gemm

}  // namespace tflite

#endif  // TENSORFLOW_LITE_KERNELS_CPU_BACKEND_GEMM_PARAMS_H_