chromium/chrome/browser/win/conflicts/module_list_filter_unittest.cc

// Copyright 2018 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/browser/win/conflicts/module_list_filter.h"

#include <optional>
#include <string>
#include <utility>

#include "base/files/file_path.h"
#include "base/files/file_util.h"
#include "base/files/scoped_temp_dir.h"
#include "base/hash/sha1.h"
#include "base/i18n/case_conversion.h"
#include "base/strings/stringprintf.h"
#include "base/strings/utf_string_conversions.h"
#include "chrome/browser/win/conflicts/module_info.h"
#include "chrome/browser/win/conflicts/proto/module_list.pb.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace {

// Typedef for convenience.
using ModuleInfo = std::pair<ModuleInfoKey, ModuleInfoData>;

std::string GetCodeId(uint32_t module_time_date_stamp, uint32_t module_size) {
  return base::StringPrintf("%08X%x", module_time_date_stamp, module_size);
}

// Helper class to build and serialize a ModuleList.
class ModuleListBuilder {
 public:
  explicit ModuleListBuilder(const base::FilePath& module_list_path)
      : module_list_path_(module_list_path) {
    // Include an empty blocklist and allowlist.
    module_list_.mutable_blocklist();
    module_list_.mutable_allowlist();
  }

  ModuleListBuilder(const ModuleListBuilder&) = delete;
  ModuleListBuilder& operator=(const ModuleListBuilder&) = delete;

  // Adds a module to the allowlist.
  void AddAllowlistedModule(std::optional<std::u16string> basename,
                            std::optional<std::string> code_id) {
    CHECK(basename.has_value() || code_id.has_value());

    chrome::conflicts::ModuleGroup* module_group =
        module_list_.mutable_allowlist()->add_module_groups();

    chrome::conflicts::Module* module = module_group->add_modules();

    if (basename.has_value()) {
      module->set_basename_hash(base::SHA1HashString(
          base::UTF16ToUTF8(base::i18n::ToLower(basename.value()))));
    }

    if (code_id.has_value())
      module->set_code_id_hash(base::SHA1HashString(code_id.value()));
  }

  // Adds a module to the allowlist. Used when both the Code ID and the basename
  // must be set.
  void AddAllowlistedModule(const ModuleInfoKey& module_key,
                            const ModuleInfoData& module_data) {
    AddAllowlistedModule(
        module_data.inspection_result->basename,
        GetCodeId(module_key.module_time_date_stamp, module_key.module_size));
  }

  // Adds a module to the blocklist.
  void AddBlocklistedModule(
      const ModuleInfoKey& module_key,
      const ModuleInfoData& module_data,
      bool allow_load_value,
      chrome::conflicts::BlocklistMessageType message_type,
      const std::string& message_url) {
    chrome::conflicts::BlocklistModuleGroup* module_group =
        module_list_.mutable_blocklist()->add_module_groups();

    chrome::conflicts::BlocklistAction* blocklist_action =
        module_group->mutable_action();
    blocklist_action->set_allow_load(true);
    blocklist_action->set_message_type(message_type);
    blocklist_action->set_message_url(message_url);

    chrome::conflicts::Module* module =
        module_group->mutable_modules()->add_modules();

    module->set_basename_hash(base::SHA1HashString(base::UTF16ToUTF8(
        base::i18n::ToLower(module_data.inspection_result->basename))));

    module->set_code_id_hash(base::SHA1HashString(
        GetCodeId(module_key.module_time_date_stamp, module_key.module_size)));
  }

  // Serializes the |module_list_| to |module_list_path_|. Returns true on
  // success.
  bool Finalize() {
    std::string contents;
    return module_list_.SerializeToString(&contents) &&
           base::WriteFile(module_list_path_, contents);
  }

 private:
  const base::FilePath module_list_path_;

  chrome::conflicts::ModuleList module_list_;
};

// Creates a pair of ModuleInfoKey and ModuleInfoData with the necessary
// information to call in IsModuleAllowlisted().
ModuleInfo CreateModuleInfo(const base::FilePath& module_path,
                            uint32_t module_size,
                            uint32_t module_time_date_stamp) {
  ModuleInfo result(
      std::piecewise_construct,
      std::forward_as_tuple(module_path, module_size, module_time_date_stamp),
      std::forward_as_tuple());

  result.second.inspection_result =
      std::make_optional<ModuleInspectionResult>();
  result.second.inspection_result->basename =
      module_path.BaseName().AsUTF16Unsafe();

  return result;
}

constexpr wchar_t kDllPath1[] = L"c:\\path\\to\\module.dll";
constexpr wchar_t kDllPath2[] = L"c:\\some\\shellextension.dll";

}  // namespace

class ModuleListFilterTest : public ::testing::Test {
 public:
  ModuleListFilterTest(const ModuleListFilterTest&) = delete;
  ModuleListFilterTest& operator=(const ModuleListFilterTest&) = delete;

 protected:
  ModuleListFilterTest()
      : dll1_(kDllPath1),
        dll2_(kDllPath2),
        module_list_filter_(base::MakeRefCounted<ModuleListFilter>()) {}

  ~ModuleListFilterTest() override = default;

  void SetUp() override {
    ASSERT_TRUE(scoped_temp_dir_.CreateUniqueTempDir());
    module_list_path_ = scoped_temp_dir_.GetPath().Append(L"ModuleList.bin");
  }

  const base::FilePath& module_list_path() { return module_list_path_; }
  ModuleListFilter& module_list_filter() { return *module_list_filter_; }

  const base::FilePath dll1_;
  const base::FilePath dll2_;

 private:
  base::ScopedTempDir scoped_temp_dir_;

