llvm/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir

// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-rewrite-extract-slice-from-collapse-shape %s | FileCheck %s
// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns="test-rewrite-extract-slice-from-collapse-shape use-foreach" %s | FileCheck %s --check-prefix=FOREACH

func.func @extract_slice_static(%input: tensor<3x5x7x11xf32>) -> tensor<20x11xf32> {
  %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3]] : tensor<3x5x7x11xf32> into tensor<105x11xf32>
  %slice = tensor.extract_slice %collapsed [0, 0] [20, 11] [1, 1] : tensor<105x11xf32> to tensor<20x11xf32>
  return %slice : tensor<20x11xf32>
}

//     CHECK: func.func @extract_slice_static(%[[arg0:.+]]:
// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[c20:.+]] = arith.constant 20 : index
// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index
// CHECK-DAG: %[[c5:.+]] = arith.constant 5 : index
// CHECK-DAG: %[[c7:.+]] = arith.constant 7 : index
// CHECK-DAG: %[[init:.+]] = tensor.empty() : tensor<20x11xf32>
// CHECK-DAG: %[[tile:.+]] = scf.for %[[iv:.+]] = %[[c0]] to %[[c20]] step %[[c1]] iter_args(%[[iterArg:.+]] = %[[init]])
//     CHECK:   %[[multiIndex:.+]]:3 = affine.delinearize_index %[[iv]] into (%[[c3]], %[[c5]], %[[c7]]
//     CHECK:   %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 11] [1, 1, 1, 1] :
//     CHECK:   %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} :
//     CHECK:   %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg]][%[[iv]], 0] [1, 11] [1, 1] :
//     CHECK:   scf.yield %[[update]] :
//     CHECK: return %[[tile]]

//     FOREACH: func.func @extract_slice_static(%[[arg0:.+]]:
// FOREACH-DAG: %[[c3:.+]] = arith.constant 3 : index
// FOREACH-DAG: %[[c5:.+]] = arith.constant 5 : index
// FOREACH-DAG: %[[c7:.+]] = arith.constant 7 : index
// FOREACH-DAG: %[[init:.+]] = tensor.empty() : tensor<20x11xf32>
//     FOREACH: %[[tile:.+]] = scf.forall (%[[iv:.+]]) in (20) shared_outs(%[[dest:.+]] = %[[init]])
//     FOREACH:   %[[multiIndex:.+]]:3 = affine.delinearize_index %[[iv]] into (%[[c3]], %[[c5]], %[[c7]]
//     FOREACH:   %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 11] [1, 1, 1, 1] :
//     FOREACH:   %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} :
//     FOREACH:   in_parallel
// FOREACH-NEXT:   tensor.parallel_insert_slice %[[sliceFlat]] into %[[dest]][%[[iv]], 0] [1, 11] [1, 1] :
//     FOREACH: return %[[tile]]

// -----


func.func @extract_slice_static_strided(%input: tensor<3x5x7x11xf32>) -> tensor<10x5xf32> {
  %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3]] : tensor<3x5x7x11xf32> into tensor<105x11xf32>
  %slice = tensor.extract_slice %collapsed [13, 0] [10, 5] [2, 2] : tensor<105x11xf32> to tensor<10x5xf32>
  return %slice : tensor<10x5xf32>
}

//     CHECK: #[[$map0:.+]] = affine_map<(d0) -> (d0 * 2 + 13)>
//     CHECK: func.func @extract_slice_static_strided(%[[arg0:.+]]:
// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[c10:.+]] = arith.constant 10 : index
// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index
// CHECK-DAG: %[[c5:.+]] = arith.constant 5 : index
// CHECK-DAG: %[[c7:.+]] = arith.constant 7 : index
//     CHECK: %[[init:.+]] = tensor.empty() : tensor<10x5xf32>
//     CHECK: %[[tile:.+]] = scf.for %[[iv:.+]] = %[[c0]] to %[[c10]] step %[[c1]] iter_args(%[[iterArg:.+]] = %[[init]])
//     CHECK:   %[[inputIv:.+]] = affine.apply #[[$map0]](%[[iv]])
//     CHECK:   %[[multiIndex:.+]]:3 = affine.delinearize_index %[[inputIv]] into (%[[c3]], %[[c5]], %[[c7]]
//     CHECK:   %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 5] [1, 1, 1, 2] :
//     CHECK:   %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} :
//     CHECK:   %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg]][%[[iv]], 0] [1, 5] [1, 1] :
//     CHECK:   scf.yield %[[update]] :
//     CHECK: return %[[tile]]


