#ifndef RUY_RUY_KERNEL_X86_H_
#define RUY_RUY_KERNEL_X86_H_
#include <cstdint>
#include <cstring>
#include "ruy/kernel_common.h"
#include "ruy/mat.h"
#include "ruy/mul_params.h"
#include "ruy/opt_set.h"
#include "ruy/path.h"
#include "ruy/platform.h"
#include "ruy/tune.h"
namespace ruy {
#if RUY_PLATFORM_X86
RUY_INHERIT_KERNEL(…)
RUY_INHERIT_KERNEL(…)
RUY_INHERIT_KERNEL(…)
void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params);
void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params);
Kernel<Path::kAvx512, std::int8_t, std::int8_t, std::int32_t, DstScalar>;
Kernel<Path::kAvx512, std::int8_t, std::int16_t, std::int32_t, DstScalar>;
void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params);
void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>& param);
template <>
struct Kernel<Path::kAvx512, float, float, float, float> { … };
void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params);
void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params);
Kernel<Path::kAvx2Fma, std::int8_t, std::int8_t, std::int32_t, DstScalar>;
Kernel<Path::kAvx2Fma, std::int8_t, std::int16_t, std::int32_t, DstScalar>;
void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params);
void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params);
template <>
struct Kernel<Path::kAvx2Fma, float, float, float, float> { … };
void KernelFloatAvx(const KernelParamsFloat<8, 8>& params);
void KernelFloatAvxSingleCol(const KernelParamsFloat<8, 8>& params);
template <>
struct Kernel<Path::kAvx, float, float, float, float> { … };
void Kernel8bitAvx(const KernelParams8bit<8, 8>& params);
void Kernel8bitAvxSingleCol(const KernelParams8bit<8, 8>& params);
Kernel<Path::kAvx, std::int8_t, std::int8_t, std::int32_t, DstScalar>;
#endif
}
#if ((RUY_PLATFORM_AVX || RUY_PLATFORM_AVX2_FMA) && RUY_OPT(ASM))
#include <immintrin.h>
namespace ruy {
namespace {
namespace intrin_utils {
template <Path path>
inline float mm256_get1_ps(const __m256 a, int i) { … }
template <Path path>
inline void mm256_n_storeu_ps(float* dst, int residual_rows, const __m256 v) { … }
template <Path path>
inline __m256 MulAdd(const __m256&, const __m256&, const __m256&) { … }
template <Path path>
inline __m256i mm256_shuffle_epi8(const __m256i&, const __m256i&) { … }
template <Path path>
inline void mm_storeu_si16(void* dst, __m128i v) { … }
template <Path path>
inline void mm_storeu_si32(void* dst, __m128i v) { … }
template <Path path>
inline __m128i mm_loadu_si32(const void* src) { … }
template <Path path>
inline __m128i mm256_extracti128_si256(const __m256i&, const int) { … }
template <Path path>
inline void mm256_n_storeu_cvtepi32_epi8(std::uint8_t* dst, int residual_rows,
const __m256i v) { … }
template <Path path>
inline void mm256_storeu_cvtepi32_epi8(std::uint8_t* dst, const __m256i v) { … }
template <Path path>
inline void mm256_n_storeu_cvtepi32_epi8(std::int8_t* dst, int residual_rows,
const __m256i v) { … }
template <Path path>
inline void mm256_storeu_cvtepi32_epi8(std::int8_t* dst, const __m256i v) { … }
template <Path path>
inline void mm256_n_storeu_cvtepi32_epi16(std::int16_t* dst, int residual_rows,
const __m256i v) { … }
template <Path path>
inline void mm256_storeu_cvtepi32_epi16(std::int16_t* dst, const __m256i v) { … }
template <Path path>
inline void mm256_n_storeu_epi32(std::int32_t* dst, int residual_rows,
const __m256i v) { … }
template <Path path>
inline void mm256_storeu_epi32(std::int32_t* dst, const __m256i v) { … }
template <Path path>
void mm256_transpose8x8_ps(__m256* v0, __m256* v1, __m256* v2, __m256* v3,
__m256* v4, __m256* v5, __m256* v6, __m256* v7) { … }
template <Path path>
void mm256_transpose8x8_epi32(__m256i* v0, __m256i* v1, __m256i* v2,
__m256i* v3, __m256i* v4, __m256i* v5,
__m256i* v6, __m256i* v7) { … }
}
}
template <Path path>
inline void KernelFloatAvxCommon(const KernelParamsFloat<8, 8>& params) { … }
template <Path path>
inline void KernelFloatAvxCommonSingleCol(
const KernelParamsFloat<8, 8>& params) { … }
}
#endif
#endif