chromium/chrome/chrome_elf/third_party_dlls/main_unittest_exe.cc

// Copyright 2018 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/chrome_elf/third_party_dlls/main_unittest_exe.h"

#include <windows.h>

#include <shellapi.h>
#include <stdlib.h>

#include <memory>

#include "base/files/file.h"
#include "base/files/file_util.h"
#include "base/scoped_native_library.h"
#include "base/strings/utf_string_conversions.h"
#include "base/test/test_reg_util_win.h"
#include "chrome/chrome_elf/nt_registry/nt_registry.h"
#include "chrome/chrome_elf/third_party_dlls/main.h"
#include "chrome/chrome_elf/third_party_dlls/packed_list_file.h"
#include "chrome/chrome_elf/third_party_dlls/public_api.h"
#include "chrome/install_static/install_util.h"
#include "chrome/install_static/product_install_details.h"

namespace {

// Function object which invokes LocalFree on its parameter, which must be
// a pointer.  To be used with std::unique_ptr and CommandLineToArgvW().
struct LocalFreeDeleter {
  inline void operator()(wchar_t** ptr) const { ::LocalFree(ptr); }
};

// Attempt to load a given DLL.
third_party_dlls::ExitCode LoadDll(std::wstring name) {
  base::FilePath dll_path(name);
  base::ScopedNativeLibrary dll(dll_path);
  return dll.is_valid() ? third_party_dlls::kDllLoadSuccess
                        : third_party_dlls::kDllLoadFailed;
}

// Utility function to protect the local registry.
void RegRedirect(registry_util::RegistryOverrideManager* rom) {
  std::wstring temp;
  rom->OverrideRegistry(HKEY_CURRENT_USER, &temp);
  nt::SetTestingOverride(nt::HKCU, temp);
}

// Compare an argument path with a module-load log path.
// - |arg_path| is a UTF-16 drive path.
// - |log.section_path| is UTF-8, and will be a device path, so convert to drive
//   letter before comparing.
bool MatchPath(const wchar_t* arg_path, const third_party_dlls::LogEntry& log) {
  base::FilePath drive_path;
  if (!base::DevicePathToDriveLetterPath(
          base::FilePath(base::UTF8ToWide(log.path)), &drive_path)) {
    return false;
  }

  return drive_path.value().compare(arg_path) == 0;
}

}  // namespace

//------------------------------------------------------------------------------
// PUBLIC
//------------------------------------------------------------------------------

// Good ol' main.
// - Init third_party_dlls, which will apply a hook to NtMapViewOfSection.
// - Attempt to load a specific DLL.
//
// Arguments:
// #1: path to test blocklist file (mandatory).
// #2: test identifier (mandatory).
// #3: path to dll (test-identifier dependent).
//
// Returns:
// - Negative values in case of unexpected error.
// - 0 for successful DLL load.
// - 1 for failed DLL load.
int main() {
  // NOTE: The arguments must be treated as unicode for these tests.
  int argument_count = 0;
  std::unique_ptr<wchar_t*[], LocalFreeDeleter> argv(
      ::CommandLineToArgvW(::GetCommandLineW(), &argument_count));
  if (!argv)
    return third_party_dlls::kBadCommandLine;

  if (IsThirdPartyInitialized())
    return third_party_dlls::kThirdPartyAlreadyInitialized;

  install_static::InitializeProductDetailsForPrimaryModule();
  install_static::InitializeProcessType();

  // Get the required arguments, path to blocklist file and test id to run.
  if (argument_count < 3)
    return third_party_dlls::kMissingArgument;

  const wchar_t* blocklist_path = argv[1];
  if (!blocklist_path || ::wcslen(blocklist_path) == 0)
    return third_party_dlls::kBadBlocklistPath;

  const wchar_t* arg2 = argv[2];
  int test_id = ::_wtoi(arg2);
  if (!test_id)
    return third_party_dlls::kUnsupportedTestId;

  // Override blocklist path before initializing.
  third_party_dlls::OverrideFilePathForTesting(blocklist_path);

  // Enable a registry test net before initializing.
  registry_util::RegistryOverrideManager rom;
  RegRedirect(&rom);

  if (!third_party_dlls::Init())
    return third_party_dlls::kThirdPartyInitFailure;

  switch (test_id) {
    case third_party_dlls::kTestOnlyInitialization:
      break;

    case third_party_dlls::kTestSingleDllLoad:
    case third_party_dlls::kTestLogPath: {
      if (argument_count < 4)
        return third_party_dlls::kMissingArgument;
      const wchar_t* dll_name = argv[3];
      if (!dll_name || ::wcslen(dll_name) == 0)
        return third_party_dlls::kBadArgument;
      third_party_dlls::ExitCode code = LoadDll(dll_name);

      // Get logging.  Ensure the log is as expected.
      uint32_t bytes = 0;
      DrainLog(nullptr, 0, &bytes);
      if (!bytes)
        return third_party_dlls::kEmptyLog;
      auto buffer = std::make_unique<uint8_t[]>(bytes);
      bytes = DrainLog(&buffer[0], bytes, nullptr);
      third_party_dlls::LogEntry* entry =
          reinterpret_cast<third_party_dlls::LogEntry*>(&buffer[0]);
      if (!bytes || bytes < third_party_dlls::GetLogEntrySize(entry->path_len))
        return third_party_dlls::kBadLogEntrySize;

      if ((code == third_party_dlls::kDllLoadFailed &&
           entry->type != third_party_dlls::kBlocked) ||
          (code == third_party_dlls::kDllLoadSuccess &&
           entry->type != third_party_dlls::kAllowed)) {
        return third_party_dlls::kUnexpectedLog;
      }

      if (test_id == third_party_dlls::kTestLogPath &&
          !MatchPath(dll_name, *entry))
        return third_party_dlls::kUnexpectedSectionPath;

      return code;
    }

    default:
      return third_party_dlls::kUnsupportedTestId;
  }

  return 0;
}