chromium/third_party/mediapipe/src/mediapipe/framework/calculator_runner.cc

// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Definitions for CalculatorRunner.

#include "mediapipe/framework/calculator_runner.h"

#include "absl/log/absl_check.h"
#include "absl/log/absl_log.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"

namespace mediapipe {

const char CalculatorRunner::kSourcePrefix[] = "source_for_";
const char CalculatorRunner::kSinkPrefix[] = "sink_for_";

namespace {

// Calculator generating a stream with the given contents.
// Inputs: none
// Outputs: 1, with the contents provided via the input side packet.
// Input side packets: 1, pointing to CalculatorRunner::StreamContents.
class CalculatorRunnerSourceCalculator : public CalculatorBase {
 public:
  static absl::Status GetContract(CalculatorContract* cc) {
    cc->InputSidePackets()
        .Index(0)
        .Set<const CalculatorRunner::StreamContents*>();
    cc->Outputs().Index(0).SetAny();
    return absl::OkStatus();
  }

  absl::Status Open(CalculatorContext* cc) override {
    const auto* contents = cc->InputSidePackets()
                               .Index(0)
                               .Get<const CalculatorRunner::StreamContents*>();
    // Set the header and packets of the output stream.
    cc->Outputs().Index(0).SetHeader(contents->header);
    for (const Packet& packet : contents->packets) {
      cc->Outputs().Index(0).AddPacket(packet);
    }
    return absl::OkStatus();
  }
  absl::Status Process(CalculatorContext* cc) override {
    return tool::StatusStop();
  }
};
REGISTER_CALCULATOR(CalculatorRunnerSourceCalculator);

// Calculator recording the contents of a stream.
// Inputs: 1, with the contents written to the input side packet.
// Outputs: none
// Input side packets: 1, pointing to CalculatorRunner::StreamContents.
class CalculatorRunnerSinkCalculator : public CalculatorBase {
 public:
  static absl::Status GetContract(CalculatorContract* cc) {
    cc->Inputs().Index(0).SetAny();
    cc->InputSidePackets().Index(0).Set<CalculatorRunner::StreamContents*>();
    return absl::OkStatus();
  }

  absl::Status Open(CalculatorContext* cc) override {
    contents_ = cc->InputSidePackets()
                    .Index(0)
                    .Get<CalculatorRunner::StreamContents*>();
    contents_->header = cc->Inputs().Index(0).Header();
    return absl::OkStatus();
  }

  absl::Status Process(CalculatorContext* cc) override {
    contents_->packets.push_back(cc->Inputs().Index(0).Value());
    return absl::OkStatus();
  }

