llvm/mlir/test/Dialect/ArmSVE/legalize-vector-storage.mlir

// RUN: mlir-opt %s -allow-unregistered-dialect -arm-sve-legalize-vector-storage -split-input-file -verify-diagnostics | FileCheck %s

/// This tests the basic functionality of the -arm-sve-legalize-vector-storage pass.

// -----

// CHECK-LABEL: @store_and_reload_sve_predicate_nxv1i1(
// CHECK-SAME:                                         %[[MASK:.*]]: vector<[1]xi1>)
func.func @store_and_reload_sve_predicate_nxv1i1(%mask: vector<[1]xi1>) -> vector<[1]xi1> {
  // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref<vector<[16]xi1>>
  %alloca = memref.alloca() : memref<vector<[1]xi1>>
  // CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<[1]xi1>
  // CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref<vector<[16]xi1>>
  memref.store %mask, %alloca[] : memref<vector<[1]xi1>>
  // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[ALLOCA]][] : memref<vector<[16]xi1>>
  // CHECK-NEXT: %[[MASK:.*]] = arm_sve.convert_from_svbool %[[RELOAD]] : vector<[1]xi1>
  %reload = memref.load %alloca[] : memref<vector<[1]xi1>>
  // CHECK-NEXT: return %[[MASK]] : vector<[1]xi1>
  return %reload : vector<[1]xi1>
}

// -----

// CHECK-LABEL: @store_and_reload_sve_predicate_nxv2i1(
// CHECK-SAME:                                         %[[MASK:.*]]: vector<[2]xi1>)
func.func @store_and_reload_sve_predicate_nxv2i1(%mask: vector<[2]xi1>) -> vector<[2]xi1> {
  // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref<vector<[16]xi1>>
  %alloca = memref.alloca() : memref<vector<[2]xi1>>
  // CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<[2]xi1>
  // CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref<vector<[16]xi1>>
  memref.store %mask, %alloca[] : memref<vector<[2]xi1>>
  // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[ALLOCA]][] : memref<vector<[16]xi1>>
  // CHECK-NEXT: %[[MASK:.*]] = arm_sve.convert_from_svbool %[[RELOAD]] : vector<[2]xi1>
  %reload = memref.load %alloca[] : memref<vector<[2]xi1>>
  // CHECK-NEXT: return %[[MASK]] : vector<[2]xi1>
  return %reload : vector<[2]xi1>
}

// -----

// CHECK-LABEL: @store_and_reload_sve_predicate_nxv4i1(
// CHECK-SAME:                                         %[[MASK:.*]]: vector<[4]xi1>)
func.func @store_and_reload_sve_predicate_nxv4i1(%mask: vector<[4]xi1>) -> vector<[4]xi1> {
  // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref<vector<[16]xi1>>
  %alloca = memref.alloca() : memref<vector<[4]xi1>>
  // CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<[4]xi1>
  // CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref<vector<[16]xi1>>
  memref.store %mask, %alloca[] : memref<vector<[4]xi1>>
  // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[ALLOCA]][] : memref<vector<[16]xi1>>
  // CHECK-NEXT: %[[MASK:.*]] = arm_sve.convert_from_svbool %[[RELOAD]] : vector<[4]xi1>
  %reload = memref.load %alloca[] : memref<vector<[4]xi1>>
  // CHECK-NEXT: return %[[MASK]] : vector<[4]xi1>
  return %reload : vector<[4]xi1>
}

// -----

// CHECK-LABEL: @store_and_reload_sve_predicate_nxv8i1(
// CHECK-SAME:                                         %[[MASK:.*]]: vector<[8]xi1>)
func.func @store_and_reload_sve_predicate_nxv8i1(%mask: vector<[8]xi1>) -> vector<[8]xi1> {
  // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref<vector<[16]xi1>>
  %alloca = memref.alloca() : memref<vector<[8]xi1>>
  // CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<[8]xi1>
  // CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref<vector<[16]xi1>>
  memref.store %mask, %alloca[] : memref<vector<[8]xi1>>
  // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[ALLOCA]][] : memref<vector<[16]xi1>>
  // CHECK-NEXT: %[[MASK:.*]] = arm_sve.convert_from_svbool %[[RELOAD]] : vector<[8]xi1>
  %reload = memref.load %alloca[] : memref<vector<[8]xi1>>
  // CHECK-NEXT: return %[[MASK]] : vector<[8]xi1>
  return %reload : vector<[8]xi1>
}

// -----

// CHECK-LABEL: @store_and_reload_sve_predicate_nxv16i1(
// CHECK-SAME:                                         %[[MASK:.*]]: vector<[16]xi1>)
func.func @store_and_reload_sve_predicate_nxv16i1(%mask: vector<[16]xi1>) -> vector<[16]xi1> {
  // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref<vector<[16]xi1>>
  %alloca = memref.alloca() : memref<vector<[16]xi1>>
  // CHECK-NEXT: memref.store %[[MASK]], %[[ALLOCA]][] : memref<vector<[16]xi1>>
  memref.store %mask, %alloca[] : memref<vector<[16]xi1>>
  // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[ALLOCA]][] : memref<vector<[16]xi1>>
  %reload = memref.load %alloca[] : memref<vector<[16]xi1>>
  // CHECK-NEXT: return %[[RELOAD]] : vector<[16]xi1>
  return %reload : vector<[16]xi1>
}

// -----

/// This is not a valid SVE mask type, so is ignored by the
// `-arm-sve-legalize-vector-storage` pass.

// CHECK-LABEL: @store_and_reload_unsupported_type(
// CHECK-SAME:                                         %[[MASK:.*]]: vector<[7]xi1>)
func.func @store_and_reload_unsupported_type(%mask: vector<[7]xi1>) -> vector<[7]xi1> {
  // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref<vector<[7]xi1>>
  %alloca = memref.alloca() : memref<vector<[7]xi1>>
  // CHECK-NEXT: memref.store %[[MASK]], %[[ALLOCA]][] : memref<vector<[7]xi1>>
  memref.store %mask, %alloca[] : memref<vector<[7]xi1>>
  // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[ALLOCA]][] : memref<vector<[7]xi1>>
  %reload = memref.load %alloca[] : memref<vector<[7]xi1>>
  // CHECK-NEXT: return %[[RELOAD]] : vector<[7]xi1>
  return %reload : vector<[7]xi1>
}

// -----

// CHECK-LABEL: @store_2d_mask_and_reload_slice(
// CHECK-SAME:                                  %[[MASK:.*]]: vector<3x[8]xi1>)
func.func @store_2d_mask_and_reload_slice(%mask: vector<3x[8]xi1>) -> vector<[8]xi1> {
  // CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index
  %c0 = arith.constant 0 : index
  // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref<vector<3x[16]xi1>>
  %alloca = memref.alloca() : memref<vector<3x[8]xi1>>
  // CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<3x[8]xi1>
  // CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref<vector<3x[16]xi1>>
  memref.store %mask, %alloca[] : memref<vector<3x[8]xi1>>
  // CHECK-NEXT: %[[UNPACK:.*]] = vector.type_cast %[[ALLOCA]] : memref<vector<3x[16]xi1>> to memref<3xvector<[16]xi1>>
  %unpack = vector.type_cast %alloca : memref<vector<3x[8]xi1>> to memref<3xvector<[8]xi1>>
  // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[UNPACK]][%[[C0]]] : memref<3xvector<[16]xi1>>
  // CHECK-NEXT: %[[SLICE:.*]] = arm_sve.convert_from_svbool %[[RELOAD]] : vector<[8]xi1>
  %slice = memref.load %unpack[%c0] : memref<3xvector<[8]xi1>>
  // CHECK-NEXT: return %[[SLICE]] : vector<[8]xi1>
  return %slice : vector<[8]xi1>
}

// -----

