#include "mediapipe/framework/calculator_graph.h"
#include <stdio.h>
#include <algorithm>
#include <cstdint>
#include <functional>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_set.h"
#include "absl/log/absl_check.h"
#include "absl/log/absl_log.h"
#include "absl/log/check.h"
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "absl/synchronization/mutex.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_base.h"
#include "mediapipe/framework/counter_factory.h"
#include "mediapipe/framework/delegating_executor.h"
#include "mediapipe/framework/executor.h"
#include "mediapipe/framework/graph_output_stream.h"
#include "mediapipe/framework/graph_service_manager.h"
#include "mediapipe/framework/input_stream_manager.h"
#include "mediapipe/framework/mediapipe_profiling.h"
#include "mediapipe/framework/output_side_packet_impl.h"
#include "mediapipe/framework/output_stream_manager.h"
#include "mediapipe/framework/output_stream_poller.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/packet_generator.h"
#include "mediapipe/framework/packet_generator.pb.h"
#include "mediapipe/framework/packet_set.h"
#include "mediapipe/framework/packet_type.h"
#include "mediapipe/framework/port.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/core_proto_inc.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/map_util.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/source_location.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_builder.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/framework/scheduler.h"
#include "mediapipe/framework/status_handler.h"
#include "mediapipe/framework/status_handler.pb.h"
#include "mediapipe/framework/thread_pool_executor.h"
#include "mediapipe/framework/thread_pool_executor.pb.h"
#include "mediapipe/framework/timestamp.h"
#include "mediapipe/framework/tool/fill_packet_set.h"
#include "mediapipe/framework/tool/status_util.h"
#include "mediapipe/framework/tool/tag_map.h"
#include "mediapipe/framework/tool/validate.h"
#include "mediapipe/framework/tool/validate_name.h"
#include "mediapipe/framework/validated_graph_config.h"
#include "mediapipe/gpu/gpu_service.h"
#include "mediapipe/gpu/graph_support.h"
#include "mediapipe/util/cpu_util.h"
namespace mediapipe {
namespace {
constexpr int kMaxNumAccumulatedErrors = …;
constexpr char kApplicationThreadExecutorType[] = …;
constexpr absl::StatusToStringMode kStatusLogFlags = …;
}
void CalculatorGraph::ScheduleAllOpenableNodes() { … }
void CalculatorGraph::GraphInputStream::SetHeader(const Packet& header) { … }
void CalculatorGraph::GraphInputStream::SetNextTimestampBound(
Timestamp timestamp) { … }
void CalculatorGraph::GraphInputStream::PropagateUpdatesToMirrors() { … }
void CalculatorGraph::GraphInputStream::Close() { … }
CalculatorGraph::CalculatorGraph(
std::shared_ptr<GraphServiceManager> service_manager)
: … { … }
CalculatorGraph::CalculatorGraph()
: … { … }
CalculatorGraph::CalculatorGraph(CalculatorContext* cc)
: … { … }
CalculatorGraph::CalculatorGraph(CalculatorGraphConfig config)
: … { … }
CalculatorGraph::~CalculatorGraph() { … }
absl::Status CalculatorGraph::InitializePacketGeneratorGraph(
const std::map<std::string, Packet>& side_packets) { … }
absl::Status CalculatorGraph::InitializeStreams() { … }
static void MaybeFixupLegacyGpuNodeContract(CalculatorNode& node) { … }
absl::Status CalculatorGraph::InitializeCalculatorNodes() { … }
absl::Status CalculatorGraph::InitializePacketGeneratorNodes(
const std::vector<int>& non_scheduled_generators) { … }
absl::Status CalculatorGraph::InitializeProfiler() { … }
absl::Status CalculatorGraph::InitializeExecutors() { … }
absl::Status CalculatorGraph::InitializeDefaultExecutor(
const ThreadPoolExecutorOptions* default_executor_options,
bool use_application_thread) { … }
absl::Status CalculatorGraph::Initialize(
std::unique_ptr<ValidatedGraphConfig> validated_graph,
const std::map<std::string, Packet>& side_packets) { … }
absl::Status CalculatorGraph::Initialize(CalculatorGraphConfig input_config) { … }
absl::Status CalculatorGraph::Initialize(
CalculatorGraphConfig input_config,
const std::map<std::string, Packet>& side_packets) { … }
absl::Status CalculatorGraph::Initialize(
const std::vector<CalculatorGraphConfig>& input_configs,
const std::vector<CalculatorGraphTemplate>& input_templates,
const std::map<std::string, Packet>& side_packets,
const std::string& graph_type, const Subgraph::SubgraphOptions* options) { … }
absl::Status CalculatorGraph::ObserveOutputStream(
const std::string& stream_name,
std::function<absl::Status(const Packet&)> packet_callback,
bool observe_timestamp_bounds) { … }
absl::Status CalculatorGraph::SetErrorCallback(
std::function<void(const absl::Status&)> error_callback) { … }
absl::StatusOr<OutputStreamPoller> CalculatorGraph::AddOutputStreamPoller(
const std::string& stream_name, bool observe_timestamp_bounds) { … }
absl::StatusOr<Packet> CalculatorGraph::GetOutputSidePacket(
const std::string& packet_name) { … }
absl::Status CalculatorGraph::Run(
const std::map<std::string, Packet>& extra_side_packets) { … }
absl::Status CalculatorGraph::StartRun(
const std::map<std::string, Packet>& extra_side_packets,
const std::map<std::string, Packet>& stream_headers) { … }
#if !MEDIAPIPE_DISABLE_GPU
absl::Status CalculatorGraph::SetGpuResources(
std::shared_ptr<::mediapipe::GpuResources> resources) {
RET_CHECK_NE(resources, nullptr);
auto gpu_service = service_manager_->GetServiceObject(kGpuService);
RET_CHECK_EQ(gpu_service, nullptr)
<< "The GPU resources have already been configured.";
return service_manager_->SetServiceObject(kGpuService, std::move(resources));
}
std::shared_ptr<::mediapipe::GpuResources> CalculatorGraph::GetGpuResources()
const {
return service_manager_->GetServiceObject(kGpuService);
}
static Packet GetLegacyGpuSharedSidePacket(
const std::map<std::string, Packet>& side_packets) {
auto legacy_sp_iter = side_packets.find(kGpuSharedSidePacketName);
if (legacy_sp_iter == side_packets.end()) return {};
return legacy_sp_iter->second;
}
absl::Status CalculatorGraph::MaybeSetUpGpuServiceFromLegacySidePacket(
Packet legacy_sp) {
if (legacy_sp.IsEmpty()) return absl::OkStatus();
auto gpu_resources = service_manager_->GetServiceObject(kGpuService);
if (gpu_resources) {
ABSL_LOG(WARNING)
<< "::mediapipe::GpuSharedData provided as a side packet while the "
<< "graph already had one; ignoring side packet";
return absl::OkStatus();
}
gpu_resources = legacy_sp.Get<::mediapipe::GpuSharedData*>()->gpu_resources;
return service_manager_->SetServiceObject(kGpuService, gpu_resources);
}
std::map<std::string, Packet> CalculatorGraph::MaybeCreateLegacyGpuSidePacket(
Packet legacy_sp) {
std::map<std::string, Packet> additional_side_packets;
auto gpu_resources = service_manager_->GetServiceObject(kGpuService);
if (gpu_resources &&
(legacy_sp.IsEmpty() ||
legacy_sp.Get<::mediapipe::GpuSharedData*>()->gpu_resources !=
gpu_resources)) {
legacy_gpu_shared_ =
std::make_unique<mediapipe::GpuSharedData>(gpu_resources);
additional_side_packets[kGpuSharedSidePacketName] =
MakePacket<::mediapipe::GpuSharedData*>(legacy_gpu_shared_.get());
}
return additional_side_packets;
}
static bool UsesGpu(const CalculatorNode& node) {
return node.Contract().ServiceRequests().contains(kGpuService.key);
}
absl::Status CalculatorGraph::PrepareGpu() {
auto gpu_resources = service_manager_->GetServiceObject(kGpuService);
if (!gpu_resources) return absl::OkStatus();
for (auto& node : nodes_) {
if (UsesGpu(*node)) {
MP_RETURN_IF_ERROR(gpu_resources->PrepareGpuNode(node.get()));
}
}
for (const auto& name_executor : gpu_resources->GetGpuExecutors()) {
MP_RETURN_IF_ERROR(
SetExecutorInternal(name_executor.first, name_executor.second));
}
return absl::OkStatus();
}
#endif
absl::Status CalculatorGraph::PrepareServices() { … }
absl::Status CalculatorGraph::PrepareForRun(
const std::map<std::string, Packet>& extra_side_packets,
const std::map<std::string, Packet>& stream_headers) { … }
absl::Status CalculatorGraph::WaitUntilIdle() { … }
absl::Status CalculatorGraph::WaitUntilDone() { … }
absl::Status CalculatorGraph::WaitForObservedOutput() { … }
absl::Status CalculatorGraph::AddPacketToInputStream(
absl::string_view stream_name, const Packet& packet) { … }
absl::Status CalculatorGraph::AddPacketToInputStream(
absl::string_view stream_name, Packet&& packet) { … }
absl::Status CalculatorGraph::SetInputStreamTimestampBound(
const std::string& stream_name, Timestamp timestamp) { … }
template <typename T>
absl::Status CalculatorGraph::AddPacketToInputStreamInternal(
absl::string_view stream_name, T&& packet) { … }
absl::Status CalculatorGraph::SetInputStreamMaxQueueSize(
const std::string& stream_name, int max_queue_size) { … }
bool CalculatorGraph::HasInputStream(const std::string& stream_name) { … }
absl::Status CalculatorGraph::CloseInputStream(const std::string& stream_name) { … }
absl::Status CalculatorGraph::CloseAllInputStreams() { … }
absl::Status CalculatorGraph::CloseAllPacketSources() { … }
void CalculatorGraph::RecordError(const absl::Status& error) { … }
bool CalculatorGraph::GetCombinedErrors(absl::Status* error_status) { … }
bool CalculatorGraph::GetCombinedErrors(const std::string& error_prefix,
absl::Status* error_status) { … }
void CalculatorGraph::CallStatusHandlers(GraphRunState graph_run_state,
const absl::Status& status) { … }
int CalculatorGraph::GetMaxInputStreamQueueSize() { … }
void CalculatorGraph::UpdateThrottledNodes(InputStreamManager* stream,
bool* stream_was_full) { … }
bool CalculatorGraph::IsNodeThrottled(int node_id) { … }
bool IsGraphOutputStream(
InputStreamManager* stream,
const std::vector<std::shared_ptr<internal::GraphOutputStream>>&
graph_output_streams) { … }
bool CalculatorGraph::UnthrottleSources() { … }
CalculatorGraph::GraphInputStreamAddMode
CalculatorGraph::GetGraphInputStreamAddMode() const { … }
void CalculatorGraph::SetGraphInputStreamAddMode(GraphInputStreamAddMode mode) { … }
void CalculatorGraph::Cancel() { … }
void CalculatorGraph::Pause() { … }
void CalculatorGraph::Resume() { … }
absl::Status CalculatorGraph::SetExecutorInternal(
const std::string& name, std::shared_ptr<Executor> executor) { … }
absl::Status CalculatorGraph::SetExecutor(const std::string& name,
std::shared_ptr<Executor> executor) { … }
absl::Status CalculatorGraph::CreateDefaultThreadPool(
const ThreadPoolExecutorOptions* default_executor_options,
int num_threads) { … }
bool CalculatorGraph::IsReservedExecutorName(const std::string& name) { … }
absl::Status CalculatorGraph::FinishRun() { … }
void CalculatorGraph::CleanupAfterRun(absl::Status* status) { … }
const OutputStreamManager* CalculatorGraph::FindOutputStreamManager(
const std::string& name) { … }
std::string CalculatorGraph::ListSourceNodes() const { … }
std::string CalculatorGraph::GetParentNodeDebugName(
InputStreamManager* stream) const { … }
namespace {
void PrintTimingToInfo(const std::string& label, int64_t timer_value) { … }
bool MetricElementComparator(const std::pair<std::string, int64_t>& e1,
const std::pair<std::string, int64_t>& e2) { … }
}
absl::Status CalculatorGraph::GetCalculatorProfiles(
std::vector<CalculatorProfile>* profiles) const { … }
}