chromium/third_party/mediapipe/src/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc

// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <cstdint>
#include <optional>
#include <string>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "absl/strings/match.h"
#include "absl/strings/strip.h"
#include "mediapipe/calculators/image/opencv_image_encoder_calculator.pb.h"
#include "mediapipe/calculators/tensorflow/pack_media_sequence_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/location.h"
#include "mediapipe/framework/formats/location_opencv.h"
#include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/util/sequence/media_sequence.h"
#include "mediapipe/util/sequence/media_sequence_util.h"
#include "tensorflow/core/example/example.pb.h"
#include "tensorflow/core/example/feature.pb.h"

namespace mediapipe {

const char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE";
const char kImageTag[] = "IMAGE";
const char kImageLabelPrefixTag[] = "IMAGE_LABEL_";
const char kClipLabelPrefixTag[] = "CLIP_LABEL_";
const char kFloatContextFeaturePrefixTag[] = "FLOAT_CONTEXT_FEATURE_";
const char kIntsContextFeaturePrefixTag[] = "INTS_CONTEXT_FEATURE_";
const char kBytesContextFeaturePrefixTag[] = "BYTES_CONTEXT_FEATURE_";
const char kFloatFeaturePrefixTag[] = "FLOAT_FEATURE_";
const char kIntFeaturePrefixTag[] = "INT_FEATURE_";
const char kBytesFeaturePrefixTag[] = "BYTES_FEATURE_";
const char kForwardFlowEncodedTag[] = "FORWARD_FLOW_ENCODED";
const char kBBoxTag[] = "BBOX";
const char kKeypointsTag[] = "KEYPOINTS";
const char kSegmentationMaskTag[] = "CLASS_SEGMENTATION";
const char kClipMediaIdTag[] = "CLIP_MEDIA_ID";

namespace tf = ::tensorflow;
namespace mpms = mediapipe::mediasequence;

// Sink calculator to package streams into tf.SequenceExamples.
//
// The calculator takes a tf.SequenceExample as a side input and then adds
// the data from inputs to the SequenceExample with timestamps. Additional
// context features can be supplied verbatim in the calculator's options. The
// SequenceExample will conform to the description in media_sequence.h.
//
// The supported input stream tags are:
// * "IMAGE", which stores the encoded images from the
//   OpenCVImageEncoderCalculator,
// * "IMAGE_LABEL", which stores whole image labels from Detection,
// * "FORWARD_FLOW_ENCODED", which stores the encoded optical flow from the same
//   calculator,
// * "BBOX" which stores bounding boxes from vector<Detections>,
// * streams with the "FLOAT_FEATURE_${NAME}" pattern, which stores the values
//   from vector<float>'s associated with the name ${NAME},
// * "KEYPOINTS" stores a map of 2D keypoints from flat_hash_map<string,
//   vector<pair<float, float>>>,
// * "CLIP_MEDIA_ID", which stores the clip's media ID as a string.
// * "CLIP_LABEL_${NAME}" which stores sparse feature labels, ID and scores in
//   mediapipe::Detection. In the input Detection, the score field is required,
//   and label and label_id are optional but at least one of them should be set.
// "IMAGE_${NAME}", "BBOX_${NAME}", and "KEYPOINTS_${NAME}" will also store
// prefixed versions of each stream, which allows for multiple image streams to
// be included. However, the default names are supported by more tools.
//
// Example config:
// node {
//   calculator: "PackMediaSequenceCalculator"
//   input_side_packet: "SEQUENCE_EXAMPLE:example_input_side_packet"
//   input_stream: "IMAGE:frames"
//   input_stream: "FLOAT_FEATURE_FDENSE:fdense_vf"
//   output_stream: "SEQUENCE_EXAMPLE:example_output_stream"
//   options {
//     [mediapipe.PackMediaSequenceCalculatorOptions.ext]: {
//       context_feature_map {
//         feature {
//           key: "image/frames_per_second"
//           value {
//             float_list {
//               value: 30.0
//             }
//           }
//         }
//       }
//     }
//   }
// }
namespace {
uint8_t ConvertFloatToByte(const float float_value) {
  float clamped_value = std::clamp(0.0f, 1.0f, float_value);
  return static_cast<uint8_t>(clamped_value * 255.0 + .5f);
}
}  // namespace

class PackMediaSequenceCalculator : public CalculatorBase {
 public:
  static absl::Status GetContract(CalculatorContract* cc) {
    RET_CHECK(cc->InputSidePackets().HasTag(kSequenceExampleTag));
    cc->InputSidePackets().Tag(kSequenceExampleTag).Set<tf::SequenceExample>();
    if (cc->InputSidePackets().HasTag(kClipMediaIdTag)) {
      cc->InputSidePackets().Tag(kClipMediaIdTag).Set<std::string>();
    }

    if (cc->Inputs().HasTag(kForwardFlowEncodedTag)) {
      cc->Inputs()
          .Tag(kForwardFlowEncodedTag)
          .Set<OpenCvImageEncoderCalculatorResults>();
    }
    if (cc->Inputs().HasTag(kSegmentationMaskTag)) {
      cc->Inputs().Tag(kSegmentationMaskTag).Set<std::vector<Detection>>();
    }

    for (const auto& tag : cc->Inputs().GetTags()) {
      if (absl::StartsWith(tag, kImageTag)) {
        if (absl::StartsWith(tag, kImageLabelPrefixTag)) {
          cc->Inputs().Tag(tag).Set<Detection>();
          continue;
        }
        std::string key = "";
        if (tag != kImageTag) {
          int tag_length = sizeof(kImageTag) / sizeof(*kImageTag) - 1;
          if (tag[tag_length] == '_') {
            key = tag.substr(tag_length + 1);
          } else {
            continue;  // Skip keys that don't match "(kImageTag)_?"
          }
        }
        cc->Inputs().Tag(tag).Set<OpenCvImageEncoderCalculatorResults>();
      }
      if (absl::StartsWith(tag, kKeypointsTag)) {
        std::string key = "";
        if (tag != kKeypointsTag) {
          int tag_length = sizeof(kKeypointsTag) / sizeof(*kKeypointsTag) - 1;
          if (tag[tag_length] == '_') {
            key = tag.substr(tag_length + 1);
          } else {
            continue;  // Skip keys that don't match "(kKeypointsTag)_?"
          }
        }
        cc->Inputs()
            .Tag(tag)
            .Set<absl::flat_hash_map<std::string,
                                     std::vector<std::pair<float, float>>>>();
      }
      if (absl::StartsWith(tag, kBBoxTag)) {
        std::string key = "";
        if (tag != kBBoxTag) {
          int tag_length = sizeof(kBBoxTag) / sizeof(*kBBoxTag) - 1;
          if (tag[tag_length] == '_') {
            key = tag.substr(tag_length + 1);
          } else {
            continue;  // Skip keys that don't match "(kBBoxTag)_?"
          }
        }
        cc->Inputs().Tag(tag).Set<std::vector<Detection>>();
      }
      if (absl::StartsWith(tag, kClipLabelPrefixTag)) {
        cc->Inputs().Tag(tag).Set<Detection>();
      }
      if (absl::StartsWith(tag, kFloatContextFeaturePrefixTag)) {
        cc->Inputs().Tag(tag).Set<std::vector<float>>();
      }
      if (absl::StartsWith(tag, kIntsContextFeaturePrefixTag)) {
        cc->Inputs().Tag(tag).Set<std::vector<int64_t>>();
      }
      if (absl::StartsWith(tag, kBytesContextFeaturePrefixTag)) {
        cc->Inputs().Tag(tag).Set<std::vector<std::string>>();
      }
      if (absl::StartsWith(tag, kFloatFeaturePrefixTag)) {
        cc->Inputs().Tag(tag).Set<std::vector<float>>();
      }
      if (absl::StartsWith(tag, kIntFeaturePrefixTag)) {
        cc->Inputs().Tag(tag).Set<std::vector<int64_t>>();
      }
      if (absl::StartsWith(tag, kBytesFeaturePrefixTag)) {
        cc->Inputs().Tag(tag).Set<std::vector<std::string>>();
      }
    }

    RET_CHECK(cc->Outputs().HasTag(kSequenceExampleTag) ||
              cc->OutputSidePackets().HasTag(kSequenceExampleTag))
        << "Neither the output stream nor the output side packet is set to "
           "output the sequence example.";
    if (cc->Outputs().HasTag(kSequenceExampleTag)) {
      cc->Outputs().Tag(kSequenceExampleTag).Set<tf::SequenceExample>();
    }
    if (cc->OutputSidePackets().HasTag(kSequenceExampleTag)) {
      cc->OutputSidePackets()
          .Tag(kSequenceExampleTag)
          .Set<tf::SequenceExample>();
    }
    return absl::OkStatus();
  }

