chromium/third_party/ruy/src/ruy/trace.h

/* Copyright 2021 Google LLC. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef RUY_RUY_TRACE_H_
#define RUY_RUY_TRACE_H_

#ifdef RUY_TRACE

#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <memory>
#include <mutex>
#include <string>
#include <thread>
#include <vector>

#include "ruy/mat.h"
#include "ruy/matrix.h"
#include "ruy/path.h"
#include "ruy/platform.h"
#include "ruy/side_pair.h"

namespace ruy {

// Helper for `formatted` so we don't have to put .c_str() on strings.
template <typename T>
T value_for_snprintf(T value) {
  return value;
}

inline const char* value_for_snprintf(const std::string& s) {
  return s.c_str();
}

// A sprintf-like function returning a std::string.
// Remove this once we can rely on std::format (c++20).
template <typename... Args>
std::string formatted(const char* format, Args... args) {
  char buf[1024];
#pragma GCC diagnostic push
#pragma GCC diagnostic warning "-Wformat-security"
  int size = snprintf(buf, sizeof buf, format, value_for_snprintf(args)...);
#pragma GCC diagnostic pop
  if (size <= 0) {
    abort();
  }
  return std::string(buf);
}

// An entry in the trace.
struct ThreadTraceEntry final {
  std::string text;
  int indent = 0;
  const char* source_file = nullptr;
  int source_line = 0;
};

// Trace for one thread.
class ThreadTrace final {
 public:
  ~ThreadTrace() {}

  void set_thread_id(int thread_id) { thread_id_ = thread_id; }
  int thread_id() const { return thread_id_; }

  bool is_in_run_ahead_packing_loop() const {
    return is_in_run_ahead_packing_loop_;
  }
  void set_is_in_run_ahead_packing_loop(bool value) {
    is_in_run_ahead_packing_loop_ = value;
  }

  void set_current_source_file(const char* source_file) {
    current_source_file_ = source_file;
  }

  void set_current_source_line(int source_line) {
    current_source_line_ = source_line;
  }

  const std::vector<ThreadTraceEntry>& entries() const { return entries_; }

  template <typename... Args>
  void Write(const char* format, Args... args) {
    ThreadTraceEntry entry;
    entry.text = formatted(format, args...);
    entry.indent = indent_;
    entry.source_file = current_source_file_;
    entry.source_line = current_source_line_;
    entries_.emplace_back(std::move(entry));
  }

  template <typename... Args>
  void EnterScope(const char* scope_name) {
    Write("%s {", scope_name);
    indent_++;
  }
  void LeaveScope(const char* scope_name) {
    indent_--;
    Write("}  // end of %s", scope_name);
  }

 private:
  // The trace contents
  std::vector<ThreadTraceEntry> entries_;

  // Current indentation level.
  int indent_ = 0;
  // Thread's ID as set by Ruy, e.g. [0,N-1]. Not OS TID.
  int thread_id_ = -1;
  // The run-ahead loop in `EnsurePacked` may run many iterations when the
  // thread is waiting for a block to be packed by another thread --- it's
  // a busy wait. We track whether we are in that mode to avoid generating
  // many uninteresting trace entries.
  bool is_in_run_ahead_packing_loop_ = false;
  // Last recorded value of __FILE__ and __LINE__, as a convenience so we don't
  // have to pass these in every call to `Write`.
  const char* current_source_file_ = nullptr;
  int current_source_line_ = 0;
};

// Main components of ruy. Used for trace colorization.
enum class Component { kNone, kFrontEnd, kMiddleEnd, kBackEnd, kThreadPool };

// Output format for the trace.
enum class TraceOutputFormat { kNone, kTerminal, kHtml };

inline std::string IndentString(int indent) {
  std::string s;
  for (int i = 0; i < indent; i++) {
    s += "  ";
  }
  return s;
}

// Returns the text to write to the trace to open a colored section.
inline const char* ColorSectionStart(TraceOutputFormat output_format,
                                     Component component) {
  switch (output_format) {
    case TraceOutputFormat::kTerminal:
      switch (component) {
        case Component::kFrontEnd:
          return "\x1b[36m";
        case Component::kMiddleEnd:
          return "\x1b[32m";
        case Component::kBackEnd:
          return "\x1b[31m";
        case Component::kThreadPool:
          return "\x1b[33m";
        default:
          abort();
          return nullptr;
      }
    case TraceOutputFormat::kHtml:
      switch (component) {
        case Component::kFrontEnd:
          return "<span style=\"background-color:#B2EBF2\">";
        case Component::kMiddleEnd:
          return "<span style=\"background-color:#C8E6C9\">";
        case Component::kBackEnd:
          return "<span style=\"background-color:#FFCDD2\">";
        case Component::kThreadPool:
          return "<span style=\"background-color:#FFF9C4\">";
        default:
          abort();
          return nullptr;
      }
    default:
      abort();
      return nullptr;
  }
}

// Returns the text to write to the trace to close a colored section.
inline const char* ColorSectionEnd(TraceOutputFormat output_format) {
  switch (output_format) {
    case TraceOutputFormat::kTerminal:
      return "\x1b[0m";
    case TraceOutputFormat::kHtml:
      return "</span>";
    default:
      abort();
      return nullptr;
  }
}

// Returns the output format to use for the trace.
inline TraceOutputFormat GetOutputFormat() {
  const char* html_env = getenv("RUY_TRACE_HTML");
  if (html_env && strtol(html_env, nullptr, 10) != 0) {
    return TraceOutputFormat::kHtml;
  } else {
    return TraceOutputFormat::kTerminal;
  }
}

// A `basename` function that's good enough for ruy __FILE__'s.
// Note: `basename` is POSIX-only and annoying (takes a char*, may mutate).
inline const char* GetBaseName(const char* path) {
  std::size_t len = strlen(path);
  if (len == 0) {
    return path;
  }
  const char* ptr = path + len - 1;
  while (ptr != path) {
    if (*ptr == '/' || *ptr == '\\') {
      return ptr + 1;
    }
    --ptr;
  }
  // Path did not contain any path separator.
  return path;
}

// Determines a Component (used for colorization) by source file.
inline Component GetComponentBySourceFile(const char* base_name) {
  if (!strcmp(base_name, "pack.h") || !strcmp(base_name, "kernel.h")) {
    return Component::kBackEnd;
  } else if (!strcmp(base_name, "trmul.cc") ||
             !strcmp(base_name, "block_map.cc")) {
    return Component::kMiddleEnd;
  } else if (!strcmp(base_name, "thread_pool.cc")) {
    return Component::kThreadPool;
  } else {
    return Component::kFrontEnd;
  }
}

inline std::string EscapeText(TraceOutputFormat output_format,
                              const std::string& text) {
  if (output_format == TraceOutputFormat::kHtml) {
    std::string escaped_text;
    for (char c : text) {
      if (c == '<') {
        escaped_text += "&lt;";
      } else if (c == '>') {
        escaped_text += "&gt;";
      } else {
        escaped_text += c;
      }
    }
    return escaped_text;
  } else {
    return text;
  }
}

// Prints an entry from the trace to the destination trace file.
inline void Print(const ThreadTraceEntry& entry,
                  TraceOutputFormat output_format, FILE* file) {
  const char* base_name = GetBaseName(entry.source_file);
  Component component = GetComponentBySourceFile(base_name);
  const std::string& source_location =
      formatted("%s:%d", base_name, entry.source_line);
  const std::string& escaped_text = EscapeText(output_format, entry.text);
  fprintf(file, "%s%-32s%s%s%s\n", ColorSectionStart(output_format, component),
          source_location.c_str(), IndentString(entry.indent).c_str(),
          escaped_text.c_str(), ColorSectionEnd(output_format));
}

// Prints a thread's entire trace to the destination trace file.
inline void Print(const ThreadTrace& trace, TraceOutputFormat output_format,
                  FILE* file) {
  if (output_format == TraceOutputFormat::kHtml) {
    fprintf(file, "<html><body><pre>\n<span style=\"font-weight:bold\">\n");
  }
  fprintf(file, "Ruy trace for thread %d:\n", trace.thread_id());
  if (output_format == TraceOutputFormat::kHtml) {
    fprintf(file, "</span>\n");
  }
  for (const ThreadTraceEntry& entry : trace.entries()) {
    Print(entry, output_format, file);
  }
  fprintf(file, "\n");
  if (output_format == TraceOutputFormat::kHtml) {
    fprintf(file, "</pre></body></html>\n");
  }
}

// Holds all the threads' traces. This is a global singleton class.
// On exit, when the singleton is destroyed, the destructor prints out the
// traces.
class AllThreadTraces final {
 public:
  // Add a new ThreadTrace for the current thread. Should be called only once
  // on each thread.
  ThreadTrace* AddCurrentThread() {
    std::lock_guard<std::mutex> lock(mutex_);
    ThreadTrace* thread_trace = new ThreadTrace;
    thread_traces_.emplace_back(thread_trace);
    return thread_trace;
  }
  ~AllThreadTraces() {
    std::lock_guard<std::mutex> lock(mutex_);
    // Open the destination file.
    const char* file_env = getenv("RUY_TRACE_FILE");
    FILE* file = stdout;
    if (file_env) {
      file = fopen(file_env, "w");
      if (!file) {
        fprintf(stderr, "Failed to open %s for write\n", file_env);
        exit(1);
      }
    }
    // Sort the threads by Ruy Thread ID (not OS TID).
    auto output_format = GetOutputFormat();
    std::sort(std::begin(thread_traces_), std::end(thread_traces_),
              [](const auto& a, const auto& b) {
                return a->thread_id() < b->thread_id();
              });
    // Print all the threads' traces.
    for (const auto& trace : thread_traces_) {
      Print(*trace, output_format, file);
    }
    if (file_env) {
      fclose(file);
    }
  }
  static AllThreadTraces* Singleton() {
    static AllThreadTraces all_thread_traces;
    return &all_thread_traces;
  }

 private:
  std::vector<std::unique_ptr<ThreadTrace>> thread_traces_;
  std::mutex mutex_;
};

// Returns the thread-local ThreadTrace singleton, constructing it as needed.
inline ThreadTrace* ThreadLocalTrace() {
  static thread_local ThreadTrace* thread_local_trace =
      AllThreadTraces::Singleton()->AddCurrentThread();
  return thread_local_trace;
}

// RAII helper to trace a scope, e.g. a function scope.
class RuyTraceScope {
  const char* source_file_;
  int source_line_;
  const char* scope_name_;

 public:
  RuyTraceScope(const char* source_file, int source_line,
                const char* scope_name)
      : source_file_(source_file),
        source_line_(source_line),
        scope_name_(scope_name) {
    ThreadLocalTrace()->set_current_source_file(source_file_);
    ThreadLocalTrace()->set_current_source_line(source_line_);
    ThreadLocalTrace()->EnterScope(scope_name_);
  }
  ~RuyTraceScope() {
    ThreadLocalTrace()->set_current_source_file(source_file_);
    ThreadLocalTrace()->set_current_source_line(source_line_);
    ThreadLocalTrace()->LeaveScope(scope_name_);
  }
};

#define RUY_TRACE_SCOPE_NAME_IMPL
#define RUY_TRACE_SCOPE_NAME
#define RUY_TRACE_SCOPE

// Helpers to trace Ruy objects.

inline std::string str(Order o) {
  return o == Order::kRowMajor ? "row-major" : "column-major";
}

inline std::string str(Side s) { return s == Side::kLhs ? "LHS" : "RHS"; }

inline std::string str(const Layout& layout) {
  std::string s =
      formatted("%dx%d, %s", layout.rows(), layout.cols(), str(layout.order()));
  int inner_size =
      layout.order() == Order::kRowMajor ? layout.cols() : layout.rows();
  if (inner_size != layout.stride()) {
    s += formatted(", stride=%d", layout.stride());
  } else {
    s += formatted(", unstrided");
  }
  return s;
}

inline std::string str(const MatLayout& layout) {
  std::string s =
      formatted("%dx%d, %s", layout.rows, layout.cols, str(layout.order));
  int inner_size = layout.order == Order::kRowMajor ? layout.cols : layout.rows;
  if (inner_size != layout.stride) {
    s += formatted(", stride=%d", layout.stride);
  } else {
    s += formatted(", unstrided");
  }
  return s;
}

inline std::string str(const PMatLayout& layout) {
  std::string s =
      formatted("%dx%d, %s", layout.rows, layout.cols, str(layout.order));
  int inner_size = layout.order == Order::kRowMajor ? layout.cols : layout.rows;
  if (inner_size != layout.stride) {
    s += formatted(", stride=%d", layout.stride);
  } else {
    s += formatted(", unstrided");
  }
  s += formatted(", kernel blocks: %dx%d %s", layout.kernel.rows,
                 layout.kernel.cols, str(layout.kernel.order));
  return s;
}

template <typename T>
std::string str() {
  return "<unknown type>";
}
#define RUY_IMPL_STR_TYPE_STD
#define RUY_IMPL_STR_TYPE

RUY_IMPL_STR_TYPE(float)
RUY_IMPL_STR_TYPE(double)
RUY_IMPL_STR_TYPE_STD(int8_t)
RUY_IMPL_STR_TYPE_STD(uint8_t)
RUY_IMPL_STR_TYPE_STD(int16_t)
RUY_IMPL_STR_TYPE_STD(uint16_t)
RUY_IMPL_STR_TYPE_STD(int32_t)
RUY_IMPL_STR_TYPE_STD(uint32_t)
RUY_IMPL_STR_TYPE_STD(int64_t)
RUY_IMPL_STR_TYPE_STD(uint64_t)

template <typename T>
std::string str(const Matrix<T>& matrix) {
  std::string s = formatted("Matrix<%s>, %s", str<T>(), str(matrix.layout()));
  if (matrix.zero_point()) {
    s += formatted(", zero_point=%d", static_cast<int>(matrix.zero_point()));
  }
  if (matrix.cache_policy() != CachePolicy::kNeverCache) {
    s +=
        formatted(", cache_policy=%d", static_cast<int>(matrix.cache_policy()));
  }
  return s;
}

inline std::string str(const Type& type) {
  char c;
  if (type.is_floating_point) {
    c = 'f';
  } else if (type.is_signed) {
    c = 'i';
  } else {
    c = 'u';
  }
  return formatted("%c%d", c, type.size * 8);
}

inline std::string str(const EMat& mat) {
  std::string s =
      formatted("EMat, data_type=%s, %s", str(mat.data_type), str(mat.layout));
  if (mat.zero_point) {
    s += formatted(", zero_point=%d", static_cast<int>(mat.zero_point));
  }
  if (mat.cache_policy != CachePolicy::kNeverCache) {
    s += formatted(", cache_policy=%d", static_cast<int>(mat.cache_policy));
  }
  return s;
}

inline std::string str(const PEMat& mat) {
  std::string s =
      formatted("PEMat, data_type=%s, %s", str(mat.data_type), str(mat.layout));
  if (mat.zero_point) {
    s += formatted(", zero_point=%d", static_cast<int>(mat.zero_point));
  }
  return s;
}

inline std::string str(Path paths) {
  bool first = true;
  std::string s;
  for (int bit = 0; bit < 16; bit++) {
    Path cur_path = static_cast<Path>(1 << bit);
    if ((paths & cur_path) != Path::kNone) {
      if (!first) {
        s += " | ";
      }
      first = false;
      switch (cur_path) {
        case Path::kNone:
          continue;
#define RUY_HANDLE_PATH
          RUY_HANDLE_PATH(kStandardCpp)
          RUY_HANDLE_PATH(kInternalStandardCppVariant1)
          RUY_HANDLE_PATH(kInternalStandardCppVariant2)
          RUY_HANDLE_PATH(kInternalStandardCppVariant3)
#if RUY_PLATFORM_ARM
          RUY_HANDLE_PATH(kNeon)
          RUY_HANDLE_PATH(kNeonDotprod)
#endif  // RUY_PLATFORM_ARM
#if RUY_PLATFORM_X86
          RUY_HANDLE_PATH(kAvx)
          RUY_HANDLE_PATH(kAvx2Fma)
          RUY_HANDLE_PATH(kAvx512)
#endif  // RUY_PLATFORM_X86
#undef RUY_HANDLE_PATH
        default:
          fprintf(stderr, "Unhandled Path value 0x%x\n",
                  static_cast<int>(cur_path));
          abort();
      }
    }
  }
  return s;
}

// Implementation of RUY_TRACE_INFO(X) macros.

#define RUY_TRACE_INFO_MUL

#define RUY_TRACE_INFO_CREATE_TRMUL_PARAMS_TRANSPOSING

#define RUY_TRACE_INFO_CREATE_TRMUL_PARAMS_ASSUMING_COLMAJOR_DST

#define RUY_TRACE_INFO_POPULATE_TRMUL_PARAMS

#define RUY_TRACE_INFO_GET_RUNTIME_ENABLED_PATHS_USING_SET_VALUE

#define RUY_TRACE_INFO_GET_RUNTIME_ENABLED_PATHS_USING_ENV_VAR

#define RUY_TRACE_INFO_GET_RUNTIME_ENABLED_PATHS_USING_DETECTION

#define RUY_TRACE_INFO_PREPARE_PACKED_MATRICES_SHOULD_CACHE

#define RUY_TRACE_INFO_PREPARE_PACKED_MATRICES_NO_CACHE

#define RUY_TRACE_INFO_GET_TENTATIVE_THREAD_COUNT

#define RUY_TRACE_INFO_GET_USE_SIMPLE_LOOP_RETURNS_TRUE

#define RUY_TRACE_INFO_GET_USE_SIMPLE_LOOP_RETURNS_FALSE

#define RUY_TRACE_INFO_TRMUL_SIMPLE_LOOP

#define RUY_TRACE_INFO_TRMUL_GENERAL_CASE

#define RUY_TRACE_INFO_MAKE_BLOCK_MAP_START

#define RUY_TRACE_INFO_MAKE_BLOCK_MAP_EACH_TENTATIVE_BLOCK_SIZE

#define RUY_TRACE_INFO_MAKE_BLOCK_MAP_END

#define RUY_TRACE_SET_THEAD_ID

#define RUY_TRACE_INFO_TRMUL_TASK_MAIN_LOOP_GOT_BLOCK_COORDS

#define RUY_TRACE_INFO_TRYPACK_PACKING

#define RUY_TRACE_INFO_TRYPACK_ANOTHER_THREAD_PACKING

#define RUY_TRACE_INFO_TRYPACK_PREVIOUSLY_PACKED

#define RUY_TRACE_INFO_TRYPACK_PACKED_BY_ANOTHER_THREAD

#define RUY_TRACE_INFO_ENSURE_PACKED_ENTER_RUN_AHEAD

#define RUY_TRACE_INFO_ENSURE_PACKED_END

#define RUY_TRACE_INFO_RUN_PACK

#define RUY_TRACE_INFO_RUN_KERNEL

#define RUY_TRACE_INFO_THREAD_FUNC_IMPL_WAITING

#define RUY_TRACE_INFO_THREADPOOL_EXECUTE_STARTING_TASK

#define RUY_TRACE_INFO_THREADPOOL_EXECUTE_STARTING_TASK_ZERO_ON_CUR_THREAD

#define RUY_TRACE_INFO_THREADPOOL_EXECUTE_WAITING_FOR_THREADS

#define RUY_TRACE_INFO

}  // namespace ruy

#else

// Vacuous implementation when RUY_TRACE is not defined.
#define RUY_TRACE_SCOPE_NAME(name)
#define RUY_TRACE_SCOPE
#define RUY_TRACE_SET_THEAD_ID(thread_id)
#define RUY_TRACE_INFO(id)

#endif

#endif  // RUY_RUY_TRACE_H_