llvm/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h

//===- SparseTensorIterator.h ---------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_
#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_

#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"

namespace mlir {
namespace sparse_tensor {

// Forward declaration.
class SparseIterator;

/// The base class for all types of sparse tensor levels. It provides interfaces
/// to query the loop range (see `peekRangeAt`) and look up the coordinates (see
/// `peekCrdAt`).
class SparseTensorLevel {};

enum class IterKind : uint8_t {};

/// A `SparseIterationSpace` represents a sparse set of coordinates defined by
/// (possibly multiple) levels of a specific sparse tensor.
/// TODO: remove `SparseTensorLevel` and switch to SparseIterationSpace when
/// feature complete.
class SparseIterationSpace {};

/// Helper class that generates loop conditions, etc, to traverse a
/// sparse tensor level.
class SparseIterator {};

/// Helper function to create a TensorLevel object from given `tensor`.
std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(OpBuilder &b,
                                                         Location l, Value t,
                                                         unsigned tid,
                                                         Level lvl);

/// Helper function to create a TensorLevel object from given ValueRange.
std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(LevelType lt, Value sz,
                                                         ValueRange buffers,
                                                         unsigned tid, Level l);

/// Helper function to create a simple SparseIterator object that iterate
/// over the entire iteration space.
std::unique_ptr<SparseIterator>
makeSimpleIterator(OpBuilder &b, Location l,
                   const SparseIterationSpace &iterSpace);

/// Helper function to create a simple SparseIterator object that iterate
/// over the sparse tensor level.
/// TODO: switch to `SparseIterationSpace` (which support N-D iterator) when
/// feature complete.
std::unique_ptr<SparseIterator> makeSimpleIterator(
    const SparseTensorLevel &stl,
    SparseEmitStrategy strategy = SparseEmitStrategy::kFunctional);

/// Helper function to create a synthetic SparseIterator object that iterates
/// over a dense space specified by [0,`sz`).
std::pair<std::unique_ptr<SparseTensorLevel>, std::unique_ptr<SparseIterator>>
makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl,
                        SparseEmitStrategy strategy);

/// Helper function to create a SparseIterator object that iterates over a
/// sliced space, the orignal space (before slicing) is traversed by `sit`.
std::unique_ptr<SparseIterator>
makeSlicedLevelIterator(std::unique_ptr<SparseIterator> &&sit, Value offset,
                        Value stride, Value size, SparseEmitStrategy strategy);

/// Helper function to create a SparseIterator object that iterates over a
/// padded sparse level (the padded value must be zero).
std::unique_ptr<SparseIterator>
makePaddedIterator(std::unique_ptr<SparseIterator> &&sit, Value padLow,
                   Value padHigh, SparseEmitStrategy strategy);

/// Helper function to create a SparseIterator object that iterate over the
/// non-empty subsections set.
std::unique_ptr<SparseIterator> makeNonEmptySubSectIterator(
    OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound,
    std::unique_ptr<SparseIterator> &&delegate, Value size, unsigned stride,
    SparseEmitStrategy strategy);

/// Helper function to create a SparseIterator object that iterates over a
/// non-empty subsection created by NonEmptySubSectIterator.
std::unique_ptr<SparseIterator> makeTraverseSubSectIterator(
    OpBuilder &b, Location l, const SparseIterator &subsectIter,
    const SparseIterator &parent, std::unique_ptr<SparseIterator> &&wrap,
    Value loopBound, unsigned stride, SparseEmitStrategy strategy);

} // namespace sparse_tensor
} // namespace mlir

#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_