chromium/third_party/mediapipe/src/mediapipe/framework/api2/port.h

// This file defines an API to define a node's ports in a concise, type-safe
// way. Example usage in a node:
//
//   static constexpr Input<int> kBase("IN");
//   static constexpr Output<float> kOut("OUT");
//   static constexpr SideInput<float>::Optional kDelta("DELTA");
//   static constexpr SideOutput<float> kForward("FORWARD");
//
// Pass a CalculatorContext to a port to access the inputs or outputs in the
// context. For example:
//
//   kBase(cc) yields an InputShardAccess<int>
//   kOut(cc) yields an OutputShardAccess<float>
//   kDelta(cc) yields an InputSidePacketAccess<float>
//   kForward(cc) yields an OutputSidePacketAccess<float>

#ifndef MEDIAPIPE_FRAMEWORK_API2_PORT_H_
#define MEDIAPIPE_FRAMEWORK_API2_PORT_H_

#include <type_traits>
#include <utility>

#include "absl/log/absl_check.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/api2/const_str.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/calculator_contract.h"
#include "mediapipe/framework/output_side_packet.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/tool/type_util.h"

namespace mediapipe {
namespace api2 {

// This is a base class for various types of port. It is not meant to be used
// directly by node code.
class PortBase {};

// These four base classes are used to distinguish between ports of different
// kinds. They are not meant to be used directly by node code.
class InputBase : public PortBase {};
class OutputBase : public PortBase {};
class SideInputBase : public PortBase {};
class SideOutputBase : public PortBase {};

struct NoneType {};

template <auto& kP>
struct SameType {};

class PacketTypeAccess;
class PacketTypeAccessFallback;
template <typename T>
class InputShardAccess;
template <typename T>
class OutputShardAccess;
template <typename T>
class InputSidePacketAccess;
template <typename T>
class OutputSidePacketAccess;
template <typename T>
class InputShardOrSideAccess;

namespace internal {

// Forward declaration for AddToContract friend.
template <typename...>
class Contract;

template <class CC>
auto GetCollection(CC* cc, const InputBase& port) -> decltype(cc->Inputs()) {}

template <class CC>
auto GetCollection(CC* cc, const SideInputBase& port)
    -> decltype(cc->InputSidePackets()) {}

template <class CC>
auto GetCollection(CC* cc, const OutputBase& port) -> decltype(cc->Outputs()) {}

template <class CC>
auto GetCollection(CC* cc, const SideOutputBase& port)
    -> decltype(cc->OutputSidePackets()) {}

template <class Collection>
auto GetOrNull(Collection& collection, const absl::string_view& tag, int index)
    -> decltype(&collection.Get(std::declval<CollectionItemId>())) {}

template <class T>
struct IsOneOf : std::false_type {};

IsOneOf<OneOf<T...>>;

template <class T>
struct IsSameType : std::false_type {};

template <class P, P& kP>
struct IsSameType<SameType<kP>> : std::true_type {};

template <typename T,
          typename std::enable_if<!std::is_same<T, AnyType>{}

template <typename T, typename std::enable_if<IsSameType<T>{}

template <typename T,
          typename std::enable_if<std::is_same<T, AnyType>{}

template <>
inline void SetType<NoneType>(CalculatorContract* cc, PacketType& pt) {}

template <typename... T>
inline void SetTypeOneOf(OneOf<T...>, CalculatorContract* cc, PacketType& pt) {}

template <typename T, typename std::enable_if<IsOneOf<T>{}

template <typename ValueT>
InputShardAccess<ValueT> SinglePortAccess(mediapipe::CalculatorContext* cc,
                                          InputStreamShard* stream) {}

template <typename ValueT>
OutputShardAccess<ValueT> SinglePortAccess(mediapipe::CalculatorContext* cc,
                                           OutputStreamShard* stream) {}

template <typename ValueT>
InputSidePacketAccess<ValueT> SinglePortAccess(
    mediapipe::CalculatorContext* cc, const mediapipe::Packet* packet) {}

template <typename ValueT>
OutputSidePacketAccess<ValueT> SinglePortAccess(
    mediapipe::CalculatorContext* cc, OutputSidePacket* osp) {}

template <typename ValueT>
InputShardOrSideAccess<ValueT> SinglePortAccess(
    mediapipe::CalculatorContext* cc, InputStreamShard* stream,
    const mediapipe::Packet* packet) {}

template <typename ValueT>
PacketTypeAccess SinglePortAccess(mediapipe::CalculatorContract* cc,
                                  PacketType* pt);

template <typename ValueT>
PacketTypeAccessFallback SinglePortAccess(mediapipe::CalculatorContract* cc,
                                          PacketType* pt, bool is_stream);

template <typename ValueT, typename PortT, class CC>
auto AccessPort(std::false_type, const PortT& port, CC* cc) {}

template <typename ValueT, typename X, class CC>
class MultiplePortAccess {};

template <typename ValueT, typename PortT, class CC>
auto AccessPort(std::true_type, const PortT& port, CC* cc) {}

template <class Base>
struct SideBase;

template <>
struct SideBase<InputBase> {};

// TODO: maybe return a PacketBase instead of a Packet<internal::Generic>?
template <typename T, typename = void>
struct ActualPayloadType {};

template <typename T>
struct ActualPayloadType<T, std::enable_if_t<IsSameType<T>{}, void>> {
  using type = typename ActualPayloadType<
      typename std::decay_t<decltype(T::kPort)>::value_t>::type;
};

}  // namespace internal

// Maps special port value types, such as AnyType, to internal::Generic.
template <typename T>
using ActualPayloadT = typename internal::ActualPayloadType<T>::type;

static_assert(std::is_same_v<ActualPayloadT<int>, int>, "");
static_assert(std::is_same_v<ActualPayloadT<AnyType>, internal::Generic>, "");

template <typename Base, typename ValueT, bool IsOptional = false,
          bool IsMultiple = false>
class SideFallbackT;

// This template is used to define a port. Nodes should use it through one
// of the aliases below (Input, Output, SideInput, SideOutput).
template <typename Base, typename ValueT, bool IsOptionalV = false,
          bool IsMultipleV = false>
class PortCommon : public Base {};

// Use one of these templates to define a port in node code.
template <typename T = internal::Generic>
using Input = PortCommon<InputBase, T>;

template <typename T = internal::Generic>
using Output = PortCommon<OutputBase, T>;

template <typename T = internal::Generic>
using SideInput = PortCommon<SideInputBase, T>;

template <typename T = internal::Generic>
using SideOutput = PortCommon<SideOutputBase, T>;

template <typename Base, typename ValueT, bool IsOptionalV, bool IsMultipleV>
class SideFallbackT : public Base {};

// An OutputShardAccess is returned when accessing an output stream within a
// CalculatorContext (e.g. kOut(cc)), and provides a type-safe interface to
// OutputStreamShard. Like that class, this class will not be usually named in
// calculator code, but used as a temporary object (e.g. kOut(cc).Send(...)).
//
// If not connected (!IsConnected()) SetNextTimestampBound is safe to call and
// does nothing.
// All the sub-classes that define Send should implement it to be safe to to
// call if not connected and do nothing in such case.
class OutputShardAccessBase {};

template <typename T>
class OutputShardAccess : public OutputShardAccessBase {};

template <>
class OutputShardAccess<internal::Generic> : public OutputShardAccessBase {};

// Equivalent of OutputShardAccess, but for side packets.
template <typename T>
class OutputSidePacketAccess {};

template <typename T>
class InputShardAccess : public Packet<T> {};

template <typename T>
class InputSidePacketAccess : public Packet<T> {};

template <typename T>
class InputShardOrSideAccess : public Packet<T> {};

class PacketTypeAccess {};

class PacketTypeAccessFallback : public PacketTypeAccess {};

namespace internal {
template <typename ValueT>
PacketTypeAccess SinglePortAccess(mediapipe::CalculatorContract* cc,
                                  PacketType* pt) {}
template <typename ValueT>
PacketTypeAccessFallback SinglePortAccess(mediapipe::CalculatorContract* cc,
                                          PacketType* pt, bool is_stream) {}
}  // namespace internal

}  // namespace api2
}  // namespace mediapipe

#endif  // MEDIAPIPE_FRAMEWORK_API2_PORT_H_