/* Copyright 2019 Google LLC. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #ifndef RUY_RUY_MUL_PARAMS_H_ #define RUY_RUY_MUL_PARAMS_H_ #include <cstdint> #include <limits> #include <type_traits> #include "ruy/check_macros.h" #include "ruy/size_util.h" namespace ruy { // Enumeration to designate which dimension is the 'channels', for MulParams // features that are 'per-channel', namely the bias-vector and the quantized // multiplier. enum class ChannelDimension : std::int8_t { … }; namespace detail { template <typename tAccumScalar, typename tDstScalar> struct MulParamsStorage; } // MulParams describes all about a matrix multiplication that // isn't encoded in the LHS, RHS and destination matrices. Some of that // information is encoded as compile-time constants and types (for instance, the // choice of accumulator type, AccumScalar). Some of that information is encoded // as runtime values (for instance, the optional bias vector). // // Template parameters: // AccumScalar: Accumulator type. The type of accumulators used to compute the // dot-products before being ultimately casted to the destination type. // DstScalar: The destination scalar type. // // Constraints on these template parameters (see also the ruy::Mul comment): // * If DstScalar is floating-point then AccumScalar must also be. // * If DstScalar is integral then AccumScalar must be std::int32_t. Moreover // in that integral case, there is a mode switch: // - If DstScalar is std::int32_t then the multiplier_* fields are all // disabled, and ruy::Mul will just return raw (unscaled) accumulators. // - If DstScalar is not std::int32_t then the multiplier_* fields are // enabled, and ruy::Mul will use them to scale internal std::int32_t // accumulators before casting them to the DstScalar type. The default // values are such that the effective multiplier is 1 (no scaling). // // For the latter case (DstScalar integral and narrower than std::int32_t), // reference code can be found in the implementation of ruy::ApplyMultiplier. // If you look there, you'll find warnings like this: // // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // Warning: this code is not meant to be bit-exact-normative. // Please refer to the class comment of ruy::MulParams, in mul_params.h. // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // // The explanation of this warning is that as of early 2021, we still don't know // whether it is advisable to let this code as-is have normative value, or // whether that would become advisable after some specific final change. // // Ruy's CPU backends (x86 and ARM) as of early 2021 happen to conform // bit-exactly to this reference, but we also know that x86 could be faster if // it didn't, and so could NEON-less ARM (such as Cortex-M) (see [2]). We don't // know that this particular reference code is inherently better than other // forms that could perform better on these architectures --- in fact, the // alternative that was proposed in [2] as better performing on ARM Cortex-M // is also inherently more accurate thanks to rounding only once, but it would // perform worse on both ARM NEON, and x86. // // In fact, if we look at other hardware architectures beyond current Ruy // targets, namely "hardware accelerators", it becomes clear that there is no // hope for any form of this to be efficiently implementable simultaneously on // all current relevant hardware. Indeed, some accelerators prefer to perform // the multiplication in IEEE float32, others in IEEE float16, others in // bfloat16, others in 16-bit fixed-point... // // See: // [1] https://github.com/google/ruy/pull/227 // [2] https://github.com/tensorflow/tensorflow/issues/25087 template <typename tAccumScalar, typename tDstScalar> class MulParams final { … }; namespace detail { // Floating-point case. template <typename AccumScalar, typename DstScalar> struct MulParamsStorage final { … }; // Specialization for the integer-quantized type, with down-quantization of // int32 accumulators to a narrower destination scalar type. MulParamsStorage<std::int32_t, DstScalar>; // Specialization used in the integer case when outputting raw int32 // accumulators, without down-quantization to a narrower destination scalar // type. In this case, the feature of clamping destination values is not // available. template <> struct MulParamsStorage<std::int32_t, std::int32_t> final { … }; } // namespace detail } // namespace ruy #endif // RUY_RUY_MUL_PARAMS_H_