llvm/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir

// RUN: mlir-opt %s --buffer-deallocation-simplification --split-input-file | FileCheck %s

func.func @dealloc_deallocated_in_retained(%arg0: memref<2xi32>, %arg1: i1, %arg2: memref<2xi32>, %arg3: i1) -> (i1, i1, i1, i1, i1, i1, i1) {
  %0 = bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg1) retain (%arg0 : memref<2xi32>)
  %1 = bufferization.dealloc (%arg0, %arg2 : memref<2xi32>, memref<2xi32>) if (%arg1, %arg1) retain (%arg0 : memref<2xi32>)
  %2:2 = bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg1) retain (%arg0, %arg2 : memref<2xi32>, memref<2xi32>)
  // multiple must-alias
  %3 = memref.subview %arg0[0][1][1] : memref<2xi32> to memref<i32>
  %4 = memref.subview %arg0[1][1][1] : memref<2xi32> to memref<1xi32, strided<[1], offset: 1>>
  %alloc = memref.alloc() : memref<2xi32>
  %5:3 = bufferization.dealloc (%arg0, %4 : memref<2xi32>, memref<1xi32, strided<[1], offset: 1>>) if (%arg1, %arg3) retain (%arg0, %alloc, %3 : memref<2xi32>, memref<2xi32>, memref<i32>)
  return %0, %1, %2#0, %2#1, %5#0, %5#1, %5#2 : i1, i1, i1, i1, i1, i1, i1
}

// CHECK-LABEL: func @dealloc_deallocated_in_retained
//  CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>, [[ARG3:%.+]]: i1)
//  CHECK-NEXT: arith.constant false
//  CHECK-NEXT: [[V1:%.+]] = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG1]]) retain ([[ARG0]] : memref<2xi32>)
//  CHECK-NEXT: [[O1:%.+]] = arith.ori [[V1]], [[ARG1]]
//  CHECK-NEXT: [[V2:%.+]]:2 = bufferization.dealloc ([[ARG0]] : memref<2xi32>) if ([[ARG1]]) retain ([[ARG0]], [[ARG2]] : memref<2xi32>, memref<2xi32>)
// COM: the RemoveRetainedMemrefsGuaranteedToNotAlias pattern removes all the
// COM: retained memrefs since the list of memrefs to be deallocated becomes empty
// COM: due to the pattern under test (and thus there is no memref the retain values
// COM: could alias to)
// CHECK-NOT: if
//  CHECK-NEXT: [[V3:%.+]] = arith.ori [[ARG3]], [[ARG1]]
//  CHECK-NEXT: [[V4:%.+]] = arith.ori [[ARG3]], [[ARG1]]
//  CHECK-NEXT: return [[ARG1]], [[O1]], [[V2]]#0, [[V2]]#1, [[V3]], %false{{[0-9_]*}}, [[V4]] :

// -----

func.func @dealloc_deallocated_in_retained_extract_base_memref(%arg0: memref<2xi32>, %arg1: i1, %arg2: memref<2xi32>, %arg3: i1) -> (i1, i1, i1, i1, i1, i1, i1) {
  %base_buffer, %offset, %size, %stride = memref.extract_strided_metadata %arg0 : memref<2xi32> -> memref<i32>, index, index, index
  %base_buffer0, %offset0, %size0, %stride0 = memref.extract_strided_metadata %arg2 : memref<2xi32> -> memref<i32>, index, index, index
  %0 = bufferization.dealloc (%base_buffer : memref<i32>) if (%arg1) retain (%arg0 : memref<2xi32>)
  %1 = bufferization.dealloc (%base_buffer, %base_buffer0 : memref<i32>, memref<i32>) if (%arg1, %arg1) retain (%arg0 : memref<2xi32>)
  %2:2 = bufferization.dealloc (%base_buffer : memref<i32>) if (%arg1) retain (%arg0, %arg2 : memref<2xi32>, memref<2xi32>)
  // multiple must-alias
  %3 = memref.subview %arg0[0][1][1] : memref<2xi32> to memref<i32>
  %4 = memref.subview %arg0[1][1][1] : memref<2xi32> to memref<1xi32, strided<[1], offset: 1>>
  %alloc = memref.alloc() : memref<2xi32>
  %5:3 = bufferization.dealloc (%base_buffer, %4 : memref<i32>, memref<1xi32, strided<[1], offset: 1>>) if (%arg1, %arg3) retain (%arg0, %alloc, %3 : memref<2xi32>, memref<2xi32>, memref<i32>)
  return %0, %1, %2#0, %2#1, %5#0, %5#1, %5#2 : i1, i1, i1, i1, i1, i1, i1
}

