// Copyright 2024 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/tensor/inference_runner_ml_drift_opencl_delegate.h"
#include <memory>
#include <utility>
#include <vector>
#include "absl/log/absl_check.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "mediapipe/calculators/tensor/inference_calculator_utils.h"
#include "mediapipe/calculators/tensor/inference_io_mapper.h"
#include "mediapipe/calculators/tensor/tensor_span.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/util/tflite/tflite_model_loader.h"
#include "tensorflow/lite/c/c_api_types.h"
#include "tensorflow/lite/core/interpreter_builder.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/util.h"
#include "third_party/ml_drift/contrib/tflite_op_resolver.h"
#include "third_party/odml/infra/ml_drift_delegate/ml_drift_cl.h"
namespace mediapipe::api2 {
using ::tflite::ml_drift::MlDriftClDelegateDefaultOptionsPtr;
using ::tflite::ml_drift::TfLiteCreateMlDriftClDelegate;
absl::Status InferenceRunnerMlDriftOpenClDelegate::Init(
CalculatorContext* cc, Packet<TfLiteModelPtr> model_packet,
Packet<ml_drift::contrib::TfLiteOpResolver> op_resolver_packet) {
const auto& options = cc->Options<InferenceCalculatorOptions>();
RET_CHECK_EQ(options.delegate().gpu().api(),
InferenceCalculatorOptions::Delegate::Gpu::ML_DRIFT_OPENCL);
model_packet_ = model_packet;
tflite::InterpreterBuilder builder(*model_packet_.Get(),
op_resolver_packet.Get());
builder(&interpreter_);
ABSL_CHECK(interpreter_ != nullptr);
MP_ASSIGN_OR_RETURN(
input_output_tensor_names_,
InferenceIoMapper::GetInputOutputTensorNamesFromInterpreter(
*interpreter_));
// Initialize ML Drift CL.
auto delegate_options = MlDriftClDelegateDefaultOptionsPtr();
delegate_options->enable_fast_tuning = true;
tflite::Interpreter::TfLiteDelegatePtr delegate =
TfLiteCreateMlDriftClDelegate(std::move(delegate_options));
ABSL_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(std::move(delegate)),
kTfLiteOk);
interpreter_->AllocateTensors();
return absl::OkStatus();
}
absl::StatusOr<std::vector<Tensor>> InferenceRunnerMlDriftOpenClDelegate::Run(
CalculatorContext* cc, const TensorSpan& input_tensors) {
// If the input tensors have dynamic shape, then the tensors need to be
// resized and reallocated before we can copy the tensor values.
bool resized_tensor_shapes = false;
for (int i = 0; i < input_tensors.size(); ++i) {
const Tensor& input_tensor = input_tensors[i];
if (input_tensor.shape().is_dynamic) {
const TfLiteTensor* interpreter_tensor =
interpreter_->tensor(interpreter_->inputs()[i]);
// TODO: Can avoid copying even these <= 4 values in the future.
std::vector<int> interpreter_dims{
interpreter_tensor->dims->data,
interpreter_tensor->dims->data + interpreter_tensor->dims->size};
if (interpreter_dims != input_tensor.shape().dims) {
interpreter_->ResizeInputTensorStrict(i, input_tensor.shape().dims);
resized_tensor_shapes = true;
}
}
}
// Reallocation is needed for memory sanity.
if (resized_tensor_shapes) {
interpreter_->AllocateTensors();
}
for (int i = 0; i < input_tensors.size(); ++i) {
const Tensor& input_tensor = input_tensors[i];
MP_RETURN_IF_ERROR(
CopyCpuInputIntoInterpreterTensor(input_tensor, *interpreter_, i));
}
ABSL_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
MP_ASSIGN_OR_RETURN(auto output_tensors,
AllocateOutputTensors(*interpreter_));
for (int i = 0; i < output_tensors.size(); ++i) {
const int output_tensor_index = interpreter_->outputs()[i];
MP_RETURN_IF_ERROR(CopyInterpreterTensorIntoCpuOutput(
*interpreter_, output_tensor_index, output_tensors[i]));
}
return output_tensors;
}
absl::StatusOr<std::vector<Tensor>>
InferenceRunnerMlDriftOpenClDelegate::AllocateOutputTensors(
const tflite::Interpreter& interpreter) {
const int num_outputs = interpreter.outputs().size();
std::vector<Tensor> output_tensors;
output_tensors.reserve(num_outputs);
for (int i = 0; i < num_outputs; ++i) {
const TfLiteTensor* reference_tensor =
interpreter.tensor(interpreter.outputs()[i]);
MP_ASSIGN_OR_RETURN(Tensor output_tensor,
CreateTensorWithTfLiteTensorSpecs(
*reference_tensor, /*memory_manager=*/nullptr,
tflite::kDefaultTensorAlignment));
output_tensors.push_back(std::move(output_tensor));
}
return output_tensors;
}
const InputOutputTensorNames&
InferenceRunnerMlDriftOpenClDelegate::GetInputOutputTensorNames() const {
return input_output_tensor_names_;
}
} // namespace mediapipe::api2