//===- LowerVectorScam.cpp - Lower 'vector.scan' operation ----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements target-independent rewrites and utilities to lower the // 'vector.scan' operation. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Location.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/VectorInterfaces.h" #define DEBUG_TYPE … usingnamespacemlir; usingnamespacemlir::vector; /// This function checks to see if the vector combining kind /// is consistent with the integer or float element type. static bool isValidKind(bool isInt, vector::CombiningKind kind) { … } namespace { /// Convert vector.scan op into arith ops and vector.insert_strided_slice / /// vector.extract_strided_slice. /// /// Example: /// /// ``` /// %0:2 = vector.scan <add>, %arg0, %arg1 /// {inclusive = true, reduction_dim = 1} : /// (vector<2x3xi32>, vector<2xi32>) to (vector<2x3xi32>, vector<2xi32>) /// ``` /// /// is converted to: /// /// ``` /// %cst = arith.constant dense<0> : vector<2x3xi32> /// %0 = vector.extract_strided_slice %arg0 /// {offsets = [0, 0], sizes = [2, 1], strides = [1, 1]} /// : vector<2x3xi32> to vector<2x1xi32> /// %1 = vector.insert_strided_slice %0, %cst /// {offsets = [0, 0], strides = [1, 1]} /// : vector<2x1xi32> into vector<2x3xi32> /// %2 = vector.extract_strided_slice %arg0 /// {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]} /// : vector<2x3xi32> to vector<2x1xi32> /// %3 = arith.muli %0, %2 : vector<2x1xi32> /// %4 = vector.insert_strided_slice %3, %1 /// {offsets = [0, 1], strides = [1, 1]} /// : vector<2x1xi32> into vector<2x3xi32> /// %5 = vector.extract_strided_slice %arg0 /// {offsets = [0, 2], sizes = [2, 1], strides = [1, 1]} /// : vector<2x3xi32> to vector<2x1xi32> /// %6 = arith.muli %3, %5 : vector<2x1xi32> /// %7 = vector.insert_strided_slice %6, %4 /// {offsets = [0, 2], strides = [1, 1]} /// : vector<2x1xi32> into vector<2x3xi32> /// %8 = vector.shape_cast %6 : vector<2x1xi32> to vector<2xi32> /// return %7, %8 : vector<2x3xi32>, vector<2xi32> /// ``` struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> { … }; } // namespace void mlir::vector::populateVectorScanLoweringPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { … }