llvm/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h

//===- Merger.h - Utilities for defining lattices ---------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This header file defines utilities for dealing with iteration lattices.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_
#define MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_

#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SparseTensor/IR/Enums.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/BitVector.h"

#include <optional>

namespace mlir {
namespace sparse_tensor {

namespace detail {
/// A constant serving as the canonically invalid identifier,
/// regardless of the identifier type.
static constexpr unsigned kInvalidId =;
} // namespace detail

/// Tensor identifiers, chosen to be the `BlockArgument::getArgNumber`
/// of the value passed to `Merger::buildTensorExp`.
TensorId;

/// Loop identifiers.
LoopId;

/// A compressed representation of `std::pair<TensorId, LoopId>`.
/// The compression scheme is such that this also serves as an index
/// into the bitvector stored in `LatPoint` (since that bitvector is
/// just the implementation for a set of `TensorLoopId` values).
TensorLoopId;

/// `TensorExp` identifiers. These are allocated by `Merger::addExp`,
/// and serve as unique identifiers for the corresponding `TensorExp` object.
ExprId;

/// `LatPoint` identifiers. These are allocated by `Merger::addLat`,
/// and serve as unique identifiers for the corresponding `LatPoint` object.
LatPointId;

/// `LatSet` identifiers.  These are allocated by `Merger::addSet` (and
/// by other methods calling that one), and serve as unique identifiers
/// for the corresponding `SmallVector<LatPointId>` object.
LatSetId;

/// A pair of level and its corresponding LevelType of a tensor.
LvlLTPair;

/// A pair of loop id and its coefficients. E.g., for affine expression in the
/// affine map `2 * d0`, loop id = 0, coefficient = 2.
LoopCoeffPair;

/// Tensor expression. Represents an MLIR expression in tensor index notation.
struct TensorExp final {};

/// Tensor expression kind.
///
/// The `kLoopVar` leaf kind is for representing `linalg::IndexOp`.
/// That is, its argument is a `LoopId` identifying the loop-variable
/// in question, and its value will be the current iteration's value.
/// The `kSynZero` leaf kind is for representing a synthetic zero value,
/// which can be introduced when sparsifying operations like `arith::cmp`
/// to generate `arith::cmp %lhs, %syn_zero` when the rhs operand is absent.
enum class TensorExp::Kind {};

/// Lattice point.  Each lattice point consists of a formal conjunction
/// of `TensorLoopId`s, together with the identifier of the corresponding
/// tensor expression.  The formal conjunction is represented as a set of
/// `TensorLoopId`, where that set is implemented as a `BitVector`.
struct LatPoint final {};

/// A class to handle all iteration lattice operations. This class abstracts
/// away from some implementation details of storing iteration lattices and
/// tensor expressions. This allows for fine-tuning performance characteristics
/// independently from the basic algorithm if bottlenecks are identified.
class Merger {};

} // namespace sparse_tensor
} // namespace mlir

#endif // MLIR_DIALECT_SPARSETENSOR_UTILS_MERGER_H_