llvm/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp

//===- SparseBufferRewriting.cpp - Sparse buffer rewriting rules ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements rewriting rules that are specific to sparse tensor
// primitives with memref operands.
//
//===----------------------------------------------------------------------===//

#include "Utils/CodegenUtils.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Support/LLVM.h"

usingnamespacemlir;
usingnamespacemlir::sparse_tensor;

//===---------------------------------------------------------------------===//
// Helper methods for the actual rewriting rules.
//===---------------------------------------------------------------------===//

static constexpr uint64_t loIdx =;
static constexpr uint64_t hiIdx =;
static constexpr uint64_t xStartIdx =;

static constexpr const char kPartitionFuncNamePrefix[] =;
static constexpr const char kBinarySearchFuncNamePrefix[] =;
static constexpr const char kHybridQuickSortFuncNamePrefix[] =;
static constexpr const char kSortStableFuncNamePrefix[] =;
static constexpr const char kShiftDownFuncNamePrefix[] =;
static constexpr const char kHeapSortFuncNamePrefix[] =;
static constexpr const char kQuickSortFuncNamePrefix[] =;

FuncGeneratorType;

/// Constructs a function name with this format to facilitate quick sort:
///   <namePrefix><xPerm>_<x type>_<y0 type>..._<yn type> for sort
///   <namePrefix><xPerm>_<x type>_coo_<ny>_<y0 type>..._<yn type> for sort_coo
static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream,
                                         StringRef namePrefix, AffineMap xPerm,
                                         uint64_t ny, ValueRange operands) {}

/// Looks up a function that is appropriate for the given operands being
/// sorted, and creates such a function if it doesn't exist yet. The
/// parameters `xPerm` and `ny` tell the number of x and y values provided
/// by the buffer in xStartIdx.
//
// All sorting function generators take (lo, hi, xs, ys) in `operands` as
// parameters for the sorting functions. Other parameters, such as the recursive
// call depth, are appended to the end of the parameter list as
// "trailing parameters".
static FlatSymbolRefAttr getMangledSortHelperFunc(
    OpBuilder &builder, func::FuncOp insertPoint, TypeRange resultTypes,
    StringRef namePrefix, AffineMap xPerm, uint64_t ny, ValueRange operands,
    FuncGeneratorType createFunc, uint32_t nTrailingP = 0) {}

