llvm/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp

//===- MMAUtils.cpp - MLIR NVGPU dialect utils for MMA operations----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/NVGPU/Utils/MMAUtils.h"

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"

usingnamespacemlir;
usingnamespacemlir::nvgpu;

/// There are always 4 threads per [128|256|512] bit row.
static constexpr int64_t kThreadsPerRow =;
static constexpr int64_t kNumRowsPerTile =;

static bool isAccumulatorOrResult(MatMulOperandRole operandType) {}

/// Returns the number of registers which compose a matrix fragment held by a
/// single thread.
static int64_t inferNumRegistersPerMatrixFragment(const WarpMatrixInfo &type) {}

/// Returns the number of 8 x [128|256|512] bit tiles that compose the given
/// operand shape.
static std::array<int64_t, 2> getTileShape(ArrayRef<int64_t> operandShape,
                                           Type elementType,
                                           int64_t lineSizeBits) {}

/// Returns the first user of the `op` that is vector.contract. If no
/// vector.contract user exists, return failure.
FailureOr<vector::ContractionOp> nvgpu::getUserContract(Operation *op) {}

FailureOr<WarpMatrixInfo> nvgpu::getWarpMatrixInfo(Operation *op) {}

int64_t nvgpu::inferTileWidthInBits(const WarpMatrixInfo &type) {}

FailureOr<FragmentElementInfo>
nvgpu::getMmaSyncRegisterType(const WarpMatrixInfo &type) {}

static AffineMap getRegisterIndexToTileOffsetMap(int64_t lineSize,
                                                 Type elementType,
                                                 ArrayRef<int64_t> operandShape,
                                                 bool isAccumulator,
                                                 int64_t elementsPerRegister,
                                                 AffineExpr logicalValueId) {}

FailureOr<AffineMap>
nvgpu::getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc,
                                         const WarpMatrixInfo &fragmentType) {}

FailureOr<nvgpu::LdMatrixParams>
nvgpu::getLdMatrixParams(const WarpMatrixInfo &type, bool transpose) {}

FailureOr<AffineMap>
nvgpu::getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc,
                                      const LdMatrixParams &params) {}

bool nvgpu::canLowerToWarpMatrixOperation(vector::TransferReadOp op) {}

bool nvgpu::canLowerToWarpMatrixOperation(vector::TransferWriteOp op) {}