llvm/mlir/test/Dialect/Tensor/resolve-shaped-type-result-dims.mlir

// RUN: mlir-opt -resolve-shaped-type-result-dims -split-input-file %s | FileCheck %s

func.func @insert_slice(
    %arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?x?xf32>,
    %arg2 : index, %arg3 : index, %arg4 : index) -> (index, index, index) {
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %c2 = arith.constant 2 : index
  %d0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
  %d1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
  %d2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
  %0 = tensor.insert_slice %arg0 into %arg1[%arg2, %arg3, %arg4] [%d0, %d1, %d2] [1, 1, 1] : tensor<?x?x?xf32> into tensor<?x?x?xf32>
  %1 = tensor.dim %0, %c0 : tensor<?x?x?xf32>
  %2 = tensor.dim %0, %c1 : tensor<?x?x?xf32>
  %3 = tensor.dim %0, %c2 : tensor<?x?x?xf32>
  return %1, %2, %3 : index, index, index
}
// CHECK-LABEL: func @insert_slice(
//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
//   CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
//   CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[ARG1]], %[[C0]]
//   CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[ARG1]], %[[C1]]
//   CHECK-DAG:   %[[D2:.+]] = tensor.dim %[[ARG1]], %[[C2]]
//       CHECK:   return %[[D0]], %[[D1]], %[[D2]]

// -----

func.func @extract_slice(%arg0 : tensor<?x?x?xf32>, %arg1 : index, %arg2 : index,
    %arg3 : index) -> (index, index, index) {
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %c2 = arith.constant 2 : index
  %0 = tensor.extract_slice %arg0[0, 0, 0] [%arg1, %arg2, %arg3] [1, 1, 1] :
      tensor<?x?x?xf32> to tensor<?x?x?xf32>
  %1 = tensor.dim %0, %c0 : tensor<?x?x?xf32>
  %2 = tensor.dim %0, %c1 : tensor<?x?x?xf32>
  %3 = tensor.dim %0, %c2 : tensor<?x?x?xf32>
  return %1, %2, %3 : index, index, index
}
// CHECK-LABEL: func @extract_slice(
//  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x?xf32>
//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9_]+]]: index
//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9_]+]]: index
//  CHECK-SAME:     %[[ARG3:[a-zA-Z0-9_]+]]: index
//       CHECK:   return %[[ARG1]], %[[ARG2]], %[[ARG3]]

// -----

func.func @extract_slice_rank_reduced_1(%arg0 : tensor<?x?x?xf32>,
    %arg1 : index) -> index {
  %c0 = arith.constant 0 : index
  %0 = tensor.extract_slice %arg0[0, 0, 0] [1, %arg1, 1] [1, 1, 1] :
     tensor<?x?x?xf32> to tensor<?xf32>
  %1 = tensor.dim %0, %c0 : tensor<?xf32>
  return %1 : index
}
// CHECK-LABEL: func @extract_slice_rank_reduced_1(
//  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x?xf32>
//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9_]+]]: index
//       CHECK:   return %[[ARG1]]

// -----

func.func @extract_slice_rank_reduced_2(%arg0 : tensor<?x?x?xf32>,
    %arg1 : index) -> index {
  %c0 = arith.constant 0 : index
  %0 = tensor.extract_slice %arg0[0, 0, 0] [1, %arg1, 1] [1, 1, 1] :
     tensor<?x?x?xf32> to tensor<?x1xf32>
  %1 = tensor.dim %0, %c0 : tensor<?x1xf32>
  return %1 : index
}
// CHECK-LABEL: func @extract_slice_rank_reduced_2(
//  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x?xf32>
//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9_]+]]: index
//       CHECK:   return %[[ARG1]]

// -----

func.func @extract_slice_rank_reduced_3(%arg0 : tensor<?x?x?xf32>,
    %arg1 : index) -> index {
  %c1 = arith.constant 1 : index
  %0 = tensor.extract_slice %arg0[0, 0, 0] [1, %arg1, 1] [1, 1, 1] :
     tensor<?x?x?xf32> to tensor<1x?xf32>
  %1 = tensor.dim %0, %c1 : tensor<1x?xf32>
  return %1 : index
}
// CHECK-LABEL: func @extract_slice_rank_reduced_3(
//  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x?xf32>
//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9_]+]]: index
//       CHECK:   return %[[ARG1]]

// -----

func.func @extract_slice_rank_reduced_4(%arg0 : tensor<?x?x?xf32>,
    %arg1 : index) -> index {
  %c1 = arith.constant 1 : index
  %0 = tensor.extract_slice %arg0[0, 0, 0] [1, %arg1, 1] [1, 1, 1] :
     tensor<?x?x?xf32> to tensor<1x?x1xf32>
  %1 = tensor.dim %0, %c1 : tensor<1x?x1xf32>
  return %1 : index
}
// CHECK-LABEL: func @extract_slice_rank_reduced_4(
//  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x?xf32>
//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9_]+]]: index
//       CHECK:   return %[[ARG1]]

// -----

func.func @extract_slice_rank_reduced_5(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
    %arg2 : index) -> (index, index) {
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %0 = tensor.extract_slice %arg0[0, 0, 0] [%arg1, 1, %arg2] [1, 1, 1] :
     tensor<?x?x?xf32> to tensor<?x?xf32>
  %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
  %2 = tensor.dim %0, %c1 : tensor<?x?xf32>
  return %1, %2 : index, index
}
// CHECK-LABEL: func @extract_slice_rank_reduced_5(
//  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x?xf32>
//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9_]+]]: index
//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9_]+]]: index
//       CHECK:   return %[[ARG1]], %[[ARG2]]

// -----

func.func @extract_slice_rank_reduced_6(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
    %arg2 : index) -> (index, index) {
  %c0 = arith.constant 0 : index
  %c2 = arith.constant 2 : index
  %0 = tensor.extract_slice %arg0[0, 0, 0] [%arg1, 1, %arg2] [1, 1, 1] :
     tensor<?x?x?xf32> to tensor<?x1x?xf32>
  %1 = tensor.dim %0, %c0 : tensor<?x1x?xf32>
  %2 = tensor.dim %0, %c2 : tensor<?x1x?xf32>
  return %1, %2 : index, index
}
// CHECK-LABEL: func @extract_slice_rank_reduced_6(
//  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x?xf32>
//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9_]+]]: index
//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9_]+]]: index
//       CHECK:   return %[[ARG1]], %[[ARG2]]

// -----

func.func @collapse_shape() -> index {
  %c0 = arith.constant 0 : index
  %c7 = arith.constant 7 : index
  %c1_i16 = arith.constant 1 : i16
  %generated = tensor.generate %c7 {
  ^bb0(%arg3: index, %arg4: index):
    tensor.yield %c1_i16 : i16
  } : tensor<?x22xi16>
  %collapsed = tensor.collapse_shape %generated [[0, 1]] : tensor<?x22xi16> into tensor<?xi16>
  %d0 = tensor.dim %collapsed, %c0 : tensor<?xi16>
  return %d0 : index
}
// CHECK-LABEL: func @collapse_shape(
//       CHECK:   %[[c154:.*]] = arith.constant 154 : index
//       CHECK:   return %[[c154]]