// RUN: mlir-opt --split-input-file --tosa-to-arith="include-apply-rescale=true use-32-bit=true" %s -verify-diagnostics -o -| FileCheck %s
// RUN: mlir-opt --split-input-file --tosa-to-arith="include-apply-rescale=false" %s -verify-diagnostics -o -| FileCheck --check-prefix="SCALE" %s
// CHECK-LABEL: func @const_test
func.func @const_test() -> (tensor<i32>) {
// CHECK: [[C3:%.+]] = arith.constant dense<3> : tensor<i32>
%result = "tosa.const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
// CHECK: return [[C3]]
return %result : tensor<i32>
}
// -----
// CHECK-LABEL: @apply_scale_test_i32
// SCALE: tosa.apply_scale
func.func @apply_scale_test_i32(%arg0 : i32, %arg1 : i32, %arg2 : i8) -> (i32) {
// CHECK-DAG: %[[S32:.+]] = arith.extui %arg2 : i8 to i32
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i32
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : i32
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : i32
// CHECK-DAG: %[[C30:.+]] = arith.constant 30 : i32
// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : i32
// Compute the high-low values of the matmul in 64-bits.
// CHECK-DAG: %[[LOW:.+]], %[[HI:.+]] = arith.mulsi_extended %arg0, %arg1
// Determine whether the high bits need to shift left or right and by how much.
// CHECK-DAG: %[[OVER31:.+]] = arith.cmpi sge, %[[S32]], %[[C32]]
// CHECK-DAG: %[[OVER32:.+]] = arith.cmpi sgt, %[[S32]], %[[C32]]
// CHECK-DAG: %[[HISHLN:.+]] = arith.subi %[[C32]], %[[S32]]
// CHECK-DAG: %[[HISHRN:.+]] = arith.subi %[[S32]], %[[C32]]
// CHECK-DAG: %[[HISHL:.+]] = arith.select %[[OVER31]], %[[C0]], %[[HISHLN]]
// CHECK-DAG: %[[HISHR:.+]] = arith.select %[[OVER31]], %[[HISHRN]], %[[C0]]
// Apply double rounding.
// CHECK-DAG: %[[CN1:.+]] = arith.constant -1
// CHECK-DAG: %[[POS:.+]] = arith.cmpi sge, %arg0, %[[C0]]
// CHECK-DAG: %[[DIR:.+]] = arith.select %[[POS]], %[[C1]], %[[CN1]]
// CHECK-DAG: %[[DRND:.+]] = arith.select %[[OVER31]], %[[DIR]], %[[C0]]
// CHECK-DAG: %[[DSHFTR:.+]] = arith.shrui %[[LOW]], %[[C30]]
// CHECK-DAG: %[[DRNDED:.+]] = arith.addi %[[DSHFTR]], %[[DRND]]
// CHECK-DAG: %[[DCARRY:.+]] = arith.shrsi %[[DRNDED]], %[[C2:.+]]
// CHECK-DAG: %[[DBIT:.+]] = arith.shli %[[DRND]], %[[C30]]
// CHECK-DAG: %[[DLOW:.+]] = arith.addi %[[LOW]], %[[DBIT]]
// CHECK-DAG: %[[DHI:.+]] = arith.addi %[[HI]], %[[DCARRY]]
// Apply low-bit rounding.
// CHECK-DAG: %[[SHFTM1:.+]] = arith.subi %[[S32]], %[[C1]]
// CHECK-DAG: %[[LBIT:.+]] = arith.shli %[[C1]], %[[SHFTM1]]
// CHECK-DAG: %[[HALF:.+]] = arith.select %[[OVER32]], %[[C0]], %[[LBIT]]
// CHECK-DAG: %[[LADD:.+]] = arith.addi %[[DLOW]], %[[HALF]]
// CHECK-DAG: %[[LLO:.+]] = arith.cmpi ugt, %[[DLOW]], %[[LADD]]
// CHECK-DAG: %[[LCARRY:.+]] = arith.extui %[[LLO]] : i1 to i32
// CHECK-DAG: %[[LRNDED:.+]] = arith.addi %[[DHI]], %[[LCARRY]]
// Apply high-bit rounding.
// CHECK-DAG: %[[HISHRM1:.+]] = arith.subi %[[HISHR]], %[[C1]]
// CHECK-DAG: %[[LHISHFT:.+]] = arith.shli %[[C1]], %[[HISHRM1]]
// CHECK-DAG: %[[LHI:.+]] = arith.select %[[OVER32]], %[[LHISHFT]], %[[C0]]
// CHECK-DAG: %[[FHI:.+]] = arith.addi %[[LRNDED]], %[[LHI]]
// Combine hi-low into the final result.
// CHECK-DAG: %[[HIL:.+]] = arith.shli %[[FHI]], %[[HISHL]]
// CHECK-DAG: %[[HIALIGN:.+]] = arith.shrsi %[[HIL:.+]], %[[HISHR]]
// CHECK-DAG: %[[LOR:.+]] = arith.shrui %[[LADD]], %[[S32]]
// CHECK-DAG: %[[LOWALIGN:.+]] = arith.select %[[OVER31]], %[[C0]], %[[LOR]]
// CHECK-DAG: %[[RESULT:.+]] = arith.addi %[[LOWALIGN]], %[[HIALIGN]]
// CHECK: return %[[RESULT]]
%res = tosa.apply_scale %arg0, %arg1, %arg2 {double_round = true} : (i32, i32, i8) -> i32
return %res : i32
}
// -----
// CHECK-LABEL: @apply_scale_test_vector
// SCALE: tosa.apply_scale
func.func @apply_scale_test_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>, %arg2 : vector<4xi8>) -> (vector<4xi32>) {
// CHECK-NOT: "tosa.apply_scale"
%res = tosa.apply_scale %arg0, %arg1, %arg2 {double_round = true} : (vector<4xi32>, vector<4xi32>, vector<4xi8>) -> vector<4xi32>
return %res : vector<4xi32>
}
// -----
// CHECK-LABEL: @apply_scale_test_i48
// SCALE: tosa.apply_scale
func.func @apply_scale_test_i48(%arg0 : i48, %arg1 : i32, %arg2 : i8) -> (i32) {
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i48
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : i64
// CHECK-DAG: %[[C31:.+]] = arith.constant 31 : i32
// Multiply in 64 bits.
// CHECK-DAG: %[[V64:.+]] = arith.extsi %arg0 : i48 to i64
// CHECK-DAG: %[[M64:.+]] = arith.extsi %arg1 : i32 to i64
// CHECK-DAG: %[[MUL:.+]] = arith.muli %[[V64]], %[[M64]]
// Round normally.
// CHECK-DAG: %[[S32:.+]] = arith.extui %arg2 : i8 to i32
// CHECK-DAG: %[[S64:.+]] = arith.extui %[[S32]] : i32 to i64
// CHECK-DAG: %[[ONEL:.+]] = arith.shli %[[C1]], %[[S64]] : i64
// CHECK-DAG: %[[ONER:.+]] = arith.shrui %[[ONEL]], %[[C1]]
// CHECK-DAG: %[[ROUND:.+]] = arith.addi %[[MUL]], %[[ONER]]
// Apply double rounding.
// CHECK-DAG: %[[DUP:.+]] = arith.constant 1073741824 : i64
// CHECK-DAG: %[[DDOWN:.+]] = arith.constant -1073741824 : i64
// CHECK-DAG: %[[POS:.+]] = arith.cmpi sge, %arg0, %[[C0]]
// CHECK-DAG: %[[DBIT:.+]] = arith.select %[[POS]], %[[DUP]], %[[DDOWN]]
// CHECK-DAG: %[[DRND:.+]] = arith.addi %[[DBIT]], %[[ROUND]]
// CHECK-DAG: %[[USED:.+]] = arith.cmpi sgt, %[[S32]], %[[C31]] : i32
// CHECK-DAG: %[[RES64:.+]] = arith.select %[[USED]], %[[DRND]], %[[ROUND]] : i64
// Shift and truncate final answer.
// CHECK-DAG: %[[SHR:.+]] = arith.shrsi %[[RES64]], %[[S64]]
// CHECK-DAG: %[[TRUNC:.+]] = arith.trunci %[[SHR]] : i64 to i32
// CHECK: return %[[TRUNC]]
%res = tosa.apply_scale %arg0, %arg1, %arg2 {double_round = true} : (i48, i32, i8) -> i32
return %res : i32
}
// -----
// CHECK-LABEL: @apply_scale_test_i64
// SCALE: tosa.apply_scale
func.func @apply_scale_test_i64(%arg0 : i64, %arg1 : i32, %arg2 : i8) -> (i32) {
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i64
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : i64
// CHECK-DAG: %[[C31:.+]] = arith.constant 31 : i32
// Multiply in 64 bits.
// CHECK-DAG: %[[M64:.+]] = arith.extsi %arg1 : i32 to i64
// CHECK-DAG: %[[MUL:.+]] = arith.muli %arg0, %[[M64]]
// Round normally.
// CHECK-DAG: %[[S32:.+]] = arith.extui %arg2 : i8 to i32
// CHECK-DAG: %[[S64:.+]] = arith.extui %[[S32]] : i32 to i64
// CHECK-DAG: %[[ONEL:.+]] = arith.shli %[[C1]], %[[S64]] : i64
// CHECK-DAG: %[[ONER:.+]] = arith.shrui %[[ONEL]], %[[C1]]
// CHECK-DAG: %[[ROUND:.+]] = arith.addi %[[MUL]], %[[ONER]]
// Apply double rounding.
// CHECK-DAG: %[[DUP:.+]] = arith.constant 1073741824 : i64
// CHECK-DAG: %[[DDOWN:.+]] = arith.constant -1073741824 : i64
// CHECK-DAG: %[[POS:.+]] = arith.cmpi sge, %arg0, %[[C0]]
// CHECK-DAG: %[[DBIT:.+]] = arith.select %[[POS]], %[[DUP]], %[[DDOWN]]
// CHECK-DAG: %[[DRND:.+]] = arith.addi %[[DBIT]], %[[ROUND]]
// CHECK-DAG: %[[USED:.+]] = arith.cmpi sgt, %[[S32]], %[[C31]] : i32
// CHECK-DAG: %[[RES64:.+]] = arith.select %[[USED]], %[[DRND]], %[[ROUND]] : i64
// Shift and truncate final answer.
// CHECK-DAG: %[[SHR:.+]] = arith.shrsi %[[RES64]], %[[S64]]
// CHECK-DAG: %[[TRUNC:.+]] = arith.trunci %[[SHR]] : i64 to i32
// CHECK: return %[[TRUNC]]
%res = tosa.apply_scale %arg0, %arg1, %arg2 {double_round = true} : (i64, i32, i8) -> i32
return %res : i32
}