#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h"
#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/ValueRange.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringSwitch.h"
#include <cassert>
namespace mlir {
static bool createElementwiseOp(ConversionPatternRewriter &builder,
gpu::SubgroupMmaElementwiseOp op, Type coopType,
ValueRange operands) { … }
bool allOperandsHaveSameCoopMatrixType(ValueRange operands) { … }
namespace {
struct WmmaConstantOpToSPIRVLowering final
: OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> { … };
struct WmmaElementwiseOpToSPIRVDefaultLowering final
: OpConversionPattern<gpu::SubgroupMmaElementwiseOp> { … };
struct WmmaElementwiseOpToSPIRVScalarMulLowering final
: OpConversionPattern<gpu::SubgroupMmaElementwiseOp> { … };
}
namespace khr {
namespace {
struct WmmaLoadOpToSPIRVLowering final
: OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> { … };
struct WmmaStoreOpToSPIRVLowering final
: OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> { … };
struct WmmaMmaOpToSPIRVLowering final
: OpConversionPattern<gpu::SubgroupMmaComputeOp> { … };
}
}
}
void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
const SPIRVTypeConverter &converter, RewritePatternSet &patterns) { … }
void mlir::populateMMAToSPIRVCoopMatrixTypeConversion(
mlir::SPIRVTypeConverter &typeConverter) { … }