llvm/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir

// DEFINE: %{entry_point} = za0_d_f64
// DEFINE: %{compile} = mlir-opt %s -test-lower-to-arm-sme -test-lower-to-llvm
// DEFINE: %{run} = %mcr_aarch64_cmd \
// DEFINE:  -march=aarch64 -mattr=+sve,+sme \
// DEFINE:  -e %{entry_point} -entry-point-result=i32 \
// DEFINE:  -shared-libs=%native_mlir_runner_utils,%native_mlir_c_runner_utils,%native_arm_sme_abi_shlib

// RUN: %{compile} | %{run} | FileCheck %s --check-prefix=CHECK-ZA0_D

// REDEFINE: %{entry_point} = load_store_two_za_s_tiles
// RUN: %{compile} | %{run} | FileCheck %s

// Integration tests demonstrating load/store to/from SME ZA tile.

// This test verifies a 64-bit element ZA with FP64 data is correctly
// loaded/stored to/from memory.
func.func @za0_d_f64() -> i32 {
  %c0 = arith.constant 0 : index
  %c0_f64 = arith.constant 0.0 : f64
  %c1_f64 = arith.constant 1.0 : f64
  %c1_index = arith.constant 1 : index

  // "svl" refers to the Streaming Vector Length and "svl_d" the number of
  // 64-bit elements in a vector of SVL bits.
  %svl_d = arm_sme.streaming_vl <double>

  // Allocate "mem1" and fill each "row" with row number.
  //
  // For example, assuming an SVL of 256-bits:
  //
  //   0.1, 0.1, 0.1, 0.1
  //   1.1, 1.1, 1.1, 1.1
  //   2.1, 2.1, 2.1, 2.1
  //   3.1, 3.1, 3.1, 3.1
  //
  %tilesize = arith.muli %svl_d, %svl_d : index
  %mem1 = memref.alloca(%tilesize) : memref<?xf64>
  %init_0 = arith.constant 0.1 : f64
  scf.for %i = %c0 to %tilesize step %svl_d iter_args(%val = %init_0) -> (f64) {
    %splat_val = vector.broadcast %val : f64 to vector<[2]xf64>
    vector.store %splat_val, %mem1[%i] : memref<?xf64>, vector<[2]xf64>
    %val_next = arith.addf %val, %c1_f64 : f64
    scf.yield %val_next : f64
  }

  // Dump "mem1". The smallest SVL is 128-bits so the tile will be at least
  // 2x2xi64.
  //
  // CHECK-ZA0_D:      ( 0.1, 0.1
  // CHECK-ZA0_D-NEXT: ( 1.1, 1.1
  scf.for %i = %c0 to %tilesize step %svl_d {
    %tileslice = vector.load %mem1[%i] : memref<?xf64>, vector<[2]xf64>
    vector.print %tileslice : vector<[2]xf64>
  }

  // Load ZA0.D from "mem1"
  %za0_d = vector.load %mem1[%c0] : memref<?xf64>, vector<[2]x[2]xf64>

  // Allocate "mem2" to store ZA0.D to
  %mem2 = memref.alloca(%tilesize) : memref<?xf64>

  // Zero "mem2"
  scf.for %i = %c0 to %tilesize step %c1_index {
    memref.store %c0_f64, %mem2[%i] : memref<?xf64>
  }

  // Verify "mem2" is zeroed by doing an add reduction with initial value of
  // zero
  %init_0_f64 = arith.constant 0.0 : f64
  %add_reduce = scf.for %vnum = %c0 to %tilesize step %svl_d iter_args(%iter = %init_0_f64) -> (f64) {
    %row = vector.load %mem2[%vnum] : memref<?xf64>, vector<[2]xf64>

    %inner_add_reduce = scf.for %offset = %c0 to %svl_d step %c1_index iter_args(%inner_iter = %init_0_f64) -> (f64) {
      %t = vector.extractelement %row[%offset : index] : vector<[2]xf64>
      %inner_add_reduce_next = arith.addf %inner_iter, %t : f64
      scf.yield %inner_add_reduce_next : f64
    }

    %add_reduce_next = arith.addf %iter, %inner_add_reduce : f64
    scf.yield %add_reduce_next : f64
  }

  // CHECK-ZA0_D: 0
  vector.print %add_reduce : f64

  // Dump zeroed "mem2". The smallest SVL is 128-bits so the tile will be at
  // least 2x2xi64.
  //
  // CHECK-ZA0_D-NEXT: ( 0, 0
  // CHECK-ZA0_D-NEXT: ( 0, 0
  scf.for %i = %c0 to %tilesize step %svl_d {
    %tileslice = vector.load %mem2[%i] : memref<?xf64>, vector<[2]xf64>
    vector.print %tileslice : vector<[2]xf64>
  }

  // Verify "mem1" != "mem2"
  %init_1 = arith.constant 1 : i64
  %mul_reduce_0 = scf.for %vnum = %c0 to %tilesize step %svl_d iter_args(%iter = %init_1) -> (i64) {
    %row_1 = vector.load %mem1[%vnum] : memref<?xf64>, vector<[2]xf64>
    %row_2 = vector.load %mem2[%vnum] : memref<?xf64>, vector<[2]xf64>
    %cmp = arith.cmpf one, %row_1, %row_2 : vector<[2]xf64>

    %inner_mul_reduce = scf.for %i = %c0 to %svl_d step %c1_index iter_args(%inner_iter = %init_1) -> (i64) {
      %t = vector.extractelement %cmp[%i : index] : vector<[2]xi1>
      %t_i64 = arith.extui %t : i1 to i64
      %inner_mul_reduce_next = arith.muli %inner_iter, %t_i64 : i64
      scf.yield %inner_mul_reduce_next : i64
    }

    %mul_reduce_next = arith.muli %iter, %inner_mul_reduce : i64
    scf.yield %mul_reduce_next : i64
  }

  // CHECK-ZA0_D: 1
  vector.print %mul_reduce_0 : i64

  // Store ZA0.D to "mem2"
  vector.store %za0_d, %mem2[%c0] : memref<?xf64>, vector<[2]x[2]xf64>

  // Verify "mem1" == "mem2"
  %mul_reduce_1 = scf.for %vnum = %c0 to %tilesize step %svl_d iter_args(%iter = %init_1) -> (i64) {
    %row_1 = vector.load %mem1[%vnum] : memref<?xf64>, vector<[2]xf64>
    %row_2 = vector.load %mem2[%vnum] : memref<?xf64>, vector<[2]xf64>
    %cmp = arith.cmpf oeq, %row_1, %row_2 : vector<[2]xf64>

    %inner_mul_reduce = scf.for %i = %c0 to %svl_d step %c1_index iter_args(%inner_iter = %init_1) -> (i64) {
      %t = vector.extractelement %cmp[%i : index] : vector<[2]xi1>
      %t_i64 = arith.extui %t : i1 to i64
      %inner_mul_reduce_next = arith.muli %inner_iter, %t_i64 : i64
      scf.yield %inner_mul_reduce_next : i64
    }

    %mul_reduce_next = arith.muli %iter, %inner_mul_reduce : i64
    scf.yield %mul_reduce_next : i64
  }

  // CHECK-ZA0_D-NEXT: 1
  vector.print %mul_reduce_1 : i64

  // Dump "mem2". The smallest SVL is 128-bits so the tile will be at least
  // 2x2xi64.
  //
  // CHECK-ZA0_D-NEXT: ( 0.1, 0.1
  // CHECK-ZA0_D-NEXT: ( 1.1, 1.1
  scf.for %i = %c0 to %tilesize step %svl_d {
    %tileslice = vector.load %mem2[%i] : memref<?xf64>, vector<[2]xf64>
    vector.print %tileslice : vector<[2]xf64>
  }

  %c0_i32 = arith.constant 0 : i32
  return %c0_i32 : i32
}

