//===- LowerVectorInterleave.cpp - Lower 'vector.interleave' operation ----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements target-independent rewrites and utilities to lower the // 'vector.interleave' operation. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" #define DEBUG_TYPE … usingnamespacemlir; usingnamespacemlir::vector; namespace { /// A one-shot unrolling of vector.interleave to the `targetRank`. /// /// Example: /// /// ```mlir /// vector.interleave %a, %b : vector<1x2x3x4xi64> -> vector<1x2x3x8xi64> /// ``` /// Would be unrolled to: /// ```mlir /// %result = arith.constant dense<0> : vector<1x2x3x8xi64> /// %0 = vector.extract %a[0, 0, 0] ─┐ /// : vector<4xi64> from vector<1x2x3x4xi64> | /// %1 = vector.extract %b[0, 0, 0] | /// : vector<4xi64> from vector<1x2x3x4xi64> | - Repeated 6x for /// %2 = vector.interleave %0, %1 : | all leading positions /// : vector<4xi64> -> vector<8xi64> | /// %3 = vector.insert %2, %result [0, 0, 0] | /// : vector<8xi64> into vector<1x2x3x8xi64> ┘ /// ``` /// /// Note: If any leading dimension before the `targetRank` is scalable the /// unrolling will stop before the scalable dimension. class UnrollInterleaveOp final : public OpRewritePattern<vector::InterleaveOp> { … }; /// A one-shot unrolling of vector.deinterleave to the `targetRank`. /// /// Example: /// /// ```mlir /// %0, %1 = vector.deinterleave %a : vector<1x2x3x8xi64> -> vector<1x2x3x4xi64> /// ``` /// Would be unrolled to: /// ```mlir /// %result = arith.constant dense<0> : vector<1x2x3x4xi64> /// %0 = vector.extract %a[0, 0, 0] ─┐ /// : vector<8xi64> from vector<1x2x3x8xi64> | /// %1, %2 = vector.deinterleave %0 | /// : vector<8xi64> -> vector<4xi64> | -- Initial deinterleave /// %3 = vector.insert %1, %result [0, 0, 0] | operation unrolled. /// : vector<4xi64> into vector<1x2x3x4xi64> | /// %4 = vector.insert %2, %result [0, 0, 0] | /// : vector<4xi64> into vector<1x2x3x4xi64> ┘ /// %5 = vector.extract %a[0, 0, 1] ─┐ /// : vector<8xi64> from vector<1x2x3x8xi64> | /// %6, %7 = vector.deinterleave %5 | /// : vector<8xi64> -> vector<4xi64> | -- Recursive pattern for /// %8 = vector.insert %6, %3 [0, 0, 1] | subsequent unrolled /// : vector<4xi64> into vector<1x2x3x4xi64> | deinterleave /// %9 = vector.insert %7, %4 [0, 0, 1] | operations. Repeated /// : vector<4xi64> into vector<1x2x3x4xi64> ┘ 5x in this case. /// ``` /// /// Note: If any leading dimension before the `targetRank` is scalable the /// unrolling will stop before the scalable dimension. class UnrollDeinterleaveOp final : public OpRewritePattern<vector::DeinterleaveOp> { … }; /// Rewrite vector.interleave op into an equivalent vector.shuffle op, when /// applicable: `sourceType` must be 1D and non-scalable. /// /// Example: /// /// ```mlir /// vector.interleave %a, %b : vector<7xi16> -> vector<14xi16> /// ``` /// /// Is rewritten into: /// /// ```mlir /// vector.shuffle %arg0, %arg1 [0, 7, 1, 8, 2, 9, 3, 10, 4, 11, 5, 12, 6, 13] /// : vector<7xi16>, vector<7xi16> /// ``` struct InterleaveToShuffle final : OpRewritePattern<vector::InterleaveOp> { … }; } // namespace void mlir::vector::populateVectorInterleaveLoweringPatterns( RewritePatternSet &patterns, int64_t targetRank, PatternBenefit benefit) { … } void mlir::vector::populateVectorInterleaveToShufflePatterns( RewritePatternSet &patterns, PatternBenefit benefit) { … }