//===- VectorDropLeadUnitDim.cpp - Conversion within the Vector dialect ---===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include <numeric> #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/TypeUtilities.h" #define DEBUG_TYPE … usingnamespacemlir; usingnamespacemlir::vector; // Trims leading one dimensions from `oldType` and returns the result type. // Returns `vector<1xT>` if `oldType` only has one element. static VectorType trimLeadingOneDims(VectorType oldType) { … } /// Return a smallVector of size `rank` containing all zeros. static SmallVector<int64_t> splatZero(int64_t rank) { … } namespace { // Casts away leading one dimensions in vector.extract_strided_slice's vector // input by inserting vector.broadcast. struct CastAwayExtractStridedSliceLeadingOneDim : public OpRewritePattern<vector::ExtractStridedSliceOp> { … }; // Casts away leading one dimensions in vector.insert_strided_slice's vector // inputs by inserting vector.broadcast. struct CastAwayInsertStridedSliceLeadingOneDim : public OpRewritePattern<vector::InsertStridedSliceOp> { … }; // Casts away leading one dimensions in vector.insert's vector inputs by // inserting vector.broadcast. struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> { … }; static Value dropUnitDimsFromMask(OpBuilder &b, Location loc, Value mask, VectorType newType, AffineMap newMap, VectorType oldMaskType) { … } // Turns vector.transfer_read on vector with leading 1 dimensions into // vector.shape_cast followed by vector.transfer_read on vector without leading // 1 dimensions. struct CastAwayTransferReadLeadingOneDim : public OpRewritePattern<vector::TransferReadOp> { … }; // Turns vector.transfer_write on vector with leading 1 dimensions into // vector.shape_cast followed by vector.transfer_write on vector without leading // 1 dimensions. struct CastAwayTransferWriteLeadingOneDim : public OpRewritePattern<vector::TransferWriteOp> { … }; } // namespace FailureOr<Value> mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp, MaskingOpInterface maskingOp, RewriterBase &rewriter) { … } namespace { /// Turns vector.contract on vector with leading 1 dimensions into /// vector.extract followed by vector.contract on vector without leading /// 1 dimensions. Also performs tranpose of lhs and rhs operands if required /// prior to extract. struct CastAwayContractionLeadingOneDim : public MaskableOpRewritePattern<vector::ContractionOp> { … }; /// Looks at elementwise operations on vectors with at least one leading /// dimension equal 1, e.g. vector<1x[4]x1xf32> (but not vector<2x[4]x1xf32>), /// and cast aways the leading one dimensions (_plural_) and then broadcasts /// the results. /// /// Example before: /// %1 = arith.mulf %arg0, %arg1 : vector<1x4x1xf32> /// Example after: /// %2 = arith.mulf %0, %1 : vector<4x1xf32> /// %3 = vector.broadcast %2 : vector<4x1xf32> to vector<1x4x1xf32> /// /// Does support scalable vectors. class CastAwayElementwiseLeadingOneDim : public RewritePattern { … }; // Drops leading 1 dimensions from vector.constant_mask and inserts a // vector.broadcast back to the original shape. struct CastAwayConstantMaskLeadingOneDim : public OpRewritePattern<vector::ConstantMaskOp> { … }; } // namespace void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { … }