llvm/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

//===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the types and operation details for the NVVM IR dialect in
// MLIR, and the LLVM IR dialect.  It also registers the dialect.
//
// The NVVM dialect only contains GPU specific additions on top of the general
// LLVM dialect.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/LLVMIR/NVVMDialect.h"

#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Type.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
#include <optional>
#include <string>

usingnamespacemlir;
usingnamespaceNVVM;

#include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"

//===----------------------------------------------------------------------===//
// Printing/parsing for NVVM ops
//===----------------------------------------------------------------------===//

static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) {}

// <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {}

void VoteBallotOp::print(OpAsmPrinter &p) {}

LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {}

LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {}

LogicalResult CpAsyncOp::verify() {}

// Given the element type of an operand and whether or not it is an accumulator,
// this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
// operand's element type.
std::optional<mlir::NVVM::MMATypes>
MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) {}

static bool isInt4PtxType(MMATypes type) {}

static bool isInt8PtxType(MMATypes type) {}

static bool isIntegerPtxType(MMATypes type) {}

MMATypes MmaOp::accumPtxType() {}

MMATypes MmaOp::resultPtxType() {}

void MmaOp::print(OpAsmPrinter &p) {}

void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType,
                  ValueRange operandA, ValueRange operandB, ValueRange operandC,
                  ArrayRef<int64_t> shape, std::optional<MMAB1Op> b1Op,
                  std::optional<MMAIntOverflow> intOverflow,
                  std::optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
                  std::optional<std::array<MMALayout, 2>> multiplicandLayouts) {}

// <operation> :=
//   A `[` $operandA `]` B `[` $operandB `]` C `[` $operandC `]`
//   attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0]))
//     `->` type($res)
ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {}

LogicalResult MmaOp::verify() {}

LogicalResult ShflOp::verify() {}

std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
                                                   NVVM::MMAFrag frag, int nRow,
                                                   int nCol,
                                                   MLIRContext *context) {}

static std::pair<mlir::Type, unsigned>
inferMMATypeFromMNK(NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n,
                    int k, MLIRContext *context) {}

LogicalResult NVVM::WMMALoadOp::verify() {}

LogicalResult NVVM::WMMAStoreOp::verify() {}

LogicalResult NVVM::WMMAMmaOp::verify() {}

LogicalResult NVVM::LdMatrixOp::verify() {}

LogicalResult NVVM::StMatrixOp::verify() {}

FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {}

LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD,
                                     NVVM::WGMMATypes typeA,
                                     NVVM::WGMMATypes typeB) {}

LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) {}

LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {}

std::string NVVM::WgmmaMmaAsyncOp::getPtx() {}

void NVVM::WgmmaMmaAsyncOp::getAsmValues(
    RewriterBase &rewriter,
    llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
        &asmValues) {}
LogicalResult NVVM::FenceProxyOp::verify() {}

LogicalResult NVVM::FenceProxyAcquireOp::verify() {}

LogicalResult NVVM::FenceProxyReleaseOp::verify() {}

LogicalResult NVVM::SetMaxRegisterOp::verify() {}

LogicalResult NVVM::BarrierOp::verify() {}

//===----------------------------------------------------------------------===//
// NVVMDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//

// TODO: This should be the llvm.nvvm dialect once this is supported.
void NVVMDialect::initialize() {}

LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
                                                    NamedAttribute attr) {}

LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op,
                                                    unsigned regionIndex,
                                                    unsigned argIndex,
                                                    NamedAttribute argAttr) {}

//===----------------------------------------------------------------------===//
// NVVM target attribute.
//===----------------------------------------------------------------------===//
LogicalResult
NVVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError,
                       int optLevel, StringRef triple, StringRef chip,
                       StringRef features, DictionaryAttr flags,
                       ArrayAttr files) {}

#define GET_OP_CLASSES
#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"

#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"