chromium/third_party/mediapipe/src/mediapipe/calculators/util/association_calculator.h

// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef MEDIAPIPE_CALCULATORS_UTIL_ASSOCIATION_CALCULATOR_H_
#define MEDIAPIPE_CALCULATORS_UTIL_ASSOCIATION_CALCULATOR_H_

#include <memory>
#include <vector>

#include "absl/log/absl_check.h"
#include "absl/memory/memory.h"
#include "mediapipe/calculators/util/association_calculator.pb.h"
#include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/collection_item_id.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/rectangle.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/util/rectangle_util.h"

namespace mediapipe {

// AssocationCalculator<T> accepts multiple inputs of vectors of type T that can
// be converted to Rectangle_f. The output is a vector of type T that contains
// elements from the input vectors that don't overlap with each other. When
// two elements overlap, the element that comes in from a later input stream
// is kept in the output. This association operation is useful for multiple
// instance inference pipelines in MediaPipe.
// If an input stream is tagged with "PREV" tag, IDs of overlapping elements
// from "PREV" input stream are propagated to the output. Elements in the "PREV"
// input stream that don't overlap with other elements are not added to the
// output. This stream is designed to take detections from previous timestamp,
// e.g. output of PreviousLoopbackCalculator to provide temporal association.
// See AssociationDetectionCalculator and AssociationNormRectCalculator for
// example uses.
template <typename T>
class AssociationCalculator : public CalculatorBase {
 public:
  static absl::Status GetContract(CalculatorContract* cc) {
    // Atmost one input stream can be tagged with "PREV".
    RET_CHECK_LE(cc->Inputs().NumEntries("PREV"), 1);

    if (cc->Inputs().HasTag("PREV")) {
      RET_CHECK_GE(cc->Inputs().NumEntries(), 2);
    }

    for (CollectionItemId id = cc->Inputs().BeginId();
         id < cc->Inputs().EndId(); ++id) {
      cc->Inputs().Get(id).Set<std::vector<T>>();
    }

    cc->Outputs().Index(0).Set<std::vector<T>>();

    return absl::OkStatus();
  }

  absl::Status Open(CalculatorContext* cc) override {
    cc->SetOffset(TimestampDiff(0));

    has_prev_input_stream_ = cc->Inputs().HasTag("PREV");
    if (has_prev_input_stream_) {
      prev_input_stream_id_ = cc->Inputs().GetId("PREV", 0);
    }
    options_ = cc->Options<::mediapipe::AssociationCalculatorOptions>();
    ABSL_CHECK_GE(options_.min_similarity_threshold(), 0);

    return absl::OkStatus();
  }

  absl::Status Process(CalculatorContext* cc) override {
    auto get_non_overlapping_elements = GetNonOverlappingElements(cc);
    if (!get_non_overlapping_elements.ok()) {
      return get_non_overlapping_elements.status();
    }
    std::list<T> result = get_non_overlapping_elements.value();

    if (has_prev_input_stream_ &&
        !cc->Inputs().Get(prev_input_stream_id_).IsEmpty()) {
      // Processed all regular input streams. Now compare the result list
      // elements with those in the PREV input stream, and propagate IDs from
      // PREV input stream as appropriate.
      const std::vector<T>& prev_input_vec =
          cc->Inputs()
              .Get(prev_input_stream_id_)
              .template Get<std::vector<T>>();

      MP_RETURN_IF_ERROR(
          PropagateIdsFromPreviousToCurrent(prev_input_vec, &result));
    }

    auto output = absl::make_unique<std::vector<T>>();
    for (auto it = result.begin(); it != result.end(); ++it) {
      output->push_back(*it);
    }
    cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());

    return absl::OkStatus();
  }

 protected:
  ::mediapipe::AssociationCalculatorOptions options_;

  bool has_prev_input_stream_;
  CollectionItemId prev_input_stream_id_;

  virtual absl::StatusOr<Rectangle_f> GetRectangle(const T& input) {
    return absl::OkStatus();
  }

  virtual std::pair<bool, int> GetId(const T& input) { return {false, -1}; }

  virtual void SetId(T* input, int id) {}

