#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h"
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/TransformOps/Utils.h"
#include "mlir/Dialect/GPU/Transforms/Passes.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include <type_traits>
usingnamespacemlir;
usingnamespacemlir::gpu;
usingnamespacemlir::transform;
usingnamespacemlir::transform::gpu;
#define DEBUG_TYPE …
#define DEBUG_TYPE_ALIAS …
#define DBGS() …
#define LDBG(X) …
#define DBGS_ALIAS() …
void transform::ApplyGPUToNVVMConversionPatternsOp::populatePatterns(
TypeConverter &typeConverter, RewritePatternSet &patterns) { … }
LogicalResult
transform::ApplyGPUToNVVMConversionPatternsOp::verifyTypeConverter(
transform::TypeConverterBuilderOpInterface builder) { … }
void transform::ApplyGPUWwmaToNVVMConversionPatternsOp::populatePatterns(
TypeConverter &typeConverter, RewritePatternSet &patterns) { … }
LogicalResult
transform::ApplyGPUWwmaToNVVMConversionPatternsOp::verifyTypeConverter(
transform::TypeConverterBuilderOpInterface builder) { … }
void transform::ApplyGPUSubgroupReduceToNVVMConversionPatternsOp::
populatePatterns(TypeConverter &typeConverter,
RewritePatternSet &patterns) { … }
LogicalResult transform::ApplyGPUSubgroupReduceToNVVMConversionPatternsOp::
verifyTypeConverter(transform::TypeConverterBuilderOpInterface builder) { … }
void ApplyGPURewritePatternsOp::populatePatterns(RewritePatternSet &patterns) { … }
static std::optional<SmallVector<int64_t>>
gpuMmaUnrollOrder(vector::ContractionOp contract) { … }
static std::optional<SmallVector<int64_t>>
getSubgroupMmaNativeVectorSize(Operation *op, int64_t m, int64_t n, int64_t k) { … }
void transform::ApplyUnrollVectorsSubgroupMmaOp::populatePatterns(
RewritePatternSet &patterns) { … }
void EliminateBarriersOp::populatePatterns(RewritePatternSet &patterns) { … }
namespace {
struct MappingKind { … };
struct BlockMappingKind : MappingKind { … };
struct ThreadMappingKind : MappingKind { … };
}
static DiagnosedSilenceableFailure
definiteFailureHelper(std::optional<TransformOpInterface> transformOp,
Operation *target, const Twine &message) { … }
template <typename MappingKindType>
static DiagnosedSilenceableFailure
checkMappingAttributeTypes(std::optional<TransformOpInterface> transformOp,
scf::ForallOp forallOp) { … }
template <typename MappingKindType>
static DiagnosedSilenceableFailure
verifyGpuMapping(std::optional<TransformOpInterface> transformOp,
scf::ForallOp forallOp) { … }
struct ForallRewriteResult { … };
template <typename OpTy, typename OperationOrBlock>
static void
replaceUnitMappingIdsHelper(RewriterBase &rewriter, Location loc,
OperationOrBlock *parent, Value replacement,
ArrayRef<int64_t> availableMappingSizes) { … }
static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
scf::ForallOp forallOp, ArrayRef<int64_t> availableMappingSizes,
ForallRewriteResult &result, const GpuIdBuilder &gpuIdBuilder) { … }
DiagnosedSilenceableFailure mlir::transform::gpu::mapForallToBlocksImpl(
RewriterBase &rewriter, TransformOpInterface transformOp,
scf::ForallOp forallOp, SmallVectorImpl<int64_t> &gridDims,
const GpuIdBuilder &gpuIdBuilder) { … }
DiagnosedSilenceableFailure
mlir::transform::gpu::findTopLevelForallOp(Operation *target,
scf::ForallOp &topLevelForallOp,
TransformOpInterface transformOp) { … }
DiagnosedSilenceableFailure transform::MapForallToBlocks::applyToOne(
transform::TransformRewriter &rewriter, Operation *target,
ApplyToEachResultList &results, transform::TransformState &state) { … }
LogicalResult transform::MapForallToBlocks::verify() { … }
static DiagnosedSilenceableFailure checkMappingSpec(
std::optional<TransformOpInterface> transformOp, scf::ForallOp forallOp,
ArrayRef<int64_t> numParallelIterations, ArrayRef<int64_t> blockOrGridSizes,
int factor, bool useLinearMapping = false) { … }
static DiagnosedSilenceableFailure
getThreadIdBuilder(std::optional<TransformOpInterface> transformOp,
scf::ForallOp forallOp, ArrayRef<int64_t> blockSizes,
int64_t warpSize, GpuIdBuilder &gpuIdBuilder) { … }
DiagnosedSilenceableFailure mlir::transform::gpu::mapOneForallToThreadsImpl(
RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
scf::ForallOp forallOp, ArrayRef<int64_t> blockSizes, int64_t warpSize,
bool syncAfterDistribute) { … }
DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForallToThreadsImpl(
RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
Operation *target, ArrayRef<int64_t> blockDims, int64_t warpSize,
bool syncAfterDistribute) { … }
DiagnosedSilenceableFailure transform::MapNestedForallToThreads::applyToOne(
transform::TransformRewriter &rewriter, Operation *target,
ApplyToEachResultList &results, TransformState &state) { … }
namespace {
class GPUTransformDialectExtension
: public transform::TransformDialectExtension<
GPUTransformDialectExtension> { … };
}
#define GET_OP_CLASSES
#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc"
void mlir::gpu::registerTransformDialectExtension(DialectRegistry ®istry) { … }