llvm/mlir/lib/Dialect/MemRef/Transforms/ExpandRealloc.cpp

//===- ExpandRealloc.cpp - Expand memref.realloc ops into it's components -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Transforms/DialectConversion.h"

namespace mlir {
namespace memref {
#define GEN_PASS_DEF_EXPANDREALLOC
#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
} // namespace memref
} // namespace mlir

usingnamespacemlir;

namespace {

/// The `realloc` operation performs a conditional allocation and copy to
/// increase the size of a buffer if necessary. This pattern converts the
/// `realloc` operation into this sequence of simpler operations.

/// Example of an expansion:
/// ```mlir
/// %realloc = memref.realloc %alloc (%size) : memref<?xf32> to memref<?xf32>
/// ```
/// is expanded to
/// ```mlir
/// %c0 = arith.constant 0 : index
/// %dim = memref.dim %alloc, %c0 : memref<?xf32>
/// %is_old_smaller = arith.cmpi ult, %dim, %arg1
/// %realloc = scf.if %is_old_smaller -> (memref<?xf32>) {
///   %new_alloc = memref.alloc(%size) : memref<?xf32>
///   %subview = memref.subview %new_alloc[0] [%dim] [1]
///   memref.copy %alloc, %subview
///   memref.dealloc %alloc
///   scf.yield %alloc_0 : memref<?xf32>
/// } else {
///   %reinterpret_cast = memref.reinterpret_cast %alloc to
///     offset: [0], sizes: [%size], strides: [1]
///   scf.yield %reinterpret_cast : memref<?xf32>
/// }
/// ```
struct ExpandReallocOpPattern : public OpRewritePattern<memref::ReallocOp> {};

struct ExpandReallocPass
    : public memref::impl::ExpandReallocBase<ExpandReallocPass> {};

} // namespace

void mlir::memref::populateExpandReallocPatterns(RewritePatternSet &patterns,
                                                 bool emitDeallocs) {}

std::unique_ptr<Pass> mlir::memref::createExpandReallocPass(bool emitDeallocs) {}