llvm/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp

//===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/STLExtras.h"

namespace mlir {
#define GEN_PASS_DEF_CONVERTSHAPETOSTANDARD
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir

usingnamespacemlir;
usingnamespacemlir::shape;
usingnamespacemlir::scf;

/// Conversion patterns.
namespace {
class AnyOpConversion : public OpConversionPattern<AnyOp> {};
} // namespace

LogicalResult
AnyOpConversion::matchAndRewrite(AnyOp op, OpAdaptor adaptor,
                                 ConversionPatternRewriter &rewriter) const {}

namespace {
template <typename SrcOpTy, typename DstOpTy>
class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {};
} // namespace

namespace {
struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> {};

// Get the resulting extent in a given dimension. This is computed with any
// number of extent tensors and shifted offsets into them.
Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors,
                        ValueRange rankDiffs, Value outputDimension) {}
} // namespace

LogicalResult BroadcastOpConverter::matchAndRewrite(
    BroadcastOp op, OpAdaptor adaptor,
    ConversionPatternRewriter &rewriter) const {}

namespace {
class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> {};
} // namespace

LogicalResult ConstShapeOpConverter::matchAndRewrite(
    ConstShapeOp op, OpAdaptor adaptor,
    ConversionPatternRewriter &rewriter) const {}

namespace {
class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> {};
} // namespace

LogicalResult ConstSizeOpConversion::matchAndRewrite(
    ConstSizeOp op, OpAdaptor adaptor,
    ConversionPatternRewriter &rewriter) const {}

namespace {
struct IsBroadcastableOpConverter
    : public OpConversionPattern<IsBroadcastableOp> {};
} // namespace

LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
    IsBroadcastableOp op, OpAdaptor adaptor,
    ConversionPatternRewriter &rewriter) const {}

namespace {
class DimOpConverter : public OpConversionPattern<DimOp> {};
} // namespace

LogicalResult
DimOpConverter::matchAndRewrite(DimOp op, OpAdaptor adaptor,
                                ConversionPatternRewriter &rewriter) const {}

namespace {
class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {};
} // namespace

LogicalResult GetExtentOpConverter::matchAndRewrite(
    GetExtentOp op, OpAdaptor adaptor,
    ConversionPatternRewriter &rewriter) const {}

namespace {
class RankOpConverter : public OpConversionPattern<shape::RankOp> {};
} // namespace

LogicalResult
RankOpConverter::matchAndRewrite(shape::RankOp op, OpAdaptor adaptor,
                                 ConversionPatternRewriter &rewriter) const {}

namespace {
/// Converts `shape.reduce` to `scf.for`.
struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> {};
} // namespace

LogicalResult
ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
                                   ConversionPatternRewriter &rewriter) const {}

namespace {
/// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is
/// only defined on `tensor<?xindex>` operands. The test for equality first
/// compares their size and, if equal, checks every extent for equality.
///
/// Example:
///
/// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
///
/// becomes
///
/// %c0 = arith.constant 0 : index
/// %0 = dim %arg0, %c0 : tensor<?xindex>
/// %1 = dim %arg1, %c0 : tensor<?xindex>
/// %2 = arith.cmpi "eq", %0, %1 : index
/// %result = scf.if %2 -> (i1) {
///   %c1 = arith.constant 1 : index
///   %true = arith.constant true
///   %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) {
///     %5 = tensor.extract %arg0[%arg2] : tensor<?xindex>
///     %6 = tensor.extract %arg1[%arg2] : tensor<?xindex>
///     %7 = arith.cmpi "eq", %5, %6 : index
///     %8 = arith.andi %arg3, %7 : i1
///     scf.yield %8 : i1
///   }
///   scf.yield %4 : i1
/// } else {
///   %false = arith.constant false
///   scf.yield %false : i1
/// }
///
struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> {};
} // namespace

LogicalResult
ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
                                    ConversionPatternRewriter &rewriter) const {}

namespace {
class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {};
} // namespace

LogicalResult ShapeOfOpConversion::matchAndRewrite(
    ShapeOfOp op, OpAdaptor adaptor,
    ConversionPatternRewriter &rewriter) const {}

namespace {
class SplitAtOpConversion : public OpConversionPattern<SplitAtOp> {};
} // namespace

LogicalResult SplitAtOpConversion::matchAndRewrite(
    SplitAtOp op, OpAdaptor adaptor,
    ConversionPatternRewriter &rewriter) const {}

namespace {
class ToExtentTensorOpConversion
    : public OpConversionPattern<ToExtentTensorOp> {};
} // namespace

namespace {
/// Import the Shape Ops to Std Patterns.
#include "ShapeToStandard.cpp.inc"
} // namespace

namespace {
/// Conversion pass.
class ConvertShapeToStandardPass
    : public impl::ConvertShapeToStandardBase<ConvertShapeToStandardPass> {};
} // namespace

void ConvertShapeToStandardPass::runOnOperation() {}

void mlir::populateShapeToStandardConversionPatterns(
    RewritePatternSet &patterns) {}

std::unique_ptr<OperationPass<ModuleOp>>
mlir::createConvertShapeToStandardPass() {}