llvm/mlir/lib/Dialect/Shape/IR/Shape.cpp

//===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include <utility>

#include "mlir/Dialect/Shape/IR/Shape.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/CommonFolders.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/SetOperations.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/raw_ostream.h"

usingnamespacemlir;
usingnamespacemlir::shape;

#include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc"

namespace {
#include "ShapeCanonicalization.inc"
} // namespace

RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) {}

bool shape::isExtentTensorType(Type type) {}

LogicalResult shape::getShapeVec(Value input,
                                 SmallVectorImpl<int64_t> &shapeValues) {}

static bool isErrorPropagationPossible(TypeRange operandTypes) {}

static LogicalResult verifySizeOrIndexOp(Operation *op) {}

static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {}

template <typename... Ty>
static bool eachHasOnlyOneOfTypes(TypeRange typeRange) {}

template <typename... Ty, typename... ranges>
static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) {}

//===----------------------------------------------------------------------===//
// InlinerInterface
//===----------------------------------------------------------------------===//

namespace {
/// This class defines the interface for inlining shape dialect ops.
struct ShapeInlinerInterface : public DialectInlinerInterface {};
} // namespace

void ShapeDialect::initialize() {}

Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
                                             Attribute value, Type type,
                                             Location loc) {}

LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
                                                     NamedAttribute attribute) {}

//===----------------------------------------------------------------------===//
// AnyOp
//===----------------------------------------------------------------------===//

// TODO: Canonicalization should be implemented for shapes that can be
// determined through mixtures of the known dimensions of the inputs.
OpFoldResult AnyOp::fold(FoldAdaptor adaptor) {}

//===----------------------------------------------------------------------===//
// AssumingOp
//===----------------------------------------------------------------------===//

ParseResult AssumingOp::parse(OpAsmParser &parser, OperationState &result) {}

void AssumingOp::print(OpAsmPrinter &p) {}

namespace {
// Removes AssumingOp with a passing witness and inlines the region.
struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {};

struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> {};
} // namespace

void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                             MLIRContext *context) {}

// See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
void AssumingOp::getSuccessorRegions(
    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {}

void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
                                        PatternRewriter &rewriter) {}

void AssumingOp::build(
    OpBuilder &builder, OperationState &result, Value witness,
    function_ref<SmallVector<Value, 2>(OpBuilder &, Location)> bodyBuilder) {}

//===----------------------------------------------------------------------===//
// AddOp
//===----------------------------------------------------------------------===//

LogicalResult mlir::shape::AddOp::inferReturnTypes(
    MLIRContext *context, std::optional<Location> location,
    AddOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {}

bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {}

OpFoldResult mlir::shape::AddOp::fold(FoldAdaptor adaptor) {}

LogicalResult shape::AddOp::verify() {}

//===----------------------------------------------------------------------===//
// AssumingAllOp
//===----------------------------------------------------------------------===//

namespace {

// Merge multiple `shape.assuming_all` operations together.
//
//   %0 = shape.assuming_all %w0, %w1
//   %1 = shape.assuming_all %w2, %0
//
// to:
//
//   %0 = shape.assuming_all %w0, %w2, %w2
struct MergeAssumingAllOps : public OpRewritePattern<AssumingAllOp> {};

// Eliminate `cstr_broadcastable` operands from `assuming_all` operation that
// are subsumed by others.
//
//   %0 = shape.cstr_broadcastable %shape0, %shape1
//   %1 = shape.cstr_broadcastable %shape0, %shape1, %shape2
//
//   %2 = shape.cstr_broadcastable %shape3, %shape4
//   %3 = shape.cstr_broadcastable %shape3, %shape4, %shape5
//
//   %4 = shape.assuming_all %0, %1, %2, %3
//
// to:
//
//   %0 = shape.cstr_broadcastable %shape0, %shape1, %shape2
//   %1 = shape.cstr_broadcastable %shape3, %shape4, %shape5
//   %2 = shape.assuming_all %0, %1
//
// In this example if shapes [0, 1, 2] are broadcastable, then it means that
// shapes [0, 1] are broadcastable too, and can be removed from the list of
// constraints. If shapes [0, 1, 2] are not broadcastable, then it doesn't
// matter if shapes [0, 1] are broadcastable (same for shapes [3, 4, 5]).
struct AssumingAllOfCstrBroadcastable : public OpRewritePattern<AssumingAllOp> {};

struct AssumingAllToCstrEqCanonicalization
    : public OpRewritePattern<AssumingAllOp> {};

template <typename OpTy>
struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {};
} // namespace

void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                                MLIRContext *context) {}

OpFoldResult AssumingAllOp::fold(FoldAdaptor adaptor) {}

LogicalResult AssumingAllOp::verify() {}

//===----------------------------------------------------------------------===//
// BroadcastOp
//===----------------------------------------------------------------------===//

OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {}

LogicalResult BroadcastOp::verify() {}

