llvm/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h

//===- Transforms.h - MemRef Dialect transformations ------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
/// This header declares functions that assist transformations in the MemRef
/// dialect.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_MEMREF_TRANSFORMS_TRANSFORMS_H
#define MLIR_DIALECT_MEMREF_TRANSFORMS_TRANSFORMS_H

#include "mlir/Support/LLVM.h"
#include "llvm/ADT/STLFunctionalExtras.h"

namespace mlir {
class OpBuilder;
class RewritePatternSet;
class RewriterBase;
class Value;
class ValueRange;

namespace arith {
class WideIntEmulationConverter;
class NarrowTypeEmulationConverter;
} // namespace arith

namespace memref {
class AllocOp;
class AllocaOp;
class DeallocOp;

//===----------------------------------------------------------------------===//
// Patterns
//===----------------------------------------------------------------------===//

/// Collects a set of patterns to rewrite ops within the memref dialect.
void populateExpandOpsPatterns(RewritePatternSet &patterns);

/// Appends patterns for folding memref aliasing ops into consumer load/store
/// ops into `patterns`.
void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns);

/// Appends patterns that resolve `memref.dim` operations with values that are
/// defined by operations that implement the
/// `ReifyRankedShapedTypeOpInterface`, in terms of shapes of its input
/// operands.
void populateResolveRankedShapedTypeResultDimsPatterns(
    RewritePatternSet &patterns);

/// Appends patterns that resolve `memref.dim` operations with values that are
/// defined by operations that implement the `InferShapedTypeOpInterface`, in
/// terms of shapes of its input operands.
void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns);

/// Appends patterns for expanding memref operations that modify the metadata
/// (sizes, offset, strides) of a memref into easier to analyze constructs.
void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns);

/// Appends patterns for resolving `memref.extract_strided_metadata` into
/// `memref.extract_strided_metadata` of its source.
void populateResolveExtractStridedMetadataPatterns(RewritePatternSet &patterns);

/// Appends patterns for expanding `memref.realloc` operations.
void populateExpandReallocPatterns(RewritePatternSet &patterns,
                                   bool emitDeallocs = true);

/// Appends patterns for emulating wide integer memref operations with ops over
/// narrower integer types.
void populateMemRefWideIntEmulationPatterns(
    arith::WideIntEmulationConverter &typeConverter,
    RewritePatternSet &patterns);

/// Appends type conversions for emulating wide integer memref operations with
/// ops over narrowe integer types.
void populateMemRefWideIntEmulationConversions(
    arith::WideIntEmulationConverter &typeConverter);

/// Appends patterns for emulating memref operations over narrow types with ops
/// over wider types.
void populateMemRefNarrowTypeEmulationPatterns(
    arith::NarrowTypeEmulationConverter &typeConverter,
    RewritePatternSet &patterns);

/// Appends type conversions for emulating memref operations over narrow types
/// with ops over wider types.
void populateMemRefNarrowTypeEmulationConversions(
    arith::NarrowTypeEmulationConverter &typeConverter);

/// Transformation to do multi-buffering/array expansion to remove dependencies
/// on the temporary allocation between consecutive loop iterations.
/// It returns the new allocation if the original allocation was multi-buffered
/// and returns failure() otherwise.
/// When `skipOverrideAnalysis`, the pass will apply the transformation
/// without checking thwt the buffer is overrided at the beginning of each
/// iteration. This implies that user knows that there is no data carried across
/// loop iterations. Example:
/// ```
/// %0 = memref.alloc() : memref<4x128xf32>
/// scf.for %iv = %c1 to %c1024 step %c3 {
///   memref.copy %1, %0 : memref<4x128xf32> to memref<4x128xf32>
///   "some_use"(%0) : (memref<4x128xf32>) -> ()
/// }
/// ```
/// into:
/// ```
/// %0 = memref.alloc() : memref<5x4x128xf32>
/// scf.for %iv = %c1 to %c1024 step %c3 {
///   %s = arith.subi %iv, %c1 : index
///   %d = arith.divsi %s, %c3 : index
///   %i = arith.remsi %d, %c5 : index
///   %sv = memref.subview %0[%i, 0, 0] [1, 4, 128] [1, 1, 1] :
///     memref<5x4x128xf32> to memref<4x128xf32, strided<[128, 1], offset: ?>>
///   memref.copy %1, %sv : memref<4x128xf32> to memref<4x128xf32, strided<...>>
///   "some_use"(%sv) : (memref<4x128xf32, strided<...>) -> ()
/// }
/// ```
FailureOr<memref::AllocOp> multiBuffer(RewriterBase &rewriter,
                                       memref::AllocOp allocOp,
                                       unsigned multiplier,
                                       bool skipOverrideAnalysis = false);
