llvm/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/outerproduct-i8i8i32.mlir

// DEFINE: %{entry} = main
// DEFINE: %{compile} = mlir-opt %s -test-lower-to-arm-sme -test-lower-to-llvm
// DEFINE: %{run} = %mcr_aarch64_cmd \
// DEFINE:   -march=aarch64 -mattr=+sve,+sme \
// DEFINE:   -e %{entry} -entry-point-result=void \
// DEFINE:   -shared-libs=%native_mlir_runner_utils,%native_mlir_c_runner_utils,%native_mlir_arm_runner_utils,%native_arm_sme_abi_shlib

// RUN: %{compile} | %{run} | FileCheck %s

// NOTE: QEMU <= 8.2.0 gives incorrect result for SME SMOPA 4-way outer product
// instruction see: https://gitlab.com/qemu-project/qemu/-/issues/2083.

// NOTE: there is no non-widening variant for these types and this test can't
// be lowered without the widening pass, therefore we can't check if the result
// is the same without widening pass like 'test-outerproduct-f16f16f32.mlir'
// does.

func.func @main() {
  %c128 = arith.constant 128 : i32
  func.call @setArmSVLBits(%c128) : (i32) -> ()

  func.call @test_outerproduct_i8i8i32 () : () -> ()

  func.call @test_masked_outerproduct_i8i8i32() : () -> ()

  return
}

func.func @test_outerproduct_i8i8i32() {
  %undef = llvm.mlir.undef : vector<[4]xi8>

  %a0_data = arith.constant dense<[0, 4, 8, 12]> : vector<4xi8>
  %a1_data = arith.constant dense<[1, 5, 9, 13]> : vector<4xi8>
  %a2_data = arith.constant dense<[2, 6, 10, 14]> : vector<4xi8>
  %a3_data = arith.constant dense<[3, 7, 11, 15]> : vector<4xi8>

  %b0_data = arith.constant dense<[16, 20, 24, 28]> : vector<4xi8>
  %b1_data = arith.constant dense<[17, 21, 25, 29]> : vector<4xi8>
  %b2_data = arith.constant dense<[18, 22, 26, 30]> : vector<4xi8>
  %b3_data = arith.constant dense<[19, 23, 27, 31]> : vector<4xi8>

  %a0 = vector.scalable.insert %a0_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
  %b0 = vector.scalable.insert %b0_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
  %a1 = vector.scalable.insert %a1_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
  %b1 = vector.scalable.insert %b1_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
  %a2 = vector.scalable.insert %a2_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
  %b2 = vector.scalable.insert %b2_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
  %a3 = vector.scalable.insert %a3_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
  %b3 = vector.scalable.insert %b3_data, %undef[0] : vector<4xi8> into vector<[4]xi8>

  %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
  %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
  %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
  %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
  %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
  %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
  %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
  %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>

  %0 = vector.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32>
  %1 = vector.outerproduct %a1_ext, %b1_ext, %0 : vector<[4]xi32>, vector<[4]xi32>
  %2 = vector.outerproduct %a2_ext, %b2_ext, %1 : vector<[4]xi32>, vector<[4]xi32>
  %3 = vector.outerproduct %a3_ext, %b3_ext, %2 : vector<[4]xi32>, vector<[4]xi32>

  // CHECK:      ( 110,  134,  158,  182 )
  // CHECK-NEXT: ( 390,  478,  566,  654 )
  // CHECK-NEXT: ( 670,  822,  974, 1126 )
  // CHECK-NEXT: ( 950, 1166, 1382, 1598 )
  vector.print %3 : vector<[4]x[4]xi32>

  return
}

func.func @test_masked_outerproduct_i8i8i32() {
  %undef = llvm.mlir.undef : vector<[4]xi8>

  %a0_data = arith.constant dense<[0, 4, 8, 12]> : vector<4xi8>
  %a1_data = arith.constant dense<[1, 5, 9, 13]> : vector<4xi8>
  %a2_data = arith.constant dense<[2, 6, 10, 14]> : vector<4xi8>
  %a3_data = arith.constant dense<[3, 7, 11, 15]> : vector<4xi8>

  %b0_data = arith.constant dense<[16, 20, 24, 28]> : vector<4xi8>
  %b1_data = arith.constant dense<[17, 21, 25, 29]> : vector<4xi8>
  %b2_data = arith.constant dense<[18, 22, 26, 30]> : vector<4xi8>
  %b3_data = arith.constant dense<[19, 23, 27, 31]> : vector<4xi8>

  %a0 = vector.scalable.insert %a0_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
  %b0 = vector.scalable.insert %b0_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
  %a1 = vector.scalable.insert %a1_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
  %b1 = vector.scalable.insert %b1_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
  %a2 = vector.scalable.insert %a2_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
  %b2 = vector.scalable.insert %b2_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
  %a3 = vector.scalable.insert %a3_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
  %b3 = vector.scalable.insert %b3_data, %undef[0] : vector<4xi8> into vector<[4]xi8>

  %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
  %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
  %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
  %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
  %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
  %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
  %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
  %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>

  %c1 = arith.constant 1 : index
  %c2 = arith.constant 2 : index
  %c3 = arith.constant 3 : index
  %c4 = arith.constant 4 : index

  %mask0 = vector.create_mask %c1, %c1 : vector<[4]x[4]xi1>
  %mask1 = vector.create_mask %c1, %c2 : vector<[4]x[4]xi1>
  %mask2 = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1>
  %mask3 = vector.create_mask %c3, %c4 : vector<[4]x[4]xi1>

  %acc = arith.constant dense<2> : vector<[4]x[4]xi32>
  %0 = vector.mask %mask0 {
    vector.outerproduct %a0_ext, %b0_ext, %acc : vector<[4]xi32>, vector<[4]xi32>
  } : vector<[4]x[4]xi1> -> vector<[4]x[4]xi32>
  %1 = vector.mask %mask1 {
    vector.outerproduct %a1_ext, %b1_ext, %0 : vector<[4]xi32>, vector<[4]xi32>
  } : vector<[4]x[4]xi1> -> vector<[4]x[4]xi32>
  %2 = vector.mask %mask2 {
    vector.outerproduct %a2_ext, %b2_ext, %1 : vector<[4]xi32>, vector<[4]xi32>
  } : vector<[4]x[4]xi1> -> vector<[4]x[4]xi32>
  %3 = vector.mask %mask3 {
    vector.outerproduct %a3_ext, %b3_ext, %2 : vector<[4]xi32>, vector<[4]xi32>
  } : vector<[4]x[4]xi1> -> vector<[4]x[4]xi32>

  // CHECK:      ( 112, 136, 135,  95 )
  // CHECK-NEXT: ( 243, 295, 347, 219 )
  // CHECK-NEXT: ( 211, 255, 299, 343 )
  // CHECK-NEXT: (   2,   2,   2,   2 )
  vector.print %3 : vector<[4]x[4]xi32>

  return
}

func.func private @setArmSVLBits(%bits : i32)