chromium/third_party/tensorflow_models/patches/0002-Add-int8-quantization-support-to-the-denylist-op.patch

From a7b009d15354e8d80eec463b95c7a7a08489bc58 Mon Sep 17 00:00:00 2001
From: Robert Ogden <[email protected]>
Date: Wed, 26 Jun 2024 08:56:07 -0700
Subject: [PATCH] Add int8 quantization support to the denylist op

---
 .../seq_flow_lite/tflite_ops/denylist.cc      | 58 ++++++++++++-------
 1 file changed, 36 insertions(+), 22 deletions(-)

diff --git a/third_party/tensorflow_models/src/research/seq_flow_lite/tflite_ops/denylist.cc b/third_party/tensorflow_models/src/research/seq_flow_lite/tflite_ops/denylist.cc
index 99671a6bb947e..97f58fdf7c9e4 100644
--- a/third_party/tensorflow_models/src/research/seq_flow_lite/tflite_ops/denylist.cc
+++ b/third_party/tensorflow_models/src/research/seq_flow_lite/tflite_ops/denylist.cc
@@ -14,6 +14,8 @@ limitations under the License.
 ==============================================================================*/
 #include "tflite_ops/denylist.h"  // seq_flow_lite
 
+#include <cstdint>
+
 #include "absl/container/flat_hash_set.h"
 #include "tensorflow/lite/context.h"
 #include "tflite_ops/quantization_util.h"  // seq_flow_lite
@@ -44,6 +46,34 @@ TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node) {
       output_dims);
 }
 
+template <typename T>
+TfLiteStatus QuantizedEval(DenylistOp* op,
+                           TfLiteContext* context,
+                           TfLiteTensor* output_categories,
+                           int input_size) {
+  const uint8_t one = PodQuantize<T>(1.0, output_categories->params.zero_point,
+                                     1.0 / output_categories->params.scale);
+  const uint8_t zero = PodQuantize<T>(0.0, output_categories->params.zero_point,
+                                      1.0 / output_categories->params.scale);
+  int n_categories = op->categories();
+  for (int i = 0; i < input_size; i++) {
+    absl::flat_hash_set<int> categories;
+    TF_LITE_ENSURE_STATUS(op->GetCategories(context, i, categories));
+    if (categories.empty()) {
+      for (int j = 0; j < n_categories; j++) {
+        tflite::GetTensorData<T>(output_categories)[i * n_categories + j] =
+            (j < op->negative_categories()) ? one : zero;
+      }
+    } else {
+      for (int j = 0; j < n_categories; j++) {
+        tflite::GetTensorData<T>(output_categories)[i * n_categories + j] =
+            (categories.find(j) != categories.end()) ? one : zero;
+      }
+    }
+  }
+  return kTfLiteOk;
+}
+
 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
   auto* op = reinterpret_cast<DenylistOp*>(node->user_data);
   TF_LITE_ENSURE_STATUS(op->CheckErrors(context));
@@ -56,10 +86,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
   for (int i = 0; i < input_dims->size; i++) {
     input_size *= input_dims->data[i];
   }
-  int n_categories = op->categories();
 
   TF_LITE_ENSURE_STATUS(op->InitializeInput(context, node));
   if (output_categories->type == kTfLiteFloat32) {
+    int n_categories = op->categories();
     for (int i = 0; i < input_size; i++) {
       absl::flat_hash_set<int> categories;
       TF_LITE_ENSURE_STATUS(op->GetCategories(context, i, categories));
@@ -76,27 +106,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
       }
     }
   } else if (output_categories->type == kTfLiteUInt8) {
-    const uint8_t one =
-        PodQuantize<uint8_t>(1.0, output_categories->params.zero_point,
-                             1.0 / output_categories->params.scale);
-    const uint8_t zero =
-        PodQuantize<uint8_t>(0.0, output_categories->params.zero_point,
-                             1.0 / output_categories->params.scale);
-    for (int i = 0; i < input_size; i++) {
-      absl::flat_hash_set<int> categories;
-      TF_LITE_ENSURE_STATUS(op->GetCategories(context, i, categories));
-      if (categories.empty()) {
-        for (int j = 0; j < n_categories; j++) {
-          output_categories->data.uint8[i * n_categories + j] =
-              (j < op->negative_categories()) ? one : zero;
-        }
-      } else {
-        for (int j = 0; j < n_categories; j++) {
-          output_categories->data.uint8[i * n_categories + j] =
-              (categories.find(j) != categories.end()) ? one : zero;
-        }
-      }
-    }
+    TF_LITE_ENSURE_STATUS(
+        QuantizedEval<uint8_t>(op, context, output_categories, input_size));
+  } else if (output_categories->type == kTfLiteInt8) {
+    TF_LITE_ENSURE_STATUS(
+        QuantizedEval<int8_t>(op, context, output_categories, input_size));
   }
   op->FinalizeInput();
   return kTfLiteOk;
-- 
2.45.2.741.gdbec12cfda-goog