// CHECK-LABEL: func @dealloc_deallocated_in_retained_extract_base_memref
//  CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>, [[ARG3:%.+]]: i1)
//  CHECK-NEXT: arith.constant false
//  CHECK-NEXT: [[BASE0:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ARG0]] :
//  CHECK-NEXT: [[BASE1:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ARG2]] :
//  CHECK-NEXT: [[V1:%.+]] = bufferization.dealloc ([[BASE1]] : memref<i32>) if ([[ARG1]]) retain ([[ARG0]] : memref<2xi32>)
//  CHECK-NEXT: [[O1:%.+]] = arith.ori [[V1]], [[ARG1]]
//  CHECK-NEXT: [[V2:%.+]]:2 = bufferization.dealloc ([[BASE0]] : memref<i32>) if ([[ARG1]]) retain ([[ARG0]], [[ARG2]] : memref<2xi32>, memref<2xi32>)
// COM: the RemoveRetainedMemrefsGuaranteedToNotAlias pattern removes all the
// COM: retained memrefs since the list of memrefs to be deallocated becomes empty
// COM: due to the pattern under test (and thus there is no memref the retain values
// COM: could alias to)
// CHECK-NOT: if
//  CHECK-NEXT: [[V3:%.+]] = arith.ori [[ARG3]], [[ARG1]]
//  CHECK-NEXT: [[V4:%.+]] = arith.ori [[ARG3]], [[ARG1]]
//  CHECK-NEXT: return [[ARG1]], [[O1]], [[V2]]#0, [[V2]]#1, [[V3]], %false{{[0-9_]*}}, [[V4]] :

// -----

func.func @remove_retained_memrefs_guarateed_to_not_alias(%arg0: i1, %arg1: memref<2xi32>) -> (i1, i1, memref<2xi32>) {
  %alloc = memref.alloc() : memref<2xi32>
  %alloc0 = memref.alloc() : memref<2xi32>
  %0:2 = bufferization.dealloc (%alloc : memref<2xi32>) if (%arg0) retain (%alloc0, %arg1 : memref<2xi32>, memref<2xi32>)
  return %0#0, %0#1, %alloc : i1, i1, memref<2xi32>
}

// CHECK-LABEL: func @remove_retained_memrefs_guarateed_to_not_alias
//  CHECK-SAME: ([[ARG0:%.+]]: i1, [[ARG1:%.+]]: memref<2xi32>)
//  CHECK-NEXT: [[FALSE:%.+]] = arith.constant false
//  CHECK-NEXT: [[ALLOC:%.+]] = memref.alloc(
//  CHECK-NEXT: bufferization.dealloc ([[ALLOC]] : memref<2xi32>) if ([[ARG0]])
//  CHECK-NOT: retain
//  CHECK-NEXT: return [[FALSE]], [[FALSE]], [[ALLOC]] :

// -----

func.func @dealloc_split_when_no_other_aliasing(%arg0: i1, %arg1: memref<2xi32>, %arg2: memref<2xi32>, %arg3: i1) -> (i1, i1) {
  %alloc = memref.alloc() : memref<2xi32>
  %alloc0 = memref.alloc() : memref<2xi32>
  %0 = arith.select %arg0, %alloc, %alloc0 : memref<2xi32>
  %1:2 = bufferization.dealloc (%alloc, %arg2 : memref<2xi32>, memref<2xi32>) if (%arg0, %arg3) retain (%arg1, %0 : memref<2xi32>, memref<2xi32>)
  return %1#0, %1#1 : i1, i1
}

// CHECK-LABEL: func @dealloc_split_when_no_other_aliasing
//  CHECK-SAME: ([[ARG0:%.+]]: i1, [[ARG1:%.+]]: memref<2xi32>, [[ARG2:%.+]]: memref<2xi32>, [[ARG3:%.+]]: i1)
//  CHECK-NEXT:   [[ALLOC0:%.+]] = memref.alloc(
//  CHECK-NEXT:   [[ALLOC1:%.+]] = memref.alloc(
//  CHECK-NEXT:   [[V0:%.+]] = arith.select{{.*}}[[ALLOC0]], [[ALLOC1]] :
// COM: there is only one value in the retained lists because the
// COM: RemoveRetainedMemrefsGuaranteedToNotAlias pattern also applies here:
// COM: - %alloc is guaranteed to not alias with %arg1.
// COM: - %arg2 is guaranteed to not alias with %0.
//  CHECK-NEXT:   [[V1:%.+]] = bufferization.dealloc ([[ALLOC0]] : memref<2xi32>) if ([[ARG0]]) retain ([[V0]] : memref<2xi32>)
//  CHECK-NEXT:   [[V2:%.+]] = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG3]]) retain ([[ARG1]] : memref<2xi32>)
//  CHECK-NEXT:   return [[V2]], [[V1]] :

// -----

