// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#if !defined(__ANDROID__)
#include "mediapipe/framework/port/file_helpers.h"
#endif
#include "absl/log/absl_log.h"
#include "absl/strings/str_replace.h"
#include "mediapipe/calculators/tensorflow/tensorflow_session.h"
#include "mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/cc/saved_model/loader.h"
#include "tensorflow/cc/saved_model/tag_constants.h"
namespace mediapipe {
namespace {
constexpr char kSessionTag[] = "SESSION";
static constexpr char kStringSavedModelPath[] = "STRING_SAVED_MODEL_PATH";
// Given the path to a directory containing multiple tensorflow saved models
// in subdirectories, replaces path with the alphabetically last subdirectory.
absl::Status GetLatestDirectory(std::string* path) {
#if defined(__ANDROID__)
return absl::UnimplementedError(
"GetLatestDirectory is not implemented on Android");
#else
std::vector<std::string> saved_models;
RET_CHECK_OK(file::MatchInTopSubdirectories(
*path, tensorflow::kSavedModelFilenamePb, &saved_models));
RET_CHECK_GT(saved_models.size(), 0)
<< "No exported bundles found in " << path;
::std::sort(saved_models.begin(), saved_models.end());
*path = std::string(file::Dirname(saved_models.back()));
return absl::OkStatus();
#endif
}
// If options.convert_signature_to_tags() is set, will convert letters to
// uppercase and replace /, -, . and :'s with _'s. This enables the standard
// SavedModel classification, regression, and prediction signatures to be used
// as uppercase INPUTS and OUTPUTS tags for streams and supports other common
// patterns.
const std::string MaybeConvertSignatureToTag(
const std::string& name,
const TensorFlowSessionFromSavedModelCalculatorOptions& options) {
if (options.convert_signature_to_tags()) {
std::string output;
output.resize(name.length());
std::transform(name.begin(), name.end(), output.begin(),
[](unsigned char c) { return std::toupper(c); });
output = absl::StrReplaceAll(
output, {{"/", "_"}, {"-", "_"}, {".", "_"}, {":", "_"}});
ABSL_LOG(INFO) << "Renamed TAG from: " << name << " to " << output;
return output;
} else {
return name;
}
}
} // namespace
// TensorFlowSessionFromSavedModelCalculator is a MediaPipe packet calculator
// that loads a trained TensorFlow model exported via SavedModel's exporter and
// returns a Packet containing a unique_ptr to a mediapipe::TensorFlowSession,
// which in turn contains a TensorFlow Session ready for execution and a map
// between tags and tensor names.
//
//
// Example usage:
// node {
// calculator: "TensorFlowSessionFromSavedModelCalculator"
// output_side_packet: "SESSION:vod_session"
// options {
// [mediapipe.TensorFlowSessionFromSavedModelCalculatorOptions.ext]: {
// signature_name: "serving_default"
// saved_model_path: "path/to/model"
// }
// }
// }
class TensorFlowSessionFromSavedModelCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
const auto& options =
cc->Options<TensorFlowSessionFromSavedModelCalculatorOptions>();
const bool has_exactly_one_model =
options.saved_model_path().empty() ==
cc->InputSidePackets().HasTag(kStringSavedModelPath);
RET_CHECK(has_exactly_one_model)
<< "Must have exactly one of saved model filepath in options or "
"input_side_packets STRING_MODEL_FILE_PATH";
// Path of savedmodel.
if (cc->InputSidePackets().HasTag(kStringSavedModelPath)) {
cc->InputSidePackets().Tag(kStringSavedModelPath).Set<std::string>();
}
// A TensorFlow model loaded and ready for use along with tensor
cc->OutputSidePackets().Tag(kSessionTag).Set<TensorFlowSession>();
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) override {
const auto& options =
cc->Options<TensorFlowSessionFromSavedModelCalculatorOptions>();
std::string path = cc->InputSidePackets().HasTag(kStringSavedModelPath)
? cc->InputSidePackets()
.Tag(kStringSavedModelPath)
.Get<std::string>()
: options.saved_model_path();
if (options.load_latest_model()) {
RET_CHECK_OK(GetLatestDirectory(&path));
}
// Set user specified tags properly.
// If no tags specified will use tensorflow::kSavedModelTagServe by default.
std::unordered_set<std::string> tags_set;
for (const std::string& tag : options.saved_model_tag()) {
tags_set.insert(tag);
}
if (tags_set.empty()) {
tags_set.insert(tensorflow::kSavedModelTagServe);
}
tensorflow::RunOptions run_options;
tensorflow::SessionOptions session_options;
session_options.config = options.session_config();
auto saved_model = absl::make_unique<tensorflow::SavedModelBundle>();
::tensorflow::Status status = tensorflow::LoadSavedModel(
session_options, run_options, path, tags_set, saved_model.get());
if (!status.ok()) {
return absl::Status(static_cast<absl::StatusCode>(status.code()),
status.ToString());
}
auto session = absl::make_unique<TensorFlowSession>();
session->session = std::move(saved_model->session);
RET_CHECK(!options.signature_name().empty());
const auto& signature_def_map = saved_model->meta_graph_def.signature_def();
const auto& signature_def = signature_def_map.at(options.signature_name());
for (const auto& input_signature : signature_def.inputs()) {
session->tag_to_tensor_map[MaybeConvertSignatureToTag(
input_signature.first, options)] = input_signature.second.name();
}
for (const auto& output_signature : signature_def.outputs()) {
session->tag_to_tensor_map[MaybeConvertSignatureToTag(
output_signature.first, options)] = output_signature.second.name();
}
cc->OutputSidePackets().Tag(kSessionTag).Set(Adopt(session.release()));
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
return absl::OkStatus();
}
};
REGISTER_CALCULATOR(TensorFlowSessionFromSavedModelCalculator);
} // namespace mediapipe