From 300388985a1835c993499cb413b8d900666651ed Mon Sep 17 00:00:00 2001
From: Robert Ogden <[email protected]>
Date: Wed, 30 Nov 2022 10:24:16 -0800
Subject: [PATCH 05/10] check cancel flag before calling invoke
---
.../cc/port/default/tflite_wrapper.cc | 14 ++++++++++----
1 file changed, 10 insertions(+), 4 deletions(-)
diff --git a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc
index dce66868264cb..70fedca9a3f22 100644
--- a/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc
+++ b/third_party/tflite_support/src/tensorflow_lite_support/cc/port/default/tflite_wrapper.cc
@@ -258,8 +258,10 @@ absl::Status TfLiteInterpreterWrapper::InvokeWithFallback(
const std::function<absl::Status(tflite::Interpreter* interpreter)>&
set_inputs) {
RETURN_IF_ERROR(set_inputs(interpreter_.get()));
- // Reset cancel flag before calling `Invoke()`.
- cancel_flag_.Set(false);
+ if (cancel_flag_.Get()) {
+ cancel_flag_.Set(false);
+ return absl::CancelledError("cancelled before Invoke() was called");
+ }
TfLiteStatus status = kTfLiteError;
if (fallback_on_execution_error_) {
status = InterpreterUtils::InvokeWithCPUFallback(interpreter_.get());
@@ -273,6 +275,7 @@ absl::Status TfLiteInterpreterWrapper::InvokeWithFallback(
// Assume the inference is cancelled successfully if Invoke() returns
// kTfLiteError and the cancel flag is `true`.
if (status == kTfLiteError && cancel_flag_.Get()) {
+ cancel_flag_.Set(false);
return absl::CancelledError("Invoke() cancelled.");
}
if (delegate_) {
@@ -289,14 +292,17 @@ absl::Status TfLiteInterpreterWrapper::InvokeWithFallback(
}
absl::Status TfLiteInterpreterWrapper::InvokeWithoutFallback() {
- // Reset cancel flag before calling `Invoke()`.
- cancel_flag_.Set(false);
+ if (cancel_flag_.Get()) {
+ cancel_flag_.Set(false);
+ return absl::CancelledError("cancelled before Invoke() was called");
+ }
TfLiteStatus status = interpreter_->Invoke();
if (status != kTfLiteOk) {
// Assume InvokeWithoutFallback() is guarded under caller's synchronization.
// Assume the inference is cancelled successfully if Invoke() returns
// kTfLiteError and the cancel flag is `true`.
if (status == kTfLiteError && cancel_flag_.Get()) {
+ cancel_flag_.Set(false);
return absl::CancelledError("Invoke() cancelled.");
}
return absl::InternalError("Invoke() failed.");
--
2.42.0.515.g380fc7ccd1-goog