#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "llvm/ADT/STLForwardCompat.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/MathExtras.h"
#include <optional>
usingnamespacemlir;
usingnamespacemlir::gpu;
static constexpr uint64_t kMaxDim = …;
static constexpr uint64_t kMaxClusterDim = …;
static constexpr uint64_t kMaxSubgroupSize = …;
static ConstantIntRanges getIndexRange(uint64_t umin, uint64_t umax) { … }
namespace {
enum class LaunchDims : uint32_t { … };
}
static Value valueByDim(KernelDim3 dims, Dimension dim) { … }
static uint64_t zext(uint32_t arg) { … }
static std::optional<uint64_t>
getKnownLaunchAttr(GPUFuncOp func, LaunchDims dims, Dimension dim) { … }
static std::optional<uint64_t> getKnownLaunchAttr(FunctionOpInterface func,
StringRef attrName,
Dimension dim) { … }
template <typename Op>
static std::optional<uint64_t> getKnownLaunchDim(Op op, LaunchDims type) { … }
void ClusterDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
SetIntRangeFn setResultRange) { … }
void ClusterDimBlocksOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
SetIntRangeFn setResultRange) { … }
void ClusterIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
SetIntRangeFn setResultRange) { … }
void ClusterBlockIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
SetIntRangeFn setResultRange) { … }
void BlockDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
SetIntRangeFn setResultRange) { … }
void BlockIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
SetIntRangeFn setResultRange) { … }
void GridDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
SetIntRangeFn setResultRange) { … }
void ThreadIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
SetIntRangeFn setResultRange) { … }
void LaneIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
SetIntRangeFn setResultRange) { … }
void SubgroupIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
SetIntRangeFn setResultRange) { … }
void GlobalIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
SetIntRangeFn setResultRange) { … }
void NumSubgroupsOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
SetIntRangeFn setResultRange) { … }
void SubgroupSizeOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
SetIntRangeFn setResultRange) { … }
void LaunchOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) { … }