llvm/mlir/test/Dialect/Mesh/canonicalization.mlir

// RUN: mlir-opt --canonicalize %s | FileCheck %s

mesh.mesh @mesh0(shape = 2x4)

// CHECK-LABEL: func @all_reduce_empty_mesh_axes
func.func @all_reduce_empty_mesh_axes(
// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NOT: mesh.all_reduce
  %0 = mesh.all_reduce %arg0 on @mesh0
    mesh_axes = []
    : tensor<4xf32> -> tensor<4xf32>
// CHECK: return %[[ARG]]
  return %0 : tensor<4xf32>
}

// CHECK-LABEL: func @all_reduce_empty_mesh_axes_different_return_type
func.func @all_reduce_empty_mesh_axes_different_return_type(
    %arg0 : tensor<4xf32>) -> tensor<4xf64> {
// CHECK: mesh.all_reduce
  %0 = mesh.all_reduce %arg0 on @mesh0
// CHECK-NOT: mesh_axes
    mesh_axes = []
    : tensor<4xf32> -> tensor<4xf64>
  return %0 : tensor<4xf64>
}

// CHECK-LABEL: func @all_reduce_default_reduction
func.func @all_reduce_default_reduction(
    %arg0 : tensor<4xf32>) -> tensor<4xf64> {
  %0 = mesh.all_reduce %arg0 on @mesh0
    mesh_axes = [0]
// CHECK-NOT: reduction
    reduction = sum
    : tensor<4xf32> -> tensor<4xf64>
  return %0 : tensor<4xf64>
}

// CHECK-LABEL: func @all_to_all_empty_mesh_axes
func.func @all_to_all_empty_mesh_axes(
// CHECK-SAME: %[[ARG:.*]]: tensor<8xf32>
    %arg0 : tensor<8xf32>) -> tensor<8xf32> {
// CHECK-NOT: mesh.all_to_all
  %0 = mesh.all_to_all %arg0 on @mesh0
    mesh_axes = []
    split_axis = 0
    concat_axis = 0
    : tensor<8xf32> -> tensor<8xf32>
// CHECK: return %[[ARG]]
  return %0 : tensor<8xf32>
}

// CHECK-LABEL: func @all_gather_empty_mesh_axes
func.func @all_gather_empty_mesh_axes(
// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NOT: mesh.all_gather
  %0 = mesh.all_gather %arg0 on @mesh0
    mesh_axes = []
    gather_axis = 0
    : tensor<4xf32> -> tensor<4xf32>
// CHECK: return %[[ARG]]
  return %0 : tensor<4xf32>
}

// CHECK-LABEL: func @all_slice_empty_mesh_axes
func.func @all_slice_empty_mesh_axes(
// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NOT: mesh.scatter
  %0 = mesh.all_slice %arg0 on @mesh0
    mesh_axes = []
    slice_axis = 0
    : tensor<4xf32> -> tensor<4xf32>
// CHECK: return %[[ARG]]
  return %0 : tensor<4xf32>
}

// CHECK-LABEL: func @broadcast_empty_mesh_axes
func.func @broadcast_empty_mesh_axes(
// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NOT: mesh.broadcast
  %0 = mesh.broadcast %arg0 on @mesh0
    mesh_axes = []
    root = []
    : (tensor<4xf32>) -> tensor<4xf32>
// CHECK: return %[[ARG]]
  return %0 : tensor<4xf32>
}

// CHECK-LABEL: func @gather_empty_mesh_axes
func.func @gather_empty_mesh_axes(
// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NOT: mesh.gather
  %0 = mesh.gather %arg0 on @mesh0
    mesh_axes = []
    gather_axis = 0
    root = []
    : (tensor<4xf32>) -> tensor<4xf32>
// CHECK: return %[[ARG]]
  return %0 : tensor<4xf32>
}

// CHECK-LABEL: func @receive_empty_mesh_axes
func.func @receive_empty_mesh_axes(
// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NOT: mesh.recv
  %0 = mesh.recv %arg0 on @mesh0
    mesh_axes = []
    : (tensor<4xf32>) -> tensor<4xf32>
// CHECK: return %[[ARG]]
  return %0 : tensor<4xf32>
}

// CHECK-LABEL: func @reduce_empty_mesh_axes
func.func @reduce_empty_mesh_axes(
// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NOT: mesh.reduce
  %0 = mesh.reduce %arg0 on @mesh0
    mesh_axes = []
    root = []
    : (tensor<4xf32>) -> tensor<4xf32>
// CHECK: return %[[ARG]]
  return %0 : tensor<4xf32>
}

// CHECK-LABEL: func @reduce_scatter_empty_mesh_axes
func.func @reduce_scatter_empty_mesh_axes(
// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NOT: mesh.reduce_scatter
  %0 = mesh.reduce_scatter %arg0 on @mesh0
    mesh_axes = []
    scatter_axis = 0
    : tensor<4xf32> -> tensor<4xf32>
// CHECK: return %[[ARG]]
  return %0 : tensor<4xf32>
}

// CHECK-LABEL: func @reduce_scatter_empty_mesh_axes_different_return_type
func.func @reduce_scatter_empty_mesh_axes_different_return_type(
    %arg0 : tensor<4xf32>) -> tensor<4xf64> {
// CHECK: mesh.reduce_scatter
  %0 = mesh.reduce_scatter %arg0 on @mesh0
// CHECK-NOT: mesh_axes
    mesh_axes = []
    scatter_axis = 0
    : tensor<4xf32> -> tensor<4xf64>
  return %0 : tensor<4xf64>
}

// CHECK-LABEL: func @reduce_scatter_default_reduction
func.func @reduce_scatter_default_reduction(
    %arg0 : tensor<4xf32>) -> tensor<2xf64> {
  %0 = mesh.reduce_scatter %arg0 on @mesh0
    mesh_axes = [0]
// CHECK-NOT: reduction
    reduction = sum
    scatter_axis = 0
    : tensor<4xf32> -> tensor<2xf64>
  return %0 : tensor<2xf64>
}

// CHECK-LABEL: func @scatter_empty_mesh_axes
func.func @scatter_empty_mesh_axes(
// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NOT: mesh.scatter
  %0 = mesh.scatter %arg0 on @mesh0
    mesh_axes = []
    scatter_axis = 0
    root = []
    : (tensor<4xf32>) -> tensor<4xf32>
// CHECK: return %[[ARG]]
  return %0 : tensor<4xf32>
}

// CHECK-LABEL: func @send_empty_mesh_axes
func.func @send_empty_mesh_axes(
// CHECK-SAME: %[[ARG:.*]]: tensor<4xf32>
    %arg0 : tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NOT: mesh.send
  %0 = mesh.send %arg0 on @mesh0
    mesh_axes = []
    destination = []
    : (tensor<4xf32>) -> tensor<4xf32>
// CHECK: return %[[ARG]]
  return %0 : tensor<4xf32>
}