
//===- IntNarrowing.cpp - Integer bitwidth reduction optimizations --------===//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "mlir/Dialect/Arith/Transforms/Passes.h"

#include "mlir/Analysis/Presburger/IntegerRelation.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/Transforms.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include <cassert>
#include <cstdint>

namespace mlir::arith {
#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
} // namespace mlir::arith

namespace mlir::arith {
namespace {
// Common Helpers

/// The base for integer bitwidth narrowing patterns.
template <typename SourceOp>
struct NarrowingPattern : OpRewritePattern<SourceOp> {};

/// Returns the integer bitwidth required to represent `type`.
FailureOr<unsigned> calculateBitsRequired(Type type) {}

enum class ExtensionKind {};

/// Wrapper around `arith::ExtSIOp` and `arith::ExtUIOp` ops that abstracts away
/// the exact op type. Exposes helper functions to query the types, operands,
/// and the result. This is so that we can handle both extension kinds without
/// needing to use templates or branching.
class ExtensionOp {};

/// Returns the integer bitwidth required to represent `value`.
unsigned calculateBitsRequired(const APInt &value,
                               ExtensionKind lookThroughExtension) {}

/// Returns the integer bitwidth required to represent `value`.
/// Looks through either sign- or zero-extension as specified by
/// `lookThroughExtension`.
FailureOr<unsigned> calculateBitsRequired(Value value,
                                          ExtensionKind lookThroughExtension) {}

/// Base pattern for arith binary ops.
/// Example:
/// ```
///   %lhs = arith.extsi %a : i8 to i32
///   %rhs = arith.extsi %b : i8 to i32
///   %r = arith.addi %lhs, %rhs : i32
/// ==>
///   %lhs = arith.extsi %a : i8 to i16
///   %rhs = arith.extsi %b : i8 to i16
///   %add = arith.addi %lhs, %rhs : i16
///   %r = arith.extsi %add : i16 to i32
/// ```
template <typename BinaryOp>
struct BinaryOpNarrowingPattern : NarrowingPattern<BinaryOp> {};

// AddIOp Pattern

struct AddIPattern final : BinaryOpNarrowingPattern<arith::AddIOp> {};

// SubIOp Pattern

struct SubIPattern final : BinaryOpNarrowingPattern<arith::SubIOp> {};

// MulIOp Pattern

struct MulIPattern final : BinaryOpNarrowingPattern<arith::MulIOp> {};

// DivSIOp Pattern

struct DivSIPattern final : BinaryOpNarrowingPattern<arith::DivSIOp> {};

// DivUIOp Pattern

struct DivUIPattern final : BinaryOpNarrowingPattern<arith::DivUIOp> {};

// Min/Max Patterns

template <typename MinMaxOp, ExtensionKind Kind>
struct MinMaxPattern final : BinaryOpNarrowingPattern<MinMaxOp> {};

// *IToFPOp Patterns

template <typename IToFPOp, ExtensionKind Extension>
struct IToFPPattern final : NarrowingPattern<IToFPOp> {};

// Index Cast Patterns

// These rely on the `ValueBounds` interface for index values. For example, we
// can often statically tell index value bounds of loop induction variables.

template <typename CastOp, ExtensionKind Kind>
struct IndexCastPattern final : NarrowingPattern<CastOp> {};

// Patterns to Commute Extension Ops

struct ExtensionOverBroadcast final : NarrowingPattern<vector::BroadcastOp> {};

struct ExtensionOverExtract final : NarrowingPattern<vector::ExtractOp> {};

struct ExtensionOverExtractElement final
    : NarrowingPattern<vector::ExtractElementOp> {};

struct ExtensionOverExtractStridedSlice final
    : NarrowingPattern<vector::ExtractStridedSliceOp> {};

/// Base pattern for `vector.insert` narrowing patterns.
template <typename InsertionOp>
struct ExtensionOverInsertionPattern : NarrowingPattern<InsertionOp> {};

struct ExtensionOverInsert final
    : ExtensionOverInsertionPattern<vector::InsertOp> {};

struct ExtensionOverInsertElement final
    : ExtensionOverInsertionPattern<vector::InsertElementOp> {};

struct ExtensionOverInsertStridedSlice final
    : ExtensionOverInsertionPattern<vector::InsertStridedSliceOp> {};

struct ExtensionOverShapeCast final : NarrowingPattern<vector::ShapeCastOp> {};

struct ExtensionOverTranspose final : NarrowingPattern<vector::TransposeOp> {};

struct ExtensionOverFlatTranspose final
    : NarrowingPattern<vector::FlatTransposeOp> {};

// Pass Definitions

struct ArithIntNarrowingPass final
    : impl::ArithIntNarrowingBase<ArithIntNarrowingPass> {};
} // namespace

// Public API

void populateArithIntNarrowingPatterns(
    RewritePatternSet &patterns, const ArithIntNarrowingOptions &options) {}

} // namespace mlir::arith