llvm/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp

//===- ResolveShapedTypeResultDims.cpp - Resolve dim ops of result values -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This pass resolves `memref.dim` operations of result values in terms of
// shapes of their operands using the `InferShapedTypeOpInterface`.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/MemRef/Transforms/Passes.h"

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir {
namespace memref {
#define GEN_PASS_DEF_RESOLVERANKEDSHAPETYPERESULTDIMS
#define GEN_PASS_DEF_RESOLVESHAPEDTYPERESULTDIMS
#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
} // namespace memref
} // namespace mlir

usingnamespacemlir;

namespace {
/// Fold dim of an operation that implements the InferShapedTypeOpInterface
template <typename OpTy>
struct DimOfShapedTypeOpInterface : public OpRewritePattern<OpTy> {};

/// Fold dim of an operation that implements the InferShapedTypeOpInterface
template <typename OpTy>
struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {};
} // namespace

//===----------------------------------------------------------------------===//
// Pass registration
//===----------------------------------------------------------------------===//

namespace {
struct ResolveRankedShapeTypeResultDimsPass final
    : public memref::impl::ResolveRankedShapeTypeResultDimsBase<
          ResolveRankedShapeTypeResultDimsPass> {};

struct ResolveShapedTypeResultDimsPass final
    : public memref::impl::ResolveShapedTypeResultDimsBase<
          ResolveShapedTypeResultDimsPass> {};

} // namespace

void memref::populateResolveRankedShapedTypeResultDimsPatterns(
    RewritePatternSet &patterns) {}

void memref::populateResolveShapedTypeResultDimsPatterns(
    RewritePatternSet &patterns) {}

void ResolveRankedShapeTypeResultDimsPass::runOnOperation() {}

void ResolveShapedTypeResultDimsPass::runOnOperation() {}

std::unique_ptr<Pass> memref::createResolveShapedTypeResultDimsPass() {}

std::unique_ptr<Pass> memref::createResolveRankedShapeTypeResultDimsPass() {}