//===- AVXTranspose.cpp - Lower Vector transpose to AVX -------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements vector.transpose rewrites as AVX patterns for particular // sizes of interest. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/Dialect/X86Vector/Transforms.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "llvm/Support/Format.h" #include "llvm/Support/FormatVariadic.h" usingnamespacemlir; usingnamespacemlir::vector; usingnamespacemlir::x86vector; usingnamespacemlir::x86vector::avx2; usingnamespacemlir::x86vector::avx2::inline_asm; usingnamespacemlir::x86vector::avx2::intrin; Value mlir::x86vector::avx2::inline_asm::mm256BlendPsAsm( ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) { … } Value mlir::x86vector::avx2::intrin::mm256UnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2) { … } Value mlir::x86vector::avx2::intrin::mm256UnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2) { … } /// a a b b a a b b /// Takes an 8 bit mask, 2 bit for each position of a[0, 3) **and** b[0, 4): /// 0:127 | 128:255 /// b01 b23 C8 D8 | b01+4 b23+4 C8+4 D8+4 Value mlir::x86vector::avx2::intrin::mm256ShufflePs(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) { … } // imm[0:1] out of imm[0:3] is: // 0 1 2 3 // a[0:127] or a[128:255] or b[0:127] or b[128:255] | // a[0:127] or a[128:255] or b[0:127] or b[128:255] // 0 1 2 3 // imm[0:1] out of imm[4:7]. Value mlir::x86vector::avx2::intrin::mm256Permute2f128Ps( ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) { … } /// If bit i of `mask` is zero, take f32@i from v1 else take it from v2. Value mlir::x86vector::avx2::intrin::mm256BlendPs(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) { … } /// AVX2 4x8xf32-specific transpose lowering using a "C intrinsics" model. void mlir::x86vector::avx2::transpose4x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef<Value> vs) { … } /// AVX2 8x8xf32-specific transpose lowering using a "C intrinsics" model. void mlir::x86vector::avx2::transpose8x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef<Value> vs) { … } /// Rewrite AVX2-specific vector.transpose, for the supported cases and /// depending on the `TransposeLoweringOptions`. The lowering supports 2-D /// transpose cases and n-D cases that have been decomposed into 2-D /// transposition slices. For example, a 3-D transpose: /// /// %0 = vector.transpose %arg0, [2, 0, 1] /// : vector<1024x2048x4096xf32> to vector<4096x1024x2048xf32> /// /// could be sliced into 2-D transposes by tiling two of its dimensions to one /// of the vector lengths supported by the AVX2 patterns (e.g., 4x8): /// /// %0 = vector.transpose %arg0, [2, 0, 1] /// : vector<1x4x8xf32> to vector<8x1x4xf32> /// /// This lowering will analyze the n-D vector.transpose and determine if it's a /// supported 2-D transposition slice where any of the AVX2 patterns can be /// applied. class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> { … }; void mlir::x86vector::avx2::populateSpecializedTransposeLoweringPatterns( RewritePatternSet &patterns, LoweringOptions options, int benefit) { … }