// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/util/tflite/op_resolver.h"
#include "tensorflow/lite/builtin_ops.h"
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api.h"
#include "tensorflow/lite/c/c_api_opaque.h"
namespace mediapipe {
namespace {
constexpr char kMaxPoolingWithArgmax2DOpName[] = "MaxPoolingWithArgmax2D";
constexpr int kMaxPoolingWithArgmax2DOpVersion = 1;
constexpr char kMaxUnpooling2DOpName[] = "MaxUnpooling2D";
constexpr int kMaxUnpooling2DOpVersion = 1;
constexpr char kConvolution2DTransposeBiasOpName[] =
"Convolution2DTransposeBias";
constexpr int kConvolution2DTransposeBiasOpVersion = 1;
TfLiteRegistration* RegisterMaxPoolingWithArgmax2D() {
static TfLiteOperator* reg_external = []() {
// Intentionally allocated and never destroyed.
auto* r = TfLiteOperatorCreate(
kTfLiteBuiltinCustom, kMaxPoolingWithArgmax2DOpName,
kMaxPoolingWithArgmax2DOpVersion, /*user_data=*/nullptr);
TfLiteOperatorSetInit(
r, [](TfLiteOpaqueContext*, const char*, size_t) -> void* {
return new TfLitePaddingValues();
});
TfLiteOperatorSetFree(r, [](TfLiteOpaqueContext*, void* buffer) -> void {
delete reinterpret_cast<TfLitePaddingValues*>(buffer);
});
TfLiteOperatorSetPrepare(
r,
[](TfLiteOpaqueContext* context,
TfLiteOpaqueNode* node) -> TfLiteStatus { return kTfLiteOk; });
TfLiteOperatorSetInvoke(
r, [](TfLiteOpaqueContext* context, TfLiteOpaqueNode*) -> TfLiteStatus {
TfLiteOpaqueContextReportError(
context, "MaxPoolingWithArgmax2D is only available on the GPU.");
return kTfLiteError;
});
return r;
}();
static TfLiteRegistration reg{};
reg.registration_external = reg_external;
return ®
}
TfLiteRegistration* RegisterMaxUnpooling2D() {
static TfLiteOperator* reg_external =
// Intentionally allocated and never destroyed.
TfLiteOperatorCreate(kTfLiteBuiltinCustom, kMaxUnpooling2DOpName,
kMaxUnpooling2DOpVersion,
/*user_data=*/nullptr);
static TfLiteRegistration reg{};
reg.registration_external = reg_external;
return ®
}
TfLiteRegistration* RegisterConvolution2DTransposeBias() {
static TfLiteOperator* reg_external =
// Intentionally allocated and never destroyed.
TfLiteOperatorCreate(
kTfLiteBuiltinCustom, kConvolution2DTransposeBiasOpName,
kConvolution2DTransposeBiasOpVersion, /*user_data=*/nullptr);
static TfLiteRegistration reg{};
reg.registration_external = reg_external;
return ®
}
} // namespace
OpResolver::OpResolver() {
AddCustom(kMaxPoolingWithArgmax2DOpName, RegisterMaxPoolingWithArgmax2D(),
kMaxPoolingWithArgmax2DOpVersion);
AddCustom(kMaxUnpooling2DOpName, RegisterMaxUnpooling2D(),
kMaxUnpooling2DOpVersion);
AddCustom(kConvolution2DTransposeBiasOpName,
RegisterConvolution2DTransposeBias(),
kConvolution2DTransposeBiasOpVersion);
}
} // namespace mediapipe