chromium/third_party/mediapipe/src/mediapipe/framework/profiler/reporter/reporter.cc

#include "mediapipe/framework/profiler/reporter/reporter.h"

#include <algorithm>
#include <iterator>
#include <memory>
#include <ostream>
#include <string>
#include <string_view>
#include <vector>

#include "absl/container/btree_map.h"
#include "absl/status/status.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "fstream"
#include "map"
#include "mediapipe/framework/calculator_profile.pb.h"
#include "mediapipe/framework/port/advanced_proto_inc.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/re2.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_macros.h"

namespace mediapipe {
namespace reporter {

const LazyRE2 kValidColumnRegex = {"^[a-zA-Z0-9_?*]+$"};
const LazyRE2 kReplace1WildcharRegex = {"\\?"};
const LazyRE2 kReplace0toNWildcharRegex = {"\\*"};

std::string ToStringF(double d) { return absl::StrFormat("%1.2f", d); }
std::string ToString(double d) { return absl::StrFormat("%1.0f", d); }

absl::btree_map<std::string,
                std::function<const std::string(const CalculatorData&)>>
    kColumns = {
        {"calculator",
         [](const CalculatorData& d) -> const std::string { return d.name; }},
        {"counter",
         [](const CalculatorData& d) -> const std::string {
           return ToString(d.counter);
         }},
        {"completed",
         [](const CalculatorData& d) -> const std::string {
           return ToString(d.completed);
         }},
        {"dropped",
         [](const CalculatorData& d) -> const std::string {
           return ToString(d.dropped);
         }},
        {"fps",
         [](const CalculatorData& d) -> const std::string {
           return ToStringF(d.fps);
         }},
        {"frequency",
         [](const CalculatorData& d) -> const std::string {
           return ToStringF(d.frequency);
         }},
        {"processing_rate",
         [](const CalculatorData& d) -> const std::string {
           return ToStringF(d.processing_rate);
         }},
        {"thread_count",
         [](const CalculatorData& d) -> const std::string {
           return ToStringF(d.thread_count);
         }},
        {"time_mean",
         [](const CalculatorData& d) -> const std::string {
           return ToStringF(d.time_stat.mean());
         }},
        {"time_stddev",
         [](const CalculatorData& d) -> const std::string {
           return ToStringF(d.time_stat.stddev());
         }},
        {"time_total",
         [](const CalculatorData& d) -> const std::string {
           return ToString(d.time_stat.total());
         }},
        {"time_percent",
         [](const CalculatorData& d) -> const std::string {
           return ToStringF(d.time_percent);
         }},
        {"input_latency_mean",
         [](const CalculatorData& d) -> const std::string {
           return ToStringF(d.input_latency_stat.mean());
         }},
        {"input_latency_stddev",
         [](const CalculatorData& d) -> const std::string {
           return ToStringF(d.input_latency_stat.stddev());
         }},
        {"input_latency_total",
         [](const CalculatorData& d) -> const std::string {
           return ToString(d.input_latency_stat.total());
         }}};

// Holds calculator traces that have an output trace with a provided stream ID
// and Packet timestamp.
typedef std::map<std::pair<int64_t, int32_t>,
                 const mediapipe::GraphTrace::CalculatorTrace*>
    PacketKeyToCalcTrace;

// Use this to locate the calculator trace with the start time for a given
// node its thread_id and its packet timestamp.
typedef std::map<std::pair<int64_t, std::pair<int32_t, int32_t>>,
                 const mediapipe::GraphTrace::CalculatorTrace*>
    TimestampNodeIdToCalcTrace;

// Maps node IDs to names.
typedef std::map<int32_t, std::string> NameLookup;

Reporter::Reporter() { MEDIAPIPE_CHECK_OK(set_columns({"*"})); }

int64_t RecursePacketStartTime(
    const PacketKeyToCalcTrace& output_trace_lookup,
    const mediapipe::GraphProfile& profile,
    const mediapipe::GraphTrace::CalculatorTrace& trace,
    std::vector<int32_t>* visited_calculators) {
  int64_t child_start_time =
      trace.has_start_time() ? trace.start_time() : trace.finish_time();
  const int32_t node_id = trace.node_id();
  if (std::find(visited_calculators->begin(), visited_calculators->end(),
                node_id) != visited_calculators->end()) {
    return child_start_time;
  }
  visited_calculators->push_back(node_id);

  for (size_t index = 0; index < trace.input_trace_size(); ++index) {
    const auto& stream_trace = trace.input_trace(index);

    // Find the output corresponding to this input.
    const auto it = output_trace_lookup.find(std::make_pair(
        stream_trace.packet_timestamp(), stream_trace.stream_id()));
    if (it != output_trace_lookup.end()) {
      child_start_time =
          std::min(child_start_time,
                   RecursePacketStartTime(output_trace_lookup, profile,
                                          *it->second, visited_calculators));
    }
  }
  return child_start_time;
}

int64_t CalculateInputLatency(
    const PacketKeyToCalcTrace& output_trace_lookup,
    const mediapipe::GraphProfile& profile,
    const mediapipe::GraphTrace::CalculatorTrace& trace) {
  // Track visited calculators to detect loops.
  std::vector<int> visited_calculators;

  // If a calculator has no start time, then there is no latency to measure.
  const auto result =
      !trace.has_start_time()
          ? 0L
          : trace.start_time() - RecursePacketStartTime(output_trace_lookup,
                                                        profile, trace,
                                                        &visited_calculators);

  return result;
}

void CacheNodeNameLookup(const mediapipe::GraphProfile& profile,
                         NameLookup* result) {
  for (const auto& graph_trace : profile.graph_trace()) {
    int key = 0;
    for (const auto& calc_name : graph_trace.calculator_name()) {
      (*result)[key] = calc_name;
      ++key;
    }
  }
}

void CacheOutputTraceLookup(const mediapipe::GraphProfile& profile,
                            PacketKeyToCalcTrace* output_trace_lookup,
                            TimestampNodeIdToCalcTrace* start_time_lookup) {
  for (const auto& graph_trace : profile.graph_trace()) {
    for (const auto& calc_trace : graph_trace.calculator_trace()) {
      if (calc_trace.event_type() != mediapipe::GraphTrace_EventType_PROCESS) {
        continue;
      }
      if (calc_trace.has_start_time() && !calc_trace.has_finish_time()) {
        (*start_time_lookup)[std::make_pair(
            calc_trace.input_timestamp(),
            std::make_pair(calc_trace.node_id(), calc_trace.thread_id()))] =
            &calc_trace;
      }
      for (const auto& stream_trace : calc_trace.output_trace()) {
        (*output_trace_lookup)[std::make_pair(stream_trace.packet_timestamp(),
                                              stream_trace.stream_id())] =
            &calc_trace;
      }
    }
  }
}

void CompleteCalculatorData(
    const GraphData& graph_data,
    std::map<std::string, CalculatorData>* calculator_data) {
  for (auto& calc_entry : *calculator_data) {
    auto& calc_data = calc_entry.second;
    const auto& time_stat = calc_data.time_stat;
    const auto& latency_stat = calc_data.input_latency_stat;
    const auto time_to_process = time_stat.mean() + latency_stat.mean();
    calc_data.fps = time_to_process == 0 ? 0 : 1.0E+6 / time_to_process;

    const auto duration = graph_data.max_time - graph_data.min_time;
    calc_data.frequency = calc_data.completed / (duration / 1.0E+6);
    calc_data.time_percent = 100 * time_stat.total() / graph_data.total_time;
    calc_data.dropped = calc_data.counter - calc_data.completed;

    calc_data.processing_rate = calc_data.time_stat.mean() == 0
                                    ? 0
                                    : 1.0 / calc_data.time_stat.mean() * 1.0E+6;
    calc_data.thread_count = calc_data.threads.size();
  }
}

void Reporter::Accumulate(const mediapipe::GraphProfile& profile) {
  // Cache nodeID to its string name.
  NameLookup name_lookup;
  CacheNodeNameLookup(profile, &name_lookup);

  // Cache some lookups so that we can quickly find the matching output stream
  // for a given input stream, and so that we can find the start time of a
  // given timestamp of a node.
  PacketKeyToCalcTrace output_trace_lookup;
  TimestampNodeIdToCalcTrace start_event_lookup;
  CacheOutputTraceLookup(profile, &output_trace_lookup, &start_event_lookup);

  // Hold the domain of all times found in the trace file.
  auto& min_time = graph_data_.min_time;
  auto& max_time = graph_data_.max_time;
  auto& total_time = graph_data_.total_time;

  // The start and finish time of PROCESS events might be split between
  // events. If a start event has been seen, we'll record it so that we can
  // match it up later.
  std::map<int32_t, absl::optional<int64_t>> start_times;

  for (const auto& graph_trace : profile.graph_trace()) {
    for (const auto& calc_trace : graph_trace.calculator_trace()) {
      if (calc_trace.event_type() != mediapipe::GraphTrace_EventType_PROCESS) {
        continue;
      }

      const auto& node_name = name_lookup[calc_trace.node_id()];
      auto& calc_data = calculator_data_[node_name];

      calc_data.name = node_name;
      calc_data.threads.insert(calc_trace.thread_id());

      // If there is a start time, update the domain of the trace time, and
      // mark that we've seen a start time for this calculator.
      if (calc_trace.has_start_time()) {
        min_time =
            std::min(min_time, static_cast<int64_t>(calc_trace.start_time() +
                                                    graph_trace.base_time()));
        ++calc_data.counter;
      }

      // If there is a finish time, update the domain and mark that an event
      // has been completed.
      if (calc_trace.has_finish_time()) {
        const auto finish_time =
            calc_trace.finish_time() + graph_trace.base_time();
        max_time = std::max(max_time, static_cast<int64_t>(finish_time));

        absl::optional<int64_t> start_time;
        if (calc_trace.has_start_time()) {
          start_time.emplace(calc_trace.start_time());
        } else {
          const auto start_event_it = start_event_lookup.find(std::make_pair(
              calc_trace.input_timestamp(),
              std::make_pair(calc_trace.node_id(), calc_trace.thread_id())));
          if (start_event_it != start_event_lookup.end()) {
            start_time.emplace(start_event_it->second->start_time());
          }
        }

        // Edge case -- If a finish time came in without a start time, then
        // we know that an event started before the trace became available.
        // But since we don't know when that is, we can't record its duration
        // and won't count it.
        if (start_time) {
          ++calc_data.completed;
          // Add up the duration of the events that led up to this start
          // event.
          const auto input_latency =
              CalculateInputLatency(output_trace_lookup, profile, calc_trace);
          calc_data.input_latency_stat.Push(input_latency);
          const auto duration =
              finish_time - (start_time.value() + graph_trace.base_time());
          calc_data.time_stat.Push(duration);
          total_time += duration;
        }
      }
    }
  }
}

absl::Status Reporter::set_columns(const std::vector<std::string>& columns) {
  bool error = false;
  std::stringstream warnings;
  std::vector<std::string> new_columns({"calculator"});

  // Iterate through the desired columns and build a regex.
  for (const auto& column_matcher : columns) {
    if (!RE2::PartialMatch(column_matcher, *kValidColumnRegex)) {
      warnings << "Column '" << column_matcher << "' is invalid." << std::endl;
      error = true;
      continue;
    }
    std::string colString = column_matcher;
    RE2::GlobalReplace(&colString, *kReplace0toNWildcharRegex, ".*");
    RE2::GlobalReplace(&colString, *kReplace1WildcharRegex, ".");

    // Iterator through our available columns and add them to our collection
    // of new columns if they do not already exist.
    bool matched = false;
    for (const auto& column : kColumns) {
      if (RE2::FullMatch(column.first, colString)) {
        matched = true;
        if (std::find(new_columns.begin(), new_columns.end(), column.first) ==
            new_columns.end()) {
          new_columns.push_back(column.first);
        }
      }
    }
    if (!matched) {
      warnings << "Column '" << column_matcher << "' did not match any columns."
               << std::endl;
      error = true;
    }
  }
  // As long as there is still one column, honor the request, even if error.
  if (!new_columns.empty()) {
    columns_.swap(new_columns);
  }
  if (!error) {
    return absl::OkStatus();
  }
  return absl::InvalidArgumentError(warnings.str());
}

class ReportImpl : public Report {
 public:
  ReportImpl(const std::map<std::string, CalculatorData>& calculator_data,
             const GraphData& graph_data)
      : calculator_data_(calculator_data), graph_data_(graph_data) {}
  void Print(std::ostream& output) override;
  const std::vector<std::string>& headers() override { return headers_impl; }
  const std::vector<std::vector<std::string>>& lines() override {
    return lines_impl;
  }
  const GraphData& graph_data() override { return graph_data_; }
  const std::map<std::string, CalculatorData>& calculator_data() override {
    return calculator_data_;
  }

