llvm/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp

//===- AVXTranspose.cpp - Lower Vector transpose to AVX -------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements vector.transpose rewrites as AVX patterns for particular
// sizes of interest.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/Dialect/X86Vector/Transforms.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/Support/Format.h"
#include "llvm/Support/FormatVariadic.h"

usingnamespacemlir;
usingnamespacemlir::vector;
usingnamespacemlir::x86vector;
usingnamespacemlir::x86vector::avx2;
usingnamespacemlir::x86vector::avx2::inline_asm;
usingnamespacemlir::x86vector::avx2::intrin;

Value mlir::x86vector::avx2::inline_asm::mm256BlendPsAsm(
    ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) {}

Value mlir::x86vector::avx2::intrin::mm256UnpackLoPs(ImplicitLocOpBuilder &b,
                                                     Value v1, Value v2) {}

Value mlir::x86vector::avx2::intrin::mm256UnpackHiPs(ImplicitLocOpBuilder &b,
                                                     Value v1, Value v2) {}
///                            a  a   b   b  a  a   b   b
/// Takes an 8 bit mask, 2 bit for each position of a[0, 3)  **and** b[0, 4):
///                                 0:127    |         128:255
///                            b01  b23  C8  D8  |  b01+4 b23+4 C8+4 D8+4
Value mlir::x86vector::avx2::intrin::mm256ShufflePs(ImplicitLocOpBuilder &b,
                                                    Value v1, Value v2,
                                                    uint8_t mask) {}

// imm[0:1] out of imm[0:3] is:
//    0             1           2             3
// a[0:127] or a[128:255] or b[0:127] or b[128:255]    |
//          a[0:127] or a[128:255] or b[0:127] or b[128:255]
//             0             1           2             3
// imm[0:1] out of imm[4:7].
Value mlir::x86vector::avx2::intrin::mm256Permute2f128Ps(
    ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) {}

/// If bit i of `mask` is zero, take f32@i from v1 else take it from v2.
Value mlir::x86vector::avx2::intrin::mm256BlendPs(ImplicitLocOpBuilder &b,
                                                  Value v1, Value v2,
                                                  uint8_t mask) {}

/// AVX2 4x8xf32-specific transpose lowering using a "C intrinsics" model.
void mlir::x86vector::avx2::transpose4x8xf32(ImplicitLocOpBuilder &ib,
                                             MutableArrayRef<Value> vs) {}

/// AVX2 8x8xf32-specific transpose lowering using a "C intrinsics" model.
void mlir::x86vector::avx2::transpose8x8xf32(ImplicitLocOpBuilder &ib,
                                             MutableArrayRef<Value> vs) {}

/// Rewrite AVX2-specific vector.transpose, for the supported cases and
/// depending on the `TransposeLoweringOptions`. The lowering supports 2-D
/// transpose cases and n-D cases that have been decomposed into 2-D
/// transposition slices. For example, a 3-D transpose:
///
///   %0 = vector.transpose %arg0, [2, 0, 1]
///      : vector<1024x2048x4096xf32> to vector<4096x1024x2048xf32>
///
/// could be sliced into 2-D transposes by tiling two of its dimensions to one
/// of the vector lengths supported by the AVX2 patterns (e.g., 4x8):
///
///   %0 = vector.transpose %arg0, [2, 0, 1]
///      : vector<1x4x8xf32> to vector<8x1x4xf32>
///
/// This lowering will analyze the n-D vector.transpose and determine if it's a
/// supported 2-D transposition slice where any of the AVX2 patterns can be
/// applied.
class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {};

void mlir::x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
    RewritePatternSet &patterns, LoweringOptions options, int benefit) {}