#include "LoopEmitter.h"
#include "CodegenUtils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
usingnamespacemlir;
usingnamespacemlir::sparse_tensor;
#define CMPI …
#define C_IDX …
#define YIELD …
#define ADDI …
#define ANDI …
#define SUBI …
#define MULI …
#define REMUI(lhs, rhs) …
#define DIVUI(lhs, rhs) …
#define SELECT …
#ifndef NDEBUG
LLVM_ATTRIBUTE_UNUSED static void dumpIndexMemRef(OpBuilder &builder,
Location loc, Value memref) {
memref = builder.create<memref::CastOp>(
loc, UnrankedMemRefType::get(builder.getIndexType(), 0), memref);
createFuncCall(builder, loc, "printMemrefInd", TypeRange{},
ValueRange{memref}, EmitCInterface::On);
}
#endif
static Value genSliceOffset(OpBuilder &builder, Location loc, Value tensor,
Level lvl) { … }
static Value genSliceStride(OpBuilder &builder, Location loc, Value tensor,
Level lvl) { … }
static bool isIntOrFPZero(Attribute attr) { … }
static Value unFoldOpIntResult(OpBuilder &builder, Location loc,
OpFoldResult ofr) { … }
static Value tryFoldTensors(Value t) { … }
LoopEmitter::LoopEmitter(ValueRange tensors, StringAttr loopTag, bool hasOutput,
bool isSparseOut, unsigned numLoops,
DependentLvlGetter dimGetter,
SparseEmitStrategy emitStrategy) { … }
void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
bool isSparseOut, unsigned numLoops,
DependentLvlGetter dimGetter,
SparseEmitStrategy emitStrategy) { … }
std::unique_ptr<SparseIterator>
LoopEmitter::makeLevelIterator(OpBuilder &builder, Location loc, TensorId t,
Level l) { … }
void LoopEmitter::initializeLoopEmit(
OpBuilder &builder, Location loc, LoopEmitter::OutputUpdater updater,
LoopEmitter::SynTensorBoundSetter synSetter) { … }
void LoopEmitter::initSubSectIterator(OpBuilder &builder, Location loc) { … }
void LoopEmitter::categorizeIterators(
ArrayRef<TensorLevel> tidLvls, SmallVectorImpl<SparseIterator *> &raIters,
SmallVectorImpl<SparseIterator *> &spIters) { … }
void LoopEmitter::enterNewLoopSeq(OpBuilder &builder, Location loc,
ArrayRef<TensorLevel> tidLvls) { … }
void LoopEmitter::exitCurrentLoopSeq(OpBuilder &builder, Location loc) { … }
Value LoopEmitter::genAffine(OpBuilder &builder, Location loc, AffineExpr a) { … }
std::pair<Operation *, Value> LoopEmitter::emitForLoopOverTensorAtLvl(
OpBuilder &builder, Location loc, SparseIterator &iter,
MutableArrayRef<Value> reduc, bool isParallel) { … }
std::pair<Operation *, Value> LoopEmitter::emitWhileLoopOverTensorsAtLvls(
OpBuilder &builder, Location loc, ArrayRef<SparseIterator *> spIters,
MutableArrayRef<Value> reduc, bool needsUniv) { … }
bool LoopEmitter::shouldIteratedByForLoop(ArrayRef<SparseIterator *> spIters) { … }
Region *LoopEmitter::enterCurrentCoIterationCase(OpBuilder &builder,
Location loc,
I64BitSet caseBit,
unsigned caseIdx,
MutableArrayRef<Value> reduc) { … }
Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
OpBuilder &builder, Location loc, ArrayRef<TensorLevel> tidLvls,
unsigned numCases, MutableArrayRef<Value> reduc, bool tryParallel,
bool needsUniv) { … }
void LoopEmitter::locateLvlAtAffineAddress(OpBuilder &builder, Location loc,
TensorLevel tidLvl,
AffineExpr lvlExpr) { … }
void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
TensorId tid, Level lvl) { … }
void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
MutableArrayRef<Value> reduc) { … }
void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
MutableArrayRef<Value> reduc) { … }
void LoopEmitter::exitCurrentLoop(RewriterBase &rewriter, Location loc,
MutableArrayRef<Value> reduc) { … }
std::pair<Operation *, Value> sparse_tensor::genCoIteration(
OpBuilder &builder, Location loc, ArrayRef<SparseIterator *> spIters,
MutableArrayRef<Value> reduc, Value uniIdx, bool userReducFirst) { … }
#undef CMPI
#undef C_IDX
#undef YIELD
#undef ADDI
#undef ANDI
#undef SUBI
#undef MULI
#undef SELECT