// DEFINE: %{entry} = main
// DEFINE: %{compile} = mlir-opt %s -test-lower-to-arm-sme -test-lower-to-llvm
// DEFINE: %{run} = %mcr_aarch64_cmd \
// DEFINE: -march=aarch64 -mattr=+sve,+sme \
// DEFINE: -e %{entry} -entry-point-result=void \
// DEFINE: -shared-libs=%native_mlir_runner_utils,%native_mlir_c_runner_utils,%native_mlir_arm_runner_utils,%native_arm_sme_abi_shlib
// RUN: %{compile} | %{run} | FileCheck %s
// NOTE: QEMU <= 8.2.0 gives incorrect result for SME SMOPA 4-way outer product
// instruction see: https://gitlab.com/qemu-project/qemu/-/issues/2083.
// NOTE: there is no non-widening variant for these types and this test can't
// be lowered without the widening pass, therefore we can't check if the result
// is the same without widening pass like 'test-outerproduct-f16f16f32.mlir'
// does.
func.func @main() {
%c128 = arith.constant 128 : i32
func.call @setArmSVLBits(%c128) : (i32) -> ()
func.call @test_outerproduct_i8i8i32 () : () -> ()
func.call @test_masked_outerproduct_i8i8i32() : () -> ()
return
}
func.func @test_outerproduct_i8i8i32() {
%undef = llvm.mlir.undef : vector<[4]xi8>
%a0_data = arith.constant dense<[0, 4, 8, 12]> : vector<4xi8>
%a1_data = arith.constant dense<[1, 5, 9, 13]> : vector<4xi8>
%a2_data = arith.constant dense<[2, 6, 10, 14]> : vector<4xi8>
%a3_data = arith.constant dense<[3, 7, 11, 15]> : vector<4xi8>
%b0_data = arith.constant dense<[16, 20, 24, 28]> : vector<4xi8>
%b1_data = arith.constant dense<[17, 21, 25, 29]> : vector<4xi8>
%b2_data = arith.constant dense<[18, 22, 26, 30]> : vector<4xi8>
%b3_data = arith.constant dense<[19, 23, 27, 31]> : vector<4xi8>
%a0 = vector.scalable.insert %a0_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
%b0 = vector.scalable.insert %b0_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
%a1 = vector.scalable.insert %a1_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
%b1 = vector.scalable.insert %b1_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
%a2 = vector.scalable.insert %a2_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
%b2 = vector.scalable.insert %b2_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
%a3 = vector.scalable.insert %a3_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
%b3 = vector.scalable.insert %b3_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
%a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
%b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
%a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
%b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
%a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
%b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
%a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
%b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
%0 = vector.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32>
%1 = vector.outerproduct %a1_ext, %b1_ext, %0 : vector<[4]xi32>, vector<[4]xi32>
%2 = vector.outerproduct %a2_ext, %b2_ext, %1 : vector<[4]xi32>, vector<[4]xi32>
%3 = vector.outerproduct %a3_ext, %b3_ext, %2 : vector<[4]xi32>, vector<[4]xi32>
// CHECK: ( 110, 134, 158, 182 )
// CHECK-NEXT: ( 390, 478, 566, 654 )
// CHECK-NEXT: ( 670, 822, 974, 1126 )
// CHECK-NEXT: ( 950, 1166, 1382, 1598 )
vector.print %3 : vector<[4]x[4]xi32>
return
}
func.func @test_masked_outerproduct_i8i8i32() {
%undef = llvm.mlir.undef : vector<[4]xi8>
%a0_data = arith.constant dense<[0, 4, 8, 12]> : vector<4xi8>
%a1_data = arith.constant dense<[1, 5, 9, 13]> : vector<4xi8>
%a2_data = arith.constant dense<[2, 6, 10, 14]> : vector<4xi8>
%a3_data = arith.constant dense<[3, 7, 11, 15]> : vector<4xi8>
%b0_data = arith.constant dense<[16, 20, 24, 28]> : vector<4xi8>
%b1_data = arith.constant dense<[17, 21, 25, 29]> : vector<4xi8>
%b2_data = arith.constant dense<[18, 22, 26, 30]> : vector<4xi8>
%b3_data = arith.constant dense<[19, 23, 27, 31]> : vector<4xi8>
%a0 = vector.scalable.insert %a0_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
%b0 = vector.scalable.insert %b0_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
%a1 = vector.scalable.insert %a1_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
%b1 = vector.scalable.insert %b1_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
%a2 = vector.scalable.insert %a2_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
%b2 = vector.scalable.insert %b2_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
%a3 = vector.scalable.insert %a3_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
%b3 = vector.scalable.insert %b3_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
%a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
%b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
%a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
%b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
%a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
%b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
%a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
%b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c4 = arith.constant 4 : index
%mask0 = vector.create_mask %c1, %c1 : vector<[4]x[4]xi1>
%mask1 = vector.create_mask %c1, %c2 : vector<[4]x[4]xi1>
%mask2 = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1>
%mask3 = vector.create_mask %c3, %c4 : vector<[4]x[4]xi1>
%acc = arith.constant dense<2> : vector<[4]x[4]xi32>
%0 = vector.mask %mask0 {
vector.outerproduct %a0_ext, %b0_ext, %acc : vector<[4]xi32>, vector<[4]xi32>
} : vector<[4]x[4]xi1> -> vector<[4]x[4]xi32>
%1 = vector.mask %mask1 {
vector.outerproduct %a1_ext, %b1_ext, %0 : vector<[4]xi32>, vector<[4]xi32>
} : vector<[4]x[4]xi1> -> vector<[4]x[4]xi32>
%2 = vector.mask %mask2 {
vector.outerproduct %a2_ext, %b2_ext, %1 : vector<[4]xi32>, vector<[4]xi32>
} : vector<[4]x[4]xi1> -> vector<[4]x[4]xi32>
%3 = vector.mask %mask3 {
vector.outerproduct %a3_ext, %b3_ext, %2 : vector<[4]xi32>, vector<[4]xi32>
} : vector<[4]x[4]xi1> -> vector<[4]x[4]xi32>
// CHECK: ( 112, 136, 135, 95 )
// CHECK-NEXT: ( 243, 295, 347, 219 )
// CHECK-NEXT: ( 211, 255, 299, 343 )
// CHECK-NEXT: ( 2, 2, 2, 2 )
vector.print %3 : vector<[4]x[4]xi32>
return
}
func.func private @setArmSVLBits(%bits : i32)