//===- LowerVectorShapeCast.cpp - Lower 'vector.shape_cast' operation -----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements target-independent rewrites and utilities to lower the // 'vector.shape_cast' operation. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Location.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/VectorInterfaces.h" #define DEBUG_TYPE … usingnamespacemlir; usingnamespacemlir::vector; namespace { /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D /// vectors progressively on the way to target llvm.matrix intrinsics. /// This iterates over the most major dimension of the 2-D vector and performs /// rewrites into: /// vector.extract from 2-D + vector.insert_strided_slice offset into 1-D class ShapeCastOp2DDownCastRewritePattern : public OpRewritePattern<vector::ShapeCastOp> { … }; /// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D /// vectors progressively. /// This iterates over the most major dimension of the 2-D vector and performs /// rewrites into: /// vector.extract_strided_slice from 1-D + vector.insert into 2-D /// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle. class ShapeCastOp2DUpCastRewritePattern : public OpRewritePattern<vector::ShapeCastOp> { … }; static void incIdx(llvm::MutableArrayRef<int64_t> idx, VectorType tp, int dimIdx, int initialStep = 1) { … } // We typically should not lower general shape cast operations into data // movement instructions, since the assumption is that these casts are // optimized away during progressive lowering. For completeness, however, // we fall back to a reference implementation that moves all elements // into the right place if we get here. class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> { … }; /// A shape_cast lowering for scalable vectors with a single trailing scalable /// dimension. This is similar to the general shape_cast lowering but makes use /// of vector.scalable.insert and vector.scalable.extract to move elements a /// subvector at a time. /// /// E.g.: /// ``` /// // Flatten scalable vector /// %0 = vector.shape_cast %arg0 : vector<2x1x[4]xi32> to vector<[8]xi32> /// ``` /// is rewritten to: /// ``` /// // Flatten scalable vector /// %c = arith.constant dense<0> : vector<[8]xi32> /// %0 = vector.extract %arg0[0, 0] : vector<[4]xi32> from vector<2x1x[4]xi32> /// %1 = vector.scalable.insert %0, %c[0] : vector<[4]xi32> into vector<[8]xi32> /// %2 = vector.extract %arg0[1, 0] : vector<[4]xi32> from vector<2x1x[4]xi32> /// %3 = vector.scalable.insert %2, %1[4] : vector<[4]xi32> into vector<[8]xi32> /// ``` /// or: /// ``` /// // Un-flatten scalable vector /// %0 = vector.shape_cast %arg0 : vector<[8]xi32> to vector<2x1x[4]xi32> /// ``` /// is rewritten to: /// ``` /// // Un-flatten scalable vector /// %c = arith.constant dense<0> : vector<2x1x[4]xi32> /// %0 = vector.scalable.extract %arg0[0] : vector<[4]xi32> from vector<[8]xi32> /// %1 = vector.insert %0, %c [0, 0] : vector<[4]xi32> into vector<2x1x[4]xi32> /// %2 = vector.scalable.extract %arg0[4] : vector<[4]xi32> from vector<[8]xi32> /// %3 = vector.insert %2, %1 [1, 0] : vector<[4]xi32> into vector<2x1x[4]xi32> /// ``` class ScalableShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> { … }; } // namespace void mlir::vector::populateVectorShapeCastLoweringPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { … }