#ifndef RUY_RUY_TRACE_H_
#define RUY_RUY_TRACE_H_
#ifdef RUY_TRACE
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <memory>
#include <mutex>
#include <string>
#include <thread>
#include <vector>
#include "ruy/mat.h"
#include "ruy/matrix.h"
#include "ruy/path.h"
#include "ruy/platform.h"
#include "ruy/side_pair.h"
namespace ruy {
template <typename T>
T value_for_snprintf(T value) {
return value;
}
inline const char* value_for_snprintf(const std::string& s) {
return s.c_str();
}
template <typename... Args>
std::string formatted(const char* format, Args... args) {
char buf[1024];
#pragma GCC diagnostic push
#pragma GCC diagnostic warning "-Wformat-security"
int size = snprintf(buf, sizeof buf, format, value_for_snprintf(args)...);
#pragma GCC diagnostic pop
if (size <= 0) {
abort();
}
return std::string(buf);
}
struct ThreadTraceEntry final {
std::string text;
int indent = 0;
const char* source_file = nullptr;
int source_line = 0;
};
class ThreadTrace final {
public:
~ThreadTrace() {}
void set_thread_id(int thread_id) { thread_id_ = thread_id; }
int thread_id() const { return thread_id_; }
bool is_in_run_ahead_packing_loop() const {
return is_in_run_ahead_packing_loop_;
}
void set_is_in_run_ahead_packing_loop(bool value) {
is_in_run_ahead_packing_loop_ = value;
}
void set_current_source_file(const char* source_file) {
current_source_file_ = source_file;
}
void set_current_source_line(int source_line) {
current_source_line_ = source_line;
}
const std::vector<ThreadTraceEntry>& entries() const { return entries_; }
template <typename... Args>
void Write(const char* format, Args... args) {
ThreadTraceEntry entry;
entry.text = formatted(format, args...);
entry.indent = indent_;
entry.source_file = current_source_file_;
entry.source_line = current_source_line_;
entries_.emplace_back(std::move(entry));
}
template <typename... Args>
void EnterScope(const char* scope_name) {
Write("%s {", scope_name);
indent_++;
}
void LeaveScope(const char* scope_name) {
indent_--;
Write("} // end of %s", scope_name);
}
private:
std::vector<ThreadTraceEntry> entries_;
int indent_ = 0;
int thread_id_ = -1;
bool is_in_run_ahead_packing_loop_ = false;
const char* current_source_file_ = nullptr;
int current_source_line_ = 0;
};
enum class Component { kNone, kFrontEnd, kMiddleEnd, kBackEnd, kThreadPool };
enum class TraceOutputFormat { kNone, kTerminal, kHtml };
inline std::string IndentString(int indent) {
std::string s;
for (int i = 0; i < indent; i++) {
s += " ";
}
return s;
}
inline const char* ColorSectionStart(TraceOutputFormat output_format,
Component component) {
switch (output_format) {
case TraceOutputFormat::kTerminal:
switch (component) {
case Component::kFrontEnd:
return "\x1b[36m";
case Component::kMiddleEnd:
return "\x1b[32m";
case Component::kBackEnd:
return "\x1b[31m";
case Component::kThreadPool:
return "\x1b[33m";
default:
abort();
return nullptr;
}
case TraceOutputFormat::kHtml:
switch (component) {
case Component::kFrontEnd:
return "<span style=\"background-color:#B2EBF2\">";
case Component::kMiddleEnd:
return "<span style=\"background-color:#C8E6C9\">";
case Component::kBackEnd:
return "<span style=\"background-color:#FFCDD2\">";
case Component::kThreadPool:
return "<span style=\"background-color:#FFF9C4\">";
default:
abort();
return nullptr;
}
default:
abort();
return nullptr;
}
}
inline const char* ColorSectionEnd(TraceOutputFormat output_format) {
switch (output_format) {
case TraceOutputFormat::kTerminal:
return "\x1b[0m";
case TraceOutputFormat::kHtml:
return "</span>";
default:
abort();
return nullptr;
}
}
inline TraceOutputFormat GetOutputFormat() {
const char* html_env = getenv("RUY_TRACE_HTML");
if (html_env && strtol(html_env, nullptr, 10) != 0) {
return TraceOutputFormat::kHtml;
} else {
return TraceOutputFormat::kTerminal;
}
}
inline const char* GetBaseName(const char* path) {
std::size_t len = strlen(path);
if (len == 0) {
return path;
}
const char* ptr = path + len - 1;
while (ptr != path) {
if (*ptr == '/' || *ptr == '\\') {
return ptr + 1;
}
--ptr;
}
return path;
}
inline Component GetComponentBySourceFile(const char* base_name) {
if (!strcmp(base_name, "pack.h") || !strcmp(base_name, "kernel.h")) {
return Component::kBackEnd;
} else if (!strcmp(base_name, "trmul.cc") ||
!strcmp(base_name, "block_map.cc")) {
return Component::kMiddleEnd;
} else if (!strcmp(base_name, "thread_pool.cc")) {
return Component::kThreadPool;
} else {
return Component::kFrontEnd;
}
}
inline std::string EscapeText(TraceOutputFormat output_format,
const std::string& text) {
if (output_format == TraceOutputFormat::kHtml) {
std::string escaped_text;
for (char c : text) {
if (c == '<') {
escaped_text += "<";
} else if (c == '>') {
escaped_text += ">";
} else {
escaped_text += c;
}
}
return escaped_text;
} else {
return text;
}
}
inline void Print(const ThreadTraceEntry& entry,
TraceOutputFormat output_format, FILE* file) {
const char* base_name = GetBaseName(entry.source_file);
Component component = GetComponentBySourceFile(base_name);
const std::string& source_location =
formatted("%s:%d", base_name, entry.source_line);
const std::string& escaped_text = EscapeText(output_format, entry.text);
fprintf(file, "%s%-32s%s%s%s\n", ColorSectionStart(output_format, component),
source_location.c_str(), IndentString(entry.indent).c_str(),
escaped_text.c_str(), ColorSectionEnd(output_format));
}
inline void Print(const ThreadTrace& trace, TraceOutputFormat output_format,
FILE* file) {
if (output_format == TraceOutputFormat::kHtml) {
fprintf(file, "<html><body><pre>\n<span style=\"font-weight:bold\">\n");
}
fprintf(file, "Ruy trace for thread %d:\n", trace.thread_id());
if (output_format == TraceOutputFormat::kHtml) {
fprintf(file, "</span>\n");
}
for (const ThreadTraceEntry& entry : trace.entries()) {
Print(entry, output_format, file);
}
fprintf(file, "\n");
if (output_format == TraceOutputFormat::kHtml) {
fprintf(file, "</pre></body></html>\n");
}
}
class AllThreadTraces final {
public:
ThreadTrace* AddCurrentThread() {
std::lock_guard<std::mutex> lock(mutex_);
ThreadTrace* thread_trace = new ThreadTrace;
thread_traces_.emplace_back(thread_trace);
return thread_trace;
}
~AllThreadTraces() {
std::lock_guard<std::mutex> lock(mutex_);
const char* file_env = getenv("RUY_TRACE_FILE");
FILE* file = stdout;
if (file_env) {
file = fopen(file_env, "w");
if (!file) {
fprintf(stderr, "Failed to open %s for write\n", file_env);
exit(1);
}
}
auto output_format = GetOutputFormat();
std::sort(std::begin(thread_traces_), std::end(thread_traces_),
[](const auto& a, const auto& b) {
return a->thread_id() < b->thread_id();
});
for (const auto& trace : thread_traces_) {
Print(*trace, output_format, file);
}
if (file_env) {
fclose(file);
}
}
static AllThreadTraces* Singleton() {
static AllThreadTraces all_thread_traces;
return &all_thread_traces;
}
private:
std::vector<std::unique_ptr<ThreadTrace>> thread_traces_;
std::mutex mutex_;
};
inline ThreadTrace* ThreadLocalTrace() {
static thread_local ThreadTrace* thread_local_trace =
AllThreadTraces::Singleton()->AddCurrentThread();
return thread_local_trace;
}
class RuyTraceScope {
const char* source_file_;
int source_line_;
const char* scope_name_;
public:
RuyTraceScope(const char* source_file, int source_line,
const char* scope_name)
: source_file_(source_file),
source_line_(source_line),
scope_name_(scope_name) {
ThreadLocalTrace()->set_current_source_file(source_file_);
ThreadLocalTrace()->set_current_source_line(source_line_);
ThreadLocalTrace()->EnterScope(scope_name_);
}
~RuyTraceScope() {
ThreadLocalTrace()->set_current_source_file(source_file_);
ThreadLocalTrace()->set_current_source_line(source_line_);
ThreadLocalTrace()->LeaveScope(scope_name_);
}
};
#define RUY_TRACE_SCOPE_NAME_IMPL …
#define RUY_TRACE_SCOPE_NAME …
#define RUY_TRACE_SCOPE …
inline std::string str(Order o) {
return o == Order::kRowMajor ? "row-major" : "column-major";
}
inline std::string str(Side s) { return s == Side::kLhs ? "LHS" : "RHS"; }
inline std::string str(const Layout& layout) {
std::string s =
formatted("%dx%d, %s", layout.rows(), layout.cols(), str(layout.order()));
int inner_size =
layout.order() == Order::kRowMajor ? layout.cols() : layout.rows();
if (inner_size != layout.stride()) {
s += formatted(", stride=%d", layout.stride());
} else {
s += formatted(", unstrided");
}
return s;
}
inline std::string str(const MatLayout& layout) {
std::string s =
formatted("%dx%d, %s", layout.rows, layout.cols, str(layout.order));
int inner_size = layout.order == Order::kRowMajor ? layout.cols : layout.rows;
if (inner_size != layout.stride) {
s += formatted(", stride=%d", layout.stride);
} else {
s += formatted(", unstrided");
}
return s;
}
inline std::string str(const PMatLayout& layout) {
std::string s =
formatted("%dx%d, %s", layout.rows, layout.cols, str(layout.order));
int inner_size = layout.order == Order::kRowMajor ? layout.cols : layout.rows;
if (inner_size != layout.stride) {
s += formatted(", stride=%d", layout.stride);
} else {
s += formatted(", unstrided");
}
s += formatted(", kernel blocks: %dx%d %s", layout.kernel.rows,
layout.kernel.cols, str(layout.kernel.order));
return s;
}
template <typename T>
std::string str() {
return "<unknown type>";
}
#define RUY_IMPL_STR_TYPE_STD …
#define RUY_IMPL_STR_TYPE …
RUY_IMPL_STR_TYPE(float)
RUY_IMPL_STR_TYPE(double)
RUY_IMPL_STR_TYPE_STD(int8_t)
RUY_IMPL_STR_TYPE_STD(uint8_t)
RUY_IMPL_STR_TYPE_STD(int16_t)
RUY_IMPL_STR_TYPE_STD(uint16_t)
RUY_IMPL_STR_TYPE_STD(int32_t)
RUY_IMPL_STR_TYPE_STD(uint32_t)
RUY_IMPL_STR_TYPE_STD(int64_t)
RUY_IMPL_STR_TYPE_STD(uint64_t)
template <typename T>
std::string str(const Matrix<T>& matrix) {
std::string s = formatted("Matrix<%s>, %s", str<T>(), str(matrix.layout()));
if (matrix.zero_point()) {
s += formatted(", zero_point=%d", static_cast<int>(matrix.zero_point()));
}
if (matrix.cache_policy() != CachePolicy::kNeverCache) {
s +=
formatted(", cache_policy=%d", static_cast<int>(matrix.cache_policy()));
}
return s;
}
inline std::string str(const Type& type) {
char c;
if (type.is_floating_point) {
c = 'f';
} else if (type.is_signed) {
c = 'i';
} else {
c = 'u';
}
return formatted("%c%d", c, type.size * 8);
}
inline std::string str(const EMat& mat) {
std::string s =
formatted("EMat, data_type=%s, %s", str(mat.data_type), str(mat.layout));
if (mat.zero_point) {
s += formatted(", zero_point=%d", static_cast<int>(mat.zero_point));
}
if (mat.cache_policy != CachePolicy::kNeverCache) {
s += formatted(", cache_policy=%d", static_cast<int>(mat.cache_policy));
}
return s;
}
inline std::string str(const PEMat& mat) {
std::string s =
formatted("PEMat, data_type=%s, %s", str(mat.data_type), str(mat.layout));
if (mat.zero_point) {
s += formatted(", zero_point=%d", static_cast<int>(mat.zero_point));
}
return s;
}
inline std::string str(Path paths) {
bool first = true;
std::string s;
for (int bit = 0; bit < 16; bit++) {
Path cur_path = static_cast<Path>(1 << bit);
if ((paths & cur_path) != Path::kNone) {
if (!first) {
s += " | ";
}
first = false;
switch (cur_path) {
case Path::kNone:
continue;
#define RUY_HANDLE_PATH …
RUY_HANDLE_PATH(kStandardCpp)
RUY_HANDLE_PATH(kInternalStandardCppVariant1)
RUY_HANDLE_PATH(kInternalStandardCppVariant2)
RUY_HANDLE_PATH(kInternalStandardCppVariant3)
#if RUY_PLATFORM_ARM
RUY_HANDLE_PATH(kNeon)
RUY_HANDLE_PATH(kNeonDotprod)
#endif
#if RUY_PLATFORM_X86
RUY_HANDLE_PATH(kAvx)
RUY_HANDLE_PATH(kAvx2Fma)
RUY_HANDLE_PATH(kAvx512)
#endif
#undef RUY_HANDLE_PATH
default:
fprintf(stderr, "Unhandled Path value 0x%x\n",
static_cast<int>(cur_path));
abort();
}
}
}
return s;
}
#define RUY_TRACE_INFO_MUL …
#define RUY_TRACE_INFO_CREATE_TRMUL_PARAMS_TRANSPOSING …
#define RUY_TRACE_INFO_CREATE_TRMUL_PARAMS_ASSUMING_COLMAJOR_DST …
#define RUY_TRACE_INFO_POPULATE_TRMUL_PARAMS …
#define RUY_TRACE_INFO_GET_RUNTIME_ENABLED_PATHS_USING_SET_VALUE …
#define RUY_TRACE_INFO_GET_RUNTIME_ENABLED_PATHS_USING_ENV_VAR …
#define RUY_TRACE_INFO_GET_RUNTIME_ENABLED_PATHS_USING_DETECTION …
#define RUY_TRACE_INFO_PREPARE_PACKED_MATRICES_SHOULD_CACHE …
#define RUY_TRACE_INFO_PREPARE_PACKED_MATRICES_NO_CACHE …
#define RUY_TRACE_INFO_GET_TENTATIVE_THREAD_COUNT …
#define RUY_TRACE_INFO_GET_USE_SIMPLE_LOOP_RETURNS_TRUE …
#define RUY_TRACE_INFO_GET_USE_SIMPLE_LOOP_RETURNS_FALSE …
#define RUY_TRACE_INFO_TRMUL_SIMPLE_LOOP …
#define RUY_TRACE_INFO_TRMUL_GENERAL_CASE …
#define RUY_TRACE_INFO_MAKE_BLOCK_MAP_START …
#define RUY_TRACE_INFO_MAKE_BLOCK_MAP_EACH_TENTATIVE_BLOCK_SIZE …
#define RUY_TRACE_INFO_MAKE_BLOCK_MAP_END …
#define RUY_TRACE_SET_THEAD_ID …
#define RUY_TRACE_INFO_TRMUL_TASK_MAIN_LOOP_GOT_BLOCK_COORDS …
#define RUY_TRACE_INFO_TRYPACK_PACKING …
#define RUY_TRACE_INFO_TRYPACK_ANOTHER_THREAD_PACKING …
#define RUY_TRACE_INFO_TRYPACK_PREVIOUSLY_PACKED …
#define RUY_TRACE_INFO_TRYPACK_PACKED_BY_ANOTHER_THREAD …
#define RUY_TRACE_INFO_ENSURE_PACKED_ENTER_RUN_AHEAD …
#define RUY_TRACE_INFO_ENSURE_PACKED_END …
#define RUY_TRACE_INFO_RUN_PACK …
#define RUY_TRACE_INFO_RUN_KERNEL …
#define RUY_TRACE_INFO_THREAD_FUNC_IMPL_WAITING …
#define RUY_TRACE_INFO_THREADPOOL_EXECUTE_STARTING_TASK …
#define RUY_TRACE_INFO_THREADPOOL_EXECUTE_STARTING_TASK_ZERO_ON_CUR_THREAD …
#define RUY_TRACE_INFO_THREADPOOL_EXECUTE_WAITING_FOR_THREADS …
#define RUY_TRACE_INFO …
}
#else
#define RUY_TRACE_SCOPE_NAME(name) …
#define RUY_TRACE_SCOPE
#define RUY_TRACE_SET_THEAD_ID(thread_id) …
#define RUY_TRACE_INFO(id) …
#endif
#endif