#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/NVGPU/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/ArrayRef.h"
usingnamespacemlir;
usingnamespacemlir::linalg;
usingnamespacemlir::nvgpu;
usingnamespacemlir::NVVM;
usingnamespacemlir::transform;
#define DEBUG_TYPE …
#define DBGS() …
#define DBGSNL() …
#define LDBG(X) …
void transform::ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns(
TypeConverter &typeConverter, RewritePatternSet &patterns) { … }
LogicalResult
transform::ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter(
transform::TypeConverterBuilderOpInterface builder) { … }
void transform::CreateAsyncGroupsOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { … }
DiagnosedSilenceableFailure transform::CreateAsyncGroupsOp::applyToOne(
TransformRewriter &rewriter, Operation *target,
ApplyToEachResultList &results, TransformState &state) { … }
static bool hasDefaultMemorySpace(BaseMemRefType type) { … }
static bool hasSharedMemorySpace(BaseMemRefType type) { … }
static Value getValueLoadedFromGlobal(Operation *op) { … }
static bool isStoreToShared(Operation *op, Value v) { … }
static bool isLoadFromGlobalStoredToShared(Operation *op) { … }
static LogicalResult
collectStage0PipeliningOps(scf::ForOp forOp,
llvm::SmallPtrSet<Operation *, 16> &ops) { … }
static void
setAsyncWaitGroupsInFlight(OpBuilder &builder, Operation *op,
scf::PipeliningOption::PipelinerPart part,
unsigned iteration, unsigned depth) { … }
static void getPipelineStages(
scf::ForOp forOp,
std::vector<std::pair<Operation *, unsigned>> &opsWithPipelineStages,
unsigned depth, llvm::SmallPtrSetImpl<Operation *> &stage0Ops) { … }
static Operation *replaceOpWithPredicatedOp(RewriterBase &rewriter,
Operation *op, Value predicate) { … }
static std::tuple<DiagnosedSilenceableFailure, scf::ForOp>
pipelineForSharedCopies(RewriterBase &rewriter, scf::ForOp forOp, int64_t depth,
bool epiloguePeeling) { … }
DiagnosedSilenceableFailure PipelineSharedMemoryCopiesOp::applyToOne(
TransformRewriter &rewriter, scf::ForOp forOp,
ApplyToEachResultList &results, TransformState &state) { … }
struct RowColIndexing : private std::pair<AffineExpr, AffineExpr> { … };
struct MmaSyncBuilder { … };
template <typename ApplyFn, typename ReduceFn>
static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn,
ReduceFn reduceFn) { … }
SmallVector<Value>
MmaSyncBuilder::buildMemRefLoads(OpBuilder &b, Location loc,
OpFoldResult laneId, Value memref,
const IndexCalculator &indexFn) { … }
Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand(
OpBuilder &b, Location loc, OpFoldResult laneId, Value memref,
IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) { … }
SmallVector<Operation *> MmaSyncBuilder::buildMemRefStores(
OpBuilder &b, Location loc, ValueRange toStore, OpFoldResult laneId,
Value memref, const IndexCalculator &indexFn) { … }
SmallVector<Operation *> MmaSyncBuilder::buildMmaSyncMemRefStoreOperand(
OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId,
Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) { … }
static std::tuple<SmallVector<int64_t>, SmallVector<int64_t>,
SmallVector<int64_t>>
makeVectorShapes(ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs,
ArrayRef<int64_t> res) { … }
FailureOr<MmaSyncBuilder::MmaSyncInfo>
MmaSyncBuilder::getIndexCalculators(ArrayRef<int64_t> opShape,
TypeRange elementalTypes) { … }
FailureOr<Operation *> MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) { … }
DiagnosedSilenceableFailure transform::RewriteMatmulAsMmaSyncOp::applyToOne(
transform::TransformRewriter &rewriter, LinalgOp linalgOp,
transform::ApplyToEachResultList &results,
transform::TransformState &state) { … }
struct HopperBuilder { … };
SmallVector<Operation *> HopperBuilder::buildPredicateLoadsOnThread0(
ArrayRef<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescriptors,
ArrayRef<TypedValue<MemRefType>> sharedMemBuffers,
TypedValue<nvgpu::MBarrierGroupType> barrier) { … }
static Attribute getSharedAddressSpaceAttribute(OpBuilder &b) { … }
TypedValue<nvgpu::MBarrierGroupType>
HopperBuilder::buildAndInitBarrierInSharedMemory(OpFoldResult numThreads) { … }
TypedValue<nvgpu::TensorMapDescriptorType>
HopperBuilder::buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref,
gpu::LaunchOp launchOp) { … }
OpFoldResult HopperBuilder::buildTmaAsyncLoad(
TypedValue<nvgpu::TensorMapDescriptorType> globalDesc,
TypedValue<MemRefType> sharedMemref,
TypedValue<nvgpu::MBarrierGroupType> barrier,
SmallVectorImpl<Operation *> &loadOps) { … }
void HopperBuilder::buildBarrierArriveTx(
TypedValue<nvgpu::MBarrierGroupType> barrier,
ArrayRef<OpFoldResult> mixedSizes) { … }
void HopperBuilder::buildTryWaitParity(
TypedValue<nvgpu::MBarrierGroupType> barrier) { … }
struct CopyBuilder : public HopperBuilder { … };
SmallVector<Operation *> CopyBuilder::rewrite(ArrayRef<Operation *> copyOps) { … }
DiagnosedSilenceableFailure
transform::RewriteCopyAsTmaOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) { … }
namespace {
class NVGPUTransformDialectExtension
: public transform::TransformDialectExtension<
NVGPUTransformDialectExtension> { … };
}
#define GET_OP_CLASSES
#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
void mlir::nvgpu::registerTransformDialectExtension(DialectRegistry ®istry) { … }