llvm/flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp

//===- InlineElementals.cpp - Inline chained hlfir.elemental ops ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// Chained elemental operations like a + b + c can inline the first elemental
// at the hlfir.apply in the body of the second one (as described in
// docs/HighLevelFIR.md). This has to be done in a pass rather than in lowering
// so that it happens after the HLFIR intrinsic simplification pass.
//===----------------------------------------------------------------------===//

#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/HLFIRTools.h"
#include "flang/Optimizer/Dialect/Support/FIRContext.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/HLFIR/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"
#include <iterator>

namespace hlfir {
#define GEN_PASS_DEF_INLINEELEMENTALS
#include "flang/Optimizer/HLFIR/Passes.h.inc"
} // namespace hlfir

/// If the elemental has only two uses and those two are an apply operation and
/// a destroy operation, return those two, otherwise return {}
static std::optional<std::pair<hlfir::ApplyOp, hlfir::DestroyOp>>
getTwoUses(hlfir::ElementalOp elemental) {
  mlir::Operation::user_range users = elemental->getUsers();
  // don't inline anything with more than one use (plus hfir.destroy)
  if (std::distance(users.begin(), users.end()) != 2) {
    return std::nullopt;
  }

  // If the ElementalOp must produce a temporary (e.g. for
  // finalization purposes), then we cannot inline it.
  if (hlfir::elementalOpMustProduceTemp(elemental))
    return std::nullopt;

  hlfir::ApplyOp apply;
  hlfir::DestroyOp destroy;
  for (mlir::Operation *user : users)
    mlir::TypeSwitch<mlir::Operation *, void>(user)
        .Case([&](hlfir::ApplyOp op) { apply = op; })
        .Case([&](hlfir::DestroyOp op) { destroy = op; });

  if (!apply || !destroy)
    return std::nullopt;

  // we can't inline if the return type of the yield doesn't match the return
  // type of the apply
  auto yield = mlir::dyn_cast_or_null<hlfir::YieldElementOp>(
      elemental.getRegion().back().back());
  assert(yield && "hlfir.elemental should always end with a yield");
  if (apply.getResult().getType() != yield.getElementValue().getType())
    return std::nullopt;

  return std::pair{apply, destroy};
}

namespace {
class InlineElementalConversion
    : public mlir::OpRewritePattern<hlfir::ElementalOp> {
public:
  using mlir::OpRewritePattern<hlfir::ElementalOp>::OpRewritePattern;

  llvm::LogicalResult
  matchAndRewrite(hlfir::ElementalOp elemental,
                  mlir::PatternRewriter &rewriter) const override {
    std::optional<std::pair<hlfir::ApplyOp, hlfir::DestroyOp>> maybeTuple =
        getTwoUses(elemental);
    if (!maybeTuple)
      return rewriter.notifyMatchFailure(
          elemental, "hlfir.elemental does not have two uses");

    if (elemental.isOrdered()) {
      // We can only inline the ordered elemental into a loop-like
      // construct that processes the indices in-order and does not
      // have the side effects itself. Adhere to conservative behavior
      // for the time being.
      return rewriter.notifyMatchFailure(elemental,
                                         "hlfir.elemental is ordered");
    }
    auto [apply, destroy] = *maybeTuple;

    assert(elemental.getRegion().hasOneBlock() &&
           "expect elemental region to have one block");

    fir::FirOpBuilder builder{rewriter, elemental.getOperation()};
    builder.setInsertionPointAfter(apply);
    hlfir::YieldElementOp yield = hlfir::inlineElementalOp(
        elemental.getLoc(), builder, elemental, apply.getIndices());

    // remove the old elemental and all of the bookkeeping
    rewriter.replaceAllUsesWith(apply.getResult(), yield.getElementValue());
    rewriter.eraseOp(yield);
    rewriter.eraseOp(apply);
    rewriter.eraseOp(destroy);
    rewriter.eraseOp(elemental);

    return mlir::success();
  }
};

class InlineElementalsPass
    : public hlfir::impl::InlineElementalsBase<InlineElementalsPass> {
public:
  void runOnOperation() override {
    mlir::MLIRContext *context = &getContext();

    mlir::GreedyRewriteConfig config;
    // Prevent the pattern driver from merging blocks.
    config.enableRegionSimplification =
        mlir::GreedySimplifyRegionLevel::Disabled;

    mlir::RewritePatternSet patterns(context);
    patterns.insert<InlineElementalConversion>(context);

    if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
            getOperation(), std::move(patterns), config))) {
      mlir::emitError(getOperation()->getLoc(),
                      "failure in HLFIR elemental inlining");
      signalPassFailure();
    }
  }
};
} // namespace