/// Creates a code block to process each pair of (xs[i], xs[j]) for sorting.
/// The code to process the value pairs is generated by `bodyBuilder`.
static void forEachIJPairInXs(
    OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
    uint64_t ny,
    function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {}

/// Creates a code block to process each pair of (xys[i], xys[j]) for sorting.
/// The code to process the value pairs is generated by `bodyBuilder`.
static void forEachIJPairInAllBuffers(
    OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
    uint64_t ny,
    function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {}

/// Creates a code block for swapping the values in index i and j for all the
/// buffers.
//
// The generated IR corresponds to this C like algorithm:
//     swap(x0[i], x0[j]);
//     swap(x1[i], x1[j]);
//     ...
//     swap(xn[i], xn[j]);
//     swap(y0[i], y0[j]);
//     ...
//     swap(yn[i], yn[j]);
static void createSwap(OpBuilder &builder, Location loc, ValueRange args,
                       AffineMap xPerm, uint64_t ny) {}

/// Creates code to compare all the (xs[i], xs[j]) pairs. The method to compare
/// each pair is create via `compareBuilder`.
static Value createInlinedCompareImplementation(
    OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
    uint64_t ny,
    function_ref<Value(OpBuilder &, Location, Value, Value, Value, bool, bool)>
        compareBuilder) {}

/// Generates code to compare whether x[i] is equal to x[j] and returns the
/// result of the comparison.
static Value createEqCompare(OpBuilder &builder, Location loc, Value i, Value j,
                             Value x, bool isFirstDim, bool isLastDim) {}

/// Creates code to compare whether xs[i] is equal to xs[j].
//
// The generate IR corresponds to this C like algorithm:
//   if (x0[i] != x0[j])
//     return false;
//   else
//     if (x1[i] != x1[j])
//       return false;
//     else if (x2[2] != x2[j]))
//       and so on ...
static Value createInlinedEqCompare(OpBuilder &builder, Location loc,
                                    ValueRange args, AffineMap xPerm,
                                    uint64_t ny, uint32_t nTrailingP = 0) {}

/// Generates code to compare whether x[i] is less than x[j] and returns the
/// result of the comparison.
static Value createLessThanCompare(OpBuilder &builder, Location loc, Value i,
                                   Value j, Value x, bool isFirstDim,
                                   bool isLastDim) {}

/// Creates code to compare whether xs[i] is less than xs[j].
//
// The generate IR corresponds to this C like algorithm:
//   if (x0[i] != x0[j])
//     return x0[i] < x0[j];
//   else if (x1[j] != x1[i])
//     return x1[i] < x1[j];
//   else
//       and so on ...
static Value createInlinedLessThan(OpBuilder &builder, Location loc,
                                   ValueRange args, AffineMap xPerm,
                                   uint64_t ny, uint32_t nTrailingP = 0) {}

/// Creates a function to use a binary search to find the insertion point for
/// inserting xs[hi] to the sorted values xs[lo..hi).
//
// The generate IR corresponds to this C like algorithm:
//   p = hi
//   while (lo < hi)
//      mid = (lo + hi) >> 1
//      if (xs[p] < xs[mid])
//        hi = mid
//      else
//        lo = mid - 1
//   return lo;
//
static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
                                   func::FuncOp func, AffineMap xPerm,
                                   uint64_t ny, uint32_t nTrailingP = 0) {}

/// Creates code to advance i in a loop based on xs[p] as follows:
///   while (xs[i] < xs[p]) i += step (step > 0)
/// or
///   while (xs[i] > xs[p]) i += step (step < 0)
/// The routine returns i as well as a boolean value to indicate whether
/// xs[i] == xs[p].
static std::pair<Value, Value> createScanLoop(OpBuilder &builder,
                                              ModuleOp module,
                                              func::FuncOp func, ValueRange xs,
                                              Value i, Value p, AffineMap xPerm,
                                              uint64_t ny, int step) {}

/// Creates and returns an IfOp to compare two elements and swap the elements
/// if compareFunc(data[b], data[a]) returns true. The new insertion point is
/// right after the swap instructions.
static scf::IfOp createCompareThenSwap(OpBuilder &builder, Location loc,
                                       AffineMap xPerm, uint64_t ny,
                                       SmallVectorImpl<Value> &swapOperands,
                                       SmallVectorImpl<Value> &compareOperands,
                                       Value a, Value b) {}

/// Creates code to insert the 3rd element to a list of two sorted elements.
static void createInsert3rd(OpBuilder &builder, Location loc, AffineMap xPerm,
                            uint64_t ny, SmallVectorImpl<Value> &swapOperands,
                            SmallVectorImpl<Value> &compareOperands, Value v0,
                            Value v1, Value v2) {}

/// Creates code to sort 3 elements.
static void createSort3(OpBuilder &builder, Location loc, AffineMap xPerm,
                        uint64_t ny, SmallVectorImpl<Value> &swapOperands,
                        SmallVectorImpl<Value> &compareOperands, Value v0,
                        Value v1, Value v2) {}

/// Creates code to sort 5 elements.
static void createSort5(OpBuilder &builder, Location loc, AffineMap xPerm,
                        uint64_t ny, SmallVectorImpl<Value> &swapOperands,
                        SmallVectorImpl<Value> &compareOperands, Value v0,
                        Value v1, Value v2, Value v3, Value v4) {}

/// Creates a code block to swap the values in indices lo, mi, and hi so that
/// data[lo], data[mi] and data[hi] are sorted in non-decreasing values. When
/// the number of values in range [lo, hi) is more than a threshold, we also
/// include the middle of [lo, mi) and [mi, hi) and sort a total of five values.
static void createChoosePivot(OpBuilder &builder, ModuleOp module,
                              func::FuncOp func, AffineMap xPerm, uint64_t ny,
                              Value lo, Value hi, Value mi, ValueRange args) {}

/// Creates a function to perform quick sort partition on the values in the
/// range of index [lo, hi), assuming lo < hi.
//
// The generated IR corresponds to this C like algorithm:
// int partition(lo, hi, xs) {
//   p = (lo+hi)/2  // pivot index
//   i = lo
//   j = hi-1
//   while (true) do {
//     while (xs[i] < xs[p]) i ++;
//     i_eq = (xs[i] == xs[p]);
//     while (xs[j] > xs[p]) j --;
//     j_eq = (xs[j] == xs[p]);
//
//     if (i >= j) return j + 1;
//
//     if (i < j) {
//       swap(xs[i], xs[j])
//       if (i == p) {
//         p = j;
//       } else if (j == p) {
//         p = i;
//       }
//       if (i_eq && j_eq) {
//         ++i;
//         --j;
//       }
//     }
//   }
// }
static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
                                func::FuncOp func, AffineMap xPerm, uint64_t ny,
                                uint32_t nTrailingP = 0) {}

/// Computes (n-2)/n, assuming n has index type.
static Value createSubTwoDividedByTwo(OpBuilder &builder, Location loc,
                                      Value n) {}

