llvm/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp

//===- IntNarrowing.cpp - Integer bitwidth reduction optimizations --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Arith/Transforms/Passes.h"

#include "mlir/Analysis/Presburger/IntegerRelation.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/Transforms.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include <cassert>
#include <cstdint>

namespace mlir::arith {
#define GEN_PASS_DEF_ARITHINTNARROWING
#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
} // namespace mlir::arith

namespace mlir::arith {
namespace {
//===----------------------------------------------------------------------===//
// Common Helpers
//===----------------------------------------------------------------------===//

/// The base for integer bitwidth narrowing patterns.
template <typename SourceOp>
struct NarrowingPattern : OpRewritePattern<SourceOp> {};

/// Returns the integer bitwidth required to represent `type`.
FailureOr<unsigned> calculateBitsRequired(Type type) {}

enum class ExtensionKind {};

/// Wrapper around `arith::ExtSIOp` and `arith::ExtUIOp` ops that abstracts away
/// the exact op type. Exposes helper functions to query the types, operands,
/// and the result. This is so that we can handle both extension kinds without
/// needing to use templates or branching.
class ExtensionOp {};

/// Returns the integer bitwidth required to represent `value`.
unsigned calculateBitsRequired(const APInt &value,
                               ExtensionKind lookThroughExtension) {}

/// Returns the integer bitwidth required to represent `value`.
/// Looks through either sign- or zero-extension as specified by
/// `lookThroughExtension`.
FailureOr<unsigned> calculateBitsRequired(Value value,
                                          ExtensionKind lookThroughExtension) {}

/// Base pattern for arith binary ops.
/// Example:
/// ```
///   %lhs = arith.extsi %a : i8 to i32
///   %rhs = arith.extsi %b : i8 to i32
///   %r = arith.addi %lhs, %rhs : i32
/// ==>
///   %lhs = arith.extsi %a : i8 to i16
///   %rhs = arith.extsi %b : i8 to i16
///   %add = arith.addi %lhs, %rhs : i16
///   %r = arith.extsi %add : i16 to i32
/// ```
template <typename BinaryOp>
struct BinaryOpNarrowingPattern : NarrowingPattern<BinaryOp> {};

//===----------------------------------------------------------------------===//
// AddIOp Pattern
//===----------------------------------------------------------------------===//

struct AddIPattern final : BinaryOpNarrowingPattern<arith::AddIOp> {};

//===----------------------------------------------------------------------===//
// SubIOp Pattern
//===----------------------------------------------------------------------===//

struct SubIPattern final : BinaryOpNarrowingPattern<arith::SubIOp> {};

//===----------------------------------------------------------------------===//
// MulIOp Pattern
//===----------------------------------------------------------------------===//

struct MulIPattern final : BinaryOpNarrowingPattern<arith::MulIOp> {};

//===----------------------------------------------------------------------===//
// DivSIOp Pattern
//===----------------------------------------------------------------------===//

struct DivSIPattern final : BinaryOpNarrowingPattern<arith::DivSIOp> {};

//===----------------------------------------------------------------------===//
// DivUIOp Pattern
//===----------------------------------------------------------------------===//

struct DivUIPattern final : BinaryOpNarrowingPattern<arith::DivUIOp> {};

//===----------------------------------------------------------------------===//
// Min/Max Patterns
//===----------------------------------------------------------------------===//

template <typename MinMaxOp, ExtensionKind Kind>
struct MinMaxPattern final : BinaryOpNarrowingPattern<MinMaxOp> {};
MaxSIPattern;
MaxUIPattern;
MinSIPattern;
MinUIPattern;

//===----------------------------------------------------------------------===//
// *IToFPOp Patterns
//===----------------------------------------------------------------------===//

template <typename IToFPOp, ExtensionKind Extension>
struct IToFPPattern final : NarrowingPattern<IToFPOp> {};
SIToFPPattern;
UIToFPPattern;

//===----------------------------------------------------------------------===//
// Index Cast Patterns
//===----------------------------------------------------------------------===//

// These rely on the `ValueBounds` interface for index values. For example, we
// can often statically tell index value bounds of loop induction variables.

template <typename CastOp, ExtensionKind Kind>
struct IndexCastPattern final : NarrowingPattern<CastOp> {};
IndexCastSIPattern;
IndexCastUIPattern;

//===----------------------------------------------------------------------===//
// Patterns to Commute Extension Ops
//===----------------------------------------------------------------------===//

struct ExtensionOverBroadcast final : NarrowingPattern<vector::BroadcastOp> {};

struct ExtensionOverExtract final : NarrowingPattern<vector::ExtractOp> {};

struct ExtensionOverExtractElement final
    : NarrowingPattern<vector::ExtractElementOp> {};

struct ExtensionOverExtractStridedSlice final
    : NarrowingPattern<vector::ExtractStridedSliceOp> {};

/// Base pattern for `vector.insert` narrowing patterns.
template <typename InsertionOp>
struct ExtensionOverInsertionPattern : NarrowingPattern<InsertionOp> {};

struct ExtensionOverInsert final
    : ExtensionOverInsertionPattern<vector::InsertOp> {};

struct ExtensionOverInsertElement final
    : ExtensionOverInsertionPattern<vector::InsertElementOp> {};

struct ExtensionOverInsertStridedSlice final
    : ExtensionOverInsertionPattern<vector::InsertStridedSliceOp> {};

struct ExtensionOverShapeCast final : NarrowingPattern<vector::ShapeCastOp> {};

struct ExtensionOverTranspose final : NarrowingPattern<vector::TransposeOp> {};

struct ExtensionOverFlatTranspose final
    : NarrowingPattern<vector::FlatTransposeOp> {};

//===----------------------------------------------------------------------===//
// Pass Definitions
//===----------------------------------------------------------------------===//

struct ArithIntNarrowingPass final
    : impl::ArithIntNarrowingBase<ArithIntNarrowingPass> {};
} // namespace

//===----------------------------------------------------------------------===//
// Public API
//===----------------------------------------------------------------------===//

void populateArithIntNarrowingPatterns(
    RewritePatternSet &patterns, const ArithIntNarrowingOptions &options) {}

} // namespace mlir::arith