  // Each header name in alphabetical order, except the first column, which is
  // always "calculator".
  std::vector<std::string> headers_impl;

  // Values for each calculator, corresponding to the label in headers().
  std::vector<std::vector<std::string>> lines_impl;

  // The longest string of any value in a given column (including the header
  // for that column). Used for formatting the output.
  std::vector<size_t> char_counts_impl;
  bool compact_flag = false;

  const std::map<std::string, CalculatorData>& calculator_data_;
  const GraphData& graph_data_;
};

void ReportImpl::Print(std::ostream& output) {
  // Print the results, but aside from the last column, add whitespace to
  // fill space up to char_counts[column] + 1. The strings in the output
  // are mutable to support padding, hence no const in the for loops.
  int column_number = 0;
  // Make a copy of the column string because we might be adding spaces.
  for (auto column : headers_impl) {
    int padding_needed = char_counts_impl[column_number] + 1 - column.length();
    if (compact_flag) {
      padding_needed = 1;
    }
    column.append(padding_needed, ' ');
    output << column;
    column_number++;
  }
  output << std::endl;

  for (auto& row : lines_impl) {
    int column_number = 0;
    for (auto column : row) {
      int padding_needed =
          char_counts_impl[column_number] + 1 - column.length();
      if (compact_flag) {
        padding_needed = 1;
      }
      column.append(padding_needed, ' ');
      output << column;
      column_number++;
    }
    output << std::endl;
  }
}

std::unique_ptr<Report> Reporter::Report() {
  CompleteCalculatorData(graph_data_, &calculator_data_);

  auto report = std::make_unique<ReportImpl>(calculator_data_, graph_data_);
  report->compact_flag = compact_flag_;

  // First row contains the column headers.
  auto& headers = report->headers_impl;
  auto& lines = report->lines_impl;
  auto& char_counts = report->char_counts_impl;

  headers = columns_;
  char_counts.resize(headers.size());
  std::transform(headers.begin(), headers.end(), char_counts.begin(),
                 [](const auto& header) { return header.size(); });

  for (const auto& header : headers) {
    size_t line_num = 0;
    for (auto calc_it = calculator_data_.begin();
         calc_it != calculator_data_.end(); ++calc_it) {
      const std::string value = kColumns[header](calc_it->second);
      if (calc_it->second.name.empty()) {
        continue;
      }
      while (line_num >= lines.size()) {
        lines.push_back({});
      }
      auto& line = lines[line_num];
      line.push_back(value);
      const size_t char_count_index = line.size() - 1;
      char_counts[char_count_index] =
          std::max(char_counts[char_count_index], value.length());
      ++line_num;
    }
  }
  return report;
}

}  // namespace reporter
}  // namespace mediapipe