  base::FilePath module_list_path_;

  scoped_refptr<ModuleListFilter> module_list_filter_;
};

TEST_F(ModuleListFilterTest, IsAllowlistedStringPieceVersion) {
  std::u16string basename = u"basename.dll";  // Must be lowercase.
  std::string code_id = GetCodeId(12u, 32u);

  ModuleListBuilder module_list_builder(module_list_path());
  module_list_builder.AddAllowlistedModule(basename, code_id);
  ASSERT_TRUE(module_list_builder.Finalize());

  ASSERT_TRUE(module_list_filter().Initialize(module_list_path()));

  // Calculate hashes.
  std::string basename_hash = base::SHA1HashString(base::UTF16ToUTF8(basename));
  std::string code_id_hash = base::SHA1HashString(code_id);

  EXPECT_TRUE(module_list_filter().IsAllowlisted(basename_hash, code_id_hash));
}

TEST_F(ModuleListFilterTest, AllowlistedModules) {
  ModuleInfo module_1 = CreateModuleInfo(dll1_, 0123, 4567);
  ModuleInfo module_2 = CreateModuleInfo(dll2_, 7654, 3210);

  ModuleListBuilder module_list_builder(module_list_path());
  module_list_builder.AddAllowlistedModule(module_1.first, module_1.second);
  ASSERT_TRUE(module_list_builder.Finalize());

  ASSERT_TRUE(module_list_filter().Initialize(module_list_path()));

  EXPECT_TRUE(
      module_list_filter().IsAllowlisted(module_1.first, module_1.second));
  EXPECT_FALSE(
      module_list_filter().IsAllowlisted(module_2.first, module_2.second));
}

TEST_F(ModuleListFilterTest, BlocklistedModules) {
  const char kFurtherInfoURL[] = "http://www.further-info.com";

  ModuleInfo module_1 = CreateModuleInfo(dll1_, 0123, 4567);
  ModuleInfo module_2 = CreateModuleInfo(dll2_, 7654, 3210);

  ModuleListBuilder module_list_builder(module_list_path());
  module_list_builder.AddBlocklistedModule(
      module_1.first, module_1.second, true,
      chrome::conflicts::BlocklistMessageType::FURTHER_INFORMATION,
      kFurtherInfoURL);
  ASSERT_TRUE(module_list_builder.Finalize());

  ASSERT_TRUE(module_list_filter().Initialize(module_list_path()));

  std::unique_ptr<chrome::conflicts::BlocklistAction> blocklist_action =
      module_list_filter().IsBlocklisted(module_1.first, module_1.second);
  ASSERT_TRUE(blocklist_action);
  EXPECT_TRUE(blocklist_action->allow_load());
  EXPECT_EQ(chrome::conflicts::BlocklistMessageType::FURTHER_INFORMATION,
            blocklist_action->message_type());
  EXPECT_EQ(kFurtherInfoURL, blocklist_action->message_url());
  EXPECT_FALSE(
      module_list_filter().IsBlocklisted(module_2.first, module_2.second));
}

TEST_F(ModuleListFilterTest, BasenameOnly) {
  ModuleInfo original =
      CreateModuleInfo(base::FilePath(L"c:\\path\\basename.dll"), 1111, 0001);
  ModuleInfo same_basename = CreateModuleInfo(
      base::FilePath(L"c:\\wrong_path\\basename.dll"), 2222, 0002);
  ModuleInfo same_path = CreateModuleInfo(
      base::FilePath(L"c:\\path\\wrong_basename.dll"), 3333, 0003);
  ModuleInfo same_code_id = CreateModuleInfo(
      base::FilePath(L"c:\\wrong_path\\wrong_basename.dll"), 1111, 0001);

  ModuleListBuilder module_list_builder(module_list_path());
  module_list_builder.AddAllowlistedModule(
      original.second.inspection_result->basename, std::nullopt);
  ASSERT_TRUE(module_list_builder.Finalize());

  ASSERT_TRUE(module_list_filter().Initialize(module_list_path()));

  EXPECT_TRUE(
      module_list_filter().IsAllowlisted(original.first, original.second));
  EXPECT_TRUE(module_list_filter().IsAllowlisted(same_basename.first,
                                                 same_basename.second));
  EXPECT_FALSE(
      module_list_filter().IsAllowlisted(same_path.first, same_path.second));
  EXPECT_FALSE(module_list_filter().IsAllowlisted(same_code_id.first,
                                                  same_code_id.second));
}

TEST_F(ModuleListFilterTest, CodeIdOnly) {
  ModuleInfo original =
      CreateModuleInfo(base::FilePath(L"c:\\path\\basename.dll"), 1111, 0001);
  ModuleInfo same_basename = CreateModuleInfo(
      base::FilePath(L"c:\\wrong_path\\basename.dll"), 2222, 0002);
  ModuleInfo same_code_id = CreateModuleInfo(
      base::FilePath(L"c:\\wrong_path\\wrong_basename.dll"), 1111, 0001);

  ModuleListBuilder module_list_builder(module_list_path());
  module_list_builder.AddAllowlistedModule(
      std::nullopt, GetCodeId(original.first.module_time_date_stamp,
                              original.first.module_size));
  ASSERT_TRUE(module_list_builder.Finalize());

  ASSERT_TRUE(module_list_filter().Initialize(module_list_path()));

  EXPECT_TRUE(
      module_list_filter().IsAllowlisted(original.first, original.second));
  EXPECT_FALSE(module_list_filter().IsAllowlisted(same_basename.first,
                                                  same_basename.second));
  EXPECT_TRUE(module_list_filter().IsAllowlisted(same_code_id.first,
                                                 same_code_id.second));
}