llvm/mlir/include/mlir/Dialect/X86Vector/Transforms.h

//=- Transforms.h - X86Vector Dialect Transformation Entrypoints -*- C++ -*-=//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_X86VECTOR_TRANSFORMS_H
#define MLIR_DIALECT_X86VECTOR_TRANSFORMS_H

#include "mlir/IR/Value.h"

namespace mlir {

class ImplicitLocOpBuilder;
class LLVMConversionTarget;
class LLVMTypeConverter;
class RewritePatternSet;

namespace x86vector {

/// Helper class to factor out the creation and extraction of masks from nibs.
struct MaskHelper {};

//===----------------------------------------------------------------------===//
/// Helpers extracted from:
///   - clang/lib/Headers/avxintrin.h
///   - clang/test/CodeGen/X86/avx-builtins.c
///   - clang/test/CodeGen/X86/avx2-builtins.c
///   - clang/test/CodeGen/X86/avx-shuffle-builtins.c
/// as well as the Intel Intrinsics Guide
/// (https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html)
/// make it easier to just implement known good lowerings.
/// All intrinsics correspond 1-1 to the Intel definition.
//===----------------------------------------------------------------------===//

namespace avx2 {

namespace inline_asm {
//===----------------------------------------------------------------------===//
/// Methods in the inline_asm namespace  emit calls to LLVM::InlineAsmOp.
//===----------------------------------------------------------------------===//
/// If bit i of `mask` is zero, take f32@i from v1 else take it from v2.
Value mm256BlendPsAsm(ImplicitLocOpBuilder &b, Value v1, Value v2,
                      uint8_t mask);

} // namespace inline_asm

namespace intrin {
//===----------------------------------------------------------------------===//
/// Methods in the intrin namespace emulate clang's impl. of X86 intrinsics.
//===----------------------------------------------------------------------===//
/// Lower to vector.shuffle v1, v2, [0, 8, 1, 9, 4, 12, 5, 13].
Value mm256UnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2);

/// Lower to vector.shuffle v1, v2, [0, 8, 1, 9, 4, 12, 5, 13].
Value mm256UnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2);

///                            a  a   b   b  a  a   b   b
/// Take an 8 bit mask, 2 bit for each position of a[0, 3)  **and** b[0, 4):
///                                 0:127    |         128:255
///                            b01  b23  C8  D8  |  b01+4 b23+4 C8+4 D8+4
Value mm256ShufflePs(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask);

// imm[0:1] out of imm[0:3] is:
//    0             1           2             3
// a[0:127] or a[128:255] or b[0:127] or b[128:255]    |
//          a[0:127] or a[128:255] or b[0:127] or b[128:255]
//             0             1           2             3
// imm[0:1] out of imm[4:7].
Value mm256Permute2f128Ps(ImplicitLocOpBuilder &b, Value v1, Value v2,
                          uint8_t mask);

/// If bit i of `mask` is zero, take f32@i from v1 else take it from v2.
Value mm256BlendPs(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask);
} // namespace intrin

//===----------------------------------------------------------------------===//
/// Generic lowerings may either use intrin or inline_asm depending on needs.
//===----------------------------------------------------------------------===//
/// 4x8xf32-specific AVX2 transpose lowering.
void transpose4x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef<Value> vs);

/// 8x8xf32-specific AVX2 transpose lowering.
void transpose8x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef<Value> vs);

/// Structure to control the behavior of specialized AVX2 transpose lowering.
struct TransposeLoweringOptions {};

/// Options for controlling specialized AVX2 lowerings.
struct LoweringOptions {};

/// Insert specialized transpose lowering patterns.
void populateSpecializedTransposeLoweringPatterns(
    RewritePatternSet &patterns, LoweringOptions options = LoweringOptions(),
    int benefit = 10);

} // namespace avx2
} // namespace x86vector

/// Collect a set of patterns to lower X86Vector ops to ops that map to LLVM
/// intrinsics.
void populateX86VectorLegalizeForLLVMExportPatterns(
    LLVMTypeConverter &converter, RewritePatternSet &patterns);

/// Configure the target to support lowering X86Vector ops to ops that map to
/// LLVM intrinsics.
void configureX86VectorLegalizeForExportTarget(LLVMConversionTarget &target);

} // namespace mlir

#endif // MLIR_DIALECT_X86VECTOR_TRANSFORMS_H