llvm/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp

//===- ExtractAddressCmoputations.cpp - Extract address computations  -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
/// This transformation pass rewrites loading/storing from/to a memref with
/// offsets into loading/storing from/to a subview and without any offset on
/// the instruction itself.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/PatternMatch.h"

usingnamespacemlir;

namespace {

//===----------------------------------------------------------------------===//
// Helper functions for the `load base[off0...]`
//  => `load (subview base[off0...])[0...]` pattern.
//===----------------------------------------------------------------------===//

// Matches getFailureOrSrcMemRef specs for LoadOp.
// \see LoadStoreLikeOpRewriter.
static FailureOr<Value> getLoadOpSrcMemRef(memref::LoadOp loadOp) {}

// Matches rebuildOpFromAddressAndIndices specs for LoadOp.
// \see LoadStoreLikeOpRewriter.
static memref::LoadOp rebuildLoadOp(RewriterBase &rewriter,
                                    memref::LoadOp loadOp, Value srcMemRef,
                                    ArrayRef<Value> indices) {}

// Matches getViewSizeForEachDim specs for LoadOp.
// \see LoadStoreLikeOpRewriter.
static SmallVector<OpFoldResult>
getLoadOpViewSizeForEachDim(RewriterBase &rewriter, memref::LoadOp loadOp) {}

//===----------------------------------------------------------------------===//
// Helper functions for the `store val, base[off0...]`
//  => `store val, (subview base[off0...])[0...]` pattern.
//===----------------------------------------------------------------------===//

// Matches getFailureOrSrcMemRef specs for StoreOp.
// \see LoadStoreLikeOpRewriter.
static FailureOr<Value> getStoreOpSrcMemRef(memref::StoreOp storeOp) {}

// Matches rebuildOpFromAddressAndIndices specs for StoreOp.
// \see LoadStoreLikeOpRewriter.
static memref::StoreOp rebuildStoreOp(RewriterBase &rewriter,
                                      memref::StoreOp storeOp, Value srcMemRef,
                                      ArrayRef<Value> indices) {}

// Matches getViewSizeForEachDim specs for StoreOp.
// \see LoadStoreLikeOpRewriter.
static SmallVector<OpFoldResult>
getStoreOpViewSizeForEachDim(RewriterBase &rewriter, memref::StoreOp storeOp) {}

//===----------------------------------------------------------------------===//
// Helper functions for the `ldmatrix base[off0...]`
//  => `ldmatrix (subview base[off0...])[0...]` pattern.
//===----------------------------------------------------------------------===//

// Matches getFailureOrSrcMemRef specs for LdMatrixOp.
// \see LoadStoreLikeOpRewriter.
static FailureOr<Value> getLdMatrixOpSrcMemRef(nvgpu::LdMatrixOp ldMatrixOp) {}

// Matches rebuildOpFromAddressAndIndices specs for LdMatrixOp.
// \see LoadStoreLikeOpRewriter.
static nvgpu::LdMatrixOp rebuildLdMatrixOp(RewriterBase &rewriter,
                                           nvgpu::LdMatrixOp ldMatrixOp,
                                           Value srcMemRef,
                                           ArrayRef<Value> indices) {}

//===----------------------------------------------------------------------===//
// Helper functions for the `transfer_read base[off0...]`
//  => `transfer_read (subview base[off0...])[0...]` pattern.
//===----------------------------------------------------------------------===//

// Matches getFailureOrSrcMemRef specs for TransferReadOp.
// \see LoadStoreLikeOpRewriter.
template <typename TransferLikeOp>
static FailureOr<Value>
getTransferLikeOpSrcMemRef(TransferLikeOp transferLikeOp) {}

// Matches rebuildOpFromAddressAndIndices specs for TransferReadOp.
// \see LoadStoreLikeOpRewriter.
static vector::TransferReadOp
rebuildTransferReadOp(RewriterBase &rewriter,
                      vector::TransferReadOp transferReadOp, Value srcMemRef,
                      ArrayRef<Value> indices) {}

//===----------------------------------------------------------------------===//
// Helper functions for the `transfer_write base[off0...]`
//  => `transfer_write (subview base[off0...])[0...]` pattern.
//===----------------------------------------------------------------------===//

// Matches rebuildOpFromAddressAndIndices specs for TransferWriteOp.
// \see LoadStoreLikeOpRewriter.
static vector::TransferWriteOp
rebuildTransferWriteOp(RewriterBase &rewriter,
                       vector::TransferWriteOp transferWriteOp, Value srcMemRef,
                       ArrayRef<Value> indices) {}

//===----------------------------------------------------------------------===//
// Generic helper functions used as default implementation in
// LoadStoreLikeOpRewriter.
//===----------------------------------------------------------------------===//

/// Helper function to get the src memref.
/// It uses the already defined getFailureOrSrcMemRef but asserts
/// that the source is a memref.
template <typename LoadStoreLikeOp,
          FailureOr<Value> (*getFailureOrSrcMemRef)(LoadStoreLikeOp)>
static Value getSrcMemRef(LoadStoreLikeOp loadStoreLikeOp) {}

/// Helper function to get the sizes of the resulting view.
/// This function gets the sizes of the source memref then substracts the
/// offsets used within \p loadStoreLikeOp. This gives the maximal (for
/// inbound) sizes for the view.
/// The source memref is retrieved using getSrcMemRef on \p loadStoreLikeOp.
template <typename LoadStoreLikeOp, Value (*getSrcMemRef)(LoadStoreLikeOp)>
static SmallVector<OpFoldResult>
getGenericOpViewSizeForEachDim(RewriterBase &rewriter,
                               LoadStoreLikeOp loadStoreLikeOp) {}

/// Rewrite a store/load-like op so that all its indices are zeros.
/// E.g., %ld = memref.load %base[%off0]...[%offN]
/// =>
/// %new_base = subview %base[%off0,.., %offN][1,..,1][1,..,1]
/// %ld = memref.load %new_base[0,..,0] :
///    memref<1x..x1xTy, strided<[1,..,1], offset: ?>>
///
/// `getSrcMemRef` returns the source memref for the given load-like operation.
///
/// `getViewSizeForEachDim` returns the sizes of view that is going to feed
/// new operation. This must return one size per dimension of the view.
/// The sizes of the view needs to be at least as big as what is actually
/// going to be accessed. Use the provided `loadStoreOp` to get the right
/// sizes.
///
/// Using the given rewriter, `rebuildOpFromAddressAndIndices` creates a new
/// LoadStoreLikeOp that reads from srcMemRef[indices].
/// The returned operation will be used to replace loadStoreOp.
template <typename LoadStoreLikeOp,
          FailureOr<Value> (*getFailureOrSrcMemRef)(LoadStoreLikeOp),
          LoadStoreLikeOp (*rebuildOpFromAddressAndIndices)(
              RewriterBase & /*rewriter*/, LoadStoreLikeOp /*loadStoreOp*/,
              Value /*srcMemRef*/, ArrayRef<Value> /*indices*/),
          SmallVector<OpFoldResult> (*getViewSizeForEachDim)(
              RewriterBase & /*rewriter*/, LoadStoreLikeOp /*loadStoreOp*/) =
              getGenericOpViewSizeForEachDim<
                  LoadStoreLikeOp,
                  getSrcMemRef<LoadStoreLikeOp, getFailureOrSrcMemRef>>>
struct LoadStoreLikeOpRewriter : public OpRewritePattern<LoadStoreLikeOp> {
  using OpRewritePattern<LoadStoreLikeOp>::OpRewritePattern;

  LogicalResult matchAndRewrite(LoadStoreLikeOp loadStoreLikeOp,
                                PatternRewriter &rewriter) const override {}
};
} // namespace

void memref::populateExtractAddressComputationsPatterns(
    RewritePatternSet &patterns) {}