/// Call into `multiBuffer` with  locally constructed IRRewriter.
FailureOr<memref::AllocOp> multiBuffer(memref::AllocOp allocOp,
                                       unsigned multiplier,
                                       bool skipOverrideAnalysis = false);

/// Appends patterns for extracting address computations from the instructions
/// with memory accesses such that these memory accesses use only a base
/// pointer.
///
/// For instance,
/// ```mlir
/// memref.load %base[%off0, ...]
/// ```
///
/// Will be rewritten in:
/// ```mlir
/// %new_base = memref.subview %base[%off0,...][1,...][1,...]
/// memref.load %new_base[%c0,...]
/// ```
void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns);

/// Build a new memref::AllocaOp whose dynamic sizes are independent of all
/// given independencies. If the op is already independent of all
/// independencies, the same AllocaOp result is returned.
///
/// Failure indicates the no suitable upper bound for the dynamic sizes could be
/// found.
FailureOr<Value> buildIndependentOp(OpBuilder &b, AllocaOp allocaOp,
                                    ValueRange independencies);

/// Build a new memref::AllocaOp whose dynamic sizes are independent of all
/// given independencies. If the op is already independent of all
/// independencies, the same AllocaOp result is returned.
///
/// The original AllocaOp is replaced with the new one, wrapped in a SubviewOp.
/// The result type of the replacement is different from the original allocation
/// type: it has the same shape, but a different layout map. This function
/// updates all users that do not have a memref result or memref region block
/// argument, and some frequently used memref dialect ops (such as
/// memref.subview). It does not update other uses such as the init_arg of an
/// scf.for op. Such uses are wrapped in unrealized_conversion_cast.
///
/// Failure indicates the no suitable upper bound for the dynamic sizes could be
/// found.
///
/// Example (make independent of %iv):
/// ```
/// scf.for %iv = %c0 to %sz step %c1 {
///   %0 = memref.alloca(%iv) : memref<?xf32>
///   %1 = memref.subview %0[0][5][1] : ...
///   linalg.generic outs(%1 : ...) ...
///   %2 = scf.for ... iter_arg(%arg0 = %0) ...
///   ...
/// }
/// ```
///
/// The above IR is rewritten to:
///
/// ```
/// scf.for %iv = %c0 to %sz step %c1 {
///   %0 = memref.alloca(%sz - 1) : memref<?xf32>
///   %0_subview = memref.subview %0[0][%iv][1]
///       : memref<?xf32> to memref<?xf32, #map>
///   %1 = memref.subview %0_subview[0][5][1] : ...
///   linalg.generic outs(%1 : ...) ...
///   %cast = unrealized_conversion_cast %0_subview
///       : memref<?xf32, #map> to memref<?xf32>
///   %2 = scf.for ... iter_arg(%arg0 = %cast) ...
///  ...
/// }
/// ```
FailureOr<Value> replaceWithIndependentOp(RewriterBase &rewriter,
                                          memref::AllocaOp allocaOp,
                                          ValueRange independencies);

/// Replaces the given `alloc` with the corresponding `alloca` and returns it if
/// the following conditions are met:
///   - the corresponding dealloc is available in the same block as the alloc;
///   - the filter, if provided, succeeds on the alloc/dealloc pair.
/// Otherwise returns nullptr and leaves the IR unchanged.
memref::AllocaOp allocToAlloca(
    RewriterBase &rewriter, memref::AllocOp alloc,
    function_ref<bool(memref::AllocOp, memref::DeallocOp)> filter = nullptr);

} // namespace memref
} // namespace mlir

#endif