chromium/base/win/com_init_check_hook_unittest.cc

// Copyright 2017 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "base/win/com_init_check_hook.h"

#include <objbase.h>

#include <shlobj.h>
#include <wrl/client.h>

#include "base/test/gtest_util.h"
#include "base/win/com_init_util.h"
#include "base/win/patch_util.h"
#include "base/win/scoped_com_initializer.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace base {
namespace win {

using Microsoft::WRL::ComPtr;

TEST(ComInitCheckHook, AssertNotInitialized) {
  ComInitCheckHook com_check_hook;
  AssertComApartmentType(ComApartmentType::NONE);
  ComPtr<IUnknown> shell_link;
#if defined(COM_INIT_CHECK_HOOK_ENABLED)
  EXPECT_DCHECK_DEATH(::CoCreateInstance(CLSID_ShellLink, nullptr, CLSCTX_ALL,
                                         IID_PPV_ARGS(&shell_link)));
#else
  EXPECT_EQ(CO_E_NOTINITIALIZED,
            ::CoCreateInstance(CLSID_ShellLink, nullptr, CLSCTX_ALL,
                               IID_PPV_ARGS(&shell_link)));
#endif
}

TEST(ComInitCheckHook, HookRemoval) {
  AssertComApartmentType(ComApartmentType::NONE);
  { ComInitCheckHook com_check_hook; }
  ComPtr<IUnknown> shell_link;
  EXPECT_EQ(CO_E_NOTINITIALIZED,
            ::CoCreateInstance(CLSID_ShellLink, nullptr, CLSCTX_ALL,
                               IID_PPV_ARGS(&shell_link)));
}

TEST(ComInitCheckHook, NoAssertComInitialized) {
  ComInitCheckHook com_check_hook;
  ScopedCOMInitializer com_initializer;
  ComPtr<IUnknown> shell_link;
  EXPECT_TRUE(SUCCEEDED(::CoCreateInstance(CLSID_ShellLink, nullptr, CLSCTX_ALL,
                                           IID_PPV_ARGS(&shell_link))));
}

TEST(ComInitCheckHook, MultipleHooks) {
  ComInitCheckHook com_check_hook_1;
  ComInitCheckHook com_check_hook_2;
  AssertComApartmentType(ComApartmentType::NONE);
  ComPtr<IUnknown> shell_link;
#if defined(COM_INIT_CHECK_HOOK_ENABLED)
  EXPECT_DCHECK_DEATH(::CoCreateInstance(CLSID_ShellLink, nullptr, CLSCTX_ALL,
                                         IID_PPV_ARGS(&shell_link)));
#else
  EXPECT_EQ(CO_E_NOTINITIALIZED,
            ::CoCreateInstance(CLSID_ShellLink, nullptr, CLSCTX_ALL,
                               IID_PPV_ARGS(&shell_link)));
#endif
}

TEST(ComInitCheckHook, UnexpectedHook) {
#if defined(COM_INIT_CHECK_HOOK_ENABLED)
  HMODULE ole32_library = ::LoadLibrary(L"ole32.dll");
  ASSERT_TRUE(ole32_library);

  uint32_t co_create_instance_padded_address =
      reinterpret_cast<uint32_t>(
          GetProcAddress(ole32_library, "CoCreateInstance")) -
      5;
  const unsigned char* co_create_instance_bytes =
      reinterpret_cast<const unsigned char*>(co_create_instance_padded_address);
  const unsigned char original_byte = co_create_instance_bytes[0];
  const unsigned char unexpected_byte = 0xdb;
  ASSERT_EQ(static_cast<DWORD>(NO_ERROR),
            internal::ModifyCode(
                reinterpret_cast<void*>(co_create_instance_padded_address),
                reinterpret_cast<const void*>(&unexpected_byte),
                sizeof(unexpected_byte)));

  EXPECT_DCHECK_DEATH({ ComInitCheckHook com_check_hook; });

  // If this call fails, really bad things are going to happen to other tests
  // so CHECK here.
  CHECK_EQ(static_cast<DWORD>(NO_ERROR),
           internal::ModifyCode(
               reinterpret_cast<void*>(co_create_instance_padded_address),
               reinterpret_cast<const void*>(&original_byte),
               sizeof(original_byte)));

  ::FreeLibrary(ole32_library);
  ole32_library = nullptr;
#endif
}

TEST(ComInitCheckHook, ExternallyHooked) {
#if defined(COM_INIT_CHECK_HOOK_ENABLED)
  HMODULE ole32_library = ::LoadLibrary(L"ole32.dll");
  ASSERT_TRUE(ole32_library);

  uint32_t co_create_instance_address = reinterpret_cast<uint32_t>(
      GetProcAddress(ole32_library, "CoCreateInstance"));
  const unsigned char* co_create_instance_bytes =
      reinterpret_cast<const unsigned char*>(co_create_instance_address);
  const unsigned char original_byte = co_create_instance_bytes[0];
  const unsigned char jmp_byte = 0xe9;
  ASSERT_EQ(static_cast<DWORD>(NO_ERROR),
            internal::ModifyCode(
                reinterpret_cast<void*>(co_create_instance_address),
                reinterpret_cast<const void*>(&jmp_byte), sizeof(jmp_byte)));

  // Externally patched instances should crash so we catch these cases on bots.
  EXPECT_DCHECK_DEATH({ ComInitCheckHook com_check_hook; });

  // If this call fails, really bad things are going to happen to other tests
  // so CHECK here.
  CHECK_EQ(
      static_cast<DWORD>(NO_ERROR),
      internal::ModifyCode(reinterpret_cast<void*>(co_create_instance_address),
                           reinterpret_cast<const void*>(&original_byte),
                           sizeof(original_byte)));

  ::FreeLibrary(ole32_library);
  ole32_library = nullptr;
#endif
}

TEST(ComInitCheckHook, UnexpectedChangeDuringHook) {
#if defined(COM_INIT_CHECK_HOOK_ENABLED)
  HMODULE ole32_library = ::LoadLibrary(L"ole32.dll");
  ASSERT_TRUE(ole32_library);

  uint32_t co_create_instance_padded_address =
      reinterpret_cast<uint32_t>(
          GetProcAddress(ole32_library, "CoCreateInstance")) -
      5;
  const unsigned char* co_create_instance_bytes =
      reinterpret_cast<const unsigned char*>(co_create_instance_padded_address);
  const unsigned char original_byte = co_create_instance_bytes[0];
  const unsigned char unexpected_byte = 0xdb;
  ASSERT_EQ(static_cast<DWORD>(NO_ERROR),
            internal::ModifyCode(
                reinterpret_cast<void*>(co_create_instance_padded_address),
                reinterpret_cast<const void*>(&unexpected_byte),
                sizeof(unexpected_byte)));

  EXPECT_DCHECK_DEATH({
    ComInitCheckHook com_check_hook;

    internal::ModifyCode(
        reinterpret_cast<void*>(co_create_instance_padded_address),
        reinterpret_cast<const void*>(&unexpected_byte),
        sizeof(unexpected_byte));
  });

  // If this call fails, really bad things are going to happen to other tests
  // so CHECK here.
  CHECK_EQ(static_cast<DWORD>(NO_ERROR),
           internal::ModifyCode(
               reinterpret_cast<void*>(co_create_instance_padded_address),
               reinterpret_cast<const void*>(&original_byte),
               sizeof(original_byte)));

  ::FreeLibrary(ole32_library);
  ole32_library = nullptr;
#endif
}

}  // namespace win
}  // namespace base