#include <utility>
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/CommonFolders.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/SetOperations.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/raw_ostream.h"
usingnamespacemlir;
usingnamespacemlir::shape;
#include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc"
namespace {
#include "ShapeCanonicalization.inc"
}
RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) { … }
bool shape::isExtentTensorType(Type type) { … }
LogicalResult shape::getShapeVec(Value input,
SmallVectorImpl<int64_t> &shapeValues) { … }
static bool isErrorPropagationPossible(TypeRange operandTypes) { … }
static LogicalResult verifySizeOrIndexOp(Operation *op) { … }
static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) { … }
template <typename... Ty>
static bool eachHasOnlyOneOfTypes(TypeRange typeRange) { … }
template <typename... Ty, typename... ranges>
static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) { … }
namespace {
struct ShapeInlinerInterface : public DialectInlinerInterface { … };
}
void ShapeDialect::initialize() { … }
Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) { … }
LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attribute) { … }
OpFoldResult AnyOp::fold(FoldAdaptor adaptor) { … }
ParseResult AssumingOp::parse(OpAsmParser &parser, OperationState &result) { … }
void AssumingOp::print(OpAsmPrinter &p) { … }
namespace {
struct AssumingWithTrue : public OpRewritePattern<AssumingOp> { … };
struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> { … };
}
void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) { … }
void AssumingOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { … }
void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
PatternRewriter &rewriter) { … }
void AssumingOp::build(
OpBuilder &builder, OperationState &result, Value witness,
function_ref<SmallVector<Value, 2>(OpBuilder &, Location)> bodyBuilder) { … }
LogicalResult mlir::shape::AddOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location,
AddOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) { … }
bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { … }
OpFoldResult mlir::shape::AddOp::fold(FoldAdaptor adaptor) { … }
LogicalResult shape::AddOp::verify() { … }
namespace {
struct MergeAssumingAllOps : public OpRewritePattern<AssumingAllOp> { … };
struct AssumingAllOfCstrBroadcastable : public OpRewritePattern<AssumingAllOp> { … };
struct AssumingAllToCstrEqCanonicalization
: public OpRewritePattern<AssumingAllOp> { … };
template <typename OpTy>
struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> { … };
}
void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) { … }
OpFoldResult AssumingAllOp::fold(FoldAdaptor adaptor) { … }
LogicalResult AssumingAllOp::verify() { … }
OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) { … }
LogicalResult BroadcastOp::verify() { … }
namespace {
template <typename OpTy>
struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> { … };
struct BroadcastForwardSingleOperandPattern
: public OpRewritePattern<BroadcastOp> { … };
struct BroadcastFoldConstantOperandsPattern
: public OpRewritePattern<BroadcastOp> { … };
template <typename OpTy>
struct CanonicalizeCastExtentTensorOperandsPattern
: public OpRewritePattern<OpTy> { … };
struct BroadcastConcretizeResultTypePattern
: public OpRewritePattern<BroadcastOp> { … };
}
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) { … }
OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) { … }
void ConstShapeOp::print(OpAsmPrinter &p) { … }
ParseResult ConstShapeOp::parse(OpAsmParser &parser, OperationState &result) { … }
OpFoldResult ConstShapeOp::fold(FoldAdaptor) { … }
void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) { … }
LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location,
ConstShapeOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) { … }
bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l,
TypeRange r) { … }
void CstrBroadcastableOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) { … }
static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) { … }
OpFoldResult CstrBroadcastableOp::fold(FoldAdaptor adaptor) { … }
LogicalResult CstrBroadcastableOp::verify() { … }
void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) { … }
OpFoldResult CstrEqOp::fold(FoldAdaptor adaptor) { … }
void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
int64_t value) { … }
OpFoldResult ConstSizeOp::fold(FoldAdaptor) { … }
void ConstSizeOp::getAsmResultNames(
llvm::function_ref<void(Value, StringRef)> setNameFn) { … }
OpFoldResult ConstWitnessOp::fold(FoldAdaptor) { … }
OpFoldResult CstrRequireOp::fold(FoldAdaptor adaptor) { … }
std::optional<int64_t> DimOp::getConstantIndex() { … }
OpFoldResult DimOp::fold(FoldAdaptor adaptor) { … }
LogicalResult mlir::shape::DimOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location,
DimOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) { … }
bool mlir::shape::DimOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { … }
OpFoldResult DivOp::fold(FoldAdaptor adaptor) { … }
LogicalResult mlir::shape::DivOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location,
DivOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) { … }
bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { … }
LogicalResult DivOp::verify() { … }
OpFoldResult ShapeEqOp::fold(FoldAdaptor adaptor) { … }
OpFoldResult IndexToSizeOp::fold(FoldAdaptor adaptor) { … }
void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) { … }
OpFoldResult FromExtentsOp::fold(FoldAdaptor adaptor) { … }
void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
StringRef name) { … }
FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) { … }
ParseResult FunctionLibraryOp::parse(OpAsmParser &parser,
OperationState &result) { … }
void FunctionLibraryOp::print(OpAsmPrinter &p) { … }
FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs) { … }
FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
Operation::dialect_attr_range attrs) { … }
FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs,
ArrayRef<DictionaryAttr> argAttrs) { … }
void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
FunctionType type, ArrayRef<NamedAttribute> attrs,
ArrayRef<DictionaryAttr> argAttrs) { … }
ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { … }
void FuncOp::print(OpAsmPrinter &p) { … }
std::optional<int64_t> GetExtentOp::getConstantDim() { … }
OpFoldResult GetExtentOp::fold(FoldAdaptor adaptor) { … }
void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
int64_t dim) { … }
LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location,
GetExtentOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) { … }
bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l,
TypeRange r) { … }
LogicalResult GetExtentOp::verify() { … }
void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) { … }
OpFoldResult IsBroadcastableOp::fold(FoldAdaptor adaptor) { … }
LogicalResult mlir::shape::MeetOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location,
MeetOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) { … }
bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { … }
OpFoldResult shape::RankOp::fold(FoldAdaptor adaptor) { … }
namespace {
struct RankShapeOfCanonicalizationPattern
: public OpRewritePattern<shape::RankOp> { … };
}
void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) { … }
LogicalResult mlir::shape::RankOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location,
RankOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) { … }
bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { … }
LogicalResult shape::RankOp::verify() { … }
OpFoldResult NumElementsOp::fold(FoldAdaptor adaptor) { … }
LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location,
NumElementsOp::Adaptor adaptor,
SmallVectorImpl<Type> &inferredReturnTypes) { … }
bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l,
TypeRange r) { … }
LogicalResult shape::NumElementsOp::verify() { … }
OpFoldResult MaxOp::fold(FoldAdaptor adaptor) { … }
LogicalResult mlir::shape::MaxOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location,
MaxOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) { … }
bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { … }
OpFoldResult MinOp::fold(FoldAdaptor adaptor) { … }
LogicalResult mlir::shape::MinOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location,
MinOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) { … }
bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { … }
OpFoldResult MulOp::fold(FoldAdaptor adaptor) { … }
LogicalResult mlir::shape::MulOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location,
MulOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) { … }
bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { … }
LogicalResult shape::MulOp::verify() { … }
namespace {
struct ShapeOfOpToConstShapeOp : public OpRewritePattern<shape::ShapeOfOp> { … };
struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> { … };
struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> { … };
}
void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) { … }
LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location,
ShapeOfOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) { … }
bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { … }
LogicalResult shape::ShapeOfOp::verify() { … }
OpFoldResult SizeToIndexOp::fold(FoldAdaptor adaptor) { … }
void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) { … }
bool SizeToIndexOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { … }
LogicalResult shape::YieldOp::verify() { … }
LogicalResult SplitAtOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) { … }
OpFoldResult ToExtentTensorOp::fold(FoldAdaptor adaptor) { … }
bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { … }
void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
ValueRange initVals) { … }
LogicalResult ReduceOp::verify() { … }
ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) { … }
void ReduceOp::print(OpAsmPrinter &p) { … }
#define GET_OP_CLASSES
#include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"