llvm/mlir/test/Integration/Dialect/Transform/match_reduction.mlir

// RUN: mlir-opt %s --transform-interpreter --verify-diagnostics

module attributes { transform.with_named_sequence } {
  transform.named_sequence @_reduce_leading_trailing(%entry: !transform.any_op {transform.readonly})
      -> (!transform.any_op) {
    %c1 = transform.param.constant 1 : i64 -> !transform.param<i64>

    transform.match.structured %entry : !transform.any_op {
    ^bb0(%struct: !transform.any_op):
      transform.match.structured.dim %struct[all] {parallel} : !transform.any_op
      transform.match.structured.input %struct[all] {projected_permutation} : !transform.any_op
      transform.match.structured.init %struct[all] {permutation} : !transform.any_op
      %ni = transform.match.structured.num_inits %struct : (!transform.any_op) -> !transform.param<i64>
      transform.match.param.cmpi eq %ni, %c1 : !transform.param<i64>
    }
    transform.yield %entry : !transform.any_op
  }

  transform.named_sequence @fill_reduce_leading_trailing(%entry: !transform.any_op {transform.readonly})
      -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op,
          !transform.param<i64>, !transform.param<i64>, !transform.param<i64>) {
    %c1 = transform.param.constant 1 : i64 -> !transform.param<i64>
    %c2 = transform.param.constant 2 : i64 -> !transform.param<i64>
    %c4 = transform.param.constant 4 : i64 -> !transform.param<i64>

    %rk, %dms, %bw, %operand_o, %init_v, %trailing_o = transform.match.structured failures(propagate) %entry
        : (!transform.any_op) -> (!transform.param<i64>, !transform.param<i64>, !transform.param<i64>,
                                  !transform.any_op, !transform.any_value, !transform.any_op) {
    ^bb0(%struct: !transform.any_op):
      %rank = transform.match.structured.rank %struct : (!transform.any_op) -> !transform.param<i64>
      transform.match.param.cmpi ge %rank, %c2 : !transform.param<i64>
      transform.match.param.cmpi le %rank, %c4 : !transform.param<i64>

      transform.match.structured.dim %struct[-1] {reduction} : !transform.any_op
      transform.match.structured.dim %struct[except(-1)] {parallel} : !transform.any_op
      %dims = transform.match.structured.dim %struct[all] : (!transform.any_op) -> !transform.param<i64>

      %n_inputs = transform.match.structured.num_inputs %struct : (!transform.any_op) -> !transform.param<i64>
      %n_outputs = transform.match.structured.num_inits %struct : (!transform.any_op) -> !transform.param<i64>
      transform.match.param.cmpi eq %n_inputs, %c1 : !transform.param<i64>
      transform.match.param.cmpi eq %n_outputs, %c1 : !transform.param<i64>

      transform.match.structured.input %struct[0] {projected_permutation} : !transform.any_op
      transform.match.structured.init %struct[0] {projected_permutation} : !transform.any_op
      %init = transform.match.structured.init %struct[0] : (!transform.any_op) -> !transform.any_value

      // This danse is necessary to create an empty handle if there is no single
      // user without failing the entire match
      %trailing_optional = transform.sequence %struct : (!transform.any_op) -> !transform.any_op failures(suppress) {
      ^bb0(%struct_inner: !transform.any_op):
        %result = transform.match.structured failures(propagate) %struct_inner : (!transform.any_op) -> !transform.any_op {
        ^bb0(%struct_inner_inner: !transform.any_op):
          %result_inner = transform.match.structured.result %struct_inner_inner[0] {single} : (!transform.any_op) -> !transform.any_op
          %trailing = transform.include @_reduce_leading_trailing failures(propagate) (%result_inner) : (!transform.any_op) -> !transform.any_op
          transform.match.structured.yield %trailing : !transform.any_op
        }
        transform.yield %result: !transform.any_op
      }

      // Suppress errors as a way to implement optionality. We cannot suppress them in
      // the include because it keeps matching after "get_defining_op" fails, which
      // breaks the single-op precondition of the following ops. We don't want to
      // propagate that failure though.
      //
      // Additionally, we cannot put the sequence inside the call because its first
      // operand must be an operation handle (the verifier asserts!) and there is
      // no such handle available there.
      //
      // TODO: extend the structured matching to gracefully handle empty handles
      // or provide the suppress-errors-but-stop failure mode for includes to
      // implement optionality.
      %operand_optional = transform.sequence %struct : (!transform.any_op) -> !transform.any_op failures(suppress) {
      ^bb0(%struct_inner: !transform.any_op):
        %operand3 = transform.match.structured failures(propagate) %struct_inner : (!transform.any_op) -> !transform.any_op {
        ^bb1(%struct_inner_inner: !transform.any_op):
          %operand = transform.match.structured.input %struct_inner_inner[0] : (!transform.any_op) -> !transform.any_op
          %operand2 = transform.include @_reduce_leading_trailing failures(propagate) (%operand) : (!transform.any_op) -> !transform.any_op
          transform.match.structured.yield %operand2 : !transform.any_op
        }
        transform.yield %operand3 : !transform.any_op
      }

      %bitwidth = transform.match.structured.elemental_bitwidth %init : (!transform.any_value) -> !transform.param<i64>

      transform.match.structured.body %struct { reduction_position = 0 } : !transform.any_op
      transform.match.structured.yield %rank, %dims, %bitwidth, %operand_optional, %init, %trailing_optional
        : !transform.param<i64>, !transform.param<i64>, !transform.param<i64>,
          !transform.any_op, !transform.any_value, !transform.any_op
    }

    %init_o = transform.get_defining_op %init_v : (!transform.any_value) -> !transform.any_op
    transform.match.operation_name %init_o ["linalg.fill"] : !transform.any_op

    transform.yield %operand_o, %init_o, %entry, %trailing_o, %rk, %dms, %bw
        : !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op,
          !transform.param<i64>, !transform.param<i64>, !transform.param<i64>
  }

  transform.named_sequence @print_reduce_leading_trailing(
      %leading: !transform.any_op {transform.readonly},
      %fill: !transform.any_op {transform.readonly},
      %reduction: !transform.any_op {transform.readonly},
      %trailing: !transform.any_op {transform.readonly},
      %rank: !transform.param<i64> {transform.readonly},
      %dims: !transform.param<i64> {transform.readonly},
      %bitwidth: !transform.param<i64> {transform.readonly}) {
    transform.debug.emit_remark_at %leading, "leading" : !transform.any_op
    transform.debug.emit_remark_at %fill, "fill" : !transform.any_op
    transform.debug.emit_remark_at %reduction, "reduction" : !transform.any_op
    transform.debug.emit_remark_at %trailing, "trailing" : !transform.any_op
    transform.debug.emit_param_as_remark %rank, "rank" at %reduction : !transform.param<i64>, !transform.any_op
    transform.debug.emit_param_as_remark %dims, "dimensions" at %reduction : !transform.param<i64>, !transform.any_op
    transform.debug.emit_param_as_remark %bitwidth, "bitwidth" at %reduction : !transform.param<i64>, !transform.any_op
    transform.yield
  }

  transform.named_sequence @__transform_main(%root: !transform.any_op {transform.consumed}) {
    transform.foreach_match in %root
      @fill_reduce_leading_trailing -> @print_reduce_leading_trailing
      : (!transform.any_op) -> !transform.any_op
    transform.yield
  }
}

