#ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_THREAD_POOL_H
#define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_THREAD_POOL_H
#ifdef EIGEN_USE_THREADS
#include "./InternalHeaderCheck.h"
namespace Eigen {
template <typename Indices, typename LeftArgType, typename RightArgType, typename OutputKernelType>
struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>,
ThreadPoolDevice>
: public TensorContractionEvaluatorBase<TensorEvaluator<
const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, ThreadPoolDevice>> {
typedef ThreadPoolDevice Device;
typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self;
typedef TensorContractionEvaluatorBase<Self> Base;
typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType;
typedef std::remove_const_t<typename XprType::Scalar> Scalar;
typedef typename XprType::Index Index;
typedef typename XprType::CoeffReturnType CoeffReturnType;
typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
static constexpr int Layout = TensorEvaluator<LeftArgType, Device>::Layout;
typedef std::conditional_t<static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>
EvalLeftArgType;
typedef std::conditional_t<static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>
EvalRightArgType;
static constexpr int LDims =
internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
static constexpr int RDims =
internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
static constexpr int ContractDims = internal::array_size<Indices>::value;
typedef array<Index, LDims> left_dim_mapper_t;
typedef array<Index, RDims> right_dim_mapper_t;
typedef array<Index, ContractDims> contract_t;
typedef array<Index, LDims - ContractDims> left_nocontract_t;
typedef array<Index, RDims - ContractDims> right_nocontract_t;
static constexpr int NumDims = LDims + RDims - 2 * ContractDims;
typedef DSizes<Index, NumDims> Dimensions;
typedef std::remove_const_t<typename EvalLeftArgType::Scalar> LhsScalar;
typedef std::remove_const_t<typename EvalRightArgType::Scalar> RhsScalar;
typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits;
typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
TensorEvaluator(const XprType& op, const Device& device) : Base(op, device) {}
template <int Alignment>
void evalProduct(Scalar* buffer) const {
evalProductImpl<NoCallback, Alignment>(buffer, NoCallback());
}
template <typename EvalToCallback, int Alignment>
void evalProductAsync(Scalar* buffer, EvalToCallback done) const {
evalProductImpl<EvalToCallback, Alignment>(buffer, std::move(done));
}
template <typename DoneCallback, int Alignment>
void evalProductImpl(Scalar* buffer, DoneCallback done) const {
static const bool IsEvalInSyncMode = std::is_same<DoneCallback, NoCallback>::value;
const Index m = this->m_i_size;
const Index n = this->m_j_size;
const Index k = this->m_k_size;
if (m == 0 || n == 0 || k == 0) return;
bool shard_by_col = shardByCol(m, n, 2);
Index bm, bn, bk;
if (shard_by_col) {
internal::TensorContractionBlocking<Scalar, LhsScalar, RhsScalar, Index, internal::ShardByCol> blocking(k, m, n,
2);
bm = blocking.mc();
bn = blocking.nc();
bk = blocking.kc();
} else {
internal::TensorContractionBlocking<Scalar, LhsScalar, RhsScalar, Index, internal::ShardByRow> blocking(k, m, n,
2);
bm = blocking.mc();
bn = blocking.nc();
bk = blocking.kc();
}
const TensorOpCost cost = contractionCost(m, n, bm, bn, bk, shard_by_col, false);
int num_threads =
TensorCostModel<ThreadPoolDevice>::numThreads(static_cast<double>(n) * m, cost, this->m_device.numThreads());
int num_threads_by_k = numThreadsInnerDim(m, n, k);
if (shardByInnerDim(m, n, k, num_threads, num_threads_by_k)) {
if (IsEvalInSyncMode) {
EvalShardedByInnerDimContext<DoneCallback> ctx(this, num_threads_by_k, buffer, m, n, k, std::move(done));
ctx.template run<Alignment>();
} else {
auto* ctx =
new EvalShardedByInnerDimContext<DoneCallback>(this, num_threads_by_k, buffer, m, n, k, std::move(done));
ctx->template runAsync<Alignment>();
}
return;
}
if (n == 1) num_threads = 1;
if (num_threads == 1) {
TENSOR_CONTRACTION_DISPATCH(this->template evalProductSequential, Unaligned, (buffer));
if (!IsEvalInSyncMode) done();
return;
}
shard_by_col = shardByCol(m, n, num_threads);
if (shard_by_col) {
internal::TensorContractionBlocking<Scalar, LhsScalar, RhsScalar, Index, internal::ShardByCol> blocking(
k, m, n, num_threads);
bm = blocking.mc();
bn = blocking.nc();
bk = blocking.kc();
} else {
internal::TensorContractionBlocking<Scalar, LhsScalar, RhsScalar, Index, internal::ShardByRow> blocking(
k, m, n, num_threads);
bm = blocking.mc();
bn = blocking.nc();
bk = blocking.kc();
}
Index nm0 = numext::div_ceil(m, bm);
Index nn0 = numext::div_ceil(n, bn);
Index nk = numext::div_ceil(k, bk);
Index gm = 1;
Index gn = 1;
if (shard_by_col) {
gm = coarsenM(m, n, bm, bn, bk, gn, num_threads, shard_by_col);
gn = coarsenN(m, n, bm, bn, bk, gm, num_threads, shard_by_col);
} else {
gn = coarsenN(m, n, bm, bn, bk, gm, num_threads, shard_by_col);
gm = coarsenM(m, n, bm, bn, bk, gn, num_threads, shard_by_col);
}
Index nm = numext::div_ceil(nm0, gm);
Index nn = numext::div_ceil(nn0, gn);
const Index sharding_dim_tasks = shard_by_col ? nn : nm;
const int num_worker_threads = this->m_device.numThreadsInPool();
const float oversharding_factor = num_worker_threads <= 4 ? 8.0
: num_worker_threads <= 8 ? 4.0
: num_worker_threads <= 16 ? 2.0
: num_worker_threads <= 32 ? 1.0
: num_worker_threads <= 64 ? 0.8
: 0.6;
const bool parallelize_by_sharding_dim_only = sharding_dim_tasks >= oversharding_factor * num_worker_threads;
bool parallel_pack = num_threads >= nm * nn;
if (m * bk * Index(sizeof(LhsScalar)) + n * bk * Index(sizeof(RhsScalar)) <= l2CacheSize() * num_threads)
parallel_pack = true;
if ((shard_by_col ? nm : nn) == 1) parallel_pack = false;
if (parallelize_by_sharding_dim_only) parallel_pack = false;
if (IsEvalInSyncMode) {
#define CONTEXT_ARGS …
TENSOR_CONTRACTION_DISPATCH(SyncEvalParallelContext, Alignment, CONTEXT_ARGS);
#undef CONTEXT_ARGS
} else {
#define CONTEXT_ARGS …
TENSOR_CONTRACTION_ASYNC_DISPATCH(EvalParallelContext, DoneCallback, Alignment, CONTEXT_ARGS, run());
#undef CONTEXT_ARGS
}
}
struct NoCallback {
void operator()() { eigen_assert(false && "NoCallback should never be called"); }
};
template <typename DoneCallback, typename Context>
class EvalParallelNotification;
template <typename Context>
class EvalParallelNotification<NoCallback, Context> {
public:
EvalParallelNotification(Context*, NoCallback) {}
void Notify() { done_.Notify(); }
void Wait() { done_.Wait(); }
private:
Eigen::Notification done_;
};
template <typename DoneCallback, typename Context>
class EvalParallelNotification {
public:
EvalParallelNotification(Context* ctx, DoneCallback done) : ctx_(ctx), done_(std::move(done)) {}
void Notify() {
DoneCallback done_copy = std::move(done_);
delete ctx_;
done_copy();
}
void Wait() {}
private:
Context* ctx_;
DoneCallback done_;
};
template <typename DoneCallback, bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered, int Alignment>
class EvalParallelContext {
public:
typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs, LeftEvaluator, left_nocontract_t,
contract_t, internal::packet_traits<LhsScalar>::size,
lhs_inner_dim_contiguous, false, Unaligned>
LhsMapper;
typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs, RightEvaluator, right_nocontract_t,
contract_t, internal::packet_traits<RhsScalar>::size,
rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Unaligned>
RhsMapper;
typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
typedef internal::TensorContractionKernel<Scalar, LhsScalar, RhsScalar, Index, OutputMapper, LhsMapper, RhsMapper>
TensorContractionKernel;
typedef typename TensorContractionKernel::LhsBlock LhsBlock;
typedef typename TensorContractionKernel::RhsBlock RhsBlock;
typedef typename TensorContractionKernel::BlockMemHandle BlockMemHandle;
EvalParallelContext(const Self* self, int num_threads, Scalar* buffer, Index tm, Index tn, Index tk, Index bm,
Index bn, Index bk, Index nm, Index nn, Index nk, Index gm, Index gn, Index nm0, Index nn0,
bool shard_by_col, bool parallel_pack, bool parallelize_by_sharding_dim_only, DoneCallback done)
: created_by_thread_id_(std::this_thread::get_id()),
done_(this, std::move(done)),
device_(self->m_device),
lhs_(self->m_leftImpl, self->m_left_nocontract_strides, self->m_i_strides, self->m_left_contracting_strides,
self->m_k_strides),
rhs_(self->m_rightImpl, self->m_right_nocontract_strides, self->m_j_strides,
self->m_right_contracting_strides, self->m_k_strides),
buffer_(buffer),
output_(buffer, tm),
output_kernel_(self->m_output_kernel),
tensor_contraction_params_(self->m_tensor_contraction_params),
num_threads_(num_threads),
shard_by_col_(shard_by_col),
parallel_pack_(parallel_pack),
parallelize_by_sharding_dim_only_(parallelize_by_sharding_dim_only),
m_(tm),
n_(tn),
k_(tk),
bm_(bm),
bn_(bn),
bk_(bk),
nm_(nm),
nn_(nn),
nk_(nk),
gm_(gm),
gn_(gn),
nm0_(nm0),
nn0_(nn0),
kernel_(m_, k_, n_, bm_, bk_, bn_),
num_thread_local_allocations_(0),
thread_local_capacity(2 * (parallelize_by_sharding_dim_only_ ? device_.numThreadsInPool() : 0)),
lhs_thread_local_blocks_(shard_by_col_ ? 0 : thread_local_capacity, {*this}, {*this}),
rhs_thread_local_blocks_(shard_by_col_ ? thread_local_capacity : 0, {*this}, {*this}) {
eigen_assert(!(parallel_pack && parallelize_by_sharding_dim_only));
for (Index x = 0; x < P; x++) {
state_switch_[x] =
x == 0 ? 1 : (parallel_pack_ ? nn_ + nm_ : (shard_by_col_ ? nn_ : nm_)) + (x == P - 1 ? nm_ * nn_ : 0);
state_packing_ready_[x] = parallel_pack_ ? 0 : (shard_by_col_ ? nm_ : nn_);
state_kernel_[x] = new std::atomic<uint8_t>*[nm_];
for (Index m = 0; m < nm_; m++) {
state_kernel_[x][m] = new std::atomic<uint8_t>[nn_];
for (Index n = 0; n < nn_; n++)
state_kernel_[x][m][n].store((x == 0 ? 0 : 1) + (parallel_pack_ ? 2 : 1), std::memory_order_relaxed);
}
}
packed_mem_ = kernel_.allocateSlices(
device_,
nm0_,
nn0_,
std::min<Index>(nk_, P - 1),
packed_lhs_, packed_rhs_);
if (parallelize_by_sharding_dim_only_) {
const int num_worker_threads = device_.numThreadsInPool();
if (shard_by_col) {
can_use_thread_local_packed_ = new std::atomic<bool>[nn_];
for (int i = 0; i < nn_; ++i) can_use_thread_local_packed_[i].store(true, std::memory_order_relaxed);
Index num_blocks = num_worker_threads * gn_;
thread_local_pre_alocated_mem_ = kernel_.allocateSlices(
device_,
0,
num_blocks,
1,
nullptr, &rhs_thread_local_pre_allocated_);
} else {
can_use_thread_local_packed_ = new std::atomic<bool>[nm_];
for (int i = 0; i < nm_; ++i) can_use_thread_local_packed_[i].store(true, std::memory_order_relaxed);
Index num_blocks = num_worker_threads * gm_;
thread_local_pre_alocated_mem_ = kernel_.allocateSlices(
device_,
num_blocks,
0,
1, &lhs_thread_local_pre_allocated_,
nullptr);
}
}
}
~EvalParallelContext() {
for (Index x = 0; x < P; x++) {
for (Index m = 0; m < nm_; m++) delete[] state_kernel_[x][m];
delete[] state_kernel_[x];
}
kernel_.deallocate(device_, packed_mem_);
if (parallelize_by_sharding_dim_only_) {
kernel_.deallocate(device_, thread_local_pre_alocated_mem_);
delete[] can_use_thread_local_packed_;
}
}
void run() {
signal_switch(0, 1);
done_.Wait();
}
private:
std::thread::id created_by_thread_id_;
EvalParallelNotification<DoneCallback, EvalParallelContext> done_;
const Device& device_;
LhsMapper lhs_;
RhsMapper rhs_;
Scalar* const buffer_;
OutputMapper output_;
OutputKernelType output_kernel_;
TensorContractionParams tensor_contraction_params_;
const int num_threads_;
const bool shard_by_col_;
const bool parallel_pack_;
const bool parallelize_by_sharding_dim_only_;
const Index m_;
const Index n_;
const Index k_;
const Index bm_;
const Index bn_;
const Index bk_;
const Index nm_;
const Index nn_;
const Index nk_;
const Index gm_;
const Index gn_;
const Index nm0_;
const Index nn0_;
TensorContractionKernel kernel_;
static constexpr Index P = 3;
BlockMemHandle packed_mem_;
std::vector<LhsBlock> packed_lhs_[P - 1];
std::vector<RhsBlock> packed_rhs_[P - 1];
BlockMemHandle thread_local_pre_alocated_mem_;
std::vector<LhsBlock> lhs_thread_local_pre_allocated_;
std::vector<RhsBlock> rhs_thread_local_pre_allocated_;
std::atomic<int> num_thread_local_allocations_;
const int thread_local_capacity;
template <typename BlockType>
class ThreadLocalBlocks {
public:
ThreadLocalBlocks() = default;
ThreadLocalBlocks(BlockType* base, size_t grain_size)
: is_pre_allocated_(true), thread_local_pre_allocated_base_(base), grain_size_(grain_size) {}
ThreadLocalBlocks(BlockMemHandle mem_handle, std::vector<BlockType> blocks)
: is_pre_allocated_(false), mem_handle_(std::move(mem_handle)), blocks_(std::move(blocks)) {}
BlockType& block(int grain_index) {
eigen_assert(grain_index >= 0);
eigen_assert(static_cast<size_t>(grain_index) < size());
return is_pre_allocated_ ? thread_local_pre_allocated_base_[grain_index] : blocks_[grain_index];
}
void Release(EvalParallelContext& ctx) const {
if (!is_pre_allocated_) {
ctx.kernel_.deallocate(ctx.device_, mem_handle_);
}
}
size_t size() const { return is_pre_allocated_ ? grain_size_ : blocks_.size(); }
private:
bool is_pre_allocated_;
BlockType* thread_local_pre_allocated_base_ = nullptr;
size_t grain_size_ = 0;
BlockMemHandle mem_handle_{};
std::vector<BlockType> blocks_;
};
template <typename BlockType, bool is_rhs>
class ThreadLocalBlocksInitialize {
static constexpr bool kIsLhs = !is_rhs && std::is_same<BlockType, LhsBlock>::value;
static const bool kIsRhs = is_rhs && std::is_same<BlockType, RhsBlock>::value;
static_assert(kIsLhs || kIsRhs, "Unknown block type");
using Blocks = ThreadLocalBlocks<BlockType>;
public:
ThreadLocalBlocksInitialize(EvalParallelContext& ctx)
: ctx_(ctx), num_worker_threads_(ctx_.device_.numThreadsInPool()) {}
void operator()(Blocks& blocks) {
const int n = ctx_.num_thread_local_allocations_.fetch_add(1, std::memory_order_relaxed);
if (n >= num_worker_threads_) {
ThreadLocalBlocksAllocator<is_rhs>::allocate(ctx_, blocks);
} else {
ThreadLocalBlocksAllocator<is_rhs>::reuse(ctx_, n, blocks);
}
}
private:
template <bool pack_rhs, typename EvalCtx = EvalParallelContext>
struct ThreadLocalBlocksAllocator;
template <typename EvalCtx>
struct ThreadLocalBlocksAllocator<true, EvalCtx> {
static void allocate(EvalCtx& ctx, Blocks& blocks) {
std::vector<RhsBlock> rhs_blocks;
BlockMemHandle mem_handle = ctx.kernel_.allocateSlices(ctx.device_,
0,
ctx.gn_,
1,
nullptr, &rhs_blocks);
blocks = ThreadLocalBlocks<RhsBlock>(std::move(mem_handle), std::move(rhs_blocks));
}
static void reuse(EvalCtx& ctx, int index, Blocks& blocks) {
RhsBlock* ptr = &ctx.rhs_thread_local_pre_allocated_[ctx.gn_ * index];
blocks = ThreadLocalBlocks<RhsBlock>(ptr, ctx.gn_);
}
};
template <typename EvalCtx>
struct ThreadLocalBlocksAllocator<false, EvalCtx> {
static void allocate(EvalCtx& ctx, Blocks& blocks) {
std::vector<LhsBlock> lhs_blocks;
BlockMemHandle mem_handle = ctx.kernel_.allocateSlices(ctx.device_,
ctx.gm_,
0,
1,
&lhs_blocks, nullptr);
blocks = ThreadLocalBlocks<LhsBlock>(std::move(mem_handle), std::move(lhs_blocks));
}
static void reuse(EvalCtx& ctx, int index, Blocks& blocks) {
LhsBlock* ptr = &ctx.lhs_thread_local_pre_allocated_[ctx.gm_ * index];
blocks = ThreadLocalBlocks<LhsBlock>(ptr, ctx.gm_);
}
};
EvalParallelContext& ctx_;
const int num_worker_threads_;
};
template <typename BlockType>
class ThreadLocalBlocksRelease {
public:
using Blocks = ThreadLocalBlocks<BlockType>;
ThreadLocalBlocksRelease(EvalParallelContext& ctx) : ctx_(ctx) {}
void operator()(Blocks& blocks) { blocks.Release(ctx_); }
private:
EvalParallelContext& ctx_;
};
using ThreadLocalLhsInit = ThreadLocalBlocksInitialize<LhsBlock, false>;
using ThreadLocalRhsInit = ThreadLocalBlocksInitialize<RhsBlock, true>;
using ThreadLocalLhsRelease = ThreadLocalBlocksRelease<LhsBlock>;
using ThreadLocalRhsRelease = ThreadLocalBlocksRelease<RhsBlock>;
Eigen::ThreadLocal<ThreadLocalBlocks<LhsBlock>, ThreadLocalLhsInit, ThreadLocalLhsRelease> lhs_thread_local_blocks_;
Eigen::ThreadLocal<ThreadLocalBlocks<RhsBlock>, ThreadLocalRhsInit, ThreadLocalRhsRelease> rhs_thread_local_blocks_;
std::atomic<bool>* can_use_thread_local_packed_;
std::atomic<uint8_t>** state_kernel_[P];
char pad_[128];
std::atomic<Index> state_packing_ready_[P];
std::atomic<Index> state_switch_[P];
LhsBlock& packed_lhs(Index m, Index k, Index m1, bool use_thread_local) {
if (use_thread_local) {
eigen_assert(!shard_by_col_);
ThreadLocalBlocks<LhsBlock>& blocks = lhs_thread_local_blocks_.local();
Index grain_index = m1 - m * gm_;
return blocks.block(
internal::convert_index<int>(grain_index));
} else {
return packed_lhs_[k % (P - 1)][m1];
}
}
RhsBlock& packed_rhs(Index n, Index k, Index n1, bool use_thread_local) {
if (use_thread_local) {
eigen_assert(shard_by_col_);
ThreadLocalBlocks<RhsBlock>& blocks = rhs_thread_local_blocks_.local();
Index grain_index = n1 - n * gn_;
return blocks.block(
internal::convert_index<int>(grain_index));
} else {
return packed_rhs_[k % (P - 1)][n1];
}
}
void pack_lhs(Index m, Index k) {
bool use_thread_local = false;
if (parallelize_by_sharding_dim_only_ && !shard_by_col_ &&
can_use_thread_local_packed_[m].load(std::memory_order_relaxed)) {
if (state_kernel_[k % P][m][0].load(std::memory_order_relaxed) == 1) {
use_thread_local = true;
} else {
eigen_assert(k > 0);
can_use_thread_local_packed_[m].store(false, std::memory_order_relaxed);
}
}
const Index mend = m * gm_ + gm(m);
for (Index m1 = m * gm_; m1 < mend; m1++)
kernel_.packLhs(&packed_lhs(m, k, m1, use_thread_local), lhs_.getSubMapper(m1 * bm_, k * bk_), bk(k), bm(m1));
if (!parallel_pack_ && shard_by_col_) {
eigen_assert(!use_thread_local);
signal_packing(k);
} else {
signal_switch(k + 1);
for (Index n = nn_ - 1; n >= 0; n--) {
bool sync = parallelize_by_sharding_dim_only_ || n == 0;
signal_kernel(m, n, k, sync, use_thread_local);
}
}
}
void pack_rhs(Index n, Index k) {
bool use_thread_local = false;
if (parallelize_by_sharding_dim_only_ && shard_by_col_ &&
can_use_thread_local_packed_[n].load(std::memory_order_relaxed)) {
if (state_kernel_[k % P][0][n].load(std::memory_order_relaxed) == 1) {
use_thread_local = true;
} else {
eigen_assert(k > 0);
can_use_thread_local_packed_[n].store(false, std::memory_order_relaxed);
}
}
const Index nend = n * gn_ + gn(n);
for (Index n1 = n * gn_; n1 < nend; n1++) {
if (!TensorContractionKernel::HasBeta && k == 0) {
std::fill_n(buffer_ + n1 * bn_ * m_, bn(n1) * m_, Scalar(0));
}
kernel_.packRhs(&packed_rhs(n, k, n1, use_thread_local), rhs_.getSubMapper(k * bk_, n1 * bn_), bk(k), bn(n1));
}
if (parallel_pack_ || shard_by_col_) {
signal_switch(k + 1);
for (Index m = nm_ - 1; m >= 0; m--) {
bool sync = parallelize_by_sharding_dim_only_ || m == 0;
signal_kernel(m, n, k, sync, use_thread_local);
}
} else {
eigen_assert(!use_thread_local);
signal_packing(k);
}
}
void kernel(Index m, Index n, Index k, bool use_thread_local) {
const Index nend = n * gn_ + gn(n);
const Index mend = m * gm_ + gm(m);
const Scalar alpha = Scalar(1);
const Scalar beta = (TensorContractionKernel::HasBeta && k == 0) ? Scalar(0) : Scalar(1);
if (shard_by_col_) {
for (Index n1 = n * gn_; n1 < nend; n1++) {
for (Index m1 = m * gm_; m1 < mend; m1++) {
const auto output_mapper = output_.getSubMapper(m1 * bm_, n1 * bn_);
kernel_.invoke(output_mapper, packed_lhs(m, k, m1, !shard_by_col_ && use_thread_local),
packed_rhs(n, k, n1, shard_by_col_ && use_thread_local), bm(m1), bk(k), bn(n1), alpha, beta);
if (k + 1 == nk_) {
output_kernel_(output_mapper, tensor_contraction_params_, m1 * bm_, n1 * bn_, bm(m1), bn(n1));
}
}
}
} else {
for (Index m1 = m * gm_; m1 < mend; m1++)
for (Index n1 = n * gn_; n1 < nend; n1++) {
const auto output_mapper = output_.getSubMapper(m1 * bm_, n1 * bn_);
kernel_.invoke(output_mapper, packed_lhs(m, k, m1, !shard_by_col_ && use_thread_local),
packed_rhs(n, k, n1, shard_by_col_ && use_thread_local), bm(m1), bk(k), bn(n1), alpha, beta);
if (k + 1 == nk_) {
output_kernel_(output_mapper, tensor_contraction_params_, m1 * bm_, n1 * bn_, bm(m1), bn(n1));
}
}
}
signal_kernel(m, n, k + 1, false, false);
signal_switch(k + 2);
}
void signal_packing(Index k) {
eigen_assert(!parallel_pack_);
Index s = state_packing_ready_[k % P].fetch_sub(1);
eigen_assert(s > 0);
if (s != 1) return;
state_packing_ready_[k % P] = shard_by_col_ ? nm_ : nn_;
enqueue_packing(k, shard_by_col_);
}
void signal_kernel(Index m, Index n, Index k, bool sync, bool use_thread_local) {
std::atomic<uint8_t>* state = &state_kernel_[k % P][m][n];
Index s = state->load();
eigen_assert(s > 0);
if (s != 1 && state->fetch_sub(1) != 1) {
eigen_assert(!use_thread_local);
return;
}
state->store(parallel_pack_ ? 3 : 2, std::memory_order_relaxed);
if (sync) {
kernel(m, n, k, use_thread_local);
} else {
eigen_assert(!use_thread_local);
device_.enqueueNoNotification([=]() { kernel(m, n, k, use_thread_local); });
}
}
void signal_switch(Index k, Index v = 1) {
Index s = state_switch_[k % P].fetch_sub(v);
eigen_assert(s >= v);
if (s != v) return;
state_switch_[k % P] = (parallel_pack_ ? nm_ + nn_ : (shard_by_col_ ? nn_ : nm_)) + nm_ * nn_;
if (k < nk_) {
if (parallel_pack_) {
enqueue_packing(k, !shard_by_col_);
enqueue_packing(k, shard_by_col_);
} else if (shard_by_col_) {
enqueue_packing(k, false);
} else {
enqueue_packing(k, true);
}
} else if (k == nk_) {
signal_switch(k + 1, parallel_pack_ ? nm_ + nn_ : (shard_by_col_ ? nn_ : nm_));
} else {
done_.Notify();
}
}
void enqueue_packing(Index k, bool rhs) { enqueue_packing_helper(0, rhs ? nn_ : nm_, k, rhs); }
void enqueue_packing_helper(Index start, Index end, Index k, bool rhs) {
if (end - start == 1) {
if (rhs)
pack_rhs(start, k);
else
pack_lhs(start, k);
} else {
while (end - start > 1) {
Index mid = (start + end) / 2;
device_.enqueueNoNotification([=]() { enqueue_packing_helper(mid, end, k, rhs); });
end = mid;
}
bool pack_async = (start == 0) && (parallelize_by_sharding_dim_only_ && shard_by_col_ == rhs) &&
(k > 0 || std::this_thread::get_id() == created_by_thread_id_);
if (pack_async) {
device_.enqueueNoNotification([=]() { enqueue_packing_helper(start, end, k, rhs); });
} else {
enqueue_packing_helper(start, end, k, rhs);
}
}
}
Index bm(Index m) const { return m + 1 < nm0_ ? bm_ : m_ + bm_ - bm_ * nm0_; }
Index bn(Index n) const { return n + 1 < nn0_ ? bn_ : n_ + bn_ - bn_ * nn0_; }
Index bk(Index k) const { return k + 1 < nk_ ? bk_ : k_ + bk_ - bk_ * nk_; }
Index gm(Index m) const { return m + 1 < nm_ ? gm_ : nm0_ + gm_ - gm_ * nm_; }
Index gn(Index n) const { return n + 1 < nn_ ? gn_ : nn0_ + gn_ - gn_ * nn_; }
EvalParallelContext(const EvalParallelContext&) = delete;
void operator=(const EvalParallelContext&) = delete;
};
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
using SyncEvalParallelContext = EvalParallelContext<NoCallback, lhs_inner_dim_contiguous, rhs_inner_dim_contiguous,
rhs_inner_dim_reordered, Alignment>;
template <typename DoneCallback>
struct EvalShardedByInnerDimContext {
EvalShardedByInnerDimContext(const Self* self, int num_threads, Scalar* result_buffer, Index m_size, Index n_size,
Index k_size, DoneCallback done_callback)
: evaluator(self),
m_lhs_inner_dim_contiguous(evaluator->m_lhs_inner_dim_contiguous),
m_rhs_inner_dim_contiguous(evaluator->m_rhs_inner_dim_contiguous),
m_rhs_inner_dim_reordered(evaluator->m_rhs_inner_dim_reordered),
result(result_buffer),
m(m_size),
n(n_size),
k(k_size),
done(std::move(done_callback)),
buffer_size_bytes(m * n * sizeof(Scalar)),
block_size(blockSize(k, num_threads)),
num_blocks(numext::div_ceil<Index>(k, block_size)),
num_pending_blocks(internal::convert_index<int>(num_blocks)),
l0_ranges(numext::div_ceil<Index>(num_blocks, l0_size)),
l0_state(l0_ranges),
block_buffers(num_blocks) {
for (int i = 0; i < l0_ranges; ++i) {
const Index num_pending_tasks = actualRangeSize(l0_ranges, l0_size, i);
l0_state.emplace_back(internal::convert_index<int>(num_pending_tasks));
}
for (Index block_idx = 0; block_idx < num_blocks; ++block_idx) {
Scalar* buf = block_idx == 0 ? result : static_cast<Scalar*>(evaluator->m_device.allocate(buffer_size_bytes));
block_buffers.emplace_back(buf);
}
}
~EvalShardedByInnerDimContext() {
for (Index i = 1; i < num_blocks; ++i) {
evaluator->m_device.deallocate(block_buffers[i]);
}
}
template <int Alignment>
void run() {
Barrier barrier(internal::convert_index<int>(num_blocks));
eval<Alignment>(barrier, 0, num_blocks);
barrier.Wait();
aggregateL0Blocks<Alignment>();
applyOutputKernel();
}
template <int Alignment>
void runAsync() {
evalAsync<Alignment>(0, num_blocks);
}
private:
static const Index packet_size = internal::packet_traits<RhsScalar>::size;
const Self* evaluator;
bool m_lhs_inner_dim_contiguous;
bool m_rhs_inner_dim_contiguous;
bool m_rhs_inner_dim_reordered;
Scalar* result;
Index m;
Index n;
Index k;
DoneCallback done;
Index buffer_size_bytes;
Index block_size;
Index num_blocks;
std::atomic<int> num_pending_blocks;
static const Index l0_size = 4;
Index l0_ranges;
MaxSizeVector<std::atomic<int>> l0_state;
MaxSizeVector<Scalar*> block_buffers;
template <int Alignment>
void processBlock(Index block_idx, Index begin, Index end) {
Scalar* buf = block_buffers[block_idx];
TENSOR_CONTRACTION_DISPATCH(evaluator->template evalGemmPartialWithoutOutputKernel, Alignment,
(buf, begin, end,
internal::convert_index<int>(num_blocks)));
const Index l0_index = block_idx / l0_size;
const int v = l0_state[l0_index].fetch_sub(1);
eigen_assert(v >= 1);
if (v == 1) {
const Index rng_size = actualRangeSize(l0_ranges, l0_size, l0_index);
const Index dst_block_idx = l0_index * l0_size;
if (rng_size == l0_size) {
addAllToBuffer<Alignment>(m * n,
block_buffers[dst_block_idx + 1],
block_buffers[dst_block_idx + 2],
block_buffers[dst_block_idx + 3],
block_buffers[dst_block_idx]);
} else {
for (int i = 1; i < rng_size; ++i) {
addToBuffer<Alignment>(m * n,
block_buffers[dst_block_idx + i],
block_buffers[dst_block_idx]);
}
}
}
}
template <int Alignment>
void aggregateL0Blocks() const {
Index l0_index = 1;
for (; l0_index + 2 < l0_ranges; l0_index += 3) {
addAllToBuffer<Alignment>(m * n,
block_buffers[(l0_index + 0) * l0_size],
block_buffers[(l0_index + 1) * l0_size],
block_buffers[(l0_index + 2) * l0_size],
block_buffers[0]);
}
for (; l0_index < l0_ranges; ++l0_index) {
addToBuffer<Alignment>(m * n, block_buffers[l0_index * l0_size], block_buffers[0]);
}
}
void applyOutputKernel() const {
typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
evaluator->m_output_kernel(OutputMapper(result, m), evaluator->m_tensor_contraction_params,
static_cast<Eigen::Index>(0), static_cast<Eigen::Index>(0), m, n);
}
Index actualBlockSize(Index block_idx) const {
return block_idx + 1 < num_blocks ? block_size : k + block_size - block_size * num_blocks;
};
Index actualRangeSize(Index num_ranges, Index range_size, Index range_idx) const {
eigen_assert(range_idx < num_ranges);
return range_idx + 1 < num_ranges ? range_size : num_blocks + range_size - range_size * num_ranges;
};
template <int Alignment>
EIGEN_STRONG_INLINE static void addToBuffer(size_t n, const Scalar* src_buf, Scalar* tgt_buf) {
const int output_packet_size = internal::unpacket_traits<PacketReturnType>::size;
size_t i = 0;
const size_t num_packets = n / output_packet_size;
for (; i < output_packet_size * num_packets; i += output_packet_size) {
const PacketReturnType src_val = internal::pload<PacketReturnType>(src_buf + i);
const PacketReturnType tgt_val = internal::ploadt<PacketReturnType, Alignment>(tgt_buf + i);
const PacketReturnType sum = internal::padd(src_val, tgt_val);
internal::pstoret<Scalar, PacketReturnType, Alignment>(tgt_buf + i, sum);
}
for (; i < n; ++i) {
tgt_buf[i] += src_buf[i];
}
}
template <int Alignment>
EIGEN_STRONG_INLINE static void addAllToBuffer(size_t n, const Scalar* src_buf0, const Scalar* src_buf1,
const Scalar* src_buf2, Scalar* dst_buf) {
using ::Eigen::internal::padd;
using ::Eigen::internal::pload;
using ::Eigen::internal::ploadt;
using ::Eigen::internal::pstoret;
const int output_packet_size = internal::unpacket_traits<PacketReturnType>::size;
size_t i = 0;
const size_t num_packets = n / output_packet_size;
for (; i < output_packet_size * num_packets; i += output_packet_size) {
const auto src_val0 = pload<PacketReturnType>(src_buf0 + i);
const auto src_val1 = pload<PacketReturnType>(src_buf1 + i);
const auto src_val2 = pload<PacketReturnType>(src_buf2 + i);
const auto dst_val = ploadt<PacketReturnType, Alignment>(dst_buf + i);
const auto sum = padd(padd(dst_val, src_val0), padd(src_val1, src_val2));
pstoret<Scalar, PacketReturnType, Alignment>(dst_buf + i, sum);
}
for (; i < n; ++i) {
dst_buf[i] += src_buf0[i] + src_buf1[i] + src_buf2[i];
}
}
template <int Alignment>
void eval(Barrier& barrier, Index start_block_idx, Index end_block_idx) {
while (end_block_idx - start_block_idx > 1) {
Index mid_block_idx = (start_block_idx + end_block_idx) / 2;
evaluator->m_device.enqueueNoNotification([this, &barrier, mid_block_idx, end_block_idx]() {
eval<Alignment>(barrier, mid_block_idx, end_block_idx);
});
end_block_idx = mid_block_idx;
}
Index block_idx = start_block_idx;
Index block_start = block_idx * block_size;
Index block_end = block_start + actualBlockSize(block_idx);
processBlock<Alignment>(block_idx, block_start, block_end);
barrier.Notify();
}
template <int Alignment>
void evalAsync(Index start_block_idx, Index end_block_idx) {
while (end_block_idx - start_block_idx > 1) {
Index mid_block_idx = (start_block_idx + end_block_idx) / 2;
evaluator->m_device.enqueueNoNotification(
[this, mid_block_idx, end_block_idx]() { evalAsync<Alignment>(mid_block_idx, end_block_idx); });
end_block_idx = mid_block_idx;
}
Index block_idx = start_block_idx;
Index block_start = block_idx * block_size;
Index block_end = block_start + actualBlockSize(block_idx);
processBlock<Alignment>(block_idx, block_start, block_end);
int v = num_pending_blocks.fetch_sub(1);
eigen_assert(v >= 1);
if (v == 1) {
aggregateL0Blocks<Alignment>();
applyOutputKernel();
DoneCallback done_copy = std::move(done);
delete this;
done_copy();
}
}
static Index blockSize(Index k, int num_threads) {
const auto round_up = [=](Index index) -> Index {
const Index kmultiple = packet_size <= 8 ? 8 : packet_size;
return numext::div_ceil<Index>(index, kmultiple) * kmultiple;
};
const Index target_block_size = round_up(numext::div_ceil<Index>(k, num_threads));
const Index desired_min_block_size = 12 * packet_size;
return numext::mini<Index>(k, numext::maxi<Index>(desired_min_block_size, target_block_size));
}
EvalShardedByInnerDimContext(const EvalShardedByInnerDimContext&) = delete;
void operator=(const EvalShardedByInnerDimContext&) = delete;
};
static bool shardByCol(Index m, Index n, Index num_threads) {
if (m / num_threads >= Traits::nr &&
(n / num_threads < Traits::nr ||
(n / num_threads < 4 * Traits::nr && (n % (num_threads * Traits::nr)) != 0 &&
((m % (num_threads * Traits::nr)) == 0 ||
(m / n >= 6)))))
return false;
if (n / num_threads < 16 * Traits::nr && m > n * 32) return false;
return true;
}
Index coarsenM(Index m, Index n, Index bm, Index bn, Index bk, Index gn, int num_threads, bool shard_by_col) const {
Index gm = 1;
Index gm1 = 1;
Index nm0 = numext::div_ceil(m, bm);
Index nm1 = nm0;
for (;;) {
while (gm1 <= nm0 && nm1 == numext::div_ceil(nm0, gm1)) gm1++;
if (gm1 > nm0) break;
int res = checkGrain(m, n, bm, bn, bk, gm1, gn, gm, gn, num_threads, shard_by_col);
if (res < 0) break;
nm1 = numext::div_ceil(nm0, gm1);
if (res == 0) continue;
gm = gm1;
}
return gm;
}
Index coarsenN(Index m, Index n, Index bm, Index bn, Index bk, Index gm, int num_threads, bool shard_by_col) const {
Index gn = 1;
Index gn1 = 1;
Index nn0 = numext::div_ceil(n, bn);
Index nn1 = nn0;
for (;;) {
while (gn1 <= nn0 && nn1 == numext::div_ceil(nn0, gn1)) gn1++;
if (gn1 > nn0) break;
int res = checkGrain(m, n, bm, bn, bk, gm, gn1, gm, gn, num_threads, shard_by_col);
if (res < 0) break;
nn1 = numext::div_ceil(nn0, gn1);
if (res == 0) continue;
gn = gn1;
}
return gn;
}
int checkGrain(Index m, Index n, Index bm, Index bn, Index bk, Index gm, Index gn, Index oldgm, Index oldgn,
int num_threads, bool shard_by_col) const {
const TensorOpCost cost = contractionCost(bm * gm, bn * gn, bm, bn, bk, shard_by_col, true);
double taskSize = TensorCostModel<ThreadPoolDevice>::taskSize(static_cast<double>(bm) * gm * bn * gn, cost);
if (taskSize < 1) return 1;
if (taskSize > 2) return -1;
Index nm0 = numext::div_ceil(m, bm);
Index nn0 = numext::div_ceil(n, bn);
Index new_tasks = numext::div_ceil(nm0, gm) * numext::div_ceil(nn0, gn);
double new_parallelism =
static_cast<double>(new_tasks) / (numext::div_ceil<Index>(new_tasks, num_threads) * num_threads);
Index old_tasks = numext::div_ceil(nm0, oldgm) * numext::div_ceil(nn0, oldgn);
double old_parallelism =
static_cast<double>(old_tasks) / (numext::div_ceil<Index>(old_tasks, num_threads) * num_threads);
if (new_parallelism > old_parallelism || new_parallelism == 1) return 1;
return 0;
}
TensorOpCost contractionCost(Index m, Index n, Index bm, Index bn, Index bk, bool shard_by_col,
bool prepacked) const {
const int packed_size = std::min<int>(PacketType<LhsScalar, Device>::size, PacketType<RhsScalar, Device>::size);
const int output_packet_size = internal::unpacket_traits<PacketReturnType>::size;
const double kd = static_cast<double>(bk);
double compute_bandwidth = computeBandwidth(false, bm, bn, bk);
TensorOpCost cost = TensorOpCost(0, 0, kd * compute_bandwidth, true, packed_size);
cost += TensorOpCost(0, sizeof(CoeffReturnType), 0, true, output_packet_size);
if (prepacked) {
return cost;
}
TensorOpCost lhsCost = this->m_leftImpl.costPerCoeff(true) * (kd / n);
TensorOpCost rhsCost = this->m_rightImpl.costPerCoeff(true) * (kd / m);
if (shard_by_col)
lhsCost.dropMemoryCost();
else
rhsCost.dropMemoryCost();
return cost + lhsCost + rhsCost;
}
static bool shardByInnerDim(Index m, Index n, Index k, int num_threads, int num_threads_by_k) {
std::ptrdiff_t bufsize = m * n * sizeof(Scalar);
bool shard_by_k = false;
if (n == 1 ||
num_threads_by_k < 2 ||
num_threads_by_k < num_threads ||
bufsize > l3CacheSize() / num_threads_by_k ||
k / num_threads_by_k < 2 * Traits::nr) {
shard_by_k = false;
} else if (numext::maxi(m, n) / num_threads < Traits::nr ||
(k / num_threads_by_k > 8 * Traits::nr &&
(numext::mini(m, n) < 2 * Traits::nr || num_threads_by_k > num_threads))) {
shard_by_k = true;
}
return shard_by_k;
}
TensorOpCost contractionCostPerInnerDim(Index m, Index n, Index k) const {
const int output_packet_size = internal::unpacket_traits<PacketReturnType>::size;
TensorOpCost cost(0, 0, (computeBandwidth(true, m, n, k) * m) * n, true, output_packet_size);
cost += TensorOpCost(0, sizeof(CoeffReturnType), 0, true, output_packet_size);
TensorOpCost lhsCost = this->m_leftImpl.costPerCoeff(true) * m;
TensorOpCost rhsCost = this->m_rightImpl.costPerCoeff(true) * n;
lhsCost.dropMemoryCost();
return cost + lhsCost + rhsCost;
}
int numThreadsInnerDim(Index m, Index n, Index k) const {
const int output_packet_size = internal::unpacket_traits<PacketReturnType>::size;
TensorOpCost cost = contractionCostPerInnerDim(m, n, k);
double total_parallel_cost = TensorCostModel<ThreadPoolDevice>::totalCost(k, cost);
double reduction_cost =
TensorCostModel<ThreadPoolDevice>::totalCost(m * n, TensorOpCost(2, 1, 1, true, output_packet_size));
int num_threads = 1;
double min_cost = total_parallel_cost;
double kPerThreadOverHead = 3000;
double kFixedOverHead = 100000;
for (int nt = 2; nt <= this->m_device.numThreads(); nt += 2) {
double sequential_cost = kFixedOverHead + nt * (reduction_cost + kPerThreadOverHead);
double parallel_cost = total_parallel_cost / nt + sequential_cost;
if (parallel_cost < min_cost) {
num_threads = nt;
min_cost = parallel_cost;
}
}
return num_threads;
}
double computeBandwidth(bool shard_by_col, Index bm, Index bn, Index bk) const {
double computeBandwidth = bk == 1 ? 4.0
: (shard_by_col ? bn : bm) < Traits::nr || (shard_by_col ? bm : bn) < Traits::mr ? 2.0
: 0.5;
#ifndef EIGEN_VECTORIZE_FMA
if (computeBandwidth == 0.5) computeBandwidth = 1.0;
#endif
return computeBandwidth;
}
};
}
#endif
#endif