// RUN: mlir-opt %s -split-input-file --pass-pipeline='builtin.module(func.func(nvgpu-optimize-shared-memory))' | FileCheck %s
// CHECK: @optimize_128x32xf16_32x128xf16([[arg0:%.+]]: memref<{{.*}}>, [[ldRow:%.+]]: index, [[ldCol:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index)
func.func @optimize_128x32xf16_32x128xf16(%arg0: memref<128x128xf16>,
%ldRow: index, %ldCol: index,
%stRow: index, %stCol: index,
%fragRow: index, %fragCol :index)
-> (vector<4x2xf16>, vector<4x2xf16>) {
// CHECK: [[shm:%.+]] = memref.alloc
// CHECK: [[shmB:%.+]] = memref.alloc
%shm = memref.alloc() : memref<128x32xf16, 3>
%shmB = memref.alloc() : memref<32x128xf16, 3>
// CHECK: [[c6:%.+]] = arith.constant 6 : index
// CHECK: [[src_bits:%.+]] = arith.andi [[stRow]], [[c6]]
// CHECK: [[c2:%.+]] = arith.constant 2 : index
// CHECK: [[xorBits:%.+]] = arith.shli [[src_bits]], [[c2]]
// CHECK: [[stColPerm:%.+]] = arith.xori [[stCol]], [[xorBits]]
// CHECK: nvgpu.device_async_copy [[arg0]][[[ldRow]], [[ldCol]]], [[shm]][[[stRow]], [[stColPerm]]]
%0 = nvgpu.device_async_copy %arg0[%ldRow, %ldCol], %shm[%stRow, %stCol], 8
: memref<128x128xf16> to memref<128x32xf16, 3>
%1 = nvgpu.device_async_create_group %0
nvgpu.device_async_wait %1 { numGroups = 1 : i32}
// CHECK: [[c6:%.+]] = arith.constant 6 : index
// CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
// CHECK: [[c2:%.+]] = arith.constant 2 : index
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
// CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]]
// CHECK: nvgpu.ldmatrix [[shm]][[[fragRow]], [[fragColPerm]]]
%mat = nvgpu.ldmatrix %shm[%fragRow, %fragCol] {numTiles = 4 : i32, transpose = false}
: memref<128x32xf16, 3> -> vector<4x2xf16>
// CHECK: [[c15:%.+]] = arith.constant 15 : index
// CHECK: [[src_bits:%.+]] = arith.andi [[stRow]], [[c15]]
// CHECK: [[c3:%.+]] = arith.constant 3 : index
// CHECK: [[xorBits:%.+]] = arith.shli [[src_bits]], [[c3]]
// CHECK: [[stColPerm:%.+]] = arith.xori [[stCol]], [[xorBits]]
// CHECK: nvgpu.device_async_copy [[arg0]][[[ldRow]], [[ldCol]]], [[shmB]][[[stRow]], [[stColPerm]]]
%2 = nvgpu.device_async_copy %arg0[%ldRow, %ldCol], %shmB[%stRow, %stCol], 8
: memref<128x128xf16> to memref<32x128xf16, 3>
%3 = nvgpu.device_async_create_group %0
nvgpu.device_async_wait %1 { numGroups = 1 : i32}
// CHECK: [[c15:%.+]] = arith.constant 15 : index
// CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c15]]
// CHECK: [[c3:%.+]] = arith.constant 3 : index
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c3]]
// CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]]
// CHECK: nvgpu.ldmatrix [[shmB]][[[fragRow]], [[fragColPerm]]]
%matB = nvgpu.ldmatrix %shmB[%fragRow, %fragCol] {numTiles = 4 : i32, transpose = false}
: memref<32x128xf16, 3> -> vector<4x2xf16>
return %mat, %matB: vector<4x2xf16>, vector<4x2xf16>
}
// -----
// CHECK: @optimize_64x16xf32_16x64xf32([[arg0:%.+]]: memref<{{.*}}>, [[ldRow:%.+]]: index, [[ldCol:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index)
func.func @optimize_64x16xf32_16x64xf32(%arg0: memref<128x128xf32>,
%ldRow: index, %ldCol: index,
%stRow: index, %stCol: index,
%fragRow: index, %fragCol :index)
-> (vector<4x1xf32>, vector<4x1xf32>, f32, vector<4xf32>, f32) {
// CHECK: [[shm:%.+]] = memref.alloc
// CHECK: [[shmB:%.+]] = memref.alloc
%shm = memref.alloc() : memref<64x16xf32, 3>
%shmB = memref.alloc() : memref<16x64xf32, 3>
// CHECK: [[c6:%.+]] = arith.constant 6 : index
// CHECK: [[src_bits:%.+]] = arith.andi [[stRow]], [[c6]]
// CHECK: [[c1:%.+]] = arith.constant 1 : index
// CHECK: [[xorBits:%.+]] = arith.shli [[src_bits]], [[c1]]
// CHECK: [[stColPerm:%.+]] = arith.xori [[stCol]], [[xorBits]]
// CHECK: nvgpu.device_async_copy [[arg0]][[[ldRow]], [[ldCol]]], [[shm]][[[stRow]], [[stColPerm]]]
%0 = nvgpu.device_async_copy %arg0[%ldRow, %ldCol], %shm[%stRow, %stCol], 4
: memref<128x128xf32> to memref<64x16xf32, 3>
%1 = nvgpu.device_async_create_group %0
nvgpu.device_async_wait %1 { numGroups = 1 : i32}
// CHECK: [[c6:%.+]] = arith.constant 6 : index
// CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
// CHECK: [[c1:%.+]] = arith.constant 1 : index
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c1]]
// CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]]
// CHECK: nvgpu.ldmatrix [[shm]][[[fragRow]], [[fragColPerm]]]
%mat = nvgpu.ldmatrix %shm[%fragRow, %fragCol] {numTiles = 4 : i32, transpose = false}
: memref<64x16xf32, 3> -> vector<4x1xf32>
// CHECK: [[c6:%.+]] = arith.constant 6 : index
// CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
// CHECK: [[c1:%.+]] = arith.constant 1 : index
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c1]]
// CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]]
// CHECK: memref.load [[shm]][[[fragRow]], [[fragColPerm]]]
%elem = memref.load %shm[%fragRow, %fragCol] : memref<64x16xf32, 3>
// Verify vector operations.
// CHECK: [[c6:%.+]] = arith.constant 6 : index
// CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
// CHECK: [[c1:%.+]] = arith.constant 1 : index
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c1]]
// CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]]
// CHECK: vector.load [[shm]][[[fragRow]], [[fragColPerm]]]
%elem2 = vector.load %shm[%fragRow, %fragCol] : memref<64x16xf32, 3>, vector<4xf32>
// CHECK: [[c6:%.+]] = arith.constant 6 : index
// CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
// CHECK: [[c1:%.+]] = arith.constant 1 : index
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c1]]
// CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]]
// CHECK: vector.store %{{.+}}, [[shm]][[[fragRow]], [[fragColPerm]]]
vector.store %elem2, %shm[%fragRow, %fragCol] : memref<64x16xf32, 3>, vector<4xf32>
// CHECK: [[c6:%.+]] = arith.constant 6 : index
// CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
// CHECK: [[c1:%.+]] = arith.constant 1 : index
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c1]]
// CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]]
// CHECK: memref.store %{{.+}}, [[shm]][[[fragRow]], [[fragColPerm]]]
memref.store %elem, %shm[%fragRow, %fragCol] : memref<64x16xf32, 3>
// Verify 16x64xf32 memory size.
// CHECK: [[c15:%.+]] = arith.constant 15 : index
// CHECK: [[src_bits:%.+]] = arith.andi [[stRow]], [[c15]]
// CHECK: [[c2:%.+]] = arith.constant 2 : index
// CHECK: [[xorBits:%.+]] = arith.shli [[src_bits]], [[c2]]
// CHECK: [[stColPerm:%.+]] = arith.xori [[stCol]], [[xorBits]]
// CHECK: nvgpu.device_async_copy [[arg0]][[[ldRow]], [[ldCol]]], [[shmB]][[[stRow]], [[stColPerm]]]
%2 = nvgpu.device_async_copy %arg0[%ldRow, %ldCol], %shmB[%stRow, %stCol], 4
: memref<128x128xf32> to memref<16x64xf32, 3>
%3 = nvgpu.device_async_create_group %0
nvgpu.device_async_wait %1 { numGroups = 1 : i32}
// CHECK: [[c15:%.+]] = arith.constant 15 : index
// CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c15]]
// CHECK: [[c2:%.+]] = arith.constant 2 : index
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
// CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]]
// CHECK: nvgpu.ldmatrix [[shmB]][[[fragRow]], [[fragColPerm]]]
%matB = nvgpu.ldmatrix %shmB[%fragRow, %fragCol] {numTiles = 4 : i32, transpose = false}
: memref<16x64xf32, 3> -> vector<4x1xf32>
// CHECK: [[c15:%.+]] = arith.constant 15 : index
// CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c15]]
// CHECK: [[c2:%.+]] = arith.constant 2 : index
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
// CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]]
// CHECK: memref.load [[shmB]][[[fragRow]], [[fragColPerm]]]
%elemB = memref.load %shmB[%fragRow, %fragCol] : memref<16x64xf32, 3>
return %mat, %matB, %elem, %elem2, %elemB: vector<4x1xf32>, vector<4x1xf32>, f32, vector<4xf32>, f32
}
// -----
// Small column edge cases
// CHECK: @small_column_size_f64([[arg0:%.+]]: memref<{{.*}}>, [[ldRow:%.+]]: index, [[ldCol:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index)
func.func @small_column_size_f64(%arg0: memref<32x32xf64>,
%ldRow: index, %ldCol: index,
%stRow: index, %stCol: index,
%fragRow: index, %fragCol :index)
-> f64 {
// CHECK: [[shm:%.+]] = memref.alloc
%shm = memref.alloc() : memref<32x4xf64, 3>
// CHECK: [[c4:%.+]] = arith.constant 4 : index
// CHECK: [[src_bits:%.+]] = arith.andi [[stRow]], [[c4]]
// CHECK: [[c1:%.+]] = arith.constant 1 : index
// CHECK: [[xorBits:%.+]] = arith.shrui [[src_bits]], [[c1]]
// CHECK: [[stColPerm:%.+]] = arith.xori [[stCol]], [[xorBits]]
// CHECK: nvgpu.device_async_copy [[arg0]][[[ldRow]], [[ldCol]]], [[shm]][[[stRow]], [[stColPerm]]]
%0 = nvgpu.device_async_copy %arg0[%ldRow, %ldCol], %shm[%stRow, %stCol], 2
: memref<32x32xf64> to memref<32x4xf64, 3>
%1 = nvgpu.device_async_create_group %0
nvgpu.device_async_wait %1 { numGroups = 1 : i32}
// CHECK: [[c6:%.+]] = arith.constant 4 : index
// CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
// CHECK: [[c1:%.+]] = arith.constant 1 : index
// CHECK: [[xorBits:%.+]] = arith.shrui [[srcBits]], [[c1]]
// CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol]], [[xorBits]]
// CHECK: memref.load [[shm]][[[fragRow]], [[fragColPerm]]]
%el = memref.load %shm[%fragRow, %fragCol] : memref<32x4xf64, 3>
return %el: f64
}
// CHECK: @too_small_column_size_f16([[arg0:%.+]]: memref<{{.*}}>, [[ldRow:%.+]]: index, [[ldCol:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index)
func.func @too_small_column_size_f16(%arg0: memref<128x128xf16>,
%ldRow: index, %ldCol: index,
%stRow: index, %stCol: index,
%fragRow: index, %fragCol :index)
-> vector<1x2xf16> {
// CHECK: [[shm:%.+]] = memref.alloc
%shm = memref.alloc() : memref<128x8xf16, 3>
// CHECK: nvgpu.device_async_copy [[arg0]][[[ldRow]], [[ldCol]]], [[shm]][[[stRow]], [[stCol]]]
%0 = nvgpu.device_async_copy %arg0[%ldRow, %ldCol], %shm[%stRow, %stCol], 8
: memref<128x128xf16> to memref<128x8xf16, 3>
%1 = nvgpu.device_async_create_group %0
nvgpu.device_async_wait %1 { numGroups = 1 : i32}
// CHECK: nvgpu.ldmatrix [[shm]][[[fragRow]], [[fragCol]]]
%mat = nvgpu.ldmatrix %shm[%fragRow, %fragCol] {numTiles = 1 : i32, transpose = false}
: memref<128x8xf16, 3> -> vector<1x2xf16>
return %mat: vector<1x2xf16>
}
// -----
// CHECK: @abort_if_subview([[arg0:%.+]]: memref<{{.*}}>, [[ldRow:%.+]]: index, [[ldCol:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index)
func.func @abort_if_subview(%arg0: memref<128x128xf16>,
%ldRow: index, %ldCol: index,
%stRow: index, %stCol: index,
%fragRow: index, %fragCol :index)
-> vector<1x2xf16> {
// CHECK: [[shm:%.+]] = memref.alloc
%shm = memref.alloc() : memref<128x32xf16, 3>
// CHECK: [[shmView:%.+]] = memref.subview
%shmView = memref.subview %shm[0, 0][64, 32][1, 1] : memref<128x32xf16, 3> to memref<64x32xf16, 3>
// CHECK: nvgpu.device_async_copy [[arg0]][[[ldRow]], [[ldCol]]], [[shm]][[[stRow]], [[stCol]]]
%0 = nvgpu.device_async_copy %arg0[%ldRow, %ldCol], %shm[%stRow, %stCol], 8
: memref<128x128xf16> to memref<128x32xf16, 3>
%1 = nvgpu.device_async_create_group %0
nvgpu.device_async_wait %1 { numGroups = 1 : i32}
// CHECK: nvgpu.ldmatrix [[shmView]][[[fragRow]], [[fragCol]]]
%mat = nvgpu.ldmatrix %shmView[%fragRow, %fragCol] {numTiles = 1 : i32, transpose = false}
: memref<64x32xf16, 3> -> vector<1x2xf16>
return %mat: vector<1x2xf16>
}