#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/AMX/AMXDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
usingnamespacemlir;
usingnamespacemlir::amx;
namespace {
std::pair<Value, Value> getTileSizes(ConversionPatternRewriter &rewriter,
const LLVMTypeConverter &typeConverter,
VectorType vType, Location loc) { … }
LogicalResult verifyStride(MemRefType mType) { … }
Value getStride(ConversionPatternRewriter &rewriter,
const LLVMTypeConverter &typeConverter, MemRefType mType,
Value base, Location loc) { … }
struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> { … };
struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> { … };
struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> { … };
struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> { … };
struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> { … };
}
void mlir::populateAMXLegalizeForLLVMExportPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns) { … }
void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) { … }