llvm/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp

//===- ExpandStridedMetadata.cpp - Simplify this operation -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
/// The pass expands memref operations that modify the metadata of a memref
/// (sizes, offset, strides) into a sequence of easier to analyze constructs.
/// In particular, this pass transforms operations into explicit sequence of
/// operations that model the effect of this operation on the different
/// metadata. This pass uses affine constructs to materialize these effects.
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
#include <optional>

namespace mlir {
namespace memref {
#define GEN_PASS_DEF_EXPANDSTRIDEDMETADATA
#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
} // namespace memref
} // namespace mlir

usingnamespacemlir;
usingnamespacemlir::affine;

namespace {

struct StridedMetadata {};

/// From `subview(memref, subOffset, subSizes, subStrides))` compute
///
/// \verbatim
/// baseBuffer, baseOffset, baseSizes, baseStrides =
///     extract_strided_metadata(memref)
/// strides#i = baseStrides#i * subStrides#i
/// offset = baseOffset + sum(subOffset#i * baseStrides#i)
/// sizes = subSizes
/// \endverbatim
///
/// and return {baseBuffer, offset, sizes, strides}
static FailureOr<StridedMetadata>
resolveSubviewStridedMetadata(RewriterBase &rewriter,
                              memref::SubViewOp subview) {}

/// Replace `dst = subview(memref, subOffset, subSizes, subStrides))`
/// With
///
/// \verbatim
/// baseBuffer, baseOffset, baseSizes, baseStrides =
///     extract_strided_metadata(memref)
/// strides#i = baseStrides#i * subSizes#i
/// offset = baseOffset + sum(subOffset#i * baseStrides#i)
/// sizes = subSizes
/// dst = reinterpret_cast baseBuffer, offset, sizes, strides
/// \endverbatim
///
/// In other words, get rid of the subview in that expression and canonicalize
/// on its effects on the offset, the sizes, and the strides using affine.apply.
struct SubviewFolder : public OpRewritePattern<memref::SubViewOp> {};

/// Pattern to replace `extract_strided_metadata(subview)`
/// With
///
/// \verbatim
/// baseBuffer, baseOffset, baseSizes, baseStrides =
///     extract_strided_metadata(memref)
/// strides#i = baseStrides#i * subSizes#i
/// offset = baseOffset + sum(subOffset#i * baseStrides#i)
/// sizes = subSizes
/// \verbatim
///
/// with `baseBuffer`, `offset`, `sizes` and `strides` being
/// the replacements for the original `extract_strided_metadata`.
struct ExtractStridedMetadataOpSubviewFolder
    : OpRewritePattern<memref::ExtractStridedMetadataOp> {};

/// Compute the expanded sizes of the given \p expandShape for the
/// \p groupId-th reassociation group.
/// \p origSizes hold the sizes of the source shape as values.
/// This is used to compute the new sizes in cases of dynamic shapes.
///
/// sizes#i =
///     baseSizes#groupId / product(expandShapeSizes#j,
///                                  for j in group excluding reassIdx#i)
/// Where reassIdx#i is the reassociation index at index i in \p groupId.
///
/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
///
/// TODO: Move this utility function directly within ExpandShapeOp. For now,
/// this is not possible because this function uses the Affine dialect and the
/// MemRef dialect cannot depend on the Affine dialect.
static SmallVector<OpFoldResult>
getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
                 ArrayRef<OpFoldResult> origSizes, unsigned groupId) {}

/// Compute the expanded strides of the given \p expandShape for the
/// \p groupId-th reassociation group.
/// \p origStrides and \p origSizes hold respectively the strides and sizes
/// of the source shape as values.
/// This is used to compute the strides in cases of dynamic shapes and/or
/// dynamic stride for this reassociation group.
///
/// strides#i =
///     origStrides#reassDim * product(expandShapeSizes#j, for j in
///                                    reassIdx#i+1..reassIdx#i+group.size-1)
///
/// Where reassIdx#i is the reassociation index for at index i in \p groupId
/// and expandShapeSizes#j is either:
/// - The constant size at dimension j, derived directly from the result type of
///   the expand_shape op, or
/// - An affine expression: baseSizes#reassDim / product of all constant sizes
///   in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic
///   element.)
///
/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
///
/// TODO: Move this utility function directly within ExpandShapeOp. For now,
/// this is not possible because this function uses the Affine dialect and the
/// MemRef dialect cannot depend on the Affine dialect.
SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
                                             OpBuilder &builder,
                                             ArrayRef<OpFoldResult> origSizes,
                                             ArrayRef<OpFoldResult> origStrides,
                                             unsigned groupId) {}

/// Produce an OpFoldResult object with \p builder at \p loc representing
/// `prod(valueOrConstant#i, for i in {indices})`,
/// where valueOrConstant#i is maybeConstant[i] when \p isDymamic is false,
/// values[i] otherwise.
///
/// \pre for all index in indices: index < values.size()
/// \pre for all index in indices: index < maybeConstants.size()
static OpFoldResult
getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
                   ArrayRef<int64_t> maybeConstants,
                   ArrayRef<OpFoldResult> values,
                   llvm::function_ref<bool(int64_t)> isDynamic) {}

