//===- LowerVectorGather.cpp - Lower 'vector.gather' operation ------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements target-independent rewrites and utilities to lower the // 'vector.gather' operation. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Location.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/VectorInterfaces.h" #define DEBUG_TYPE … usingnamespacemlir; usingnamespacemlir::vector; namespace { /// Flattens 2 or more dimensional `vector.gather` ops by unrolling the /// outermost dimension. For example: /// ``` /// %g = vector.gather %base[%c0][%v], %mask, %pass_thru : /// ... into vector<2x3xf32> /// /// ==> /// /// %0 = arith.constant dense<0.0> : vector<2x3xf32> /// %g0 = vector.gather %base[%c0][%v0], %mask0, %pass_thru0 : ... /// %1 = vector.insert %g0, %0 [0] : vector<3xf32> into vector<2x3xf32> /// %g1 = vector.gather %base[%c0][%v1], %mask1, %pass_thru1 : ... /// %g = vector.insert %g1, %1 [1] : vector<3xf32> into vector<2x3xf32> /// ``` /// /// When applied exhaustively, this will produce a sequence of 1-d gather ops. /// /// Supports vector types with a fixed leading dimension. struct FlattenGather : OpRewritePattern<vector::GatherOp> { … }; /// Rewrites a vector.gather of a strided MemRef as a gather of a non-strided /// MemRef with updated indices that model the strided access. /// /// ```mlir /// %subview = memref.subview %M (...) /// : memref<100x3xf32> to memref<100xf32, strided<[3]>> /// %gather = vector.gather %subview[%idxs] (...) : memref<100xf32, strided<[3]>> /// ``` /// ==> /// ```mlir /// %collapse_shape = memref.collapse_shape %M (...) /// : memref<100x3xf32> into memref<300xf32> /// %new_idxs = arith.muli %idxs, %c3 : vector<4xindex> /// %gather = vector.gather %collapse_shape[%new_idxs] (...) /// : memref<300xf32> (...) /// ``` /// /// ATM this is effectively limited to reading a 1D Vector from a 2D MemRef, /// but should be fairly straightforward to extend beyond that. struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> { … }; /// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or /// `tensor.extract`s. To avoid out-of-bounds memory accesses, these /// loads/extracts are made conditional using `scf.if` ops. struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> { … }; } // namespace void mlir::vector::populateVectorGatherLoweringPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { … }