chromium/third_party/mediapipe/src/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.cc

// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <algorithm>
#include <cstdint>
#include <map>
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>

#include "absl/base/thread_annotations.h"
#include "absl/log/absl_check.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_split.h"
#include "absl/synchronization/mutex.h"
#include "mediapipe/calculators/tensorflow/tensorflow_inference_calculator.pb.h"
#include "mediapipe/calculators/tensorflow/tensorflow_session.h"
#include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/deps/clock.h"
#include "mediapipe/framework/deps/monotonic_clock.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/map_util.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/framework/timestamp.h"
#include "mediapipe/framework/tool/status_util.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_util.h"

#if !defined(MEDIAPIPE_MOBILE) && !defined(__APPLE__)
#include "tensorflow/core/profiler/lib/traceme.h"
#endif

namespace tf = ::tensorflow;

namespace mediapipe {

namespace {

constexpr char kRecurrentInitTensorsTag[] = "RECURRENT_INIT_TENSORS";
constexpr char kSessionTag[] = "SESSION";
constexpr char kSessionBundleTag[] = "SESSION_BUNDLE";

// This is a simple implementation of a semaphore using standard C++ libraries.
// It is supposed to be used only by TensorflowInferenceCalculator to throttle
// the concurrent calls of Tensorflow Session::Run. This is useful when multiple
// threads execute the graph (e.g. in a mapreduce type of job) but not to
// overload GPU/TPU/...
class SimpleSemaphore {
 public:
  explicit SimpleSemaphore(uint32_t initial_count) : count_(initial_count) {}
  SimpleSemaphore(const SimpleSemaphore&) = delete;
  SimpleSemaphore(SimpleSemaphore&&) = delete;

  // Acquires the semaphore by certain amount.
  void Acquire(uint32_t amount) {
    mutex_.Lock();
    while (count_ < amount) {
      cond_.Wait(&mutex_);
    }
    count_ -= amount;
    mutex_.Unlock();
  }

  // Releases the semaphore by certain amount.
  void Release(uint32_t amount) {
    mutex_.Lock();
    count_ += amount;
    cond_.SignalAll();
    mutex_.Unlock();
  }

 private:
  uint32_t count_;
  absl::Mutex mutex_;
  absl::CondVar cond_;
};

class InferenceState {
 public:
  InferenceState() : input_tensor_batches_(), batch_timestamps_() {}
  // A mapping between stream tags and the tensors we are collecting as a
  // batch.
  std::map<std::string, std::vector<tf::Tensor>> input_tensor_batches_;
  // The timestamps that go into a batch.
  std::vector<Timestamp> batch_timestamps_;
};

}  // namespace

// This calculator performs inference on a trained TensorFlow model.
//
// TensorFlow Sessions can be created from checkpoint paths, frozen models, or
// the SavedModel system. See the TensorFlowSessionFrom* packet generators for
// details. Each of these methods defines a mapping between MediaPipe streams
// and TensorFlow tensors. All of this information is passed in as an
// input_side_packet.
//
// The input and output streams are TensorFlow tensors labeled by tags. The tags
// for the streams are matched to feeds and fetches in a TensorFlow session
// using a named_signature.generic_signature in the ModelManifest. The
// generic_signature is used as key-value pairs between the MediaPipe tag and
// the TensorFlow tensor. The signature_name in the options proto determines
// which named_signature is used. The keys in the generic_signature must be
// valid MediaPipe tags ([A-Z0-9_]*, no lowercase or special characters). All of
// the tensors corresponding to tags in the signature for input_streams are fed
// to the model and for output_streams the tensors are fetched from the model.
//
// Other calculators are used to convert data to and from tensors, this op only
// handles the TensorFlow session and batching. Batching occurs by concatenating
// input tensors along the 0th dimension across timestamps. If the 0th dimension
// is not a batch dimension, this calculator will add a 0th dimension by
// default. Setting add_batch_dim_to_tensors to false disables the dimension
// addition. Once batch_size inputs have been provided, the batch will be run
// and the output tensors sent out on the output streams with timestamps
// corresponding to the input stream packets. Setting the batch_size to 1
// completely disables batching, but is independent of add_batch_dim_to_tensors.
//
// The TensorFlowInferenceCalculator also support feeding states recurrently for
// RNNs and LSTMs. Simply set the recurrent_tag_pair options to define the
// recurrent tensors. Initializing the recurrent state can be handled by the
// GraphTensorsPacketGenerator.
//
// The calculator updates two Counters to report timing information:
//   --<name>-TotalTimeUsecs = Total time spent running inference (in usecs),
//   --<name>-TotalProcessedTimestamps = # of instances processed
//         (approximately batches processed  * batch_size),
// where <name> is replaced with CalculatorGraphConfig::Node::name() if it
// exists, or with TensorFlowInferenceCalculator if the name is not set. The
// name must be set for timing information to be instance-specific in graphs
// with multiple TensorFlowInferenceCalculators.
//
// Example config:
//   packet_generator {
//     packet_generator: "TensorFlowSessionFromSavedModelGenerator"
//     output_side_packet: "tensorflow_session"
//     options {
//       [mediapipe.TensorFlowSessionFromSavedModelGeneratorOptions.ext]: {
//         saved_model_path: "/path/to/saved/model"
//         signature_name: "mediapipe"
//       }
//     }
//   }
//   node {
//     calculator: "TensorFlowInferenceCalculator"
//     input_stream: "IMAGES:image_tensors_keyed_in_signature_by_tag"
//     input_stream: "AUDIO:audio_tensors_keyed_in_signature_by_tag"
//     output_stream: "LABELS:softmax_tensor_keyed_in_signature_by_tag"
//     input_side_packet: "SESSION:tensorflow_session"
//   }
//
//   Where the input and output streams are treated as Packet<tf::Tensor> and
//   the mediapipe_signature has tensor bindings between "IMAGES", "AUDIO", and
//   "LABELS" and their respective tensors exported to /path/to/bundle. For an
//   example of how this model was exported, see
//   tensorflow_inference_test_graph_generator.py
//
// It is possible to use a GraphDef proto that was not exported by exporter (i.e
// without MetaGraph with bindings). Such GraphDef could contain all of its
// parameters in-lined (for example, it can be the output of freeze_graph.py).
// To instantiate a TensorFlow model from a GraphDef file, replace the
// packet_factory above with TensorFlowSessionFromFrozenGraphGenerator:
//
//   packet_generator {
//     packet_generator: "TensorFlowSessionFromFrozenGraphGenerator"
//     output_side_packet: "SESSION:tensorflow_session"
//     options {
//       [mediapipe.TensorFlowSessionFromFrozenGraphGeneratorOptions.ext]: {
//         graph_proto_path: "[PATH]"
//         tag_to_tensor_names {
//           key: "JPG_STRING"
//           value: "input:0"
//         }
//         tag_to_tensor_names {
//           key: "SOFTMAX"
//           value: "softmax:0"
//         }
//       }
//     }
//   }
//
// It is also possible to use a GraphDef proto and checkpoint file that have not
// been frozen. This can be used to load graphs directly as they have been
// written from training. However, it is more brittle and you are encouraged to
// use a one of the more perminent formats described above. To instantiate a
// TensorFlow model from a GraphDef file and checkpoint, replace the
// packet_factory above with TensorFlowSessionFromModelCheckpointGenerator:
//
//   packet_generator {
//     packet_generator: "TensorFlowSessionFromModelCheckpointGenerator"
//     output_side_packet: "SESSION:tensorflow_session"
//     options {
//       [mediapipe.TensorFlowSessionFromModelCheckpointGeneratorOptions.ext]: {
//         graph_proto_path: "[PATH]"
//         model_options {
//           checkpoint_path: "[PATH2]"
//         }
//         tag_to_tensor_names {
//           key: "JPG_STRING"
//           value: "input:0"
//         }
//         tag_to_tensor_names {
//           key: "SOFTMAX"
//           value: "softmax:0"
//         }
//       }
//     }
//   }
class TensorFlowInferenceCalculator : public CalculatorBase {
 public:
  // Counters for recording timing information. The actual names have the value
  // of CalculatorGraphConfig::Node::name() prepended.
  static constexpr char kTotalUsecsCounterSuffix[] = "TotalTimeUsecs";
  static constexpr char kTotalProcessedTimestampsCounterSuffix[] =
      "TotalProcessedTimestamps";
  static constexpr char kTotalSessionRunsTimeUsecsCounterSuffix[] =
      "TotalSessionRunsTimeUsecs";
  static constexpr char kTotalNumSessionRunsCounterSuffix[] =
      "TotalNumSessionRuns";

