llvm/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp

//===- BufferDeallocationSimplification.cpp -------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements logic for optimizing `bufferization.dealloc` operations
// that requires more analysis than what can be supported by regular
// canonicalization patterns.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir {
namespace bufferization {
#define GEN_PASS_DEF_BUFFERDEALLOCATIONSIMPLIFICATION
#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
} // namespace bufferization
} // namespace mlir

usingnamespacemlir;
usingnamespacemlir::bufferization;

//===----------------------------------------------------------------------===//
// Helpers
//===----------------------------------------------------------------------===//

/// Given a memref value, return the "base" value by skipping over all
/// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
static Value getViewBase(Value value) {}

static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
                                            ValueRange memrefs,
                                            ValueRange conditions,
                                            PatternRewriter &rewriter) {}

/// Return "true" if the given values are guaranteed to be different (and
/// non-aliasing) allocations based on the fact that one value is the result
/// of an allocation and the other value is a block argument of a parent block.
/// Note: This is a best-effort analysis that will eventually be replaced by a
/// proper "is same allocation" analysis. This function may return "false" even
/// though the two values are distinct allocations.
static bool distinctAllocAndBlockArgument(Value v1, Value v2) {}

/// Checks if `memref` may potentially alias a MemRef in `otherList`. It is
/// often a requirement of optimization patterns that there cannot be any
/// aliasing memref in order to perform the desired simplification.
static bool potentiallyAliasesMemref(BufferOriginAnalysis &analysis,
                                     ValueRange otherList, Value memref) {}

//===----------------------------------------------------------------------===//
// Patterns
//===----------------------------------------------------------------------===//

namespace {

/// Remove values from the `memref` operand list that are also present in the
/// `retained` list (or a guaranteed alias of it) because they will never
/// actually be deallocated. However, we also need to be certain about which
/// other memrefs in the `retained` list can alias, i.e., there must not by any
/// may-aliasing memref. This is necessary because the `dealloc` operation is
/// defined to return one `i1` value per memref in the `retained` list which
/// represents the disjunction of the condition values corresponding to all
/// aliasing values in the `memref` list. In particular, this means that if
/// there is some value R in the `retained` list which aliases with a value M in
/// the `memref` list (but can only be staticaly determined to may-alias) and M
/// is also present in the `retained` list, then it would be illegal to remove M
/// because the result corresponding to R would be computed incorrectly
/// afterwards.  Because we require an alias analysis, this pattern cannot be
/// applied as a regular canonicalization pattern.
///
/// Example:
/// ```mlir
/// %0:3 = bufferization.dealloc (%m0 : ...) if (%cond0)
///                     retain (%m0, %r0, %r1 : ...)
/// ```
/// is canonicalized to
/// ```mlir
/// // bufferization.dealloc without memrefs and conditions returns %false for
/// // every retained value
/// %0:3 = bufferization.dealloc retain (%m0, %r0, %r1 : ...)
/// %1 = arith.ori %0#0, %cond0 : i1
/// // replace %0#0 with %1
/// ```
/// given that `%r0` and `%r1` may not alias with `%m0`.
struct RemoveDeallocMemrefsContainedInRetained
    : public OpRewritePattern<DeallocOp> {};

/// Remove memrefs from the `retained` list which are guaranteed to not alias
/// any memref in the `memrefs` list. The corresponding result value can be
/// replaced with `false` in that case according to the operation description.
///
/// Example:
/// ```mlir
/// %0:2 = bufferization.dealloc (%m : memref<2xi32>) if (%cond)
///                       retain (%r0, %r1 : memref<2xi32>, memref<2xi32>)
/// return %0#0, %0#1
/// ```
/// can be canonicalized to the following given that `%r0` and `%r1` do not
/// alias `%m`:
/// ```mlir
/// bufferization.dealloc (%m : memref<2xi32>) if (%cond)
/// return %false, %false
/// ```
struct RemoveRetainedMemrefsGuaranteedToNotAlias
    : public OpRewritePattern<DeallocOp> {};

/// Split off memrefs to separate dealloc operations to reduce the number of
/// runtime checks required and enable further canonicalization of the new and
/// simpler dealloc operations. A memref can be split off if it is guaranteed to
/// not alias with any other memref in the `memref` operand list.  The results
/// of the old and the new dealloc operation have to be combined by computing
/// the element-wise disjunction of them.
///
/// Example:
/// ```mlir
/// %0:2 = bufferization.dealloc (%m0, %m1 : memref<2xi32>, memref<2xi32>)
///                           if (%cond0, %cond1)
///                       retain (%r0, %r1 : memref<2xi32>, memref<2xi32>)
/// return %0#0, %0#1
/// ```
/// Given that `%m0` is guaranteed to never alias with `%m1`, the above IR is
/// canonicalized to the following, thus reducing the number of runtime alias
/// checks by 1 and potentially enabling further canonicalization of the new
/// split-up dealloc operations.
/// ```mlir
/// %0:2 = bufferization.dealloc (%m0 : memref<2xi32>) if (%cond0)
///                       retain (%r0, %r1 : memref<2xi32>, memref<2xi32>)
/// %1:2 = bufferization.dealloc (%m1 : memref<2xi32>) if (%cond1)
///                       retain (%r0, %r1 : memref<2xi32>, memref<2xi32>)
/// %2 = arith.ori %0#0, %1#0
/// %3 = arith.ori %0#1, %1#1
/// return %2, %3
/// ```
struct SplitDeallocWhenNotAliasingAnyOther
    : public OpRewritePattern<DeallocOp> {};

/// Check for every retained memref if a must-aliasing memref exists in the
/// 'memref' operand list with constant 'true' condition. If so, we can replace
/// the operation result corresponding to that retained memref with 'true'. If
/// this condition holds for all retained memrefs we can also remove the
/// aliasing memrefs and their conditions since they will never be deallocated
/// due to the must-alias and we don't need them to compute the result value
/// anymore since it got replaced with 'true'.
///
/// Example:
/// ```mlir
/// %0:2 = bufferization.dealloc (%arg0, %arg1, %arg2 : ...)
///                           if (%true, %true, %true)
///                       retain (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
/// ```
/// becomes
/// ```mlir
/// %0:2 = bufferization.dealloc (%arg2 : memref<2xi32>) if (%true)
///                       retain (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
/// // replace %0#0 with %true
/// // replace %0#1 with %true
/// ```
/// Note that the dealloc operation will still have the result values, but they
/// don't have uses anymore.
struct RetainedMemrefAliasingAlwaysDeallocatedMemref
    : public OpRewritePattern<DeallocOp> {};

} // namespace

//===----------------------------------------------------------------------===//
// BufferDeallocationSimplificationPass
//===----------------------------------------------------------------------===//

namespace {

/// The actual buffer deallocation pass that inserts and moves dealloc nodes
/// into the right positions. Furthermore, it inserts additional clones if
/// necessary. It uses the algorithm described at the top of the file.
struct BufferDeallocationSimplificationPass
    : public bufferization::impl::BufferDeallocationSimplificationBase<
          BufferDeallocationSimplificationPass> {};

} // namespace

std::unique_ptr<Pass>
mlir::bufferization::createBufferDeallocationSimplificationPass() {}