func.func @dealloc_remove_dealloc_memref_contained_in_retained_with_const_true_condition(
  %arg0: memref<2xi32>, %arg1: memref<2xi32>, %arg2: memref<2xi32>) -> (memref<2xi32>, memref<2xi32>, i1, i1) {
  %true = arith.constant true
  %0:2 = bufferization.dealloc (%arg0, %arg1, %arg2 : memref<2xi32>, memref<2xi32>, memref<2xi32>) if (%true, %true, %true) retain (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
  return %arg0, %arg1, %0#0, %0#1 : memref<2xi32>, memref<2xi32>, i1, i1
}

// CHECK-LABEL: func @dealloc_remove_dealloc_memref_contained_in_retained_with_const_true_condition
//  CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: memref<2xi32>, [[ARG2:%.+]]: memref<2xi32>)
//       CHECK:   bufferization.dealloc ([[ARG2]] :{{.*}}) if (%true{{[0-9_]*}})
//  CHECK-NEXT:   return [[ARG0]], [[ARG1]], %true{{[0-9_]*}}, %true{{[0-9_]*}} :

// -----

func.func @dealloc_remove_dealloc_memref_contained_in_retained_with_const_true_condition(
  %arg0: memref<2xi32>, %arg1: memref<2xi32>, %arg2: memref<2xi32>) -> (memref<2xi32>, memref<2xi32>, i1, i1) {
  %true = arith.constant true
  %base_buffer, %offset, %size, %stride = memref.extract_strided_metadata %arg0 : memref<2xi32> -> memref<i32>, index, index, index
  %base_buffer_1, %offset_1, %size_1, %stride_1 = memref.extract_strided_metadata %arg1 : memref<2xi32> -> memref<i32>, index, index, index
  %base_buffer_2, %offset_2, %size_2, %stride_2 = memref.extract_strided_metadata %arg2 : memref<2xi32> -> memref<i32>, index, index, index
  %0:2 = bufferization.dealloc (%base_buffer, %base_buffer_1, %base_buffer_2 : memref<i32>, memref<i32>, memref<i32>) if (%true, %true, %true) retain (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
  return %arg0, %arg1, %0#0, %0#1 : memref<2xi32>, memref<2xi32>, i1, i1
}

// CHECK-LABEL: func @dealloc_remove_dealloc_memref_contained_in_retained_with_const_true_condition
//  CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: memref<2xi32>, [[ARG2:%.+]]: memref<2xi32>)
//       CHECK:   [[BASE:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[ARG2]]
//       CHECK:   bufferization.dealloc ([[BASE]] :{{.*}}) if (%true{{[0-9_]*}})
//  CHECK-NEXT:   return [[ARG0]], [[ARG1]], %true{{[0-9_]*}}, %true{{[0-9_]*}} :

// -----

func.func @alloc_and_bbarg(%arg0: memref<5xf32>, %arg1: index, %arg2: index, %arg3: index) -> f32 {
  %true = arith.constant true
  %false = arith.constant false
  %0:2 = scf.for %arg4 = %arg1 to %arg2 step %arg3 iter_args(%arg5 = %arg0, %arg6 = %false) -> (memref<5xf32>, i1) {
    %alloc = memref.alloc() : memref<5xf32>
    memref.copy %arg5, %alloc : memref<5xf32> to memref<5xf32>
    %base_buffer_0, %offset_1, %sizes_2, %strides_3 = memref.extract_strided_metadata %arg5 : memref<5xf32> -> memref<f32>, index, index, index
    %2 = bufferization.dealloc (%base_buffer_0, %alloc : memref<f32>, memref<5xf32>) if (%arg6, %true) retain (%alloc : memref<5xf32>)
    scf.yield %alloc, %2 : memref<5xf32>, i1
  }
  %1 = memref.load %0#0[%arg1] : memref<5xf32>
  %base_buffer, %offset, %sizes, %strides = memref.extract_strided_metadata %0#0 : memref<5xf32> -> memref<f32>, index, index, index
  bufferization.dealloc (%base_buffer : memref<f32>) if (%0#1)
  return %1 : f32
}

// CHECK-LABEL: func @alloc_and_bbarg
//       CHECK:   %[[true:.*]] = arith.constant true
//       CHECK:   scf.for {{.*}} iter_args(%[[iter:.*]] = %{{.*}}, %{{.*}} = %{{.*}})
//       CHECK:     %[[alloc:.*]] = memref.alloc
//       CHECK:     %[[view:.*]], %{{.*}}, %{{.*}}, %{{.*}} =  memref.extract_strided_metadata %[[iter]]
//       CHECK:     bufferization.dealloc (%[[view]] : memref<f32>)
//   CHECK-NOT:     retain
//       CHECK:     scf.yield %[[alloc]], %[[true]]

// -----

func.func @duplicate_memref(%arg0: memref<5xf32>, %arg1: memref<6xf32>, %c: i1) -> i1 {
  %0 = bufferization.dealloc (%arg0, %arg0 : memref<5xf32>, memref<5xf32>) if (%c, %c) retain (%arg1 : memref<6xf32>)
  return %0 : i1
}

// CHECK-LABEL: func @duplicate_memref(
//       CHECK:   %[[r:.*]] = bufferization.dealloc (%{{.*}} : memref<5xf32>) if (%{{.*}}) retain (%{{.*}} : memref<6xf32>)
//       CHECK:   return %[[r]]