llvm/mlir/test/Dialect/NVGPU/optimize-shared-memory.mlir

// RUN: mlir-opt %s -split-input-file --pass-pipeline='builtin.module(func.func(nvgpu-optimize-shared-memory))' | FileCheck %s

// CHECK: @optimize_128x32xf16_32x128xf16([[arg0:%.+]]: memref<{{.*}}>, [[ldRow:%.+]]: index, [[ldCol:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index)
func.func @optimize_128x32xf16_32x128xf16(%arg0: memref<128x128xf16>,
                               %ldRow: index, %ldCol: index,
                               %stRow: index, %stCol: index,
                               %fragRow: index, %fragCol :index)
                                -> (vector<4x2xf16>, vector<4x2xf16>) {
  // CHECK: [[shm:%.+]] = memref.alloc
  // CHECK: [[shmB:%.+]] = memref.alloc
  %shm = memref.alloc() : memref<128x32xf16, 3>
  %shmB = memref.alloc() : memref<32x128xf16, 3>

  // CHECK: [[c6:%.+]] = arith.constant 6 : index
  // CHECK: [[src_bits:%.+]] = arith.andi [[stRow]], [[c6]]
  // CHECK: [[c2:%.+]] = arith.constant 2 : index
  // CHECK: [[xorBits:%.+]] = arith.shli [[src_bits]], [[c2]]
  // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol]], [[xorBits]]
  // CHECK: nvgpu.device_async_copy [[arg0]][[[ldRow]], [[ldCol]]], [[shm]][[[stRow]], [[stColPerm]]]
  %0 = nvgpu.device_async_copy %arg0[%ldRow, %ldCol], %shm[%stRow, %stCol], 8
      : memref<128x128xf16> to memref<128x32xf16, 3>
  %1 = nvgpu.device_async_create_group %0
  nvgpu.device_async_wait %1 { numGroups = 1 : i32}

  // CHECK: [[c6:%.+]] = arith.constant 6 : index
  // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
  // CHECK: [[c2:%.+]] = arith.constant 2 : index
  // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
  // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]]
  // CHECK: nvgpu.ldmatrix [[shm]][[[fragRow]], [[fragColPerm]]]
  %mat = nvgpu.ldmatrix %shm[%fragRow, %fragCol] {numTiles = 4 : i32, transpose = false}
      : memref<128x32xf16, 3> -> vector<4x2xf16>

  // CHECK: [[c15:%.+]] = arith.constant 15 : index
  // CHECK: [[src_bits:%.+]] = arith.andi [[stRow]], [[c15]]
  // CHECK: [[c3:%.+]] = arith.constant 3 : index
  // CHECK: [[xorBits:%.+]] = arith.shli [[src_bits]], [[c3]]
  // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol]], [[xorBits]]
  // CHECK: nvgpu.device_async_copy [[arg0]][[[ldRow]], [[ldCol]]], [[shmB]][[[stRow]], [[stColPerm]]]
  %2 = nvgpu.device_async_copy %arg0[%ldRow, %ldCol], %shmB[%stRow, %stCol], 8
      : memref<128x128xf16> to memref<32x128xf16, 3>
  %3 = nvgpu.device_async_create_group %0
  nvgpu.device_async_wait %1 { numGroups = 1 : i32}

  // CHECK: [[c15:%.+]] = arith.constant 15 : index
  // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c15]]
  // CHECK: [[c3:%.+]] = arith.constant 3 : index
  // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c3]]
  // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]]
  // CHECK: nvgpu.ldmatrix [[shmB]][[[fragRow]], [[fragColPerm]]]
  %matB = nvgpu.ldmatrix %shmB[%fragRow, %fragCol] {numTiles = 4 : i32, transpose = false}
      : memref<32x128xf16, 3> -> vector<4x2xf16>

  return %mat, %matB: vector<4x2xf16>, vector<4x2xf16>
}


// -----

