llvm/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp

//===- Merger.cpp - Implementation of iteration lattices ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/SparseTensor/Utils/Merger.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"

#include "mlir/IR/Operation.h"
#include "llvm/Support/Debug.h"
#include <optional>

namespace mlir {
namespace sparse_tensor {

enum class ExpArity {};

static ExpArity getExpArity(TensorExp::Kind k) {}

//===----------------------------------------------------------------------===//
// Constructors.
//===----------------------------------------------------------------------===//

TensorExp::TensorExp(TensorExp::Kind k, unsigned x, ExprId y, Value v,
                     Operation *o, Attribute a)
    :{}

Merger::Merger(unsigned numInputOutputTensors, unsigned numLoops,
               unsigned maxLvlRank)
    :{}

//===----------------------------------------------------------------------===//
// Lattice methods.
//===----------------------------------------------------------------------===//

ExprId Merger::addTensorExp(TensorId t) {}

ExprId Merger::addLoopVarExp(LoopId i) {}

ExprId Merger::addInvariantExp(Value v) {}

ExprId Merger::addSynZeroExp() {}

ExprId Merger::addExp(TensorExp::Kind k, ExprId e0, ExprId e1, Operation *op,
                      Attribute attr) {}

ExprId Merger::addExp(TensorExp::Kind k, ExprId e, Value v, Operation *op,
                      Attribute attr) {}

LatPointId Merger::addLat(TensorId t, LoopId i, ExprId e) {}

LatPointId Merger::addLat(const BitVector &bits, ExprId e) {}

LatSetId Merger::addSet() {}

LatPointId Merger::conjLat(ExprId e, LatPointId p0, LatPointId p1,
                           Operation *op) {}

LatSetId Merger::conjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op) {}

LatSetId Merger::disjSet(ExprId e, LatSetId s0, LatSetId s1, Operation *op) {}

LatSetId Merger::disjSetWithZero(ExprId e, LatSetId s0, LatSetId s1) {}

LatSetId Merger::combiSet(ExprId e, LatSetId s0, LatSetId s1, Operation *orig,
                          bool includeLeft, TensorExp::Kind ltrans,
                          Operation *opleft, bool includeRight,
                          TensorExp::Kind rtrans, Operation *opright) {}

LatSetId Merger::mapSet(TensorExp::Kind kind, LatSetId s0, Value v,
                        Operation *op, Attribute a) {}

LatSetId Merger::mapBinWithSynZeroSet(ExprId e, LatSetId s0, bool lhsZero) {}

LatSetId Merger::optimizeSet(LatSetId s0) {}

BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {}

bool Merger::latGT(LatPointId i, LatPointId j) const {}

bool Merger::onlyDenseDiff(LatPointId i, LatPointId j) const {}

bool Merger::expContainsTensor(ExprId e, TensorId t) const {}

bool Merger::hasNegateOnOut(ExprId e) const {}

bool Merger::isSingleCondition(TensorId t, ExprId e) const {}

bool Merger::hasAnySparse(const BitVector &bits) const {}

bool Merger::hasSparseIdxReduction(const BitVector &bits) const {}

#ifndef NDEBUG

//===----------------------------------------------------------------------===//
// Print methods (for debugging).
//===----------------------------------------------------------------------===//

static const char *kindToOpSymbol(TensorExp::Kind kind) {
  switch (kind) {
  // Leaf.
  case TensorExp::Kind::kTensor:
    return "tensor";
  case TensorExp::Kind::kInvariant:
    return "invariant";
  case TensorExp::Kind::kLoopVar:
    return "index";
  case TensorExp::Kind::kSynZero:
    return "0";
  // Unary operations.
  case TensorExp::Kind::kAbsF:
  case TensorExp::Kind::kAbsC:
  case TensorExp::Kind::kAbsI:
    return "abs";
  case TensorExp::Kind::kCeilF:
    return "ceil";
  case TensorExp::Kind::kFloorF:
    return "floor";
  case TensorExp::Kind::kSqrtF:
  case TensorExp::Kind::kSqrtC:
    return "sqrt";
  case TensorExp::Kind::kExpm1F:
  case TensorExp::Kind::kExpm1C:
    return "expm1";
  case TensorExp::Kind::kLog1pF:
  case TensorExp::Kind::kLog1pC:
    return "log1p";
  case TensorExp::Kind::kRelu:
    return "relu";
  case TensorExp::Kind::kSinF:
  case TensorExp::Kind::kSinC:
    return "sin";
  case TensorExp::Kind::kTanhF:
  case TensorExp::Kind::kTanhC:
    return "tanh";
  case TensorExp::Kind::kNegF:
  case TensorExp::Kind::kNegC:
  case TensorExp::Kind::kNegI:
    return "-";
  case TensorExp::Kind::kTruncF:
  case TensorExp::Kind::kExtF:
  case TensorExp::Kind::kCastFS:
  case TensorExp::Kind::kCastFU:
  case TensorExp::Kind::kCastSF:
  case TensorExp::Kind::kCastUF:
  case TensorExp::Kind::kCastS:
  case TensorExp::Kind::kCastU:
  case TensorExp::Kind::kCastIdx:
  case TensorExp::Kind::kTruncI:
  case TensorExp::Kind::kCIm:
    return "complex.im";
  case TensorExp::Kind::kCRe:
    return "complex.re";
  case TensorExp::Kind::kBitCast:
    return "cast";
  case TensorExp::Kind::kBinaryBranch:
    return "binary_branch";
  case TensorExp::Kind::kUnary:
    return "unary";
  case TensorExp::Kind::kSelect:
    return "select";
  // Binary operations.
  case TensorExp::Kind::kMulF:
  case TensorExp::Kind::kMulC:
  case TensorExp::Kind::kMulI:
    return "*";
  case TensorExp::Kind::kDivF:
  case TensorExp::Kind::kDivC:
  case TensorExp::Kind::kDivS:
  case TensorExp::Kind::kDivU:
    return "/";
  case TensorExp::Kind::kAddF:
  case TensorExp::Kind::kAddC:
  case TensorExp::Kind::kAddI:
    return "+";
  case TensorExp::Kind::kSubF:
  case TensorExp::Kind::kSubC:
  case TensorExp::Kind::kSubI:
    return "-";
  case TensorExp::Kind::kAndI:
    return "&";
  case TensorExp::Kind::kOrI:
    return "|";
  case TensorExp::Kind::kXorI:
    return "^";
  case TensorExp::Kind::kShrS:
    return "a>>";
  case TensorExp::Kind::kShrU:
    return ">>";
  case TensorExp::Kind::kShlI:
    return "<<";
  case TensorExp::Kind::kCmpF:
  case TensorExp::Kind::kCmpI:
    return "cmp";
  case TensorExp::Kind::kBinary:
    return "binary";
  case TensorExp::Kind::kReduce:
    return "reduce";
  case TensorExp::Kind::kDenseOp:
    return "dense";
  }
  llvm_unreachable("unexpected kind for symbol");
}

void Merger::dumpExp(ExprId e) const {
  const auto &expr = exp(e);
  switch (expr.kind) {
  // Leaf.
  case TensorExp::Kind::kTensor:
    if (expr.tensor == syntheticTensor)
      llvm::dbgs() << "synthetic_";
    else if (expr.tensor == outTensor)
      llvm::dbgs() << "output_";
    llvm::dbgs() << "tensor_" << expr.tensor;
    break;
  case TensorExp::Kind::kInvariant:
    llvm::dbgs() << "invariant";
    break;
  case TensorExp::Kind::kSynZero:
    llvm::dbgs() << "0";
    break;
  case TensorExp::Kind::kLoopVar:
    llvm::dbgs() << "loopvar_" << expr.loop;
    break;
  // Unary operations.
  case TensorExp::Kind::kAbsF:
  case TensorExp::Kind::kAbsC:
  case TensorExp::Kind::kAbsI:
  case TensorExp::Kind::kCeilF:
  case TensorExp::Kind::kFloorF:
  case TensorExp::Kind::kSqrtF:
  case TensorExp::Kind::kSqrtC:
  case TensorExp::Kind::kExpm1F:
  case TensorExp::Kind::kExpm1C:
  case TensorExp::Kind::kLog1pF:
  case TensorExp::Kind::kLog1pC:
  case TensorExp::Kind::kRelu:
  case TensorExp::Kind::kSinF:
  case TensorExp::Kind::kSinC:
  case TensorExp::Kind::kTanhF:
  case TensorExp::Kind::kTanhC:
  case TensorExp::Kind::kNegF:
  case TensorExp::Kind::kNegC:
  case TensorExp::Kind::kNegI:
  case TensorExp::Kind::kTruncF:
  case TensorExp::Kind::kExtF:
  case TensorExp::Kind::kCastFS:
  case TensorExp::Kind::kCastFU:
  case TensorExp::Kind::kCastSF:
  case TensorExp::Kind::kCastUF:
  case TensorExp::Kind::kCastS:
  case TensorExp::Kind::kCastU:
  case TensorExp::Kind::kCastIdx:
  case TensorExp::Kind::kTruncI:
  case TensorExp::Kind::kCIm:
  case TensorExp::Kind::kCRe:
  case TensorExp::Kind::kBitCast:
  case TensorExp::Kind::kBinaryBranch:
  case TensorExp::Kind::kUnary:
  case TensorExp::Kind::kSelect:
    llvm::dbgs() << kindToOpSymbol(expr.kind) << " ";
    dumpExp(expr.children.e0);
    break;
  // Binary operations.
  case TensorExp::Kind::kMulF:
  case TensorExp::Kind::kMulC:
  case TensorExp::Kind::kMulI:
  case TensorExp::Kind::kDivF:
  case TensorExp::Kind::kDivC:
  case TensorExp::Kind::kDivS:
  case TensorExp::Kind::kDivU:
  case TensorExp::Kind::kAddF:
  case TensorExp::Kind::kAddC:
  case TensorExp::Kind::kAddI:
  case TensorExp::Kind::kSubF:
  case TensorExp::Kind::kSubC:
  case TensorExp::Kind::kSubI:
  case TensorExp::Kind::kAndI:
  case TensorExp::Kind::kOrI:
  case TensorExp::Kind::kXorI:
  case TensorExp::Kind::kShrS:
  case TensorExp::Kind::kShrU:
  case TensorExp::Kind::kShlI:
  case TensorExp::Kind::kCmpF:
  case TensorExp::Kind::kCmpI:
  case TensorExp::Kind::kBinary:
  case TensorExp::Kind::kReduce:
  case TensorExp::Kind::kDenseOp:
    llvm::dbgs() << "(";
    dumpExp(expr.children.e0);
    llvm::dbgs() << " " << kindToOpSymbol(expr.kind);
    if (expr.attr)
      llvm::dbgs() << "{" << expr.attr << "}";
    if (expr.children.e1 != detail::kInvalidId) {
      llvm::dbgs() << " ";
      dumpExp(expr.children.e1);
      llvm::dbgs() << ")";
    } else {
      assert(expr.kind == TensorExp::Kind::kDenseOp);
    }
    break;
  }
}

void Merger::dumpLat(LatPointId p) const {
  const auto &point = lat(p);
  llvm::dbgs() << "lat(";
  dumpBits(point.bits);
  llvm::dbgs() << " :";
  dumpBits(point.simple);
  llvm::dbgs() << " : ";
  dumpExp(point.exp);
  llvm::dbgs() << " )\n";
}

void Merger::dumpSet(LatSetId s) const {
  const auto &ss = set(s);
  llvm::dbgs() << "{ #" << ss.size() << "\n";
  for (const LatPointId p : ss) {
    llvm::dbgs() << "  ";
    dumpLat(p);
  }
  llvm::dbgs() << "}\n";
}

void Merger::dumpBits(const BitVector &bits) const {
  for (TensorLoopId b = 0, be = bits.size(); b < be; b++) {
    if (bits[b]) {
      const TensorId t = tensor(b);
      const LoopId i = loop(b);
      const auto lt = lvlTypes[t][i];
      if (isLvlWithNonTrivialIdxExp(b))
        llvm::dbgs() << " DEP_" << t << "_" << i;
      else
        llvm::dbgs() << " i_" << t << "_" << i << "_" << toMLIRString(lt);
    }
  }
}

#endif // NDEBUG

//===----------------------------------------------------------------------===//
// Builder methods.
//===----------------------------------------------------------------------===//

LatSetId Merger::buildLattices(ExprId e, LoopId i) {}

std::optional<ExprId> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {}

/// Only returns true if we are certain this is a zero.
static bool isCertainZero(Value val) {}

/// Only returns false if we are certain this is a nonzero.
bool Merger::maybeZero(ExprId e) const {}

Type Merger::inferType(ExprId e, Value src) const {}

/// Ensures that the sparsifier can generate code for expression.
static bool isAdmissibleBranchExp(Operation *op, Block *block, Value v) {}

/// Ensures that the sparsifier can generate code for branch.
static bool isAdmissibleBranch(Operation *op, Region &region) {}

// Recognizes a direct GT comparison.
static bool isGreater(TensorExp::Kind kind, Attribute attr) {}

std::pair<std::optional<ExprId>, bool>
Merger::buildTensorExp(linalg::GenericOp op, Value v) {}

static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region &region,
                           ValueRange vals) {}

static Value buildUnaryPresent(RewriterBase &rewriter, Location loc,
                               Operation *op, Value v0) {}

static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc,
                                Operation *op, Value v0, Value v1) {}

static Value buildRelu(RewriterBase &rewriter, Location loc, Value v0,
                       Attribute attr) {}

Value Merger::buildExp(RewriterBase &rewriter, Location loc, ExprId e, Value v0,
                       Value v1) const {}

} // namespace sparse_tensor
} // namespace mlir