
//===- ArithOps.cpp - MLIR Arith dialect ops implementation -----===//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <cassert>
#include <cstdint>
#include <functional>
#include <utility>

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/CommonFolders.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"

#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/APSInt.h"
#include "llvm/ADT/FloatingPointMode.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"


// Pattern helpers

static IntegerAttr
applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs,
                    Attribute rhs,
                    function_ref<APInt(const APInt &, const APInt &)> binFn) {}

static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
                                   Attribute lhs, Attribute rhs) {}

static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
                                   Attribute lhs, Attribute rhs) {}

static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
                                   Attribute lhs, Attribute rhs) {}

// Merge overflow flags from 2 ops, selecting the most conservative combination.
static IntegerOverflowFlagsAttr
mergeOverflowFlags(IntegerOverflowFlagsAttr val1,
                   IntegerOverflowFlagsAttr val2) {}

/// Invert an integer comparison predicate.
arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {}

/// Equivalent to
/// convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)).
/// Not possible to implement as chain of calls as this would introduce a
/// circular dependency with MLIRArithAttrToLLVMConversion and make arith depend
/// on the LLVM dialect and on translation to LLVM.
static llvm::RoundingMode
convertArithRoundingModeToLLVMIR(RoundingMode roundingMode) {}

static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {}

static int64_t getScalarOrElementWidth(Type type) {}

static int64_t getScalarOrElementWidth(Value value) {}

static FailureOr<APInt> getIntOrSplatIntValue(Attribute attr) {}

static Attribute getBoolAttribute(Type type, bool value) {}

// TableGen'd canonicalization patterns

namespace {
#include "ArithCanonicalization.inc"
} // namespace

// Common helpers

/// Return the type of the same shape (scalar, vector or tensor) containing i1.
static Type getI1SameShape(Type type) {}

// ConstantOp

void arith::ConstantOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {}

/// TODO: disallow arith.constant to return anything other than signless integer
/// or float like.
LogicalResult arith::ConstantOp::verify() {}

bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {}

ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
                                          Type type, Location loc) {}

OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) {}

void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
                                 int64_t value, unsigned width) {}

void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
                                 int64_t value, Type type) {}

bool arith::ConstantIntOp::classof(Operation *op) {}

void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
                                   const APFloat &value, FloatType type) {}

bool arith::ConstantFloatOp::classof(Operation *op) {}

void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result,
                                   int64_t value) {}

bool arith::ConstantIndexOp::classof(Operation *op) {}

// AddIOp

OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) {}

void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                                MLIRContext *context) {}

// AddUIExtendedOp

std::optional<SmallVector<int64_t, 4>>
arith::AddUIExtendedOp::getShapeForUnroll() {}

// Returns the overflow bit, assuming that `sum` is the result of unsigned
// addition of `operand` and another number.
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand) {}

arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
                             SmallVectorImpl<OpFoldResult> &results) {}

void arith::AddUIExtendedOp::getCanonicalizationPatterns(
    RewritePatternSet &patterns, MLIRContext *context) {}

// SubIOp

OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {}

void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                                MLIRContext *context) {}

// MulIOp

OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {}

void arith::MulIOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {}

void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                                MLIRContext *context) {}

// MulSIExtendedOp

std::optional<SmallVector<int64_t, 4>>
arith::MulSIExtendedOp::getShapeForUnroll() {}

arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
                             SmallVectorImpl<OpFoldResult> &results) {}

void arith::MulSIExtendedOp::getCanonicalizationPatterns(
    RewritePatternSet &patterns, MLIRContext *context) {}

// MulUIExtendedOp

std::optional<SmallVector<int64_t, 4>>
arith::MulUIExtendedOp::getShapeForUnroll() {}

arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
                             SmallVectorImpl<OpFoldResult> &results) {}

void arith::MulUIExtendedOp::getCanonicalizationPatterns(
    RewritePatternSet &patterns, MLIRContext *context) {}

// DivUIOp

OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {}

/// Returns whether an unsigned division by `divisor` is speculatable.
static Speculation::Speculatability getDivUISpeculatability(Value divisor) {}

Speculation::Speculatability arith::DivUIOp::getSpeculatability() {}

// DivSIOp

OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {}

/// Returns whether a signed division by `divisor` is speculatable. This
/// function conservatively assumes that all signed division by -1 are not
/// speculatable.
static Speculation::Speculatability getDivSISpeculatability(Value divisor) {}

Speculation::Speculatability arith::DivSIOp::getSpeculatability() {}

// Ceil and floor division folding helpers

static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
                                    bool &overflow) {}

// CeilDivUIOp

OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {}

Speculation::Speculatability arith::CeilDivUIOp::getSpeculatability() {}

// CeilDivSIOp

OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {}

Speculation::Speculatability arith::CeilDivSIOp::getSpeculatability() {}

// FloorDivSIOp

OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {}

// RemUIOp

OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {}

// RemSIOp

OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {}

// AndIOp

/// Fold `and(a, and(a, b))` to `and(a, b)`
static Value foldAndIofAndI(arith::AndIOp op) {}

OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {}

// OrIOp

OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {}

// XOrIOp

OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {}

void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                                MLIRContext *context) {}

// NegFOp

OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {}

// AddFOp

OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {}

// SubFOp

OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {}

// MaximumFOp

OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {}

// MaxNumFOp

OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {}

// MaxSIOp

OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {}

// MaxUIOp

OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {}

// MinimumFOp

OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {}

// MinNumFOp

OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {}

// MinSIOp

OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {}

// MinUIOp

OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {}

// MulFOp

OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {}

void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                                MLIRContext *context) {}

// DivFOp

OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {}

void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                                MLIRContext *context) {}

// RemFOp

OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {}

// Utility functions for verifying cast ops


/// Returns a non-null type only if the provided type is one of the allowed
/// types or one of the allowed shaped types of the allowed types. Returns the
/// element type if a valid shaped type is provided.
template <typename... ShapedTypes, typename... ElementTypes>
static Type getUnderlyingType(Type type, type_list<ShapedTypes...>,
                              type_list<ElementTypes...>) {}

/// Get allowed underlying types for vectors and tensors.
template <typename... ElementTypes>
static Type getTypeIfLike(Type type) {}

/// Get allowed underlying types for vectors, tensors, and memrefs.
template <typename... ElementTypes>
static Type getTypeIfLikeOrMemRef(Type type) {}

/// Return false if both types are ranked tensor with mismatching encoding.
static bool hasSameEncoding(Type typeA, Type typeB) {}

static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {}

// Verifiers for integer and floating point extension/truncation ops

// Extend ops can only extend to a wider type.
template <typename ValType, typename Op>
static LogicalResult verifyExtOp(Op op) {}

// Truncate ops can only truncate to a shorter type.
template <typename ValType, typename Op>
static LogicalResult verifyTruncateOp(Op op) {}

/// Validate a cast that changes the width of a type.
template <template <typename> class WidthComparator, typename... ElementTypes>
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {}

/// Attempts to convert `sourceValue` to an APFloat value with
/// `targetSemantics` and `roundingMode`, without any information loss.
static FailureOr<APFloat> convertFloatValue(
    APFloat sourceValue, const llvm::fltSemantics &targetSemantics,
    llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {}

// ExtUIOp

OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {}

bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {}

LogicalResult arith::ExtUIOp::verify() {}

// ExtSIOp

OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {}

bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {}

void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                                 MLIRContext *context) {}

LogicalResult arith::ExtSIOp::verify() {}

// ExtFOp

/// Fold extension of float constants when there is no information loss due the
/// difference in fp semantics.
OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {}

bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {}

LogicalResult arith::ExtFOp::verify() {}

// TruncIOp

OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {}

bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {}

void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                                  MLIRContext *context) {}

LogicalResult arith::TruncIOp::verify() {}

// TruncFOp

/// Perform safe const propagation for truncf, i.e., only propagate if FP value
/// can be represented without precision loss.
OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {}

bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {}

LogicalResult arith::TruncFOp::verify() {}

// AndIOp

void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                                MLIRContext *context) {}

// OrIOp

void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                               MLIRContext *context) {}

// Verifiers for casts between integers and floats.

template <typename From, typename To>
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {}


bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {}

OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {}


bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {}

OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {}


bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {}

OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {}


bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {}

OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {}

// IndexCastOp

static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs) {}

bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
                                           TypeRange outputs) {}

OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {}

void arith::IndexCastOp::getCanonicalizationPatterns(
    RewritePatternSet &patterns, MLIRContext *context) {}

// IndexCastUIOp

bool arith::IndexCastUIOp::areCastCompatible(TypeRange inputs,
                                             TypeRange outputs) {}

OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {}

void arith::IndexCastUIOp::getCanonicalizationPatterns(
    RewritePatternSet &patterns, MLIRContext *context) {}

// BitcastOp

bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {}

OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {}

void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                                   MLIRContext *context) {}

// CmpIOp

/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
/// comparison predicates.
bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
                                    const APInt &lhs, const APInt &rhs) {}

/// Returns true if the predicate is true for two equal operands.
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {}

static std::optional<int64_t> getIntegerWidth(Type t) {}

OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {}

void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                                MLIRContext *context) {}

// CmpFOp

/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
/// comparison predicates.
bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
                                    const APFloat &lhs, const APFloat &rhs) {}

OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {}

class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {};

void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                                MLIRContext *context) {}

// SelectOp

//  select %arg, %c1, %c0 => extui %arg
struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {};

void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                  MLIRContext *context) {}

OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {}

ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {}

void arith::SelectOp::print(OpAsmPrinter &p) {}

LogicalResult arith::SelectOp::verify() {}
// ShLIOp

OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {}

// ShRUIOp

OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {}

// ShRSIOp

OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {}

// Atomic Enum

/// Returns the identity value attribute associated with an AtomicRMWKind op.
TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
                                            OpBuilder &builder, Location loc,
                                            bool useOnlyFiniteValue) {}

/// Return the identity numeric value associated to the give op.
std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {}

/// Returns the identity value associated with an AtomicRMWKind op.
Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
                                    OpBuilder &builder, Location loc,
                                    bool useOnlyFiniteValue) {}

/// Return the value obtained by applying the reduction operation kind
/// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
                                  Location loc, Value lhs, Value rhs) {}

// TableGen'd op method definitions

#include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"

// TableGen'd enum attribute definitions

#include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"