chromium/chrome/browser/win/conflicts/enumerate_shell_extensions_unittest.cc

// Copyright 2017 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/browser/win/conflicts/enumerate_shell_extensions.h"

#include <vector>

#include "base/containers/contains.h"
#include "base/files/file_path.h"
#include "base/functional/bind.h"
#include "base/path_service.h"
#include "base/strings/strcat_win.h"
#include "base/test/task_environment.h"
#include "base/test/test_reg_util_win.h"
#include "chrome/browser/win/conflicts/module_info_util.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace {

class EnumerateShellExtensionsTest : public testing::Test {
 public:
  EnumerateShellExtensionsTest() = default;

  EnumerateShellExtensionsTest(const EnumerateShellExtensionsTest&) = delete;
  EnumerateShellExtensionsTest& operator=(const EnumerateShellExtensionsTest&) =
      delete;

  ~EnumerateShellExtensionsTest() override = default;

  // Override all registry hives so that real shell extensions don't mess up
  // the unit tests.
  void SetUp() override {
    ASSERT_NO_FATAL_FAILURE(
        registry_override_manager_.OverrideRegistry(HKEY_CLASSES_ROOT));
    ASSERT_NO_FATAL_FAILURE(
        registry_override_manager_.OverrideRegistry(HKEY_CURRENT_USER));
    ASSERT_NO_FATAL_FAILURE(
        registry_override_manager_.OverrideRegistry(HKEY_LOCAL_MACHINE));
  }

  void RunUntilIdle() { task_environment_.RunUntilIdle(); }

 private:
  base::test::TaskEnvironment task_environment_;

  registry_util::RegistryOverrideManager registry_override_manager_;
};

// Adds a fake shell extension entry to the registry that should be found by
// the ShellExtensionEnumerator. The call must be wrapped inside an
// ASSERT_NO_FATAL_FAILURE.
void RegisterFakeApprovedShellExtension(HKEY key,
                                        const wchar_t* guid,
                                        const wchar_t* path) {
  base::win::RegKey class_id(HKEY_CLASSES_ROOT, GuidToClsid(guid).c_str(),
                             KEY_WRITE);
  ASSERT_TRUE(class_id.Valid());

  ASSERT_EQ(ERROR_SUCCESS, class_id.WriteValue(nullptr, path));

  base::win::RegKey registration(key, kApprovedShellExtensionRegistryKey,
                                 KEY_WRITE);
  ASSERT_EQ(ERROR_SUCCESS, registration.WriteValue(guid, L""));
}

// Adds a fake shell extension entry to the registry that should be found by
// the ShellExtensionEnumerator. The call must be wrapped inside an
// ASSERT_NO_FATAL_FAILURE.
void RegisterFakeShellExtension(const wchar_t* guid,
                                const wchar_t* path,
                                const wchar_t* shell_extension_type,
                                const wchar_t* shell_object_type) {
  base::win::RegKey class_id(HKEY_CLASSES_ROOT, GuidToClsid(guid).c_str(),
                             KEY_WRITE);
  ASSERT_TRUE(class_id.Valid());

  ASSERT_EQ(ERROR_SUCCESS, class_id.WriteValue(nullptr, path));

  base::win::RegKey registration(
      HKEY_CLASSES_ROOT,
      base::StrCat({shell_object_type, L"\\shellex\\", shell_extension_type,
                    L"\\", guid})
          .c_str(),
      KEY_WRITE);
  ASSERT_EQ(ERROR_SUCCESS, registration.WriteValue(nullptr, guid));
}

void OnShellExtensionPathEnumerated(
    std::vector<base::FilePath>* shell_extension_paths,
    const base::FilePath& path) {
  shell_extension_paths->push_back(path);
}

void OnShellExtensionEnumerated(
    std::vector<base::FilePath>* shell_extension_paths,
    const base::FilePath& path,
    uint32_t size_of_image,
    uint32_t time_date_stamp) {
  shell_extension_paths->push_back(path);
}

void OnEnumerationFinished(bool* is_enumeration_finished) {
  *is_enumeration_finished = true;
}

}  // namespace