  absl::Status Open(CalculatorContext* cc) override {
    sequence_ = ::absl::make_unique<tf::SequenceExample>(
        cc->InputSidePackets()
            .Tag(kSequenceExampleTag)
            .Get<tf::SequenceExample>());
    if (cc->InputSidePackets().HasTag(kClipMediaIdTag) &&
        !cc->InputSidePackets().Tag(kClipMediaIdTag).IsEmpty()) {
      clip_media_id_ =
          cc->InputSidePackets().Tag(kClipMediaIdTag).Get<std::string>();
    }

    const auto& context_features =
        cc->Options<PackMediaSequenceCalculatorOptions>().context_feature_map();
    for (const auto& feature : context_features.feature()) {
      *mpms::MutableContext(feature.first, sequence_.get()) = feature.second;
    }
    for (const auto& tag : cc->Inputs().GetTags()) {
      features_present_[tag] = false;
    }

    replace_keypoints_ = false;
    if (cc->Options<PackMediaSequenceCalculatorOptions>()
            .replace_data_instead_of_append()) {
      // Clear the existing values under the same key.
      for (const auto& tag : cc->Inputs().GetTags()) {
        if (absl::StartsWith(tag, kImageTag)) {
          if (absl::StartsWith(tag, kImageLabelPrefixTag)) {
            std::string key =
                std::string(absl::StripPrefix(tag, kImageLabelPrefixTag));
            mpms::ClearImageLabelString(key, sequence_.get());
            mpms::ClearImageLabelConfidence(key, sequence_.get());
            if (!key.empty() || mpms::HasImageEncoded(*sequence_)) {
              mpms::ClearImageTimestamp(key, sequence_.get());
            }
            continue;
          }
          std::string key = "";
          if (tag != kImageTag) {
            int tag_length = sizeof(kImageTag) / sizeof(*kImageTag) - 1;
            if (tag[tag_length] == '_') {
              key = tag.substr(tag_length + 1);
            } else {
              continue;  // Skip keys that don't match "(kImageTag)_?"
            }
          }
          mpms::ClearImageEncoded(key, sequence_.get());
          mpms::ClearImageTimestamp(key, sequence_.get());
        }
        if (absl::StartsWith(tag, kBBoxTag)) {
          std::string key = "";
          if (tag != kBBoxTag) {
            int tag_length = sizeof(kBBoxTag) / sizeof(*kBBoxTag) - 1;
            if (tag[tag_length] == '_') {
              key = tag.substr(tag_length + 1);
            } else {
              continue;  // Skip keys that don't match "(kBBoxTag)_?"
            }
          }
          mpms::ClearBBox(key, sequence_.get());
          mpms::ClearBBoxTimestamp(key, sequence_.get());
          mpms::ClearBBoxIsAnnotated(key, sequence_.get());
          mpms::ClearBBoxNumRegions(key, sequence_.get());
          mpms::ClearBBoxLabelString(key, sequence_.get());
          mpms::ClearBBoxLabelIndex(key, sequence_.get());
          mpms::ClearBBoxLabelConfidence(key, sequence_.get());
          mpms::ClearBBoxClassString(key, sequence_.get());
          mpms::ClearBBoxClassIndex(key, sequence_.get());
          mpms::ClearBBoxTrackString(key, sequence_.get());
          mpms::ClearBBoxTrackIndex(key, sequence_.get());
          mpms::ClearUnmodifiedBBoxTimestamp(key, sequence_.get());
        }
        if (absl::StartsWith(tag, kClipLabelPrefixTag)) {
          const std::string& key = tag.substr(
              sizeof(kClipLabelPrefixTag) / sizeof(*kClipLabelPrefixTag) - 1);
          mpms::ClearClipLabelIndex(key, sequence_.get());
          mpms::ClearClipLabelString(key, sequence_.get());
          mpms::ClearClipLabelConfidence(key, sequence_.get());
        }
        if (absl::StartsWith(tag, kFloatContextFeaturePrefixTag)) {
          const std::string& key =
              tag.substr(sizeof(kFloatContextFeaturePrefixTag) /
                             sizeof(*kFloatContextFeaturePrefixTag) -
                         1);
          mpms::ClearContextFeatureFloats(key, sequence_.get());
        }
        if (absl::StartsWith(tag, kIntsContextFeaturePrefixTag)) {
          const std::string& key =
              tag.substr(sizeof(kIntsContextFeaturePrefixTag) /
                             sizeof(*kIntsContextFeaturePrefixTag) -
                         1);
          mpms::ClearContextFeatureInts(key, sequence_.get());
        }
        if (absl::StartsWith(tag, kBytesContextFeaturePrefixTag)) {
          const std::string& key =
              tag.substr(sizeof(kBytesContextFeaturePrefixTag) /
                             sizeof(*kBytesContextFeaturePrefixTag) -
                         1);
          mpms::ClearContextFeatureBytes(key, sequence_.get());
        }
        if (absl::StartsWith(tag, kFloatFeaturePrefixTag)) {
          std::string key = tag.substr(sizeof(kFloatFeaturePrefixTag) /
                                           sizeof(*kFloatFeaturePrefixTag) -
                                       1);
          mpms::ClearFeatureFloats(key, sequence_.get());
          mpms::ClearFeatureTimestamp(key, sequence_.get());
        }
        if (absl::StartsWith(tag, kIntFeaturePrefixTag)) {
          std::string key = tag.substr(
              sizeof(kIntFeaturePrefixTag) / sizeof(*kIntFeaturePrefixTag) - 1);
          mpms::ClearFeatureInts(key, sequence_.get());
          mpms::ClearFeatureTimestamp(key, sequence_.get());
        }
        if (absl::StartsWith(tag, kBytesFeaturePrefixTag)) {
          std::string key = tag.substr(sizeof(kBytesFeaturePrefixTag) /
                                           sizeof(*kBytesFeaturePrefixTag) -
                                       1);
          mpms::ClearFeatureBytes(key, sequence_.get());
          mpms::ClearFeatureTimestamp(key, sequence_.get());
        }
        if (absl::StartsWith(tag, kKeypointsTag)) {
          std::string key =
              tag.substr(sizeof(kKeypointsTag) / sizeof(*kKeypointsTag) - 1);
          replace_keypoints_ = true;
        }
      }
      if (cc->Inputs().HasTag(kForwardFlowEncodedTag)) {
        mpms::ClearForwardFlowEncoded(sequence_.get());
        mpms::ClearForwardFlowTimestamp(sequence_.get());
      }
    }

    return absl::OkStatus();
  }

