//===-- MMAUtils.h - MLIR NVGPU dialect utilities for MMA operations-------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file provides utilities to assist in the lowering of other dialects // (e.g. Vector) to `nvgpu.mma.*` dialect operations. // //===----------------------------------------------------------------------===// #ifndef MLIR_DIALECT_NVGPU_UTILS_MMAUTILS_H #define MLIR_DIALECT_NVGPU_UTILS_MMAUTILS_H #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Types.h" namespace mlir { namespace nvgpu { /// Represents the role of an operand in an MMA instruction: /// `result := matmul(A, B) + C` enum class MatMulOperandRole : int32_t { … }; /// Returns the first user of the `op` that is vector.contract. If no /// vector.contract user exists, return failure. FailureOr<vector::ContractionOp> getUserContract(Operation *op); /// Collects information about a warp-level matrix operand represented by a /// VectorType. struct WarpMatrixInfo { … }; /// If `op` is a `vector.transfer_write`, return the `WarpMatrixInfo` for the /// vector operand. If op is a `vector.transfer_read`, `vector.contraction`, or /// `arith.constant`, return the `WarpMatrixInfo` corresponding to the result. /// Otherwise, return failure. FailureOr<WarpMatrixInfo> getWarpMatrixInfo(Operation *op); /// Returns the number of bits in a single tile row. It is either 128, 256, or /// 512 bits depending on the data type and` whether the operand is an /// accumulator/result operand int64_t inferTileWidthInBits(const WarpMatrixInfo &type); /// Specifies information about the registers which compose a matrix fragment /// according to the PTX documentation. struct FragmentElementInfo { … }; /// Returns a FragmentElementInfo struct describing the register types for the /// given matrix fragment type. FailureOr<FragmentElementInfo> getMmaSyncRegisterType(const WarpMatrixInfo &type); /// Returns an AffineMap which maps a two dimensions representing (laneId, /// logicalValueId) and returns two results representing offsets within a /// matrix operand. The offsets point to the values the thread is responsible /// for (AKA the matrix fragment values) during a warp-collective matrix /// operation. For a visual reference of this LaneId -> (row, col) mapping, /// please see NVIDIA's PTX documentation: /// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma FailureOr<AffineMap> getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc, const WarpMatrixInfo &fragmentType); /// Encapsulates the parameters needed to lower a `nvgpu.ldmatrix` operation to /// `nvvm.ldmatrix`. struct LdMatrixParams { … }; /// Given `type` that contains info for a warp-matrix operand and whether or not /// the load is a transposed load, return the LdMatrixParams. FailureOr<LdMatrixParams> getLdMatrixParams(const WarpMatrixInfo &type, bool transpose); /// Returns an AffineMap which maps a single dimension representing the laneId /// to two results representing offsets within the matrix operand that should /// be the pointer locations a thread should pass to the ldmatrix instruction. FailureOr<AffineMap> getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc, const LdMatrixParams ¶ms); /// Returns whether the `vector.transfer_read` instruction can be interpreted /// as a warp-level cooperative matrix load operation. This function is meant to /// be used to establish whether `op` is part of a chain of such warp-level /// operations. bool canLowerToWarpMatrixOperation(vector::TransferReadOp op); /// Returns whether the `vector.transfer_write` instruction can be interpreted /// as a warp-level cooperative matrix store operation. This function is meant /// to be used to establish whether `op` is part of a chain of such warp-level /// operations. bool canLowerToWarpMatrixOperation(vector::TransferWriteOp op); } // namespace nvgpu } // namespace mlir #endif // MLIR_DIALECT_NVGPU_UTILS_MMAUTILS_H