#ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_GPU_H
#define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_GPU_H
#if defined(EIGEN_USE_GPU) && defined(EIGEN_GPUCC)
#include "./InternalHeaderCheck.h"
namespace Eigen {
template <typename Scalar, typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper,
bool needs_edge_check>
__device__ EIGEN_STRONG_INLINE void EigenContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
const OutputMapper output, Scalar* lhs_shmem,
Scalar* rhs_shmem, const Index m_size,
const Index n_size, const Index k_size) {
const Index m_block_idx = blockIdx.x;
const Index n_block_idx = blockIdx.y;
const Index base_m = 64 * m_block_idx;
const Index base_n = 64 * n_block_idx;
Scalar lhs_pf0;
Scalar lhs_pf1;
Scalar lhs_pf2;
Scalar lhs_pf3;
Scalar lhs_pf4;
Scalar lhs_pf5;
Scalar lhs_pf6;
Scalar lhs_pf7;
Scalar rhs_pf0;
Scalar rhs_pf1;
Scalar rhs_pf2;
Scalar rhs_pf3;
Scalar rhs_pf4;
Scalar rhs_pf5;
Scalar rhs_pf6;
Scalar rhs_pf7;
const Index lhs_store_idx_base = threadIdx.y * 72 + threadIdx.x * 9 + threadIdx.z;
const Index rhs_store_idx_base = threadIdx.y * 72 + threadIdx.z * 8 + threadIdx.x;
const Index lhs_store_idx_0 = lhs_store_idx_base + 576 * 0;
const Index lhs_store_idx_1 = lhs_store_idx_base + 576 * 1;
const Index lhs_store_idx_2 = lhs_store_idx_base + 576 * 2;
const Index lhs_store_idx_3 = lhs_store_idx_base + 576 * 3;
const Index lhs_store_idx_4 = lhs_store_idx_base + 576 * 4;
const Index lhs_store_idx_5 = lhs_store_idx_base + 576 * 5;
const Index lhs_store_idx_6 = lhs_store_idx_base + 576 * 6;
const Index lhs_store_idx_7 = lhs_store_idx_base + 576 * 7;
const Index rhs_store_idx_0 = rhs_store_idx_base + 576 * 0;
const Index rhs_store_idx_1 = rhs_store_idx_base + 576 * 1;
const Index rhs_store_idx_2 = rhs_store_idx_base + 576 * 2;
const Index rhs_store_idx_3 = rhs_store_idx_base + 576 * 3;
const Index rhs_store_idx_4 = rhs_store_idx_base + 576 * 4;
const Index rhs_store_idx_5 = rhs_store_idx_base + 576 * 5;
const Index rhs_store_idx_6 = rhs_store_idx_base + 576 * 6;
const Index rhs_store_idx_7 = rhs_store_idx_base + 576 * 7;
const Index load_idx_vert = threadIdx.x + 8 * threadIdx.y;
const Index lhs_vert = base_m + load_idx_vert;
#define prefetchIntoRegisters …
#define writeRegToShmem …
#define res …
#define initResultRow …
internal::scalar_cast_op<int, Scalar> conv;
initResultRow(0);
initResultRow(1);
initResultRow(2);
initResultRow(3);
initResultRow(4);
initResultRow(5);
initResultRow(6);
initResultRow(7);
#undef initResultRow
for (Index base_k = 0; base_k < k_size; base_k += 64) {
__syncthreads();
prefetchIntoRegisters(base_k);
writeRegToShmem();
#undef prefetchIntoRegisters
#undef writeRegToShmem
__syncthreads();
#define lcol …
Scalar lcol(0);
Scalar lcol(1);
Scalar lcol(2);
Scalar lcol(3);
Scalar lcol(4);
Scalar lcol(5);
Scalar lcol(6);
Scalar lcol(7);
#define rrow …
Scalar rrow(0);
Scalar rrow(1);
Scalar rrow(2);
Scalar rrow(3);
Scalar rrow(4);
Scalar rrow(5);
Scalar rrow(6);
Scalar rrow(7);
const Scalar* lhs_block = &lhs_shmem[threadIdx.x + 9 * threadIdx.y];
const Scalar* rhs_block = &rhs_shmem[threadIdx.x + 8 * threadIdx.z];
#define lhs_element …
#define rhs_element …
#define loadData …
#define computeCol …
#define computePass …
computePass(0);
computePass(1);
computePass(2);
computePass(3);
computePass(4);
computePass(5);
computePass(6);
computePass(7);
#undef lcol
#undef rrow
#undef lhs_element
#undef rhs_element
#undef loadData
#undef computeCol
#undef computePass
}
#if defined(EIGEN_HIPCC) || (defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER < 90000)
#define shuffleInc …
#else
#define shuffleInc …
#endif
#define reduceRow …
#define reduceMatrix …
reduceMatrix(1);
reduceMatrix(2);
reduceMatrix(4);
#undef shuffleInc
#undef reduceRow
#undef reduceMatrix
__syncthreads();
#define writeResultShmem …
#define writeRow …
if (threadIdx.x == 0) {
writeRow(0);
writeRow(1);
writeRow(2);
writeRow(3);
writeRow(4);
writeRow(5);
writeRow(6);
writeRow(7);
}
#undef writeResultShmem
#undef writeRow
const int max_i_write = numext::mini((int)((m_size - base_m - threadIdx.y + 7) / 8), 8);
const int max_j_write = numext::mini((int)((n_size - base_n - threadIdx.z + 7) / 8), 8);
if (threadIdx.x < max_i_write) {
if (max_j_write == 8) {
Scalar val0 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 0];
Scalar val1 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 1];
Scalar val2 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 2];
Scalar val3 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 3];
Scalar val4 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 4];
Scalar val5 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 5];
Scalar val6 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 6];
Scalar val7 = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * 7];
output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 0) = val0;
output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 1) = val1;
output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 2) = val2;
output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 3) = val3;
output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 4) = val4;
output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 5) = val5;
output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 6) = val6;
output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * 7) = val7;
} else {
#pragma unroll 7
for (int j = 0; j < max_j_write; j++) {
Scalar val = lhs_shmem[threadIdx.x + 8 * threadIdx.y + 64 * threadIdx.z + 512 * j];
output(base_m + threadIdx.y + 8 * threadIdx.x, base_n + threadIdx.z + 8 * j) = val;
}
}
}
#undef res
}
template <typename Scalar, typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper>
__global__ void
#if defined(EIGEN_HIPCC)
__launch_bounds__(512, 1)
#else
__launch_bounds__(512)
#endif
EigenContractionKernel(const LhsMapper lhs, const RhsMapper rhs, const OutputMapper output, const Index m_size,
const Index n_size, const Index k_size) {
__shared__ Scalar lhs_shmem[72 * 64];
__shared__ Scalar rhs_shmem[72 * 64];
const Index m_block_idx = blockIdx.x;
const Index n_block_idx = blockIdx.y;
const Index base_m = 64 * m_block_idx;
const Index base_n = 64 * n_block_idx;
if (base_m + 63 < m_size && base_n + 63 < n_size) {
EigenContractionKernelInternal<Scalar, Index, LhsMapper, RhsMapper, OutputMapper, false>(
lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size);
} else {
EigenContractionKernelInternal<Scalar, Index, LhsMapper, RhsMapper, OutputMapper, true>(
lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size);
}
}
template <typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper, bool CHECK_LHS_BOUNDARY,
bool CHECK_RHS_BOUNDARY>
__device__ __forceinline__ void EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rhs,
const OutputMapper output,
float2 lhs_shmem2[][16],
float2 rhs_shmem2[][8], const Index m_size,
const Index n_size, const Index k_size,
const Index base_m, const Index base_n) {
float4 lhs_pf0, rhs_pf0;
float4 results[4];
for (int i = 0; i < 4; i++) {
results[i].x = results[i].y = results[i].z = results[i].w = 0;
}
#define prefetch_lhs …
Index lhs_vert = base_m + threadIdx.x * 4;
for (Index k = 0; k < k_size; k += 16) {
lhs_pf0 = internal::pset1<float4>(0);
rhs_pf0 = internal::pset1<float4>(0);
Index lhs_horiz = threadIdx.y + k;
prefetch_lhs(lhs_pf0, lhs_vert, lhs_horiz)
Index rhs_vert = k + (threadIdx.x % 4) * 4;
Index rhs_horiz0 = (threadIdx.x >> 2) + threadIdx.y * 4 + base_n;
if (!CHECK_RHS_BOUNDARY) {
if ((rhs_vert + 3) < k_size) {
rhs_pf0 = rhs.template loadPacket<float4, Unaligned>(rhs_vert, rhs_horiz0);
} else if (rhs_vert + 2 < k_size) {
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
} else if (rhs_vert + 1 < k_size) {
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
} else if (rhs_vert < k_size) {
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
}
} else {
if (rhs_horiz0 < n_size) {
if ((rhs_vert + 3) < k_size) {
rhs_pf0 = rhs.template loadPacket<float4, Unaligned>(rhs_vert, rhs_horiz0);
} else if ((rhs_vert + 2) < k_size) {
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
} else if ((rhs_vert + 1) < k_size) {
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
} else if (rhs_vert < k_size) {
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
}
}
}
float x1, x2;
if ((threadIdx.x % 8) < 4) {
x1 = rhs_pf0.y;
x2 = rhs_pf0.w;
} else {
x1 = rhs_pf0.x;
x2 = rhs_pf0.z;
}
#if defined(EIGEN_HIPCC) || (defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER < 90000)
x1 = __shfl_xor(x1, 4);
x2 = __shfl_xor(x2, 4);
#else
x1 = __shfl_xor_sync(0xFFFFFFFF, x1, 4);
x2 = __shfl_xor_sync(0xFFFFFFFF, x2, 4);
#endif
if ((threadIdx.x % 8) < 4) {
rhs_pf0.y = x1;
rhs_pf0.w = x2;
} else {
rhs_pf0.x = x1;
rhs_pf0.z = x2;
}
rhs_shmem2[(threadIdx.x >> 3) + threadIdx.y * 2][threadIdx.x % 8] = make_float2(rhs_pf0.x, rhs_pf0.y);
rhs_shmem2[(threadIdx.x >> 3) + threadIdx.y * 2 + 32][threadIdx.x % 8] = make_float2(rhs_pf0.z, rhs_pf0.w);
lhs_shmem2[threadIdx.y][threadIdx.x] = make_float2(lhs_pf0.x, lhs_pf0.y);
lhs_shmem2[threadIdx.y + 16][threadIdx.x] = make_float2(lhs_pf0.z, lhs_pf0.w);
#define add_vals …
__syncthreads();
#pragma unroll
for (int koff = 0; koff < 16; koff++) {
float2 fl1 = lhs_shmem2[koff][threadIdx.x];
float2 fl2 = lhs_shmem2[koff + 16][threadIdx.x];
int start_feature = threadIdx.y * 4;
float2 fr1 = rhs_shmem2[(start_feature >> 1) + 32 * ((koff % 4) / 2)][koff / 4 + (koff % 2) * 4];
float2 fr2 = rhs_shmem2[(start_feature >> 1) + 1 + 32 * ((koff % 4) / 2)][koff / 4 + (koff % 2) * 4];
add_vals(fl1, fl2, fr1, fr2)
}
__syncthreads();
}
#undef prefetch_lhs
#undef add_vals
Index horiz_base = threadIdx.y * 4 + base_n;
if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) {
for (int i = 0; i < 4; i++) {
output(lhs_vert, horiz_base + i) = results[i].x;
output(lhs_vert + 1, horiz_base + i) = results[i].y;
output(lhs_vert + 2, horiz_base + i) = results[i].z;
output(lhs_vert + 3, horiz_base + i) = results[i].w;
}
} else if (!CHECK_RHS_BOUNDARY) {
if (lhs_vert + 3 < m_size) {
for (int i = 0; i < 4; i++) {
output(lhs_vert, horiz_base + i) = results[i].x;
output(lhs_vert + 1, horiz_base + i) = results[i].y;
output(lhs_vert + 2, horiz_base + i) = results[i].z;
output(lhs_vert + 3, horiz_base + i) = results[i].w;
}
} else if (lhs_vert + 2 < m_size) {
for (int i = 0; i < 4; i++) {
output(lhs_vert, horiz_base + i) = results[i].x;
output(lhs_vert + 1, horiz_base + i) = results[i].y;
output(lhs_vert + 2, horiz_base + i) = results[i].z;
}
} else if (lhs_vert + 1 < m_size) {
for (int i = 0; i < 4; i++) {
output(lhs_vert, horiz_base + i) = results[i].x;
output(lhs_vert + 1, horiz_base + i) = results[i].y;
}
} else if (lhs_vert < m_size) {
for (int i = 0; i < 4; i++) {
output(lhs_vert, horiz_base + i) = results[i].x;
}
}
} else if (!CHECK_LHS_BOUNDARY) {
for (int i = 0; i < 4; i++) {
if (horiz_base + i < n_size) {
output(lhs_vert, horiz_base + i) = results[i].x;
output(lhs_vert + 1, horiz_base + i) = results[i].y;
output(lhs_vert + 2, horiz_base + i) = results[i].z;
output(lhs_vert + 3, horiz_base + i) = results[i].w;
}
}
} else {
for (int i = 0; i < 4; i++) {
if (horiz_base + i < n_size) {
if (lhs_vert < m_size) output(lhs_vert, horiz_base + i) = results[i].x;
if (lhs_vert + 1 < m_size) output(lhs_vert + 1, horiz_base + i) = results[i].y;
if (lhs_vert + 2 < m_size) output(lhs_vert + 2, horiz_base + i) = results[i].z;
if (lhs_vert + 3 < m_size) output(lhs_vert + 3, horiz_base + i) = results[i].w;
}
}
}
}
template <typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper, bool CHECK_LHS_BOUNDARY,
bool CHECK_RHS_BOUNDARY>
__device__ __forceinline__ void EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
const OutputMapper output, float2 lhs_shmem2[][32],
float2 rhs_shmem2[][8], const Index m_size,
const Index n_size, const Index k_size,
const Index base_m, const Index base_n) {
float4 lhs_pf0, lhs_pf1, lhs_pf2, lhs_pf3;
float4 rhs_pf0, rhs_pf1;
float4 results[8];
for (int i = 0; i < 8; i++) {
results[i].x = results[i].y = results[i].z = results[i].w = 0;
}
Index lhs_vert = base_m + threadIdx.x * 4 + (threadIdx.y % 4) * 32;
for (Index k = 0; k < k_size; k += 32) {
lhs_pf0 = internal::pset1<float4>(0);
lhs_pf1 = internal::pset1<float4>(0);
lhs_pf2 = internal::pset1<float4>(0);
lhs_pf3 = internal::pset1<float4>(0);
rhs_pf0 = internal::pset1<float4>(0);
rhs_pf1 = internal::pset1<float4>(0);
if (!CHECK_LHS_BOUNDARY) {
if ((threadIdx.y / 4 + k + 24) < k_size) {
lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k));
lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 8));
lhs_pf2 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 16));
lhs_pf3 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 24));
} else if ((threadIdx.y / 4 + k + 16) < k_size) {
lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k));
lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 8));
lhs_pf2 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 16));
} else if ((threadIdx.y / 4 + k + 8) < k_size) {
lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k));
lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 8));
} else if ((threadIdx.y / 4 + k) < k_size) {
lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k));
}
} else {
if (lhs_vert + 3 < m_size) {
if ((threadIdx.y / 4 + k + 24) < k_size) {
lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k));
lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 8));
lhs_pf2 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 16));
lhs_pf3 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 24));
} else if ((threadIdx.y / 4 + k + 16) < k_size) {
lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k));
lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 8));
lhs_pf2 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 16));
} else if ((threadIdx.y / 4 + k + 8) < k_size) {
lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k));
lhs_pf1 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k + 8));
} else if ((threadIdx.y / 4 + k) < k_size) {
lhs_pf0 = lhs.template loadPacket<float4, Unaligned>(lhs_vert, (threadIdx.y / 4 + k));
}
} else if (lhs_vert + 2 < m_size) {
if ((threadIdx.y / 4 + k + 24) < k_size) {
lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
lhs_pf0.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k));
lhs_pf0.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k));
lhs_pf1.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 8));
lhs_pf1.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 8));
lhs_pf1.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k + 8));
lhs_pf2.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 16));
lhs_pf2.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 16));
lhs_pf2.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k + 16));
lhs_pf3.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 24));
lhs_pf3.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 24));
lhs_pf3.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k + 24));
} else if ((threadIdx.y / 4 + k + 16) < k_size) {
lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
lhs_pf0.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k));
lhs_pf0.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k));
lhs_pf1.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 8));
lhs_pf1.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 8));
lhs_pf1.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k + 8));
lhs_pf2.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 16));
lhs_pf2.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 16));
lhs_pf2.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k + 16));
} else if ((threadIdx.y / 4 + k + 8) < k_size) {
lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
lhs_pf0.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k));
lhs_pf0.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k));
lhs_pf1.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 8));
lhs_pf1.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 8));
lhs_pf1.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k + 8));
} else if ((threadIdx.y / 4 + k) < k_size) {
lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
lhs_pf0.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k));
lhs_pf0.z = lhs(lhs_vert + 2, (threadIdx.y / 4 + k));
}
} else if (lhs_vert + 1 < m_size) {
if ((threadIdx.y / 4 + k + 24) < k_size) {
lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
lhs_pf0.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k));
lhs_pf1.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 8));
lhs_pf1.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 8));
lhs_pf2.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 16));
lhs_pf2.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 16));
lhs_pf3.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 24));
lhs_pf3.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 24));
} else if ((threadIdx.y / 4 + k + 16) < k_size) {
lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
lhs_pf0.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k));
lhs_pf1.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 8));
lhs_pf1.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 8));
lhs_pf2.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 16));
lhs_pf2.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 16));
} else if ((threadIdx.y / 4 + k + 8) < k_size) {
lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
lhs_pf0.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k));
lhs_pf1.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 8));
lhs_pf1.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k + 8));
} else if ((threadIdx.y / 4 + k) < k_size) {
lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
lhs_pf0.y = lhs(lhs_vert + 1, (threadIdx.y / 4 + k));
}
} else if (lhs_vert < m_size) {
if ((threadIdx.y / 4 + k + 24) < k_size) {
lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
lhs_pf1.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 8));
lhs_pf2.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 16));
lhs_pf3.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 24));
} else if ((threadIdx.y / 4 + k + 16) < k_size) {
lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
lhs_pf1.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 8));
lhs_pf2.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 16));
} else if ((threadIdx.y / 4 + k + 8) < k_size) {
lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
lhs_pf1.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k + 8));
} else if ((threadIdx.y / 4 + k) < k_size) {
lhs_pf0.x = lhs(lhs_vert + 0, (threadIdx.y / 4 + k));
}
}
}
__syncthreads();
Index rhs_vert = k + threadIdx.x * 4;
Index rhs_horiz0 = threadIdx.y * 2 + base_n;
Index rhs_horiz1 = threadIdx.y * 2 + 1 + base_n;
if (!CHECK_RHS_BOUNDARY) {
if ((rhs_vert + 3) < k_size) {
rhs_pf0 = rhs.template loadPacket<float4, Unaligned>(rhs_vert, rhs_horiz0);
rhs_pf1 = rhs.template loadPacket<float4, Unaligned>(rhs_vert, rhs_horiz1);
} else if (rhs_vert + 2 < k_size) {
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz1);
} else if (rhs_vert + 1 < k_size) {
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
} else if (rhs_vert < k_size) {
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
}
} else {
if (rhs_horiz1 < n_size) {
if ((rhs_vert + 3) < k_size) {
rhs_pf0 = rhs.template loadPacket<float4, Unaligned>(rhs_vert, rhs_horiz0);
rhs_pf1 = rhs.template loadPacket<float4, Unaligned>(rhs_vert, rhs_horiz1);
} else if (rhs_vert + 2 < k_size) {
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
rhs_pf1.z = rhs(rhs_vert + 2, rhs_horiz1);
} else if (k + threadIdx.x * 4 + 1 < k_size) {
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
rhs_pf1.y = rhs(rhs_vert + 1, rhs_horiz1);
} else if (k + threadIdx.x * 4 < k_size) {
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
rhs_pf1.x = rhs(rhs_vert, rhs_horiz1);
}
} else if (rhs_horiz0 < n_size) {
if ((rhs_vert + 3) < k_size) {
rhs_pf0 = rhs.template loadPacket<float4, Unaligned>(rhs_vert, rhs_horiz0);
} else if ((rhs_vert + 2) < k_size) {
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
rhs_pf0.z = rhs(rhs_vert + 2, rhs_horiz0);
} else if ((rhs_vert + 1) < k_size) {
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
} else if (rhs_vert < k_size) {
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
}
}
}
__syncthreads();
rhs_shmem2[threadIdx.y][threadIdx.x] = make_float2(rhs_pf0.x, rhs_pf1.x);
rhs_shmem2[threadIdx.y + 32][threadIdx.x] = make_float2(rhs_pf0.y, rhs_pf1.y);
rhs_shmem2[threadIdx.y + 64][threadIdx.x] = make_float2(rhs_pf0.z, rhs_pf1.z);
rhs_shmem2[threadIdx.y + 96][threadIdx.x] = make_float2(rhs_pf0.w, rhs_pf1.w);
#define add_vals …
lhs_shmem2[threadIdx.y / 4][threadIdx.x + (threadIdx.y % 4) * 8] = make_float2(lhs_pf0.x, lhs_pf0.y);
lhs_shmem2[threadIdx.y / 4 + 8][threadIdx.x + (threadIdx.y % 4) * 8] = make_float2(lhs_pf1.x, lhs_pf1.y);
lhs_shmem2[threadIdx.y / 4 + 16][threadIdx.x + (threadIdx.y % 4) * 8] = make_float2(lhs_pf2.x, lhs_pf2.y);
lhs_shmem2[threadIdx.y / 4 + 24][threadIdx.x + (threadIdx.y % 4) * 8] = make_float2(lhs_pf3.x, lhs_pf3.y);
lhs_shmem2[threadIdx.y / 4 + 32][threadIdx.x + (threadIdx.y % 4) * 8] = make_float2(lhs_pf0.z, lhs_pf0.w);
lhs_shmem2[threadIdx.y / 4 + 40][threadIdx.x + (threadIdx.y % 4) * 8] = make_float2(lhs_pf1.z, lhs_pf1.w);
lhs_shmem2[threadIdx.y / 4 + 48][threadIdx.x + (threadIdx.y % 4) * 8] = make_float2(lhs_pf2.z, lhs_pf2.w);
lhs_shmem2[threadIdx.y / 4 + 56][threadIdx.x + (threadIdx.y % 4) * 8] = make_float2(lhs_pf3.z, lhs_pf3.w);
__syncthreads();
#pragma unroll
for (int koff = 0; koff < 32; koff++) {
float2 a3 = lhs_shmem2[koff][threadIdx.x + (threadIdx.y % 4) * 8];
float2 a4 = lhs_shmem2[koff + 32][threadIdx.x + (threadIdx.y % 4) * 8];
int start_feature = (threadIdx.y / 4) * 8;
float2 br1 = rhs_shmem2[start_feature / 2 + (koff % 4) * 32][koff / 4];
float2 br2 = rhs_shmem2[start_feature / 2 + 1 + (koff % 4) * 32][koff / 4];
float2 br3 = rhs_shmem2[start_feature / 2 + 2 + (koff % 4) * 32][koff / 4];
float2 br4 = rhs_shmem2[start_feature / 2 + 3 + (koff % 4) * 32][koff / 4];
add_vals(a3, a4, br1, br2, br3, br4)
}
__syncthreads();
}
#undef add_vals
__syncthreads();
Index horiz_base = (threadIdx.y / 4) * 8 + base_n;
if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) {
for (int i = 0; i < 8; i++) {
output(lhs_vert, horiz_base + i) = results[i].x;
output(lhs_vert + 1, horiz_base + i) = results[i].y;
output(lhs_vert + 2, horiz_base + i) = results[i].z;
output(lhs_vert + 3, horiz_base + i) = results[i].w;
}
} else if (!CHECK_RHS_BOUNDARY) {
if (lhs_vert + 3 < m_size) {
for (int i = 0; i < 8; i++) {
output(lhs_vert, horiz_base + i) = results[i].x;
output(lhs_vert + 1, horiz_base + i) = results[i].y;
output(lhs_vert + 2, horiz_base + i) = results[i].z;
output(lhs_vert + 3, horiz_base + i) = results[i].w;
}
} else if (lhs_vert + 2 < m_size) {
for (int i = 0; i < 8; i++) {
output(lhs_vert, horiz_base + i) = results[i].x;
output(lhs_vert + 1, horiz_base + i) = results[i].y;
output(lhs_vert + 2, horiz_base + i) = results[i].z;
}
} else if (lhs_vert + 1 < m_size) {
for (int i = 0; i < 8; i++) {
output(lhs_vert, horiz_base + i) = results[i].x;
output(lhs_vert + 1, horiz_base + i) = results[i].y;
}
} else if (lhs_vert < m_size) {
for (int i = 0; i < 8; i++) {
output(lhs_vert, horiz_base + i) = results[i].x;
}
}
} else if (!CHECK_LHS_BOUNDARY) {
for (int i = 0; i < 8; i++) {
if (horiz_base + i < n_size) {
output(lhs_vert, horiz_base + i) = results[i].x;
output(lhs_vert + 1, horiz_base + i) = results[i].y;
output(lhs_vert + 2, horiz_base + i) = results[i].z;
output(lhs_vert + 3, horiz_base + i) = results[i].w;
}
}
} else {
for (int i = 0; i < 8; i++) {
if (horiz_base + i < n_size) {
if (lhs_vert < m_size) output(lhs_vert, horiz_base + i) = results[i].x;
if (lhs_vert + 1 < m_size) output(lhs_vert + 1, horiz_base + i) = results[i].y;
if (lhs_vert + 2 < m_size) output(lhs_vert + 2, horiz_base + i) = results[i].z;
if (lhs_vert + 3 < m_size) output(lhs_vert + 3, horiz_base + i) = results[i].w;
}
}
}
}
template <typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper>
__global__ void
#if defined(EIGEN_HIPCC)
__launch_bounds__(256, 1)
#else
__launch_bounds__(256)
#endif
EigenFloatContractionKernel(const LhsMapper lhs, const RhsMapper rhs, const OutputMapper output, const Index m_size,
const Index n_size, const Index k_size) {
__shared__ float2 lhs_shmem[64 * 32];
__shared__ float2 rhs_shmem[128 * 8];
typedef float2 LHS_MEM[64][32];
typedef float2 RHS_MEM[128][8];
const Index m_block_idx = blockIdx.x;
const Index n_block_idx = blockIdx.y;
const Index base_m = 128 * m_block_idx;
const Index base_n = 64 * n_block_idx;
bool check_rhs = (base_n + 63) >= n_size;
bool check_lhs128 = (base_m + 127) >= m_size;
if (!check_rhs) {
if (!check_lhs128) {
EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, false, false>(
lhs, rhs, output, *((LHS_MEM*)lhs_shmem), *((RHS_MEM*)rhs_shmem), m_size, n_size, k_size, base_m, base_n);
} else {
EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, true, false>(
lhs, rhs, output, *((LHS_MEM*)lhs_shmem), *((RHS_MEM*)rhs_shmem), m_size, n_size, k_size, base_m, base_n);
}
} else {
if (!check_lhs128) {
EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, false, true>(
lhs, rhs, output, *((LHS_MEM*)lhs_shmem), *((RHS_MEM*)rhs_shmem), m_size, n_size, k_size, base_m, base_n);
} else {
EigenFloatContractionKernelInternal<Index, LhsMapper, RhsMapper, OutputMapper, true, true>(
lhs, rhs, output, *((LHS_MEM*)lhs_shmem), *((RHS_MEM*)rhs_shmem), m_size, n_size, k_size, base_m, base_n);
}
}
}
template <typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper>
__global__ void
#if defined(EIGEN_HIPCC)
__launch_bounds__(256, 1)
#else
__launch_bounds__(256)
#endif
EigenFloatContractionKernel16x16(const LhsMapper lhs, const RhsMapper rhs, const OutputMapper output,
const Index m_size, const Index n_size, const Index k_size) {
__shared__ float2 lhs_shmem[32][16];
__shared__ float2 rhs_shmem[64][8];
const Index m_block_idx = blockIdx.x;
const Index n_block_idx = blockIdx.y;
const Index base_m = 64 * m_block_idx;
const Index base_n = 64 * n_block_idx;
if (base_m + 63 < m_size) {
if (base_n + 63 < n_size) {
EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, false, false>(
lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
} else {
EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, false, true>(
lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
}
} else {
if (base_n + 63 < n_size) {
EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, true, false>(
lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
} else {
EigenFloatContractionKernelInternal16x16<Index, LhsMapper, RhsMapper, OutputMapper, true, true>(
lhs, rhs, output, lhs_shmem, rhs_shmem, m_size, n_size, k_size, base_m, base_n);
}
}
}
template <typename Indices, typename LeftArgType, typename RightArgType, typename OutputKernelType>
struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, GpuDevice>
: public TensorContractionEvaluatorBase<TensorEvaluator<
const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, GpuDevice> > {
typedef GpuDevice Device;
typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self;
typedef TensorContractionEvaluatorBase<Self> Base;
typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType;
typedef std::remove_const_t<typename XprType::Scalar> Scalar;
typedef typename XprType::Index Index;
typedef typename XprType::CoeffReturnType CoeffReturnType;
typedef typename PacketType<CoeffReturnType, GpuDevice>::type PacketReturnType;
static constexpr int Layout = TensorEvaluator<LeftArgType, Device>::Layout;
typedef std::conditional_t<Layout == static_cast<int>(ColMajor), LeftArgType, RightArgType> EvalLeftArgType;
typedef std::conditional_t<Layout == static_cast<int>(ColMajor), RightArgType, LeftArgType> EvalRightArgType;
static constexpr int LDims =
internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
static constexpr int RDims =
internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
static constexpr int ContractDims = internal::array_size<Indices>::value;
typedef array<Index, LDims> left_dim_mapper_t;
typedef array<Index, RDims> right_dim_mapper_t;
typedef array<Index, ContractDims> contract_t;
typedef array<Index, LDims - ContractDims> left_nocontract_t;
typedef array<Index, RDims - ContractDims> right_nocontract_t;
static constexpr int NumDims = LDims + RDims - 2 * ContractDims;
typedef DSizes<Index, NumDims> Dimensions;
typedef std::remove_const_t<typename EvalLeftArgType::Scalar> LhsScalar;
typedef std::remove_const_t<typename EvalRightArgType::Scalar> RhsScalar;
typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
typedef typename LeftEvaluator::Dimensions LeftDimensions;
typedef typename RightEvaluator::Dimensions RightDimensions;
TensorEvaluator(const XprType& op, const Device& device) : Base(op, device) {
EIGEN_STATIC_ASSERT((internal::is_same<OutputKernelType, const NoOpOutputKernel>::value),
GPU_TENSOR_CONTRACTION_DOES_NOT_SUPPORT_OUTPUT_KERNELS);
}
EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* data) {
this->m_leftImpl.evalSubExprsIfNeeded(NULL);
this->m_rightImpl.evalSubExprsIfNeeded(NULL);
if (data) {
evalTo(data);
return false;
} else {
this->m_result = static_cast<Scalar*>(this->m_device.allocate(this->dimensions().TotalSize() * sizeof(Scalar)));
evalTo(this->m_result);
return true;
}
}
void evalTo(Scalar* buffer) const {
if (this->m_lhs_inner_dim_contiguous) {
if (this->m_rhs_inner_dim_contiguous) {
if (this->m_rhs_inner_dim_reordered) {
evalTyped<true, true, true, Unaligned>(buffer);
} else {
evalTyped<true, true, false, Unaligned>(buffer);
}
} else {
if (this->m_rhs_inner_dim_reordered) {
evalTyped<true, false, true, Unaligned>(buffer);
} else {
evalTyped<true, false, false, Unaligned>(buffer);
}
}
} else {
if (this->m_rhs_inner_dim_contiguous) {
if (this->m_rhs_inner_dim_reordered) {
evalTyped<false, true, true, Unaligned>(buffer);
} else {
evalTyped<false, true, false, Unaligned>(buffer);
}
} else {
if (this->m_rhs_inner_dim_reordered) {
evalTyped<false, false, true, Unaligned>(buffer);
} else {
evalTyped<false, false, false, Unaligned>(buffer);
}
}
}
}
template <typename LhsScalar, typename RhsScalar, typename Index, typename LhsMapper, typename RhsMapper,
typename OutputMapper>
struct LaunchKernels {
static void Run(const LhsMapper& lhs, const RhsMapper& rhs, const OutputMapper& output, Index m, Index n, Index k,
const GpuDevice& device) {
const Index m_blocks = (m + 63) / 64;
const Index n_blocks = (n + 63) / 64;
const dim3 num_blocks(m_blocks, n_blocks, 1);
const dim3 block_size(8, 8, 8);
LAUNCH_GPU_KERNEL((EigenContractionKernel<Scalar, Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks,
block_size, 0, device, lhs, rhs, output, m, n, k);
}
};
template <typename Index, typename LhsMapper, typename RhsMapper, typename OutputMapper>
struct LaunchKernels<float, float, Index, LhsMapper, RhsMapper, OutputMapper> {
static void Run(const LhsMapper& lhs, const RhsMapper& rhs, const OutputMapper& output, Index m, Index n, Index k,
const GpuDevice& device) {
if (m < 768 || n < 768) {
const Index m_blocks = (m + 63) / 64;
const Index n_blocks = (n + 63) / 64;
const dim3 num_blocks(m_blocks, n_blocks, 1);
const dim3 block_size(16, 16, 1);
LAUNCH_GPU_KERNEL((EigenFloatContractionKernel16x16<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks,
block_size, 0, device, lhs, rhs, output, m, n, k);
} else {
const Index m_blocks = (m + 127) / 128;
const Index n_blocks = (n + 63) / 64;
const dim3 num_blocks(m_blocks, n_blocks, 1);
const dim3 block_size(8, 32, 1);
LAUNCH_GPU_KERNEL((EigenFloatContractionKernel<Index, LhsMapper, RhsMapper, OutputMapper>), num_blocks,
block_size, 0, device, lhs, rhs, output, m, n, k);
}
}
};
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
void evalTyped(Scalar* buffer) const {
const Index k = this->m_k_size;
EIGEN_UNUSED_VARIABLE(k)
const Index m = this->m_i_size;
const Index n = this->m_j_size;
this->m_device.fill(buffer, buffer + m * n, Scalar(0));
typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs, LeftEvaluator, left_nocontract_t,
contract_t, 4, lhs_inner_dim_contiguous, false, Unaligned>
LhsMapper;
typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs, RightEvaluator, right_nocontract_t,
contract_t, 4, rhs_inner_dim_contiguous, rhs_inner_dim_reordered,
Unaligned>
RhsMapper;
typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides,
this->m_left_contracting_strides, this->m_k_strides);
RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides,
this->m_right_contracting_strides, this->m_k_strides);
OutputMapper output(buffer, m);
#if defined(EIGEN_USE_HIP)
setGpuSharedMemConfig(hipSharedMemBankSizeEightByte);
#else
setGpuSharedMemConfig(cudaSharedMemBankSizeEightByte);
#endif
LaunchKernels<LhsScalar, RhsScalar, Index, LhsMapper, RhsMapper, OutputMapper>::Run(lhs, rhs, output, m, n, k,
this->m_device);
}
};
}
#endif
#endif