!in_tensor_t = tensor<8x64xf32>
!out_tensor_t = tensor<8xf32>

func.func @eltwise_reduce(%arg : !in_tensor_t) -> (!out_tensor_t) {
  %cst = arith.constant -0.000000e+00 : f32

  %0 = tensor.empty() : !out_tensor_t
  // expected-remark @below {{fill}}
  %1 = linalg.fill ins(%cst : f32) outs(%0 : !out_tensor_t) ->  !out_tensor_t
  %2 = tensor.empty() : !in_tensor_t
  // expected-remark @below {{leading}}
  %3 = linalg.generic {
    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
                     affine_map<(d0, d1) -> (d0, d1)>],
    iterator_types = ["parallel", "parallel"]}
    ins(%arg : !in_tensor_t) outs(%2 : !in_tensor_t) {
    ^bb0(%arg3: f32, %arg4: f32):
      %4 = arith.addf %arg3, %arg3 : f32
      %5 = arith.addf %4, %4 : f32
      linalg.yield %5 : f32
    } -> !in_tensor_t

  // expected-remark @below {{reduction}}
  // expected-remark @below {{rank 2}}
  // expected-remark @below {{dimensions 8 : i64, 64 : i64}}
  // expected-remark @below {{bitwidth 32 : i64}}
  %6 = linalg.generic {
    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
                     affine_map<(d0, d1) -> (d0)>],
    iterator_types = ["parallel", "reduction"]}
    ins(%3 : !in_tensor_t) outs(%1 : !out_tensor_t) {
      ^bb0(%arg3: f32, %arg4: f32):
        %4 = arith.addf %arg3, %arg4 : f32
        linalg.yield %4 : f32
      } -> !out_tensor_t

  return %6 : !out_tensor_t
}