// -----


func.func @extract_slice_dynamic(%input: tensor<3x?x?x11xf32>, %offt: index, %size: index) -> tensor<?x5xf32> {
  %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3]] : tensor<3x?x?x11xf32> into tensor<?x11xf32>
  %slice = tensor.extract_slice %collapsed [%offt, 0] [%size, 5] [2, 2] : tensor<?x11xf32> to tensor<?x5xf32>
  return %slice : tensor<?x5xf32>
}

//     CHECK: #[[map0:.+]] = affine_map<(d0)[s0] -> (d0 * 2 + s0)>
//     CHECK: func.func @extract_slice_dynamic(%[[arg0:.+]]: tensor<{{.*}}>, %[[lb:.+]]: index, %[[sz:.+]]: index)
// CHECK-DAG:   %[[c0:.+]] = arith.constant 0 : index
// CHECK-DAG:   %[[c1:.+]] = arith.constant 1 : index
// CHECK-DAG:   %[[c2:.+]] = arith.constant 2 : index
// CHECK-DAG:   %[[c3:.+]] = arith.constant 3 : index
//     CHECK:   %[[init:.+]] = tensor.empty(%[[sz]]) : tensor<?x5xf32>
// CHECK-DAG:   %[[d1:.+]] = tensor.dim %arg0, %[[c1]] : tensor<3x?x?x11xf32>
// CHECK-DAG:   %[[d2:.+]] = tensor.dim %arg0, %[[c2]] : tensor<3x?x?x11xf32>
//     CHECK:   %[[tile:.+]] = scf.for %[[iv:.+]] = %[[c0]] to %[[sz]] step %[[c1]] iter_args(%[[iterArg:.+]] = %[[init]])
//     CHECK:     %[[inputIv:.+]] = affine.apply #[[map0]](%[[iv]])[%[[lb]]]
//     CHECK:     %[[multiIndex:.+]]:3 = affine.delinearize_index %[[inputIv]] into (%[[c3]], %[[d1]], %[[d2]]) :
//     CHECK:     %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 5] [1, 1, 1, 2] :
//     CHECK:     %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} :
//     CHECK:     %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg]][%[[iv]], 0] [1, 5] [1, 1] :
//     CHECK:     scf.yield %[[update]] :
//     CHECK:   return %[[tile]] :

// -----


func.func @extract_slice_dynamic_multidim(%input: tensor<3x?x?x11x?xf32>, %offt0: index, %size0: index, %offt1: index, %size1: index) -> tensor<?x?xf32> {
  %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3, 4]] : tensor<3x?x?x11x?xf32> into tensor<?x?xf32>
  %slice = tensor.extract_slice %collapsed [%offt0, %offt1] [%size0, %size1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
  return %slice : tensor<?x?xf32>
}

//     CHECK: #[[map0:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
//     CHECK: func.func @extract_slice_dynamic_multidim(%[[arg0:.+]]: tensor<3x?x?x11x?xf32>, %[[lb1:.+]]: index, %[[sz1:.+]]: index, %[[lb2:.+]]: index, %[[sz2:.+]]: index)
// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[c2:.+]] = arith.constant 2 : index
// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index
// CHECK-DAG: %[[c4:.+]] = arith.constant 4 : index
// CHECK-DAG: %[[c11:.+]] = arith.constant 11 : index
//     CHECK: %[[init:.+]] = tensor.empty(%[[sz1]], %[[sz2]]) : tensor<?x?xf32>
// CHECK-DAG: %[[d1:.+]] = tensor.dim %[[arg0]], %[[c1]] :
// CHECK-DAG: %[[d2:.+]] = tensor.dim %[[arg0]], %[[c2]] :
// CHECK-DAG: %[[d4:.+]] = tensor.dim %[[arg0]], %[[c4]] :
//     CHECK: %[[tile1:.+]] = scf.for %[[iv1:.+]] = %[[c0]] to %[[sz1]] step %[[c1]] iter_args(%[[iterArg1:.+]] = %[[init]])
//     CHECK:   %[[tile2:.+]] = scf.for %[[iv2:.+]] = %[[c0]] to %[[sz2]] step %[[c1]] iter_args(%[[iterArg2:.+]] = %[[iterArg1]])
//     CHECK:       %[[inputIv1:.+]] = affine.apply #[[map0:.+]](%[[iv1]])[%[[lb1]]]
//     CHECK:       %[[multiIndex1:.+]]:3 = affine.delinearize_index %[[inputIv1]] into (%[[c3]], %[[d1]], %[[d2]]) :
//     CHECK:       %[[inputIv2:.+]] = affine.apply #[[map0:.+]](%[[iv2]])[%[[lb2]]]
//     CHECK:       %[[multiIndex2:.+]]:2 = affine.delinearize_index %[[inputIv2]] into (%[[c11]], %[[d4]]) :
//     CHECK:       %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex1]]#0, %[[multiIndex1]]#1, %[[multiIndex1]]#2, %[[multiIndex2]]#0, %[[multiIndex2]]#1] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] :
//     CHECK:       %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3, 4]{{\]}} :
//     CHECK:       %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg2]][%[[iv1]], %[[iv2]]] [1, 1] [1, 1] :
//     CHECK:       scf.yield %[[update]] :
//     CHECK:     scf.yield %[[tile2]] :
//     CHECK:   return %[[tile1]] :

//     FOREACH: #[[map1:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
//     FOREACH: func.func @extract_slice_dynamic_multidim(%[[arg0:.+]]: tensor<3x?x?x11x?xf32>, %[[lb1:.+]]: index, %[[sz1:.+]]: index, %[[lb2:.+]]: index, %[[sz2:.+]]: index)
// FOREACH-DAG: %[[c1:.+]] = arith.constant 1 : index
// FOREACH-DAG: %[[c2:.+]] = arith.constant 2 : index
// FOREACH-DAG: %[[c3:.+]] = arith.constant 3 : index
// FOREACH-DAG: %[[c4:.+]] = arith.constant 4 : index
// FOREACH-DAG: %[[c11:.+]] = arith.constant 11 : index
//     FOREACH:     %[[init:.+]] = tensor.empty(%[[sz1]], %[[sz2]]) : tensor<?x?xf32>
// FOREACH-DAG:     %[[d1:.+]] = tensor.dim %[[arg0]], %[[c1]] :
// FOREACH-DAG:     %[[d2:.+]] = tensor.dim %[[arg0]], %[[c2]] :
// FOREACH-DAG:     %[[d4:.+]] = tensor.dim %[[arg0]], %[[c4]] :
//     FOREACH:     %[[tile1:.+]] = scf.forall (%[[tid1:.+]], %[[tid2:.+]]) in (%[[sz1]], %[[sz2]]) shared_outs(%[[dest:.+]] = %[[init]])
// FOREACH-DAG:       %[[iv1:.+]] = affine.apply #[[map1]](%[[tid1]])[%[[lb1]]]
//     FOREACH:       %[[multiIndex1:.+]]:3 = affine.delinearize_index %[[iv1]] into (%[[c3]], %[[d1]], %[[d2]]) :
// FOREACH-DAG:       %[[iv2:.+]] = affine.apply #[[map1]](%[[tid2]])[%[[lb2]]]
//     FOREACH:       %[[multiIndex2:.+]]:2 = affine.delinearize_index %[[iv2]] into (%[[c11]], %[[d4]]) :
//     FOREACH:       %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex1]]#0, %[[multiIndex1]]#1, %[[multiIndex1]]#2, %[[multiIndex2]]#0, %[[multiIndex2]]#1] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] :
//     FOREACH:       %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3, 4]{{\]}} :
//     FOREACH:       in_parallel
//FOREACH-NEXT:         tensor.parallel_insert_slice %[[sliceFlat]] into %[[dest]][%[[tid1]], %[[tid2]]] [1, 1] [1, 1] :

