#ifndef RUY_RUY_KERNEL_ARM_H_
#define RUY_RUY_KERNEL_ARM_H_
#include <cstddef>
#include <cstdint>
#include "ruy/asm_helpers.h"
#include "ruy/kernel_common.h"
#include "ruy/mat.h"
#include "ruy/mul_params.h"
#include "ruy/opt_set.h"
#include "ruy/path.h"
#include "ruy/platform.h"
#include "ruy/profiler/instrumentation.h"
#include "ruy/side_pair.h"
#include "ruy/size_util.h"
#include "ruy/tune.h"
namespace ruy {
#if RUY_PLATFORM_NEON && RUY_OPT(ASM)
RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kNeon)
RUY_INHERIT_KERNEL(Path::kNeon, Path::kNeonDotprod)
#if RUY_PLATFORM_NEON_64
void Kernel8bitNeon(const KernelParams8bit<4, 4>& params);
void Kernel8bitNeon1Col(const KernelParams8bit<4, 4>& params);
#elif RUY_PLATFORM_NEON_32
void Kernel8bitNeon(const KernelParams8bit<4, 2>& params);
void Kernel8bitNeon1Col(const KernelParams8bit<4, 2>& params);
#endif
void Kernel8bitNeonA55ish(const KernelParams8bit<4, 4>& params);
void Kernel8bitNeonDotprod(const KernelParams8bit<8, 8>& params);
void Kernel8bitNeonDotprod1Col(const KernelParams8bit<8, 8>& params);
void Kernel8bitNeonDotprodA55ish(const KernelParams8bit<8, 8>& params);
void Kernel8bitNeonDotprodX1(const KernelParams8bit<8, 8>& params);
#if RUY_PLATFORM_NEON_64
template <typename DstScalar>
struct Kernel<Path::kNeon, std::int8_t, std::int8_t, std::int32_t, DstScalar> {
static constexpr Path kPath = Path::kNeon;
using LhsLayout = FixedKernelLayout<Order::kColMajor, 16, 4>;
using RhsLayout = FixedKernelLayout<Order::kColMajor, 16, 4>;
Tuning tuning = Tuning::kAuto;
explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs,
const MulParams<std::int32_t, DstScalar>& mul_params, int start_row,
int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const {
KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row,
end_col, dst, ¶ms);
if (dst->layout.cols == 1 &&
mul_params.channel_dimension() == ChannelDimension::kRow) {
Kernel8bitNeon1Col(params);
return;
}
if (__builtin_expect(tuning == Tuning::kA55ish, true)) {
Kernel8bitNeonA55ish(params);
} else {
Kernel8bitNeon(params);
}
}
};
#endif
#if RUY_PLATFORM_NEON_32
template <typename DstScalar>
struct Kernel<Path::kNeon, std::int8_t, std::int8_t, std::int32_t, DstScalar> {
static constexpr Path kPath = Path::kNeon;
using LhsLayout = FixedKernelLayout<Order::kColMajor, 16, 4>;
using RhsLayout = FixedKernelLayout<Order::kColMajor, 16, 2>;
Tuning tuning = Tuning::kAuto;
explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs,
const MulParams<std::int32_t, DstScalar>& mul_params, int start_row,
int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const {
KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row,
end_col, dst, ¶ms);
if (dst->layout.cols == 1 &&
mul_params.channel_dimension() == ChannelDimension::kRow) {
Kernel8bitNeon1Col(params);
return;
}
Kernel8bitNeon(params);
}
};
#endif
#if RUY_PLATFORM_NEON_64
template <typename DstScalar>
struct Kernel<Path::kNeonDotprod, std::int8_t, std::int8_t, std::int32_t,
DstScalar> {
static constexpr Path kPath = Path::kNeonDotprod;
Tuning tuning = Tuning::kAuto;
using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs,
const MulParams<std::int32_t, DstScalar>& mul_params, int start_row,
int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const {
KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row,
end_col, dst, ¶ms);
if (dst->layout.cols == 1 &&
mul_params.channel_dimension() == ChannelDimension::kRow) {
Kernel8bitNeonDotprod1Col(params);
} else if (__builtin_expect(tuning == Tuning::kA55ish, true)) {
Kernel8bitNeonDotprodA55ish(params);
} else if (tuning == Tuning::kX1) {
Kernel8bitNeonDotprodX1(params);
} else {
Kernel8bitNeonDotprod(params);
}
}
};
#endif
void KernelFloatNeon(const KernelParamsFloat<8, 8>& params);
void KernelFloatNeonX1(const KernelParamsFloat<8, 8>& params);
void KernelFloatNeonA55ish(const KernelParamsFloat<8, 8>& params);
void KernelFloat32Neon(const KernelParamsFloat<8, 4>& params);
void KernelFloatNeonDotprodA55ish(const KernelParamsFloat<8, 8>& params);
#if RUY_PLATFORM_NEON_64
template <>
struct Kernel<Path::kNeon, float, float, float, float> {
static constexpr Path kPath = Path::kNeon;
Tuning tuning = Tuning::kAuto;
using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
void Run(const PMat<float>& lhs, const PMat<float>& rhs,
const MulParams<float, float>& mul_params, int start_row,
int start_col, int end_row, int end_col, Mat<float>* dst) const {
KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params;
MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row,
end_col, dst, ¶ms);
if (__builtin_expect(tuning == Tuning::kA55ish, true)) {
KernelFloatNeonA55ish(params);
} else if (tuning == Tuning::kX1) {
KernelFloatNeonX1(params);
} else {
KernelFloatNeon(params);
}
}
};
#endif
#if RUY_PLATFORM_NEON_32
template <>
struct Kernel<Path::kNeon, float, float, float, float> {
static constexpr Path kPath = Path::kNeon;
Tuning tuning = Tuning::kAuto;
using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 4>;
explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
void Run(const PMat<float>& lhs, const PMat<float>& rhs,
const MulParams<float, float>& mul_params, int start_row,
int start_col, int end_row, int end_col, Mat<float>* dst) const {
KernelParamsFloat<8, 4> params;
MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row,
end_col, dst, ¶ms);
KernelFloat32Neon(params);
}
};
#endif
template <>
struct Kernel<Path::kNeonDotprod, float, float, float, float> {
static constexpr Path kPath = Path::kNeonDotprod;
Tuning tuning = Tuning::kAuto;
using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
using Base = Kernel<Path::kNeon, float, float, float, float>;
explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
void Run(const PMat<float>& lhs, const PMat<float>& rhs,
const MulParams<float, float>& mul_params, int start_row,
int start_col, int end_row, int end_col, Mat<float>* dst) const {
KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params;
MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row,
end_col, dst, ¶ms);
if (__builtin_expect(tuning == Tuning::kA55ish, true)) {
KernelFloatNeonDotprodA55ish(params);
} else if (tuning == Tuning::kX1) {
KernelFloatNeonX1(params);
} else {
KernelFloatNeon(params);
}
}
};
#endif
}
#endif