// This test loads two 32-bit element ZA tiles from memory and stores them back
// to memory in reverse order. This verifies the memref indices for the vector
// load and store are correctly preserved since the second tile is offset from
// the first tile.
func.func @load_store_two_za_s_tiles() -> i32 {
  %c0 = arith.constant 0 : index
  %c0_i32 = arith.constant 0 : i32
  %c1_i32 = arith.constant 1 : i32
  %c2_i32 = arith.constant 2 : i32
  %c1_index = arith.constant 1 : index
  %c2_index = arith.constant 2 : index

  // "svl" refers to the Streaming Vector Length and "svl_s" can mean either:
  // * the number of 32-bit elements in a vector of SVL bits.
  // * the number of tile slices (1d vectors) in a 32-bit element tile.
  %svl_s = arm_sme.streaming_vl <word>

  // Allocate memory for two 32-bit element tiles.
  %size_of_tile = arith.muli %svl_s, %svl_s : index
  %size_of_two_tiles = arith.muli %size_of_tile, %c2_index : index
  %mem1 = memref.alloca(%size_of_two_tiles) : memref<?xi32>

  // Fill memory that tile 1 will be loaded from with '1' and '2' for tile 2.
  //
  // For example, assuming an SVL of 128-bits and two 4x4xi32 tiles:
  //
  // tile 1
  //
  //   1, 1, 1, 1
  //   1, 1, 1, 1
  //   1, 1, 1, 1
  //   1, 1, 1, 1
  //
  // tile 2
  //
  //   2, 2, 2, 2
  //   2, 2, 2, 2
  //   2, 2, 2, 2
  //   2, 2, 2, 2
  //
  scf.for %i = %c0 to %size_of_two_tiles step %svl_s {
    %isFirstTile = arith.cmpi ult, %i, %size_of_tile : index
    %val = scf.if %isFirstTile -> i32 {
      scf.yield %c1_i32 : i32
    } else {
      scf.yield %c2_i32 : i32
    }
    %splat_val = vector.broadcast %val : i32 to vector<[4]xi32>
    vector.store %splat_val, %mem1[%i] : memref<?xi32>, vector<[4]xi32>
  }

  // Dump "mem1". The smallest SVL is 128-bits so each tile will be at least
  // 4x4xi32.
  //
  // CHECK:      ( 1, 1, 1, 1
  // CHECK-NEXT: ( 1, 1, 1, 1
  // CHECK-NEXT: ( 1, 1, 1, 1
  // CHECK-NEXT: ( 1, 1, 1, 1
  // CHECK:      ( 2, 2, 2, 2
  // CHECK-NEXT: ( 2, 2, 2, 2
  // CHECK-NEXT: ( 2, 2, 2, 2
  // CHECK-NEXT: ( 2, 2, 2, 2
  scf.for %i = %c0 to %size_of_two_tiles step %svl_s {
    %tileslice = vector.load %mem1[%i] : memref<?xi32>, vector<[4]xi32>
    vector.print %tileslice : vector<[4]xi32>
  }

  // Load tile 1 from memory
  %za0_s = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>

  // Load tile 2 from memory
  %za1_s = vector.load %mem1[%size_of_tile] : memref<?xi32>, vector<[4]x[4]xi32>

  // Allocate new memory to store tiles to
  %mem2 = memref.alloca(%size_of_two_tiles)  : memref<?xi32>

  // Zero new memory
  scf.for %i = %c0 to %size_of_two_tiles step %c1_index {
    memref.store %c0_i32, %mem2[%i] : memref<?xi32>
  }

  // Stores tiles back to (new) memory in reverse order

  // Store tile 2 to memory
  vector.store %za1_s, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>

  // Store tile 1 to memory
  vector.store %za0_s, %mem2[%size_of_tile] : memref<?xi32>, vector<[4]x[4]xi32>

  // Dump "mem2" and check the tiles were stored in reverse order. The smallest
  // SVL is 128-bits so the tiles will be at least 4x4xi32.
  //
  // CHECK:      TILE BEGIN
  // CHECK-NEXT: ( 2, 2, 2, 2
  // CHECK-NEXT: ( 2, 2, 2, 2
  // CHECK-NEXT: ( 2, 2, 2, 2
  // CHECK-NEXT: ( 2, 2, 2, 2
  // CHECK:      TILE END
  // CHECK-NEXT: TILE BEGIN
  // CHECK-NEXT: ( 1, 1, 1, 1
  // CHECK-NEXT: ( 1, 1, 1, 1
  // CHECK-NEXT: ( 1, 1, 1, 1
  // CHECK-NEXT: ( 1, 1, 1, 1
  // CHECK:      TILE END
  vector.print str "TILE BEGIN\n"
  scf.for %i = %c0 to %size_of_two_tiles step %svl_s {
    %av = vector.load %mem2[%i] : memref<?xi32>, vector<[4]xi32>
    vector.print %av : vector<[4]xi32>

    %tileSizeMinusStep = arith.subi %size_of_tile, %svl_s : index
    %isNextTile = arith.cmpi eq, %i, %tileSizeMinusStep : index
    scf.if %isNextTile {
      vector.print str "TILE END\n"
      vector.print str "TILE BEGIN\n"
    }
  }
  vector.print str "TILE END\n"

  return %c0_i32 : i32
}