llvm/mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir

// RUN: mlir-opt %s --pre-sparsification-rewrite --sparse-reinterpret-map  | FileCheck %s --check-prefix=CHECK-FOLD
// RUN: mlir-opt %s --pre-sparsification-rewrite --sparse-reinterpret-map --sparsification | FileCheck %s

#trait = {
  indexing_maps = [
      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
  ],
  iterator_types = ["parallel", "parallel", "parallel", "parallel"]
}

#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>

#COO = #sparse_tensor.encoding<{map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton(nonunique, soa), d2 : singleton(soa))}>
#CCCD = #sparse_tensor.encoding<{ map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : compressed, d2 : compressed, d3 : dense) }>

// CHECK-LABEL:   func.func @fold_convert(
// CHECK:           scf.for
// CHECK:             scf.for
// CHECK:               scf.for
// CHECK:                 scf.if
// CHECK-NEXT:               tensor.insert
// CHECK-NEXT:               scf.yield
// CHECK-NEXT:             else
// CHECK-NEXT:               scf.yield
// CHECK:                 scf.yield
// CHECK:               scf.yield
// CHECK:             scf.yield
// CHECK:           sparse_tensor.load

// CHECK-FOLD-LABEL:   func.func @fold_convert(
// CHECK-FOLD-NOT:     sparse_tensor.convert
func.func @fold_convert(%arg0: tensor<128x32x32x1xf32>, %arg1: tensor<128x32x32x1xf32>, %arg2: tensor<128x32x32x1xf32>) -> tensor<128x32x32x1xf32, #CCCD> {
  %cst = arith.constant 0.000000e+00 : f32
  %cst_0 = arith.constant 1.000000e+00 : f32
  %cst_1 = arith.constant 1.000000e+00 : f32
  %0 = tensor.empty() : tensor<128x32x32x1xf32>
  %1 = linalg.generic #trait
  ins(%arg0, %arg1, %arg2 : tensor<128x32x32x1xf32>, tensor<128x32x32x1xf32>, tensor<128x32x32x1xf32>)
  outs(%0 : tensor<128x32x32x1xf32>) {
    ^bb0(%in: f32, %in_2: f32, %in_3: f32, %out: f32):
      %3 = arith.subf %cst_0, %in_2 : f32
      %4 = arith.mulf %in, %3 : f32
      %5 = arith.mulf %4, %cst_1 : f32
      %6 = arith.addf %5, %in_3 : f32
      %7 = arith.subf %6, %cst_0 : f32
      %8 = arith.cmpf uge, %7, %cst : f32
      %9 = arith.uitofp %8 : i1 to f32
      linalg.yield %9 : f32
    } -> tensor<128x32x32x1xf32>
  %2 = sparse_tensor.convert %1 : tensor<128x32x32x1xf32> to tensor<128x32x32x1xf32, #CCCD>
  return %2 : tensor<128x32x32x1xf32, #CCCD>
}

#trait_bin = {
  indexing_maps = [
      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
  ],
  iterator_types = ["parallel", "parallel", "parallel", "parallel"]
}

// CHECK-FOLD-LABEL:   func.func @fold_convert_multi_use(
// CHECK-FOLD:           tensor.empty() : tensor<128x32x32x1xf32>
// CHECK-FOLD:           linalg.generic
// CHECK-FOLD:           tensor.empty() : tensor<128x32x32x1xf32, #sparse>
// CHECK-FOLD:           linalg.generic
// CHECK-FOLD-NOT:       sparse_tensor.convert
func.func @fold_convert_multi_use(%arg0: tensor<128x32x32x1xf32>, %arg1: tensor<128x32x32x1xf32>,
                        %arg2: tensor<128x32x32x1xf32>, %arg3: tensor<128x32x32x1xf32>) -> (tensor<128x32x32x1xf32>, tensor<128x32x32x1xf32, #CCCD>) {
  %cst = arith.constant 0.000000e+00 : f32
  %cst_0 = arith.constant 1.000000e+00 : f32
  %cst_1 = arith.constant 1.000000e+00 : f32

  %0 = tensor.empty() : tensor<128x32x32x1xf32>
  %1 = linalg.generic #trait_bin
  ins(%arg0, %arg1 : tensor<128x32x32x1xf32>, tensor<128x32x32x1xf32>)
  outs(%0 : tensor<128x32x32x1xf32>) {
    ^bb0(%in: f32, %in_1: f32, %out: f32):
      %3 = arith.mulf %in, %in_1 : f32
      linalg.yield %3 : f32
    } -> tensor<128x32x32x1xf32>

  // A second kernel that uses %0 as the init operand.
  %3 = linalg.generic #trait_bin
  ins(%arg2, %arg3 : tensor<128x32x32x1xf32>, tensor<128x32x32x1xf32>)
  outs(%0 : tensor<128x32x32x1xf32>) {
    ^bb0(%in: f32, %in_1: f32, %out: f32):
      %3 = arith.mulf %in, %in_1 : f32
      linalg.yield %3 : f32
    } -> tensor<128x32x32x1xf32>
  %4 = sparse_tensor.convert %3 : tensor<128x32x32x1xf32> to tensor<128x32x32x1xf32, #CCCD>

  return %1, %4 : tensor<128x32x32x1xf32>, tensor<128x32x32x1xf32, #CCCD>
}



// FIXME: The following kernel is not sparsifiable because `arith.select`
// operations is not handled by the sparse compiler at the moment.
//
// CHECK-FOLD-LABEL:   func.func @fold_cast(
// CHECK-FOLD-NOT:     sparse_tensor.convert
func.func @fold_cast(%0: tensor<10x20x30xf64, #COO>) -> tensor<10x20x30xf64, #COO> {
  %cst = arith.constant 0.000000e+00 : f64
  %1 = tensor.empty() : tensor<10x20x30xf64>
  %2 = linalg.generic { indexing_maps = [#map, #map],
                        iterator_types = ["parallel", "parallel", "parallel"]
                      }
  ins (%0 : tensor<10x20x30xf64, #COO>)
  outs(%1 : tensor<10x20x30xf64>) {
      ^bb0(%in: f64, %out: f64):
        %4 = arith.cmpf ugt, %in, %cst : f64
        %5 = arith.select %4, %in, %cst : f64
        linalg.yield %5 : f64
  } -> tensor<10x20x30xf64>
  %cast = tensor.cast %2 : tensor<10x20x30xf64> to tensor<10x20x30xf64, #COO>
  return %cast : tensor<10x20x30xf64, #COO>
}