llvm/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/outerproduct-f16f16f32.mlir

// DEFINE: %{opts} =
// DEFINE: %{entry} = main
// DEFINE: %{compile} = mlir-opt %s \
// DEFINE:   -test-lower-to-arm-sme=%{opts} -test-lower-to-llvm -o %t
// DEFINE: %{run} = %mcr_aarch64_cmd %t \
// DEFINE:   -march=aarch64 -mattr=+sve,+sme \
// DEFINE:   -e %{entry} -entry-point-result=void \
// DEFINE:   -shared-libs=%native_mlir_runner_utils,%native_mlir_c_runner_utils,%native_mlir_arm_runner_utils,%native_arm_sme_abi_shlib

// RUN: %{compile}

// RUN: %{run} | FileCheck %s

// Check result is the same when outerproducts are not combined into widening
// variant.

// REDEFINE: %{opts} = fuse-outer-products=false
// RUN: %{run} | FileCheck %s

func.func @main() {
  %c128 = arith.constant 128 : i32
  func.call @setArmSVLBits(%c128) : (i32) -> ()

  func.call @test_outerproduct_f16f16f32() : () -> ()

  // TODO: A bug in QEMU causes masked FMOPAs to hang [1]. Should be fixed in
  // 8.2.0, this test currently isn't run, once this version is available in CI
  // it can be run. The output without check lines in the function are correct
  // and have been verified on a version with the fix.
  // [1] https://gitlab.com/qemu-project/qemu/-/issues/1985
  //func.call @test_masked_outerproduct_f16f16f32() : () -> ()

  return
}

func.func @test_outerproduct_f16f16f32() {
  %undef = llvm.mlir.undef : vector<[4]xf16>

  %a0_data = arith.constant dense<[0., 2., 4., 6.]> : vector<4xf16>
  %b0_data = arith.constant dense<[1., 3., 5., 7.]> : vector<4xf16>
  %a1_data = arith.constant dense<[8., 10., 12., 14.]> : vector<4xf16>
  %b1_data = arith.constant dense<[9., 11., 13., 15.]> : vector<4xf16>

  %a0 = vector.scalable.insert %a0_data, %undef[0] : vector<4xf16> into vector<[4]xf16>
  %b0 = vector.scalable.insert %b0_data, %undef[0] : vector<4xf16> into vector<[4]xf16>
  %a1 = vector.scalable.insert %a1_data, %undef[0] : vector<4xf16> into vector<[4]xf16>
  %b1 = vector.scalable.insert %b1_data, %undef[0] : vector<4xf16> into vector<[4]xf16>

  %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
  %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
  %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
  %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>

  %acc = arith.constant dense<7.0> : vector<[4]x[4]xf32>
  %0 = vector.outerproduct %a0_ext, %b0_ext, %acc : vector<[4]xf32>, vector<[4]xf32>
  %1 = vector.outerproduct %a1_ext, %b1_ext, %0 : vector<[4]xf32>, vector<[4]xf32>

  // CHECK:      (  79,  95, 111, 127 )
  // CHECK-NEXT: (  99, 123, 147, 171 )
  // CHECK-NEXT: ( 119, 151, 183, 215 )
  // CHECK-NEXT: ( 139, 179, 219, 259 )
  vector.print %1 : vector<[4]x[4]xf32>

  return
}

func.func @test_masked_outerproduct_f16f16f32() {
  %undef = llvm.mlir.undef : vector<[4]xf16>

  %a0_data = arith.constant dense<[0., 2., 4., 6.]> : vector<4xf16>
  %b0_data = arith.constant dense<[1., 3., 5., 7.]> : vector<4xf16>
  %a1_data = arith.constant dense<[8., 10., 12., 14.]> : vector<4xf16>
  %b1_data = arith.constant dense<[9., 11., 13., 15.]> : vector<4xf16>

  %a0 = vector.scalable.insert %a0_data, %undef[0] : vector<4xf16> into vector<[4]xf16>
  %b0 = vector.scalable.insert %b0_data, %undef[0] : vector<4xf16> into vector<[4]xf16>
  %a1 = vector.scalable.insert %a1_data, %undef[0] : vector<4xf16> into vector<[4]xf16>
  %b1 = vector.scalable.insert %b1_data, %undef[0] : vector<4xf16> into vector<[4]xf16>

  %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
  %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
  %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
  %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>

  %acc = arith.constant dense<7.0> : vector<[4]x[4]xf32>

  %c2 = arith.constant 2 : index
  %c3 = arith.constant 3 : index
  %mask0 = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1>
  %mask1 = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1>

  %0 = vector.mask %mask0 {
    vector.outerproduct %a0_ext, %b0_ext, %acc : vector<[4]xf32>, vector<[4]xf32>
  } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>

  %1 = vector.mask %mask1 {
    vector.outerproduct %a1_ext, %b1_ext, %0 : vector<[4]xf32>, vector<[4]xf32>
  } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>

  // TODO: CHECK these lines once QEMU is fixed.
  // (  79,  95,  7, 7 )
  // (  99, 123, 17, 7 )
  // ( 115, 139,  7, 7 )
  // (   7,   7,  7, 7 )
  vector.print %1 : vector<[4]x[4]xf32>

  return
}

func.func private @setArmSVLBits(%bits : i32)