chromium/third_party/ruy/src/ruy/kernel_x86.h

/* Copyright 2019 Google LLC. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef RUY_RUY_KERNEL_X86_H_
#define RUY_RUY_KERNEL_X86_H_

#include <cstdint>
#include <cstring>

#include "ruy/kernel_common.h"
#include "ruy/mat.h"
#include "ruy/mul_params.h"
#include "ruy/opt_set.h"
#include "ruy/path.h"
#include "ruy/platform.h"
#include "ruy/tune.h"

namespace ruy {

#if RUY_PLATFORM_X86

RUY_INHERIT_KERNEL()
RUY_INHERIT_KERNEL()
RUY_INHERIT_KERNEL()

void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params);
void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params);

Kernel<Path::kAvx512, std::int8_t, std::int8_t, std::int32_t, DstScalar>;

Kernel<Path::kAvx512, std::int8_t, std::int16_t, std::int32_t, DstScalar>;

void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params);
void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>& param);

template <>
struct Kernel<Path::kAvx512, float, float, float, float> {};

void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params);
void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params);

Kernel<Path::kAvx2Fma, std::int8_t, std::int8_t, std::int32_t, DstScalar>;

Kernel<Path::kAvx2Fma, std::int8_t, std::int16_t, std::int32_t, DstScalar>;

void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params);
void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params);

template <>
struct Kernel<Path::kAvx2Fma, float, float, float, float> {};

void KernelFloatAvx(const KernelParamsFloat<8, 8>& params);
void KernelFloatAvxSingleCol(const KernelParamsFloat<8, 8>& params);

template <>
struct Kernel<Path::kAvx, float, float, float, float> {};

void Kernel8bitAvx(const KernelParams8bit<8, 8>& params);
void Kernel8bitAvxSingleCol(const KernelParams8bit<8, 8>& params);

Kernel<Path::kAvx, std::int8_t, std::int8_t, std::int32_t, DstScalar>;

#endif  // RUY_PLATFORM_X86
}  // namespace ruy

#if ((RUY_PLATFORM_AVX || RUY_PLATFORM_AVX2_FMA) && RUY_OPT(ASM))

#include <immintrin.h>  // IWYU pragma: keep

namespace ruy {
namespace {
namespace intrin_utils {

// Defined as a template so clang won't detect it as an uneeded
// definition.
template <Path path>
inline float mm256_get1_ps(const __m256 a, int i) {}

// Defined as a template so clang won't detect it as an uneeded
// definition.
template <Path path>
inline void mm256_n_storeu_ps(float* dst, int residual_rows, const __m256 v) {}

template <Path path>
inline __m256 MulAdd(const __m256&, const __m256&, const __m256&) {}

template <Path path>
inline __m256i mm256_shuffle_epi8(const __m256i&, const __m256i&) {}

// Polyfill for _mm_storeu_si16(dst, v).
template <Path path>
inline void mm_storeu_si16(void* dst, __m128i v) {}

// Polyfill for _mm_storeu_si32(dst, v).
template <Path path>
inline void mm_storeu_si32(void* dst, __m128i v) {}

// Polyfill for _mm_loadu_si32(src).
template <Path path>
inline __m128i mm_loadu_si32(const void* src) {}

template <Path path>
inline __m128i mm256_extracti128_si256(const __m256i&, const int) {}

template <Path path>
inline void mm256_n_storeu_cvtepi32_epi8(std::uint8_t* dst, int residual_rows,
                                         const __m256i v) {}

template <Path path>
inline void mm256_storeu_cvtepi32_epi8(std::uint8_t* dst, const __m256i v) {}

template <Path path>
inline void mm256_n_storeu_cvtepi32_epi8(std::int8_t* dst, int residual_rows,
                                         const __m256i v) {}

template <Path path>
inline void mm256_storeu_cvtepi32_epi8(std::int8_t* dst, const __m256i v) {}

template <Path path>
inline void mm256_n_storeu_cvtepi32_epi16(std::int16_t* dst, int residual_rows,
                                          const __m256i v) {}

template <Path path>
inline void mm256_storeu_cvtepi32_epi16(std::int16_t* dst, const __m256i v) {}

template <Path path>
inline void mm256_n_storeu_epi32(std::int32_t* dst, int residual_rows,
                                 const __m256i v) {}

template <Path path>
inline void mm256_storeu_epi32(std::int32_t* dst, const __m256i v) {}

// Transpose a 8x8 matrix of floats.
template <Path path>
void mm256_transpose8x8_ps(__m256* v0, __m256* v1, __m256* v2, __m256* v3,
                           __m256* v4, __m256* v5, __m256* v6, __m256* v7) {}

// Transpose a 8x8 matrix of int32's.
template <Path path>
void mm256_transpose8x8_epi32(__m256i* v0, __m256i* v1, __m256i* v2,
                              __m256i* v3, __m256i* v4, __m256i* v5,
                              __m256i* v6, __m256i* v7) {}

}  // namespace intrin_utils
}  // namespace

template <Path path>
inline void KernelFloatAvxCommon(const KernelParamsFloat<8, 8>& params) {}

template <Path path>
inline void KernelFloatAvxCommonSingleCol(
    const KernelParamsFloat<8, 8>& params) {}
}  // namespace ruy
#endif  //  (RUY_PLATFORM_AVX || RUY_PLATFORM_AVX2_FMA) && RUY_OPT(ASM)

#endif  // RUY_RUY_KERNEL_X86_H_