chromium/third_party/mediapipe/src/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator.cc

// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Reads serialized GraphDef proto. There are three ways to load a model:
// 1. Specify the path to a graph.pb in the calculator options.
// 2. Specify the path to the graph.pb through the
// input_side_packet:STRING_MODEL_FILE_PATH
// 3. Provide a serialized GraphDef through input_side_packet:STRING_MODEL,
// typically provided by EmbeddingFilePacketFactory.
//
// Produces a SessionBundle that TensorFlowInferenceCalculator can use.

#include <string>

#include "absl/log/absl_log.h"
#include "mediapipe/calculators/tensorflow/tensorflow_session.h"
#include "mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/deps/clock.h"
#include "mediapipe/framework/deps/monotonic_clock.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/tool/status_util.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/public/session_options.h"

#if defined(MEDIAPIPE_MOBILE)
#include "mediapipe/util/android/file/base/helpers.h"
#else
#include "mediapipe/framework/port/file_helpers.h"
#endif

namespace mediapipe {

namespace tf = ::tensorflow;

namespace {

constexpr char kSessionTag[] = "SESSION";
constexpr char kStringModelFilePathTag[] = "STRING_MODEL_FILE_PATH";
constexpr char kStringModelTag[] = "STRING_MODEL";

// Updates the graph nodes to use the device as specified by device_id.
void SetPreferredDevice(tf::GraphDef* graph_def, absl::string_view device_id) {
  for (auto& node : *graph_def->mutable_node()) {
    if (node.device().empty()) {
      node.set_device(std::string(device_id));
    }
  }
}
}  // namespace

class TensorFlowSessionFromFrozenGraphCalculator : public CalculatorBase {
 public:
  static absl::Status GetContract(CalculatorContract* cc) {
    const auto& options =
        cc->Options<TensorFlowSessionFromFrozenGraphCalculatorOptions>();
    bool has_exactly_one_model =
        !options.graph_proto_path().empty()
            ? !(cc->InputSidePackets().HasTag(kStringModelTag) |
                cc->InputSidePackets().HasTag(kStringModelFilePathTag))
            : (cc->InputSidePackets().HasTag(kStringModelTag) ^
               cc->InputSidePackets().HasTag(kStringModelFilePathTag));
    RET_CHECK(has_exactly_one_model)
        << "Must have exactly one of graph_proto_path in options or "
           "input_side_packets STRING_MODEL or STRING_MODEL_FILE_PATH";
    if (cc->InputSidePackets().HasTag(kStringModelTag)) {
      cc->InputSidePackets()
          .Tag(kStringModelTag)
          .Set<std::string>(
              // String model from embedded path
          );
    } else if (cc->InputSidePackets().HasTag(kStringModelFilePathTag)) {
      cc->InputSidePackets()
          .Tag(kStringModelFilePathTag)
          .Set<std::string>(
              // Filename of string model.
          );
    }
    cc->OutputSidePackets()
        .Tag(kSessionTag)
        .Set<TensorFlowSession>(
            // A TensorFlow model loaded and ready for use along with
            // a map from tags to tensor names.
        );
    RET_CHECK_GT(options.tag_to_tensor_names().size(), 0);
    return absl::OkStatus();
  }

  absl::Status Open(CalculatorContext* cc) override {
    auto clock = std::unique_ptr<mediapipe::Clock>(
        mediapipe::MonotonicClock::CreateSynchronizedMonotonicClock());
    const uint64_t start_time = absl::ToUnixMicros(clock->TimeNow());
    const auto& options =
        cc->Options<TensorFlowSessionFromFrozenGraphCalculatorOptions>();
    // Output bundle packet.
    auto session = ::absl::make_unique<TensorFlowSession>();

    tf::SessionOptions session_options;
    session_options.config.CopyFrom(options.config());
    std::vector<mediapipe::ProtoString> initialization_op_names;
    initialization_op_names.reserve(options.initialization_op_names_size());
    for (int i = 0; i < options.initialization_op_names_size(); ++i) {
      initialization_op_names.emplace_back(options.initialization_op_names(i));
    }
    session->session.reset(tf::NewSession(session_options));

    std::string graph_def_serialized;
    if (cc->InputSidePackets().HasTag(kStringModelTag)) {
      graph_def_serialized =
          cc->InputSidePackets().Tag(kStringModelTag).Get<std::string>();
    } else if (cc->InputSidePackets().HasTag(kStringModelFilePathTag)) {
      const std::string& frozen_graph = cc->InputSidePackets()
                                            .Tag(kStringModelFilePathTag)
                                            .Get<std::string>();
      RET_CHECK_OK(
          mediapipe::file::GetContents(frozen_graph, &graph_def_serialized));
    } else {
      RET_CHECK_OK(mediapipe::file::GetContents(options.graph_proto_path(),
                                                &graph_def_serialized));
    }
    tensorflow::GraphDef graph_def;

    RET_CHECK(graph_def.ParseFromString(graph_def_serialized));

    // Update the graph nodes to use the preferred device, if set.
    if (!options.preferred_device_id().empty()) {
      SetPreferredDevice(&graph_def, options.preferred_device_id());
    }

    const tf::Status tf_status = session->session->Create(graph_def);
    RET_CHECK(tf_status.ok()) << "Create failed: " << tf_status.ToString();

    for (const auto& key_value : options.tag_to_tensor_names()) {
      session->tag_to_tensor_map[key_value.first] = key_value.second;
    }
    if (!initialization_op_names.empty()) {
      const tf::Status tf_status =
          session->session->Run({}, {}, initialization_op_names, {});
      // RET_CHECK on the tf::Status object itself in order to print an
      // informative error message.
      RET_CHECK(tf_status.ok()) << "Run failed: " << tf_status.ToString();
    }

    cc->OutputSidePackets().Tag(kSessionTag).Set(Adopt(session.release()));
    const uint64_t end_time = absl::ToUnixMicros(clock->TimeNow());
    ABSL_LOG(INFO) << "Loaded frozen model in: " << end_time - start_time
                   << " microseconds.";
    return absl::OkStatus();
  }

  absl::Status Process(CalculatorContext* cc) override {
    return absl::OkStatus();
  }
};
REGISTER_CALCULATOR(TensorFlowSessionFromFrozenGraphCalculator);

}  // namespace mediapipe