chromium/third_party/mediapipe/src/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc

// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <algorithm>

#include "absl/status/status.h"

#if !defined(__ANDROID__)
#include "mediapipe/framework/port/file_helpers.h"
#endif
#include "absl/log/absl_log.h"
#include "absl/strings/str_replace.h"
#include "mediapipe/calculators/tensorflow/tensorflow_session.h"
#include "mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.pb.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/packet_generator.h"
#include "mediapipe/framework/packet_type.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/tool/status_util.h"
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/cc/saved_model/loader.h"
#include "tensorflow/cc/saved_model/tag_constants.h"

namespace mediapipe {

namespace {

constexpr char kSessionTag[] = "SESSION";

static constexpr char kStringSavedModelPath[] = "STRING_SAVED_MODEL_PATH";

static constexpr char kStringSignatureName[] = "STRING_SIGNATURE_NAME";

// Given the path to a directory containing multiple tensorflow saved models
// in subdirectories, replaces path with the alphabetically last subdirectory.
absl::Status GetLatestDirectory(std::string* path) {
#if defined(__ANDROID__)
  return absl::UnimplementedError(
      "GetLatestDirectory is not implemented on Android");
#else
  std::vector<std::string> saved_models;
  RET_CHECK_OK(file::MatchInTopSubdirectories(
      *path, tensorflow::kSavedModelFilenamePb, &saved_models));
  RET_CHECK_GT(saved_models.size(), 0)
      << "No exported bundles found in " << path;
  ::std::sort(saved_models.begin(), saved_models.end());
  *path = std::string(file::Dirname(saved_models.back()));
  return absl::OkStatus();
#endif
}

// If options.convert_signature_to_tags() is set, will convert letters to
// uppercase and replace /, -, and .'s with _'s. This enables the standard
// SavedModel classification, regression, and prediction signatures to be used
// as uppercase INPUTS and OUTPUTS tags for streams and supports other common
// patterns.
const std::string MaybeConvertSignatureToTag(
    const std::string& name,
    const TensorFlowSessionFromSavedModelGeneratorOptions& options) {
  if (options.convert_signature_to_tags()) {
    std::string output;
    output.resize(name.length());
    std::transform(name.begin(), name.end(), output.begin(),
                   [](unsigned char c) { return std::toupper(c); });
    output = absl::StrReplaceAll(
        output, {{"/", "_"}, {"-", "_"}, {".", "_"}, {":", "_"}});
    ABSL_LOG(INFO) << "Renamed TAG from: " << name << " to " << output;
    return output;
  } else {
    return name;
  }
}

}  // namespace

// TensorFlowSessionFromSavedModelGenerator is a MediaPipe packet generator
// that loads a trained TensorFlow model exported via SavedModel's exporter and
// returns a Packet containing a unique_ptr to a mediapipe::TensorFlowSession,
// which in turn contains a TensorFlow Session ready for execution and a map
// between tags and tensor names.
class TensorFlowSessionFromSavedModelGenerator : public PacketGenerator {
 public:
  static absl::Status FillExpectations(
      const PacketGeneratorOptions& extendable_options,
      PacketTypeSet* input_side_packets, PacketTypeSet* output_side_packets) {
    const TensorFlowSessionFromSavedModelGeneratorOptions& options =
        extendable_options.GetExtension(
            TensorFlowSessionFromSavedModelGeneratorOptions::ext);
    const bool has_exactly_one_model =
        options.saved_model_path().empty() ==
        input_side_packets->HasTag(kStringSavedModelPath);
    RET_CHECK(has_exactly_one_model)
        << "Must have exactly one of saved model filepath in options or "
           "input_side_packets STRING_MODEL_FILE_PATH";
    // Path of savedmodel.
    if (input_side_packets->HasTag(kStringSavedModelPath)) {
      input_side_packets->Tag(kStringSavedModelPath).Set<std::string>();
    }
    // Set Signature_def.
    if (input_side_packets->HasTag(kStringSignatureName)) {
      input_side_packets->Tag(kStringSignatureName).Set<std::string>();
    }
    // A TensorFlow model loaded and ready for use along with tensor
    output_side_packets->Tag(kSessionTag).Set<TensorFlowSession>();
    return absl::OkStatus();
  }

  static absl::Status Generate(const PacketGeneratorOptions& extendable_options,
                               const PacketSet& input_side_packets,
                               PacketSet* output_side_packets) {
    const TensorFlowSessionFromSavedModelGeneratorOptions& options =
        extendable_options.GetExtension(
            TensorFlowSessionFromSavedModelGeneratorOptions::ext);
    std::string path =
        input_side_packets.HasTag(kStringSavedModelPath)
            ? input_side_packets.Tag(kStringSavedModelPath).Get<std::string>()
            : options.saved_model_path();
    if (options.load_latest_model()) {
      RET_CHECK_OK(GetLatestDirectory(&path));
    }

    // Set user specified tags properly.
    // If no tags specified will use tensorflow::kSavedModelTagServe by default.
    std::unordered_set<std::string> tags_set;
    for (const std::string& tag : options.saved_model_tag()) {
      tags_set.insert(tag);
    }
    if (tags_set.empty()) {
      tags_set.insert(tensorflow::kSavedModelTagServe);
    }

    tensorflow::RunOptions run_options;
    tensorflow::SessionOptions session_options;
    session_options.config = options.session_config();
    auto saved_model = absl::make_unique<tensorflow::SavedModelBundle>();
    ::tensorflow::Status status = tensorflow::LoadSavedModel(
        session_options, run_options, path, tags_set, saved_model.get());
    if (!status.ok()) {
      return absl::Status(static_cast<absl::StatusCode>(status.code()),
                          status.ToString());
    }
    auto session = absl::make_unique<TensorFlowSession>();
    session->session = std::move(saved_model->session);

    // Use input side packet to overwrite signature name in options.
    std::string signature_name =
        input_side_packets.HasTag(kStringSignatureName)
            ? input_side_packets.Tag(kStringSignatureName).Get<std::string>()
            : options.signature_name();
    RET_CHECK(!signature_name.empty());
    const auto& signature_def_map = saved_model->meta_graph_def.signature_def();
    if (signature_def_map.find(signature_name) == signature_def_map.end()) {
      return absl::NotFoundError(absl::StrFormat(
          "Signature name '%s' does not exist in the loaded signature def",
          signature_name));
    }
    const auto& signature_def = signature_def_map.at(signature_name);
    for (const auto& input_signature : signature_def.inputs()) {
      session->tag_to_tensor_map[MaybeConvertSignatureToTag(
          input_signature.first, options)] = input_signature.second.name();
    }
    for (const auto& output_signature : signature_def.outputs()) {
      session->tag_to_tensor_map[MaybeConvertSignatureToTag(
          output_signature.first, options)] = output_signature.second.name();
    }

    output_side_packets->Tag(kSessionTag) = Adopt(session.release());
    return absl::OkStatus();
  }
};
REGISTER_PACKET_GENERATOR(TensorFlowSessionFromSavedModelGenerator);

}  // namespace mediapipe