llvm/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp

//===- OptimizeSharedMemory.cpp - MLIR NVGPU pass implementation ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements transforms to optimize accesses to shared memory.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/NVGPU/Transforms/Passes.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/NVGPU/Transforms/Transforms.h"
#include "mlir/Dialect/NVGPU/Transforms/Utils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/MathExtras.h"

namespace mlir {
namespace nvgpu {
#define GEN_PASS_DEF_OPTIMIZESHAREDMEMORY
#include "mlir/Dialect/NVGPU/Transforms/Passes.h.inc"
} // namespace nvgpu
} // namespace mlir

usingnamespacemlir;
usingnamespacemlir::nvgpu;

/// The size of a shared memory line according to NV documentation.
constexpr int64_t kSharedMemoryLineSizeBytes =;
/// We optimize for 128bit accesses, but this can be made an argument in the
/// future.
constexpr int64_t kDefaultVectorSizeBits =;

/// Uses `srcIndexValue` to permute `tgtIndexValue` via
/// `result = xor(floordiv(srcIdxVal,permuteEveryN),
///               floordiv(tgtIdxVal,vectorSize)))
///            + tgtIdxVal % vectorSize`
/// This is done using an optimized sequence of `arith` operations.
static Value permuteVectorOffset(OpBuilder &b, Location loc,
                                 ArrayRef<Value> indices, MemRefType memrefTy,
                                 int64_t srcDim, int64_t tgtDim) {}

static void transformIndices(OpBuilder &builder, Location loc,
                             SmallVector<Value, 4> &indices,
                             MemRefType memrefTy, int64_t srcDim,
                             int64_t tgtDim) {}

/// Return all operations within `parentOp` that read from or write to
/// `shmMemRef`.
static LogicalResult
getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
                      SmallVector<Operation *, 16> &readOps,
                      SmallVector<Operation *, 16> &writeOps) {}

llvm::LogicalResult
mlir::nvgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
                                                Value memrefValue) {}

namespace {
class OptimizeSharedMemoryPass
    : public nvgpu::impl::OptimizeSharedMemoryBase<OptimizeSharedMemoryPass> {};
} // namespace

std::unique_ptr<Pass> mlir::nvgpu::createOptimizeSharedMemoryPass() {}