//===- SparseBufferRewriting.cpp - Sparse buffer rewriting rules ----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements rewriting rules that are specific to sparse tensor // primitives with memref operands. // //===----------------------------------------------------------------------===// #include "Utils/CodegenUtils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/Support/LLVM.h" usingnamespacemlir; usingnamespacemlir::sparse_tensor; //===---------------------------------------------------------------------===// // Helper methods for the actual rewriting rules. //===---------------------------------------------------------------------===// static constexpr uint64_t loIdx = …; static constexpr uint64_t hiIdx = …; static constexpr uint64_t xStartIdx = …; static constexpr const char kPartitionFuncNamePrefix[] = …; static constexpr const char kBinarySearchFuncNamePrefix[] = …; static constexpr const char kHybridQuickSortFuncNamePrefix[] = …; static constexpr const char kSortStableFuncNamePrefix[] = …; static constexpr const char kShiftDownFuncNamePrefix[] = …; static constexpr const char kHeapSortFuncNamePrefix[] = …; static constexpr const char kQuickSortFuncNamePrefix[] = …; FuncGeneratorType; /// Constructs a function name with this format to facilitate quick sort: /// <namePrefix><xPerm>_<x type>_<y0 type>..._<yn type> for sort /// <namePrefix><xPerm>_<x type>_coo_<ny>_<y0 type>..._<yn type> for sort_coo static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream, StringRef namePrefix, AffineMap xPerm, uint64_t ny, ValueRange operands) { … } /// Looks up a function that is appropriate for the given operands being /// sorted, and creates such a function if it doesn't exist yet. The /// parameters `xPerm` and `ny` tell the number of x and y values provided /// by the buffer in xStartIdx. // // All sorting function generators take (lo, hi, xs, ys) in `operands` as // parameters for the sorting functions. Other parameters, such as the recursive // call depth, are appended to the end of the parameter list as // "trailing parameters". static FlatSymbolRefAttr getMangledSortHelperFunc( OpBuilder &builder, func::FuncOp insertPoint, TypeRange resultTypes, StringRef namePrefix, AffineMap xPerm, uint64_t ny, ValueRange operands, FuncGeneratorType createFunc, uint32_t nTrailingP = 0) { … } /// Creates a code block to process each pair of (xs[i], xs[j]) for sorting. /// The code to process the value pairs is generated by `bodyBuilder`. static void forEachIJPairInXs( OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) { … } /// Creates a code block to process each pair of (xys[i], xys[j]) for sorting. /// The code to process the value pairs is generated by `bodyBuilder`. static void forEachIJPairInAllBuffers( OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) { … } /// Creates a code block for swapping the values in index i and j for all the /// buffers. // // The generated IR corresponds to this C like algorithm: // swap(x0[i], x0[j]); // swap(x1[i], x1[j]); // ... // swap(xn[i], xn[j]); // swap(y0[i], y0[j]); // ... // swap(yn[i], yn[j]); static void createSwap(OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny) { … } /// Creates code to compare all the (xs[i], xs[j]) pairs. The method to compare /// each pair is create via `compareBuilder`. static Value createInlinedCompareImplementation( OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, function_ref<Value(OpBuilder &, Location, Value, Value, Value, bool, bool)> compareBuilder) { … } /// Generates code to compare whether x[i] is equal to x[j] and returns the /// result of the comparison. static Value createEqCompare(OpBuilder &builder, Location loc, Value i, Value j, Value x, bool isFirstDim, bool isLastDim) { … } /// Creates code to compare whether xs[i] is equal to xs[j]. // // The generate IR corresponds to this C like algorithm: // if (x0[i] != x0[j]) // return false; // else // if (x1[i] != x1[j]) // return false; // else if (x2[2] != x2[j])) // and so on ... static Value createInlinedEqCompare(OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP = 0) { … } /// Generates code to compare whether x[i] is less than x[j] and returns the /// result of the comparison. static Value createLessThanCompare(OpBuilder &builder, Location loc, Value i, Value j, Value x, bool isFirstDim, bool isLastDim) { … } /// Creates code to compare whether xs[i] is less than xs[j]. // // The generate IR corresponds to this C like algorithm: // if (x0[i] != x0[j]) // return x0[i] < x0[j]; // else if (x1[j] != x1[i]) // return x1[i] < x1[j]; // else // and so on ... static Value createInlinedLessThan(OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP = 0) { … } /// Creates a function to use a binary search to find the insertion point for /// inserting xs[hi] to the sorted values xs[lo..hi). // // The generate IR corresponds to this C like algorithm: // p = hi // while (lo < hi) // mid = (lo + hi) >> 1 // if (xs[p] < xs[mid]) // hi = mid // else // lo = mid - 1 // return lo; // static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP = 0) { … } /// Creates code to advance i in a loop based on xs[p] as follows: /// while (xs[i] < xs[p]) i += step (step > 0) /// or /// while (xs[i] > xs[p]) i += step (step < 0) /// The routine returns i as well as a boolean value to indicate whether /// xs[i] == xs[p]. static std::pair<Value, Value> createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func, ValueRange xs, Value i, Value p, AffineMap xPerm, uint64_t ny, int step) { … } /// Creates and returns an IfOp to compare two elements and swap the elements /// if compareFunc(data[b], data[a]) returns true. The new insertion point is /// right after the swap instructions. static scf::IfOp createCompareThenSwap(OpBuilder &builder, Location loc, AffineMap xPerm, uint64_t ny, SmallVectorImpl<Value> &swapOperands, SmallVectorImpl<Value> &compareOperands, Value a, Value b) { … } /// Creates code to insert the 3rd element to a list of two sorted elements. static void createInsert3rd(OpBuilder &builder, Location loc, AffineMap xPerm, uint64_t ny, SmallVectorImpl<Value> &swapOperands, SmallVectorImpl<Value> &compareOperands, Value v0, Value v1, Value v2) { … } /// Creates code to sort 3 elements. static void createSort3(OpBuilder &builder, Location loc, AffineMap xPerm, uint64_t ny, SmallVectorImpl<Value> &swapOperands, SmallVectorImpl<Value> &compareOperands, Value v0, Value v1, Value v2) { … } /// Creates code to sort 5 elements. static void createSort5(OpBuilder &builder, Location loc, AffineMap xPerm, uint64_t ny, SmallVectorImpl<Value> &swapOperands, SmallVectorImpl<Value> &compareOperands, Value v0, Value v1, Value v2, Value v3, Value v4) { … } /// Creates a code block to swap the values in indices lo, mi, and hi so that /// data[lo], data[mi] and data[hi] are sorted in non-decreasing values. When /// the number of values in range [lo, hi) is more than a threshold, we also /// include the middle of [lo, mi) and [mi, hi) and sort a total of five values. static void createChoosePivot(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, Value lo, Value hi, Value mi, ValueRange args) { … } /// Creates a function to perform quick sort partition on the values in the /// range of index [lo, hi), assuming lo < hi. // // The generated IR corresponds to this C like algorithm: // int partition(lo, hi, xs) { // p = (lo+hi)/2 // pivot index // i = lo // j = hi-1 // while (true) do { // while (xs[i] < xs[p]) i ++; // i_eq = (xs[i] == xs[p]); // while (xs[j] > xs[p]) j --; // j_eq = (xs[j] == xs[p]); // // if (i >= j) return j + 1; // // if (i < j) { // swap(xs[i], xs[j]) // if (i == p) { // p = j; // } else if (j == p) { // p = i; // } // if (i_eq && j_eq) { // ++i; // --j; // } // } // } // } static void createPartitionFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP = 0) { … } /// Computes (n-2)/n, assuming n has index type. static Value createSubTwoDividedByTwo(OpBuilder &builder, Location loc, Value n) { … } /// Creates a function to heapify the subtree with root `start` within the full /// binary tree in the range of index [first, first + n). // // The generated IR corresponds to this C like algorithm: // void shiftDown(first, start, n, data) { // if (n >= 2) { // child = start - first // if ((n-2)/2 >= child) { // // Left child exists. // child = child * 2 + 1 // Initialize the bigger child to left child. // childIndex = child + first // if (child+1 < n && data[childIndex] < data[childIndex+1]) // // Right child exits and is bigger. // childIndex++; child++; // // Shift data[start] down to where it belongs in the subtree. // while (data[start] < data[childIndex) { // swap(data[start], data[childIndex]) // start = childIndex // if ((n - 2)/2 >= child) { // // Left child exists. // child = 2*child + 1 // childIndex = child + 1 // if (child + 1) < n && data[childIndex] < data[childIndex+1] // childIndex++; child++; // } // } // } // } // } // static void createShiftDownFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP) { … } /// Creates a function to perform heap sort on the values in the range of index /// [lo, hi) with the assumption hi - lo >= 2. // // The generate IR corresponds to this C like algorithm: // void heapSort(lo, hi, data) { // n = hi - lo // for i = (n-2)/2 downto 0 // shiftDown(lo, lo+i, n) // // for l = n downto 2 // swap(lo, lo+l-1) // shiftdown(lo, lo, l-1) // } static void createHeapSortFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP) { … } /// A helper for generating code to perform quick sort. It partitions [lo, hi), /// recursively calls quick sort to process the smaller partition and returns /// the bigger partition to be processed by the enclosed while-loop. static std::pair<Value, Value> createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func, ValueRange args, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP) { … } /// Creates a function to perform insertion sort on the values in the range of /// index [lo, hi). // // The generate IR corresponds to this C like algorithm: // void insertionSort(lo, hi, data) { // for (i = lo+1; i < hi; i++) { // d = data[i]; // p = binarySearch(lo, i-1, data) // for (j = 0; j > i - p; j++) // data[i-j] = data[i-j-1] // data[p] = d // } // } static void createSortStableFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP) { … } /// Creates a function to perform quick sort or a hybrid quick sort on the /// values in the range of index [lo, hi). // // // When nTrailingP == 0, the generated IR corresponds to this C like algorithm: // void quickSort(lo, hi, data) { // while (lo + 1 < hi) { // p = partition(low, high, data); // if (len(lo, p) < len(p+1, hi)) { // quickSort(lo, p, data); // lo = p+1; // } else { // quickSort(p + 1, hi, data); // hi = p; // } // } // } // // When nTrailingP == 1, the generated IR corresponds to this C like algorithm: // void hybridQuickSort(lo, hi, data, depthLimit) { // while (lo + 1 < hi) { // len = hi - lo; // if (len <= limit) { // insertionSort(lo, hi, data); // } else { // depthLimit --; // if (depthLimit <= 0) { // heapSort(lo, hi, data); // } else { // p = partition(low, high, data); // if (len(lo, p) < len(p+1, hi)) { // quickSort(lo, p, data, depthLimit); // lo = p+1; // } else { // quickSort(p + 1, hi, data, depthLimit); // hi = p; // } // } // } // } // } // static void createQuickSortFunc(OpBuilder &builder, ModuleOp module, func::FuncOp func, AffineMap xPerm, uint64_t ny, uint32_t nTrailingP) { … } /// Implements the rewriting for operator sort and sort_coo. template <typename OpTy> LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, AffineMap xPerm, uint64_t ny, PatternRewriter &rewriter) { … } //===---------------------------------------------------------------------===// // The actual sparse buffer rewriting rules. //===---------------------------------------------------------------------===// namespace { /// Sparse rewriting rule for the push_back operator. struct PushBackRewriter : OpRewritePattern<PushBackOp> { … }; /// Sparse rewriting rule for the sort_coo operator. struct SortRewriter : public OpRewritePattern<SortOp> { … }; } // namespace //===---------------------------------------------------------------------===// // Methods that add patterns described in this file to a pattern list. //===---------------------------------------------------------------------===// void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns, bool enableBufferInitialization) { … }