// Copyright 2021 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/util/tflite/operations/transform_landmarks.h"
#include "tensorflow/lite/delegates/gpu/common/mediapipe/transform_landmarks.h"
#include "tensorflow/lite/delegates/gpu/common/types.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/padding.h"
#include "tensorflow/lite/schema/schema_generated.h"
namespace mediapipe {
namespace tflite_operations {
namespace {
constexpr int kDataInput0Tensor = 0;
constexpr int kDataInput1Tensor = 1;
constexpr int kOutputTensor = 0;
float DotProduct(const tflite::gpu::float4& l, const tflite::gpu::float4& r) {
return l.x * r.x + l.y * r.y + l.z * r.z + l.w * r.w;
}
namespace v1 {
inline void TransformLandmarks(
const tflite::gpu::TransformLandmarksAttributes& params,
const tflite::RuntimeShape& input0_shape, const float* landmarks,
const tflite::RuntimeShape& input1_shape, const float* transform_matrix,
const tflite::RuntimeShape& output_shape, float* output_data) {
TFLITE_CHECK_EQ(input0_shape.DimensionsCount(), 4);
TFLITE_CHECK_EQ(output_shape.DimensionsCount(), 4);
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);
const int output_channels = output_shape.Dims(3);
TFLITE_CHECK_EQ(input0_shape.Dims(3) % params.dimensions, 0);
TFLITE_CHECK_NE(params.scale, 0);
tflite::RuntimeShape input_shape_with_batch{/*batch=*/1, input0_shape.Dims(1),
input0_shape.Dims(2),
input0_shape.Dims(3)};
tflite::RuntimeShape output_shape_with_batch{
/*batch=*/1, output_shape.Dims(1), output_shape.Dims(2),
output_shape.Dims(3)};
// Read first two rows of transformation matrix
tflite::gpu::float4 x_transform(transform_matrix[0], transform_matrix[1],
transform_matrix[2],
transform_matrix[3] * params.scale);
tflite::gpu::float4 y_transform(transform_matrix[4], transform_matrix[5],
transform_matrix[6],
transform_matrix[7] * params.scale);
for (int out_y = 0; out_y < output_height; ++out_y) {
for (int out_x = 0; out_x < output_width; ++out_x) {
for (int landmark = 0; landmark < output_channels / params.dimensions;
++landmark) {
const int offset = Offset(output_shape_with_batch, 0, out_y, out_x,
landmark * params.dimensions);
if (params.dimensions == 2) {
tflite::gpu::float4 lv(landmarks[offset], landmarks[offset + 1],
static_cast<float>(0.0),
static_cast<float>(1.0));
tflite::gpu::float2 transformed(DotProduct(x_transform, lv),
DotProduct(y_transform, lv));
output_data[offset] = transformed.x;
output_data[offset + 1] = transformed.y;
}
if (params.dimensions == 3) {
tflite::gpu::float4 lv(landmarks[offset], landmarks[offset + 1],
static_cast<float>(0.0),
static_cast<float>(1.0));
tflite::gpu::float3 transformed(DotProduct(x_transform, lv),
DotProduct(y_transform, lv), lv.z);
output_data[offset] = transformed.x;
output_data[offset + 1] = transformed.y;
output_data[offset + 2] = landmarks[offset + 2];
}
}
}
}
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, tflite::NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, tflite::NumOutputs(node), 1);
const TfLiteTensor* input =
tflite::GetInput(context, node, kDataInput0Tensor);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* output = tflite::GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_EQ(context, tflite::NumDimensions(input), 4);
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
TfLiteIntArray* output_size = TfLiteIntArrayCreate(4);
output_size->data[0] = input->dims->data[0];
output_size->data[1] = input->dims->data[1];
output_size->data[2] = input->dims->data[2];
output_size->data[3] = input->dims->data[3];
return context->ResizeTensor(context, output, output_size);
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::gpu::TransformLandmarksAttributes op_params;
tflite::gpu::BHWC output_shape;
auto status = tflite::gpu::ParseTransformLandmarksV1Attributes(
node->custom_initial_data, node->custom_initial_data_size, &op_params,
&output_shape);
if (!status.ok()) {
context->ReportError(context, status.message().data());
return kTfLiteError;
}
if (op_params.dimensions != 3 && op_params.dimensions != 2) {
context->ReportError(context, "Incorrect dimensions size: %d",
op_params.dimensions);
return kTfLiteError;
}
if (op_params.scale == 0) {
context->ReportError(context, "Incorrect scale value: %d", op_params.scale);
return kTfLiteError;
}
const TfLiteTensor* input0 =
tflite::GetInput(context, node, kDataInput0Tensor);
TF_LITE_ENSURE(context, input0 != nullptr);
const TfLiteTensor* input1 =
tflite::GetInput(context, node, kDataInput1Tensor);
TF_LITE_ENSURE(context, input1 != nullptr);
TfLiteTensor* output = tflite::GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TransformLandmarks(
op_params, tflite::GetTensorShape(input0),
tflite::GetTensorData<float>(input0), tflite::GetTensorShape(input1),
tflite::GetTensorData<float>(input1), tflite::GetTensorShape(output),
tflite::GetTensorData<float>(output));
return kTfLiteOk;
}
} // namespace v1
namespace v2 {
inline void TransformLandmarksV2(
const tflite::gpu::TransformLandmarksAttributes& params,
const tflite::RuntimeShape& input0_shape, const float* landmarks,
const float* transform_matrix, // transformation matrix
const tflite::RuntimeShape& output_shape, float* output_data) {
TFLITE_CHECK_EQ(input0_shape.DimensionsCount(), 3);
TFLITE_CHECK_EQ(output_shape.DimensionsCount(), 3);
const int output_width = output_shape.Dims(1);
TFLITE_CHECK_EQ(input0_shape.Dims(2) % params.dimensions, 0);
tflite::RuntimeShape input_shape_with_batch{/*batch=*/1, input0_shape.Dims(0),
input0_shape.Dims(1),
input0_shape.Dims(2)};
tflite::RuntimeShape output_shape_with_batch{
/*batch=*/1, output_shape.Dims(0), output_shape.Dims(1),
output_shape.Dims(2)};
// Read first two rows of transformation matrix
tflite::gpu::float4 x_transform(transform_matrix[0], transform_matrix[1],
transform_matrix[2], transform_matrix[3]);
tflite::gpu::float4 y_transform(transform_matrix[4], transform_matrix[5],
transform_matrix[6], transform_matrix[7]);
for (int landmark = 0; landmark < output_width; ++landmark) {
const int offset = Offset(input_shape_with_batch, 0, 0, landmark, 0);
if (params.dimensions == 2) {
tflite::gpu::float4 lv(landmarks[offset], landmarks[offset + 1],
static_cast<float>(0.0), static_cast<float>(1.0));
tflite::gpu::float2 transformed(DotProduct(x_transform, lv),
DotProduct(y_transform, lv));
output_data[offset] = transformed.x;
output_data[offset + 1] = transformed.y;
}
if (params.dimensions == 3) {
tflite::gpu::float4 lv(landmarks[offset], landmarks[offset + 1],
static_cast<float>(0.0), static_cast<float>(1.0));
tflite::gpu::float3 transformed(DotProduct(x_transform, lv),
DotProduct(y_transform, lv), lv.z);
output_data[offset] = transformed.x;
output_data[offset + 1] = transformed.y;
output_data[offset + 2] = landmarks[offset + 2];
}
}
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, tflite::NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, tflite::NumOutputs(node), 1);
const TfLiteTensor* input =
tflite::GetInput(context, node, kDataInput0Tensor);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* output = tflite::GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_EQ(context, tflite::NumDimensions(input), 3);
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
output_size->data[0] = input->dims->data[0];
output_size->data[1] = input->dims->data[1];
output_size->data[2] = input->dims->data[2];
return context->ResizeTensor(context, output, output_size);
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::gpu::TransformLandmarksAttributes op_params;
TfLiteTensor* output = tflite::GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
tflite::RuntimeShape runtime_output_shape = tflite::GetTensorShape(output);
tflite::gpu::BHWC output_shape(1, runtime_output_shape.Dims(0),
runtime_output_shape.Dims(1),
runtime_output_shape.Dims(2));
auto status = tflite::gpu::ParseTransformLandmarksV2Attributes(
node->custom_initial_data, node->custom_initial_data_size, &op_params,
&output_shape);
if (!status.ok()) {
context->ReportError(context, status.message().data());
return kTfLiteError;
}
if (op_params.dimensions != 3 && op_params.dimensions != 2) {
context->ReportError(context, "Incorrect dimensions size: %d",
op_params.dimensions);
return kTfLiteError;
}
const TfLiteTensor* input0 =
tflite::GetInput(context, node, kDataInput0Tensor);
TF_LITE_ENSURE(context, input0 != nullptr);
const TfLiteTensor* input1 =
tflite::GetInput(context, node, kDataInput1Tensor);
TF_LITE_ENSURE(context, input1 != nullptr);
TransformLandmarksV2(op_params, tflite::GetTensorShape(input0),
tflite::GetTensorData<float>(input0),
tflite::GetTensorData<float>(input1),
tflite::GetTensorShape(output),
tflite::GetTensorData<float>(output));
return kTfLiteOk;
}
} // namespace v2
} // namespace
TfLiteRegistration* RegisterTransformLandmarksV1() {
static TfLiteRegistration reg = {
/*.init=*/nullptr,
/*.free=*/nullptr,
/*.prepare=*/v1::Prepare,
/*.invoke=*/v1::Eval,
/*.profiling_string=*/nullptr,
/*.builtin_code=*/tflite::BuiltinOperator_CUSTOM,
/*.custom_name=*/"TransformLandmarks",
/*.version=*/1,
};
return ®
}
TfLiteRegistration* RegisterTransformLandmarksV2() {
static TfLiteRegistration reg = {
/*.init=*/nullptr,
/*.free=*/nullptr,
/*.prepare=*/v2::Prepare,
/*.invoke=*/v2::Eval,
/*.profiling_string=*/nullptr,
/*.builtin_code=*/tflite::BuiltinOperator_CUSTOM,
/*.custom_name=*/"TransformLandmarks",
/*.version=*/2,
};
return ®
}
} // namespace tflite_operations
} // namespace mediapipe