// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <functional>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <utility>
#include <vector>
#include "absl/log/absl_check.h"
#include "absl/strings/str_cat.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_options.pb.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/tool/container_util.h"
#include "mediapipe/framework/tool/name_util.h"
#include "mediapipe/framework/tool/subgraph_expansion.h"
#include "mediapipe/framework/tool/switch_container.pb.h"
namespace mediapipe {
namespace tool {
using mediapipe::SwitchContainerOptions;
// A graph factory producing a CalculatorGraphConfig routing packets to
// one of several contained CalculatorGraphConfigs.
//
// Usage example:
//
// node {
// calculator: "SwitchContainer"
// input_stream: "ENABLE:enable"
// input_stream: "INPUT_VIDEO:video_frames"
// output_stream: "OUTPUT_VIDEO:output_frames"
// options {
// [mediapipe.SwitchContainerOptions.ext] {
// contained_node: { calculator: "BasicSubgraph" }
// contained_node: { calculator: "AdvancedSubgraph" }
// }
// }
// }
//
// Note that the input and output stream tags supplied to the container node
// must match the input and output stream tags required by the contained nodes,
// such as "INPUT_VIDEO" and "OUTPUT_VIDEO" in the example above.
//
// Input stream "ENABLE" specifies routing of packets to either contained_node 0
// or contained_node 1, given "ENABLE:false" or "ENABLE:true" respectively.
// Input-side-packet "ENABLE" and input-stream "SELECT" can also be used
// similarly to specify the active channel.
//
// Note that this container defaults to use ImmediateInputStreamHandler,
// which can be used to accept infrequent "enable" packets asynchronously.
// However, it can be overridden to work with DefaultInputStreamHandler,
// which can be used to accept frequent "enable" packets synchronously.
class SwitchContainer : public Subgraph {
public:
SwitchContainer() = default;
absl::StatusOr<CalculatorGraphConfig> GetConfig(
const Subgraph::SubgraphOptions& options) override;
};
REGISTER_MEDIAPIPE_GRAPH(SwitchContainer);
using TagIndex = std::pair<std::string, int>;
// Returns the stream name for one of the demux output channels.
// This is the channel number followed by the stream name separated by "__".
// For example, the channel-name for sream "frame" on channel 1 is "c1__frame".
std::string ChannelName(const std::string& name, int channel) {
return absl::StrCat("c", channel, "__", name);
}
// Returns a SwitchDemuxCalculator node.
CalculatorGraphConfig::Node* BuildDemuxNode(
const std::map<TagIndex, std::string>& input_tags,
const CalculatorGraphConfig::Node& container_node,
CalculatorGraphConfig* config) {
CalculatorGraphConfig::Node* result = config->add_node();
*result->mutable_calculator() = "SwitchDemuxCalculator";
return result;
}
// Returns a SwitchMuxCalculator node.
CalculatorGraphConfig::Node* BuildMuxNode(
const std::map<TagIndex, std::string>& output_tags,
CalculatorGraphConfig* config) {
CalculatorGraphConfig::Node* result = config->add_node();
*result->mutable_calculator() = "SwitchMuxCalculator";
return result;
}
// Returns a PacketSequencerCalculator node.
CalculatorGraphConfig::Node* BuildTimestampNode(CalculatorGraphConfig* config,
bool async_selection) {
CalculatorGraphConfig::Node* result = config->add_node();
*result->mutable_calculator() = "PacketSequencerCalculator";
if (!async_selection) {
*result->mutable_input_stream_handler()->mutable_input_stream_handler() =
"DefaultInputStreamHandler";
}
return result;
}
// Copies options from one node to another.
void CopyOptions(const CalculatorGraphConfig::Node& source,
CalculatorGraphConfig::Node* dest) {
if (source.has_options()) {
*dest->mutable_options() = source.options();
}
*dest->mutable_node_options() = source.node_options();
}
// Clears options that are consumed by the container and not forwarded.
void ClearContainerOptions(SwitchContainerOptions* result) {
result->clear_contained_node();
}
// Clears options that are consumed by the container and not forwarded.
void ClearContainerOptions(CalculatorGraphConfig::Node* dest) {
if (dest->has_options() &&
dest->mutable_options()->HasExtension(SwitchContainerOptions::ext)) {
ClearContainerOptions(
dest->mutable_options()->MutableExtension(SwitchContainerOptions::ext));
}
for (google::protobuf::Any& a : *dest->mutable_node_options()) {
if (a.Is<SwitchContainerOptions>()) {
SwitchContainerOptions extension;
a.UnpackTo(&extension);
ClearContainerOptions(&extension);
a.PackFrom(extension);
}
}
}
// Returns an unused name similar to a specified name.
std::string UniqueName(std::string name, std::set<std::string>* names) {
ABSL_CHECK(names != nullptr);
std::string result = name;
int suffix = 2;
while (names->count(result) > 0) {
result = absl::StrCat(name, "_", suffix++);
}
names->insert(result);
return result;
}
// Parses tag, index, and name from a list of stream identifiers.
void ParseTags(const proto_ns::RepeatedPtrField<std::string>& streams,
std::map<TagIndex, std::string>* result) {
ABSL_CHECK(result != nullptr);
std::set<std::string> used_names;
int used_index = -1;
for (const std::string& stream : streams) {
std::string name = UniqueName(ParseNameFromStream(stream), &used_names);
TagIndex tag_index = ParseTagIndexFromStream(stream);
if (tag_index.second == -1) {
tag_index.second = ++used_index;
}
result->insert({tag_index, name});
}
}
// Removes the entry for a tag and index from a map.
void EraseTag(const std::string& stream,
std::map<TagIndex, std::string>* streams) {
ABSL_CHECK(streams != nullptr);
streams->erase(ParseTagIndexFromStream(absl::StrCat(stream, ":u")));
}
// Removes the entry for a tag and index from a list.
void EraseTag(const std::string& stream,
proto_ns::RepeatedPtrField<std::string>* streams) {
ABSL_CHECK(streams != nullptr);
TagIndex stream_tag = ParseTagIndexFromStream(absl::StrCat(stream, ":u"));
for (int i = streams->size() - 1; i >= 0; --i) {
TagIndex tag = ParseTagIndexFromStream(streams->at(i));
if (tag == stream_tag) {
streams->erase(streams->begin() + i);
}
}
}
// Returns the stream names for the container node.
void GetContainerNodeStreams(const CalculatorGraphConfig::Node& node,
CalculatorGraphConfig::Node* result) {
ABSL_CHECK(result != nullptr);
*result->mutable_input_stream() = node.input_stream();
*result->mutable_output_stream() = node.output_stream();
*result->mutable_input_side_packet() = node.input_side_packet();
*result->mutable_output_side_packet() = node.output_side_packet();
EraseTag("ENABLE", result->mutable_input_stream());
EraseTag("ENABLE", result->mutable_input_side_packet());
EraseTag("SELECT", result->mutable_input_stream());
EraseTag("SELECT", result->mutable_input_side_packet());
}
// Validate all subgraph inputs and outputs.
absl::Status ValidateContract(
const CalculatorGraphConfig::Node& subgraph_node,
const Subgraph::SubgraphOptions& subgraph_options) {
auto options =
Subgraph::GetOptions<mediapipe::SwitchContainerOptions>(subgraph_options);
std::map<TagIndex, std::string> input_tags, side_tags;
ParseTags(subgraph_node.input_stream(), &input_tags);
ParseTags(subgraph_node.input_side_packet(), &side_tags);
if (options.has_select() && options.has_enable()) {
return absl::InvalidArgumentError(
"Only one of SwitchContainer options 'enable' and 'select' can be "
"specified");
}
if (side_tags.count({"SELECT", 0}) + side_tags.count({"ENABLE", 0}) > 1 ||
input_tags.count({"SELECT", 0}) + input_tags.count({"ENABLE", 0}) > 1) {
return absl::InvalidArgumentError(
"Only one of SwitchContainer inputs 'ENABLE' and 'SELECT' can be "
"specified");
}
return absl::OkStatus();
}
// Returns true if a set of streams references a certain tag name.
bool HasTag(const proto_ns::RepeatedPtrField<std::string>& streams,
std::string tag) {
std::map<TagIndex, std::string> tags;
ParseTags(streams, &tags);
return tags.count({tag, 0}) > 0;
}
// Returns true if a set of "TAG::index" includes a TagIndex.
bool ContainsTag(const proto_ns::RepeatedPtrField<std::string>& tags,
TagIndex item) {
for (const std::string& t : tags) {
if (ParseTagIndex(t) == item) return true;
}
return false;
}
absl::StatusOr<CalculatorGraphConfig> SwitchContainer::GetConfig(
const Subgraph::SubgraphOptions& options) {
CalculatorGraphConfig config;
std::vector<CalculatorGraphConfig::Node*> subnodes;
std::vector<CalculatorGraphConfig::Node> substreams;
// Parse all input and output tags from the container node.
auto container_node = Subgraph::GetNode(options);
MP_RETURN_IF_ERROR(ValidateContract(container_node, options));
CalculatorGraphConfig::Node container_streams;
GetContainerNodeStreams(container_node, &container_streams);
std::map<TagIndex, std::string> input_tags, output_tags;
std::map<TagIndex, std::string> side_input_tags, side_output_tags;
ParseTags(container_streams.input_stream(), &input_tags);
ParseTags(container_streams.output_stream(), &output_tags);
ParseTags(container_streams.input_side_packet(), &side_input_tags);
ParseTags(container_streams.output_side_packet(), &side_output_tags);
CalculatorGraphConfig::Node* select_node = nullptr;
CalculatorGraphConfig::Node* enable_node = nullptr;
std::string select_stream = "SELECT:gate_select";
std::string enable_stream = "ENABLE:gate_enable";
// Add a PacketSequencerCalculator node for "SELECT" or "ENABLE" streams.
const auto& switch_options =
Subgraph::GetOptions<mediapipe::SwitchContainerOptions>(options);
bool async_selection = switch_options.async_selection();
if (HasTag(container_node.input_stream(), "SELECT")) {
select_node = BuildTimestampNode(&config, async_selection);
select_node->add_input_stream("INPUT:gate_select");
select_node->add_output_stream("OUTPUT:gate_select_timed");
select_stream = "SELECT:gate_select_timed";
}
if (HasTag(container_node.input_stream(), "ENABLE")) {
enable_node = BuildTimestampNode(&config, async_selection);
enable_node->add_input_stream("INPUT:gate_enable");
enable_node->add_output_stream("OUTPUT:gate_enable_timed");
enable_stream = "ENABLE:gate_enable_timed";
}
// Add a graph node for the demux, mux.
auto demux = BuildDemuxNode(input_tags, container_node, &config);
CopyOptions(container_node, demux);
ClearContainerOptions(demux);
demux->add_input_stream(select_stream);
demux->add_input_stream(enable_stream);
demux->add_input_side_packet("SELECT:gate_select");
demux->add_input_side_packet("ENABLE:gate_enable");
auto mux = BuildMuxNode(output_tags, &config);
CopyOptions(container_node, mux);
ClearContainerOptions(mux);
mux->add_input_stream(select_stream);
mux->add_input_stream(enable_stream);
mux->add_input_side_packet("SELECT:gate_select");
mux->add_input_side_packet("ENABLE:gate_enable");
// Add input streams for graph and demux.
config.add_input_stream("SELECT:gate_select");
config.add_input_stream("ENABLE:gate_enable");
config.add_input_side_packet("SELECT:gate_select");
config.add_input_side_packet("ENABLE:gate_enable");
int tick_index = 0;
for (const auto& p : input_tags) {
std::string stream = CatStream(p.first, p.second);
config.add_input_stream(stream);
demux->add_input_stream(stream);
}
// Add input streams for the timestamper.
auto& tick_streams = switch_options.tick_input_stream();
for (const auto& p : input_tags) {
if (!tick_streams.empty() && !ContainsTag(tick_streams, p.first)) continue;
TagIndex tick_tag{"TICK", tick_index++};
if (select_node) {
select_node->add_input_stream(CatStream(tick_tag, p.second));
}
if (enable_node) {
enable_node->add_input_stream(CatStream(tick_tag, p.second));
}
}
// Add output streams for graph and mux.
for (const auto& p : output_tags) {
std::string stream = CatStream(p.first, p.second);
config.add_output_stream(stream);
mux->add_output_stream(stream);
}
for (const auto& p : side_input_tags) {
std::string side = CatStream(p.first, p.second);
config.add_input_side_packet(side);
demux->add_input_side_packet(side);
}
for (const auto& p : side_output_tags) {
std::string side = CatStream(p.first, p.second);
config.add_output_side_packet(side);
mux->add_output_side_packet(side);
}
// Add a subnode for each contained_node.
auto nodes = Subgraph::GetOptions<mediapipe::SwitchContainerOptions>(options)
.contained_node();
std::vector<CalculatorGraphConfig::Node> contained_nodes(nodes.begin(),
nodes.end());
for (int i = 0; i < contained_nodes.size(); ++i) {
auto subnode = config.add_node();
*subnode = contained_nodes[i];
subnodes.push_back(subnode);
substreams.push_back(container_streams);
}
// Connect each contained graph node to demux and mux.
for (int channel = 0; channel < subnodes.size(); ++channel) {
CalculatorGraphConfig::Node& streams = substreams[channel];
// Connect each contained graph node input to a demux output.
std::map<TagIndex, std::string> input_stream_tags;
ParseTags(streams.input_stream(), &input_stream_tags);
for (auto& it : input_stream_tags) {
TagIndex tag_index = it.first;
std::string tag = ChannelTag(tag_index.first, channel);
std::string name = ChannelName(input_tags[tag_index], channel);
std::string demux_stream = CatStream({tag, tag_index.second}, name);
demux->add_output_stream(demux_stream);
subnodes[channel]->add_input_stream(CatStream(tag_index, name));
}
// Connect each contained graph node output to a mux input.
std::map<TagIndex, std::string> output_stream_tags;
ParseTags(streams.output_stream(), &output_stream_tags);
for (auto& it : output_stream_tags) {
TagIndex tag_index = it.first;
std::string tag = ChannelTag(tag_index.first, channel);
std::string name = ChannelName(output_tags[tag_index], channel);
subnodes[channel]->add_output_stream(CatStream(tag_index, name));
mux->add_input_stream(CatStream({tag, tag_index.second}, name));
}
// Connect each contained graph node side-input to a demux side-output.
std::map<TagIndex, std::string> input_side_tags;
ParseTags(streams.input_side_packet(), &input_side_tags);
for (auto& it : input_side_tags) {
TagIndex tag_index = it.first;
std::string tag = ChannelTag(tag_index.first, channel);
std::string name = ChannelName(side_input_tags[tag_index], channel);
std::string demux_stream = CatStream({tag, tag_index.second}, name);
demux->add_output_side_packet(demux_stream);
subnodes[channel]->add_input_side_packet(CatStream(tag_index, name));
}
// Connect each contained graph node side-output to a mux side-input.
std::map<TagIndex, std::string> output_side_tags;
ParseTags(streams.output_side_packet(), &output_side_tags);
for (auto& it : output_side_tags) {
TagIndex tag_index = it.first;
std::string tag = ChannelTag(tag_index.first, channel);
std::string name = ChannelName(side_output_tags[tag_index], channel);
subnodes[channel]->add_output_side_packet(CatStream(tag_index, name));
mux->add_input_side_packet(CatStream({tag, tag_index.second}, name));
}
}
return config;
}
} // namespace tool
} // namespace mediapipe