chromium/base/win/iat_patch_function.cc

// Copyright 2011 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "base/win/iat_patch_function.h"

#include "base/check_op.h"
#include "base/memory/raw_ptr_exclusion.h"
#include "base/notreached.h"
#include "base/win/patch_util.h"
#include "base/win/pe_image.h"

namespace base {
namespace win {

namespace {

struct InterceptFunctionInformation {
  bool finished_operation;
  const char* imported_from_module;
  const char* function_name;
  // RAW_PTR_EXCLUSION: #reinterpret-cast-trivial-type
  RAW_PTR_EXCLUSION void* new_function;
  RAW_PTR_EXCLUSION void** old_function;
  RAW_PTR_EXCLUSION IMAGE_THUNK_DATA** iat_thunk;
  DWORD return_code;
};

void* GetIATFunction(IMAGE_THUNK_DATA* iat_thunk) {
  if (!iat_thunk) {
    NOTREACHED();
  }

  // Works around the 64 bit portability warning:
  // The Function member inside IMAGE_THUNK_DATA is really a pointer
  // to the IAT function. IMAGE_THUNK_DATA correctly maps to IMAGE_THUNK_DATA32
  // or IMAGE_THUNK_DATA64 for correct pointer size.
  union FunctionThunk {
    IMAGE_THUNK_DATA thunk;
    // This field is not a raw_ptr<> because it was filtered by the rewriter
    // for: #union
    RAW_PTR_EXCLUSION void* pointer;
  } iat_function;

  iat_function.thunk = *iat_thunk;
  return iat_function.pointer;
}

bool InterceptEnumCallback(const base::win::PEImage& image,
                           const char* module,
                           DWORD ordinal,
                           const char* name,
                           DWORD hint,
                           IMAGE_THUNK_DATA* iat,
                           void* cookie) {
  InterceptFunctionInformation* intercept_information =
      reinterpret_cast<InterceptFunctionInformation*>(cookie);

  if (!intercept_information) {
    NOTREACHED();
  }

  DCHECK(module);

  if (name && (0 == lstrcmpiA(name, intercept_information->function_name))) {
    // Save the old pointer.
    if (intercept_information->old_function) {
      *(intercept_information->old_function) = GetIATFunction(iat);
    }

    if (intercept_information->iat_thunk) {
      *(intercept_information->iat_thunk) = iat;
    }

    // portability check
    static_assert(
        sizeof(iat->u1.Function) == sizeof(intercept_information->new_function),
        "unknown IAT thunk format");

    // Patch the function.
    intercept_information->return_code = internal::ModifyCode(
        &(iat->u1.Function), &(intercept_information->new_function),
        sizeof(intercept_information->new_function));

    // Terminate further enumeration.
    intercept_information->finished_operation = true;
    return false;
  }

  return true;
}

// Helper to intercept a function in an import table of a specific
// module.
//
// Arguments:
// module_handle          Module to be intercepted
// imported_from_module   Module that exports the symbol
// function_name          Name of the API to be intercepted
// new_function           Interceptor function
// old_function           Receives the original function pointer
// iat_thunk              Receives pointer to IAT_THUNK_DATA
//                        for the API from the import table.
//
// Returns: Returns NO_ERROR on success or Windows error code
//          as defined in winerror.h
DWORD InterceptImportedFunction(HMODULE module_handle,
                                const char* imported_from_module,
                                const char* function_name,
                                void* new_function,
                                void** old_function,
                                IMAGE_THUNK_DATA** iat_thunk) {
  if (!module_handle || !imported_from_module || !function_name ||
      !new_function) {
    NOTREACHED();
  }

  base::win::PEImage target_image(module_handle);
  if (!target_image.VerifyMagic()) {
    NOTREACHED();
  }

  InterceptFunctionInformation intercept_information = {false,
                                                        imported_from_module,
                                                        function_name,
                                                        new_function,
                                                        old_function,
                                                        iat_thunk,
                                                        ERROR_GEN_FAILURE};

  // First go through the IAT. If we don't find the import we are looking
  // for in IAT, search delay import table.
  target_image.EnumAllImports(InterceptEnumCallback, &intercept_information,
                              imported_from_module);
  if (!intercept_information.finished_operation) {
    target_image.EnumAllDelayImports(
        InterceptEnumCallback, &intercept_information, imported_from_module);
  }

  return intercept_information.return_code;
}

// Restore intercepted IAT entry with the original function.
//
// Arguments:
// intercept_function     Interceptor function
// original_function      Receives the original function pointer
//
// Returns: Returns NO_ERROR on success or Windows error code
//          as defined in winerror.h
DWORD RestoreImportedFunction(void* intercept_function,
                              void* original_function,
                              IMAGE_THUNK_DATA* iat_thunk) {
  if (!intercept_function || !original_function || !iat_thunk) {
    NOTREACHED();
  }

  if (GetIATFunction(iat_thunk) != intercept_function) {
    // Check if someone else has intercepted on top of us.
    // We cannot unpatch in this case, just raise a red flag.
    NOTREACHED();
  }

  return internal::ModifyCode(&(iat_thunk->u1.Function), &original_function,
                              sizeof(original_function));
}

}  // namespace

IATPatchFunction::IATPatchFunction() = default;

IATPatchFunction::~IATPatchFunction() {
  if (intercept_function_) {
    DWORD error = Unpatch();
    DCHECK_EQ(static_cast<DWORD>(NO_ERROR), error);
  }
}

DWORD IATPatchFunction::Patch(const wchar_t* module,
                              const char* imported_from_module,
                              const char* function_name,
                              void* new_function) {
  HMODULE module_handle = LoadLibraryW(module);
  if (!module_handle) {
    NOTREACHED();
  }

  DWORD error = PatchFromModule(module_handle, imported_from_module,
                                function_name, new_function);
  if (NO_ERROR == error) {
    module_handle_ = module_handle;
  } else {
    FreeLibrary(module_handle);
  }

  return error;
}

DWORD IATPatchFunction::PatchFromModule(HMODULE module,
                                        const char* imported_from_module,
                                        const char* function_name,
                                        void* new_function) {
  DCHECK_EQ(nullptr, original_function_);
  DCHECK_EQ(nullptr, iat_thunk_);
  DCHECK_EQ(nullptr, intercept_function_);
  DCHECK(module);

  DWORD error = InterceptImportedFunction(
      module, imported_from_module, function_name, new_function,
      &original_function_.AsEphemeralRawAddr(),
      &iat_thunk_.AsEphemeralRawAddr());

  if (NO_ERROR == error) {
    DCHECK_NE(original_function_, intercept_function_);
    intercept_function_ = new_function;
  }

  return error;
}

DWORD IATPatchFunction::Unpatch() {
  DWORD error = RestoreImportedFunction(intercept_function_, original_function_,
                                        iat_thunk_);
  DCHECK_EQ(static_cast<DWORD>(NO_ERROR), error);

  // Hands off the intercept if we fail to unpatch.
  // If IATPatchFunction::Unpatch fails during RestoreImportedFunction
  // it means that we cannot safely unpatch the import address table
  // patch. In this case its better to be hands off the intercept as
  // trying to unpatch again in the destructor of IATPatchFunction is
  // not going to be any safer
  if (module_handle_)
    FreeLibrary(module_handle_);
  module_handle_ = nullptr;
  intercept_function_ = nullptr;
  original_function_ = nullptr;
  iat_thunk_ = nullptr;

  return error;
}

void* IATPatchFunction::original_function() const {
  DCHECK(is_patched());
  return original_function_;
}

}  // namespace win
}  // namespace base