chromium/third_party/mediapipe/src/mediapipe/framework/stream_handler/fixed_size_input_stream_handler.cc

// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/stream_handler/fixed_size_input_stream_handler.h"

#include <algorithm>
#include <list>
#include <memory>
#include <utility>
#include <vector>

#include "absl/log/absl_check.h"
#include "absl/log/absl_log.h"
#include "absl/synchronization/mutex.h"
#include "mediapipe/framework/calculator_context_manager.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/collection_item_id.h"
#include "mediapipe/framework/input_stream_handler.h"
#include "mediapipe/framework/mediapipe_options.pb.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/stream_handler/default_input_stream_handler.h"
#include "mediapipe/framework/stream_handler/fixed_size_input_stream_handler.pb.h"
#include "mediapipe/framework/tool/tag_map.h"

namespace mediapipe {

FixedSizeInputStreamHandler::FixedSizeInputStreamHandler(
    std::shared_ptr<tool::TagMap> tag_map, CalculatorContextManager* cc_manager,
    const mediapipe::MediaPipeOptions& options, bool calculator_run_in_parallel)
    : DefaultInputStreamHandler(std::move(tag_map), cc_manager, options,
                                calculator_run_in_parallel) {
  const auto& ext =
      options.GetExtension(mediapipe::FixedSizeInputStreamHandlerOptions::ext);
  trigger_queue_size_ = ext.trigger_queue_size();
  target_queue_size_ = ext.target_queue_size();
  fixed_min_size_ = ext.fixed_min_size();
  pending_ = false;
  kept_timestamp_ = Timestamp::Unset();
  // TODO: Either re-enable SetLatePreparation(true) with
  // CalculatorContext::InputTimestamp set correctly, or remove the
  // implementation of SetLatePreparation.
}

void FixedSizeInputStreamHandler::EraseAllSurplus() {
  Timestamp min_timestamp_all_streams = Timestamp::Max();
  for (const auto& stream : input_stream_managers_) {
    // Check whether every InputStreamImpl grew beyond trigger_queue_size.
    if (stream->QueueSize() < trigger_queue_size_) {
      return;
    }
    Timestamp min_timestamp =
        stream->GetMinTimestampAmongNLatest(target_queue_size_);

    // Record the min timestamp among the newest target_queue_size_ packets
    // across all InputStreamImpls.
    min_timestamp_all_streams =
        std::min(min_timestamp_all_streams, min_timestamp);
  }
  for (auto& stream : input_stream_managers_) {
    stream->ErasePacketsEarlierThan(min_timestamp_all_streams);
  }
}

Timestamp FixedSizeInputStreamHandler::PreviousAllowedInStream(
    Timestamp bound) {
  return bound.IsRangeValue() ? bound - 1 : bound;
}

Timestamp FixedSizeInputStreamHandler::MinStreamBound() {
  Timestamp min_bound = Timestamp::Done();
  for (const auto& stream : input_stream_managers_) {
    Timestamp stream_bound = stream->GetMinTimestampAmongNLatest(1);
    if (stream_bound > Timestamp::Unset()) {
      stream_bound = stream_bound.NextAllowedInStream();
    } else {
      stream_bound = stream->MinTimestampOrBound(nullptr);
    }
    min_bound = std::min(min_bound, stream_bound);
  }
  return min_bound;
}

Timestamp FixedSizeInputStreamHandler::MinTimestampToProcess() {
  Timestamp min_bound = Timestamp::Done();
  for (const auto& stream : input_stream_managers_) {
    bool empty;
    Timestamp stream_timestamp = stream->MinTimestampOrBound(&empty);
    // If we're using the stream's *bound*, we only want to process up to the
    // packet *before* the bound, because a packet may still arrive at that
    // time.
    if (empty) {
      stream_timestamp = PreviousAllowedInStream(stream_timestamp);
    }
    min_bound = std::min(min_bound, stream_timestamp);
  }
  return min_bound;
}

void FixedSizeInputStreamHandler::EraseAnySurplus(bool keep_one) {
  // Record the most recent first kept timestamp on any stream.
  for (const auto& stream : input_stream_managers_) {
    int32_t queue_size = (stream->QueueSize() >= trigger_queue_size_)
                             ? target_queue_size_
                             : trigger_queue_size_ - 1;
    if (stream->QueueSize() > queue_size) {
      kept_timestamp_ = std::max(
          kept_timestamp_, stream->GetMinTimestampAmongNLatest(queue_size + 1)
                               .NextAllowedInStream());
    }
  }
  if (keep_one) {
    // In order to preserve one viable timestamp, do not truncate past
    // the timestamp bound of the least current stream.
    kept_timestamp_ =
        std::min(kept_timestamp_, PreviousAllowedInStream(MinStreamBound()));
  }
  for (auto& stream : input_stream_managers_) {
    stream->ErasePacketsEarlierThan(kept_timestamp_);
  }
}

void FixedSizeInputStreamHandler::EraseSurplusPackets(bool keep_one) {
  return (fixed_min_size_) ? EraseAllSurplus() : EraseAnySurplus(keep_one);
}

NodeReadiness FixedSizeInputStreamHandler::GetNodeReadiness(
    Timestamp* min_stream_timestamp) {
  ABSL_DCHECK(min_stream_timestamp);
  absl::MutexLock lock(&erase_mutex_);
  // kReadyForProcess is returned only once until FillInputSet completes.
  // In late_preparation mode, GetNodeReadiness must return kReadyForProcess
  // exactly once for each input-set produced.  Here, GetNodeReadiness
  // releases just one input-set at a time and then disables input queue
  // truncation until that promised input-set is consumed.
  if (pending_) {
    return NodeReadiness::kNotReady;
  }
  EraseSurplusPackets(false);
  NodeReadiness result =
      DefaultInputStreamHandler::GetNodeReadiness(min_stream_timestamp);

  // If a packet has arrived below kept_timestamp_, recalculate.
  while (*min_stream_timestamp < kept_timestamp_ &&
         result == NodeReadiness::kReadyForProcess) {
    EraseSurplusPackets(false);
    result = DefaultInputStreamHandler::GetNodeReadiness(min_stream_timestamp);
  }
  pending_ = (result == NodeReadiness::kReadyForProcess);
  return result;
}

void FixedSizeInputStreamHandler::AddPackets(CollectionItemId id,
                                             const std::list<Packet>& packets) {
  InputStreamHandler::AddPackets(id, packets);
  absl::MutexLock lock(&erase_mutex_);
  if (!pending_) {
    EraseSurplusPackets(false);
  }
}

void FixedSizeInputStreamHandler::MovePackets(CollectionItemId id,
                                              std::list<Packet>* packets) {
  InputStreamHandler::MovePackets(id, packets);
  absl::MutexLock lock(&erase_mutex_);
  if (!pending_) {
    EraseSurplusPackets(false);
  }
}

void FixedSizeInputStreamHandler::FillInputSet(Timestamp input_timestamp,
                                               InputStreamShardSet* input_set) {
  ABSL_CHECK(input_set);
  absl::MutexLock lock(&erase_mutex_);
  if (!pending_) {
    ABSL_LOG(ERROR) << "FillInputSet called without GetNodeReadiness.";
  }
  // input_timestamp is recalculated here to process the most recent packets.
  EraseSurplusPackets(true);
  input_timestamp = MinTimestampToProcess();
  DefaultInputStreamHandler::FillInputSet(input_timestamp, input_set);
  pending_ = false;
}

REGISTER_INPUT_STREAM_HANDLER(FixedSizeInputStreamHandler);

}  // namespace mediapipe