//===- ArithOps.cpp - MLIR Arith dialect ops implementation -----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include <cassert> #include <cstdint> #include <functional> #include <utility> #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/CommonFolders.h" #include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/APSInt.h" #include "llvm/ADT/FloatingPointMode.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" usingnamespacemlir; usingnamespacemlir::arith; //===----------------------------------------------------------------------===// // Pattern helpers //===----------------------------------------------------------------------===// static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref<APInt(const APInt &, const APInt &)> binFn) { … } static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs) { … } static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs) { … } static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs) { … } // Merge overflow flags from 2 ops, selecting the most conservative combination. static IntegerOverflowFlagsAttr mergeOverflowFlags(IntegerOverflowFlagsAttr val1, IntegerOverflowFlagsAttr val2) { … } /// Invert an integer comparison predicate. arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) { … } /// Equivalent to /// convertRoundingModeToLLVM(convertArithRoundingModeToLLVM(roundingMode)). /// /// Not possible to implement as chain of calls as this would introduce a /// circular dependency with MLIRArithAttrToLLVMConversion and make arith depend /// on the LLVM dialect and on translation to LLVM. static llvm::RoundingMode convertArithRoundingModeToLLVMIR(RoundingMode roundingMode) { … } static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) { … } static int64_t getScalarOrElementWidth(Type type) { … } static int64_t getScalarOrElementWidth(Value value) { … } static FailureOr<APInt> getIntOrSplatIntValue(Attribute attr) { … } static Attribute getBoolAttribute(Type type, bool value) { … } //===----------------------------------------------------------------------===// // TableGen'd canonicalization patterns //===----------------------------------------------------------------------===// namespace { #include "ArithCanonicalization.inc" } // namespace //===----------------------------------------------------------------------===// // Common helpers //===----------------------------------------------------------------------===// /// Return the type of the same shape (scalar, vector or tensor) containing i1. static Type getI1SameShape(Type type) { … } //===----------------------------------------------------------------------===// // ConstantOp //===----------------------------------------------------------------------===// void arith::ConstantOp::getAsmResultNames( function_ref<void(Value, StringRef)> setNameFn) { … } /// TODO: disallow arith.constant to return anything other than signless integer /// or float like. LogicalResult arith::ConstantOp::verify() { … } bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) { … } ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value, Type type, Location loc) { … } OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { … } void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width) { … } void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, int64_t value, Type type) { … } bool arith::ConstantIntOp::classof(Operation *op) { … } void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result, const APFloat &value, FloatType type) { … } bool arith::ConstantFloatOp::classof(Operation *op) { … } void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result, int64_t value) { … } bool arith::ConstantIndexOp::classof(Operation *op) { … } //===----------------------------------------------------------------------===// // AddIOp //===----------------------------------------------------------------------===// OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) { … } void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { … } //===----------------------------------------------------------------------===// // AddUIExtendedOp //===----------------------------------------------------------------------===// std::optional<SmallVector<int64_t, 4>> arith::AddUIExtendedOp::getShapeForUnroll() { … } // Returns the overflow bit, assuming that `sum` is the result of unsigned // addition of `operand` and another number. static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand) { … } LogicalResult arith::AddUIExtendedOp::fold(FoldAdaptor adaptor, SmallVectorImpl<OpFoldResult> &results) { … } void arith::AddUIExtendedOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { … } //===----------------------------------------------------------------------===// // SubIOp //===----------------------------------------------------------------------===// OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) { … } void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { … } //===----------------------------------------------------------------------===// // MulIOp //===----------------------------------------------------------------------===// OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) { … } void arith::MulIOp::getAsmResultNames( function_ref<void(Value, StringRef)> setNameFn) { … } void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { … } //===----------------------------------------------------------------------===// // MulSIExtendedOp //===----------------------------------------------------------------------===// std::optional<SmallVector<int64_t, 4>> arith::MulSIExtendedOp::getShapeForUnroll() { … } LogicalResult arith::MulSIExtendedOp::fold(FoldAdaptor adaptor, SmallVectorImpl<OpFoldResult> &results) { … } void arith::MulSIExtendedOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { … } //===----------------------------------------------------------------------===// // MulUIExtendedOp //===----------------------------------------------------------------------===// std::optional<SmallVector<int64_t, 4>> arith::MulUIExtendedOp::getShapeForUnroll() { … } LogicalResult arith::MulUIExtendedOp::fold(FoldAdaptor adaptor, SmallVectorImpl<OpFoldResult> &results) { … } void arith::MulUIExtendedOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { … } //===----------------------------------------------------------------------===// // DivUIOp //===----------------------------------------------------------------------===// OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) { … } /// Returns whether an unsigned division by `divisor` is speculatable. static Speculation::Speculatability getDivUISpeculatability(Value divisor) { … } Speculation::Speculatability arith::DivUIOp::getSpeculatability() { … } //===----------------------------------------------------------------------===// // DivSIOp //===----------------------------------------------------------------------===// OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) { … } /// Returns whether a signed division by `divisor` is speculatable. This /// function conservatively assumes that all signed division by -1 are not /// speculatable. static Speculation::Speculatability getDivSISpeculatability(Value divisor) { … } Speculation::Speculatability arith::DivSIOp::getSpeculatability() { … } //===----------------------------------------------------------------------===// // Ceil and floor division folding helpers //===----------------------------------------------------------------------===// static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow) { … } //===----------------------------------------------------------------------===// // CeilDivUIOp //===----------------------------------------------------------------------===// OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) { … } Speculation::Speculatability arith::CeilDivUIOp::getSpeculatability() { … } //===----------------------------------------------------------------------===// // CeilDivSIOp //===----------------------------------------------------------------------===// OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) { … } Speculation::Speculatability arith::CeilDivSIOp::getSpeculatability() { … } //===----------------------------------------------------------------------===// // FloorDivSIOp //===----------------------------------------------------------------------===// OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // RemUIOp //===----------------------------------------------------------------------===// OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // RemSIOp //===----------------------------------------------------------------------===// OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // AndIOp //===----------------------------------------------------------------------===// /// Fold `and(a, and(a, b))` to `and(a, b)` static Value foldAndIofAndI(arith::AndIOp op) { … } OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // OrIOp //===----------------------------------------------------------------------===// OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // XOrIOp //===----------------------------------------------------------------------===// OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) { … } void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { … } //===----------------------------------------------------------------------===// // NegFOp //===----------------------------------------------------------------------===// OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // AddFOp //===----------------------------------------------------------------------===// OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // SubFOp //===----------------------------------------------------------------------===// OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // MaximumFOp //===----------------------------------------------------------------------===// OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // MaxNumFOp //===----------------------------------------------------------------------===// OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // MaxSIOp //===----------------------------------------------------------------------===// OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // MaxUIOp //===----------------------------------------------------------------------===// OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // MinimumFOp //===----------------------------------------------------------------------===// OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // MinNumFOp //===----------------------------------------------------------------------===// OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // MinSIOp //===----------------------------------------------------------------------===// OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // MinUIOp //===----------------------------------------------------------------------===// OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // MulFOp //===----------------------------------------------------------------------===// OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) { … } void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { … } //===----------------------------------------------------------------------===// // DivFOp //===----------------------------------------------------------------------===// OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) { … } void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { … } //===----------------------------------------------------------------------===// // RemFOp //===----------------------------------------------------------------------===// OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // Utility functions for verifying cast ops //===----------------------------------------------------------------------===// type_list; /// Returns a non-null type only if the provided type is one of the allowed /// types or one of the allowed shaped types of the allowed types. Returns the /// element type if a valid shaped type is provided. template <typename... ShapedTypes, typename... ElementTypes> static Type getUnderlyingType(Type type, type_list<ShapedTypes...>, type_list<ElementTypes...>) { … } /// Get allowed underlying types for vectors and tensors. template <typename... ElementTypes> static Type getTypeIfLike(Type type) { … } /// Get allowed underlying types for vectors, tensors, and memrefs. template <typename... ElementTypes> static Type getTypeIfLikeOrMemRef(Type type) { … } /// Return false if both types are ranked tensor with mismatching encoding. static bool hasSameEncoding(Type typeA, Type typeB) { … } static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) { … } //===----------------------------------------------------------------------===// // Verifiers for integer and floating point extension/truncation ops //===----------------------------------------------------------------------===// // Extend ops can only extend to a wider type. template <typename ValType, typename Op> static LogicalResult verifyExtOp(Op op) { … } // Truncate ops can only truncate to a shorter type. template <typename ValType, typename Op> static LogicalResult verifyTruncateOp(Op op) { … } /// Validate a cast that changes the width of a type. template <template <typename> class WidthComparator, typename... ElementTypes> static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) { … } /// Attempts to convert `sourceValue` to an APFloat value with /// `targetSemantics` and `roundingMode`, without any information loss. static FailureOr<APFloat> convertFloatValue( APFloat sourceValue, const llvm::fltSemantics &targetSemantics, llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) { … } //===----------------------------------------------------------------------===// // ExtUIOp //===----------------------------------------------------------------------===// OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) { … } bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { … } LogicalResult arith::ExtUIOp::verify() { … } //===----------------------------------------------------------------------===// // ExtSIOp //===----------------------------------------------------------------------===// OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) { … } bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { … } void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { … } LogicalResult arith::ExtSIOp::verify() { … } //===----------------------------------------------------------------------===// // ExtFOp //===----------------------------------------------------------------------===// /// Fold extension of float constants when there is no information loss due the /// difference in fp semantics. OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) { … } bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { … } LogicalResult arith::ExtFOp::verify() { … } //===----------------------------------------------------------------------===// // TruncIOp //===----------------------------------------------------------------------===// OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) { … } bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { … } void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { … } LogicalResult arith::TruncIOp::verify() { … } //===----------------------------------------------------------------------===// // TruncFOp //===----------------------------------------------------------------------===// /// Perform safe const propagation for truncf, i.e., only propagate if FP value /// can be represented without precision loss. OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) { … } bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { … } LogicalResult arith::TruncFOp::verify() { … } //===----------------------------------------------------------------------===// // AndIOp //===----------------------------------------------------------------------===// void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { … } //===----------------------------------------------------------------------===// // OrIOp //===----------------------------------------------------------------------===// void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { … } //===----------------------------------------------------------------------===// // Verifiers for casts between integers and floats. //===----------------------------------------------------------------------===// template <typename From, typename To> static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) { … } //===----------------------------------------------------------------------===// // UIToFPOp //===----------------------------------------------------------------------===// bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { … } OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // SIToFPOp //===----------------------------------------------------------------------===// bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { … } OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // FPToUIOp //===----------------------------------------------------------------------===// bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { … } OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // FPToSIOp //===----------------------------------------------------------------------===// bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { … } OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // IndexCastOp //===----------------------------------------------------------------------===// static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs) { … } bool arith::IndexCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { … } OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) { … } void arith::IndexCastOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { … } //===----------------------------------------------------------------------===// // IndexCastUIOp //===----------------------------------------------------------------------===// bool arith::IndexCastUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { … } OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) { … } void arith::IndexCastUIOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { … } //===----------------------------------------------------------------------===// // BitcastOp //===----------------------------------------------------------------------===// bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { … } OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) { … } void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { … } //===----------------------------------------------------------------------===// // CmpIOp //===----------------------------------------------------------------------===// /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer /// comparison predicates. bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate, const APInt &lhs, const APInt &rhs) { … } /// Returns true if the predicate is true for two equal operands. static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) { … } static std::optional<int64_t> getIntegerWidth(Type t) { … } OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) { … } void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { … } //===----------------------------------------------------------------------===// // CmpFOp //===----------------------------------------------------------------------===// /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point /// comparison predicates. bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate, const APFloat &lhs, const APFloat &rhs) { … } OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) { … } class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> { … }; void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { … } //===----------------------------------------------------------------------===// // SelectOp //===----------------------------------------------------------------------===// // select %arg, %c1, %c0 => extui %arg struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> { … }; void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { … } OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) { … } ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) { … } void arith::SelectOp::print(OpAsmPrinter &p) { … } LogicalResult arith::SelectOp::verify() { … } //===----------------------------------------------------------------------===// // ShLIOp //===----------------------------------------------------------------------===// OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // ShRUIOp //===----------------------------------------------------------------------===// OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // ShRSIOp //===----------------------------------------------------------------------===// OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) { … } //===----------------------------------------------------------------------===// // Atomic Enum //===----------------------------------------------------------------------===// /// Returns the identity value attribute associated with an AtomicRMWKind op. TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue) { … } /// Return the identity numeric value associated to the give op. std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) { … } /// Returns the identity value associated with an AtomicRMWKind op. Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue) { … } /// Return the value obtained by applying the reduction operation kind /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`. Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs) { … } //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// #define GET_OP_CLASSES #include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc" //===----------------------------------------------------------------------===// // TableGen'd enum attribute definitions //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"