// CHECK-LABEL: @set_sve_alloca_alignment
func.func @set_sve_alloca_alignment() {
  /// This checks the alignment of alloca's of scalable vectors will be
  /// something the backend can handle. Currently, the backend sets the
  /// alignment of scalable vectors to their base size (i.e. their size at
  /// vscale = 1). This works for hardware-sized types, which always get a
  /// 16-byte alignment. The problem is larger types e.g. vector<[8]xf32> end up
  /// with alignments larger than 16-bytes (e.g. 32-bytes here), which are
  /// unsupported. The `-arm-sve-legalize-vector-storage` pass avoids this
  /// issue by explicitly setting the alignment to 16-bytes for all scalable
  /// vectors.

  // CHECK-COUNT-6: alignment = 16
  %a1 = memref.alloca() : memref<vector<[32]xi8>>
  %a2 = memref.alloca() : memref<vector<[16]xi8>>
  %a3 = memref.alloca() : memref<vector<[8]xi8>>
  %a4 = memref.alloca() : memref<vector<[4]xi8>>
  %a5 = memref.alloca() : memref<vector<[2]xi8>>
  %a6 = memref.alloca() : memref<vector<[1]xi8>>

  // CHECK-COUNT-6: alignment = 16
  %b1 = memref.alloca() : memref<vector<[32]xi16>>
  %b2 = memref.alloca() : memref<vector<[16]xi16>>
  %b3 = memref.alloca() : memref<vector<[8]xi16>>
  %b4 = memref.alloca() : memref<vector<[4]xi16>>
  %b5 = memref.alloca() : memref<vector<[2]xi16>>
  %b6 = memref.alloca() : memref<vector<[1]xi16>>

  // CHECK-COUNT-6: alignment = 16
  %c1 = memref.alloca() : memref<vector<[32]xi32>>
  %c2 = memref.alloca() : memref<vector<[16]xi32>>
  %c3 = memref.alloca() : memref<vector<[8]xi32>>
  %c4 = memref.alloca() : memref<vector<[4]xi32>>
  %c5 = memref.alloca() : memref<vector<[2]xi32>>
  %c6 = memref.alloca() : memref<vector<[1]xi32>>

  // CHECK-COUNT-6: alignment = 16
  %d1 = memref.alloca() : memref<vector<[32]xi64>>
  %d2 = memref.alloca() : memref<vector<[16]xi64>>
  %d3 = memref.alloca() : memref<vector<[8]xi64>>
  %d4 = memref.alloca() : memref<vector<[4]xi64>>
  %d5 = memref.alloca() : memref<vector<[2]xi64>>
  %d6 = memref.alloca() : memref<vector<[1]xi64>>

  // CHECK-COUNT-6: alignment = 16
  %e1 = memref.alloca() : memref<vector<[32]xf32>>
  %e2 = memref.alloca() : memref<vector<[16]xf32>>
  %e3 = memref.alloca() : memref<vector<[8]xf32>>
  %e4 = memref.alloca() : memref<vector<[4]xf32>>
  %e5 = memref.alloca() : memref<vector<[2]xf32>>
  %e6 = memref.alloca() : memref<vector<[1]xf32>>

  // CHECK-COUNT-6: alignment = 16
  %f1 = memref.alloca() : memref<vector<[32]xf64>>
  %f2 = memref.alloca() : memref<vector<[16]xf64>>
  %f3 = memref.alloca() : memref<vector<[8]xf64>>
  %f4 = memref.alloca() : memref<vector<[4]xf64>>
  %f5 = memref.alloca() : memref<vector<[2]xf64>>
  %f6 = memref.alloca() : memref<vector<[1]xf64>>

  "prevent.dce"(
    %a1, %a2, %a3, %a4, %a5, %a6,
    %b1, %b2, %b3, %b4, %b5, %b6,
    %c1, %c2, %c3, %c4, %c5, %c6,
    %d1, %d2, %d3, %d4, %d5, %d6,
    %e1, %e2, %e3, %e4, %e5, %e6,
    %f1, %f2, %f3, %f4, %f5, %f6)
    : (memref<vector<[32]xi8>>, memref<vector<[16]xi8>>, memref<vector<[8]xi8>>, memref<vector<[4]xi8>>, memref<vector<[2]xi8>>, memref<vector<[1]xi8>>,
       memref<vector<[32]xi16>>, memref<vector<[16]xi16>>, memref<vector<[8]xi16>>, memref<vector<[4]xi16>>, memref<vector<[2]xi16>>, memref<vector<[1]xi16>>,
       memref<vector<[32]xi32>>, memref<vector<[16]xi32>>, memref<vector<[8]xi32>>, memref<vector<[4]xi32>>, memref<vector<[2]xi32>>, memref<vector<[1]xi32>>,
       memref<vector<[32]xi64>>, memref<vector<[16]xi64>>, memref<vector<[8]xi64>>, memref<vector<[4]xi64>>, memref<vector<[2]xi64>>, memref<vector<[1]xi64>>,
       memref<vector<[32]xf32>>, memref<vector<[16]xf32>>, memref<vector<[8]xf32>>, memref<vector<[4]xf32>>, memref<vector<[2]xf32>>, memref<vector<[1]xf32>>,
       memref<vector<[32]xf64>>, memref<vector<[16]xf64>>, memref<vector<[8]xf64>>, memref<vector<[4]xf64>>, memref<vector<[2]xf64>>, memref<vector<[1]xf64>>) -> ()
  return
}