llvm/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp

//===- ResolveShapedTypeResultDims.cpp - Resolve dim ops of result values -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This pass resolves `memref.dim` operations of result values in terms of
// shapes of their operands using the `InferShapedTypeOpInterface`.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/MemRef/Transforms/Passes.h"

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir {
namespace memref {
#define GEN_PASS_DEF_RESOLVERANKEDSHAPETYPERESULTDIMS
#define GEN_PASS_DEF_RESOLVESHAPEDTYPERESULTDIMS
#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
} // namespace memref
} // namespace mlir

usingnamespacemlir;

namespace {
/// Fold dim of an operation that implements the InferShapedTypeOpInterface
template <typename OpTy>
struct DimOfShapedTypeOpInterface : public OpRewritePattern<OpTy> {};

/// Fold dim of an operation that implements the InferShapedTypeOpInterface
template <typename OpTy>
struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {};

/// Fold dim ops of iter_args to dim ops of their respective init args. E.g.:
///
/// ```
/// %0 = ... : tensor<?x?xf32>
/// scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) {
///   %1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
///   ...
/// }
/// ```
///
/// is folded to:
///
/// ```
/// %0 = ... : tensor<?x?xf32>
/// scf.forall ... shared_outs(%arg0 = %0) -> (tensor<?x?xf32>) {
///   %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
///   ...
/// }
/// ```
struct IterArgsToInitArgs : public OpRewritePattern<tensor::DimOp> {};
} // namespace

//===----------------------------------------------------------------------===//
// Pass registration
//===----------------------------------------------------------------------===//

namespace {
struct ResolveRankedShapeTypeResultDimsPass final
    : public memref::impl::ResolveRankedShapeTypeResultDimsBase<
          ResolveRankedShapeTypeResultDimsPass> {};

struct ResolveShapedTypeResultDimsPass final
    : public memref::impl::ResolveShapedTypeResultDimsBase<
          ResolveShapedTypeResultDimsPass> {};

} // namespace

void memref::populateResolveRankedShapedTypeResultDimsPatterns(
    RewritePatternSet &patterns) {}

void memref::populateResolveShapedTypeResultDimsPatterns(
    RewritePatternSet &patterns) {}

void ResolveRankedShapeTypeResultDimsPass::runOnOperation() {}

void ResolveShapedTypeResultDimsPass::runOnOperation() {}

std::unique_ptr<Pass> memref::createResolveShapedTypeResultDimsPass() {}

std::unique_ptr<Pass> memref::createResolveRankedShapeTypeResultDimsPass() {}