llvm/mlir/test/Dialect/Shape/unranked-tensor-lowering.mlir

// RUN: mlir-opt -split-input-file -canonicalize -cse %s | FileCheck %s

// This test verifies the simplification of IR patterns that emerge when
// lowering high-level element-wise ops with unranked tensor inputs. Consider
// the following function incrementing and doubling the value of an input
// unranked tensor using ops in a hypothetical high-level dialect called 'hl':
//
//  func.func @f(%input: tensor<*xf32>) -> tensor<*xf32> {
//    %0 = hl.inc %input : tensor<*xf32>
//    %1 = hl.double %0 : tensor<*xf32>
//    return %1 : tensor<*xf32>
//  }
//
// A possible strategy to lower 'hl.inc' consists in reshaping its operand into
// a 1D tensor, creating a 1D tensor splat with the same total size as the input
// operand and with value 1.0, adding both 1D tensors using 'arith.addf', and
// reshaping the result back into the original input shape. A similar process
// applies for 'hl.double', except with a tensor splat with value 2.0 and an
// 'arith.mulf' op. The body of the function in the test below contains the full
// sequence.
//
// Since such lowering process would operate on individual 'hl' ops in a
// context-oblivious manner, the emitted code produces a redundant IR pattern
// where the result of 'arith.addf' is reshaped into an unranked tensor, just
// for it to be immediately reshaped back into the 1D tensor consumed by
// 'arith.mulf'. This entails the overhead of re-computing the unranked tensor
// shape ('shape.shape_of') and size ('shape.num_elements').
//
// This test verifies that the consecutive application of a canonicalization and
// a CSE pass successfully simplifies this emerging pattern, leading to a
// version of the code in which the result of the emitted 'arith.addf' op
// associated with 'hl.inc' is directly consumed by the 'arith.mulf' op
// associated with 'hl.double', as observed in the FileCheck directives. The
// main rewrite patterns at play are 'shape.shape_of' canonicalization,
// 'tensor.reshape' canonicalization, and 'shape.num_elements' subexpression
// elimination.
//

// CHECK-LABEL: @unranked_tensor_lowering
// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>

// CHECK-DAG: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK-DAG: %[[TWO:.*]] = arith.constant 2.000000e+00 : f32

// CHECK: %[[INPUT_SHAPE:.*]] = shape.shape_of %[[INPUT]] : tensor<*xf32> -> tensor<?xindex>
// CHECK: %[[INPUT_SIZE:.*]] = shape.num_elements %[[INPUT_SHAPE]] : tensor<?xindex> -> index
// CHECK: %[[INPUT_COLLAPSED_SHAPE:.*]] = tensor.from_elements %[[INPUT_SIZE]] : tensor<1xindex>
// CHECK: %[[INPUT_COLLAPSED:.*]] = tensor.reshape %[[INPUT]](%[[INPUT_COLLAPSED_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>

// CHECK: %[[ONE_SPLAT:.*]] = tensor.splat %[[ONE]]{{\[}}%[[INPUT_SIZE]]] : tensor<?xf32>
// CHECK: %[[SUM_COLLAPSED:.*]] = arith.addf %[[INPUT_COLLAPSED]], %[[ONE_SPLAT]] : tensor<?xf32>

// CHECK: %[[TWO_SPLAT:.*]] = tensor.splat %[[TWO]]{{\[}}%[[INPUT_SIZE]]] : tensor<?xf32>
// CHECK: %[[PRODUCT_COLLAPSED:.*]] = arith.mulf %[[SUM_COLLAPSED]], %[[TWO_SPLAT]] : tensor<?xf32>

// CHECK: %[[PRODUCT:.*]] = tensor.reshape %[[PRODUCT_COLLAPSED]](%[[INPUT_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK: return %[[PRODUCT]] : tensor<*xf32>

func.func @unranked_tensor_lowering(%input: tensor<*xf32>) -> tensor<*xf32> {

  // Collapse input
  %input_shape = shape.shape_of %input : tensor<*xf32> -> tensor<?xindex>
  %input_size = shape.num_elements %input_shape : tensor<?xindex> -> index
  %input_collapsed_shape = tensor.from_elements %input_size : tensor<1xindex>
  %input_collapsed = tensor.reshape %input(%input_collapsed_shape) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>

  // Second operand for sum
  %one = arith.constant 1.0 : f32
  %one_splat = tensor.splat %one[%input_size] : tensor<?xf32>

  // Compute sum and expand it
  %sum_collapsed = arith.addf %input_collapsed, %one_splat : tensor<?xf32>
  %sum = tensor.reshape %sum_collapsed(%input_shape) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>

  // Collapse sum
  %sum_shape = shape.shape_of %sum : tensor<*xf32> -> tensor<?xindex>
  %sum_size = shape.num_elements %sum_shape : tensor<?xindex> -> index
  %sum_collapsed_shape = tensor.from_elements %sum_size : tensor<1xindex>
  %sum_collapsed_0 = tensor.reshape %sum(%sum_collapsed_shape) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>

  // Second operand for product
  %two = arith.constant 2.0 : f32
  %two_splat = tensor.splat %two[%sum_size] : tensor<?xf32>

  // Compute product and expand it
  %product_collapsed = arith.mulf %sum_collapsed_0, %two_splat : tensor<?xf32>
  %product = tensor.reshape %product_collapsed(%sum_shape) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>

  return %product : tensor<*xf32>
}