#include "sanitizer_platform.h"
#if SANITIZER_WINDOWS
# include <stddef.h>
# include "interception/interception.h"
# include "sanitizer_addrhashmap.h"
# include "sanitizer_common.h"
# include "sanitizer_internal_defs.h"
# include "sanitizer_placement_new.h"
# include "sanitizer_win_immortalize.h"
# include "sanitizer_win_interception.h"
using namespace __sanitizer;
extern "C" void *__ImageBase;
namespace __sanitizer {
static uptr GetSanitizerDllExport(const char *export_name) {
const uptr function_address =
__interception::InternalGetProcAddress(&__ImageBase, export_name);
if (function_address == 0) {
Report("ERROR: Failed to find sanitizer DLL export '%s'\n", export_name);
CHECK("Failed to find sanitizer DLL export" && 0);
}
return function_address;
}
struct WeakCallbackList {
explicit constexpr WeakCallbackList(RegisterWeakFunctionCallback cb)
: callback(cb), next(nullptr) {}
static void *operator new(size_t size) { return InternalAlloc(size); }
static void operator delete(void *p) { InternalFree(p); }
RegisterWeakFunctionCallback callback;
WeakCallbackList *next;
};
using WeakCallbackMap = AddrHashMap<WeakCallbackList *, 11>;
static WeakCallbackMap *GetWeakCallbackMap() {
return &immortalize<WeakCallbackMap>();
}
void AddRegisterWeakFunctionCallback(uptr export_address,
RegisterWeakFunctionCallback cb) {
WeakCallbackMap::Handle h_find_or_create(GetWeakCallbackMap(), export_address,
false, true);
CHECK(h_find_or_create.exists());
if (h_find_or_create.created()) {
*h_find_or_create = new WeakCallbackList(cb);
} else {
(*h_find_or_create)->next = new WeakCallbackList(cb);
}
}
static void RunWeakFunctionCallbacks(uptr export_address) {
WeakCallbackMap::Handle h_find(GetWeakCallbackMap(), export_address, false,
false);
if (!h_find.exists()) {
return;
}
WeakCallbackList *list = *h_find;
do {
list->callback();
} while ((list = list->next));
}
}
extern "C" __declspec(dllexport) bool __cdecl __sanitizer_override_function(
const char *export_name, const uptr user_function,
uptr *const old_user_function) {
CHECK(export_name);
CHECK(user_function);
const uptr sanitizer_function = GetSanitizerDllExport(export_name);
const bool function_overridden = __interception::OverrideFunction(
user_function, sanitizer_function, old_user_function);
if (!function_overridden) {
Report(
"ERROR: Failed to override local function at '%p' with sanitizer "
"function '%s'\n",
user_function, export_name);
CHECK("Failed to replace local function with sanitizer version." && 0);
}
return function_overridden;
}
extern "C"
__declspec(dllexport) bool __cdecl __sanitizer_override_function_by_addr(
const uptr source_function, const uptr target_function,
uptr *const old_target_function) {
CHECK(source_function);
CHECK(target_function);
const bool function_overridden = __interception::OverrideFunction(
target_function, source_function, old_target_function);
if (!function_overridden) {
Report(
"ERROR: Failed to override function at '%p' with function at "
"'%p'\n",
target_function, source_function);
CHECK("Failed to apply function override." && 0);
}
return function_overridden;
}
extern "C"
__declspec(dllexport) bool __cdecl __sanitizer_register_weak_function(
const char *export_name, const uptr user_function,
uptr *const old_user_function) {
CHECK(export_name);
CHECK(user_function);
const uptr sanitizer_function = GetSanitizerDllExport(export_name);
const bool function_overridden = __interception::OverrideFunction(
sanitizer_function, user_function, old_user_function);
if (!function_overridden) {
Report(
"ERROR: Failed to register local function at '%p' to be used in "
"place of sanitizer function '%s'\n.",
user_function, export_name);
CHECK("Failed to register weak function." && 0);
}
RunWeakFunctionCallbacks(sanitizer_function);
return function_overridden;
}
#endif