#include <cstdint>
#include <cstring>
#include "ruy/check_macros.h"
#include "ruy/opt_set.h"
#include "ruy/pack_x86.h"
#include "ruy/path.h"
#include "ruy/platform.h"
#include "ruy/profiler/instrumentation.h"
#if RUY_PLATFORM_AVX2_FMA && RUY_OPT(INTRINSICS)
#include <immintrin.h>
#endif
namespace ruy {
#if !(RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM))
void Pack8bitColMajorForAvx2(const std::int8_t*, std::int8_t,
const std::int8_t*, int, int, int, std::int8_t*,
std::int32_t*) {
RUY_DCHECK(false);
}
void PackFloatColMajorForAvx2(const float*, const float*, int, int, int,
float*) {
RUY_DCHECK(false);
}
void Pack8bitRowMajorForAvx2(const std::uint8_t*, int, int, std::int8_t*, int,
int, int, int, int, int, int, std::int32_t*) {
RUY_DCHECK(false);
}
#else
PackImpl8bitAvx2;
PackImplFloatAvx2;
namespace {
inline void Pack8bitColMajorForAvx2Packer(
const std::int8_t* src_ptr, std::int8_t input_xor,
const std::int8_t* zerobuf, int src_stride, int remaining_src_cols,
int src_rows, std::int8_t* packed_ptr, std::int32_t* sums_ptr,
std::int8_t* trailing_buf) { … }
template <>
inline __m256i CompareGreaterThan<Path::kAvx2Fma>(const __m256i& a,
const __m256i& b) { … }
}
void Pack8bitColMajorForAvx2(const std::int8_t* src_ptr, std::int8_t input_xor,
const std::int8_t* zerobuf, int src_stride,
int remaining_src_cols, int src_rows,
std::int8_t* packed_ptr, std::int32_t* sums_ptr) { … }
void PackFloatColMajorForAvx2(const float* src_ptr, const float* zerobuf,
int src_stride, int remaining_src_cols,
int src_rows, float* packed_ptr) { … }
void Pack8bitRowMajorForAvx2(const std::uint8_t* src_ptr, int src_stride,
int src_zero_point, std::int8_t* packed_ptr,
int packed_stride, int start_col, int end_col,
int src_cols, int block_row, int src_rows,
int input_xor, std::int32_t* sums) { … }
#endif
}