// -----

// Verifies that a linearized dimension that is not sliced does not generate a loop. Note that this
// only works for static shapes.

// CHECK: @extract_slice_non_sliced_linearized_dim(%[[arg0:.+]]: tensor<{{.*}}>,
func.func @extract_slice_non_sliced_linearized_dim(%input: tensor<3x?x?x11x2xf32>, %offt: index, %size: index) -> tensor<?x22xf32> {
  %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3, 4]] : tensor<3x?x?x11x2xf32> into tensor<?x22xf32>
  %slice = tensor.extract_slice %collapsed [%offt, 0] [%size, 22] [1, 1] : tensor<?x22xf32> to tensor<?x22xf32>
  // CHECK: scf.for
  // CHECK-NOT: scf.for
  // CHECK: %[[multiIndex:.+]]:3 = affine.delinearize_index
  // CHECK: tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0, 0] [1, 1, 1, 11, 2] [1, 1, 1, 1, 1]
  return %slice : tensor<?x22xf32>
}

// -----

// CHECK: @no_sliced_linearized_dims(%[[arg0:.+]]: tensor<{{.*}}>, %[[arg1:.+]]: index, %[[arg2:.+]]: index
func.func @no_sliced_linearized_dims(%input: tensor<30x11x100xf32>, %offt: index, %size: index) -> tensor<330x?xf32> {
  %collapsed = tensor.collapse_shape %input [[0, 1], [2]] : tensor<30x11x100xf32> into tensor<330x100xf32>
  %slice = tensor.extract_slice %collapsed [0, %offt] [330, %size] [1, 1] : tensor<330x100xf32> to tensor<330x?xf32>
  // CHECK-NOT: scf.for
  // CHECK: %[[init:.+]] = tensor.empty(%[[arg2]])
  // CHECK: %[[e:.+]] = tensor.extract_slice %[[arg0]][0, 0, %[[arg1]]] [30, 11, %[[arg2]]] [1, 1, 1]
  // CHECK: %[[c:.+]] = tensor.collapse_shape %[[e]] {{\[}}[0, 1], [2]]
  // CHECK: %[[res:.+]] = tensor.insert_slice %[[c]] into %[[init]]
  // CHECK: return %[[res]]
  return %slice : tensor<330x?xf32>
}

// -----

// The below tests verify that a dimension which is the result of collapsing at
// most one non-unit dim is handled properly.

// CHECK: @collapse_and_slice_unit_dim(%[[arg0:.+]]: tensor<{{.*}}>, %[[arg1:.+]]: index, %[[arg2:.+]]: index
func.func @collapse_and_slice_unit_dim(%input: tensor<1x11x100xf32>, %offt: index, %size: index) -> tensor<?x100xf32> {
  %collapsed = tensor.collapse_shape %input [[0, 1], [2]] : tensor<1x11x100xf32> into tensor<11x100xf32>
  %slice = tensor.extract_slice %collapsed [%offt, 0] [%size, 100] [1, 1] : tensor<11x100xf32> to tensor<?x100xf32>
  // CHECK-NOT: scf.for
  // CHECK: %[[e:.+]] = tensor.extract_slice %[[arg0]][0, 0, 0] [1, 11, 100] [1, 1, 1]
  // CHECK-SAME:           tensor<1x11x100xf32> to tensor<11x100xf32>
  // CHECK: %[[e1:.+]] = tensor.extract_slice %[[e]][%[[arg1]], 0] [%[[arg2]], 100] [1, 1]
  // CHECK-SAME:           tensor<11x100xf32> to tensor<?x100xf32>
  return %slice : tensor<?x100xf32>
}

