// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <istream>
#include <iterator>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "absl/container/node_hash_map.h"
#include "mediapipe/calculators/util/top_k_scores_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/statusor.h"
#include "mediapipe/util/resource_util.h"
#if defined(MEDIAPIPE_MOBILE)
#include "mediapipe/util/android/file/base/file.h"
#include "mediapipe/util/android/file/base/helpers.h"
#else
#include "mediapipe/framework/port/file_helpers.h"
#endif
namespace mediapipe {
constexpr char kTopKClassificationTag[] = "TOP_K_CLASSIFICATION";
constexpr char kSummaryTag[] = "SUMMARY";
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
constexpr char kTopKLabelsTag[] = "TOP_K_LABELS";
constexpr char kTopKScoresTag[] = "TOP_K_SCORES";
constexpr char kTopKIndexesTag[] = "TOP_K_INDEXES";
constexpr char kScoresTag[] = "SCORES";
// A calculator that takes a vector of scores and returns the indexes, scores,
// labels of the top k elements, classification protos, and summary string (in
// csv format).
//
// Usage example:
// node {
// calculator: "TopKScoresCalculator"
// input_stream: "SCORES:score_vector"
// output_stream: "TOP_K_INDEXES:top_k_indexes"
// output_stream: "TOP_K_SCORES:top_k_scores"
// output_stream: "TOP_K_LABELS:top_k_labels"
// output_stream: "TOP_K_CLASSIFICATIONS:top_k_classes"
// output_stream: "SUMMARY:summary"
// options: {
// [mediapipe.TopKScoresCalculatorOptions.ext] {
// top_k: 5
// threshold: 0.1
// label_map_path: "/path/to/label/map"
// }
// }
// }
class TopKScoresCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc);
absl::Status Open(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
private:
absl::Status LoadLabelmap(std::string label_map_path);
int top_k_ = -1;
float threshold_ = 0.0;
absl::node_hash_map<int, std::string> label_map_;
bool label_map_loaded_ = false;
};
REGISTER_CALCULATOR(TopKScoresCalculator);
absl::Status TopKScoresCalculator::GetContract(CalculatorContract* cc) {
RET_CHECK(cc->Inputs().HasTag(kScoresTag));
cc->Inputs().Tag(kScoresTag).Set<std::vector<float>>();
if (cc->Outputs().HasTag(kTopKIndexesTag)) {
cc->Outputs().Tag(kTopKIndexesTag).Set<std::vector<int>>();
}
if (cc->Outputs().HasTag(kTopKScoresTag)) {
cc->Outputs().Tag(kTopKScoresTag).Set<std::vector<float>>();
}
if (cc->Outputs().HasTag(kTopKLabelsTag)) {
cc->Outputs().Tag(kTopKLabelsTag).Set<std::vector<std::string>>();
}
if (cc->Outputs().HasTag(kClassificationsTag)) {
cc->Outputs().Tag(kClassificationsTag).Set<ClassificationList>();
}
if (cc->Outputs().HasTag(kSummaryTag)) {
cc->Outputs().Tag(kSummaryTag).Set<std::string>();
}
return absl::OkStatus();
}
absl::Status TopKScoresCalculator::Open(CalculatorContext* cc) {
const auto& options = cc->Options<::mediapipe::TopKScoresCalculatorOptions>();
RET_CHECK(options.has_top_k() || options.has_threshold())
<< "Must specify at least one of the top_k and threshold fields in "
"TopKScoresCalculatorOptions.";
if (options.has_top_k()) {
RET_CHECK(options.top_k() > 0) << "top_k must be greater than zero.";
top_k_ = options.top_k();
}
if (options.has_threshold()) {
threshold_ = options.threshold();
}
if (options.has_label_map_path()) {
MP_RETURN_IF_ERROR(LoadLabelmap(options.label_map_path()));
}
if (cc->Outputs().HasTag(kTopKLabelsTag)) {
RET_CHECK(!label_map_.empty());
}
return absl::OkStatus();
}
absl::Status TopKScoresCalculator::Process(CalculatorContext* cc) {
const std::vector<float>& input_vector =
cc->Inputs().Tag(kScoresTag).Get<std::vector<float>>();
std::vector<int> top_k_indexes;
std::vector<float> top_k_scores;
std::vector<std::string> top_k_labels;
if (top_k_ > 0) {
top_k_indexes.reserve(top_k_);
top_k_scores.reserve(top_k_);
top_k_labels.reserve(top_k_);
}
std::priority_queue<std::pair<float, int>, std::vector<std::pair<float, int>>,
std::greater<std::pair<float, int>>>
pq;
for (int i = 0; i < input_vector.size(); ++i) {
if (input_vector[i] < threshold_) {
continue;
}
if (top_k_ > 0) {
if (pq.size() < top_k_) {
pq.push(std::pair<float, int>(input_vector[i], i));
} else if (pq.top().first < input_vector[i]) {
pq.pop();
pq.push(std::pair<float, int>(input_vector[i], i));
}
} else {
pq.push(std::pair<float, int>(input_vector[i], i));
}
}
while (!pq.empty()) {
top_k_indexes.push_back(pq.top().second);
top_k_scores.push_back(pq.top().first);
pq.pop();
}
reverse(top_k_indexes.begin(), top_k_indexes.end());
reverse(top_k_scores.begin(), top_k_scores.end());
if (label_map_loaded_) {
for (int index : top_k_indexes) {
top_k_labels.push_back(label_map_[index]);
}
}
if (cc->Outputs().HasTag(kTopKIndexesTag)) {
cc->Outputs()
.Tag(kTopKIndexesTag)
.AddPacket(MakePacket<std::vector<int>>(top_k_indexes)
.At(cc->InputTimestamp()));
}
if (cc->Outputs().HasTag(kTopKScoresTag)) {
cc->Outputs()
.Tag(kTopKScoresTag)
.AddPacket(MakePacket<std::vector<float>>(top_k_scores)
.At(cc->InputTimestamp()));
}
if (cc->Outputs().HasTag(kTopKLabelsTag)) {
cc->Outputs()
.Tag(kTopKLabelsTag)
.AddPacket(MakePacket<std::vector<std::string>>(top_k_labels)
.At(cc->InputTimestamp()));
}
if (cc->Outputs().HasTag(kSummaryTag)) {
std::vector<std::string> results;
for (int index = 0; index < top_k_indexes.size(); ++index) {
if (label_map_loaded_) {
results.push_back(
absl::StrCat(top_k_labels[index], ":", top_k_scores[index]));
} else {
results.push_back(
absl::StrCat(top_k_indexes[index], ":", top_k_scores[index]));
}
}
cc->Outputs()
.Tag(kSummaryTag)
.AddPacket(MakePacket<std::string>(absl::StrJoin(results, ","))
.At(cc->InputTimestamp()));
}
if (cc->Outputs().HasTag(kTopKClassificationTag)) {
auto classification_list = absl::make_unique<ClassificationList>();
for (int index = 0; index < top_k_indexes.size(); ++index) {
Classification* classification =
classification_list->add_classification();
classification->set_index(top_k_indexes[index]);
classification->set_score(top_k_scores[index]);
if (label_map_loaded_) {
classification->set_label(top_k_labels[index]);
}
}
}
return absl::OkStatus();
}
absl::Status TopKScoresCalculator::LoadLabelmap(std::string label_map_path) {
std::string string_path;
MP_ASSIGN_OR_RETURN(string_path, PathToResourceAsFile(label_map_path));
std::string label_map_string;
MP_RETURN_IF_ERROR(file::GetContents(string_path, &label_map_string));
std::istringstream stream(label_map_string);
std::string line;
int i = 0;
while (std::getline(stream, line)) {
label_map_[i++] = line;
}
label_map_loaded_ = true;
return absl::OkStatus();
}
} // namespace mediapipe