// RUN: mlir-opt -split-input-file -canonicalize -cse %s | FileCheck %s
// This test verifies the simplification of IR patterns that emerge when
// lowering high-level element-wise ops with unranked tensor inputs. Consider
// the following function incrementing and doubling the value of an input
// unranked tensor using ops in a hypothetical high-level dialect called 'hl':
//
// func.func @f(%input: tensor<*xf32>) -> tensor<*xf32> {
// %0 = hl.inc %input : tensor<*xf32>
// %1 = hl.double %0 : tensor<*xf32>
// return %1 : tensor<*xf32>
// }
//
// A possible strategy to lower 'hl.inc' consists in reshaping its operand into
// a 1D tensor, creating a 1D tensor splat with the same total size as the input
// operand and with value 1.0, adding both 1D tensors using 'arith.addf', and
// reshaping the result back into the original input shape. A similar process
// applies for 'hl.double', except with a tensor splat with value 2.0 and an
// 'arith.mulf' op. The body of the function in the test below contains the full
// sequence.
//
// Since such lowering process would operate on individual 'hl' ops in a
// context-oblivious manner, the emitted code produces a redundant IR pattern
// where the result of 'arith.addf' is reshaped into an unranked tensor, just
// for it to be immediately reshaped back into the 1D tensor consumed by
// 'arith.mulf'. This entails the overhead of re-computing the unranked tensor
// shape ('shape.shape_of') and size ('shape.num_elements').
//
// This test verifies that the consecutive application of a canonicalization and
// a CSE pass successfully simplifies this emerging pattern, leading to a
// version of the code in which the result of the emitted 'arith.addf' op
// associated with 'hl.inc' is directly consumed by the 'arith.mulf' op
// associated with 'hl.double', as observed in the FileCheck directives. The
// main rewrite patterns at play are 'shape.shape_of' canonicalization,
// 'tensor.reshape' canonicalization, and 'shape.num_elements' subexpression
// elimination.
//
// CHECK-LABEL: @unranked_tensor_lowering
// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>
// CHECK-DAG: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK-DAG: %[[TWO:.*]] = arith.constant 2.000000e+00 : f32
// CHECK: %[[INPUT_SHAPE:.*]] = shape.shape_of %[[INPUT]] : tensor<*xf32> -> tensor<?xindex>
// CHECK: %[[INPUT_SIZE:.*]] = shape.num_elements %[[INPUT_SHAPE]] : tensor<?xindex> -> index
// CHECK: %[[INPUT_COLLAPSED_SHAPE:.*]] = tensor.from_elements %[[INPUT_SIZE]] : tensor<1xindex>
// CHECK: %[[INPUT_COLLAPSED:.*]] = tensor.reshape %[[INPUT]](%[[INPUT_COLLAPSED_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[ONE_SPLAT:.*]] = tensor.splat %[[ONE]]{{\[}}%[[INPUT_SIZE]]] : tensor<?xf32>
// CHECK: %[[SUM_COLLAPSED:.*]] = arith.addf %[[INPUT_COLLAPSED]], %[[ONE_SPLAT]] : tensor<?xf32>
// CHECK: %[[TWO_SPLAT:.*]] = tensor.splat %[[TWO]]{{\[}}%[[INPUT_SIZE]]] : tensor<?xf32>
// CHECK: %[[PRODUCT_COLLAPSED:.*]] = arith.mulf %[[SUM_COLLAPSED]], %[[TWO_SPLAT]] : tensor<?xf32>
// CHECK: %[[PRODUCT:.*]] = tensor.reshape %[[PRODUCT_COLLAPSED]](%[[INPUT_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK: return %[[PRODUCT]] : tensor<*xf32>
func.func @unranked_tensor_lowering(%input: tensor<*xf32>) -> tensor<*xf32> {
// Collapse input
%input_shape = shape.shape_of %input : tensor<*xf32> -> tensor<?xindex>
%input_size = shape.num_elements %input_shape : tensor<?xindex> -> index
%input_collapsed_shape = tensor.from_elements %input_size : tensor<1xindex>
%input_collapsed = tensor.reshape %input(%input_collapsed_shape) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// Second operand for sum
%one = arith.constant 1.0 : f32
%one_splat = tensor.splat %one[%input_size] : tensor<?xf32>
// Compute sum and expand it
%sum_collapsed = arith.addf %input_collapsed, %one_splat : tensor<?xf32>
%sum = tensor.reshape %sum_collapsed(%input_shape) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// Collapse sum
%sum_shape = shape.shape_of %sum : tensor<*xf32> -> tensor<?xindex>
%sum_size = shape.num_elements %sum_shape : tensor<?xindex> -> index
%sum_collapsed_shape = tensor.from_elements %sum_size : tensor<1xindex>
%sum_collapsed_0 = tensor.reshape %sum(%sum_collapsed_shape) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// Second operand for product
%two = arith.constant 2.0 : f32
%two_splat = tensor.splat %two[%sum_size] : tensor<?xf32>
// Compute product and expand it
%product_collapsed = arith.mulf %sum_collapsed_0, %two_splat : tensor<?xf32>
%product = tensor.reshape %product_collapsed(%sum_shape) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
return %product : tensor<*xf32>
}