func.func @reduce_eltwise(%arg : !in_tensor_t) -> (!out_tensor_t) {
  %cst = arith.constant -0.000000e+00 : f32

  %0 = tensor.empty() : !out_tensor_t
  // expected-remark @below {{fill}}
  %1 = linalg.fill ins(%cst : f32) outs(%0 : !out_tensor_t) -> !out_tensor_t
  // expected-remark @below {{reduction}}
  // expected-remark @below {{rank 2}}
  // expected-remark @below {{dimensions 8 : i64, 64 : i64}}
  // expected-remark @below {{bitwidth 32 : i64}}
  %5 = linalg.generic {
    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
                     affine_map<(d0, d1) -> (d0)>],
    iterator_types = ["parallel", "reduction"]}
    ins(%arg : !in_tensor_t) outs(%1 : !out_tensor_t) {
      ^bb0(%arg3: f32, %arg4: f32):
        %4 = arith.addf %arg3, %arg4 : f32
        linalg.yield %4 : f32
      } -> !out_tensor_t

  %6 = tensor.empty() : !out_tensor_t
  // expected-remark @below {{trailing}}
  %7 = linalg.generic {
    indexing_maps = [affine_map<(d0) -> (d0)>,
                     affine_map<(d0) -> (d0)>],
    iterator_types = ["parallel"]}
    ins(%5 : !out_tensor_t) outs(%6 : !out_tensor_t) {
    ^bb0(%arg3: f32, %arg4: f32):
      %4 = math.sqrt %arg3 : f32
      linalg.yield %4 : f32
    } -> !out_tensor_t
  return %7 : !out_tensor_t
}

func.func @eltwise_reduce_eltwise(%arg : !in_tensor_t) -> (!out_tensor_t) {
  %cst = arith.constant -0.000000e+00 : f32

  %0 = tensor.empty() : !out_tensor_t
  // expected-remark @below {{fill}}
  %1 = linalg.fill ins(%cst : f32) outs(%0 : !out_tensor_t) ->  !out_tensor_t
  %2 = tensor.empty() : !in_tensor_t
  // expected-remark @below {{leading}}
  %3 = linalg.generic {
    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
                     affine_map<(d0, d1) -> (d0, d1)>],
    iterator_types = ["parallel", "parallel"]}
    ins(%arg : !in_tensor_t) outs(%2 : !in_tensor_t) {
    ^bb0(%arg3: f32, %arg4: f32):
      %4 = arith.addf %arg3, %arg3 : f32
      %5 = arith.addf %4, %4 : f32
      linalg.yield %5 : f32
    } -> !in_tensor_t

  // expected-remark @below {{reduction}}
  // expected-remark @below {{rank 2}}
  // expected-remark @below {{dimensions 8 : i64, 64 : i64}}
  // expected-remark @below {{bitwidth 32 : i64}}
  %6 = linalg.generic {
    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
                     affine_map<(d0, d1) -> (d0)>],
    iterator_types = ["parallel", "reduction"]}
    ins(%3 : !in_tensor_t) outs(%1 : !out_tensor_t) {
      ^bb0(%arg3: f32, %arg4: f32):
        %4 = arith.addf %arg3, %arg4 : f32
        linalg.yield %4 : f32
      } -> !out_tensor_t

  %7 = tensor.empty() : !out_tensor_t
  // expected-remark @below {{trailing}}
  %8 = linalg.generic {
    indexing_maps = [affine_map<(d0) -> (d0)>,
                     affine_map<(d0) -> (d0)>],
    iterator_types = ["parallel"]}
    ins(%6 : !out_tensor_t) outs(%7 : !out_tensor_t) {
    ^bb0(%arg3: f32, %arg4: f32):
      %4 = math.sqrt %arg3 : f32
      linalg.yield %4 : f32
    } -> !out_tensor_t


  return %8 : !out_tensor_t
}

