//===- VectorTransferSplitRewritePatterns.cpp - Transfer Split Rewrites ---===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements target-independent patterns to rewrite a vector.transfer // op into a fully in-bounds part and a partial part. // //===----------------------------------------------------------------------===// #include <optional> #include <type_traits> #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/VectorInterfaces.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #define DEBUG_TYPE … usingnamespacemlir; usingnamespacemlir::vector; /// Build the condition to ensure that a particular VectorTransferOpInterface /// is in-bounds. static Value createInBoundsCond(RewriterBase &b, VectorTransferOpInterface xferOp) { … } /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds /// masking) fast path and a slow path. /// If `ifOp` is not null and the result is `success, the `ifOp` points to the /// newly created conditional upon function return. /// To accommodate for the fact that the original vector.transfer indexing may /// be arbitrary and the slow path indexes @[0...0] in the temporary buffer, the /// scf.if op returns a view and values of type index. /// At this time, only vector.transfer_read case is implemented. /// /// Example (a 2-D vector.transfer_read): /// ``` /// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...> /// ``` /// is transformed into: /// ``` /// %1:3 = scf.if (%inBounds) { /// // fast path, direct cast /// memref.cast %A: memref<A...> to compatibleMemRefType /// scf.yield %view : compatibleMemRefType, index, index /// } else { /// // slow path, not in-bounds vector.transfer or linalg.copy. /// memref.cast %alloc: memref<B...> to compatibleMemRefType /// scf.yield %4 : compatibleMemRefType, index, index // } /// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]} /// ``` /// where `alloc` is a top of the function alloca'ed buffer of one vector. /// /// Preconditions: /// 1. `xferOp.getPermutationMap()` must be a minor identity map /// 2. the rank of the `xferOp.memref()` and the rank of the /// `xferOp.getVector()` must be equal. This will be relaxed in the future /// but requires rank-reducing subviews. static LogicalResult splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp) { … } /// Given two MemRefTypes `aT` and `bT`, return a MemRefType to which both can /// be cast. If the MemRefTypes don't have the same rank or are not strided, /// return null; otherwise: /// 1. if `aT` and `bT` are cast-compatible, return `aT`. /// 2. else return a new MemRefType obtained by iterating over the shape and /// strides and: /// a. keeping the ones that are static and equal across `aT` and `bT`. /// b. using a dynamic shape and/or stride for the dimensions that don't /// agree. static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) { … } /// Casts the given memref to a compatible memref type. If the source memref has /// a different address space than the target type, a `memref.memory_space_cast` /// is first inserted, followed by a `memref.cast`. static Value castToCompatibleMemRefType(OpBuilder &b, Value memref, MemRefType compatibleMemRefType) { … } /// Operates under a scoped context to build the intersection between the /// view `xferOp.getSource()` @ `xferOp.getIndices()` and the view `alloc`. // TODO: view intersection/union/differences should be a proper std op. static std::pair<Value, Value> createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp, Value alloc) { … } /// Given an `xferOp` for which: /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed. /// 2. a memref of single vector `alloc` has been allocated. /// Produce IR resembling: /// ``` /// %1:3 = scf.if (%inBounds) { /// (memref.memory_space_cast %A: memref<A..., addr_space> to memref<A...>) /// %view = memref.cast %A: memref<A...> to compatibleMemRefType /// scf.yield %view, ... : compatibleMemRefType, index, index /// } else { /// %2 = linalg.fill(%pad, %alloc) /// %3 = subview %view [...][...][...] /// %4 = subview %alloc [0, 0] [...] [...] /// linalg.copy(%3, %4) /// %5 = memref.cast %alloc: memref<B...> to compatibleMemRefType /// scf.yield %5, ... : compatibleMemRefType, index, index /// } /// ``` /// Return the produced scf::IfOp. static scf::IfOp createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes, Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc) { … } /// Given an `xferOp` for which: /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed. /// 2. a memref of single vector `alloc` has been allocated. /// Produce IR resembling: /// ``` /// %1:3 = scf.if (%inBounds) { /// (memref.memory_space_cast %A: memref<A..., addr_space> to memref<A...>) /// memref.cast %A: memref<A...> to compatibleMemRefType /// scf.yield %view, ... : compatibleMemRefType, index, index /// } else { /// %2 = vector.transfer_read %view[...], %pad : memref<A...>, vector<...> /// %3 = vector.type_cast %extra_alloc : /// memref<...> to memref<vector<...>> /// store %2, %3[] : memref<vector<...>> /// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType /// scf.yield %4, ... : compatibleMemRefType, index, index /// } /// ``` /// Return the produced scf::IfOp. static scf::IfOp createFullPartialVectorTransferRead( RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes, Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc) { … } /// Given an `xferOp` for which: /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed. /// 2. a memref of single vector `alloc` has been allocated. /// Produce IR resembling: /// ``` /// %1:3 = scf.if (%inBounds) { /// memref.cast %A: memref<A...> to compatibleMemRefType /// scf.yield %view, ... : compatibleMemRefType, index, index /// } else { /// %3 = vector.type_cast %extra_alloc : /// memref<...> to memref<vector<...>> /// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType /// scf.yield %4, ... : compatibleMemRefType, index, index /// } /// ``` static ValueRange getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp, TypeRange returnTypes, Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc) { … } /// Given an `xferOp` for which: /// 1. `inBoundsCond` has been computed. /// 2. a memref of single vector `alloc` has been allocated. /// 3. it originally wrote to %view /// Produce IR resembling: /// ``` /// %notInBounds = arith.xori %inBounds, %true /// scf.if (%notInBounds) { /// %3 = subview %alloc [...][...][...] /// %4 = subview %view [0, 0][...][...] /// linalg.copy(%3, %4) /// } /// ``` static void createFullPartialLinalgCopy(RewriterBase &b, vector::TransferWriteOp xferOp, Value inBoundsCond, Value alloc) { … } /// Given an `xferOp` for which: /// 1. `inBoundsCond` has been computed. /// 2. a memref of single vector `alloc` has been allocated. /// 3. it originally wrote to %view /// Produce IR resembling: /// ``` /// %notInBounds = arith.xori %inBounds, %true /// scf.if (%notInBounds) { /// %2 = load %alloc : memref<vector<...>> /// vector.transfer_write %2, %view[...] : memref<A...>, vector<...> /// } /// ``` static void createFullPartialVectorTransferWrite(RewriterBase &b, vector::TransferWriteOp xferOp, Value inBoundsCond, Value alloc) { … } // TODO: Parallelism and threadlocal considerations with a ParallelScope trait. static Operation *getAutomaticAllocationScope(Operation *op) { … } /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds /// masking) fastpath and a slowpath. /// /// For vector.transfer_read: /// If `ifOp` is not null and the result is `success, the `ifOp` points to the /// newly created conditional upon function return. /// To accomodate for the fact that the original vector.transfer indexing may be /// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the /// scf.if op returns a view and values of type index. /// /// Example (a 2-D vector.transfer_read): /// ``` /// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...> /// ``` /// is transformed into: /// ``` /// %1:3 = scf.if (%inBounds) { /// // fastpath, direct cast /// memref.cast %A: memref<A...> to compatibleMemRefType /// scf.yield %view : compatibleMemRefType, index, index /// } else { /// // slowpath, not in-bounds vector.transfer or linalg.copy. /// memref.cast %alloc: memref<B...> to compatibleMemRefType /// scf.yield %4 : compatibleMemRefType, index, index // } /// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]} /// ``` /// where `alloc` is a top of the function alloca'ed buffer of one vector. /// /// For vector.transfer_write: /// There are 2 conditional blocks. First a block to decide which memref and /// indices to use for an unmasked, inbounds write. Then a conditional block to /// further copy a partial buffer into the final result in the slow path case. /// /// Example (a 2-D vector.transfer_write): /// ``` /// vector.transfer_write %arg, %0[...], %pad : memref<A...>, vector<...> /// ``` /// is transformed into: /// ``` /// %1:3 = scf.if (%inBounds) { /// memref.cast %A: memref<A...> to compatibleMemRefType /// scf.yield %view : compatibleMemRefType, index, index /// } else { /// memref.cast %alloc: memref<B...> to compatibleMemRefType /// scf.yield %4 : compatibleMemRefType, index, index /// } /// %0 = vector.transfer_write %arg, %1#0[%1#1, %1#2] {in_bounds = [true ... /// true]} /// scf.if (%notInBounds) { /// // slowpath: not in-bounds vector.transfer or linalg.copy. /// } /// ``` /// where `alloc` is a top of the function alloca'ed buffer of one vector. /// /// Preconditions: /// 1. `xferOp.getPermutationMap()` must be a minor identity map /// 2. the rank of the `xferOp.getSource()` and the rank of the /// `xferOp.getVector()` must be equal. This will be relaxed in the future /// but requires rank-reducing subviews. LogicalResult mlir::vector::splitFullAndPartialTransfer( RewriterBase &b, VectorTransferOpInterface xferOp, VectorTransformsOptions options, scf::IfOp *ifOp) { … } namespace { /// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern /// may take an extra filter to perform selection at a finer granularity. struct VectorTransferFullPartialRewriter : public RewritePattern { … }; } // namespace LogicalResult VectorTransferFullPartialRewriter::matchAndRewrite( Operation *op, PatternRewriter &rewriter) const { … } void mlir::vector::populateVectorTransferFullPartialPatterns( RewritePatternSet &patterns, const VectorTransformsOptions &options) { … }