llvm/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp

//===- BufferDeallocationOpInterfaceImpl.cpp ------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.h"
#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"

usingnamespacemlir;
usingnamespacemlir::bufferization;

static bool isMemref(Value v) {}

namespace {
/// While CondBranchOp also implement the BranchOpInterface, we add a
/// special-case implementation here because the BranchOpInterface does not
/// offer all of the functionallity we need to insert dealloc oeprations in an
/// efficient way. More precisely, there is no way to extract the branch
/// condition without casting to CondBranchOp specifically. It is still
/// possible to implement deallocation for cases where we don't know to which
/// successor the terminator branches before the actual branch happens by
/// inserting auxiliary blocks and putting the dealloc op there, however, this
/// can lead to less efficient code.
/// This function inserts two dealloc operations (one for each successor) and
/// adjusts the dealloc conditions according to the branch condition, then the
/// ownerships of the retained MemRefs are updated by combining the result
/// values of the two dealloc operations.
///
/// Example:
/// ```
/// ^bb1:
///   <more ops...>
///   cf.cond_br cond, ^bb2(<forward-to-bb2>), ^bb3(<forward-to-bb2>)
/// ```
/// becomes
/// ```
/// // let (m, c) = getMemrefsAndConditionsToDeallocate(bb1)
/// // let r0 = getMemrefsToRetain(bb1, bb2, <forward-to-bb2>)
/// // let r1 = getMemrefsToRetain(bb1, bb3, <forward-to-bb3>)
/// ^bb1:
///   <more ops...>
///   let thenCond = map(c, (c) -> arith.andi cond, c)
///   let elseCond = map(c, (c) -> arith.andi (arith.xori cond, true), c)
///   o0 = bufferization.dealloc m if thenCond retain r0
///   o1 = bufferization.dealloc m if elseCond retain r1
///   // replace ownership(r0) with o0 element-wise
///   // replace ownership(r1) with o1 element-wise
///   // let ownership0 := (r) -> o in o0 corresponding to r
///   // let ownership1 := (r) -> o in o1 corresponding to r
///   // let cmn := intersection(r0, r1)
///   foreach (a, b) in zip(map(cmn, ownership0), map(cmn, ownership1)):
///     forall r in r0: replace ownership0(r) with arith.select cond, a, b)
///     forall r in r1: replace ownership1(r) with arith.select cond, a, b)
///   cf.cond_br cond, ^bb2(<forward-to-bb2>, o0), ^bb3(<forward-to-bb3>, o1)
/// ```
struct CondBranchOpInterface
    : public BufferDeallocationOpInterface::ExternalModel<CondBranchOpInterface,
                                                          cf::CondBranchOp> {};

} // namespace

void mlir::cf::registerBufferDeallocationOpInterfaceExternalModels(
    DialectRegistry &registry) {}