 private:
  // Get a list of non-overlapping elements from all input streams, with
  // increasing order of priority based on input stream index.
  absl::StatusOr<std::list<T>> GetNonOverlappingElements(
      CalculatorContext* cc) {
    std::list<T> result;

    // Initialize result with the first non-empty input vector.
    CollectionItemId non_empty_id = cc->Inputs().BeginId();
    for (CollectionItemId id = cc->Inputs().BeginId();
         id < cc->Inputs().EndId(); ++id) {
      if (id == prev_input_stream_id_ || cc->Inputs().Get(id).IsEmpty()) {
        continue;
      }
      const std::vector<T>& input_vec =
          cc->Inputs().Get(id).Get<std::vector<T>>();
      if (!input_vec.empty()) {
        non_empty_id = id;
        result.push_back(input_vec[0]);
        for (int j = 1; j < input_vec.size(); ++j) {
          MP_RETURN_IF_ERROR(AddElementToList(input_vec[j], &result));
        }
        break;
      }
    }

    // Compare remaining input vectors with the non-empty result vector,
    // remove lower-priority overlapping elements from the result vector and
    // had corresponding higher-priority elements as necessary.
    for (CollectionItemId id = non_empty_id + 1; id < cc->Inputs().EndId();
         ++id) {
      if (id == prev_input_stream_id_ || cc->Inputs().Get(id).IsEmpty()) {
        continue;
      }
      const std::vector<T>& input_vec =
          cc->Inputs().Get(id).Get<std::vector<T>>();

      for (int vi = 0; vi < input_vec.size(); ++vi) {
        MP_RETURN_IF_ERROR(AddElementToList(input_vec[vi], &result));
      }
    }

    return result;
  }

  absl::Status AddElementToList(T element, std::list<T>* current) {
    // Compare this element with elements of the input collection. If this
    // element has high overlap with elements of the collection, remove
    // those elements from the collection and add this element.
    MP_ASSIGN_OR_RETURN(auto cur_rect, GetRectangle(element));

    bool change_id = false;
    int new_elem_id = -1;

    for (auto uit = current->begin(); uit != current->end();) {
      MP_ASSIGN_OR_RETURN(auto prev_rect, GetRectangle(*uit));
      if (CalculateIou(cur_rect, prev_rect) >
          options_.min_similarity_threshold()) {
        std::pair<bool, int> prev_id = GetId(*uit);
        // If prev_id.first is false when some element doesn't have an ID,
        // change_id and new_elem_id will not be updated.
        if (prev_id.first) {
          change_id = prev_id.first;
          new_elem_id = prev_id.second;
        }
        uit = current->erase(uit);
      } else {
        ++uit;
      }
    }

    if (change_id) {
      SetId(&element, new_elem_id);
    }
    current->push_back(element);

    return absl::OkStatus();
  }

  // Compare elements of the current list with elements in from the collection
  // of elements from the previous input stream, and propagate IDs from the
  // previous input stream as appropriate.
  absl::Status PropagateIdsFromPreviousToCurrent(
      const std::vector<T>& prev_input_vec, std::list<T>* current) {
    for (auto vit = current->begin(); vit != current->end(); ++vit) {
      auto get_cur_rectangle = GetRectangle(*vit);
      if (!get_cur_rectangle.ok()) {
        return get_cur_rectangle.status();
      }
      const Rectangle_f& cur_rect = get_cur_rectangle.value();

      bool change_id = false;
      int id_for_vi = -1;

      for (int ui = 0; ui < prev_input_vec.size(); ++ui) {
        auto get_prev_rectangle = GetRectangle(prev_input_vec[ui]);
        if (!get_prev_rectangle.ok()) {
          return get_prev_rectangle.status();
        }
        const Rectangle_f& prev_rect = get_prev_rectangle.value();

        if (CalculateIou(cur_rect, prev_rect) >
            options_.min_similarity_threshold()) {
          std::pair<bool, int> prev_id = GetId(prev_input_vec[ui]);
          // If prev_id.first is false when some element doesn't have an ID,
          // change_id and id_for_vi will not be updated.
          if (prev_id.first) {
            change_id = prev_id.first;
            id_for_vi = prev_id.second;
          }
        }
      }

      if (change_id) {
        T element = *vit;
        SetId(&element, id_for_vi);
        *vit = element;
      }
    }
    return absl::OkStatus();
  }
};

}  // namespace mediapipe

#endif  // MEDIAPIPE_CALCULATORS_UTIL_ASSOCIATION_CALCULATOR_H_