/// Compute the collapsed size of the given \p collpaseShape for the
/// \p groupId-th reassociation group.
/// \p origSizes hold the sizes of the source shape as values.
/// This is used to compute the new sizes in cases of dynamic shapes.
///
/// Conceptually this helper function computes:
/// `prod(origSizes#i, for i in {ressociationGroup[groupId]})`.
///
/// \post result.size() == 1, in other words, each group collapse to one
/// dimension.
///
/// TODO: Move this utility function directly within CollapseShapeOp. For now,
/// this is not possible because this function uses the Affine dialect and the
/// MemRef dialect cannot depend on the Affine dialect.
static SmallVector<OpFoldResult>
getCollapsedSize(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
                 ArrayRef<OpFoldResult> origSizes, unsigned groupId) {}

/// Compute the collapsed stride of the given \p collpaseShape for the
/// \p groupId-th reassociation group.
/// \p origStrides and \p origSizes hold respectively the strides and sizes
/// of the source shape as values.
/// This is used to compute the strides in cases of dynamic shapes and/or
/// dynamic stride for this reassociation group.
///
/// Conceptually this helper function returns the stride of the inner most
/// dimension of that group in the original shape.
///
/// \post result.size() == 1, in other words, each group collapse to one
/// dimension.
static SmallVector<OpFoldResult>
getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
                   ArrayRef<OpFoldResult> origSizes,
                   ArrayRef<OpFoldResult> origStrides, unsigned groupId) {}

/// From `reshape_like(memref, subSizes, subStrides))` compute
///
/// \verbatim
/// baseBuffer, baseOffset, baseSizes, baseStrides =
///     extract_strided_metadata(memref)
/// strides#i = baseStrides#i * subStrides#i
/// sizes = subSizes
/// \endverbatim
///
/// and return {baseBuffer, baseOffset, sizes, strides}
template <typename ReassociativeReshapeLikeOp>
static FailureOr<StridedMetadata> resolveReshapeStridedMetadata(
    RewriterBase &rewriter, ReassociativeReshapeLikeOp reshape,
    function_ref<SmallVector<OpFoldResult>(
        ReassociativeReshapeLikeOp, OpBuilder &,
        ArrayRef<OpFoldResult> /*origSizes*/, unsigned /*groupId*/)>
        getReshapedSizes,
    function_ref<SmallVector<OpFoldResult>(
        ReassociativeReshapeLikeOp, OpBuilder &,
        ArrayRef<OpFoldResult> /*origSizes*/,
        ArrayRef<OpFoldResult> /*origStrides*/, unsigned /*groupId*/)>
        getReshapedStrides) {}

/// Replace `baseBuffer, offset, sizes, strides =
///              extract_strided_metadata(reshapeLike(memref))`
/// With
///
/// \verbatim
/// baseBuffer, offset, baseSizes, baseStrides =
///     extract_strided_metadata(memref)
/// sizes = getReshapedSizes(reshapeLike)
/// strides = getReshapedStrides(reshapeLike)
/// \endverbatim
///
///
/// Notice that `baseBuffer` and `offset` are unchanged.
///
/// In other words, get rid of the expand_shape in that expression and
/// materialize its effects on the sizes and the strides using affine apply.
template <typename ReassociativeReshapeLikeOp,
          SmallVector<OpFoldResult> (*getReshapedSizes)(
              ReassociativeReshapeLikeOp, OpBuilder &,
              ArrayRef<OpFoldResult> /*origSizes*/, unsigned /*groupId*/),
          SmallVector<OpFoldResult> (*getReshapedStrides)(
              ReassociativeReshapeLikeOp, OpBuilder &,
              ArrayRef<OpFoldResult> /*origSizes*/,
              ArrayRef<OpFoldResult> /*origStrides*/, unsigned /*groupId*/)>
struct ReshapeFolder : public OpRewritePattern<ReassociativeReshapeLikeOp> {};

/// Pattern to replace `extract_strided_metadata(collapse_shape)`
/// With
///
/// \verbatim
/// baseBuffer, baseOffset, baseSizes, baseStrides =
///     extract_strided_metadata(memref)
/// strides#i = baseStrides#i * subSizes#i
/// offset = baseOffset + sum(subOffset#i * baseStrides#i)
/// sizes = subSizes
/// \verbatim
///
/// with `baseBuffer`, `offset`, `sizes` and `strides` being
/// the replacements for the original `extract_strided_metadata`.
struct ExtractStridedMetadataOpCollapseShapeFolder
    : OpRewritePattern<memref::ExtractStridedMetadataOp> {};

