chromium/third_party/mediapipe/src/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator.cc

// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <unordered_map>
#include <vector>

#include "absl/status/status.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"

namespace mediapipe {
namespace api2 {

using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;

// Aggregates EmbeddingResult packets into a vector of timestamped
// EmbeddingResult. Acts as a pass-through if no timestamp aggregation is
// needed.
//
// Inputs:
//   EMBEDDINGS: EmbeddingResult
//     The EmbeddingResult packets to aggregate.
//   TIMESTAMPS: std::vector<Timestamp> @Optional.
//     The collection of timestamps that this calculator should aggregate. This
//     stream is optional: if provided then the TIMESTAMPED_EMBEDDINGS output
//     will contain the aggregated results. Otherwise as no timestamp
//     aggregation is required the EMBEDDINGS output is used to pass the inputs
//     EmbeddingResults unchanged.
//
// Outputs:
//   EMBEDDINGS: EmbeddingResult @Optional
//     The input EmbeddingResult, unchanged. Must be connected if the TIMESTAMPS
//     input is not connected, as it signals that timestamp aggregation is not
//     required.
//  TIMESTAMPED_EMBEDDINGS: std::vector<EmbeddingResult> @Optional
//     The embedding results aggregated by timestamp. Must be connected if the
//     TIMESTAMPS input is connected as it signals that timestamp aggregation is
//     required.
//
// Example without timestamp aggregation (pass-through):
// node {
//   calculator: "EmbeddingAggregationCalculator"
//   input_stream: "EMBEDDINGS:embeddings_in"
//   output_stream: "EMBEDDINGS:embeddings_out"
// }
//
// Example with timestamp aggregation:
// node {
//   calculator: "EmbeddingAggregationCalculator"
//   input_stream: "EMBEDDINGS:embeddings_in"
//   input_stream: "TIMESTAMPS:timestamps_in"
//   output_stream: "TIMESTAMPED_EMBEDDINGS:timestamped_embeddings_out"
// }
class EmbeddingAggregationCalculator : public Node {
 public:
  static constexpr Input<EmbeddingResult> kEmbeddingsIn{"EMBEDDINGS"};
  static constexpr Input<std::vector<Timestamp>>::Optional kTimestampsIn{
      "TIMESTAMPS"};
  static constexpr Output<EmbeddingResult>::Optional kEmbeddingsOut{
      "EMBEDDINGS"};
  static constexpr Output<std::vector<EmbeddingResult>>::Optional
      kTimestampedEmbeddingsOut{"TIMESTAMPED_EMBEDDINGS"};
  MEDIAPIPE_NODE_CONTRACT(kEmbeddingsIn, kTimestampsIn, kEmbeddingsOut,
                          kTimestampedEmbeddingsOut);

  static absl::Status UpdateContract(CalculatorContract* cc);
  absl::Status Open(CalculatorContext* cc);
  absl::Status Process(CalculatorContext* cc);

 private:
  bool time_aggregation_enabled_;
  std::unordered_map<int64_t, EmbeddingResult> cached_embeddings_;
};

absl::Status EmbeddingAggregationCalculator::UpdateContract(
    CalculatorContract* cc) {
  if (kTimestampsIn(cc).IsConnected()) {
    RET_CHECK(kTimestampedEmbeddingsOut(cc).IsConnected());
  } else {
    RET_CHECK(kEmbeddingsOut(cc).IsConnected());
  }
  return absl::OkStatus();
}

absl::Status EmbeddingAggregationCalculator::Open(CalculatorContext* cc) {
  time_aggregation_enabled_ = kTimestampsIn(cc).IsConnected();
  return absl::OkStatus();
}

absl::Status EmbeddingAggregationCalculator::Process(CalculatorContext* cc) {
  if (time_aggregation_enabled_) {
    cached_embeddings_[cc->InputTimestamp().Value()] =
        std::move(*kEmbeddingsIn(cc));
    if (kTimestampsIn(cc).IsEmpty()) {
      return absl::OkStatus();
    }
    auto timestamps = kTimestampsIn(cc).Get();
    std::vector<EmbeddingResult> results;
    results.reserve(timestamps.size());
    for (const auto& timestamp : timestamps) {
      auto& result = cached_embeddings_[timestamp.Value()];
      result.set_timestamp_ms((timestamp.Value() - timestamps[0].Value()) /
                              1000);
      results.push_back(std::move(result));
      cached_embeddings_.erase(timestamp.Value());
    }
    kTimestampedEmbeddingsOut(cc).Send(std::move(results));
  } else {
    auto result = kEmbeddingsIn(cc).Get();
    result.set_timestamp_ms(cc->InputTimestamp().Value() / 1000);
    kEmbeddingsOut(cc).Send(result);
  }
  RET_CHECK(cached_embeddings_.empty());
  return absl::OkStatus();
}

MEDIAPIPE_REGISTER_NODE(EmbeddingAggregationCalculator);

}  // namespace api2
}  // namespace mediapipe