//===- VectorDistribute.cpp - patterns to do vector distribution ----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" #include "mlir/IR/AffineExpr.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/FormatVariadic.h" #include <numeric> #include <utility> usingnamespacemlir; usingnamespacemlir::vector; /// Currently the distribution map is implicit based on the vector shape. In the /// future it will be part of the op. /// Example: /// ``` /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) { /// ... /// vector.yield %3 : vector<32x16x64xf32> /// } /// ``` /// Would have an implicit map of: /// `(d0, d1, d2) -> (d0, d2)` static AffineMap calculateImplicitMap(VectorType sequentialType, VectorType distributedType) { … } namespace { /// Helper struct to create the load / store operations that permit transit /// through the parallel / sequential and the sequential / parallel boundaries /// when performing `rewriteWarpOpToScfFor`. /// /// The vector distribution dimension is inferred from the vector types. struct DistributedLoadStoreHelper { … }; } // namespace /// Helper to create a new WarpExecuteOnLane0Op with different signature. static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns( RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes) { … } /// Helper to create a new WarpExecuteOnLane0Op region with extra outputs. /// `indices` return the index of each new output. static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns( RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, TypeRange newReturnTypes, llvm::SmallVector<size_t> &indices) { … } /// Helper to know if an op can be hoisted out of the region. static bool canBeHoisted(Operation *op, function_ref<bool(Value)> definedOutside) { … } /// Return a value yielded by `warpOp` which statifies the filter lamdba /// condition and is not dead. static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp, const std::function<bool(Operation *)> &fn) { … } // Clones `op` into a new operation that takes `operands` and returns // `resultTypes`. static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter, Location loc, Operation *op, ArrayRef<Value> operands, ArrayRef<Type> resultTypes) { … } namespace { /// Rewrite a WarpExecuteOnLane0Op into a predicated scf.if op where the single /// thread `laneId` executes the entirety of the computation. /// /// After the transformation: /// - the IR within the scf.if op can be thought of as executing sequentially /// (from the point of view of threads along `laneId`). /// - the IR outside of the scf.if op can be thought of as executing in /// parallel (from the point of view of threads along `laneId`). /// /// Values that need to transit through the parallel / sequential and the /// sequential / parallel boundaries do so via reads and writes to a temporary /// memory location. /// /// The transformation proceeds in multiple steps: /// 1. Create the scf.if op. /// 2. Insert appropriate (alloc, write)-pairs before the scf.if and reads /// within the scf.if to transit the values captured from above. /// 3. Synchronize before the scf.if to ensure all writes inserted in 2. are /// consistent within the scf.if. /// 4. Move the body of the WarpExecuteOnLane0Op inside the scf.if. /// 5. Insert appropriate writes within scf.if and reads after the scf.if to /// transit the values returned by the op. /// 6. Synchronize after the scf.if to ensure all writes inserted in 5. are /// consistent after the scf.if. /// 7. Perform late cleanups. /// /// All this assumes the vector distribution occurs along the most minor /// distributed vector dimension. struct WarpOpToScfIfPattern : public OpRewritePattern<WarpExecuteOnLane0Op> { … }; /// Clone `writeOp` assumed to be nested under `warpOp` into a new warp execute /// op with the proper return type. /// The new write op is updated to write the result of the new warp execute op. /// The old `writeOp` is deleted. static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, vector::TransferWriteOp writeOp, VectorType targetType, VectorType maybeMaskType) { … } /// Return the distributed vector type based on the original type and the /// distribution map. The map is expected to have a dimension equal to the /// original type rank and should be a projection where the results are the /// distributed dimensions. The number of results should be equal to the number /// of warp sizes which is currently limited to 1. /// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1) /// and a warp size of 16 would distribute the second dimension (associated to /// d1) and return vector<16x2x64> static VectorType getDistributedType(VectorType originalType, AffineMap map, int64_t warpSize) { … } /// Distribute transfer_write ops based on the affine map returned by /// `distributionMapFn`. Writes of size more than `maxNumElementToExtract` /// will not be distributed (it should be less than the warp size). /// /// Example: /// ``` /// %0 = vector.warp_execute_on_lane_0(%id){ /// ... /// vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32> /// vector.yield /// } /// ``` /// To /// ``` /// %r:3 = vector.warp_execute_on_lane_0(%id) -> (vector<1xf32>) { /// ... /// vector.yield %v : vector<32xf32> /// } /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32> struct WarpOpTransferWrite : public OpRewritePattern<WarpExecuteOnLane0Op> { … }; /// Sink out elementwise op feeding into a warp op yield. /// ``` /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { /// ... /// %3 = arith.addf %1, %2 : vector<32xf32> /// vector.yield %3 : vector<32xf32> /// } /// ``` /// To /// ``` /// %r:3 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>, /// vector<1xf32>, vector<1xf32>) { /// ... /// %4 = arith.addf %2, %3 : vector<32xf32> /// vector.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>, /// vector<32xf32> /// } /// %0 = arith.addf %r#1, %r#2 : vector<1xf32> struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> { … }; /// Sink out splat constant op feeding into a warp op yield. /// ``` /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { /// ... /// %cst = arith.constant dense<2.0> : vector<32xf32> /// vector.yield %cst : vector<32xf32> /// } /// ``` /// To /// ``` /// vector.warp_execute_on_lane_0(%arg0 { /// ... /// } /// %0 = arith.constant dense<2.0> : vector<1xf32> struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> { … }; /// Delinearize the given `laneId` into multiple dimensions, where each /// dimension's size is determined by `originalShape` and `distributedShape` /// together. This function expects the total numbers of threads needed for /// distribution is equal to `warpSize`. Returns true and updates /// `delinearizedIds` if so. bool delinearizeLaneId(OpBuilder &builder, Location loc, ArrayRef<int64_t> originalShape, ArrayRef<int64_t> distributedShape, int64_t warpSize, Value laneId, SmallVectorImpl<Value> &delinearizedIds) { … } /// Sink out transfer_read op feeding into a warp op yield. /// ``` /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { /// ... // %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, // vector<32xf32> /// vector.yield %2 : vector<32xf32> /// } /// ``` /// To /// ``` /// %dead = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>, /// vector<1xf32>, vector<1xf32>) { /// ... /// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, /// vector<32xf32> vector.yield %2 : vector<32xf32> /// } /// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32> struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> { … }; /// Remove any result that has no use along with the matching yieldOp operand. // TODO: Move this in WarpExecuteOnLane0Op canonicalization. struct WarpOpDeadResult : public OpRewritePattern<WarpExecuteOnLane0Op> { … }; // If an operand is directly yielded out of the region we can forward it // directly and it doesn't need to go through the region. struct WarpOpForwardOperand : public OpRewritePattern<WarpExecuteOnLane0Op> { … }; struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> { … }; /// Pattern to move shape cast out of the warp op. shape cast is basically a /// no-op for warp distribution; we need to handle the shape though. struct WarpOpShapeCast : public OpRewritePattern<WarpExecuteOnLane0Op> { … }; /// Sink out vector.create_mask op feeding into a warp op yield. /// ``` /// %0 = ... /// %1 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { /// ... /// %mask = vector.create_mask %0 : vector<32xi1> /// vector.yield %mask : vector<32xi1> /// } /// ``` /// To /// ``` /// %0 = ... /// vector.warp_execute_on_lane_0(%arg0) { /// ... /// } /// %cmp = arith.cmpi ult, %laneid, %0 /// %ub = arith.select %cmp, %c0, %c1 /// %1 = vector.create_mask %ub : vector<1xi1> struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> { … }; /// Pattern to move out vector.extract of single element vector. Those don't /// need to be distributed and can just be propagated outside of the region. struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> { … }; /// Pattern to move out vector.extractelement of 0-D tensors. Those don't /// need to be distributed and can just be propagated outside of the region. struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> { … }; struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> { … }; struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> { … }; /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if /// the scf.ForOp is the last operation in the region so that it doesn't change /// the order of execution. This creates a new scf.for region after the /// WarpExecuteOnLane0Op. The new scf.for region will contain a new /// WarpExecuteOnLane0Op region. Example: /// ``` /// %w = vector.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) { /// ... /// %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v) /// -> (vector<128xf32>) { /// ... /// scf.yield %r : vector<128xf32> /// } /// vector.yield %v1 : vector<128xf32> /// } /// ``` /// To: /// %w0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) { /// ... /// vector.yield %v : vector<128xf32> /// } /// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0) /// -> (vector<4xf32>) { /// %iw = vector.warp_execute_on_lane_0(%laneid) /// args(%varg : vector<4xf32>) -> (vector<4xf32>) { /// ^bb0(%arg: vector<128xf32>): /// ... /// vector.yield %ir : vector<128xf32> /// } /// scf.yield %iw : vector<4xf32> /// } /// ``` struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> { … }; /// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op. /// The vector is reduced in parallel. Currently limited to vector size matching /// the warpOp size. E.g.: /// ``` /// %r = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (f32) { /// %0 = "some_def"() : () -> (vector<32xf32>) /// %1 = vector.reduction "add", %0 : vector<32xf32> into f32 /// vector_ext.yield %1 : f32 /// } /// ``` /// is lowered to: /// ``` /// %0 = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) { /// %1 = "some_def"() : () -> (vector<32xf32>) /// vector_ext.yield %1 : vector<32xf32> /// } /// %a = vector.extract %0[0] : f32 from vector<1xf32> /// %r = ("warp.reduction %a") /// ``` struct WarpOpReduction : public OpRewritePattern<WarpExecuteOnLane0Op> { … }; } // namespace void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern( RewritePatternSet &patterns, const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit) { … } void mlir::vector::populateDistributeTransferWriteOpPatterns( RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn, unsigned maxNumElementsToExtract, PatternBenefit benefit) { … } void mlir::vector::populatePropagateWarpVectorDistributionPatterns( RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn, const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit, PatternBenefit readBenefit) { … } void mlir::vector::populateDistributeReduction( RewritePatternSet &patterns, const DistributedReductionFn &distributedReductionFn, PatternBenefit benefit) { … } void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) { … }