// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/log/absl_log.h"
#include "absl/strings/match.h"
#include "mediapipe/calculators/core/packet_resampler_calculator.pb.h"
#include "mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/location.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/util/audio_decoder.pb.h"
#include "mediapipe/util/sequence/media_sequence.h"
#include "tensorflow/core/example/example.pb.h"
#include "tensorflow/core/example/feature.pb.h"
namespace mediapipe {
// Streams:
const char kBBoxTag[] = "BBOX";
const char kImageTag[] = "IMAGE";
const char kKeypointsTag[] = "KEYPOINTS";
const char kFloatFeaturePrefixTag[] = "FLOAT_FEATURE_";
const char kForwardFlowImageTag[] = "FORWARD_FLOW_ENCODED";
// Side Packets:
const char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE";
const char kDatasetRootDirTag[] = "DATASET_ROOT";
const char kDataPath[] = "DATA_PATH";
const char kPacketResamplerOptions[] = "RESAMPLER_OPTIONS";
const char kImagesFrameRateTag[] = "IMAGE_FRAME_RATE";
const char kAudioDecoderOptions[] = "AUDIO_DECODER_OPTIONS";
namespace tf = ::tensorflow;
namespace mpms = mediapipe::mediasequence;
// Source calculator to unpack side_packets and streams from tf.SequenceExamples
//
// Often, only side_packets or streams need to be output, but both can be output
// if needed. A tf.SequenceExample always needs to be supplied as an
// input_side_packet. The SequenceExample must be in the format described in
// media_sequence.h. This documentation will first describe the side_packets
// the calculator can output, and then describe the streams.
//
// Side_packets are commonly used to specify which clip to extract data from.
// Seeking into a video does not necessarily provide consistent timestamps when
// resampling to a known rate. To enable consistent timestamps, we unpack the
// metadata into options for the MediaDecoderCalculator and the
// PacketResamplerCalculator. To ensure consistent timestamps, the MediaDecoder
// needs to seek to slightly before the clip starts, so it sees at least one
// packet before the first packet we want to keep. The PacketResamplerCalculator
// then trims down the timestamps. Furthermore, we should always specify that we
// want timestamps from a base timestamp of 0, so we have the same resampled
// frames after a seek that we would have from the start of a video. In summary,
// when decoding image frames, output both the DECODER_OPTIONS and
// RESAMPLER_OPTIONS. In the base_media_decoder_options, specify which streams
// you want. In the base_packet_resampler_options, specify the frame_rate you
// want and base_timestamp = 0. In the options for this calculator, specify
// padding extra_padding_from_media_decoder such that at least one frame arrives
// before the first frame the PacketResamplerCalculator should output.
//
// Optional output_side_packets include (referenced by tag):
// DATA_PATH: The data_path context feature joined onto the
// options.dataset_root_directory or input_side_packet of DATASET_ROOT.
// RESAMPLER_OPTIONS: CalculatorOptions to pass to the
// PacketResamplerCalculator. The most accurate procedure for sampling a
// range of frames is to request a padded time range from the
// MediaDecoderCalculator and then trim it down to the proper time range with
// the PacketResamplerCalculator.
// IMAGES_FRAME_RATE: The frame rate of the images in the original video as a
// double.
//
// Example config:
// node {
// calculator: "UnpackMediaSequenceCalculator"
// input_side_packet: "SEQUENCE_EXAMPLE:example_input_side_packet"
// input_side_packet: "DATASET_ROOT:path_to_dataset_root_directory"
// output_side_packet: "DATA_PATH:full_path_to_data_element"
// output_side_packet: "RESAMPLER_OPTIONS:packet_resampler_options"
// options {
// [mediapipe.UnpackMediaSequenceCalculatorOptions.ext]: {
// base_packet_resampler_options {
// frame_rate: 1.0 # PARAM_FRAME_RATE
// base_timestamp: 0
// }
// }
// }
// }
//
// The calculator also takes a tf.SequenceExample as a side input and outputs
// the data in streams from the SequenceExample at the proper timestamps. The
// SequenceExample must conform to the description in media_sequence.h.
// Timestamps in the SequenceExample must be in sequential order.
//
// The following output stream tags are supported:
// IMAGE: encoded images as strings. (IMAGE_${NAME} is supported.)
// FORWARD_FLOW_ENCODED: encoded FORWARD_FLOW prefix images as strings.
// FLOAT_FEATURE_${NAME}: the feature named ${NAME} as vector<float>.
// BBOX: bounding boxes as vector<Location>s. (BBOX_${NAME} is supported.)
//
// Example config:
// node {
// calculator: "UnpackMediaSequenceCalculator"
// input_side_packet: "SEQUENCE_EXAMPLE:example_input_side_packet"
// output_stream: "IMAGE:frames"
// output_stream: "FLOAT_FEATURE_FDENSE:fdense_vf"
// output_stream: "BBOX:faces"
// }
class UnpackMediaSequenceCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
const auto& options = cc->Options<UnpackMediaSequenceCalculatorOptions>();
RET_CHECK(cc->InputSidePackets().HasTag(kSequenceExampleTag));
cc->InputSidePackets().Tag(kSequenceExampleTag).Set<tf::SequenceExample>();
// Optional side inputs.
if (cc->InputSidePackets().HasTag(kDatasetRootDirTag)) {
cc->InputSidePackets().Tag(kDatasetRootDirTag).Set<std::string>();
}
if (cc->OutputSidePackets().HasTag(kDataPath)) {
cc->OutputSidePackets().Tag(kDataPath).Set<std::string>();
}
if (cc->OutputSidePackets().HasTag(kAudioDecoderOptions)) {
cc->OutputSidePackets()
.Tag(kAudioDecoderOptions)
.Set<AudioDecoderOptions>();
}
if (cc->OutputSidePackets().HasTag(kImagesFrameRateTag)) {
cc->OutputSidePackets().Tag(kImagesFrameRateTag).Set<double>();
}
if (cc->OutputSidePackets().HasTag(kPacketResamplerOptions)) {
cc->OutputSidePackets()
.Tag(kPacketResamplerOptions)
.Set<CalculatorOptions>();
}
if ((options.has_padding_before_label() ||
options.has_padding_after_label()) &&
!(cc->OutputSidePackets().HasTag(kAudioDecoderOptions) ||
cc->OutputSidePackets().HasTag(kPacketResamplerOptions))) {
return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
<< "If specifying padding, must output " << kPacketResamplerOptions
<< "or" << kAudioDecoderOptions;
}
if (cc->Outputs().HasTag(kForwardFlowImageTag)) {
cc->Outputs().Tag(kForwardFlowImageTag).Set<std::string>();
}
for (const auto& tag : cc->Outputs().GetTags()) {
if (absl::StartsWith(tag, kImageTag)) {
std::string key = "";
if (tag != kImageTag) {
int tag_length = sizeof(kImageTag) / sizeof(*kImageTag) - 1;
if (tag[tag_length] == '_') {
key = tag.substr(tag_length + 1);
} else {
continue; // Skip keys that don't match "(kImageTag)_?"
}
}
cc->Outputs().Tag(tag).Set<std::string>();
}
if (absl::StartsWith(tag, kBBoxTag)) {
std::string key = "";
if (tag != kBBoxTag) {
int tag_length = sizeof(kBBoxTag) / sizeof(*kBBoxTag) - 1;
if (tag[tag_length] == '_') {
key = tag.substr(tag_length + 1);
} else {
continue; // Skip keys that don't match "(kBBoxTag)_?"
}
}
cc->Outputs().Tag(tag).Set<std::vector<Location>>();
}
if (absl::StartsWith(tag, kFloatFeaturePrefixTag)) {
cc->Outputs().Tag(tag).Set<std::vector<float>>();
}
}
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) override {
// Copy the packet to copy the otherwise inaccessible shared ptr.
example_packet_holder_ = cc->InputSidePackets().Tag(kSequenceExampleTag);
sequence_ = &example_packet_holder_.Get<tf::SequenceExample>();
const auto& options = cc->Options<UnpackMediaSequenceCalculatorOptions>();
// Collect the timestamps for all streams keyed by the timestamp feature's
// key. While creating this data structure we also identify the last
// timestamp and the associated feature. This information is used in process
// to output batches of packets in order.
timestamps_.clear();
int64_t last_timestamp_seen = Timestamp::PreStream().Value();
first_timestamp_seen_ = Timestamp::OneOverPostStream().Value();
for (const auto& map_kv : sequence_->feature_lists().feature_list()) {
if (absl::StrContains(map_kv.first, "/timestamp")) {
ABSL_LOG(INFO) << "Found feature timestamps: " << map_kv.first
<< " with size: " << map_kv.second.feature_size();
int64_t recent_timestamp = Timestamp::PreStream().Value();
for (int i = 0; i < map_kv.second.feature_size(); ++i) {
int64_t next_timestamp =
mpms::GetInt64sAt(*sequence_, map_kv.first, i).Get(0);
RET_CHECK_GT(next_timestamp, recent_timestamp)
<< "Timestamps must be sequential. If you're seeing this message "
<< "you may have added images to the same SequenceExample twice. "
<< "Key: " << map_kv.first;
if (options.output_poststream_as_prestream() &&
next_timestamp == Timestamp::PostStream().Value()) {
RET_CHECK_EQ(i, 0)
<< "Detected PostStream() and timestamps being output for the "
<< "same stream. This is currently invalid.";
next_timestamp = Timestamp::PreStream().Value();
}
timestamps_[map_kv.first].push_back(next_timestamp);
recent_timestamp = next_timestamp;
if (recent_timestamp < first_timestamp_seen_) {
first_timestamp_seen_ = recent_timestamp;
}
}
if (recent_timestamp > last_timestamp_seen &&
recent_timestamp < Timestamp::PostStream().Value()) {
last_timestamp_key_ = map_kv.first;
last_timestamp_seen = recent_timestamp;
}
}
}
if (!timestamps_.empty()) {
for (const auto& kv : timestamps_) {
if (!kv.second.empty() &&
kv.second[0] < Timestamp::PostStream().Value()) {
// These checks only make sense if any values are not PostStream, but
// only need to be made once.
RET_CHECK(!last_timestamp_key_.empty())
<< "Something went wrong because the timestamp key is unset. "
<< "Example: " << sequence_->DebugString();
RET_CHECK_GT(last_timestamp_seen, Timestamp::PreStream().Value())
<< "Something went wrong because the last timestamp is unset. "
<< "Example: " << sequence_->DebugString();
RET_CHECK_LT(first_timestamp_seen_,
Timestamp::OneOverPostStream().Value())
<< "Something went wrong because the first timestamp is unset. "
<< "Example: " << sequence_->DebugString();
break;
}
}
}
current_timestamp_index_ = 0;
process_poststream_ = false;
// Determine the data path and output it.
const auto& sequence = cc->InputSidePackets()
.Tag(kSequenceExampleTag)
.Get<tensorflow::SequenceExample>();
if (cc->OutputSidePackets().HasTag(kDataPath)) {
std::string root_directory = "";
if (cc->InputSidePackets().HasTag(kDatasetRootDirTag)) {
root_directory =
cc->InputSidePackets().Tag(kDatasetRootDirTag).Get<std::string>();
} else if (options.has_dataset_root_directory()) {
root_directory = options.dataset_root_directory();
}
std::string data_path = mpms::GetClipDataPath(sequence);
if (!root_directory.empty()) {
if (root_directory[root_directory.size() - 1] == '/') {
data_path = root_directory + data_path;
} else {
data_path = root_directory + "/" + data_path;
}
}
cc->OutputSidePackets().Tag(kDataPath).Set(
MakePacket<std::string>(data_path));
}
// Set the start and end of the clip in the appropriate options protos.
double start_time = 0;
double end_time = 0;
if (cc->OutputSidePackets().HasTag(kAudioDecoderOptions) ||
cc->OutputSidePackets().HasTag(kPacketResamplerOptions)) {
if (mpms::HasClipStartTimestamp(sequence)) {
start_time =
Timestamp(mpms::GetClipStartTimestamp(sequence)).Seconds() -
options.padding_before_label();
}
if (mpms::HasClipEndTimestamp(sequence)) {
end_time = Timestamp(mpms::GetClipEndTimestamp(sequence)).Seconds() +
options.padding_after_label();
}
}
if (cc->OutputSidePackets().HasTag(kAudioDecoderOptions)) {
auto audio_decoder_options = absl::make_unique<AudioDecoderOptions>(
options.base_audio_decoder_options());
if (mpms::HasClipStartTimestamp(sequence)) {
if (options.force_decoding_from_start_of_media()) {
audio_decoder_options->set_start_time(0);
} else {
audio_decoder_options->set_start_time(
start_time - options.extra_padding_from_media_decoder());
}
}
if (mpms::HasClipEndTimestamp(sequence)) {
audio_decoder_options->set_end_time(
end_time + options.extra_padding_from_media_decoder());
}
ABSL_LOG(INFO) << "Created AudioDecoderOptions:\n"
<< audio_decoder_options->DebugString();
cc->OutputSidePackets()
.Tag(kAudioDecoderOptions)
.Set(Adopt(audio_decoder_options.release()));
}
if (cc->OutputSidePackets().HasTag(kPacketResamplerOptions)) {
auto resampler_options = absl::make_unique<CalculatorOptions>();
*(resampler_options->MutableExtension(
PacketResamplerCalculatorOptions::ext)) =
options.base_packet_resampler_options();
if (mpms::HasClipStartTimestamp(sequence)) {
resampler_options
->MutableExtension(PacketResamplerCalculatorOptions::ext)
->set_start_time(Timestamp::FromSeconds(start_time).Value());
}
if (mpms::HasClipEndTimestamp(sequence)) {
resampler_options
->MutableExtension(PacketResamplerCalculatorOptions::ext)
->set_end_time(Timestamp::FromSeconds(end_time).Value());
}
ABSL_LOG(INFO) << "Created PacketResamplerOptions:\n"
<< resampler_options->DebugString();
cc->OutputSidePackets()
.Tag(kPacketResamplerOptions)
.Set(Adopt(resampler_options.release()));
}
// Output the remaining side outputs.
if (cc->OutputSidePackets().HasTag(kImagesFrameRateTag)) {
cc->OutputSidePackets()
.Tag(kImagesFrameRateTag)
.Set(MakePacket<double>(mpms::GetImageFrameRate(sequence)));
}
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
if (timestamps_.empty()) {
// This occurs when we only have metadata to unpack.
ABSL_LOG(INFO)
<< "only unpacking metadata because there are no timestamps.";
return tool::StatusStop();
}
// In Process(), we loop through timestamps on a reference stream and emit
// all packets on all streams that have a timestamp between the current
// reference timestep and the previous reference timestep. This ensures that
// we emit all timestamps in order, but also only emit a limited number in
// any particular call to Process(). At the every end, we output the
// poststream packets. If we only have poststream packets,
// last_timestamp_key_ will be empty.
int64_t start_timestamp = 0;
int64_t end_timestamp = 0;
if (last_timestamp_key_.empty() || process_poststream_) {
process_poststream_ = true;
start_timestamp = Timestamp::PostStream().Value();
end_timestamp = Timestamp::OneOverPostStream().Value();
} else {
start_timestamp =
timestamps_[last_timestamp_key_][current_timestamp_index_];
if (current_timestamp_index_ == 0) {
start_timestamp = first_timestamp_seen_;
}
end_timestamp = start_timestamp + 1; // Base case at end of sequence.
if (current_timestamp_index_ <
timestamps_[last_timestamp_key_].size() - 1) {
end_timestamp =
timestamps_[last_timestamp_key_][current_timestamp_index_ + 1];
}
}
for (const auto& map_kv : timestamps_) {
for (int i = 0; i < map_kv.second.size(); ++i) {
if (map_kv.second[i] >= start_timestamp &&
map_kv.second[i] < end_timestamp) {
Timestamp current_timestamp;
if (map_kv.second[i] == Timestamp::PostStream().Value()) {
current_timestamp = Timestamp::PostStream();
} else if (map_kv.second[i] == Timestamp::PreStream().Value()) {
current_timestamp = Timestamp::PreStream();
} else {
current_timestamp = Timestamp(map_kv.second[i]);
}
if (absl::StrContains(map_kv.first, mpms::GetImageTimestampKey())) {
std::vector<std::string> pieces = absl::StrSplit(map_kv.first, '/');
std::string feature_key = "";
std::string possible_tag = kImageTag;
if (pieces[0] != "image") {
feature_key = pieces[0];
possible_tag = absl::StrCat(kImageTag, "_", feature_key);
}
if (cc->Outputs().HasTag(possible_tag)) {
// If this is triggered, it means that there's no images to match
// the timestamps. This is clearly an error, but we don't want a
// segfault.
if (mpms::GetImageEncodedSize(feature_key, *sequence_) <= i) {
return tool::StatusStop();
}
cc->Outputs()
.Tag(possible_tag)
.Add(new std::string(
mpms::GetImageEncodedAt(feature_key, *sequence_, i)),
current_timestamp);
}
}
if (cc->Outputs().HasTag(kForwardFlowImageTag) &&
map_kv.first == mpms::GetForwardFlowTimestampKey()) {
cc->Outputs()
.Tag(kForwardFlowImageTag)
.Add(new std::string(
mpms::GetForwardFlowEncodedAt(*sequence_, i)),
current_timestamp);
}
if (absl::StrContains(map_kv.first, mpms::GetBBoxTimestampKey())) {
std::vector<std::string> pieces = absl::StrSplit(map_kv.first, '/');
std::string feature_key = "";
std::string possible_tag = kBBoxTag;
if (pieces[0] != "region") {
feature_key = pieces[0];
possible_tag = absl::StrCat(kBBoxTag, "_", feature_key);
}
if (cc->Outputs().HasTag(possible_tag)) {
const auto& bboxes = mpms::GetBBoxAt(feature_key, *sequence_, i);
cc->Outputs()
.Tag(possible_tag)
.Add(new std::vector<Location>(bboxes.begin(), bboxes.end()),
current_timestamp);
}
}
if (absl::StrContains(map_kv.first, "feature")) {
std::vector<std::string> pieces = absl::StrSplit(map_kv.first, '/');
RET_CHECK_GT(pieces.size(), 1)
<< "Failed to parse the feature substring before / from key "
<< map_kv.first;
std::string feature_key = pieces[0];
std::string possible_tag = kFloatFeaturePrefixTag + feature_key;
if (cc->Outputs().HasTag(possible_tag)) {
const auto& float_list =
mpms::GetFeatureFloatsAt(feature_key, *sequence_, i);
cc->Outputs()
.Tag(possible_tag)
.Add(new std::vector<float>(float_list.begin(),
float_list.end()),
current_timestamp);
}
}
}
}
}
++current_timestamp_index_;
if (current_timestamp_index_ < timestamps_[last_timestamp_key_].size()) {
return absl::OkStatus();
} else {
if (process_poststream_) {
// Once we've processed the PostStream timestamp we can stop.
return tool::StatusStop();
} else {
// Otherwise, we still need to do one more pass to process it.
process_poststream_ = true;
return absl::OkStatus();
}
}
}
// Hold a copy of the packet to prevent the shared_ptr from dying and then
// access the SequenceExample with a handy pointer.
const tf::SequenceExample* sequence_;
Packet example_packet_holder_;
// Store a map from the keys for each stream to the timestamps for each
// key. This allows us to identify which packets to output for each stream
// for timestamps within a given time window.
std::map<std::string, std::vector<int64_t>> timestamps_;
// Store the stream with the latest timestamp in the SequenceExample.
std::string last_timestamp_key_;
// Store the index of the current timestamp. Will be less than
// timestamps_[last_timestamp_key_].size().
int current_timestamp_index_;
// Store the very first timestamp, so we output everything on the first frame.
int64_t first_timestamp_seen_;
// List of keypoint names.
std::vector<std::string> keypoint_names_;
// Default keypoint location when missing.
float default_keypoint_location_;
bool process_poststream_;
};
REGISTER_CALCULATOR(UnpackMediaSequenceCalculator);
} // namespace mediapipe