#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/STLExtras.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTSHAPETOSTANDARD
#include "mlir/Conversion/Passes.h.inc"
}
usingnamespacemlir;
usingnamespacemlir::shape;
usingnamespacemlir::scf;
namespace {
class AnyOpConversion : public OpConversionPattern<AnyOp> { … };
}
LogicalResult
AnyOpConversion::matchAndRewrite(AnyOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { … }
namespace {
template <typename SrcOpTy, typename DstOpTy>
class BinaryOpConversion : public OpConversionPattern<SrcOpTy> { … };
}
namespace {
struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> { … };
Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors,
ValueRange rankDiffs, Value outputDimension) { … }
}
LogicalResult BroadcastOpConverter::matchAndRewrite(
BroadcastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { … }
namespace {
class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> { … };
}
LogicalResult ConstShapeOpConverter::matchAndRewrite(
ConstShapeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { … }
namespace {
class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> { … };
}
LogicalResult ConstSizeOpConversion::matchAndRewrite(
ConstSizeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { … }
namespace {
struct IsBroadcastableOpConverter
: public OpConversionPattern<IsBroadcastableOp> { … };
}
LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
IsBroadcastableOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { … }
namespace {
class DimOpConverter : public OpConversionPattern<DimOp> { … };
}
LogicalResult
DimOpConverter::matchAndRewrite(DimOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { … }
namespace {
class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> { … };
}
LogicalResult GetExtentOpConverter::matchAndRewrite(
GetExtentOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { … }
namespace {
class RankOpConverter : public OpConversionPattern<shape::RankOp> { … };
}
LogicalResult
RankOpConverter::matchAndRewrite(shape::RankOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { … }
namespace {
struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> { … };
}
LogicalResult
ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { … }
namespace {
struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> { … };
}
LogicalResult
ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { … }
namespace {
class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> { … };
}
LogicalResult ShapeOfOpConversion::matchAndRewrite(
ShapeOfOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { … }
namespace {
class SplitAtOpConversion : public OpConversionPattern<SplitAtOp> { … };
}
LogicalResult SplitAtOpConversion::matchAndRewrite(
SplitAtOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const { … }
namespace {
class ToExtentTensorOpConversion
: public OpConversionPattern<ToExtentTensorOp> { … };
}
namespace {
#include "ShapeToStandard.cpp.inc"
}
namespace {
class ConvertShapeToStandardPass
: public impl::ConvertShapeToStandardBase<ConvertShapeToStandardPass> { … };
}
void ConvertShapeToStandardPass::runOnOperation() { … }
void mlir::populateShapeToStandardConversionPatterns(
RewritePatternSet &patterns) { … }
std::unique_ptr<OperationPass<ModuleOp>>
mlir::createConvertShapeToStandardPass() { … }