From a7b009d15354e8d80eec463b95c7a7a08489bc58 Mon Sep 17 00:00:00 2001
From: Robert Ogden <[email protected]>
Date: Wed, 26 Jun 2024 08:56:07 -0700
Subject: [PATCH] Add int8 quantization support to the denylist op
---
.../seq_flow_lite/tflite_ops/denylist.cc | 58 ++++++++++++-------
1 file changed, 36 insertions(+), 22 deletions(-)
diff --git a/third_party/tensorflow_models/src/research/seq_flow_lite/tflite_ops/denylist.cc b/third_party/tensorflow_models/src/research/seq_flow_lite/tflite_ops/denylist.cc
index 99671a6bb947e..97f58fdf7c9e4 100644
--- a/third_party/tensorflow_models/src/research/seq_flow_lite/tflite_ops/denylist.cc
+++ b/third_party/tensorflow_models/src/research/seq_flow_lite/tflite_ops/denylist.cc
@@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tflite_ops/denylist.h" // seq_flow_lite
+#include <cstdint>
+
#include "absl/container/flat_hash_set.h"
#include "tensorflow/lite/context.h"
#include "tflite_ops/quantization_util.h" // seq_flow_lite
@@ -44,6 +46,34 @@ TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node) {
output_dims);
}
+template <typename T>
+TfLiteStatus QuantizedEval(DenylistOp* op,
+ TfLiteContext* context,
+ TfLiteTensor* output_categories,
+ int input_size) {
+ const uint8_t one = PodQuantize<T>(1.0, output_categories->params.zero_point,
+ 1.0 / output_categories->params.scale);
+ const uint8_t zero = PodQuantize<T>(0.0, output_categories->params.zero_point,
+ 1.0 / output_categories->params.scale);
+ int n_categories = op->categories();
+ for (int i = 0; i < input_size; i++) {
+ absl::flat_hash_set<int> categories;
+ TF_LITE_ENSURE_STATUS(op->GetCategories(context, i, categories));
+ if (categories.empty()) {
+ for (int j = 0; j < n_categories; j++) {
+ tflite::GetTensorData<T>(output_categories)[i * n_categories + j] =
+ (j < op->negative_categories()) ? one : zero;
+ }
+ } else {
+ for (int j = 0; j < n_categories; j++) {
+ tflite::GetTensorData<T>(output_categories)[i * n_categories + j] =
+ (categories.find(j) != categories.end()) ? one : zero;
+ }
+ }
+ }
+ return kTfLiteOk;
+}
+
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* op = reinterpret_cast<DenylistOp*>(node->user_data);
TF_LITE_ENSURE_STATUS(op->CheckErrors(context));
@@ -56,10 +86,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
for (int i = 0; i < input_dims->size; i++) {
input_size *= input_dims->data[i];
}
- int n_categories = op->categories();
TF_LITE_ENSURE_STATUS(op->InitializeInput(context, node));
if (output_categories->type == kTfLiteFloat32) {
+ int n_categories = op->categories();
for (int i = 0; i < input_size; i++) {
absl::flat_hash_set<int> categories;
TF_LITE_ENSURE_STATUS(op->GetCategories(context, i, categories));
@@ -76,27 +106,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
}
} else if (output_categories->type == kTfLiteUInt8) {
- const uint8_t one =
- PodQuantize<uint8_t>(1.0, output_categories->params.zero_point,
- 1.0 / output_categories->params.scale);
- const uint8_t zero =
- PodQuantize<uint8_t>(0.0, output_categories->params.zero_point,
- 1.0 / output_categories->params.scale);
- for (int i = 0; i < input_size; i++) {
- absl::flat_hash_set<int> categories;
- TF_LITE_ENSURE_STATUS(op->GetCategories(context, i, categories));
- if (categories.empty()) {
- for (int j = 0; j < n_categories; j++) {
- output_categories->data.uint8[i * n_categories + j] =
- (j < op->negative_categories()) ? one : zero;
- }
- } else {
- for (int j = 0; j < n_categories; j++) {
- output_categories->data.uint8[i * n_categories + j] =
- (categories.find(j) != categories.end()) ? one : zero;
- }
- }
- }
+ TF_LITE_ENSURE_STATUS(
+ QuantizedEval<uint8_t>(op, context, output_categories, input_size));
+ } else if (output_categories->type == kTfLiteInt8) {
+ TF_LITE_ENSURE_STATUS(
+ QuantizedEval<int8_t>(op, context, output_categories, input_size));
}
op->FinalizeInput();
return kTfLiteOk;
--
2.45.2.741.gdbec12cfda-goog