  TensorFlowInferenceCalculator() : session_(nullptr) {
    clock_ = std::unique_ptr<mediapipe::Clock>(
        mediapipe::MonotonicClock::CreateSynchronizedMonotonicClock());
  }

  static absl::Status GetContract(CalculatorContract* cc) {
    const auto& options = cc->Options<TensorFlowInferenceCalculatorOptions>();
    RET_CHECK(!cc->Inputs().GetTags().empty());
    for (const std::string& tag : cc->Inputs().GetTags()) {
      // The tensorflow::Tensor with the tag equal to the graph node. May
      // have a TimeSeriesHeader if all present TimeSeriesHeaders match.
      if (!options.batched_input()) {
        cc->Inputs().Tag(tag).Set<tf::Tensor>();
      } else {
        cc->Inputs().Tag(tag).Set<std::vector<mediapipe::Packet>>();
      }
    }
    RET_CHECK(!cc->Outputs().GetTags().empty());
    for (const std::string& tag : cc->Outputs().GetTags()) {
      // The tensorflow::Tensor with tag equal to the graph node to
      // output.  Any TimeSeriesHeader from the inputs will be forwarded
      // with channels set to 0.
      cc->Outputs().Tag(tag).Set<tf::Tensor>();
    }
    // A mediapipe::TensorFlowSession with a model loaded and ready for use.
    // For this calculator it must include a tag_to_tensor_map.
    cc->InputSidePackets().Tag(kSessionTag).Set<TensorFlowSession>();
    if (cc->InputSidePackets().HasTag(kRecurrentInitTensorsTag)) {
      cc->InputSidePackets()
          .Tag(kRecurrentInitTensorsTag)
          .Set<std::unique_ptr<std::map<std::string, tf::Tensor>>>();
    }
    return absl::OkStatus();
  }

  std::unique_ptr<InferenceState> CreateInferenceState(CalculatorContext* cc)
      ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
    std::unique_ptr<InferenceState> inference_state =
        absl::make_unique<InferenceState>();
    if (cc->InputSidePackets().HasTag(kRecurrentInitTensorsTag) &&
        !cc->InputSidePackets().Tag(kRecurrentInitTensorsTag).IsEmpty()) {
      std::map<std::string, tf::Tensor>* init_tensor_map;
      init_tensor_map = GetFromUniquePtr<std::map<std::string, tf::Tensor>>(
          cc->InputSidePackets().Tag(kRecurrentInitTensorsTag));
      for (const auto& p : *init_tensor_map) {
        inference_state->input_tensor_batches_[p.first].emplace_back(p.second);
      }
    }
    return inference_state;
  }