 private:
  CalculatorRunner::StreamContents* contents_ = nullptr;
};
REGISTER_CALCULATOR(CalculatorRunnerSinkCalculator);

}  // namespace

CalculatorRunner::CalculatorRunner(
    const CalculatorGraphConfig::Node& node_config) {
  MEDIAPIPE_CHECK_OK(InitializeFromNodeConfig(node_config));
}

absl::Status CalculatorRunner::InitializeFromNodeConfig(
    const CalculatorGraphConfig::Node& node_config) {
  node_config_ = node_config;

  if (node_config_.external_input_size() > 0) {
    RET_CHECK_EQ(0, node_config_.input_side_packet_size())
        << "Only one of input_side_packet or (deprecated) external_input can "
           "be set.";
    node_config_.mutable_external_input()->Swap(
        node_config_.mutable_input_side_packet());
  }

  MP_ASSIGN_OR_RETURN(auto input_map,
                      tool::TagMap::Create(node_config_.input_stream()));
  inputs_ = absl::make_unique<StreamContentsSet>(input_map);

  MP_ASSIGN_OR_RETURN(auto output_map,
                      tool::TagMap::Create(node_config_.output_stream()));
  outputs_ = absl::make_unique<StreamContentsSet>(output_map);

  MP_ASSIGN_OR_RETURN(auto input_side_map,
                      tool::TagMap::Create(node_config_.input_side_packet()));
  input_side_packets_ = absl::make_unique<PacketSet>(input_side_map);

  MP_ASSIGN_OR_RETURN(auto output_side_map,
                      tool::TagMap::Create(node_config_.output_side_packet()));
  output_side_packets_ = absl::make_unique<PacketSet>(output_side_map);

  return absl::OkStatus();
}

CalculatorRunner::CalculatorRunner(const std::string& calculator_type,
                                   const CalculatorOptions& options) {
  node_config_.set_calculator(calculator_type);
  *node_config_.mutable_options() = options;
  log_calculator_proto_ = true;
}

#if !defined(MEDIAPIPE_PROTO_LITE)
CalculatorRunner::CalculatorRunner(const std::string& node_config_string) {
  CalculatorGraphConfig::Node node_config;
  ABSL_CHECK(
      proto_ns::TextFormat::ParseFromString(node_config_string, &node_config));
  MEDIAPIPE_CHECK_OK(InitializeFromNodeConfig(node_config));
}

CalculatorRunner::CalculatorRunner(const std::string& calculator_type,
                                   const std::string& options_string,
                                   int num_inputs, int num_outputs,
                                   int num_side_packets) {
  node_config_.set_calculator(calculator_type);
  ABSL_CHECK(proto_ns::TextFormat::ParseFromString(
      options_string, node_config_.mutable_options()));
  SetNumInputs(num_inputs);
  SetNumOutputs(num_outputs);
  SetNumInputSidePackets(num_side_packets);
  // Reset log_calculator_proto to false, since it was set to true by
  // SetNum*() calls above.  This constructor is not deprecated but is
  // currently implemented in terms of deprecated functions.
  log_calculator_proto_ = false;
}
#endif

CalculatorRunner::~CalculatorRunner() {}

void CalculatorRunner::SetNumInputs(int n) {
  tool::TagAndNameInfo info;
  for (int i = 0; i < n; ++i) {
    info.names.push_back(absl::StrCat("input_", i));
  }
  InitializeInputs(info);
}

void CalculatorRunner::SetNumOutputs(int n) {
  tool::TagAndNameInfo info;
  for (int i = 0; i < n; ++i) {
    info.names.push_back(absl::StrCat("output_", i));
  }
  InitializeOutputs(info);
}

void CalculatorRunner::SetNumInputSidePackets(int n) {
  tool::TagAndNameInfo info;
  for (int i = 0; i < n; ++i) {
    info.names.push_back(absl::StrCat("side_packet_", i));
  }
  InitializeInputSidePackets(info);
}

void CalculatorRunner::InitializeInputs(const tool::TagAndNameInfo& info) {
  ABSL_CHECK(graph_ == nullptr);
  MEDIAPIPE_CHECK_OK(
      tool::SetFromTagAndNameInfo(info, node_config_.mutable_input_stream()));
  inputs_.reset(new StreamContentsSet(info));
  log_calculator_proto_ = true;
}

void CalculatorRunner::InitializeOutputs(const tool::TagAndNameInfo& info) {
  ABSL_CHECK(graph_ == nullptr);
  MEDIAPIPE_CHECK_OK(
      tool::SetFromTagAndNameInfo(info, node_config_.mutable_output_stream()));
  outputs_.reset(new StreamContentsSet(info));
  log_calculator_proto_ = true;
}

void CalculatorRunner::InitializeInputSidePackets(
    const tool::TagAndNameInfo& info) {
  ABSL_CHECK(graph_ == nullptr);
  MEDIAPIPE_CHECK_OK(tool::SetFromTagAndNameInfo(
      info, node_config_.mutable_input_side_packet()));
  input_side_packets_.reset(new PacketSet(info));
  log_calculator_proto_ = true;
}

mediapipe::Counter* CalculatorRunner::GetCounter(const std::string& name) {
  return graph_->GetCounterFactory()->GetCounter(name);
}

std::map<std::string, int64_t> CalculatorRunner::GetCountersValues() {
  return graph_->GetCounterFactory()->GetCounterSet()->GetCountersValues();
}

absl::Status CalculatorRunner::BuildGraph() {
  if (graph_ != nullptr) {
    // The graph was already built.
    return absl::OkStatus();
  }
  RET_CHECK(inputs_) << "The inputs were not initialized.";
  RET_CHECK(outputs_) << "The outputs were not initialized.";
  RET_CHECK(input_side_packets_)
      << "The input side packets were not initialized.";

  CalculatorGraphConfig config;
  // Add the calculator node.
  *(config.add_node()) = node_config_;

  for (int i = 0; i < node_config_.input_stream_size(); ++i) {
    std::string name;
    std::string tag;
    int index;
    MP_RETURN_IF_ERROR(tool::ParseTagIndexName(node_config_.input_stream(i),
                                               &tag, &index, &name));
    // Add a source for each input stream.
    auto* node = config.add_node();
    node->set_calculator("CalculatorRunnerSourceCalculator");
    node->add_output_stream(name);
    node->add_input_side_packet(absl::StrCat(kSourcePrefix, name));
  }
  for (int i = 0; i < node_config_.output_stream_size(); ++i) {
    std::string name;
    std::string tag;
    int index;
    MP_RETURN_IF_ERROR(tool::ParseTagIndexName(node_config_.output_stream(i),
                                               &tag, &index, &name));
    // Add a sink for each output stream.
    auto* node = config.add_node();
    node->set_calculator("CalculatorRunnerSinkCalculator");
    node->add_input_stream(name);
    node->add_input_side_packet(absl::StrCat(kSinkPrefix, name));
  }
  config.set_num_threads(1);

  if (log_calculator_proto_) {
#if defined(MEDIAPIPE_PROTO_LITE)
    ABSL_LOG(INFO)
        << "Please initialize CalculatorRunner using the recommended "
           "constructor:\n    CalculatorRunner runner(node_config);";
#else
    std::string config_string;
    proto_ns::TextFormat::Printer printer;
    printer.SetInitialIndentLevel(4);
    printer.PrintToString(node_config_, &config_string);
    ABSL_LOG(INFO)
        << "Please initialize CalculatorRunner using the recommended "
           "constructor:\n    CalculatorRunner runner(R\"(\n"
        << config_string << "\n    )\");";
#endif
  }

  graph_ = absl::make_unique<CalculatorGraph>();
  MP_RETURN_IF_ERROR(graph_->Initialize(config));
  return absl::OkStatus();
}

absl::Status CalculatorRunner::Run() {
  MP_RETURN_IF_ERROR(BuildGraph());
  // Set the input side packets for the sources.
  std::map<std::string, Packet> input_side_packets;
  int positional_index = -1;
  for (int i = 0; i < node_config_.input_stream_size(); ++i) {
    std::string name;
    std::string tag;
    int index;
    MP_RETURN_IF_ERROR(tool::ParseTagIndexName(node_config_.input_stream(i),
                                               &tag, &index, &name));
    const CalculatorRunner::StreamContents* contents;
    if (index == -1) {
      // positional_index considers the case when the tag is empty, which is
      // always the case when index == -1. If we ever support indices for
      // non-empty tags ("ABC:input1" and "ABC:input2" with automatic indices),
      // this should be changed to use a map insted.
      contents = &inputs_->Get(tag, ++positional_index);
    } else {
      contents = &inputs_->Get(tag, index);
    }
    input_side_packets.emplace(absl::StrCat(kSourcePrefix, name),
                               Adopt(new auto(contents)));
  }
  // Set the input side packets for the calculator.
  positional_index = -1;
  for (int i = 0; i < node_config_.input_side_packet_size(); ++i) {
    std::string name;
    std::string tag;
    int index;
    MP_RETURN_IF_ERROR(tool::ParseTagIndexName(
        node_config_.input_side_packet(i), &tag, &index, &name));
    const Packet* packet;
    if (index == -1) {
      packet = &input_side_packets_->Get(tag, ++positional_index);
    } else {
      packet = &input_side_packets_->Get(tag, index);
    }
    input_side_packets.emplace(name, *packet);
  }
  // Set the input side packets for the sinks.
  positional_index = -1;
  for (int i = 0; i < node_config_.output_stream_size(); ++i) {
    std::string name;
    std::string tag;
    int index;
    MP_RETURN_IF_ERROR(tool::ParseTagIndexName(node_config_.output_stream(i),
                                               &tag, &index, &name));
    CalculatorRunner::StreamContents* contents;
    if (index == -1) {
      contents = &outputs_->Get(tag, ++positional_index);
    } else {
      contents = &outputs_->Get(tag, index);
    }
    // Clear |contents| because Run() may be called multiple times.
    *contents = CalculatorRunner::StreamContents();
    input_side_packets.emplace(absl::StrCat(kSinkPrefix, name),
                               Adopt(new auto(contents)));
  }
  MP_RETURN_IF_ERROR(graph_->Run(input_side_packets));

  positional_index = -1;
  for (int i = 0; i < node_config_.output_side_packet_size(); ++i) {
    std::string name;
    std::string tag;
    int index;
    MP_RETURN_IF_ERROR(tool::ParseTagIndexName(
        node_config_.output_side_packet(i), &tag, &index, &name));
    Packet& contents = output_side_packets_->Get(
        tag, (index == -1) ? ++positional_index : index);
    MP_ASSIGN_OR_RETURN(contents, graph_->GetOutputSidePacket(name));
  }
  return absl::OkStatus();
}

}  // namespace mediapipe