//===------ WmmaOpsToNVVM.cpp - WMMA LD/ST/Compute to NVVM lowering -------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file contains definitions of patterns to lower GPU Subgroup MMA ops to // NVVM Dialect. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/IR/TypeUtilities.h" usingnamespacemlir; namespace { /// Checks if all the operands of the op being lowered are of LLVM Types. The /// types are expected to be converted by the `LLVMTypeConverter` before the op /// is actually lowered. If the type of an operands is not already converted it /// hints a missing typeConversion and failure is returned in that case. static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter) { … } /// Error string to emit when an unimplemented WMMA variant is encountered. static constexpr StringRef kInvalidCaseStr = …; static NVVM::MMAFrag convertOperand(StringRef operandName) { … } static NVVM::MMATypes getElementType(gpu::MMAMatrixType type) { … } /// This class implements the conversion of GPU MMA loadOp to wmma.load op /// in the NVVM dialect. The conversion not only emits the NVVM op but also /// emits code that is necessary to store the data in the destination memref /// after it has been loaded. struct WmmaLoadOpToNVVMLowering : public ConvertOpToLLVMPattern<gpu::SubgroupMmaLoadMatrixOp> { … }; /// This class implements the conversion of GPU MMA storeOp to wmma.store op /// in the NVVM dialect. The conversion not only emits the NVVM op but also /// emits code that is necessary to unpack the data in the source and /// convert the data in the format that is needed by the NVVM op. struct WmmaStoreOpToNVVMLowering : public ConvertOpToLLVMPattern<gpu::SubgroupMmaStoreMatrixOp> { … }; /// This class implements the conversion of GPU MMA computeOp to wmma.mma op /// in the NVVM dialect. struct WmmaMmaOpToNVVMLowering : public ConvertOpToLLVMPattern<gpu::SubgroupMmaComputeOp> { … }; /// Convert GPU MMA ConstantMatrixOp to a chain of InsertValueOp. struct WmmaConstantOpToNVVMLowering : public ConvertOpToLLVMPattern<gpu::SubgroupMmaConstantMatrixOp> { … }; static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs, Value rhs, bool isMin) { … } static Value createScalarOp(OpBuilder &builder, Location loc, gpu::MMAElementwiseOp op, ArrayRef<Value> operands) { … } /// Convert GPU MMA elementwise ops to extract + op + insert. struct WmmaElementwiseOpToNVVMLowering : public ConvertOpToLLVMPattern<gpu::SubgroupMmaElementwiseOp> { … }; } // namespace /// Return the LLVMStructureType corresponding to the MMAMatrixType `type`. LLVM::LLVMStructType mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) { … } void mlir::populateGpuWMMAToNVVMConversionPatterns( LLVMTypeConverter &converter, RewritePatternSet &patterns) { … }