  absl::Status VerifySequence() {
    std::string error_msg = "Missing features - ";
    bool all_present = true;
    for (const auto& iter : features_present_) {
      if (!iter.second) {
        all_present = false;
        absl::StrAppend(&error_msg, iter.first, ", ");
      }
    }
    if (all_present) {
      return absl::OkStatus();
    } else {
      return ::mediapipe::NotFoundErrorBuilder(MEDIAPIPE_LOC) << error_msg;
    }
  }

  absl::Status VerifySize() {
    const int64_t MAX_PROTO_BYTES = 1073741823;
    std::string id = mpms::HasExampleId(*sequence_)
                         ? mpms::GetExampleId(*sequence_)
                         : "example";
    RET_CHECK_LT(sequence_->ByteSizeLong(), MAX_PROTO_BYTES)
        << "sequence '" << id
        << "' would be too many bytes to serialize after adding features.";
    return absl::OkStatus();
  }

  absl::Status Close(CalculatorContext* cc) override {
    auto& options = cc->Options<PackMediaSequenceCalculatorOptions>();
    if (options.reconcile_metadata()) {
      RET_CHECK_OK(mpms::ReconcileMetadata(
          options.reconcile_bbox_annotations(),
          options.reconcile_region_annotations(), sequence_.get()));
    }

    if (options.skip_large_sequences()) {
      RET_CHECK_OK(VerifySize());
    }
    if (options.output_only_if_all_present()) {
      absl::Status status = VerifySequence();
      if (!status.ok()) {
        cc->GetCounter(status.ToString())->Increment();
        return status;
      }
    }

    if (cc->OutputSidePackets().HasTag(kSequenceExampleTag)) {
      cc->OutputSidePackets()
          .Tag(kSequenceExampleTag)
          .Set(MakePacket<tensorflow::SequenceExample>(*sequence_));
    }
    if (cc->Outputs().HasTag(kSequenceExampleTag)) {
      cc->Outputs()
          .Tag(kSequenceExampleTag)
          .Add(sequence_.release(), options.output_as_zero_timestamp()
                                        ? Timestamp(0ll)
                                        : Timestamp::PostStream());
    }
    sequence_.reset();

    return absl::OkStatus();
  }

