chromium/third_party/mediapipe/src/mediapipe/calculators/core/packet_cloner_calculator.cc

// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// This takes packets from N+1 streams, A_1, A_2, ..., A_N, B.
// For every packet that appears in B, outputs the most recent packet from each
// of the A_i on a separate stream.

#include <string_view>
#include <vector>

#include "absl/strings/string_view.h"
#include "mediapipe/calculators/core/packet_cloner_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"

namespace mediapipe {

// For every packet received on the last stream, output the latest packet
// obtained on all other streams. Therefore, if the last stream outputs at a
// higher rate than the others, this effectively clones the packets from the
// other streams to match the last.
//
// Example config:
// node {
//   calculator: "PacketClonerCalculator"
//   input_stream: "first_base_signal"
//   input_stream: "second_base_signal"
//   input_stream: "tick_signal"  # or input_stream: "TICK:tick_signal"
//   output_stream: "cloned_first_base_signal"
//   output_stream: "cloned_second_base_signal"
// }
//
// Or you can use "TICK" tag and put corresponding input stream at any location,
// for example at the very beginning:
// node {
//   calculator: "PacketClonerCalculator"
//   input_stream: "TICK:tick_signal"
//   input_stream: "first_base_signal"
//   input_stream: "second_base_signal"
//   output_stream: "cloned_first_base_signal"
//   output_stream: "cloned_second_base_signal"
// }
//
// Related:
//   packet_cloner_calculator.proto: Options for this calculator.
//   merge_input_streams_calculator.cc: One output stream.
//   packet_inner_join_calculator.cc: Don't output unless all inputs are new.
class PacketClonerCalculator : public CalculatorBase {
 public:
  static absl::Status GetContract(CalculatorContract* cc) {
    cc->SetProcessTimestampBounds(true);
    const Ids ids = GetIds(*cc);
    for (const auto& in_out : ids.inputs_outputs) {
      auto& input = cc->Inputs().Get(in_out.in);
      input.SetAny();
      cc->Outputs().Get(in_out.out).SetSameAs(&input);
    }
    cc->Inputs().Get(ids.tick_id).SetAny();
    return absl::OkStatus();
  }

  absl::Status Open(CalculatorContext* cc) final {
    // Load options.
    const auto calculator_options =
        cc->Options<mediapipe::PacketClonerCalculatorOptions>();
    output_only_when_all_inputs_received_ =
        calculator_options.output_only_when_all_inputs_received() ||
        calculator_options.output_packets_only_when_all_inputs_received();
    output_empty_packets_before_all_inputs_received_ =
        calculator_options.output_packets_only_when_all_inputs_received();

    // Prepare input and output ids.
    ids_ = GetIds(*cc);
    current_.resize(ids_.inputs_outputs.size());

    // Pass along the header for each stream if present.
    for (const auto& in_out : ids_.inputs_outputs) {
      auto& input = cc->Inputs().Get(in_out.in);
      if (!input.Header().IsEmpty()) {
        cc->Outputs().Get(in_out.out).SetHeader(input.Header());
      }
    }
    return absl::OkStatus();
  }

  absl::Status Process(CalculatorContext* cc) final {
    // Store input signals.
    for (int i = 0; i < ids_.inputs_outputs.size(); ++i) {
      const auto& input = cc->Inputs().Get(ids_.inputs_outputs[i].in);
      if (!input.IsEmpty()) {
        current_[i] = input.Value();
      }
    }

    bool has_all_inputs = HasAllInputs();
    // Output according to the TICK signal.
    if (!cc->Inputs().Get(ids_.tick_id).IsEmpty() &&
        (has_all_inputs || !output_only_when_all_inputs_received_)) {
      // Output each stream.
      for (int i = 0; i < ids_.inputs_outputs.size(); ++i) {
        auto& output = cc->Outputs().Get(ids_.inputs_outputs[i].out);
        if (!current_[i].IsEmpty()) {
          output.AddPacket(current_[i].At(
              cc->Inputs().Get(ids_.tick_id).Value().Timestamp()));
        }
      }
    }

    // Set timestamp bounds according to the TICK signal.
    bool tick_updated = cc->Inputs().Get(ids_.tick_id).Value().Timestamp() ==
                        cc->InputTimestamp();
    bool producing_output = has_all_inputs ||
                            output_empty_packets_before_all_inputs_received_ ||
                            !output_only_when_all_inputs_received_;
    if (tick_updated && producing_output) {
      SetAllNextTimestampBounds(cc);
    }

    return absl::OkStatus();
  }

 private:
  struct Ids {
    struct InputOutput {
      CollectionItemId in;
      CollectionItemId out;
    };
    CollectionItemId tick_id;
    std::vector<InputOutput> inputs_outputs;
  };

  template <typename CC>
  static Ids GetIds(CC& cc) {
    Ids ids;
    static constexpr absl::string_view kEmptyTag = "";
    int num_inputs_to_clone = cc.Inputs().NumEntries(kEmptyTag);
    static constexpr absl::string_view kTickTag = "TICK";
    if (cc.Inputs().HasTag(kTickTag)) {
      ids.tick_id = cc.Inputs().GetId(kTickTag, 0);
    } else {
      --num_inputs_to_clone;
      ids.tick_id = cc.Inputs().GetId(kEmptyTag, num_inputs_to_clone);
    }
    for (int i = 0; i < num_inputs_to_clone; ++i) {
      ids.inputs_outputs.push_back({.in = cc.Inputs().GetId(kEmptyTag, i),
                                    .out = cc.Outputs().GetId(kEmptyTag, i)});
    }
    return ids;
  }

  void SetAllNextTimestampBounds(CalculatorContext* cc) {
    for (const auto& in_out : ids_.inputs_outputs) {
      cc->Outputs()
          .Get(in_out.out)
          .SetNextTimestampBound(cc->InputTimestamp().NextAllowedInStream());
    }
  }

  bool HasAllInputs() {
    for (int i = 0; i < ids_.inputs_outputs.size(); ++i) {
      if (current_[i].IsEmpty()) {
        return false;
      }
    }
    return true;
  }

  std::vector<Packet> current_;
  Ids ids_;
  bool output_only_when_all_inputs_received_;
  bool output_empty_packets_before_all_inputs_received_;
};

REGISTER_CALCULATOR(PacketClonerCalculator);

}  // namespace mediapipe