chromium/third_party/mediapipe/src/mediapipe/framework/tool/switch/graph_processor.cc

#include "mediapipe/framework/tool/switch/graph_processor.h"

#include "absl/synchronization/mutex.h"

namespace mediapipe {

// TODO: add support for input and output side packets.
absl::Status GraphProcessor::Initialize(CalculatorGraphConfig graph_config) {
  graph_config_ = graph_config;

  MP_ASSIGN_OR_RETURN(graph_input_map_,
                      tool::TagMap::Create(graph_config_.input_stream()));
  MP_ASSIGN_OR_RETURN(graph_output_map_,
                      tool::TagMap::Create(graph_config_.output_stream()));
  return absl::OkStatus();
}

absl::Status GraphProcessor::AddPacket(CollectionItemId id, Packet packet) {
  absl::MutexLock lock(&graph_mutex_);
  const std::string& stream_name = graph_input_map_->Names().at(id.value());
  return graph_->AddPacketToInputStream(stream_name, packet);
}

std::shared_ptr<tool::TagMap> GraphProcessor::InputTags() {
  return graph_input_map_;
}

absl::Status GraphProcessor::SendPacket(CollectionItemId id, Packet packet) {
  MP_RETURN_IF_ERROR(WaitUntilInitialized());
  auto it = consumer_ids_.find(id);
  if (it == consumer_ids_.end()) {
    return absl::NotFoundError(
        absl::StrCat("Consumer stream not found: ", id.value()));
  }
  return consumer_->AddPacket(it->second, packet);
}

void GraphProcessor::SetConsumer(PacketConsumer* consumer) {
  absl::MutexLock lock(&graph_mutex_);
  consumer_ = consumer;
  auto input_map = consumer_->InputTags();
  for (auto id = input_map->BeginId(); id != input_map->EndId(); ++id) {
    auto tag_index = input_map->TagAndIndexFromId(id);
    auto stream_id = graph_input_map_->GetId(tag_index.first, tag_index.second);
    consumer_ids_[stream_id] = id;
  }
}

absl::Status GraphProcessor::ObserveGraph() {
  for (auto id = graph_output_map_->BeginId(); id != graph_output_map_->EndId();
       ++id) {
    std::string stream_name = graph_output_map_->Names().at(id.value());
    MP_RETURN_IF_ERROR(graph_->ObserveOutputStream(
        stream_name,
        [this, id](const Packet& packet) { return SendPacket(id, packet); },
        true));
  }
  return absl::OkStatus();
}

absl::Status GraphProcessor::WaitUntilInitialized() {
  absl::MutexLock lock(&graph_mutex_);
  auto is_initialized = [this]() ABSL_SHARED_LOCKS_REQUIRED(graph_mutex_) {
    return graph_ != nullptr && consumer_ != nullptr;
  };
  graph_mutex_.AwaitWithTimeout(absl::Condition(&is_initialized),
                                absl::Seconds(4));
  RET_CHECK(is_initialized()) << "GraphProcessor initialization timed out.";
  return absl::OkStatus();
}

absl::Status GraphProcessor::Start() {
  absl::MutexLock lock(&graph_mutex_);
  graph_ = std::make_unique<CalculatorGraph>();

  // The graph is validated here with its specified inputs and output.
  MP_RETURN_IF_ERROR(graph_->Initialize(graph_config_, side_packets_));
  MP_RETURN_IF_ERROR(ObserveGraph());
  MP_RETURN_IF_ERROR(graph_->StartRun({}));
  return absl::OkStatus();
}

absl::Status GraphProcessor::Shutdown() {
  absl::MutexLock lock(&graph_mutex_);
  if (!graph_) {
    return absl::OkStatus();
  }
  MP_RETURN_IF_ERROR(graph_->CloseAllPacketSources());
  MP_RETURN_IF_ERROR(graph_->WaitUntilDone());
  graph_ = nullptr;
  return absl::OkStatus();
}

absl::Status GraphProcessor::WaitUntilIdle() {
  absl::MutexLock lock(&graph_mutex_);
  return graph_->WaitUntilIdle();
}

// TODO
absl::Status GraphProcessor::SetSidePacket(CollectionItemId id, Packet packet) {
  return absl::OkStatus();
}
// TODO
std::shared_ptr<tool::TagMap> GraphProcessor::SideInputTags() {
  return nullptr;
}
// TODO
void GraphProcessor::SetSideConsumer(SidePacketConsumer* consumer) {}

}  // namespace mediapipe