  absl::Status Process(CalculatorContext* cc) override {
    int image_height = -1;
    int image_width = -1;
    // Because the tag order may vary, we need to loop through tags to get
    // image information before processing other tag types.
    for (const auto& tag : cc->Inputs().GetTags()) {
      if (!cc->Inputs().Tag(tag).IsEmpty()) {
        features_present_[tag] = true;
      }
      if (absl::StartsWith(tag, kImageTag) &&
          !cc->Inputs().Tag(tag).IsEmpty()) {
        std::string key = "";
        if (absl::StartsWith(tag, kImageLabelPrefixTag)) {
          std::string key =
              std::string(absl::StripPrefix(tag, kImageLabelPrefixTag));
          const auto& detection = cc->Inputs().Tag(tag).Get<Detection>();
          if (detection.label().empty()) continue;
          RET_CHECK(detection.label_size() == detection.score_size())
              << "Wrong image label data format: " << detection.label_size()
              << " vs " << detection.score_size();
          if (!detection.label_id().empty()) {
            RET_CHECK(detection.label_id_size() == detection.label_size())
                << "Wrong image label ID format: " << detection.label_id_size()
                << " vs " << detection.label_size();
          }
          std::vector<std::string> labels(detection.label().begin(),
                                          detection.label().end());
          std::vector<float> confidences(detection.score().begin(),
                                         detection.score().end());
          std::vector<int32_t> ids(detection.label_id().begin(),
                                   detection.label_id().end());
          if (!key.empty() || mpms::HasImageEncoded(*sequence_)) {
            mpms::AddImageTimestamp(key, cc->InputTimestamp().Value(),
                                    sequence_.get());
          }
          mpms::AddImageLabelString(key, labels, sequence_.get());
          mpms::AddImageLabelConfidence(key, confidences, sequence_.get());
          if (!ids.empty()) mpms::AddImageLabelIndex(key, ids, sequence_.get());
          continue;
        }
        if (tag != kImageTag) {
          int tag_length = sizeof(kImageTag) / sizeof(*kImageTag) - 1;
          if (tag[tag_length] == '_') {
            key = tag.substr(tag_length + 1);
          } else {
            continue;  // Skip keys that don't match "(kImageTag)_?"
          }
        }
        const OpenCvImageEncoderCalculatorResults& image =
            cc->Inputs().Tag(tag).Get<OpenCvImageEncoderCalculatorResults>();
        if (!image.has_encoded_image()) {
          return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
                 << "No encoded image";
        }
        image_height = image.height();
        image_width = image.width();
        mpms::AddImageTimestamp(key, cc->InputTimestamp().Value(),
                                sequence_.get());
        mpms::AddImageEncoded(key, image.encoded_image(), sequence_.get());
      }
    }
    for (const auto& tag : cc->Inputs().GetTags()) {
      if (!cc->Inputs().Tag(tag).IsEmpty()) {
        features_present_[tag] = true;
      }
      if (absl::StartsWith(tag, kKeypointsTag) &&
          !cc->Inputs().Tag(tag).IsEmpty()) {
        std::string key = "";
        if (tag != kKeypointsTag) {
          int tag_length = sizeof(kKeypointsTag) / sizeof(*kKeypointsTag) - 1;
          if (tag[tag_length] == '_') {
            key = tag.substr(tag_length + 1);
          } else {
            continue;  // Skip keys that don't match "(kKeypointsTag)_?"
          }
        }
        const auto& keypoints =
            cc->Inputs()
                .Tag(tag)
                .Get<absl::flat_hash_map<
                    std::string, std::vector<std::pair<float, float>>>>();
        for (const auto& pair : keypoints) {
          std::string prefix = mpms::merge_prefix(key, pair.first);
          if (replace_keypoints_) {
            mpms::ClearBBoxPoint(prefix, sequence_.get());
            mpms::ClearBBoxTimestamp(prefix, sequence_.get());
            mpms::ClearBBoxIsAnnotated(prefix, sequence_.get());
            mpms::ClearBBoxNumRegions(prefix, sequence_.get());
            mpms::ClearBBoxLabelString(prefix, sequence_.get());
            mpms::ClearBBoxLabelIndex(prefix, sequence_.get());
            mpms::ClearBBoxLabelConfidence(prefix, sequence_.get());
            mpms::ClearBBoxClassString(prefix, sequence_.get());
            mpms::ClearBBoxClassIndex(prefix, sequence_.get());
            mpms::ClearBBoxTrackString(prefix, sequence_.get());
            mpms::ClearBBoxTrackIndex(prefix, sequence_.get());
            mpms::ClearUnmodifiedBBoxTimestamp(prefix, sequence_.get());
          }
          mpms::AddBBoxTimestamp(prefix, cc->InputTimestamp().Value(),
                                 sequence_.get());
          mpms::AddBBoxPoint(prefix, pair.second, sequence_.get());
        }
        replace_keypoints_ = false;
      }
      if (absl::StartsWith(tag, kClipLabelPrefixTag) &&
          !cc->Inputs().Tag(tag).IsEmpty()) {
        const std::string& key = tag.substr(
            sizeof(kClipLabelPrefixTag) / sizeof(*kClipLabelPrefixTag) - 1);
        const Detection& detection = cc->Inputs().Tag(tag).Get<Detection>();
        bool add_empty_labels =
            cc->Options<PackMediaSequenceCalculatorOptions>()
                .add_empty_labels();
        if (detection.score().empty()) {
          if (add_empty_labels) {
            mpms::SetClipLabelString(key, {}, sequence_.get());
            mpms::SetClipLabelConfidence(key, {}, sequence_.get());
          }
          continue;
        }
        if (detection.label().empty() && detection.label_id().empty()) {
          return absl::InvalidArgumentError(
              "detection.label and detection.label_id can't be both empty");
        }
        // Allow empty label (for indexed feature inputs), but if label is not
        // empty, it should have the same size as the score field.
        if (!detection.label().empty()) {
          if (detection.label().size() != detection.score().size()) {
            return absl::InvalidArgumentError(
                "Different size of detection.label and detection.score");
          }
        }
        // Allow empty label_ids, but if label_ids is not empty, it should have
        // the same size as the score field.
        if (!detection.label_id().empty()) {
          if (detection.label_id().size() != detection.score().size()) {
            return absl::InvalidArgumentError(
                "Different size of detection.label_id and detection.score");
          }
        }
        for (int i = 0; i < detection.score().size(); ++i) {
          if (!detection.label_id().empty()) {
            mpms::AddClipLabelIndex(key, detection.label_id(i),
                                    sequence_.get());
          }
          if (!detection.label().empty()) {
            mpms::AddClipLabelString(key, detection.label(i), sequence_.get());
          }
          mpms::AddClipLabelConfidence(key, detection.score(i),
                                       sequence_.get());
        }
      }
      if (absl::StartsWith(tag, kFloatContextFeaturePrefixTag) &&
          !cc->Inputs().Tag(tag).IsEmpty()) {
        std::string key =
            tag.substr(sizeof(kFloatContextFeaturePrefixTag) /
                           sizeof(*kFloatContextFeaturePrefixTag) -
                       1);
        RET_CHECK_EQ(cc->InputTimestamp(), Timestamp::PostStream());
        for (const auto& value :
             cc->Inputs().Tag(tag).Get<std::vector<float>>()) {
          mpms::AddContextFeatureFloats(key, value, sequence_.get());
        }
      }
      if (absl::StartsWith(tag, kIntsContextFeaturePrefixTag) &&
          !cc->Inputs().Tag(tag).IsEmpty()) {
        const std::string& key =
            tag.substr(sizeof(kIntsContextFeaturePrefixTag) /
                           sizeof(*kIntsContextFeaturePrefixTag) -
                       1);
        // To ensure only one packet is provided for this tag.
        RET_CHECK_EQ(cc->InputTimestamp(), Timestamp::PostStream());
        for (const auto& value :
             cc->Inputs().Tag(tag).Get<std::vector<int64_t>>()) {
          mpms::AddContextFeatureInts(key, value, sequence_.get());
        }
      }
      if (absl::StartsWith(tag, kBytesContextFeaturePrefixTag) &&
          !cc->Inputs().Tag(tag).IsEmpty()) {
        const std::string& key =
            tag.substr(sizeof(kBytesContextFeaturePrefixTag) /
                           sizeof(*kBytesContextFeaturePrefixTag) -
                       1);
        // To ensure only one packet is provided for this tag.
        RET_CHECK_EQ(cc->InputTimestamp(), Timestamp::PostStream());
        for (const auto& value :
             cc->Inputs().Tag(tag).Get<std::vector<std::string>>()) {
          mpms::AddContextFeatureBytes(key, value, sequence_.get());
        }
      }
      if (absl::StartsWith(tag, kFloatFeaturePrefixTag) &&
          !cc->Inputs().Tag(tag).IsEmpty()) {
        std::string key = tag.substr(sizeof(kFloatFeaturePrefixTag) /
                                         sizeof(*kFloatFeaturePrefixTag) -
                                     1);
        mpms::AddFeatureTimestamp(key, cc->InputTimestamp().Value(),
                                  sequence_.get());
        mpms::AddFeatureFloats(key,
                               cc->Inputs().Tag(tag).Get<std::vector<float>>(),
                               sequence_.get());
      }
      if (absl::StartsWith(tag, kIntFeaturePrefixTag) &&
          !cc->Inputs().Tag(tag).IsEmpty()) {
        std::string key = tag.substr(
            sizeof(kIntFeaturePrefixTag) / sizeof(*kIntFeaturePrefixTag) - 1);
        mpms::AddFeatureTimestamp(key, cc->InputTimestamp().Value(),
                                  sequence_.get());
        mpms::AddFeatureInts(key,
                             cc->Inputs().Tag(tag).Get<std::vector<int64_t>>(),
                             sequence_.get());
      }
      if (absl::StartsWith(tag, kBytesFeaturePrefixTag) &&
          !cc->Inputs().Tag(tag).IsEmpty()) {
        std::string key = tag.substr(sizeof(kBytesFeaturePrefixTag) /
                                         sizeof(*kBytesFeaturePrefixTag) -
                                     1);
        mpms::AddFeatureTimestamp(key, cc->InputTimestamp().Value(),
                                  sequence_.get());
        mpms::AddFeatureBytes(
            key, cc->Inputs().Tag(tag).Get<std::vector<std::string>>(),
            sequence_.get());
      }
      if (absl::StartsWith(tag, kBBoxTag) && !cc->Inputs().Tag(tag).IsEmpty()) {
        std::string key = "";
        if (tag != kBBoxTag) {
          int tag_length = sizeof(kBBoxTag) / sizeof(*kBBoxTag) - 1;
          if (tag[tag_length] == '_') {
            key = tag.substr(tag_length + 1);
          } else {
            continue;  // Skip keys that don't match "(kBBoxTag)_?"
          }
        }
        std::vector<Location> predicted_locations;
        std::vector<std::string> predicted_class_strings;
        std::vector<float> predicted_class_confidences;
        std::vector<int> predicted_label_ids;
        for (auto& detection :
             cc->Inputs().Tag(tag).Get<std::vector<Detection>>()) {
          if (detection.location_data().format() ==
                  LocationData::BOUNDING_BOX ||
              detection.location_data().format() ==
                  LocationData::RELATIVE_BOUNDING_BOX) {
            if (mpms::HasImageHeight(*sequence_) &&
                mpms::HasImageWidth(*sequence_)) {
              image_height = mpms::GetImageHeight(*sequence_);
              image_width = mpms::GetImageWidth(*sequence_);
            }
            if (image_height == -1 || image_width == -1) {
              return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
                     << "Images must be provided with bounding boxes or the "
                        "image "
                     << "height and width must already be in the example.";
            }
            Location relative_bbox = Location::CreateRelativeBBoxLocation(
                Location(detection.location_data())
                    .ConvertToRelativeBBox(image_width, image_height));
            predicted_locations.push_back(relative_bbox);
            if (detection.label_size() > 0) {
              predicted_class_strings.push_back(detection.label(0));
            }
            if (detection.label_id_size() > 0) {
              predicted_label_ids.push_back(detection.label_id(0));
            }
            if (detection.score_size() > 0) {
              predicted_class_confidences.push_back(detection.score(0));
            }
          }
        }
        if (!predicted_locations.empty()) {
          mpms::AddBBox(key, predicted_locations, sequence_.get());
          mpms::AddBBoxTimestamp(key, cc->InputTimestamp().Value(),
                                 sequence_.get());
          if (!predicted_class_strings.empty()) {
            mpms::AddBBoxLabelString(key, predicted_class_strings,
                                     sequence_.get());
          }
          if (!predicted_label_ids.empty()) {
            mpms::AddBBoxLabelIndex(key, predicted_label_ids, sequence_.get());
          }
          if (!predicted_class_confidences.empty()) {
            mpms::AddBBoxLabelConfidence(key, predicted_class_confidences,
                                         sequence_.get());
          }
        }
      }
    }
    if (cc->Inputs().HasTag(kForwardFlowEncodedTag) &&
        !cc->Inputs().Tag(kForwardFlowEncodedTag).IsEmpty()) {
      const OpenCvImageEncoderCalculatorResults& forward_flow =
          cc->Inputs()
              .Tag(kForwardFlowEncodedTag)
              .Get<OpenCvImageEncoderCalculatorResults>();
      if (!forward_flow.has_encoded_image()) {
        return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
               << "No encoded forward flow";
      }
      mpms::AddForwardFlowTimestamp(cc->InputTimestamp().Value(),
                                    sequence_.get());
      mpms::AddForwardFlowEncoded(forward_flow.encoded_image(),
                                  sequence_.get());
    }
    if (cc->Inputs().HasTag(kSegmentationMaskTag) &&
        !cc->Inputs().Tag(kSegmentationMaskTag).IsEmpty()) {
      bool already_has_mask = false;
      for (auto& detection : cc->Inputs()
                                 .Tag(kSegmentationMaskTag)
                                 .Get<std::vector<Detection>>()) {
        if (detection.location_data().format() == LocationData::MASK) {
          RET_CHECK(!already_has_mask)
              << "We currently only support adding one mask per timestamp. "
              << sequence_->DebugString();
          auto mask_mat_ptr = GetCvMask(Location(detection.location_data()));
          std::vector<uchar> bytes;
          RET_CHECK(cv::imencode(".png", *mask_mat_ptr, bytes, {}));

          std::string encoded_mask(bytes.begin(), bytes.end());
          mpms::AddClassSegmentationEncoded(encoded_mask, sequence_.get());
          mpms::AddClassSegmentationTimestamp(cc->InputTimestamp().Value(),
                                              sequence_.get());
          // SegmentationClassLabelString is a context feature for the entire
          // sequence. The values in the last detection will be saved.
          mpms::SetClassSegmentationClassLabelString({detection.label(0)},
                                                     sequence_.get());
          already_has_mask = true;
        } else {
          return absl::UnimplementedError(
              "Global detections and empty detections are not supported.");
        }
      }
    }
    if (clip_media_id_.has_value()) {
      mpms::SetClipMediaId(*clip_media_id_, sequence_.get());
    }
    return absl::OkStatus();
  }

  std::unique_ptr<tf::SequenceExample> sequence_;
  std::optional<std::string> clip_media_id_ = std::nullopt;
  std::map<std::string, bool> features_present_;
  bool replace_keypoints_;
};
REGISTER_CALCULATOR(PackMediaSequenceCalculator);

}  // namespace mediapipe