// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine \
// RUN: -one-shot-bufferize="bufferize-function-boundaries" \
// RUN: -convert-scf-to-cf -convert-vector-to-llvm="enable-amx" \
// RUN: -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
// RUN: mlir-translate -mlir-to-llvmir | \
// RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" \
// RUN: --dlopen=%mlir_c_runner_utils | \
// RUN: FileCheck %s
// Note: To run this test, your CPU must support AMX.
// Multiply full size tiles into zero destination.
func.func @kernel(%arg0: memref<16x32xbf16>,
%arg1: memref<16x32xbf16>,
%arg2: memref<16x16xf32>) {
%0 = arith.constant 0 : index
%1 = amx.tile_load %arg0[%0, %0] : memref<16x32xbf16> into vector<16x32xbf16>
%2 = amx.tile_load %arg1[%0, %0] : memref<16x32xbf16> into vector<16x32xbf16>
%3 = amx.tile_zero : vector<16x16xf32>
%4 = amx.tile_mulf %1, %2, %3 : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
amx.tile_store %arg2[%0, %0], %4 : memref<16x16xf32>, vector<16x16xf32>
return
}
func.func @entry() -> i32 {
%fu = arith.constant -1.0: f32
%c0 = arith.constant 0: index
%c1 = arith.constant 1: index
%c16 = arith.constant 16: index
%c32 = arith.constant 32: index
// Setup simple test data.
%0 = arith.constant dense<[
[ 1.1, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.2, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.3, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.4, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.5, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.0, 1.6, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.7, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.8, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.8, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.7, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.6, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.5, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.4, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.3, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.2, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.1,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ]
]> : tensor<16x32xbf16>
%1 = arith.constant dense<[
[ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.1, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.2, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.3, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.4, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.0, 1.5, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.6, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.7, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.8, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.8, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.7, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.0, 1.6, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.5, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.4, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.3, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.2, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
[ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.1, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ]
]> : tensor<16x32xbf16>
// Set up memory.
%a = bufferization.to_memref %0 : memref<16x32xbf16>
%b = bufferization.to_memref %1 : memref<16x32xbf16>
%c = memref.alloc() : memref<16x16xf32>
// Call kernel.
call @kernel(%a, %b, %c) : (memref<16x32xbf16>, memref<16x32xbf16>, memref<16x16xf32>) -> ()
//
// Print and verify the 16x16 result.
//
// CHECK: ( 32.1016, 34.3984, 34.5078, 33.6953, 32.9062, 32.2031, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016 )
// CHECK-NEXT: ( 32.2031, 34.5, 34.6094, 33.7969, 33.0284, 32.3047, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031 )
// CHECK-NEXT: ( 32.2969, 34.5938, 34.7031, 33.8906, 33.1619, 32.3984, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969 )
// CHECK-NEXT: ( 32.3984, 34.6953, 34.8047, 33.9922, 33.2031, 32.5, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984 )
// CHECK-NEXT: ( 32.5, 34.7969, 34.9062, 34.0938, 33.3047, 32.6016, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5 )
// CHECK-NEXT: ( 32.6016, 34.8984, 35.0078, 34.3739, 33.4062, 32.7031, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016 )
// CHECK-NEXT: ( 32.7031, 35, 35.1094, 34.577, 33.5078, 32.8047, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031 )
// CHECK-NEXT: ( 32.7969, 35.0938, 35.2031, 34.3906, 33.6016, 32.8984, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969 )
// CHECK-NEXT: ( 32.7969, 35.0938, 35.2031, 34.3906, 33.6016, 32.8984, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969 )
// CHECK-NEXT: ( 32.7031, 35, 35.4609, 34.2969, 33.5078, 32.8047, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031 )
// CHECK-NEXT: ( 32.6016, 34.8984, 35.3697, 34.1953, 33.4062, 32.7031, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016 )
// CHECK-NEXT: ( 32.5, 34.7969, 34.9062, 34.0938, 33.3047, 32.6016, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5 )
// CHECK-NEXT: ( 32.3984, 34.6953, 34.8047, 33.9922, 33.2031, 32.5, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984 )
// CHECK-NEXT: ( 32.2969, 34.8025, 34.7031, 33.8906, 33.1016, 32.3984, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969 )
// CHECK-NEXT: ( 32.2031, 34.6619, 34.6094, 33.7969, 33.0078, 32.3047, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031 )
// CHECK-NEXT: ( 32.1016, 34.3984, 34.5078, 33.6953, 32.9062, 32.2031, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016 )
//
scf.for %i = %c0 to %c16 step %c1 {
%v = vector.transfer_read %c[%i, %c0], %fu: memref<16x16xf32>, vector<16xf32>
vector.print %v : vector<16xf32>
}
// Release resources.
memref.dealloc %c : memref<16x16xf32>
%i0 = arith.constant 0 : i32
return %i0 : i32
}