chromium/third_party/mediapipe/src/mediapipe/framework/api2/stream/concatenate.h

#ifndef MEDIAPIPE_FRAMEWORK_API2_STREAM_CONCATENATE_H_
#define MEDIAPIPE_FRAMEWORK_API2_STREAM_CONCATENATE_H_

#include <vector>

#include "mediapipe/calculators/core/concatenate_vector_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/formats/body_rig.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/tensor.h"

namespace mediapipe::api2::builder {

namespace internal_stream_concatenate {

// Helper function that adds a node to a graph, that is capable of concatenating
// a specific type (T).
template <class T>
GenericNode& AddConcatenateVectorNode(Graph& graph) {
  if constexpr (std::is_same_v<T, mediapipe::LandmarkList>) {
    return graph.AddNode("ConcatenateLandmarkListCalculator");
  } else if constexpr (std::is_same_v<T, mediapipe::JointList>) {
    return graph.AddNode("ConcatenateJointListCalculator");
  } else if constexpr (std::is_same_v<T, std::vector<Tensor>>) {
    return graph.AddNode("ConcatenateTensorVectorCalculator");
  } else {
    static_assert(dependent_false<T>::value,
                  "Concatenate node is not available for the specified type.");
  }
}

template <typename StreamsT,
          typename PayloadT = typename StreamsT::value_type::PayloadT>
Stream<PayloadT> Concatenate(StreamsT& streams,
                             const bool only_emit_if_all_present,
                             Graph& graph) {
  auto& concatenator = AddConcatenateVectorNode<PayloadT>(graph);
  for (int i = 0; i < streams.size(); ++i) {
    streams[i].ConnectTo(concatenator.In("")[i]);
  }

  auto& concatenator_opts =
      concatenator
          .template GetOptions<mediapipe::ConcatenateVectorCalculatorOptions>();
  concatenator_opts.set_only_emit_if_all_present(only_emit_if_all_present);

  return concatenator.Out("").template Cast<PayloadT>();
}

}  // namespace internal_stream_concatenate

template <typename StreamsT,
          typename PayloadT = typename StreamsT::value_type::PayloadT>
Stream<PayloadT> Concatenate(StreamsT& streams, Graph& graph) {
  return internal_stream_concatenate::Concatenate(
      streams, /*only_emit_if_all_present=*/false, graph);
}

template <typename StreamsT,
          typename PayloadT = typename StreamsT::value_type::PayloadT>
Stream<PayloadT> ConcatenateIfAllPresent(StreamsT& streams, Graph& graph) {
  return internal_stream_concatenate::Concatenate(
      streams, /*only_emit_if_all_present=*/true, graph);
}

}  // namespace mediapipe::api2::builder

#endif  // MEDIAPIPE_FRAMEWORK_API2_STREAM_CONCATENATE_H_