#include <cstdint>
#include <cstring>
#include "ruy/check_macros.h"
#include "ruy/opt_set.h"
#include "ruy/pack_x86.h"
#include "ruy/path.h"
#include "ruy/platform.h"
#include "ruy/profiler/instrumentation.h"
#if RUY_PLATFORM_AVX512 && RUY_OPT(INTRINSICS)
#include <immintrin.h>
#endif
namespace ruy {
#if !(RUY_PLATFORM_AVX512 && RUY_OPT(ASM))
void Pack8bitColMajorForAvx512(const std::int8_t*, std::int8_t,
const std::int8_t*, int, int, int, std::int8_t*,
std::int32_t*) { … }
void Pack16bitColMajorForAvx512(const std::int16_t*, const std::int16_t*, int,
int, int, std::int16_t*, std::int32_t*) { … }
void PackFloatColMajorForAvx512(const float*, const float*, int, int, int,
float*) { … }
void Pack8bitRowMajorForAvx512(const std::uint8_t*, int, int, std::int8_t*, int,
int, int, int, int, int, int, std::int32_t*) { … }
#else
using PackImpl8bitAvx512 =
PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>,
std::int8_t, std::int8_t, std::int32_t, Order::kColMajor>;
using PackImpl16bitAvx512 =
PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>,
std::int16_t, std::int16_t, std::int32_t, Order::kColMajor>;
namespace {
template <typename PackImplAvx512, typename Scalar>
inline void ZeroHalfAvx512(int src_rows, Scalar packed_zero_point,
Scalar* packed_ptr, int chunked_row_mask) {
using Layout = typename PackImplAvx512::Layout;
static constexpr int kHalfLayoutCols =
PackImplAvx512::kHalfLayoutCols;
RUY_DCHECK_EQ(kHalfLayoutCols, 8);
RUY_DCHECK_EQ(Layout::kCols, 16);
RUY_DCHECK_EQ(Layout::kRows, 4);
const int non_trailing_blocks = (src_rows & ~chunked_row_mask) >> 2;
for (int k = 0; k < non_trailing_blocks; ++k) {
for (int j = 0; j < (kHalfLayoutCols * Layout::kRows); ++j) {
packed_ptr[Layout::kCols * Layout::kRows * k + j] = packed_zero_point;
}
}
}
template <typename Scalar>
inline __m512i LoaduTwo(const Scalar* addr_lo, const Scalar* addr_hi) {
__m512i lower_filled = _mm512_castsi256_si512(
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(addr_lo)));
return _mm512_inserti32x8(
lower_filled,
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(addr_hi)), 1);
}
inline __m512i MaskLoaduTwo(__mmask32 row_mask, const __m256i default_value_v,
const std::int8_t* addr_lo,
const std::int8_t* addr_hi) {
const __m512i lower_filled = _mm512_castsi256_si512(
_mm256_mask_loadu_epi8(default_value_v, row_mask, addr_lo));
return _mm512_inserti32x8(
lower_filled, _mm256_mask_loadu_epi8(default_value_v, row_mask, addr_hi),
1);
}
inline __m512i MaskLoaduTwo(__mmask32 row_mask, const __m256i default_value_v,
const std::int16_t* addr_lo,
const std::int16_t* addr_hi) {
const __m512i lower_filled = _mm512_castsi256_si512(
_mm256_mask_loadu_epi16(default_value_v, row_mask, addr_lo));
return _mm512_inserti32x8(
lower_filled, _mm256_mask_loadu_epi16(default_value_v, row_mask, addr_hi),
1);
}
inline void HalfPack8bitAvx512(const std::int8_t* src_ptr,
std::int8_t input_xor,
const std::int8_t* zerobuf, int src_stride,
int remaining_src_cols, int src_rows,
std::int8_t* packed_ptr, std::int32_t* sums_ptr,
std::int8_t* trailing_buf) {
using Layout = PackImpl8bitAvx512::Layout;
RUY_DCHECK_EQ(Layout::kCols, 16);
RUY_DCHECK_EQ(Layout::kRows, 4);
constexpr int kNumRowChunks = 8;
constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows;
const std::int8_t* src_ptr0 = src_ptr;
const std::int8_t* src_ptr1 = src_ptr0 + src_stride;
const std::int8_t* src_ptr2 = src_ptr1 + src_stride;
const std::int8_t* src_ptr3 = src_ptr2 + src_stride;
const std::int8_t* src_ptr4 = src_ptr3 + src_stride;
const std::int8_t* src_ptr5 = src_ptr4 + src_stride;
const std::int8_t* src_ptr6 = src_ptr5 + src_stride;
const std::int8_t* src_ptr7 = src_ptr6 + src_stride;
std::int64_t src_inc0 = kNumChunkedSrcRows;
std::int64_t src_inc1 = kNumChunkedSrcRows;
std::int64_t src_inc2 = kNumChunkedSrcRows;
std::int64_t src_inc3 = kNumChunkedSrcRows;
std::int64_t src_inc4 = kNumChunkedSrcRows;
std::int64_t src_inc5 = kNumChunkedSrcRows;
std::int64_t src_inc6 = kNumChunkedSrcRows;
std::int64_t src_inc7 = kNumChunkedSrcRows;
if (remaining_src_cols < 8) {
if (remaining_src_cols <= 0) {
src_ptr0 = zerobuf;
src_inc0 = 0;
}
if (remaining_src_cols <= 1) {
src_ptr1 = zerobuf;
src_inc1 = 0;
}
if (remaining_src_cols <= 2) {
src_ptr2 = zerobuf;
src_inc2 = 0;
}
if (remaining_src_cols <= 3) {
src_ptr3 = zerobuf;
src_inc3 = 0;
}
if (remaining_src_cols <= 4) {
src_ptr4 = zerobuf;
src_inc4 = 0;
}
if (remaining_src_cols <= 5) {
src_ptr5 = zerobuf;
src_inc5 = 0;
}
if (remaining_src_cols <= 6) {
src_ptr6 = zerobuf;
src_inc6 = 0;
}
src_ptr7 = zerobuf;
src_inc7 = 0;
}
const std::int8_t zero_point = zerobuf[0];
if (sums_ptr) {
for (int i = 0; i < 8; ++i) {
sums_ptr[i] = 0;
}
}
std::int32_t sums_adjustment = 0;
const __m512i ones_16bit = _mm512_set1_epi16(1);
__m512i sums_8x2_32bit = _mm512_set1_epi32(0);
for (int k = 0; k < src_rows; k += 2 * kNumChunkedSrcRows) {
for (int m = 0; m < 2; ++m) {
const int available_src_rows = src_rows - k - m * kNumChunkedSrcRows;
if (available_src_rows >= kNumChunkedSrcRows) {
if (sums_ptr) {
__m512i t0, t1, t2, t3;
__m512i r0, r1, r2, r3;
const __m512i input_xor_v = _mm512_set1_epi8(input_xor);
t0 = LoaduTwo(src_ptr0, src_ptr4);
t1 = LoaduTwo(src_ptr1, src_ptr5);
t2 = LoaduTwo(src_ptr2, src_ptr6);
t3 = LoaduTwo(src_ptr3, src_ptr7);
r0 = _mm512_unpacklo_epi32(t0, t1);
r2 = _mm512_unpackhi_epi32(t0, t1);
r1 = _mm512_unpacklo_epi32(t2, t3);
r3 = _mm512_unpackhi_epi32(t2, t3);
t0 = _mm512_unpacklo_epi64(r0, r1);
t2 = _mm512_unpackhi_epi64(r0, r1);
t1 = _mm512_unpacklo_epi64(r2, r3);
t3 = _mm512_unpackhi_epi64(r2, r3);
r0 = _mm512_shuffle_i32x4(t0, t1, 0x88);
r1 = _mm512_shuffle_i32x4(t0, t1, 0xdd);
r2 = _mm512_shuffle_i32x4(t2, t3, 0x88);
r3 = _mm512_shuffle_i32x4(t2, t3, 0xdd);
r0 = _mm512_xor_si512(r0, input_xor_v);
r1 = _mm512_xor_si512(r1, input_xor_v);
r2 = _mm512_xor_si512(r2, input_xor_v);
r3 = _mm512_xor_si512(r3, input_xor_v);
const __m256i r0_0 = _mm512_castsi512_si256(r0);
const __m256i r0_1 = _mm512_extracti32x8_epi32(r0, 1);
const __m256i r1_0 = _mm512_castsi512_si256(r1);
const __m256i r1_1 = _mm512_extracti32x8_epi32(r1, 1);
const __m256i r2_0 = _mm512_castsi512_si256(r2);
const __m256i r2_1 = _mm512_extracti32x8_epi32(r2, 1);
const __m256i r3_0 = _mm512_castsi512_si256(r3);
const __m256i r3_1 = _mm512_extracti32x8_epi32(r3, 1);
__m512i sums_8x4_16bit;
sums_8x4_16bit = _mm512_cvtepi8_epi16(r0_0);
sums_8x4_16bit =
_mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r0_1));
sums_8x4_16bit =
_mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_0));
sums_8x4_16bit =
_mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_1));
sums_8x4_16bit =
_mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_0));
sums_8x4_16bit =
_mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_1));
sums_8x4_16bit =
_mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_0));
sums_8x4_16bit =
_mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_1));
const __m512i sums_8x2_32bit_new =
_mm512_madd_epi16(sums_8x4_16bit, ones_16bit);
sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit, sums_8x2_32bit_new);
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(packed_ptr + 0 * 16 * 4), r0_0);
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(packed_ptr + 2 * 16 * 4), r0_1);
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(packed_ptr + 4 * 16 * 4), r1_0);
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(packed_ptr + 6 * 16 * 4), r1_1);
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(packed_ptr + 1 * 16 * 4), r2_0);
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(packed_ptr + 3 * 16 * 4), r2_1);
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(packed_ptr + 5 * 16 * 4), r3_0);
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(packed_ptr + 7 * 16 * 4), r3_1);
} else {
__m512i t0, t1, t2, t3;
__m512i r0, r1, r2, r3;
const __m512i input_xor_v = _mm512_set1_epi8(input_xor);
t0 = LoaduTwo(src_ptr0, src_ptr4);
t1 = LoaduTwo(src_ptr1, src_ptr5);
t2 = LoaduTwo(src_ptr2, src_ptr6);
t3 = LoaduTwo(src_ptr3, src_ptr7);
r0 = _mm512_unpacklo_epi32(t0, t1);
r2 = _mm512_unpackhi_epi32(t0, t1);
r1 = _mm512_unpacklo_epi32(t2, t3);
r3 = _mm512_unpackhi_epi32(t2, t3);
t0 = _mm512_unpacklo_epi64(r0, r1);
t2 = _mm512_unpackhi_epi64(r0, r1);
t1 = _mm512_unpacklo_epi64(r2, r3);
t3 = _mm512_unpackhi_epi64(r2, r3);
r0 = _mm512_shuffle_i32x4(t0, t1, 0x88);
r1 = _mm512_shuffle_i32x4(t0, t1, 0xdd);
r2 = _mm512_shuffle_i32x4(t2, t3, 0x88);
r3 = _mm512_shuffle_i32x4(t2, t3, 0xdd);
r0 = _mm512_xor_si512(r0, input_xor_v);
r1 = _mm512_xor_si512(r1, input_xor_v);
r2 = _mm512_xor_si512(r2, input_xor_v);
r3 = _mm512_xor_si512(r3, input_xor_v);
const __m256i r0_0 = _mm512_castsi512_si256(r0);
const __m256i r0_1 = _mm512_extracti32x8_epi32(r0, 1);
const __m256i r1_0 = _mm512_castsi512_si256(r1);
const __m256i r1_1 = _mm512_extracti32x8_epi32(r1, 1);
const __m256i r2_0 = _mm512_castsi512_si256(r2);
const __m256i r2_1 = _mm512_extracti32x8_epi32(r2, 1);
const __m256i r3_0 = _mm512_castsi512_si256(r3);
const __m256i r3_1 = _mm512_extracti32x8_epi32(r3, 1);
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(packed_ptr + 0 * 16 * 4), r0_0);
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(packed_ptr + 2 * 16 * 4), r0_1);
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(packed_ptr + 4 * 16 * 4), r1_0);
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(packed_ptr + 6 * 16 * 4), r1_1);
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(packed_ptr + 1 * 16 * 4), r2_0);
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(packed_ptr + 3 * 16 * 4), r2_1);
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(packed_ptr + 5 * 16 * 4), r3_0);
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(packed_ptr + 7 * 16 * 4), r3_1);
}
} else if (available_src_rows > 0) {
RUY_DCHECK_LT(available_src_rows >> 2, kNumChunkedSrcRows);
const __mmask32 row_mask =
(static_cast<std::uint64_t>(1) << available_src_rows) - 1;
sums_adjustment += -(zero_point ^ input_xor) * 4 *
(8 - ((available_src_rows + 3) >> 2));
__m512i t0, t1, t2, t3;
__m512i r0, r1, r2, r3;
const __m512i input_xor_v = _mm512_set1_epi8(input_xor);
const __m256i zero_point_v = _mm256_set1_epi8(zero_point);
t0 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr0, src_ptr4);
t1 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr1, src_ptr5);
t2 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr2, src_ptr6);
t3 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr3, src_ptr7);
r0 = _mm512_unpacklo_epi32(t0, t1);
r2 = _mm512_unpackhi_epi32(t0, t1);
r1 = _mm512_unpacklo_epi32(t2, t3);
r3 = _mm512_unpackhi_epi32(t2, t3);
t0 = _mm512_unpacklo_epi64(r0, r1);
t2 = _mm512_unpackhi_epi64(r0, r1);
t1 = _mm512_unpacklo_epi64(r2, r3);
t3 = _mm512_unpackhi_epi64(r2, r3);
r0 = _mm512_shuffle_i32x4(t0, t1, 0x88);
r1 = _mm512_shuffle_i32x4(t0, t1, 0xdd);
r2 = _mm512_shuffle_i32x4(t2, t3, 0x88);
r3 = _mm512_shuffle_i32x4(t2, t3, 0xdd);
r0 = _mm512_xor_si512(r0, input_xor_v);
r1 = _mm512_xor_si512(r1, input_xor_v);
r2 = _mm512_xor_si512(r2, input_xor_v);
r3 = _mm512_xor_si512(r3, input_xor_v);
const __m256i r0_0 = _mm512_castsi512_si256(r0);
const __m256i r0_1 = _mm512_extracti32x8_epi32(r0, 1);
const __m256i r1_0 = _mm512_castsi512_si256(r1);
const __m256i r1_1 = _mm512_extracti32x8_epi32(r1, 1);
const __m256i r2_0 = _mm512_castsi512_si256(r2);
const __m256i r2_1 = _mm512_extracti32x8_epi32(r2, 1);
const __m256i r3_0 = _mm512_castsi512_si256(r3);
const __m256i r3_1 = _mm512_extracti32x8_epi32(r3, 1);
__m512i sums_8x4_16bit;
sums_8x4_16bit = _mm512_cvtepi8_epi16(r0_0);
sums_8x4_16bit =
_mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r0_1));
sums_8x4_16bit =
_mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_0));
sums_8x4_16bit =
_mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_1));
sums_8x4_16bit =
_mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_0));
sums_8x4_16bit =
_mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_1));
sums_8x4_16bit =
_mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_0));
sums_8x4_16bit =
_mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_1));
const __m512i sums_8x2_32bit_new =
_mm512_madd_epi16(sums_8x4_16bit, ones_16bit);
sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit, sums_8x2_32bit_new);
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(trailing_buf + 0 * 16 * 4), r0_0);
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(trailing_buf + 2 * 16 * 4), r0_1);
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(trailing_buf + 4 * 16 * 4), r1_0);
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(trailing_buf + 6 * 16 * 4), r1_1);
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(trailing_buf + 1 * 16 * 4), r2_0);
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(trailing_buf + 3 * 16 * 4), r2_1);
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(trailing_buf + 5 * 16 * 4), r3_0);
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(trailing_buf + 7 * 16 * 4), r3_1);
}
packed_ptr += 16 * kNumChunkedSrcRows;
src_ptr0 += src_inc0;
src_ptr1 += src_inc1;
src_ptr2 += src_inc2;
src_ptr3 += src_inc3;
src_ptr4 += src_inc4;
src_ptr5 += src_inc5;
src_ptr6 += src_inc6;
src_ptr7 += src_inc7;
}
}
if (sums_ptr) {
const __m256i sums_adjustment_v = _mm256_set1_epi32(sums_adjustment);
__m256i sums =
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(sums_ptr));
const __m512i idx =
_mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0);
const __m512i sums_2x8_32bit =
_mm512_permutexvar_epi32(idx, sums_8x2_32bit);
sums = _mm256_add_epi32(sums, sums_adjustment_v);
sums = _mm256_add_epi32(sums, _mm512_castsi512_si256(sums_2x8_32bit));
sums = _mm256_add_epi32(sums, _mm512_extracti32x8_epi32(sums_2x8_32bit, 1));
_mm256_storeu_si256(reinterpret_cast<__m256i*>(sums_ptr), sums);
}
}
inline void HalfPack16bitAvx512(const std::int16_t* src_ptr,
const std::int16_t* zerobuf, int src_stride,
int remaining_src_cols, int src_rows,
std::int16_t* packed_ptr,
std::int32_t* sums_ptr,
std::int16_t* trailing_buf) {
using Layout = PackImpl16bitAvx512::Layout;
RUY_DCHECK_EQ(Layout::kCols, 16);
RUY_DCHECK_EQ(Layout::kRows, 4);
constexpr int kNumRowChunks = 4;
constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows;
const std::int16_t* src_ptr0 = src_ptr;
const std::int16_t* src_ptr1 = src_ptr0 + src_stride;
const std::int16_t* src_ptr2 = src_ptr1 + src_stride;
const std::int16_t* src_ptr3 = src_ptr2 + src_stride;
const std::int16_t* src_ptr4 = src_ptr3 + src_stride;
const std::int16_t* src_ptr5 = src_ptr4 + src_stride;
const std::int16_t* src_ptr6 = src_ptr5 + src_stride;
const std::int16_t* src_ptr7 = src_ptr6 + src_stride;
std::int64_t src_inc0 = kNumChunkedSrcRows;
std::int64_t src_inc1 = kNumChunkedSrcRows;
std::int64_t src_inc2 = kNumChunkedSrcRows;
std::int64_t src_inc3 = kNumChunkedSrcRows;
std::int64_t src_inc4 = kNumChunkedSrcRows;
std::int64_t src_inc5 = kNumChunkedSrcRows;
std::int64_t src_inc6 = kNumChunkedSrcRows;
std::int64_t src_inc7 = kNumChunkedSrcRows;
if (remaining_src_cols < 8) {
if (remaining_src_cols <= 0) {
src_ptr0 = zerobuf;
src_inc0 = 0;
}
if (remaining_src_cols <= 1) {
src_ptr1 = zerobuf;
src_inc1 = 0;
}
if (remaining_src_cols <= 2) {
src_ptr2 = zerobuf;
src_inc2 = 0;
}
if (remaining_src_cols <= 3) {
src_ptr3 = zerobuf;
src_inc3 = 0;
}
if (remaining_src_cols <= 4) {
src_ptr4 = zerobuf;
src_inc4 = 0;
}
if (remaining_src_cols <= 5) {
src_ptr5 = zerobuf;
src_inc5 = 0;
}
if (remaining_src_cols <= 6) {
src_ptr6 = zerobuf;
src_inc6 = 0;
}
src_ptr7 = zerobuf;
src_inc7 = 0;
}
const std::int16_t zero_point = zerobuf[0];
if (sums_ptr) {
for (int i = 0; i < 8; ++i) {
sums_ptr[i] = 0;
}
}
std::int32_t sums_adjustment = 0;
const __m512i ones_16bit = _mm512_set1_epi16(1);
__m512i sums_8x2_32bit = _mm512_set1_epi32(0);
for (int k = 0; k < src_rows; k += 2 * kNumChunkedSrcRows) {
for (int m = 0; m < 2; ++m) {
const int available_src_rows = src_rows - k - m * kNumChunkedSrcRows;
if (available_src_rows > 0) {
__m512i t0, t1, t2, t3;
__m512i r0, r1, r2, r3;
std::int16_t* dst_ptr = packed_ptr;
if (available_src_rows >= kNumChunkedSrcRows) {
t0 = LoaduTwo(src_ptr0, src_ptr4);
t1 = LoaduTwo(src_ptr1, src_ptr5);
t2 = LoaduTwo(src_ptr2, src_ptr6);
t3 = LoaduTwo(src_ptr3, src_ptr7);
} else {
RUY_DCHECK_LT(available_src_rows >> 2, kNumChunkedSrcRows);
sums_adjustment +=
-(zero_point)*4 * (4 - ((available_src_rows + 3) >> 2));
const __m256i zero_point_v = _mm256_set1_epi16(zero_point);
const __mmask32 row_mask =
(static_cast<std::uint64_t>(1) << available_src_rows) - 1;
t0 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr0, src_ptr4);
t1 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr1, src_ptr5);
t2 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr2, src_ptr6);
t3 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr3, src_ptr7);
dst_ptr = trailing_buf;
}
r0 = _mm512_unpacklo_epi64(t0, t1);
r2 = _mm512_unpackhi_epi64(t0, t1);
r1 = _mm512_unpacklo_epi64(t2, t3);
r3 = _mm512_unpackhi_epi64(t2, t3);
r1 = _mm512_permutex_epi64(r1, 0x4e);
r3 = _mm512_permutex_epi64(r3, 0x4e);
t0 = _mm512_mask_blend_epi64(0xcc, r0, r1);
t1 = _mm512_mask_blend_epi64(0x33, r0, r1);
t2 = _mm512_mask_blend_epi64(0xcc, r2, r3);
t3 = _mm512_mask_blend_epi64(0x33, r2, r3);
t1 = _mm512_permutex_epi64(t1, 0x4e);
t3 = _mm512_permutex_epi64(t3, 0x4e);
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst_ptr + 0 * 16 * 4),
t0);
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst_ptr + 2 * 16 * 4),
t1);
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst_ptr + 1 * 16 * 4),
t2);
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst_ptr + 3 * 16 * 4),
t3);
if (sums_ptr) {
sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit,
_mm512_madd_epi16(t0, ones_16bit));
sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit,
_mm512_madd_epi16(t1, ones_16bit));
sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit,
_mm512_madd_epi16(t2, ones_16bit));
sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit,
_mm512_madd_epi16(t3, ones_16bit));
}
}
packed_ptr += 16 * kNumChunkedSrcRows;
src_ptr0 += src_inc0;
src_ptr1 += src_inc1;
src_ptr2 += src_inc2;
src_ptr3 += src_inc3;
src_ptr4 += src_inc4;
src_ptr5 += src_inc5;
src_ptr6 += src_inc6;
src_ptr7 += src_inc7;
}
}
if (sums_ptr) {
const __m256i sums_adjustment_v = _mm256_set1_epi32(sums_adjustment);
__m256i sums =
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(sums_ptr));
const __m512i idx =
_mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0);
const __m512i sums_2x8_32bit =
_mm512_permutexvar_epi32(idx, sums_8x2_32bit);
sums = _mm256_add_epi32(sums, sums_adjustment_v);
sums = _mm256_add_epi32(sums, _mm512_castsi512_si256(sums_2x8_32bit));
sums = _mm256_add_epi32(sums, _mm512_extracti32x8_epi32(sums_2x8_32bit, 1));
_mm256_storeu_si256(reinterpret_cast<__m256i*>(sums_ptr), sums);
}
}
inline __m512 LoaduTwo(const float* addr_lo, const float* addr_hi) {
const __m512 lower_filled = _mm512_castps256_ps512(_mm256_loadu_ps(addr_lo));
return _mm512_insertf32x8(lower_filled, _mm256_loadu_ps(addr_hi), 1);
}
inline __m512 MaskLoaduTwo(__mmask8 row_mask, const float* addr_lo,
const float* addr_hi) {
const __m512 lower_filled =
_mm512_castps256_ps512(_mm256_maskz_loadu_ps(row_mask, addr_lo));
return _mm512_insertf32x8(lower_filled,
_mm256_maskz_loadu_ps(row_mask, addr_hi), 1);
}
inline __m512 Mm512UnpackloPsx2(const __m512 a, const __m512 b) {
return _mm512_castpd_ps(
_mm512_unpacklo_pd(_mm512_castps_pd(a), _mm512_castps_pd(b)));
}
inline __m512 Mm512UnpackhiPsx2(const __m512 a, const __m512 b) {
return _mm512_castpd_ps(
_mm512_unpackhi_pd(_mm512_castps_pd(a), _mm512_castps_pd(b)));
}
inline void HalfPackFloatAvx512(const float* src_ptr, const float* zerobuf,
int src_stride, int remaining_src_cols,
int src_rows, float* packed_ptr,
float* trailing_buf) {
const float* src_ptr0 = src_ptr;
const float* src_ptr1 = src_ptr0 + src_stride;
const float* src_ptr2 = src_ptr1 + src_stride;
const float* src_ptr3 = src_ptr2 + src_stride;
const float* src_ptr4 = src_ptr3 + src_stride;
const float* src_ptr5 = src_ptr4 + src_stride;
const float* src_ptr6 = src_ptr5 + src_stride;
const float* src_ptr7 = src_ptr6 + src_stride;
std::int64_t src_inc0 = 8;
std::int64_t src_inc1 = 8;
std::int64_t src_inc2 = 8;
std::int64_t src_inc3 = 8;
std::int64_t src_inc4 = 8;
std::int64_t src_inc5 = 8;
std::int64_t src_inc6 = 8;
std::int64_t src_inc7 = 8;
if (remaining_src_cols < 8) {
if (remaining_src_cols <= 0) {
src_ptr0 = zerobuf;
src_inc0 = 0;
}
if (remaining_src_cols <= 1) {
src_ptr1 = zerobuf;
src_inc1 = 0;
}
if (remaining_src_cols <= 2) {
src_ptr2 = zerobuf;
src_inc2 = 0;
}
if (remaining_src_cols <= 3) {
src_ptr3 = zerobuf;
src_inc3 = 0;
}
if (remaining_src_cols <= 4) {
src_ptr4 = zerobuf;
src_inc4 = 0;
}
if (remaining_src_cols <= 5) {
src_ptr5 = zerobuf;
src_inc5 = 0;
}
if (remaining_src_cols <= 6) {
src_ptr6 = zerobuf;
src_inc6 = 0;
}
src_ptr7 = zerobuf;
src_inc7 = 0;
}
for (int k = 0; k < src_rows; k += 16) {
for (int m = 0; m < 2; ++m) {
const int available_src_rows = src_rows - k - 8 * m;
if (available_src_rows > 7) {
__m512 t0, t1, t2, t3;
__m512 r0, r1, r2, r3;
t0 = LoaduTwo(src_ptr0, src_ptr4);
t1 = LoaduTwo(src_ptr1, src_ptr5);
t2 = LoaduTwo(src_ptr2, src_ptr6);
t3 = LoaduTwo(src_ptr3, src_ptr7);
r0 = _mm512_unpacklo_ps(t0, t1);
r2 = _mm512_unpackhi_ps(t0, t1);
r1 = _mm512_unpacklo_ps(t2, t3);
r3 = _mm512_unpackhi_ps(t2, t3);
t0 = Mm512UnpackloPsx2(r0, r1);
t2 = Mm512UnpackhiPsx2(r0, r1);
t1 = Mm512UnpackloPsx2(r2, r3);
t3 = Mm512UnpackhiPsx2(r2, r3);
r0 = _mm512_shuffle_f32x4(t0, t1, 0x88);
r1 = _mm512_shuffle_f32x4(t0, t1, 0xdd);
r2 = _mm512_shuffle_f32x4(t2, t3, 0x88);
r3 = _mm512_shuffle_f32x4(t2, t3, 0xdd);
_mm256_storeu_ps(packed_ptr + 0 * 16, _mm512_castps512_ps256(r0));
_mm256_storeu_ps(packed_ptr + 2 * 16, _mm512_extractf32x8_ps(r0, 1));
_mm256_storeu_ps(packed_ptr + 4 * 16, _mm512_castps512_ps256(r1));
_mm256_storeu_ps(packed_ptr + 6 * 16, _mm512_extractf32x8_ps(r1, 1));
_mm256_storeu_ps(packed_ptr + 1 * 16, _mm512_castps512_ps256(r2));
_mm256_storeu_ps(packed_ptr + 3 * 16, _mm512_extractf32x8_ps(r2, 1));
_mm256_storeu_ps(packed_ptr + 5 * 16, _mm512_castps512_ps256(r3));
_mm256_storeu_ps(packed_ptr + 7 * 16, _mm512_extractf32x8_ps(r3, 1));
} else if (available_src_rows > 0) {
const __mmask8 row_mask =
(static_cast<std::uint32_t>(1) << available_src_rows) - 1;
__m512 t0, t1, t2, t3;
__m512 r0, r1, r2, r3;
t0 = MaskLoaduTwo(row_mask, src_ptr0, src_ptr4);
t1 = MaskLoaduTwo(row_mask, src_ptr1, src_ptr5);
t2 = MaskLoaduTwo(row_mask, src_ptr2, src_ptr6);
t3 = MaskLoaduTwo(row_mask, src_ptr3, src_ptr7);
r0 = _mm512_unpacklo_ps(t0, t1);
r2 = _mm512_unpackhi_ps(t0, t1);
r1 = _mm512_unpacklo_ps(t2, t3);
r3 = _mm512_unpackhi_ps(t2, t3);
t0 = Mm512UnpackloPsx2(r0, r1);
t2 = Mm512UnpackhiPsx2(r0, r1);
t1 = Mm512UnpackloPsx2(r2, r3);
t3 = Mm512UnpackhiPsx2(r2, r3);
r0 = _mm512_shuffle_f32x4(t0, t1, 0x88);
r1 = _mm512_shuffle_f32x4(t0, t1, 0xdd);
r2 = _mm512_shuffle_f32x4(t2, t3, 0x88);
r3 = _mm512_shuffle_f32x4(t2, t3, 0xdd);
_mm256_storeu_ps(trailing_buf + 0 * 16, _mm512_castps512_ps256(r0));
_mm256_storeu_ps(trailing_buf + 2 * 16, _mm512_extractf32x8_ps(r0, 1));
_mm256_storeu_ps(trailing_buf + 4 * 16, _mm512_castps512_ps256(r1));
_mm256_storeu_ps(trailing_buf + 6 * 16, _mm512_extractf32x8_ps(r1, 1));
_mm256_storeu_ps(trailing_buf + 1 * 16, _mm512_castps512_ps256(r2));
_mm256_storeu_ps(trailing_buf + 3 * 16, _mm512_extractf32x8_ps(r2, 1));
_mm256_storeu_ps(trailing_buf + 5 * 16, _mm512_castps512_ps256(r3));
}
packed_ptr += 16 * 8;
src_ptr0 += src_inc0;
src_ptr1 += src_inc1;
src_ptr2 += src_inc2;
src_ptr3 += src_inc3;
src_ptr4 += src_inc4;
src_ptr5 += src_inc5;
src_ptr6 += src_inc6;
src_ptr7 += src_inc7;
}
}
}
inline void ZeroHalfFloatAvx512(int src_rows, float* packed_ptr) {
const int non_trailing_rows = src_rows & ~7;
for (int k = 0; k < non_trailing_rows; ++k) {
for (int j = 0; j < 8; ++j) {
packed_ptr[j] = 0.0f;
}
packed_ptr += 16;
}
}
}
void Pack8bitColMajorForAvx512(const std::int8_t* src_ptr,
std::int8_t input_xor,
const std::int8_t* zerobuf, int src_stride,
int remaining_src_cols, int src_rows,
std::int8_t* packed_ptr,
std::int32_t* sums_ptr) {
profiler::ScopeLabel label("Pack kAvx512 8bit");
using Layout = PackImpl8bitAvx512::Layout;
constexpr int kHalfBlockOffset = 32;
RUY_DCHECK_EQ(kHalfBlockOffset * 2, Layout::kRows * Layout::kCols);
static constexpr int kHalfLayoutCols =
PackImpl8bitAvx512::kHalfLayoutCols;
RUY_DCHECK_EQ(kHalfLayoutCols, 8);
RUY_DCHECK_EQ(Layout::kCols, 16);
RUY_DCHECK_EQ(Layout::kRows, 4);
constexpr int kNumRowChunks = 8;
constexpr int kTrailingBufSize =
kNumRowChunks * Layout::kCols * Layout::kRows;
std::int8_t trailing_buf[kTrailingBufSize];
memset(trailing_buf, 0, kTrailingBufSize * sizeof(std::int8_t));
constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1;
std::int32_t* second_sums_ptr =
sums_ptr ? sums_ptr + kHalfLayoutCols : nullptr;
if (remaining_src_cols > kHalfLayoutCols) {
HalfPack8bitAvx512(src_ptr, input_xor, zerobuf, src_stride,
remaining_src_cols, src_rows, packed_ptr, sums_ptr,
trailing_buf);
HalfPack8bitAvx512(src_ptr + src_stride * kHalfLayoutCols, input_xor,
zerobuf, src_stride,
remaining_src_cols - kHalfLayoutCols, src_rows,
packed_ptr + kHalfBlockOffset, second_sums_ptr,
trailing_buf + kHalfBlockOffset);
} else {
HalfPack8bitAvx512(src_ptr, input_xor, zerobuf, src_stride,
remaining_src_cols, src_rows, packed_ptr, sums_ptr,
trailing_buf);
ZeroHalfAvx512<PackImpl8bitAvx512, std::int8_t>(
src_rows, zerobuf[0] ^ input_xor, packed_ptr + kHalfBlockOffset,
kChunkedRowMask);
if (second_sums_ptr) {
for (int i = 0; i < kHalfLayoutCols; ++i) {
second_sums_ptr[i] = (zerobuf[0] ^ input_xor) * ((src_rows + 3) & ~3);
}
}
}
const bool trailing_data = (src_rows & kChunkedRowMask) > 0;
if (trailing_data) {
const int non_trailing_rows = src_rows & ~kChunkedRowMask;
const int dst_rows = (src_rows + 3) & ~3;
const int trailing_rows = dst_rows - non_trailing_rows;
memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf,
Layout::kCols * trailing_rows * sizeof(std::int8_t));
}
}
void Pack16bitColMajorForAvx512(const std::int16_t* src_ptr,
const std::int16_t* zerobuf, int src_stride,
int remaining_src_cols, int src_rows,
std::int16_t* packed_ptr,
std::int32_t* sums_ptr) {
profiler::ScopeLabel label("Pack kAvx512 16bit");
using Layout = PackImpl16bitAvx512::Layout;
constexpr int kHalfBlockOffset = 32;
RUY_DCHECK_EQ(kHalfBlockOffset * 2, Layout::kRows * Layout::kCols);
static constexpr int kHalfLayoutCols =
PackImpl16bitAvx512::kHalfLayoutCols;
RUY_DCHECK_EQ(kHalfLayoutCols, 8);
RUY_DCHECK_EQ(Layout::kCols, 16);
RUY_DCHECK_EQ(Layout::kRows, 4);
constexpr int kNumRowChunks = 4;
constexpr int kTrailingBufSize =
kNumRowChunks * Layout::kCols * Layout::kRows;
std::int16_t trailing_buf[kTrailingBufSize] = {0};
constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1;
std::int32_t* second_sums_ptr =
sums_ptr ? sums_ptr + kHalfLayoutCols : nullptr;
if (remaining_src_cols > kHalfLayoutCols) {
HalfPack16bitAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols,
src_rows, packed_ptr, sums_ptr, trailing_buf);
HalfPack16bitAvx512(src_ptr + src_stride * kHalfLayoutCols, zerobuf,
src_stride, remaining_src_cols - kHalfLayoutCols,
src_rows, packed_ptr + kHalfBlockOffset,
second_sums_ptr, trailing_buf + kHalfBlockOffset);
} else {
HalfPack16bitAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols,
src_rows, packed_ptr, sums_ptr, trailing_buf);
ZeroHalfAvx512<PackImpl16bitAvx512, std::int16_t>(
src_rows, zerobuf[0], packed_ptr + kHalfBlockOffset, kChunkedRowMask);
if (second_sums_ptr) {
for (int i = 0; i < kHalfLayoutCols; ++i) {
second_sums_ptr[i] = (zerobuf[0]) * ((src_rows + 3) & ~3);
}
}
}
const bool trailing_data = (src_rows & kChunkedRowMask) > 0;
if (trailing_data) {
const int non_trailing_rows = src_rows & ~kChunkedRowMask;
const int dst_rows = (src_rows + 3) & ~3;
const int trailing_rows = dst_rows - non_trailing_rows;
memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf,
Layout::kCols * trailing_rows * sizeof(std::int16_t));
}
}
void PackFloatColMajorForAvx512(const float* src_ptr, const float* zerobuf,
int src_stride, int remaining_src_cols,
int src_rows, float* packed_ptr) {
profiler::ScopeLabel label("Pack kAvx512 float");
float trailing_buf[7 * 16];
if (remaining_src_cols > 8) {
HalfPackFloatAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols,
src_rows, packed_ptr, trailing_buf);
HalfPackFloatAvx512(src_ptr + src_stride * 8, zerobuf, src_stride,
remaining_src_cols - 8, src_rows, packed_ptr + 8,
trailing_buf + 8);
} else {
memset(trailing_buf, 0, sizeof(trailing_buf));
HalfPackFloatAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols,
src_rows, packed_ptr, trailing_buf);
ZeroHalfFloatAvx512(src_rows, packed_ptr + 8);
}
const int trailing_rows = src_rows & 7;
if (trailing_rows > 0) {
const int non_trailing_rows = src_rows & ~7;
memcpy(packed_ptr + 16 * non_trailing_rows, trailing_buf,
16 * trailing_rows * sizeof(float));
}
}
void Pack8bitRowMajorForAvx512(const std::uint8_t* src_ptr, int src_stride,
int src_zero_point, std::int8_t* packed_ptr,
int packed_stride, int start_col, int end_col,
int src_cols, int block_row, int src_rows,
int input_xor, std::int32_t* sums) {
int col = start_col;
int src_end_col = std::min(end_col, src_cols);
for (; col <= src_end_col - 16; col += 16) {
std::int8_t* dst_ptr = packed_ptr;
__m128i val0, val1, val2, val3;
__m128i input_xor_dup = _mm_set1_epi8(input_xor);
if (block_row + 4 <= src_rows) {
val0 = _mm_loadu_si128(
reinterpret_cast<const __m128i*>(src_ptr + 0 * src_stride));
val1 = _mm_loadu_si128(
reinterpret_cast<const __m128i*>(src_ptr + 1 * src_stride));
val2 = _mm_loadu_si128(
reinterpret_cast<const __m128i*>(src_ptr + 2 * src_stride));
val3 = _mm_loadu_si128(
reinterpret_cast<const __m128i*>(src_ptr + 3 * src_stride));
} else {
val0 = _mm_set1_epi8(src_zero_point);
val1 = val0;
val2 = val0;
val3 = val0;
if (block_row + 0 < src_rows)
val0 = _mm_loadu_si128(
reinterpret_cast<const __m128i*>(src_ptr + 0 * src_stride));
if (block_row + 1 < src_rows)
val1 = _mm_loadu_si128(
reinterpret_cast<const __m128i*>(src_ptr + 1 * src_stride));
if (block_row + 2 < src_rows)
val2 = _mm_loadu_si128(
reinterpret_cast<const __m128i*>(src_ptr + 2 * src_stride));
if (block_row + 3 < src_rows)
val3 = _mm_loadu_si128(
reinterpret_cast<const __m128i*>(src_ptr + 3 * src_stride));
}
val0 = _mm_xor_si128(val0, input_xor_dup);
val1 = _mm_xor_si128(val1, input_xor_dup);
val2 = _mm_xor_si128(val2, input_xor_dup);
val3 = _mm_xor_si128(val3, input_xor_dup);
__m256i val16_0 = _mm256_cvtepi8_epi16(val0);
__m256i val16_1 = _mm256_cvtepi8_epi16(val1);
__m256i val16_2 = _mm256_cvtepi8_epi16(val2);
__m256i val16_3 = _mm256_cvtepi8_epi16(val3);
__m256i new_sum16 = _mm256_add_epi16(_mm256_add_epi16(val16_0, val16_1),
_mm256_add_epi16(val16_2, val16_3));
__m512i sum =
_mm512_loadu_si512(reinterpret_cast<const __m512i*>(sums + col));
sum = _mm512_add_epi32(sum, _mm512_cvtepi16_epi32(new_sum16));
_mm512_storeu_si512(reinterpret_cast<__m512i*>(sums + col), sum);
auto zip = [](__m128i x, __m128i y) {
auto perm_64_0_64_0 = [](__m128i x) {
return _mm256_permutexvar_epi64(_mm256_setr_epi64x(0, 2, 1, 3),
_mm256_castsi128_si256(x));
};
return _mm256_unpacklo_epi8(perm_64_0_64_0(x), perm_64_0_64_0(y));
};
__m256i t2_val0 = zip(val0, val1);
__m256i t2_val1 = zip(val2, val3);
__m256i t4_val0 = _mm256_unpacklo_epi16(t2_val0, t2_val1);
__m256i t4_val1 = _mm256_unpackhi_epi16(t2_val0, t2_val1);
_mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr),
_mm256_extractf128_si256(t4_val0, 0));
_mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr + 16),
_mm256_extractf128_si256(t4_val1, 0));
_mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr + 32),
_mm256_extractf128_si256(t4_val0, 1));
_mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr + 48),
_mm256_extractf128_si256(t4_val1, 1));
src_ptr += 16;
packed_ptr += packed_stride * 16;
}
for (; col < src_end_col; col++) {
std::int32_t accum = 0;
for (int r = 0; r < 4; r++) {
std::int8_t packed_val;
if (block_row + r < src_rows) {
packed_val = input_xor ^ src_ptr[r * src_stride];
} else {
packed_val = input_xor ^ src_zero_point;
}
accum += packed_val;
*packed_ptr++ = packed_val;
}
if (sums) {
sums[col] += accum;
}
src_ptr++;
}
for (; col < end_col; col++) {
std::memset(packed_ptr, 0, 4);
packed_ptr += 4;
}
}
#endif
}