llvm/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp

//===------ WmmaOpsToSPIRV.cpp - WMMA LD/ST/Compute to SPIRV lowering -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains definitions of patterns to lower GPU Subgroup MMA ops to
// SPIRV Cooperative Matrix ops.
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h"
#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/ValueRange.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringSwitch.h"

#include <cassert>

namespace mlir {
//===----------------------------------------------------------------------===//
// Patterns and helpers.
//===----------------------------------------------------------------------===//

/// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op
/// when the elementwise op directly supports with cooperative matrix type.
/// Returns false if cannot.
///
/// See SPV_KHR_cooperative_matrix for supported elementwise ops.
static bool createElementwiseOp(ConversionPatternRewriter &builder,
                                gpu::SubgroupMmaElementwiseOp op, Type coopType,
                                ValueRange operands) {}

bool allOperandsHaveSameCoopMatrixType(ValueRange operands) {}

namespace {
/// Converts GPU MMA ConstantMatrixOp to constant SPIR-V KHR/NV cooperative
/// matrix ops.
struct WmmaConstantOpToSPIRVLowering final
    : OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {};

/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
/// the default case.
struct WmmaElementwiseOpToSPIRVDefaultLowering final
    : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {};

/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
/// matrix times scalar case.
struct WmmaElementwiseOpToSPIRVScalarMulLowering final
    : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {};
} // namespace

//===----------------------------------------------------------------------===//
// SPV_KHR_cooperative_matrix
//===----------------------------------------------------------------------===//

namespace khr {
namespace {

/// Converts the GPU MMA loadOp to KHRCooperativeMatrixLoad op in the SPIRV
/// dialect.
struct WmmaLoadOpToSPIRVLowering final
    : OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> {};

/// Converts the GPU MMA StoreOp to KHRCooperativeMatrixStore op in the SPIRV
/// dialect.
struct WmmaStoreOpToSPIRVLowering final
    : OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> {};

/// Converts GPU MMA Compute to KHRCooperativeMatrixMulAdd op in the SPIRV
/// dialect.
struct WmmaMmaOpToSPIRVLowering final
    : OpConversionPattern<gpu::SubgroupMmaComputeOp> {};

} // namespace
} // namespace khr
} // namespace mlir

void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
    const SPIRVTypeConverter &converter, RewritePatternSet &patterns) {}

void mlir::populateMMAToSPIRVCoopMatrixTypeConversion(
    mlir::SPIRVTypeConverter &typeConverter) {}