llvm/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

//===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include <optional>
#include <type_traits>

#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

usingnamespacemlir;
usingnamespacemlir::linalg;
usingnamespacemlir::vector;

namespace {

struct TestVectorToVectorLowering
    : public PassWrapper<TestVectorToVectorLowering,
                         OperationPass<func::FuncOp>> {};

struct TestVectorContractionPrepareForMMTLowering
    : public PassWrapper<TestVectorContractionPrepareForMMTLowering,
                         OperationPass<func::FuncOp>> {};

struct TestVectorUnrollingPatterns
    : public PassWrapper<TestVectorUnrollingPatterns,
                         OperationPass<func::FuncOp>> {};

struct TestVectorTransferUnrollingPatterns
    : public PassWrapper<TestVectorTransferUnrollingPatterns,
                         OperationPass<func::FuncOp>> {};

struct TestScalarVectorTransferLoweringPatterns
    : public PassWrapper<TestScalarVectorTransferLoweringPatterns,
                         OperationPass<func::FuncOp>> {};

struct TestVectorTransferOpt
    : public PassWrapper<TestVectorTransferOpt, OperationPass<func::FuncOp>> {};

struct TestVectorTransferCollapseInnerMostContiguousDims
    : public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims,
                         OperationPass<func::FuncOp>> {};

struct TestVectorSinkPatterns
    : public PassWrapper<TestVectorSinkPatterns, OperationPass<func::FuncOp>> {};

struct TestVectorReduceToContractPatternsPatterns
    : public PassWrapper<TestVectorReduceToContractPatternsPatterns,
                         OperationPass<func::FuncOp>> {};

struct TestVectorChainedReductionFoldingPatterns
    : public PassWrapper<TestVectorChainedReductionFoldingPatterns,
                         OperationPass<func::FuncOp>> {};

struct TestVectorBreakDownReductionPatterns
    : public PassWrapper<TestVectorBreakDownReductionPatterns,
                         OperationPass<func::FuncOp>> {};

struct TestFlattenVectorTransferPatterns
    : public PassWrapper<TestFlattenVectorTransferPatterns,
                         OperationPass<func::FuncOp>> {};

struct TestVectorScanLowering
    : public PassWrapper<TestVectorScanLowering, OperationPass<func::FuncOp>> {};

/// Allocate shared memory for a single warp to test lowering of
/// WarpExecuteOnLane0Op.
static Value allocateGlobalSharedMemory(Location loc, OpBuilder &builder,
                                        WarpExecuteOnLane0Op warpOp,
                                        Type type) {}

static Value warpReduction(Location loc, OpBuilder &builder, Value input,
                           CombiningKind kind, uint32_t size) {}

struct TestVectorDistribution
    : public PassWrapper<TestVectorDistribution, OperationPass<func::FuncOp>> {};

struct TestVectorExtractStridedSliceLowering
    : public PassWrapper<TestVectorExtractStridedSliceLowering,
                         OperationPass<func::FuncOp>> {};

struct TestVectorBreakDownBitCast
    : public PassWrapper<TestVectorBreakDownBitCast,
                         OperationPass<func::FuncOp>> {};

struct TestCreateVectorBroadcast
    : public PassWrapper<TestCreateVectorBroadcast,
                         OperationPass<func::FuncOp>> {};

struct TestVectorGatherLowering
    : public PassWrapper<TestVectorGatherLowering,
                         OperationPass<func::FuncOp>> {};

struct TestFoldArithExtensionIntoVectorContractPatterns
    : public PassWrapper<TestFoldArithExtensionIntoVectorContractPatterns,
                         OperationPass<func::FuncOp>> {};

struct TestVectorEmulateMaskedLoadStore final
    : public PassWrapper<TestVectorEmulateMaskedLoadStore,
                         OperationPass<func::FuncOp>> {};

struct TestVectorLinearize final
    : public PassWrapper<TestVectorLinearize, OperationPass<>> {};

struct TestEliminateVectorMasks
    : public PassWrapper<TestEliminateVectorMasks,
                         OperationPass<func::FuncOp>> {};
} // namespace

namespace mlir {
namespace test {
void registerTestVectorLowerings() {}
} // namespace test
} // namespace mlir