namespace {
template <typename OpTy>
struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {};

struct BroadcastForwardSingleOperandPattern
    : public OpRewritePattern<BroadcastOp> {};

struct BroadcastFoldConstantOperandsPattern
    : public OpRewritePattern<BroadcastOp> {};

template <typename OpTy>
struct CanonicalizeCastExtentTensorOperandsPattern
    : public OpRewritePattern<OpTy> {};

struct BroadcastConcretizeResultTypePattern
    : public OpRewritePattern<BroadcastOp> {};
} // namespace

void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                              MLIRContext *context) {}

//===----------------------------------------------------------------------===//
// ConcatOp
//===----------------------------------------------------------------------===//

OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {}

//===----------------------------------------------------------------------===//
// ConstShapeOp
//===----------------------------------------------------------------------===//

void ConstShapeOp::print(OpAsmPrinter &p) {}

ParseResult ConstShapeOp::parse(OpAsmParser &parser, OperationState &result) {}

OpFoldResult ConstShapeOp::fold(FoldAdaptor) {}

void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                               MLIRContext *context) {}

LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
    MLIRContext *context, std::optional<Location> location,
    ConstShapeOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {}

bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l,
                                                        TypeRange r) {}

//===----------------------------------------------------------------------===//
// CstrBroadcastableOp
//===----------------------------------------------------------------------===//

void CstrBroadcastableOp::getCanonicalizationPatterns(
    RewritePatternSet &patterns, MLIRContext *context) {}

// Return true if there is exactly one attribute not representing a scalar
// broadcast.
static bool hasAtMostSingleNonScalar(ArrayRef<Attribute> attributes) {}

OpFoldResult CstrBroadcastableOp::fold(FoldAdaptor adaptor) {}

LogicalResult CstrBroadcastableOp::verify() {}

//===----------------------------------------------------------------------===//
// CstrEqOp
//===----------------------------------------------------------------------===//

void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                           MLIRContext *context) {}

OpFoldResult CstrEqOp::fold(FoldAdaptor adaptor) {}

//===----------------------------------------------------------------------===//
// ConstSizeOp
//===----------------------------------------------------------------------===//

void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
                        int64_t value) {}

OpFoldResult ConstSizeOp::fold(FoldAdaptor) {}

void ConstSizeOp::getAsmResultNames(
    llvm::function_ref<void(Value, StringRef)> setNameFn) {}

//===----------------------------------------------------------------------===//
// ConstWitnessOp
//===----------------------------------------------------------------------===//

OpFoldResult ConstWitnessOp::fold(FoldAdaptor) {}

//===----------------------------------------------------------------------===//
// CstrRequireOp
//===----------------------------------------------------------------------===//

OpFoldResult CstrRequireOp::fold(FoldAdaptor adaptor) {}

//===----------------------------------------------------------------------===//
// DimOp
//===----------------------------------------------------------------------===//

std::optional<int64_t> DimOp::getConstantIndex() {}

OpFoldResult DimOp::fold(FoldAdaptor adaptor) {}

LogicalResult mlir::shape::DimOp::inferReturnTypes(
    MLIRContext *context, std::optional<Location> location,
    DimOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {}

bool mlir::shape::DimOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {}

//===----------------------------------------------------------------------===//
// DivOp
//===----------------------------------------------------------------------===//

OpFoldResult DivOp::fold(FoldAdaptor adaptor) {}

LogicalResult mlir::shape::DivOp::inferReturnTypes(
    MLIRContext *context, std::optional<Location> location,
    DivOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {}

bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {}

LogicalResult DivOp::verify() {}

//===----------------------------------------------------------------------===//
// ShapeEqOp
//===----------------------------------------------------------------------===//

OpFoldResult ShapeEqOp::fold(FoldAdaptor adaptor) {}

//===----------------------------------------------------------------------===//
// IndexToSizeOp
//===----------------------------------------------------------------------===//

OpFoldResult IndexToSizeOp::fold(FoldAdaptor adaptor) {}

void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                                MLIRContext *context) {}

//===----------------------------------------------------------------------===//
// FromExtentsOp
//===----------------------------------------------------------------------===//

OpFoldResult FromExtentsOp::fold(FoldAdaptor adaptor) {}

//===----------------------------------------------------------------------===//
// FunctionLibraryOp
//===----------------------------------------------------------------------===//

void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
                              StringRef name) {}

FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {}

ParseResult FunctionLibraryOp::parse(OpAsmParser &parser,
                                     OperationState &result) {}

void FunctionLibraryOp::print(OpAsmPrinter &p) {}

//===----------------------------------------------------------------------===//
// FuncOp
//===----------------------------------------------------------------------===//

FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
                      ArrayRef<NamedAttribute> attrs) {}
FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
                      Operation::dialect_attr_range attrs) {}
FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
                      ArrayRef<NamedAttribute> attrs,
                      ArrayRef<DictionaryAttr> argAttrs) {}

void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
                   FunctionType type, ArrayRef<NamedAttribute> attrs,
                   ArrayRef<DictionaryAttr> argAttrs) {}

ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {}

void FuncOp::print(OpAsmPrinter &p) {}

//===----------------------------------------------------------------------===//
// GetExtentOp
//===----------------------------------------------------------------------===//

std::optional<int64_t> GetExtentOp::getConstantDim() {}

OpFoldResult GetExtentOp::fold(FoldAdaptor adaptor) {}

void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
                        int64_t dim) {}

LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(
    MLIRContext *context, std::optional<Location> location,
    GetExtentOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {}

bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l,
                                                       TypeRange r) {}

LogicalResult GetExtentOp::verify() {}

//===----------------------------------------------------------------------===//
// IsBroadcastableOp
//===----------------------------------------------------------------------===//

void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                                    MLIRContext *context) {}

OpFoldResult IsBroadcastableOp::fold(FoldAdaptor adaptor) {}

//===----------------------------------------------------------------------===//
// MeetOp
//===----------------------------------------------------------------------===//

LogicalResult mlir::shape::MeetOp::inferReturnTypes(
    MLIRContext *context, std::optional<Location> location,
    MeetOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {}

bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {}

//===----------------------------------------------------------------------===//
// RankOp
//===----------------------------------------------------------------------===//

OpFoldResult shape::RankOp::fold(FoldAdaptor adaptor) {}

/// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
/// Constant folding fails in cases where only the rank is constant, not the
/// shape itself.
/// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
///
/// Example:
///
/// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
/// %rank = shape.rank %shape
///
/// becomes
///
/// %rank = shape.const_size 3

namespace {
struct RankShapeOfCanonicalizationPattern
    : public OpRewritePattern<shape::RankOp> {};
} // namespace

void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                                MLIRContext *context) {}

LogicalResult mlir::shape::RankOp::inferReturnTypes(
    MLIRContext *context, std::optional<Location> location,
    RankOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {}

bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {}

LogicalResult shape::RankOp::verify() {}

//===----------------------------------------------------------------------===//
// NumElementsOp
//===----------------------------------------------------------------------===//

OpFoldResult NumElementsOp::fold(FoldAdaptor adaptor) {}

LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
    MLIRContext *context, std::optional<Location> location,
    NumElementsOp::Adaptor adaptor,
    SmallVectorImpl<Type> &inferredReturnTypes) {}

bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l,
                                                         TypeRange r) {}

LogicalResult shape::NumElementsOp::verify() {}

//===----------------------------------------------------------------------===//
// MaxOp
//===----------------------------------------------------------------------===//

OpFoldResult MaxOp::fold(FoldAdaptor adaptor) {}

LogicalResult mlir::shape::MaxOp::inferReturnTypes(
    MLIRContext *context, std::optional<Location> location,
    MaxOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {}

bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {}

//===----------------------------------------------------------------------===//
// MinOp
//===----------------------------------------------------------------------===//

OpFoldResult MinOp::fold(FoldAdaptor adaptor) {}

LogicalResult mlir::shape::MinOp::inferReturnTypes(
    MLIRContext *context, std::optional<Location> location,
    MinOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {}

bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {}

//===----------------------------------------------------------------------===//
// MulOp
//===----------------------------------------------------------------------===//

OpFoldResult MulOp::fold(FoldAdaptor adaptor) {}

LogicalResult mlir::shape::MulOp::inferReturnTypes(
    MLIRContext *context, std::optional<Location> location,
    MulOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {}

bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {}

LogicalResult shape::MulOp::verify() {}

//===----------------------------------------------------------------------===//
// ShapeOfOp
//===----------------------------------------------------------------------===//

namespace {
/// Replace shape_of(x) where x has a constant shape with a const_shape op.
struct ShapeOfOpToConstShapeOp : public OpRewritePattern<shape::ShapeOfOp> {};

// Canonicalize
//
// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
// %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
//
// to
//
// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
// %1 = %shape
//
struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> {};

// Canonicalize
// ```
// %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
// %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex>
// ```
// to
// ```
// %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
// ```
struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {};
} // namespace

void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                            MLIRContext *context) {}

LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
    MLIRContext *context, std::optional<Location> location,
    ShapeOfOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {}

bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {}

LogicalResult shape::ShapeOfOp::verify() {}

//===----------------------------------------------------------------------===//
// SizeToIndexOp
//===----------------------------------------------------------------------===//

OpFoldResult SizeToIndexOp::fold(FoldAdaptor adaptor) {}

void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                                MLIRContext *context) {}

bool SizeToIndexOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {}

//===----------------------------------------------------------------------===//
// YieldOp
//===----------------------------------------------------------------------===//

LogicalResult shape::YieldOp::verify() {}

//===----------------------------------------------------------------------===//
// SplitAtOp
//===----------------------------------------------------------------------===//

LogicalResult SplitAtOp::fold(FoldAdaptor adaptor,
                              SmallVectorImpl<OpFoldResult> &results) {}

//===----------------------------------------------------------------------===//
// ToExtentTensorOp
//===----------------------------------------------------------------------===//

OpFoldResult ToExtentTensorOp::fold(FoldAdaptor adaptor) {}

bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {}

//===----------------------------------------------------------------------===//
// ReduceOp
//===----------------------------------------------------------------------===//

void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
                     ValueRange initVals) {}

LogicalResult ReduceOp::verify() {}

ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {}

void ReduceOp::print(OpAsmPrinter &p) {}

#define GET_OP_CLASSES
#include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"

#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"