// CHECK: @optimize_64x16xf32_16x64xf32([[arg0:%.+]]: memref<{{.*}}>, [[ldRow:%.+]]: index, [[ldCol:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index)
func.func @optimize_64x16xf32_16x64xf32(%arg0: memref<128x128xf32>,
                               %ldRow: index, %ldCol: index,
                               %stRow: index, %stCol: index,
                               %fragRow: index, %fragCol :index)
                                -> (vector<4x1xf32>, vector<4x1xf32>, f32, vector<4xf32>, f32) {
  // CHECK: [[shm:%.+]] = memref.alloc
  // CHECK: [[shmB:%.+]] = memref.alloc
  %shm = memref.alloc() : memref<64x16xf32, 3>
  %shmB = memref.alloc() : memref<16x64xf32, 3>

  // CHECK: [[c6:%.+]] = arith.constant 6 : index
  // CHECK: [[src_bits:%.+]] = arith.andi [[stRow]], [[c6]]
  // CHECK: [[c1:%.+]] = arith.constant 1 : index
  // CHECK: [[xorBits:%.+]] = arith.shli [[src_bits]], [[c1]]
  // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol]], [[xorBits]]
  // CHECK: nvgpu.device_async_copy [[arg0]][[[ldRow]], [[ldCol]]], [[shm]][[[stRow]], [[stColPerm]]]
  %0 = nvgpu.device_async_copy %arg0[%ldRow, %ldCol], %shm[%stRow, %stCol], 4
      : memref<128x128xf32> to memref<64x16xf32, 3>
  %1 = nvgpu.device_async_create_group %0
  nvgpu.device_async_wait %1 { numGroups = 1 : i32}

  // CHECK: [[c6:%.+]] = arith.constant 6 : index
  // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
  // CHECK: [[c1:%.+]] = arith.constant 1 : index
  // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c1]]
  // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]]
  // CHECK: nvgpu.ldmatrix [[shm]][[[fragRow]], [[fragColPerm]]]
  %mat = nvgpu.ldmatrix %shm[%fragRow, %fragCol] {numTiles = 4 : i32, transpose = false}
      : memref<64x16xf32, 3> -> vector<4x1xf32>

  // CHECK: [[c6:%.+]] = arith.constant 6 : index
  // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
  // CHECK: [[c1:%.+]] = arith.constant 1 : index
  // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c1]]
  // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]]
  // CHECK: memref.load [[shm]][[[fragRow]], [[fragColPerm]]]
  %elem = memref.load %shm[%fragRow, %fragCol] : memref<64x16xf32, 3>

  // Verify vector operations.

  // CHECK: [[c6:%.+]] = arith.constant 6 : index
  // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
  // CHECK: [[c1:%.+]] = arith.constant 1 : index
  // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c1]]
  // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]]
  // CHECK: vector.load [[shm]][[[fragRow]], [[fragColPerm]]]
  %elem2 = vector.load %shm[%fragRow, %fragCol] : memref<64x16xf32, 3>, vector<4xf32>

  // CHECK: [[c6:%.+]] = arith.constant 6 : index
  // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
  // CHECK: [[c1:%.+]] = arith.constant 1 : index
  // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c1]]
  // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]]
  // CHECK: vector.store %{{.+}}, [[shm]][[[fragRow]], [[fragColPerm]]]
  vector.store %elem2, %shm[%fragRow, %fragCol] : memref<64x16xf32, 3>, vector<4xf32>

  // CHECK: [[c6:%.+]] = arith.constant 6 : index
  // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
  // CHECK: [[c1:%.+]] = arith.constant 1 : index
  // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c1]]
  // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]]
  // CHECK: memref.store %{{.+}}, [[shm]][[[fragRow]], [[fragColPerm]]]
  memref.store %elem, %shm[%fragRow, %fragCol] : memref<64x16xf32, 3>

  // Verify 16x64xf32 memory size.

  // CHECK: [[c15:%.+]] = arith.constant 15 : index
  // CHECK: [[src_bits:%.+]] = arith.andi [[stRow]], [[c15]]
  // CHECK: [[c2:%.+]] = arith.constant 2 : index
  // CHECK: [[xorBits:%.+]] = arith.shli [[src_bits]], [[c2]]
  // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol]], [[xorBits]]
  // CHECK: nvgpu.device_async_copy [[arg0]][[[ldRow]], [[ldCol]]], [[shmB]][[[stRow]], [[stColPerm]]]
  %2 = nvgpu.device_async_copy %arg0[%ldRow, %ldCol], %shmB[%stRow, %stCol], 4
      : memref<128x128xf32> to memref<16x64xf32, 3>
  %3 = nvgpu.device_async_create_group %0
  nvgpu.device_async_wait %1 { numGroups = 1 : i32}

  // CHECK: [[c15:%.+]] = arith.constant 15 : index
  // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c15]]
  // CHECK: [[c2:%.+]] = arith.constant 2 : index
  // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
  // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]]
  // CHECK: nvgpu.ldmatrix [[shmB]][[[fragRow]], [[fragColPerm]]]
  %matB = nvgpu.ldmatrix %shmB[%fragRow, %fragCol] {numTiles = 4 : i32, transpose = false}
      : memref<16x64xf32, 3> -> vector<4x1xf32>

  // CHECK: [[c15:%.+]] = arith.constant 15 : index
  // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c15]]
  // CHECK: [[c2:%.+]] = arith.constant 2 : index
  // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
  // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]]
  // CHECK: memref.load [[shmB]][[[fragRow]], [[fragColPerm]]]
  %elemB = memref.load %shmB[%fragRow, %fragCol] : memref<16x64xf32, 3>

  return %mat, %matB, %elem, %elem2, %elemB: vector<4x1xf32>, vector<4x1xf32>, f32, vector<4xf32>, f32
}


// -----

// Small column edge cases

