
//===- VectorToGPU.cpp - Convert vector to GPU dialect ----------*- C++ -*-===//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// This file implements lowering of vector operations to GPU dialect ops.

#include "mlir/Conversion/VectorToGPU/VectorToGPU.h"

#include <type_traits>

#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/NVGPU/Utils/MMAUtils.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Region.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"

#define DEBUG_TYPE
#define DBGS()
#define DBGSNL()

namespace mlir {
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir


/// For a vector TransferOpType `xferOp`, an empty `indices` vector, and an
/// AffineMap representing offsets to apply to indices, the function fills
/// `indices` with the original indices plus the offsets. The offsets are
/// applied by taking into account the permutation map of the transfer op. If
/// the `offsetMap` has dimension placeholders, those should be provided in
/// `dimValues`.
template <typename TransferOpType>
static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp,
                           AffineMap offsetMap, ArrayRef<Value> dimValues,
                           SmallVector<Value, 4> &indices) {}

// Return true if the contract op can be convert to MMA matmul.
static bool contractSupportsMMAMatrixType(vector::ContractionOp contract,
                                          bool useNvGpu) {}

// Return true if the given map represents a transposed matrix load,
// i.e. (d0, d1, ...) -> (dn-1, dn-2).
static bool isTransposeMatrixLoadMap(AffineMap permutationMap) {}

// Return the stide for the second-to-last dimension of |type| if it is a memref
// and has a constant stride.
static std::optional<int64_t> getStaticallyKnownRowStride(ShapedType type) {}

// Return true if the transfer op can be converted to a MMA matrix load.
static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {}

// Return true if the transfer op can be converted to a MMA matrix store.
static bool
transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {}

/// Return true if the constant is a splat to a 2D vector so that it can be
/// converted to a MMA constant matrix op.
static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) {}

/// Return true if this is a broadcast from scalar to a 2D vector.
static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) {}

/// Return true if this integer extend op can be folded into a contract op.
template <typename ExtOpTy>
static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) {}

static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp) {}

/// Return the MMA elementwise enum associated with `op` if it is supported.
/// Return `std::nullopt` otherwise.
static std::optional<gpu::MMAElementwiseOp>
convertElementwiseOpToMMA(Operation *op) {}

/// Return true if the op is supported as elementwise op on MMAMatrix type.
static bool elementwiseSupportsMMAMatrixType(Operation *op) {}

/// Returns true if the extract strided slice op is supported with `mma.sync`
/// path.
static bool
extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op) {}

static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {}

/// Return an unsorted slice handling scf.for region differently than
/// `getSlice`. In scf.for we only want to include as part of the slice elements
/// that are part of the use/def chain.
static SetVector<Operation *>
getSliceContract(Operation *op,
                 const BackwardSliceOptions &backwardSliceOptions,
                 const ForwardSliceOptions &forwardSliceOptions) {}

// Analyze slice of operations based on convert op to figure out if the whole
// slice can be converted to MMA operations.
static SetVector<Operation *> getOpToConvert(mlir::Operation *op,
                                             bool useNvGpu) {}

namespace {
// Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted
// to MMA matmul.
struct PrepareContractToGPUMMA
    : public OpRewritePattern<vector::ContractionOp> {};

// Fold transpose op into the transfer read op. NVGPU mma.sync op only supports
// row-, column-, and row-major layout for matrixA, matrixB, and matrixC,
// respectively. We can fold the transpose operation when loading the data from
// Shared Memory to registers.
struct CombineTransferReadOpTranspose final
    : public OpRewritePattern<vector::TransposeOp> {};

} // namespace

// MMA types have different layout based on how they are used in matmul ops.
// Figure the right layout to use by looking at op uses.
// TODO: Change the GPU dialect to abstract the layout at the this level and
// only care about it during lowering to NVVM.
static const char *inferFragType(Operation *op) {}

static LogicalResult
convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
                      llvm::DenseMap<Value, Value> &valueMapping) {}

static LogicalResult
convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
                       llvm::DenseMap<Value, Value> &valueMapping) {}

/// Returns the vector type which represents a matrix fragment.
static VectorType
getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo &regInfo) {}

/// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static LogicalResult
convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op,
                         llvm::DenseMap<Value, Value> &valueMapping) {}

/// Check if the loaded matrix operand requires transposed.
/// Transposed Map Example:
/// Example 1   : (..., d0, d1) -> (d1 * 1, d0 * 2)
/// Example 2   : (d0, d1, d2, d3) -> (d3, d2)
/// The code below checks if the output 2D is transposed using a generalized
/// version     : (d0, d1, dn, ..., dm, ...) -> (dm, dn)
/// Returns     : true; if m > n, false o.w.
static FailureOr<bool> isTransposed(vector::TransferReadOp op) {}

static LogicalResult
creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
                             llvm::DenseMap<Value, Value> &valueMapping) {}

static LogicalResult
createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
                       llvm::DenseMap<Value, Value> &valueMapping) {}

/// Return true if this is a shared memory memref type.
static bool isSharedMemory(MemRefType type) {}

/// Converts a `vector.transfer_read` operation directly to either a
/// `vector.load` or a `nvgpu.ldmatrix` operation. This function should only be
/// used when converting to `nvgpu.mma.sync` operations.
static LogicalResult
convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op,
                           llvm::DenseMap<Value, Value> &valueMapping) {}

static LogicalResult
convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
                             llvm::DenseMap<Value, Value> &valueMapping) {}

static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
                                       SmallVectorImpl<int64_t> &results) {}

static LogicalResult
convertExtractStridedSlice(RewriterBase &rewriter,
                           vector::ExtractStridedSliceOp op,
                           llvm::DenseMap<Value, Value> &valueMapping) {}

static LogicalResult
convertContractOp(RewriterBase &rewriter, vector::ContractionOp op,
                  llvm::DenseMap<Value, Value> &valueMapping) {}

static LogicalResult
convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op,
                           llvm::DenseMap<Value, Value> &valueMapping) {}

/// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static LogicalResult
convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op,
                  llvm::DenseMap<Value, Value> &valueMapping) {}

/// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
static LogicalResult
convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op,
                   llvm::DenseMap<Value, Value> &valueMapping) {}

// Replace ForOp with a new ForOp with extra operands. The YieldOp is not
// updated and needs to be updated separately for the loop to be correct.
static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
                                               scf::ForOp loop,
                                               ValueRange newInitArgs) {}

static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op,
                                  llvm::DenseMap<Value, Value> &valueMapping) {}

static LogicalResult
convertYieldOp(RewriterBase &rewriter, scf::YieldOp op,
               llvm::DenseMap<Value, Value> &valueMapping) {}

/// Convert an elementwise op to the equivalent elementwise op on MMA matrix.
static LogicalResult
convertElementwiseOp(RewriterBase &rewriter, Operation *op,
                     gpu::MMAElementwiseOp opType,
                     llvm::DenseMap<Value, Value> &valueMapping) {}

void mlir::populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns,
                                              bool useNvGpu) {}

LogicalResult mlir::convertVectorToMMAOps(RewriterBase &rewriter,
                                          Operation *rootOp) {}

LogicalResult mlir::convertVectorToNVVMCompatibleMMASync(RewriterBase &rewriter,
                                                         Operation *rootOp) {}

namespace {

struct ConvertVectorToGPUPass
    : public impl::ConvertVectorToGPUBase<ConvertVectorToGPUPass> {};

} // namespace

std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass(bool useNvGpu) {}