/// Pattern to replace `extract_strided_metadata(expand_shape)`
/// with the results of computing the sizes and strides on the expanded shape
/// and dividing up dimensions into static and dynamic parts as needed.
struct ExtractStridedMetadataOpExpandShapeFolder
    : OpRewritePattern<memref::ExtractStridedMetadataOp> {};

/// Replace `base, offset, sizes, strides =
///              extract_strided_metadata(allocLikeOp)`
///
/// With
///
/// ```
/// base = reinterpret_cast allocLikeOp(allocSizes) to a flat memref<eltTy>
/// offset = 0
/// sizes = allocSizes
/// strides#i = prod(allocSizes#j, for j in {i+1..rank-1})
/// ```
///
/// The transformation only applies if the allocLikeOp has been normalized.
/// In other words, the affine_map must be an identity.
template <typename AllocLikeOp>
struct ExtractStridedMetadataOpAllocFolder
    : public OpRewritePattern<memref::ExtractStridedMetadataOp> {};

/// Replace `base, offset, sizes, strides =
///              extract_strided_metadata(get_global)`
///
/// With
///
/// ```
/// base = reinterpret_cast get_global to a flat memref<eltTy>
/// offset = 0
/// sizes = allocSizes
/// strides#i = prod(allocSizes#j, for j in {i+1..rank-1})
/// ```
///
/// It is expected that the memref.get_global op has static shapes
/// and identity affine_map for the layout.
struct ExtractStridedMetadataOpGetGlobalFolder
    : public OpRewritePattern<memref::ExtractStridedMetadataOp> {};

/// Rewrite memref.extract_aligned_pointer_as_index of a ViewLikeOp to the
/// source of the ViewLikeOp.
class RewriteExtractAlignedPointerAsIndexOfViewLikeOp
    : public OpRewritePattern<memref::ExtractAlignedPointerAsIndexOp> {};

/// Replace `base, offset, sizes, strides =
///              extract_strided_metadata(
///                 reinterpret_cast(src, srcOffset, srcSizes, srcStrides))`
/// With
/// ```
/// base, ... = extract_strided_metadata(src)
/// offset = srcOffset
/// sizes = srcSizes
/// strides = srcStrides
/// ```
///
/// In other words, consume the `reinterpret_cast` and apply its effects
/// on the offset, sizes, and strides.
class ExtractStridedMetadataOpReinterpretCastFolder
    : public OpRewritePattern<memref::ExtractStridedMetadataOp> {};

/// Replace `base, offset, sizes, strides =
///              extract_strided_metadata(
///                 cast(src) to dstTy)`
/// With
/// ```
/// base, ... = extract_strided_metadata(src)
/// offset = !dstTy.srcOffset.isDynamic()
///            ? dstTy.srcOffset
///            : extract_strided_metadata(src).offset
/// sizes = for each srcSize in dstTy.srcSizes:
///           !srcSize.isDynamic()
///             ? srcSize
//              : extract_strided_metadata(src).sizes[i]
/// strides = for each srcStride in dstTy.srcStrides:
///             !srcStrides.isDynamic()
///               ? srcStrides
///               : extract_strided_metadata(src).strides[i]
/// ```
///
/// In other words, consume the `cast` and apply its effects
/// on the offset, sizes, and strides or compute them directly from `src`.
class ExtractStridedMetadataOpCastFolder
    : public OpRewritePattern<memref::ExtractStridedMetadataOp> {};

/// Replace `base, offset, sizes, strides = extract_strided_metadata(
///      memory_space_cast(src) to dstTy)`
/// with
/// ```
///    oldBase, offset, sizes, strides = extract_strided_metadata(src)
///    destBaseTy = type(oldBase) with memory space from destTy
///    base = memory_space_cast(oldBase) to destBaseTy
/// ```
///
/// In other words, propagate metadata extraction accross memory space casts.
class ExtractStridedMetadataOpMemorySpaceCastFolder
    : public OpRewritePattern<memref::ExtractStridedMetadataOp> {};

/// Replace `base, offset =
///            extract_strided_metadata(extract_strided_metadata(src)#0)`
/// With
/// ```
/// base, ... = extract_strided_metadata(src)
/// offset = 0
/// ```
class ExtractStridedMetadataOpExtractStridedMetadataFolder
    : public OpRewritePattern<memref::ExtractStridedMetadataOp> {};
} // namespace

void memref::populateExpandStridedMetadataPatterns(
    RewritePatternSet &patterns) {}

void memref::populateResolveExtractStridedMetadataPatterns(
    RewritePatternSet &patterns) {}

//===----------------------------------------------------------------------===//
// Pass registration
//===----------------------------------------------------------------------===//

namespace {

struct ExpandStridedMetadataPass final
    : public memref::impl::ExpandStridedMetadataBase<
          ExpandStridedMetadataPass> {};

} // namespace

void ExpandStridedMetadataPass::runOnOperation() {}

std::unique_ptr<Pass> memref::createExpandStridedMetadataPass() {}