  absl::Status Open(CalculatorContext* cc) override {
    options_ = cc->Options<TensorFlowInferenceCalculatorOptions>();

    RET_CHECK(cc->InputSidePackets().HasTag(kSessionTag));
    session_ = cc->InputSidePackets()
                   .Tag(kSessionTag)
                   .Get<TensorFlowSession>()
                   .session.get();
    tag_to_tensor_map_ = cc->InputSidePackets()
                             .Tag(kSessionTag)
                             .Get<TensorFlowSession>()
                             .tag_to_tensor_map;

    // Validate and store the recurrent tags
    RET_CHECK(options_.has_batch_size());
    RET_CHECK(options_.batch_size() == 1 ||
              options_.recurrent_tag_pair().empty())
        << "To use recurrent_tag_pairs, batch_size must be 1.";

    // Helper for StrJoin. Prints key (tag) of map<string, string>.
    auto TagFormatter =
        absl::PairFormatter(absl::StreamFormatter(), "",
                            [](std::string* out, const std::string& second) {});

    for (const auto& tag_pair : options_.recurrent_tag_pair()) {
      const std::vector<std::string> tags = absl::StrSplit(tag_pair, ':');
      RET_CHECK_EQ(tags.size(), 2) << "recurrent_tag_pair must be a colon "
                                      "separated string with two components: "
                                   << tag_pair;

      RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tags[0]))
          << "Can't find tag '" << tags[0] << "' in signature "
          << options_.signature_name() << "; instead found tags "
          << absl::StrJoin(tag_to_tensor_map_, ", ", TagFormatter);
      RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tags[1]))
          << "Can't find tag '" << tags[1] << "' in signature "
          << options_.signature_name() << " ; instead found tags "
          << absl::StrJoin(tag_to_tensor_map_, ", ", TagFormatter);
      recurrent_feed_tags_.insert(tags[0]);
      recurrent_fetch_tags_to_feed_tags_[tags[1]] = tags[0];
    }

    // Check that all tags are present in this signature bound to tensors.
    for (const std::string& tag : cc->Inputs().GetTags()) {
      RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tag))
          << "Can't find tag '" << tag << "' in signature "
          << options_.signature_name() << "; instead found tags "
          << absl::StrJoin(tag_to_tensor_map_, ", ", TagFormatter);
    }
    for (const std::string& tag : cc->Outputs().GetTags()) {
      RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tag))
          << "Can't find tag '" << tag << "' in signature "
          << options_.signature_name() << "; instead found tags "
          << absl::StrJoin(tag_to_tensor_map_, ", ", TagFormatter);
    }

    {
      absl::WriterMutexLock l(&mutex_);
      inference_state_ = std::unique_ptr<InferenceState>();
    }

    if (options_.batch_size() == 1 || options_.batched_input()) {
      cc->SetOffset(0);
    }

    return absl::OkStatus();
  }

  // Adds a batch dimension to the input tensor if specified in the calculator
  // options.
  absl::Status AddBatchDimension(tf::Tensor* input_tensor) {
    if (options_.add_batch_dim_to_tensors()) {
      tf::TensorShape new_shape(input_tensor->shape());
      new_shape.InsertDim(0, 1);
      RET_CHECK(input_tensor->CopyFrom(*input_tensor, new_shape))
          << "Could not add 0th dimension to tensor without changing its shape."
          << " Current shape: " << input_tensor->shape().DebugString();
    }
    return absl::OkStatus();
  }

  absl::Status AggregateTensorPacket(
      const std::string& tag_name, const Packet& packet,
      std::map<Timestamp, std::map<std::string, tf::Tensor>>*
          input_tensors_by_tag_by_timestamp,
      InferenceState* inference_state) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
    tf::Tensor input_tensor(packet.Get<tf::Tensor>());
    RET_CHECK_OK(AddBatchDimension(&input_tensor));
    if (mediapipe::ContainsKey(recurrent_feed_tags_, tag_name)) {
      // If we receive an input on a recurrent tag, override the state.
      // It's OK to override the global state because there is just one
      // input stream allowed for recurrent tensors.
      inference_state_->input_tensor_batches_[tag_name].clear();
    }
    (*input_tensors_by_tag_by_timestamp)[packet.Timestamp()].insert(
        std::make_pair(tag_name, input_tensor));
    return absl::OkStatus();
  }

  // Removes the batch dimension of the output tensor if specified in the
  // calculator options.
  absl::Status RemoveBatchDimension(tf::Tensor* output_tensor) {
    if (options_.add_batch_dim_to_tensors()) {
      tf::TensorShape new_shape(output_tensor->shape());
      new_shape.RemoveDim(0);
      RET_CHECK(output_tensor->CopyFrom(*output_tensor, new_shape))
          << "Could not remove 0th dimension from tensor without changing its "
          << "shape. Current shape: " << output_tensor->shape().DebugString()
          << " (The expected first dimension is 1 for a batch element.)";
    }
    return absl::OkStatus();
  }

  absl::Status Process(CalculatorContext* cc) override {
    std::unique_ptr<InferenceState> inference_state_to_process;
    {
      absl::WriterMutexLock l(&mutex_);
      if (inference_state_ == nullptr) {
        inference_state_ = CreateInferenceState(cc);
      }
      std::map<Timestamp, std::map<std::string, tf::Tensor>>
          input_tensors_by_tag_by_timestamp;
      for (const std::string& tag_as_node_name : cc->Inputs().GetTags()) {
        if (cc->Inputs().Tag(tag_as_node_name).IsEmpty()) {
          // Recurrent tensors can be empty.
          if (!mediapipe::ContainsKey(recurrent_feed_tags_, tag_as_node_name)) {
            if (options_.skip_on_missing_features()) {
              return absl::OkStatus();
            } else {
              return absl::InvalidArgumentError(absl::StrCat(
                  "Tag ", tag_as_node_name,
                  " not present at timestamp: ", cc->InputTimestamp().Value()));
            }
          }
        } else if (options_.batched_input()) {
          const auto& tensor_packets =
              cc->Inputs().Tag(tag_as_node_name).Get<std::vector<Packet>>();
          if (tensor_packets.size() > options_.batch_size()) {
            return absl::InvalidArgumentError(absl::StrCat(
                "Batch for tag ", tag_as_node_name,
                " has more packets than batch capacity. batch_size: ",
                options_.batch_size(), " packets: ", tensor_packets.size()));
          }
          for (const auto& packet : tensor_packets) {
            RET_CHECK_OK(AggregateTensorPacket(
                tag_as_node_name, packet, &input_tensors_by_tag_by_timestamp,
                inference_state_.get()));
          }
        } else {
          RET_CHECK_OK(AggregateTensorPacket(
              tag_as_node_name, cc->Inputs().Tag(tag_as_node_name).Value(),
              &input_tensors_by_tag_by_timestamp, inference_state_.get()));
        }
      }
      for (const auto& timestamp_and_input_tensors_by_tag :
           input_tensors_by_tag_by_timestamp) {
        inference_state_->batch_timestamps_.emplace_back(
            timestamp_and_input_tensors_by_tag.first);
        for (const auto& input_tensor_and_tag :
             timestamp_and_input_tensors_by_tag.second) {
          inference_state_->input_tensor_batches_[input_tensor_and_tag.first]
              .emplace_back(input_tensor_and_tag.second);
        }
      }
      if (inference_state_->batch_timestamps_.size() == options_.batch_size() ||
          options_.batched_input()) {
        inference_state_to_process = std::move(inference_state_);
        inference_state_ = std::unique_ptr<InferenceState>();
      }
    }

    if (inference_state_to_process) {
      MP_RETURN_IF_ERROR(
          OutputBatch(cc, std::move(inference_state_to_process)));
    }

    return absl::OkStatus();
  }

  absl::Status Close(CalculatorContext* cc) override {
    std::unique_ptr<InferenceState> inference_state_to_process = nullptr;
    {
      absl::WriterMutexLock l(&mutex_);
      if (cc->GraphStatus().ok() && inference_state_ != nullptr &&
          !inference_state_->batch_timestamps_.empty()) {
        inference_state_to_process = std::move(inference_state_);
        inference_state_ = std::unique_ptr<InferenceState>();
      }
    }
    if (inference_state_to_process) {
      MP_RETURN_IF_ERROR(
          OutputBatch(cc, std::move(inference_state_to_process)));
    }
    return absl::OkStatus();
  }

  // When a batch of input tensors is ready to be run, runs TensorFlow and
  // outputs the output tensors. The output tensors have timestamps matching
  // the input tensor that formed that batch element. Any requested
  // batch_dimension is added and removed. This code takes advantage of the fact
  // that copying a tensor shares the same reference-counted, heap allocated
  // memory buffer. Therefore, copies are cheap and should not cause the memory
  // buffer to fall out of scope. In contrast, concat is only used where
  // necessary.
  absl::Status OutputBatch(CalculatorContext* cc,
                           std::unique_ptr<InferenceState> inference_state) {
    const int64_t start_time = absl::ToUnixMicros(clock_->TimeNow());
    std::vector<std::pair<mediapipe::ProtoString, tf::Tensor>> input_tensors;

    for (auto& keyed_tensors : inference_state->input_tensor_batches_) {
      if (options_.batch_size() == 1) {
        // Short circuit to avoid the cost of deep copying tensors in concat.
        if (!keyed_tensors.second.empty()) {
          input_tensors.emplace_back(tag_to_tensor_map_[keyed_tensors.first],
                                     keyed_tensors.second[0]);
        } else {
          // The input buffer can be empty for recurrent tensors.
          RET_CHECK(
              mediapipe::ContainsKey(recurrent_feed_tags_, keyed_tensors.first))
              << "A non-recurrent tensor does not have an input: "
              << keyed_tensors.first;
        }
      } else {
        if (options_.pad_to_batch_size()) {
          // Pad by replicating the first tensor, then ignore the values.
          keyed_tensors.second.resize(options_.batch_size());
          std::fill(keyed_tensors.second.begin() +
                        inference_state->batch_timestamps_.size(),
                    keyed_tensors.second.end(), keyed_tensors.second[0]);
        }
        tf::Tensor concated;
        const tf::Status concat_status =
            tf::tensor::Concat(keyed_tensors.second, &concated);
        ABSL_CHECK(concat_status.ok()) << concat_status.ToString();
        input_tensors.emplace_back(tag_to_tensor_map_[keyed_tensors.first],
                                   concated);
      }
    }
    inference_state->input_tensor_batches_.clear();
    std::vector<mediapipe::ProtoString> output_tensor_names;
    std::vector<std::string> output_name_in_signature;
    for (const std::string& tag : cc->Outputs().GetTags()) {
      output_tensor_names.emplace_back(tag_to_tensor_map_[tag]);
      output_name_in_signature.emplace_back(tag);
    }
    for (const auto& tag_pair : recurrent_fetch_tags_to_feed_tags_) {
      // Ensure that we always fetch the recurrent state tensors.
      if (std::find(output_name_in_signature.begin(),
                    output_name_in_signature.end(),
                    tag_pair.first) == output_name_in_signature.end()) {
        output_tensor_names.emplace_back(tag_to_tensor_map_[tag_pair.first]);
        output_name_in_signature.emplace_back(tag_pair.first);
      }
    }
    std::vector<tf::Tensor> outputs;

    SimpleSemaphore* session_run_throttle = nullptr;
    if (options_.max_concurrent_session_runs() > 0) {
      session_run_throttle =
          get_session_run_throttle(options_.max_concurrent_session_runs());
      session_run_throttle->Acquire(1);
    }
    const int64_t run_start_time = absl::ToUnixMicros(clock_->TimeNow());
    tf::Status tf_status;
    {
#if !defined(MEDIAPIPE_MOBILE) && !defined(__APPLE__)
      tsl::profiler::TraceMe trace(absl::string_view(cc->NodeName()));
#endif
      tf_status = session_->Run(input_tensors, output_tensor_names,
                                {} /* target_node_names */, &outputs);
    }

    if (session_run_throttle != nullptr) {
      session_run_throttle->Release(1);
    }

    // RET_CHECK on the tf::Status object itself in order to print an
    // informative error message.
    RET_CHECK(tf_status.ok()) << "Run failed: " << tf_status.ToString();

    const int64_t run_end_time = absl::ToUnixMicros(clock_->TimeNow());
    cc->GetCounter(kTotalSessionRunsTimeUsecsCounterSuffix)
        ->IncrementBy(run_end_time - run_start_time);
    cc->GetCounter(kTotalNumSessionRunsCounterSuffix)->Increment();

    // Feed back the recurrent state.
    for (const auto& tag_pair : recurrent_fetch_tags_to_feed_tags_) {
      int pos = std::find(output_name_in_signature.begin(),
                          output_name_in_signature.end(), tag_pair.first) -
                output_name_in_signature.begin();
      inference_state->input_tensor_batches_[tag_pair.second].emplace_back(
          outputs[pos]);
    }

    absl::WriterMutexLock l(&mutex_);
    // Set that we want to split on each index of the 0th dimension.
    std::vector<int64_t> split_vector(
        options_.pad_to_batch_size()
            ? options_.batch_size()
            : inference_state->batch_timestamps_.size(),
        1);
    for (int i = 0; i < output_tensor_names.size(); ++i) {
      if (options_.batch_size() == 1) {
        if (cc->Outputs().HasTag(output_name_in_signature[i])) {
          tf::Tensor output_tensor(outputs[i]);
          RET_CHECK_OK(RemoveBatchDimension(&output_tensor));
          cc->Outputs()
              .Tag(output_name_in_signature[i])
              .Add(new tf::Tensor(output_tensor),
                   inference_state->batch_timestamps_[0]);
        }
      } else {
        std::vector<tf::Tensor> split_tensors;
        const tf::Status split_status =
            tf::tensor::Split(outputs[i], split_vector, &split_tensors);
        ABSL_CHECK(split_status.ok()) << split_status.ToString();
        // Loop over timestamps so that we don't copy the padding.
        for (int j = 0; j < inference_state->batch_timestamps_.size(); ++j) {
          tf::Tensor output_tensor(split_tensors[j]);
          RET_CHECK_OK(RemoveBatchDimension(&output_tensor));
          cc->Outputs()
              .Tag(output_name_in_signature[i])
              .Add(new tf::Tensor(output_tensor),
                   inference_state->batch_timestamps_[j]);
        }
      }
    }

    // Get end time and report.
    const int64_t end_time = absl::ToUnixMicros(clock_->TimeNow());
    cc->GetCounter(kTotalUsecsCounterSuffix)
        ->IncrementBy(end_time - start_time);
    cc->GetCounter(kTotalProcessedTimestampsCounterSuffix)
        ->IncrementBy(inference_state->batch_timestamps_.size());

    // Make sure we hold on to the recursive state.
    if (!options_.recurrent_tag_pair().empty()) {
      inference_state_ = std::move(inference_state);
      inference_state_->batch_timestamps_.clear();
    }

    return absl::OkStatus();
  }

 private:
  // The Session object is provided by a packet factory and is owned by the
  // MediaPipe framework. Individual calls are thread-safe, but session state
  // may be shared across threads.
  tf::Session* session_;

  // A mapping between stream tags and the tensor names they are bound to.
  std::map<std::string, std::string> tag_to_tensor_map_;

  absl::Mutex mutex_;
  std::unique_ptr<InferenceState> inference_state_ ABSL_GUARDED_BY(mutex_);

  // The options for the calculator.
  TensorFlowInferenceCalculatorOptions options_;

  // Store the feed and fetch tags for feed/fetch recurrent networks.
  std::set<std::string> recurrent_feed_tags_;
  std::map<std::string, std::string> recurrent_fetch_tags_to_feed_tags_;

  // Clock used to measure the computation time in OutputBatch().
  std::unique_ptr<mediapipe::Clock> clock_;

  // The static singleton semaphore to throttle concurrent session runs.
  static SimpleSemaphore* get_session_run_throttle(
      int32_t max_concurrent_session_runs) {
    static SimpleSemaphore* session_run_throttle =
        new SimpleSemaphore(max_concurrent_session_runs);
    return session_run_throttle;
  }
};
REGISTER_CALCULATOR(TensorFlowInferenceCalculator);

constexpr char TensorFlowInferenceCalculator::kTotalUsecsCounterSuffix[];
constexpr char
    TensorFlowInferenceCalculator::kTotalProcessedTimestampsCounterSuffix[];
constexpr char
    TensorFlowInferenceCalculator::kTotalSessionRunsTimeUsecsCounterSuffix[];
constexpr char
    TensorFlowInferenceCalculator::kTotalNumSessionRunsCounterSuffix[];
}  // namespace mediapipe