llvm/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp

//===- VectorDropLeadUnitDim.cpp - Conversion within the Vector dialect ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include <numeric>

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/TypeUtilities.h"

#define DEBUG_TYPE

usingnamespacemlir;
usingnamespacemlir::vector;

// Trims leading one dimensions from `oldType` and returns the result type.
// Returns `vector<1xT>` if `oldType` only has one element.
static VectorType trimLeadingOneDims(VectorType oldType) {}

/// Return a smallVector of size `rank` containing all zeros.
static SmallVector<int64_t> splatZero(int64_t rank) {}
namespace {

// Casts away leading one dimensions in vector.extract_strided_slice's vector
// input by inserting vector.broadcast.
struct CastAwayExtractStridedSliceLeadingOneDim
    : public OpRewritePattern<vector::ExtractStridedSliceOp> {};

// Casts away leading one dimensions in vector.insert_strided_slice's vector
// inputs by inserting vector.broadcast.
struct CastAwayInsertStridedSliceLeadingOneDim
    : public OpRewritePattern<vector::InsertStridedSliceOp> {};

// Casts away leading one dimensions in vector.insert's vector inputs by
// inserting vector.broadcast.
struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {};

static Value dropUnitDimsFromMask(OpBuilder &b, Location loc, Value mask,
                                  VectorType newType, AffineMap newMap,
                                  VectorType oldMaskType) {}

// Turns vector.transfer_read on vector with leading 1 dimensions into
// vector.shape_cast followed by vector.transfer_read on vector without leading
// 1 dimensions.
struct CastAwayTransferReadLeadingOneDim
    : public OpRewritePattern<vector::TransferReadOp> {};

// Turns vector.transfer_write on vector with leading 1 dimensions into
// vector.shape_cast followed by vector.transfer_write on vector without leading
// 1 dimensions.
struct CastAwayTransferWriteLeadingOneDim
    : public OpRewritePattern<vector::TransferWriteOp> {};

} // namespace

FailureOr<Value>
mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
                                               MaskingOpInterface maskingOp,
                                               RewriterBase &rewriter) {}

namespace {

/// Turns vector.contract on vector with leading 1 dimensions into
/// vector.extract followed by vector.contract on vector without leading
/// 1 dimensions. Also performs tranpose of lhs and rhs operands if required
/// prior to extract.
struct CastAwayContractionLeadingOneDim
    : public MaskableOpRewritePattern<vector::ContractionOp> {};

/// Looks at elementwise operations on vectors with at least one leading
/// dimension equal 1, e.g. vector<1x[4]x1xf32> (but not vector<2x[4]x1xf32>),
/// and cast aways the leading one dimensions (_plural_) and then broadcasts
/// the results.
///
/// Example before:
///     %1 = arith.mulf %arg0, %arg1 : vector<1x4x1xf32>
/// Example after:
///    %2 = arith.mulf %0, %1 : vector<4x1xf32>
///    %3 = vector.broadcast %2 : vector<4x1xf32> to vector<1x4x1xf32>
///
/// Does support scalable vectors.
class CastAwayElementwiseLeadingOneDim : public RewritePattern {};

// Drops leading 1 dimensions from vector.constant_mask and inserts a
// vector.broadcast back to the original shape.
struct CastAwayConstantMaskLeadingOneDim
    : public OpRewritePattern<vector::ConstantMaskOp> {};

} // namespace

void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
    RewritePatternSet &patterns, PatternBenefit benefit) {}