llvm/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp

//===- EmulateNarrowType.cpp - Narrow type emulation ----*- C++
//-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MathExtras.h"
#include <cassert>
#include <type_traits>

usingnamespacemlir;

//===----------------------------------------------------------------------===//
// Utility functions
//===----------------------------------------------------------------------===//

/// Converts a memref::ReinterpretCastOp to the converted type. The result
/// MemRefType of the old op must have a rank and stride of 1, with static
/// offset and size. The number of bits in the offset must evenly divide the
/// bitwidth of the new converted type.
static LogicalResult
convertCastingOp(ConversionPatternRewriter &rewriter,
                 memref::ReinterpretCastOp::Adaptor adaptor,
                 memref::ReinterpretCastOp op, MemRefType newTy) {}

/// When data is loaded/stored in `targetBits` granularity, but is used in
/// `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is
/// treated as an array of elements of width `sourceBits`.
/// Return the bit offset of the value at position `srcIdx`. For example, if
/// `sourceBits` equals to 4 and `targetBits` equals to 8, the x-th element is
/// located at (x % 2) * 4. Because there are two elements in one i8, and one
/// element has 4 bits.
static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx,
                                  int sourceBits, int targetBits,
                                  OpBuilder &builder) {}

/// When writing a subbyte size, masked bitwise operations are used to only
/// modify the relevant bits. This function returns an and mask for clearing
/// the destination bits in a subbyte write. E.g., when writing to the second
/// i4 in an i32, 0xFFFFFF0F is created.
static Value getSubByteWriteMask(Location loc, OpFoldResult linearizedIndices,
                                 int64_t srcBits, int64_t dstBits,
                                 Value bitwidthOffset, OpBuilder &builder) {}

/// Returns the scaled linearized index based on the `srcBits` and `dstBits`
/// sizes. The input `linearizedIndex` has the granularity of `srcBits`, and
/// the returned index has the granularity of `dstBits`
static Value getIndicesForLoadOrStore(OpBuilder &builder, Location loc,
                                      OpFoldResult linearizedIndex,
                                      int64_t srcBits, int64_t dstBits) {}

static OpFoldResult
getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits,
                        const SmallVector<OpFoldResult> &indices,
                        Value memref) {}

namespace {

//===----------------------------------------------------------------------===//
// ConvertMemRefAllocation
//===----------------------------------------------------------------------===//

template <typename OpTy>
struct ConvertMemRefAllocation final : OpConversionPattern<OpTy> {};

//===----------------------------------------------------------------------===//
// ConvertMemRefAssumeAlignment
//===----------------------------------------------------------------------===//

struct ConvertMemRefAssumeAlignment final
    : OpConversionPattern<memref::AssumeAlignmentOp> {};

//===----------------------------------------------------------------------===//
// ConvertMemRefCopy
//===----------------------------------------------------------------------===//

struct ConvertMemRefCopy final : OpConversionPattern<memref::CopyOp> {};

//===----------------------------------------------------------------------===//
// ConvertMemRefDealloc
//===----------------------------------------------------------------------===//

struct ConvertMemRefDealloc final : OpConversionPattern<memref::DeallocOp> {};

//===----------------------------------------------------------------------===//
// ConvertMemRefLoad
//===----------------------------------------------------------------------===//

struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {};

//===----------------------------------------------------------------------===//
// ConvertMemRefMemorySpaceCast
//===----------------------------------------------------------------------===//

struct ConvertMemRefMemorySpaceCast final
    : OpConversionPattern<memref::MemorySpaceCastOp> {};

//===----------------------------------------------------------------------===//
// ConvertMemRefReinterpretCast
//===----------------------------------------------------------------------===//

/// Output types should be at most one dimensional, so only the 0 or 1
/// dimensional cases are supported.
struct ConvertMemRefReinterpretCast final
    : OpConversionPattern<memref::ReinterpretCastOp> {};

//===----------------------------------------------------------------------===//
// ConvertMemrefStore
//===----------------------------------------------------------------------===//

struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {};

//===----------------------------------------------------------------------===//
// ConvertMemRefSubview
//===----------------------------------------------------------------------===//

/// Emulating narrow ints on subview have limited support, supporting only
/// static offset and size and stride of 1. Ideally, the subview should be
/// folded away before running narrow type emulation, and this pattern should
/// only run for cases that can't be folded.
struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {};

//===----------------------------------------------------------------------===//
// ConvertMemRefCollapseShape
//===----------------------------------------------------------------------===//

/// Emulating a `memref.collapse_shape` becomes a no-op after emulation given
/// that we flatten memrefs to a single dimension as part of the emulation and
/// there is no dimension to collapse any further.
struct ConvertMemRefCollapseShape final
    : OpConversionPattern<memref::CollapseShapeOp> {};

/// Emulating a `memref.expand_shape` becomes a no-op after emulation given
/// that we flatten memrefs to a single dimension as part of the emulation and
/// the expansion would just have been undone.
struct ConvertMemRefExpandShape final
    : OpConversionPattern<memref::ExpandShapeOp> {};
} // end anonymous namespace

//===----------------------------------------------------------------------===//
// Public Interface Definition
//===----------------------------------------------------------------------===//

void memref::populateMemRefNarrowTypeEmulationPatterns(
    const arith::NarrowTypeEmulationConverter &typeConverter,
    RewritePatternSet &patterns) {}

static SmallVector<int64_t> getLinearizedShape(MemRefType ty, int srcBits,
                                               int dstBits) {}

void memref::populateMemRefNarrowTypeEmulationConversions(
    arith::NarrowTypeEmulationConverter &typeConverter) {}