// CHECK: @collapse_and_slice_multiple_unit_dim_dynamic(%[[arg0:.+]]: tensor<{{.*}}>, %[[arg1:.+]]: index, %[[arg2:.+]]: index
func.func @collapse_and_slice_multiple_unit_dim_dynamic(%input: tensor<1x?x1x100xf32>, %offt: index, %size: index) -> tensor<?x100xf32> {
  %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3]] : tensor<1x?x1x100xf32> into tensor<?x100xf32>
  %slice = tensor.extract_slice %collapsed [%offt, 0] [%size, 100] [1, 1] : tensor<?x100xf32> to tensor<?x100xf32>
  // CHECK-NOT: scf.for
  // CHECK: %[[c1:.+]] = arith.constant 1 : index
  // CHECK: %[[dim:.+]] = tensor.dim %[[arg0]], %[[c1]] :
  // CHECK: %[[e:.+]] = tensor.extract_slice %[[arg0]][0, 0, 0, 0] [1, %[[dim]], 1, 100] [1, 1, 1, 1]
  // CHECK-SAME:           tensor<1x?x1x100xf32> to tensor<?x100xf32>
  // CHECK: %[[e1:.+]] = tensor.extract_slice %[[e]][%[[arg1]], 0] [%[[arg2]], 100] [1, 1]
  // CHECK-SAME:           tensor<?x100xf32> to tensor<?x100xf32>
  return %slice : tensor<?x100xf32>
}

// CHECK: @collapse_and_slice_multiple_unit_dim_mixed(%[[arg0:.+]]: tensor<{{.*}}>, %[[arg1:.+]]: index, %[[arg2:.+]]: index
func.func @collapse_and_slice_multiple_unit_dim_mixed(%input: tensor<1x?x1x100x10xf32>, %offt: index, %size: index) -> tensor<?x?xf32> {
  %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3, 4]] : tensor<1x?x1x100x10xf32> into tensor<?x1000xf32>
  %slice = tensor.extract_slice %collapsed [%offt, %offt] [%size, %size] [1, 1] : tensor<?x1000xf32> to tensor<?x?xf32>
  // CHECK-DAG: %[[c0]] = arith.constant 0 : index
  // CHECK-DAG: %[[c1]] = arith.constant 1 : index
  // CHECK: %[[dim:.+]] = tensor.dim %[[arg0]], %[[c1]]
  // CHECK: %[[rank_reduced:.+]] = tensor.extract_slice %[[arg0]][0, 0, 0, 0, 0] [1, %[[dim]], 1, 100, 10] [1, 1, 1, 1, 1]
  // CHECK: %[[empty:.+]] = tensor.empty
  // CHECK: %[[result:.+]] = scf.for %[[iv:.+]] = %[[c0]] to %[[arg2]] step %[[c1]] iter_args(%[[ia:.+]] = %[[empty]])
  // CHECK:     %[[idx:.+]] = affine.apply
  // CHECK:     %[[multi_index:.+]] = affine.delinearize_index %[[idx]] into
  // CHECK:     %[[collapsed:.+]] = tensor.collapse_shape
  // CHECK:     %[[updated:.+]] = tensor.insert_slice
  // CHECK:     scf.yield %[[updated]]
  // CHECK: return %[[result]]
  return %slice : tensor<?x?xf32>
}

// Edge case where all collapsed dims are unit dims. This pattern can't eliminate the collapse shape,
// that should be handled by `linalg-fold-unit-extent-dims`.

// CHECK: @collapse_and_slice_multiple_all_unit_dim(%[[arg0:.+]]: tensor<{{.*}}>)
func.func @collapse_and_slice_multiple_all_unit_dim(%input: tensor<1x1x1x100xf32>) -> tensor<1x100xf32> {
  %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3]] : tensor<1x1x1x100xf32> into tensor<1x100xf32>
  %slice = tensor.extract_slice %collapsed [0, 0] [1, 100] [1, 1] : tensor<1x100xf32> to tensor<1x100xf32>
  return %slice : tensor<1x100xf32>
  // CHECK: %[[collapse:.+]] = tensor.collapse_shape %[[arg0]] {{\[}}[0, 1, 2], [3]] : tensor<1x1x1x100xf32> into tensor<1x100xf32>
  // CHECK: return %[[collapse]]
}