chromium/third_party/mediapipe/src/mediapipe/framework/calculator_base.h

// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Defines CalculatorBase, the base class for feature computation.

#ifndef MEDIAPIPE_FRAMEWORK_CALCULATOR_BASE_H_
#define MEDIAPIPE_FRAMEWORK_CALCULATOR_BASE_H_

#include <memory>
#include <string>
#include <type_traits>

#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/calculator_contract.h"
#include "mediapipe/framework/deps/registration.h"
#include "mediapipe/framework/port.h"
#include "mediapipe/framework/timestamp.h"

namespace mediapipe {

// Experimental: CalculatorBase will eventually replace Calculator as the
// base class of leaf (non-subgraph) nodes in a CalculatorGraph.
//
// The base calculator class.  A subclass must, at a minimum, provide the
// implementation of GetContract(), Process(), and register the calculator
// using REGISTER_CALCULATOR(MyClass).
//
// The framework calls four primary functions on a calculator.
// On initialization of the graph, a static function is called.
//   GetContract()
// Then, for each run of the graph on a set of input side packets, the
// following sequence will occur.
//   Open()
//   Process() (repeatedly)
//   Close()
//
// The entire calculator is constructed and destroyed for each graph run
// (set of input side packets, which could mean once per video, or once
// per image).  Any expensive operations and large objects should be
// input side packets.
//
// The framework calls Open() to initialize the calculator.
// If appropriate, Open() should call cc->SetOffset() or
// cc->Outputs().Get(id)->SetNextTimestampBound() to allow the framework to
// better optimize packet queueing.
//
// The framework calls Process() for every packet received on the input
// streams.  The framework guarantees that cc->InputTimestamp() will
// increase with every call to Process().  An empty packet will be on the
// input stream if there is no packet on a particular input stream (but
// some other input stream has a packet).
//
// The framework calls Close() after all calls to Process().
//
// Calculators with no inputs are referred to as "sources" and are handled
// slightly differently than non-sources (see the function comments for
// Process() for more details).
//
// Calculators must be thread-compatible.
// The framework does not call the non-const methods of a calculator from
// multiple threads at the same time.  However, the thread that calls the
// methods of a calculator is not fixed.  Therefore, calculators should not
// use ThreadLocal objects.
class CalculatorBase {};

namespace api2 {
class Node;
}  // namespace api2

namespace internal {

// Gives access to the static functions within subclasses of CalculatorBase.
// This adds functionality akin to virtual static functions.
class CalculatorBaseFactory {};

// Functions for checking that the calculator has the required GetContract.
template <class T>
constexpr bool CalculatorHasGetContract(decltype(&T::GetContract) /*unused*/) {}
template <class T>
constexpr bool CalculatorHasGetContract(...) {}

// Provides access to the static functions within a specific subclass
// of CalculatorBase.
template <class T, class Enable = void>
class CalculatorBaseFactoryFor : public CalculatorBaseFactory {};

template <class T>
class CalculatorBaseFactoryFor<
    T,
    typename std::enable_if<std::is_base_of<mediapipe::CalculatorBase, T>{} &&
                            !std::is_base_of<mediapipe::api2::Node, T>{}>::type>
    : public CalculatorBaseFactory {
 public:
  static_assert(CalculatorHasGetContract<T>(nullptr),
                "GetContract() must be defined with the correct signature in "
                "every calculator.");

  // Provides access to the static function GetContract within a specific
  // subclass of CalculatorBase.
  absl::Status GetContract(CalculatorContract* cc) final {
    // CalculatorBaseSubclass must implement this function, since it is not
    // implemented in the parent class.
    return T::GetContract(cc);
  }

  std::unique_ptr<CalculatorBase> CreateCalculator(
      CalculatorContext* calculator_context) final {
    return absl::make_unique<T>();
  }
};

}  // namespace internal

using CalculatorBaseRegistry =
    GlobalFactoryRegistry<std::unique_ptr<internal::CalculatorBaseFactory>>;

}  // namespace mediapipe

#endif  // MEDIAPIPE_FRAMEWORK_CALCULATOR_BASE_H_