// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <map>
#include <memory>
#include <set>
#include <string>
#include "absl/status/status.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
using mediapipe::PacketTypeSet;
using mediapipe::Timestamp;
namespace {
constexpr char kTagAtPreStream[] = "AT_PRESTREAM";
constexpr char kTagAtPostStream[] = "AT_POSTSTREAM";
constexpr char kTagAtZero[] = "AT_ZERO";
constexpr char kTagAtFirstTick[] = "AT_FIRST_TICK";
constexpr char kTagAtTick[] = "AT_TICK";
constexpr char kTagTick[] = "TICK";
constexpr char kTagAtTimestamp[] = "AT_TIMESTAMP";
constexpr char kTagSideInputTimestamp[] = "TIMESTAMP";
static std::map<std::string, Timestamp>* kTimestampMap = []() {
auto* res = new std::map<std::string, Timestamp>();
res->emplace(kTagAtPreStream, Timestamp::PreStream());
res->emplace(kTagAtPostStream, Timestamp::PostStream());
res->emplace(kTagAtZero, Timestamp(0));
res->emplace(kTagAtTick, Timestamp::Unset());
res->emplace(kTagAtFirstTick, Timestamp::Unset());
res->emplace(kTagAtTimestamp, Timestamp::Unset());
return res;
}();
template <typename CC>
std::string GetOutputTag(const CC& cc) {
// Single output tag only is required by contract.
return *cc.Outputs().GetTags().begin();
}
} // namespace
// Outputs side packet(s) in corresponding output stream(s) with a particular
// timestamp, depending on the tag used to define output stream(s). (One tag can
// be used only.)
//
// Valid tags are AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK, AT_FIRST_TICK,
// AT_TIMESTAMP and corresponding timestamps are Timestamp::PreStream(),
// Timestamp::PostStream(), Timestamp(0), timestamp of a packet received in TICK
// input, and timestamp received from a side input.
//
// Examples:
// node {
// calculator: "SidePacketToStreamCalculator"
// input_side_packet: "side_packet"
// output_stream: "AT_PRESTREAM:packet"
// }
//
// node {
// calculator: "SidePacketToStreamCalculator"
// input_stream: "TICK:tick"
// input_side_packet: "side_packet"
// output_stream: "AT_TICK:packet"
// }
//
// node {
// calculator: "SidePacketToStreamCalculator"
// input_side_packet: "TIMESTAMP:timestamp"
// input_side_packet: "side_packet"
// output_stream: "AT_TIMESTAMP:packet"
// }
class SidePacketToStreamCalculator : public CalculatorBase {
public:
SidePacketToStreamCalculator() = default;
~SidePacketToStreamCalculator() override = default;
static absl::Status GetContract(CalculatorContract* cc);
absl::Status Open(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
absl::Status Close(CalculatorContext* cc) override;
private:
bool is_tick_processing_ = false;
bool close_on_first_tick_ = false;
std::string output_tag_;
};
REGISTER_CALCULATOR(SidePacketToStreamCalculator);
absl::Status SidePacketToStreamCalculator::GetContract(CalculatorContract* cc) {
const auto& tags = cc->Outputs().GetTags();
RET_CHECK(tags.size() == 1 && kTimestampMap->count(*tags.begin()) == 1)
<< "Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK, "
"AT_FIRST_TICK and AT_TIMESTAMP tags is allowed and required to "
"specify output stream(s).";
const bool has_tick_output =
cc->Outputs().HasTag(kTagAtTick) || cc->Outputs().HasTag(kTagAtFirstTick);
const bool has_tick_input = cc->Inputs().HasTag(kTagTick);
RET_CHECK((has_tick_output && has_tick_input) ||
(!has_tick_output && !has_tick_input))
<< "Either both TICK input and tick (AT_TICK/AT_FIRST_TICK) output "
"should be used or none of them.";
RET_CHECK((cc->Outputs().HasTag(kTagAtTimestamp) &&
cc->InputSidePackets().HasTag(kTagSideInputTimestamp)) ||
(!cc->Outputs().HasTag(kTagAtTimestamp) &&
!cc->InputSidePackets().HasTag(kTagSideInputTimestamp)))
<< "Either both TIMESTAMP and AT_TIMESTAMP should be used or none of "
"them.";
const std::string output_tag = GetOutputTag(*cc);
const int num_entries = cc->Outputs().NumEntries(output_tag);
if (cc->Outputs().HasTag(kTagAtTimestamp)) {
RET_CHECK_EQ(num_entries + 1, cc->InputSidePackets().NumEntries())
<< "For AT_TIMESTAMP tag, 2 input side packets are required.";
cc->InputSidePackets().Tag(kTagSideInputTimestamp).Set<int64_t>();
} else {
RET_CHECK_EQ(num_entries, cc->InputSidePackets().NumEntries())
<< "Same number of input side packets and output streams is required.";
}
for (int i = 0; i < num_entries; ++i) {
cc->InputSidePackets().Index(i).SetAny();
cc->Outputs()
.Get(output_tag, i)
.SetSameAs(cc->InputSidePackets().Index(i).GetSameAs());
}
if (cc->Inputs().HasTag(kTagTick)) {
cc->Inputs().Tag(kTagTick).SetAny();
}
return absl::OkStatus();
}
absl::Status SidePacketToStreamCalculator::Open(CalculatorContext* cc) {
output_tag_ = GetOutputTag(*cc);
if (cc->Inputs().HasTag(kTagTick)) {
is_tick_processing_ = true;
// Set offset, so output timestamp bounds are updated in response to TICK
// timestamp bound update.
cc->SetOffset(TimestampDiff(0));
}
if (output_tag_ == kTagAtFirstTick) {
close_on_first_tick_ = true;
}
return absl::OkStatus();
}
absl::Status SidePacketToStreamCalculator::Process(CalculatorContext* cc) {
if (is_tick_processing_) {
if (cc->Outputs().Get(output_tag_, 0).IsClosed()) {
return absl::OkStatus();
}
// TICK input is guaranteed to be non-empty, as it's the only input stream
// for this calculator.
const auto& timestamp = cc->Inputs().Tag(kTagTick).Value().Timestamp();
for (int i = 0; i < cc->Outputs().NumEntries(output_tag_); ++i) {
cc->Outputs()
.Get(output_tag_, i)
.AddPacket(cc->InputSidePackets().Index(i).At(timestamp));
if (close_on_first_tick_) {
cc->Outputs().Get(output_tag_, i).Close();
}
}
return absl::OkStatus();
}
return mediapipe::tool::StatusStop();
}
absl::Status SidePacketToStreamCalculator::Close(CalculatorContext* cc) {
if (!cc->Outputs().HasTag(kTagAtTick) &&
!cc->Outputs().HasTag(kTagAtFirstTick) &&
!cc->Outputs().HasTag(kTagAtTimestamp)) {
const auto& timestamp = kTimestampMap->at(output_tag_);
for (int i = 0; i < cc->Outputs().NumEntries(output_tag_); ++i) {
cc->Outputs()
.Get(output_tag_, i)
.AddPacket(cc->InputSidePackets().Index(i).At(timestamp));
}
} else if (cc->Outputs().HasTag(kTagAtTimestamp)) {
int64_t timestamp =
cc->InputSidePackets().Tag(kTagSideInputTimestamp).Get<int64_t>();
for (int i = 0; i < cc->Outputs().NumEntries(output_tag_); ++i) {
cc->Outputs()
.Get(output_tag_, i)
.AddPacket(cc->InputSidePackets().Index(i).At(Timestamp(timestamp)));
}
}
return absl::OkStatus();
}
} // namespace mediapipe