/// Creates a function to heapify the subtree with root `start` within the full
/// binary tree in the range of index [first, first + n).
//
// The generated IR corresponds to this C like algorithm:
// void shiftDown(first, start, n, data) {
//   if (n >= 2) {
//     child = start - first
//     if ((n-2)/2 >= child) {
//       // Left child exists.
//       child = child * 2 + 1 // Initialize the bigger child to left child.
//       childIndex = child + first
//       if (child+1 < n && data[childIndex] < data[childIndex+1])
//         // Right child exits and is bigger.
//         childIndex++; child++;
//       // Shift data[start] down to where it belongs in the subtree.
//       while (data[start] < data[childIndex) {
//         swap(data[start], data[childIndex])
//         start = childIndex
//         if ((n - 2)/2 >= child) {
//           // Left child exists.
//           child = 2*child + 1
//           childIndex = child + 1
//           if (child + 1) < n && data[childIndex] < data[childIndex+1]
//             childIndex++; child++;
//         }
//       }
//     }
//   }
// }
//
static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
                                func::FuncOp func, AffineMap xPerm, uint64_t ny,
                                uint32_t nTrailingP) {}

/// Creates a function to perform heap sort on the values in the range of index
/// [lo, hi) with the assumption hi - lo >= 2.
//
// The generate IR corresponds to this C like algorithm:
// void heapSort(lo, hi, data) {
//   n = hi - lo
//   for i = (n-2)/2 downto 0
//     shiftDown(lo, lo+i, n)
//
//   for l = n downto 2
//      swap(lo, lo+l-1)
//      shiftdown(lo, lo, l-1)
// }
static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
                               func::FuncOp func, AffineMap xPerm, uint64_t ny,
                               uint32_t nTrailingP) {}

/// A helper for generating code to perform quick sort. It partitions [lo, hi),
/// recursively calls quick sort to process the smaller partition and returns
/// the bigger partition to be processed by the enclosed while-loop.
static std::pair<Value, Value>
createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
                ValueRange args, AffineMap xPerm, uint64_t ny,
                uint32_t nTrailingP) {}

/// Creates a function to perform insertion sort on the values in the range of
/// index [lo, hi).
//
// The generate IR corresponds to this C like algorithm:
// void insertionSort(lo, hi, data) {
//   for (i = lo+1; i < hi; i++) {
//      d = data[i];
//      p = binarySearch(lo, i-1, data)
//      for (j = 0; j > i - p; j++)
//        data[i-j] = data[i-j-1]
//      data[p] = d
//   }
// }
static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
                                 func::FuncOp func, AffineMap xPerm,
                                 uint64_t ny, uint32_t nTrailingP) {}

/// Creates a function to perform quick sort or a hybrid quick sort on the
/// values in the range of index [lo, hi).
//
//
// When nTrailingP == 0, the generated IR corresponds to this C like algorithm:
// void quickSort(lo, hi, data) {
//   while (lo + 1 < hi) {
//        p = partition(low, high, data);
//        if (len(lo, p) < len(p+1, hi)) {
//          quickSort(lo, p, data);
//          lo = p+1;
//        } else {
//          quickSort(p + 1, hi, data);
//          hi = p;
//        }
//   }
// }
//
// When nTrailingP == 1, the generated IR corresponds to this C like algorithm:
// void hybridQuickSort(lo, hi, data, depthLimit) {
//   while (lo + 1 < hi) {
//     len = hi - lo;
//     if (len <= limit) {
//       insertionSort(lo, hi, data);
//     } else {
//       depthLimit --;
//       if (depthLimit <= 0) {
//         heapSort(lo, hi, data);
//       } else {
//          p = partition(low, high, data);
//          if (len(lo, p) < len(p+1, hi)) {
//            quickSort(lo, p, data, depthLimit);
//            lo = p+1;
//          } else {
//            quickSort(p + 1, hi, data, depthLimit);
//            hi = p;
//          }
//       }
//     }
//   }
// }
//
static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
                                func::FuncOp func, AffineMap xPerm, uint64_t ny,
                                uint32_t nTrailingP) {}

/// Implements the rewriting for operator sort and sort_coo.
template <typename OpTy>
LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, AffineMap xPerm,
                                    uint64_t ny, PatternRewriter &rewriter) {}

//===---------------------------------------------------------------------===//
// The actual sparse buffer rewriting rules.
//===---------------------------------------------------------------------===//

namespace {
/// Sparse rewriting rule for the push_back operator.
struct PushBackRewriter : OpRewritePattern<PushBackOp> {};

/// Sparse rewriting rule for the sort_coo operator.
struct SortRewriter : public OpRewritePattern<SortOp> {};

} // namespace

//===---------------------------------------------------------------------===//
// Methods that add patterns described in this file to a pattern list.
//===---------------------------------------------------------------------===//

void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns,
                                         bool enableBufferInitialization) {}