func.func @eltwise_reduce_eltwise_swapped(%arg : !in_tensor_t) -> (!out_tensor_t) {
  %cst = arith.constant -0.000000e+00 : f32

  %2 = tensor.empty() : !in_tensor_t
  // expected-remark @below {{leading}}
  %3 = linalg.generic {
    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
                     affine_map<(d0, d1) -> (d0, d1)>],
    iterator_types = ["parallel", "parallel"]}
    ins(%arg : !in_tensor_t) outs(%2 : !in_tensor_t) {
    ^bb0(%arg3: f32, %arg4: f32):
      %4 = arith.addf %arg3, %arg3 : f32
      %5 = arith.addf %4, %4 : f32
      linalg.yield %5 : f32
    } -> !in_tensor_t

  %0 = tensor.empty() : !out_tensor_t
  // expected-remark @below {{fill}}
  %1 = linalg.fill ins(%cst : f32) outs(%0 : !out_tensor_t) ->  !out_tensor_t
  // expected-remark @below {{reduction}}
  // expected-remark @below {{rank 2}}
  // expected-remark @below {{dimensions 8 : i64, 64 : i64}}
  // expected-remark @below {{bitwidth 32 : i64}}
  %6 = linalg.generic {
    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
                     affine_map<(d0, d1) -> (d0)>],
    iterator_types = ["parallel", "reduction"]}
    ins(%3 : !in_tensor_t) outs(%1 : !out_tensor_t) {
      ^bb0(%arg3: f32, %arg4: f32):
        %4 = arith.addf %arg3, %arg4 : f32
        linalg.yield %4 : f32
      } -> !out_tensor_t

  %7 = tensor.empty() : !out_tensor_t
  // expected-remark @below {{trailing}}
  %8 = linalg.generic {
    indexing_maps = [affine_map<(d0) -> (d0)>,
                     affine_map<(d0) -> (d0)>],
    iterator_types = ["parallel"]}
    ins(%6 : !out_tensor_t) outs(%7 : !out_tensor_t) {
    ^bb0(%arg3: f32, %arg4: f32):
      %4 = math.sqrt %arg3 : f32
      linalg.yield %4 : f32
    } -> !out_tensor_t


  return %8 : !out_tensor_t
}

func.func @reduction_with_extra_op_in_func(%arg0: tensor<8x479xf32>, %arg1: tensor<32x32xf32>) -> (tensor<8xf32>, tensor<32xf32>) {
  %cst = arith.constant 0.0 : f32
  %empty = tensor.empty() : tensor<8xf32>
  // expected-remark @below {{fill}}
  %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<8xf32>) -> tensor<8xf32>
  // expected-remark @below {{reduction}}
  // expected-remark @below {{rank 2}}
  // expected-remark @below {{dimensions 8 : i64, 479 : i64}}
  // expected-remark @below {{bitwidth 32 : i64}}
  %result = linalg.generic {
    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
                     affine_map<(d0, d1) -> (d0)>],
    iterator_types = ["parallel", "reduction"]}
    ins(%arg0 : tensor<8x479xf32>)
    outs(%fill : tensor<8xf32>) {
  ^bb0(%in: f32, %out: f32):
    %6 = arith.addf %in, %out : f32
    linalg.yield %6 : f32
  } -> tensor<8xf32>

  %empty2 = tensor.empty() : tensor<32xf32>
  %fill2 = linalg.fill ins(%cst : f32) outs(%empty2 : tensor<32xf32>) -> tensor<32xf32>
  return %result, %fill2 : tensor<8xf32>, tensor<32xf32>
}