//===- EmulateWideInt.cpp - Wide integer operation emulation ----*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/Transforms/Passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Transforms/WideIntEmulationConverter.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/FuncConversions.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/APInt.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/MathExtras.h" #include <cassert> namespace mlir::arith { #define GEN_PASS_DEF_ARITHEMULATEWIDEINT #include "mlir/Dialect/Arith/Transforms/Passes.h.inc" } // namespace mlir::arith usingnamespacemlir; //===----------------------------------------------------------------------===// // Common Helper Functions //===----------------------------------------------------------------------===// /// Returns N bottom and N top bits from `value`, where N = `newBitWidth`. /// Treats `value` as a 2*N bits-wide integer. /// The bottom bits are returned in the first pair element, while the top bits /// in the second one. static std::pair<APInt, APInt> getHalves(const APInt &value, unsigned newBitWidth) { … } /// Returns the type with the last (innermost) dimension reduced to x1. /// Scalarizes 1D vector inputs to match how we extract/insert vector values, /// e.g.: /// - vector<3x2xi16> --> vector<3x1xi16> /// - vector<2xi16> --> i16 static Type reduceInnermostDim(VectorType type) { … } /// Extracts the `input` vector slice with elements at the last dimension offset /// by `lastOffset`. Returns a value of vector type with the last dimension /// reduced to x1 or fully scalarized, e.g.: /// - vector<3x2xi16> --> vector<3x1xi16> /// - vector<2xi16> --> i16 static Value extractLastDimSlice(ConversionPatternRewriter &rewriter, Location loc, Value input, int64_t lastOffset) { … } /// Extracts two vector slices from the `input` whose type is `vector<...x2T>`, /// with the first element at offset 0 and the second element at offset 1. static std::pair<Value, Value> extractLastDimHalves(ConversionPatternRewriter &rewriter, Location loc, Value input) { … } // Performs a vector shape cast to drop the trailing x1 dimension. If the // `input` is a scalar, this is a noop. static Value dropTrailingX1Dim(ConversionPatternRewriter &rewriter, Location loc, Value input) { … } /// Performs a vector shape cast to append an x1 dimension. If the /// `input` is a scalar, this is a noop. static Value appendX1Dim(ConversionPatternRewriter &rewriter, Location loc, Value input) { … } /// Inserts the `source` vector slice into the `dest` vector at offset /// `lastOffset` in the last dimension. `source` can be a scalar when `dest` is /// a 1D vector. static Value insertLastDimSlice(ConversionPatternRewriter &rewriter, Location loc, Value source, Value dest, int64_t lastOffset) { … } /// Constructs a new vector of type `resultType` by creating a series of /// insertions of `resultComponents`, each at the next offset of the last vector /// dimension. /// When all `resultComponents` are scalars, the result type is `vector<NxT>`; /// when `resultComponents` are `vector<...x1xT>`s, the result type is /// `vector<...xNxT>`, where `N` is the number of `resultComponents`. static Value constructResultVector(ConversionPatternRewriter &rewriter, Location loc, VectorType resultType, ValueRange resultComponents) { … } namespace { //===----------------------------------------------------------------------===// // ConvertConstant //===----------------------------------------------------------------------===// struct ConvertConstant final : OpConversionPattern<arith::ConstantOp> { … }; //===----------------------------------------------------------------------===// // ConvertAddI //===----------------------------------------------------------------------===// struct ConvertAddI final : OpConversionPattern<arith::AddIOp> { … }; //===----------------------------------------------------------------------===// // ConvertBitwiseBinary //===----------------------------------------------------------------------===// /// Conversion pattern template for bitwise binary ops, e.g., `arith.andi`. template <typename BinaryOp> struct ConvertBitwiseBinary final : OpConversionPattern<BinaryOp> { … }; //===----------------------------------------------------------------------===// // ConvertCmpI //===----------------------------------------------------------------------===// /// Returns the matching unsigned version of the given predicate `pred`, or the /// same predicate if `pred` is not a signed. static arith::CmpIPredicate toUnsignedPredicate(arith::CmpIPredicate pred) { … } struct ConvertCmpI final : OpConversionPattern<arith::CmpIOp> { … }; //===----------------------------------------------------------------------===// // ConvertMulI //===----------------------------------------------------------------------===// struct ConvertMulI final : OpConversionPattern<arith::MulIOp> { … }; //===----------------------------------------------------------------------===// // ConvertExtSI //===----------------------------------------------------------------------===// struct ConvertExtSI final : OpConversionPattern<arith::ExtSIOp> { … }; //===----------------------------------------------------------------------===// // ConvertExtUI //===----------------------------------------------------------------------===// struct ConvertExtUI final : OpConversionPattern<arith::ExtUIOp> { … }; //===----------------------------------------------------------------------===// // ConvertMaxMin //===----------------------------------------------------------------------===// template <typename SourceOp, arith::CmpIPredicate CmpPred> struct ConvertMaxMin final : OpConversionPattern<SourceOp> { … }; // Convert IndexCast ops //===----------------------------------------------------------------------===// /// Returns true iff the type is `index` or `vector<...index>`. static bool isIndexOrIndexVector(Type type) { … } template <typename CastOp> struct ConvertIndexCastIntToIndex final : OpConversionPattern<CastOp> { … }; template <typename CastOp, typename ExtensionOp> struct ConvertIndexCastIndexToInt final : OpConversionPattern<CastOp> { … }; //===----------------------------------------------------------------------===// // ConvertSelect //===----------------------------------------------------------------------===// struct ConvertSelect final : OpConversionPattern<arith::SelectOp> { … }; //===----------------------------------------------------------------------===// // ConvertShLI //===----------------------------------------------------------------------===// struct ConvertShLI final : OpConversionPattern<arith::ShLIOp> { … }; //===----------------------------------------------------------------------===// // ConvertShRUI //===----------------------------------------------------------------------===// struct ConvertShRUI final : OpConversionPattern<arith::ShRUIOp> { … }; //===----------------------------------------------------------------------===// // ConvertShRSI //===----------------------------------------------------------------------===// struct ConvertShRSI final : OpConversionPattern<arith::ShRSIOp> { … }; //===----------------------------------------------------------------------===// // ConvertSIToFP //===----------------------------------------------------------------------===// struct ConvertSIToFP final : OpConversionPattern<arith::SIToFPOp> { … }; //===----------------------------------------------------------------------===// // ConvertUIToFP //===----------------------------------------------------------------------===// struct ConvertUIToFP final : OpConversionPattern<arith::UIToFPOp> { … }; //===----------------------------------------------------------------------===// // ConvertTruncI //===----------------------------------------------------------------------===// struct ConvertTruncI final : OpConversionPattern<arith::TruncIOp> { … }; //===----------------------------------------------------------------------===// // ConvertVectorPrint //===----------------------------------------------------------------------===// struct ConvertVectorPrint final : OpConversionPattern<vector::PrintOp> { … }; //===----------------------------------------------------------------------===// // Pass Definition //===----------------------------------------------------------------------===// struct EmulateWideIntPass final : arith::impl::ArithEmulateWideIntBase<EmulateWideIntPass> { … }; } // end anonymous namespace //===----------------------------------------------------------------------===// // Public Interface Definition //===----------------------------------------------------------------------===// arith::WideIntEmulationConverter::WideIntEmulationConverter( unsigned widestIntSupportedByTarget) : … { … } void arith::populateArithWideIntEmulationPatterns( const WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns) { … }