//===- IntNarrowing.cpp - Integer bitwidth reduction optimizations --------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/Transforms/Passes.h" #include "mlir/Analysis/Presburger/IntegerRelation.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Transforms/Transforms.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include <cassert> #include <cstdint> namespace mlir::arith { #define GEN_PASS_DEF_ARITHINTNARROWING #include "mlir/Dialect/Arith/Transforms/Passes.h.inc" } // namespace mlir::arith namespace mlir::arith { namespace { //===----------------------------------------------------------------------===// // Common Helpers //===----------------------------------------------------------------------===// /// The base for integer bitwidth narrowing patterns. template <typename SourceOp> struct NarrowingPattern : OpRewritePattern<SourceOp> { … }; /// Returns the integer bitwidth required to represent `type`. FailureOr<unsigned> calculateBitsRequired(Type type) { … } enum class ExtensionKind { … }; /// Wrapper around `arith::ExtSIOp` and `arith::ExtUIOp` ops that abstracts away /// the exact op type. Exposes helper functions to query the types, operands, /// and the result. This is so that we can handle both extension kinds without /// needing to use templates or branching. class ExtensionOp { … }; /// Returns the integer bitwidth required to represent `value`. unsigned calculateBitsRequired(const APInt &value, ExtensionKind lookThroughExtension) { … } /// Returns the integer bitwidth required to represent `value`. /// Looks through either sign- or zero-extension as specified by /// `lookThroughExtension`. FailureOr<unsigned> calculateBitsRequired(Value value, ExtensionKind lookThroughExtension) { … } /// Base pattern for arith binary ops. /// Example: /// ``` /// %lhs = arith.extsi %a : i8 to i32 /// %rhs = arith.extsi %b : i8 to i32 /// %r = arith.addi %lhs, %rhs : i32 /// ==> /// %lhs = arith.extsi %a : i8 to i16 /// %rhs = arith.extsi %b : i8 to i16 /// %add = arith.addi %lhs, %rhs : i16 /// %r = arith.extsi %add : i16 to i32 /// ``` template <typename BinaryOp> struct BinaryOpNarrowingPattern : NarrowingPattern<BinaryOp> { … }; //===----------------------------------------------------------------------===// // AddIOp Pattern //===----------------------------------------------------------------------===// struct AddIPattern final : BinaryOpNarrowingPattern<arith::AddIOp> { … }; //===----------------------------------------------------------------------===// // SubIOp Pattern //===----------------------------------------------------------------------===// struct SubIPattern final : BinaryOpNarrowingPattern<arith::SubIOp> { … }; //===----------------------------------------------------------------------===// // MulIOp Pattern //===----------------------------------------------------------------------===// struct MulIPattern final : BinaryOpNarrowingPattern<arith::MulIOp> { … }; //===----------------------------------------------------------------------===// // DivSIOp Pattern //===----------------------------------------------------------------------===// struct DivSIPattern final : BinaryOpNarrowingPattern<arith::DivSIOp> { … }; //===----------------------------------------------------------------------===// // DivUIOp Pattern //===----------------------------------------------------------------------===// struct DivUIPattern final : BinaryOpNarrowingPattern<arith::DivUIOp> { … }; //===----------------------------------------------------------------------===// // Min/Max Patterns //===----------------------------------------------------------------------===// template <typename MinMaxOp, ExtensionKind Kind> struct MinMaxPattern final : BinaryOpNarrowingPattern<MinMaxOp> { … }; MaxSIPattern; MaxUIPattern; MinSIPattern; MinUIPattern; //===----------------------------------------------------------------------===// // *IToFPOp Patterns //===----------------------------------------------------------------------===// template <typename IToFPOp, ExtensionKind Extension> struct IToFPPattern final : NarrowingPattern<IToFPOp> { … }; SIToFPPattern; UIToFPPattern; //===----------------------------------------------------------------------===// // Index Cast Patterns //===----------------------------------------------------------------------===// // These rely on the `ValueBounds` interface for index values. For example, we // can often statically tell index value bounds of loop induction variables. template <typename CastOp, ExtensionKind Kind> struct IndexCastPattern final : NarrowingPattern<CastOp> { … }; IndexCastSIPattern; IndexCastUIPattern; //===----------------------------------------------------------------------===// // Patterns to Commute Extension Ops //===----------------------------------------------------------------------===// struct ExtensionOverBroadcast final : NarrowingPattern<vector::BroadcastOp> { … }; struct ExtensionOverExtract final : NarrowingPattern<vector::ExtractOp> { … }; struct ExtensionOverExtractElement final : NarrowingPattern<vector::ExtractElementOp> { … }; struct ExtensionOverExtractStridedSlice final : NarrowingPattern<vector::ExtractStridedSliceOp> { … }; /// Base pattern for `vector.insert` narrowing patterns. template <typename InsertionOp> struct ExtensionOverInsertionPattern : NarrowingPattern<InsertionOp> { … }; struct ExtensionOverInsert final : ExtensionOverInsertionPattern<vector::InsertOp> { … }; struct ExtensionOverInsertElement final : ExtensionOverInsertionPattern<vector::InsertElementOp> { … }; struct ExtensionOverInsertStridedSlice final : ExtensionOverInsertionPattern<vector::InsertStridedSliceOp> { … }; struct ExtensionOverShapeCast final : NarrowingPattern<vector::ShapeCastOp> { … }; struct ExtensionOverTranspose final : NarrowingPattern<vector::TransposeOp> { … }; struct ExtensionOverFlatTranspose final : NarrowingPattern<vector::FlatTransposeOp> { … }; //===----------------------------------------------------------------------===// // Pass Definitions //===----------------------------------------------------------------------===// struct ArithIntNarrowingPass final : impl::ArithIntNarrowingBase<ArithIntNarrowingPass> { … }; } // namespace //===----------------------------------------------------------------------===// // Public API //===----------------------------------------------------------------------===// void populateArithIntNarrowingPatterns( RewritePatternSet &patterns, const ArithIntNarrowingOptions &options) { … } } // namespace mlir::arith