// CHECK: @small_column_size_f64([[arg0:%.+]]: memref<{{.*}}>, [[ldRow:%.+]]: index, [[ldCol:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index)
func.func @small_column_size_f64(%arg0: memref<32x32xf64>,
                               %ldRow: index, %ldCol: index,
                               %stRow: index, %stCol: index,
                               %fragRow: index, %fragCol :index)
                                -> f64 {
  // CHECK: [[shm:%.+]] = memref.alloc
  %shm = memref.alloc() : memref<32x4xf64, 3>

  // CHECK: [[c4:%.+]] = arith.constant 4 : index
  // CHECK: [[src_bits:%.+]] = arith.andi [[stRow]], [[c4]]
  // CHECK: [[c1:%.+]] = arith.constant 1 : index
  // CHECK: [[xorBits:%.+]] = arith.shrui [[src_bits]], [[c1]]
  // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol]], [[xorBits]]
  // CHECK: nvgpu.device_async_copy [[arg0]][[[ldRow]], [[ldCol]]], [[shm]][[[stRow]], [[stColPerm]]]
  %0 = nvgpu.device_async_copy %arg0[%ldRow, %ldCol], %shm[%stRow, %stCol], 2
      : memref<32x32xf64> to memref<32x4xf64, 3>
  %1 = nvgpu.device_async_create_group %0
  nvgpu.device_async_wait %1 { numGroups = 1 : i32}

  // CHECK: [[c6:%.+]] = arith.constant 4 : index
  // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
  // CHECK: [[c1:%.+]] = arith.constant 1 : index
  // CHECK: [[xorBits:%.+]] = arith.shrui [[srcBits]], [[c1]]
  // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]]
  // CHECK: memref.load [[shm]][[[fragRow]], [[fragColPerm]]]
  %el = memref.load %shm[%fragRow, %fragCol] : memref<32x4xf64, 3>

  return %el: f64
}

// CHECK: @too_small_column_size_f16([[arg0:%.+]]: memref<{{.*}}>, [[ldRow:%.+]]: index, [[ldCol:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index)
func.func @too_small_column_size_f16(%arg0: memref<128x128xf16>,
                               %ldRow: index, %ldCol: index,
                               %stRow: index, %stCol: index,
                               %fragRow: index, %fragCol :index)
                                -> vector<1x2xf16> {
  // CHECK: [[shm:%.+]] = memref.alloc
  %shm = memref.alloc() : memref<128x8xf16, 3>

  // CHECK: nvgpu.device_async_copy [[arg0]][[[ldRow]], [[ldCol]]], [[shm]][[[stRow]], [[stCol]]]
  %0 = nvgpu.device_async_copy %arg0[%ldRow, %ldCol], %shm[%stRow, %stCol], 8
      : memref<128x128xf16> to memref<128x8xf16, 3>
  %1 = nvgpu.device_async_create_group %0
  nvgpu.device_async_wait %1 { numGroups = 1 : i32}

  // CHECK: nvgpu.ldmatrix [[shm]][[[fragRow]], [[fragCol]]]
  %mat = nvgpu.ldmatrix %shm[%fragRow, %fragCol] {numTiles = 1 : i32, transpose = false}
      : memref<128x8xf16, 3> -> vector<1x2xf16>

  return %mat: vector<1x2xf16>
}

// -----

// CHECK: @abort_if_subview([[arg0:%.+]]: memref<{{.*}}>, [[ldRow:%.+]]: index, [[ldCol:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index)
func.func @abort_if_subview(%arg0: memref<128x128xf16>,
                               %ldRow: index, %ldCol: index,
                               %stRow: index, %stCol: index,
                               %fragRow: index, %fragCol :index)
                                -> vector<1x2xf16> {
  // CHECK: [[shm:%.+]] = memref.alloc
  %shm = memref.alloc() : memref<128x32xf16, 3>
  // CHECK: [[shmView:%.+]] = memref.subview
  %shmView = memref.subview %shm[0, 0][64, 32][1, 1] : memref<128x32xf16, 3> to memref<64x32xf16, 3>

  // CHECK: nvgpu.device_async_copy [[arg0]][[[ldRow]], [[ldCol]]], [[shm]][[[stRow]], [[stCol]]]
  %0 = nvgpu.device_async_copy %arg0[%ldRow, %ldCol], %shm[%stRow, %stCol], 8
      : memref<128x128xf16> to memref<128x32xf16, 3>
  %1 = nvgpu.device_async_create_group %0
  nvgpu.device_async_wait %1 { numGroups = 1 : i32}

  // CHECK: nvgpu.ldmatrix [[shmView]][[[fragRow]], [[fragCol]]]
  %mat = nvgpu.ldmatrix %shmView[%fragRow, %fragCol] {numTiles = 1 : i32, transpose = false}
      : memref<64x32xf16, 3> -> vector<1x2xf16>

  return %mat: vector<1x2xf16>
}