#include "mlir/Dialect/SparseTensor/Utils/Merger.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/IR/Operation.h"
#include "llvm/Support/Debug.h"
#include <optional>
namespace mlir {
namespace sparse_tensor {
enum class ExpArity { … };
static ExpArity getExpArity(TensorExp::Kind k) { … }
TensorExp::TensorExp(TensorExp::Kind k, unsigned x, ExprId y, Value v,
Operation *o, Attribute a)
: … { … }
Merger::Merger(unsigned numInputOutputTensors, unsigned numLoops,
unsigned maxLvlRank)
: … { … }
ExprId Merger::addTensorExp(TensorId t) { … }
ExprId Merger::addLoopVarExp(LoopId i) { … }
ExprId Merger::addInvariantExp(Value v) { … }
ExprId Merger::addSynZeroExp() { … }
ExprId Merger::addExp(TensorExp::Kind k, ExprId e0, ExprId e1, Operation *op,
Attribute attr) { … }
ExprId Merger::addExp(TensorExp::Kind k, ExprId e, Value v, Operation *op,
Attribute attr) { … }
LatPointId Merger::addLat(TensorId t, LoopId i, ExprId e) { … }
LatPointId Merger::addLat(const BitVector &bits, ExprId e) { … }
LatSetId Merger::addSet() { … }
LatPointId Merger::conjLat(ExprId e, LatPointId p0, LatPointId p1,
Operation *op) { … }
LatSetId Merger::conjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op) { … }
LatSetId Merger::disjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op) { … }
LatSetId Merger::disjSetWithZero(ExprId e, LatSetId s0, LatSetId s1) { … }
LatSetId Merger::combiSet(ExprId e, LatSetId s0, LatSetId s1, Operation *orig,
bool includeLeft, TensorExp::Kind ltrans,
Operation *opleft, bool includeRight,
TensorExp::Kind rtrans, Operation *opright) { … }
LatSetId Merger::mapSet(TensorExp::Kind kind, LatSetId s0, Value v,
Operation *op, Attribute a) { … }
LatSetId Merger::mapBinWithSynZeroSet(ExprId e, LatSetId s0, bool lhsZero) { … }
LatSetId Merger::optimizeSet(LatSetId s0) { … }
BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) { … }
bool Merger::latGT(LatPointId i, LatPointId j) const { … }
bool Merger::onlyDenseDiff(LatPointId i, LatPointId j) const { … }
bool Merger::expContainsTensor(ExprId e, TensorId t) const { … }
bool Merger::hasNegateOnOut(ExprId e) const { … }
bool Merger::isSingleCondition(TensorId t, ExprId e) const { … }
bool Merger::hasAnySparse(const BitVector &bits) const { … }
bool Merger::hasSparseIdxReduction(const BitVector &bits) const { … }
#ifndef NDEBUG
static const char *kindToOpSymbol(TensorExp::Kind kind) {
switch (kind) {
case TensorExp::Kind::kTensor:
return "tensor";
case TensorExp::Kind::kInvariant:
return "invariant";
case TensorExp::Kind::kLoopVar:
return "index";
case TensorExp::Kind::kSynZero:
return "0";
case TensorExp::Kind::kAbsF:
case TensorExp::Kind::kAbsC:
case TensorExp::Kind::kAbsI:
return "abs";
case TensorExp::Kind::kCeilF:
return "ceil";
case TensorExp::Kind::kFloorF:
return "floor";
case TensorExp::Kind::kSqrtF:
case TensorExp::Kind::kSqrtC:
return "sqrt";
case TensorExp::Kind::kExpm1F:
case TensorExp::Kind::kExpm1C:
return "expm1";
case TensorExp::Kind::kLog1pF:
case TensorExp::Kind::kLog1pC:
return "log1p";
case TensorExp::Kind::kRelu:
return "relu";
case TensorExp::Kind::kSinF:
case TensorExp::Kind::kSinC:
return "sin";
case TensorExp::Kind::kTanhF:
case TensorExp::Kind::kTanhC:
return "tanh";
case TensorExp::Kind::kNegF:
case TensorExp::Kind::kNegC:
case TensorExp::Kind::kNegI:
return "-";
case TensorExp::Kind::kTruncF:
case TensorExp::Kind::kExtF:
case TensorExp::Kind::kCastFS:
case TensorExp::Kind::kCastFU:
case TensorExp::Kind::kCastSF:
case TensorExp::Kind::kCastUF:
case TensorExp::Kind::kCastS:
case TensorExp::Kind::kCastU:
case TensorExp::Kind::kCastIdx:
case TensorExp::Kind::kTruncI:
case TensorExp::Kind::kCIm:
return "complex.im";
case TensorExp::Kind::kCRe:
return "complex.re";
case TensorExp::Kind::kBitCast:
return "cast";
case TensorExp::Kind::kBinaryBranch:
return "binary_branch";
case TensorExp::Kind::kUnary:
return "unary";
case TensorExp::Kind::kSelect:
return "select";
case TensorExp::Kind::kMulF:
case TensorExp::Kind::kMulC:
case TensorExp::Kind::kMulI:
return "*";
case TensorExp::Kind::kDivF:
case TensorExp::Kind::kDivC:
case TensorExp::Kind::kDivS:
case TensorExp::Kind::kDivU:
return "/";
case TensorExp::Kind::kAddF:
case TensorExp::Kind::kAddC:
case TensorExp::Kind::kAddI:
return "+";
case TensorExp::Kind::kSubF:
case TensorExp::Kind::kSubC:
case TensorExp::Kind::kSubI:
return "-";
case TensorExp::Kind::kAndI:
return "&";
case TensorExp::Kind::kOrI:
return "|";
case TensorExp::Kind::kXorI:
return "^";
case TensorExp::Kind::kShrS:
return "a>>";
case TensorExp::Kind::kShrU:
return ">>";
case TensorExp::Kind::kShlI:
return "<<";
case TensorExp::Kind::kCmpF:
case TensorExp::Kind::kCmpI:
return "cmp";
case TensorExp::Kind::kBinary:
return "binary";
case TensorExp::Kind::kReduce:
return "reduce";
case TensorExp::Kind::kDenseOp:
return "dense";
}
llvm_unreachable("unexpected kind for symbol");
}
void Merger::dumpExp(ExprId e) const {
const auto &expr = exp(e);
switch (expr.kind) {
case TensorExp::Kind::kTensor:
if (expr.tensor == syntheticTensor)
llvm::dbgs() << "synthetic_";
else if (expr.tensor == outTensor)
llvm::dbgs() << "output_";
llvm::dbgs() << "tensor_" << expr.tensor;
break;
case TensorExp::Kind::kInvariant:
llvm::dbgs() << "invariant";
break;
case TensorExp::Kind::kSynZero:
llvm::dbgs() << "0";
break;
case TensorExp::Kind::kLoopVar:
llvm::dbgs() << "loopvar_" << expr.loop;
break;
case TensorExp::Kind::kAbsF:
case TensorExp::Kind::kAbsC:
case TensorExp::Kind::kAbsI:
case TensorExp::Kind::kCeilF:
case TensorExp::Kind::kFloorF:
case TensorExp::Kind::kSqrtF:
case TensorExp::Kind::kSqrtC:
case TensorExp::Kind::kExpm1F:
case TensorExp::Kind::kExpm1C:
case TensorExp::Kind::kLog1pF:
case TensorExp::Kind::kLog1pC:
case TensorExp::Kind::kRelu:
case TensorExp::Kind::kSinF:
case TensorExp::Kind::kSinC:
case TensorExp::Kind::kTanhF:
case TensorExp::Kind::kTanhC:
case TensorExp::Kind::kNegF:
case TensorExp::Kind::kNegC:
case TensorExp::Kind::kNegI:
case TensorExp::Kind::kTruncF:
case TensorExp::Kind::kExtF:
case TensorExp::Kind::kCastFS:
case TensorExp::Kind::kCastFU:
case TensorExp::Kind::kCastSF:
case TensorExp::Kind::kCastUF:
case TensorExp::Kind::kCastS:
case TensorExp::Kind::kCastU:
case TensorExp::Kind::kCastIdx:
case TensorExp::Kind::kTruncI:
case TensorExp::Kind::kCIm:
case TensorExp::Kind::kCRe:
case TensorExp::Kind::kBitCast:
case TensorExp::Kind::kBinaryBranch:
case TensorExp::Kind::kUnary:
case TensorExp::Kind::kSelect:
llvm::dbgs() << kindToOpSymbol(expr.kind) << " ";
dumpExp(expr.children.e0);
break;
case TensorExp::Kind::kMulF:
case TensorExp::Kind::kMulC:
case TensorExp::Kind::kMulI:
case TensorExp::Kind::kDivF:
case TensorExp::Kind::kDivC:
case TensorExp::Kind::kDivS:
case TensorExp::Kind::kDivU:
case TensorExp::Kind::kAddF:
case TensorExp::Kind::kAddC:
case TensorExp::Kind::kAddI:
case TensorExp::Kind::kSubF:
case TensorExp::Kind::kSubC:
case TensorExp::Kind::kSubI:
case TensorExp::Kind::kAndI:
case TensorExp::Kind::kOrI:
case TensorExp::Kind::kXorI:
case TensorExp::Kind::kShrS:
case TensorExp::Kind::kShrU:
case TensorExp::Kind::kShlI:
case TensorExp::Kind::kCmpF:
case TensorExp::Kind::kCmpI:
case TensorExp::Kind::kBinary:
case TensorExp::Kind::kReduce:
case TensorExp::Kind::kDenseOp:
llvm::dbgs() << "(";
dumpExp(expr.children.e0);
llvm::dbgs() << " " << kindToOpSymbol(expr.kind);
if (expr.attr)
llvm::dbgs() << "{" << expr.attr << "}";
if (expr.children.e1 != detail::kInvalidId) {
llvm::dbgs() << " ";
dumpExp(expr.children.e1);
llvm::dbgs() << ")";
} else {
assert(expr.kind == TensorExp::Kind::kDenseOp);
}
break;
}
}
void Merger::dumpLat(LatPointId p) const {
const auto &point = lat(p);
llvm::dbgs() << "lat(";
dumpBits(point.bits);
llvm::dbgs() << " :";
dumpBits(point.simple);
llvm::dbgs() << " : ";
dumpExp(point.exp);
llvm::dbgs() << " )\n";
}
void Merger::dumpSet(LatSetId s) const {
const auto &ss = set(s);
llvm::dbgs() << "{ #" << ss.size() << "\n";
for (const LatPointId p : ss) {
llvm::dbgs() << " ";
dumpLat(p);
}
llvm::dbgs() << "}\n";
}
void Merger::dumpBits(const BitVector &bits) const {
for (TensorLoopId b = 0, be = bits.size(); b < be; b++) {
if (bits[b]) {
const TensorId t = tensor(b);
const LoopId i = loop(b);
const auto lt = lvlTypes[t][i];
if (isLvlWithNonTrivialIdxExp(b))
llvm::dbgs() << " DEP_" << t << "_" << i;
else
llvm::dbgs() << " i_" << t << "_" << i << "_" << toMLIRString(lt);
}
}
}
#endif
LatSetId Merger::buildLattices(ExprId e, LoopId i) { … }
std::optional<ExprId> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) { … }
static bool isCertainZero(Value val) { … }
bool Merger::maybeZero(ExprId e) const { … }
Type Merger::inferType(ExprId e, Value src) const { … }
static bool isAdmissibleBranchExp(Operation *op, Block *block, Value v) { … }
static bool isAdmissibleBranch(Operation *op, Region ®ion) { … }
static bool isGreater(TensorExp::Kind kind, Attribute attr) { … }
std::pair<std::optional<ExprId>, bool>
Merger::buildTensorExp(linalg::GenericOp op, Value v) { … }
static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region ®ion,
ValueRange vals) { … }
static Value buildUnaryPresent(RewriterBase &rewriter, Location loc,
Operation *op, Value v0) { … }
static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc,
Operation *op, Value v0, Value v1) { … }
static Value buildRelu(RewriterBase &rewriter, Location loc, Value v0,
Attribute attr) { … }
Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0,
Value v1) const { … }
}
}