// Registers a few fake approved shell extensions then see if
// EnumerateShellExtensionPaths() finds them.
TEST_F(EnumerateShellExtensionsTest, EnumerateApprovedShellExtensionPaths) {
  struct {
    HKEY hkey_root;
    const wchar_t* guid;
    const wchar_t* path;
  } kTestCases[] = {
      {HKEY_CURRENT_USER, L"{FAKE_GUID_0001}", L"c:\\module.dll"},
      {HKEY_LOCAL_MACHINE, L"{FAKE_GUID_0002}", L"c:\\dir\\shell_ext.dll"},
      {HKEY_LOCAL_MACHINE, L"{FAKE_GUID_0003}", L"c:\\path\\test.dll"},
  };

  // Register all fake shell extensions in kTestCases.
  for (const auto& test_case : kTestCases) {
    ASSERT_NO_FATAL_FAILURE(RegisterFakeApprovedShellExtension(
        test_case.hkey_root, test_case.guid, test_case.path));
  }

  std::vector<base::FilePath> shell_extension_paths;
  internal::EnumerateShellExtensionPaths(
      base::BindRepeating(&OnShellExtensionPathEnumerated,
                          base::Unretained(&shell_extension_paths)));

  ASSERT_EQ(3u, shell_extension_paths.size());
  for (const auto& test_case : kTestCases) {
    // The inefficiency is fine as long as the number of test cases stays small.
    EXPECT_TRUE(
        base::Contains(shell_extension_paths, base::FilePath(test_case.path)));
  }
}

TEST_F(EnumerateShellExtensionsTest, EnumerateApprovedShellExtensions) {
  // Use the current exe file as an arbitrary module that exists.
  base::FilePath file_exe;
  ASSERT_TRUE(base::PathService::Get(base::FILE_EXE, &file_exe));
  ASSERT_NO_FATAL_FAILURE(RegisterFakeApprovedShellExtension(
      HKEY_LOCAL_MACHINE, L"{FAKE_GUID}", file_exe.value().c_str()));

  std::vector<base::FilePath> shell_extension_paths;
  bool is_enumeration_finished = false;
  EnumerateShellExtensions(
      base::BindRepeating(&OnShellExtensionEnumerated,
                          base::Unretained(&shell_extension_paths)),
      base::BindOnce(&OnEnumerationFinished,
                     base::Unretained(&is_enumeration_finished)));

  RunUntilIdle();

  EXPECT_TRUE(is_enumeration_finished);
  ASSERT_EQ(1u, shell_extension_paths.size());
  EXPECT_EQ(file_exe, shell_extension_paths[0]);
}

TEST_F(EnumerateShellExtensionsTest, EnumerateShellExtensionPaths) {
  struct {
    const wchar_t* guid;
    const wchar_t* path;
    const wchar_t* shell_extension_type;
    const wchar_t* shell_object_type;
  } kTestCases[] = {
      {L"{FAKE_GUID_0001}", L"c:\\module.dll", L"ColumnHandlers", L"Folder"},
      {L"{FAKE_GUID_0002}", L"c:\\dir\\shell_ext.dll", L"ContextMenuHandlers",
       L"*"},
      {L"{FAKE_GUID_0003}", L"c:\\path\\test.dll", L"CopyHookHandlers",
       L"Printers"},
      {L"{FAKE_GUID_0004}", L"c:\\foo\\bar.dll", L"DragDropHandlers", L"Drive"},
      {L"{FAKE_GUID_0005}", L"c:\\foo\\baz.dll", L"PropertySheetHandlers",
       L"AllFileSystemObjects"},
  };

  // Register all fake shell extensions in kTestCases.
  for (const auto& test_case : kTestCases) {
    ASSERT_NO_FATAL_FAILURE(RegisterFakeShellExtension(
        test_case.guid, test_case.path, test_case.shell_extension_type,
        test_case.shell_object_type));
  }

  std::vector<base::FilePath> shell_extension_paths;
  internal::EnumerateShellExtensionPaths(
      base::BindRepeating(&OnShellExtensionPathEnumerated,
                          base::Unretained(&shell_extension_paths)));

  ASSERT_EQ(5u, shell_extension_paths.size());
  for (const auto& test_case : kTestCases) {
    // The inefficiency is fine as long as the number of test cases stays small.
    EXPECT_TRUE(
        base::Contains(shell_extension_paths, base::FilePath(test_case.path)));
  }
}