llvm/mlir/test/Integration/Dialect/Vector/CPU/AMX/mulf-full.mlir

// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine \
// RUN: -one-shot-bufferize="bufferize-function-boundaries" \
// RUN: -convert-scf-to-cf -convert-vector-to-llvm="enable-amx" \
// RUN:  -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \
// RUN: mlir-translate -mlir-to-llvmir | \
// RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" \
// RUN:  --dlopen=%mlir_c_runner_utils | \
// RUN: FileCheck %s

// Note: To run this test, your CPU must support AMX.

// Multiply full size tiles into zero destination.
func.func @kernel(%arg0: memref<16x32xbf16>,
             %arg1: memref<16x32xbf16>,
             %arg2: memref<16x16xf32>) {
  %0 = arith.constant 0 : index
  %1 = amx.tile_load %arg0[%0, %0] : memref<16x32xbf16>  into vector<16x32xbf16>
  %2 = amx.tile_load %arg1[%0, %0] : memref<16x32xbf16>  into vector<16x32xbf16>
  %3 = amx.tile_zero : vector<16x16xf32>
  %4 = amx.tile_mulf %1, %2, %3 : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
  amx.tile_store %arg2[%0, %0], %4 : memref<16x16xf32>, vector<16x16xf32>
  return
}

func.func @entry() -> i32 {
  %fu = arith.constant -1.0: f32
  %c0 = arith.constant 0: index
  %c1 = arith.constant 1: index
  %c16 = arith.constant 16: index
  %c32 = arith.constant 32: index

  // Setup simple test data.
  %0 = arith.constant dense<[
     [ 1.1, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
     [ 1.0, 1.2, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
     [ 1.0, 1.0, 1.3, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
     [ 1.0, 1.0, 1.0, 1.4, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
     [ 1.0, 1.0, 1.0, 1.0, 1.5, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
     [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.6, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
     [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.7, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
     [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.8, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
     [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.8, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
     [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.7, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
     [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.6, 1.0, 1.0, 1.0, 1.0, 1.0,
       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
     [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.5, 1.0, 1.0, 1.0, 1.0,
       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
     [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.4, 1.0, 1.0, 1.0,
       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
     [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.3, 1.0, 1.0,
       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
     [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.2, 1.0,
       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
     [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.1,
       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ]
  ]> : tensor<16x32xbf16>
  %1 = arith.constant dense<[
     [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.1, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
     [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.2, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
     [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.3, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
     [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.4, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
     [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.5, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
     [ 1.0, 1.0, 1.0, 1.0, 1.6, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
     [ 1.0, 1.0, 1.0, 1.7, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
     [ 1.0, 1.0, 1.8, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
     [ 1.0, 1.0, 1.0, 1.8, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
     [ 1.0, 1.0, 1.0, 1.0, 1.7, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
     [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.6, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
     [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.5, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
     [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.4, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
     [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.3, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
     [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.2, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ],
     [ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.1, 1.0, 1.0, 1.0, 1.0, 1.0,
       1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 ]
  ]> : tensor<16x32xbf16>

  // Set up memory.
  %a = bufferization.to_memref %0 : memref<16x32xbf16>
  %b = bufferization.to_memref %1 : memref<16x32xbf16>
  %c = memref.alloc() : memref<16x16xf32>

  // Call kernel.
  call @kernel(%a, %b, %c) : (memref<16x32xbf16>, memref<16x32xbf16>, memref<16x16xf32>) -> ()

  //
  // Print and verify the 16x16 result.
  //
  // CHECK:      ( 32.1016, 34.3984, 34.5078, 33.6953, 32.9062, 32.2031, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016 )
  // CHECK-NEXT: ( 32.2031, 34.5, 34.6094, 33.7969, 33.0284, 32.3047, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031 )
  // CHECK-NEXT: ( 32.2969, 34.5938, 34.7031, 33.8906, 33.1619, 32.3984, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969 )
  // CHECK-NEXT: ( 32.3984, 34.6953, 34.8047, 33.9922, 33.2031, 32.5, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984 )
  // CHECK-NEXT: ( 32.5, 34.7969, 34.9062, 34.0938, 33.3047, 32.6016, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5 )
  // CHECK-NEXT: ( 32.6016, 34.8984, 35.0078, 34.3739, 33.4062, 32.7031, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016 )
  // CHECK-NEXT: ( 32.7031, 35, 35.1094, 34.577, 33.5078, 32.8047, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031 )
  // CHECK-NEXT: ( 32.7969, 35.0938, 35.2031, 34.3906, 33.6016, 32.8984, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969 )
  // CHECK-NEXT: ( 32.7969, 35.0938, 35.2031, 34.3906, 33.6016, 32.8984, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969, 32.7969 )
  // CHECK-NEXT: ( 32.7031, 35, 35.4609, 34.2969, 33.5078, 32.8047, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031, 32.7031 )
  // CHECK-NEXT: ( 32.6016, 34.8984, 35.3697, 34.1953, 33.4062, 32.7031, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016, 32.6016 )
  // CHECK-NEXT: ( 32.5, 34.7969, 34.9062, 34.0938, 33.3047, 32.6016, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5, 32.5 )
  // CHECK-NEXT: ( 32.3984, 34.6953, 34.8047, 33.9922, 33.2031, 32.5, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984, 32.3984 )
  // CHECK-NEXT: ( 32.2969, 34.8025, 34.7031, 33.8906, 33.1016, 32.3984, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969, 32.2969 )
  // CHECK-NEXT: ( 32.2031, 34.6619, 34.6094, 33.7969, 33.0078, 32.3047, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031, 32.2031 )
  // CHECK-NEXT: ( 32.1016, 34.3984, 34.5078, 33.6953, 32.9062, 32.2031, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016, 32.1016 )
  //
  scf.for %i = %c0 to %c16 step %c1 {
    %v = vector.transfer_read %c[%i, %c0], %fu: memref<16x16xf32>, vector<16xf32>
    vector.print %v : vector<16xf32>
  }

  // Release resources.
  memref.dealloc %c : memref<16x16xf32>
  %i0 = arith.constant 0 : i32
  return %i0 : i32
}