chromium/sandbox/win/src/service_resolver_unittest.cc

// Copyright 2012 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

// This file contains unit tests for ServiceResolverThunk.

#include "sandbox/win/src/service_resolver.h"

#include <ntstatus.h>
#include <stddef.h>

#include <memory>

#include "base/bit_cast.h"
#include "base/memory/raw_ptr.h"
#include "sandbox/win/src/nt_internals.h"
#include "sandbox/win/src/resolver.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace {

// This is the concrete resolver used to perform service-call type functions
// inside ntdll.dll.
class ServiceResolverTest : public sandbox::ServiceResolverThunk {
 public:
  // The service resolver needs a child process to write to.
  explicit ServiceResolverTest(bool relaxed)
      : sandbox::ServiceResolverThunk(::GetCurrentProcess(), relaxed) {}

  ServiceResolverTest(const ServiceResolverTest&) = delete;
  ServiceResolverTest& operator=(const ServiceResolverTest&) = delete;

  // Sets the interception target to the desired address.
  void set_target(void* target) { fake_target_ = target; }

 protected:
  // Overrides Resolver::Init
  NTSTATUS Init(const void* target_module,
                const void* interceptor_module,
                const char* target_name,
                const char* interceptor_name,
                const void* interceptor_entry_point,
                void* thunk_storage,
                size_t storage_bytes) final {
    NTSTATUS ret = STATUS_SUCCESS;
    ret = sandbox::ServiceResolverThunk::Init(
        target_module, interceptor_module, target_name, interceptor_name,
        interceptor_entry_point, thunk_storage, storage_bytes);
    EXPECT_EQ(STATUS_SUCCESS, ret);

    this->target_ = fake_target_;

    return ret;
  }

  // Holds the address of the fake target.
  raw_ptr<void> fake_target_;
};

NTSTATUS PatchNtdllWithResolver(const char* function,
                                bool relaxed,
                                ServiceResolverTest& resolver) {
  HMODULE ntdll_base = ::GetModuleHandle(L"ntdll.dll");
  EXPECT_TRUE(ntdll_base);

  void* target =
      reinterpret_cast<void*>(::GetProcAddress(ntdll_base, function));
  EXPECT_TRUE(target);
  if (!target)
    return STATUS_UNSUCCESSFUL;

  BYTE service[50];
  memcpy(service, target, sizeof(service));

  resolver.set_target(service);

  // Any pointer will do as an interception_entry_point
  void* function_entry = &resolver;
  size_t thunk_size = resolver.GetThunkSize();
  std::unique_ptr<char[]> thunk = std::make_unique<char[]>(thunk_size);
  size_t used;

  resolver.AllowLocalPatches();

  NTSTATUS ret = resolver.Setup(ntdll_base, nullptr, function, nullptr,
                                function_entry, thunk.get(), thunk_size, &used);
  if (NT_SUCCESS(ret)) {
    const BYTE kJump32 = 0xE9;
    EXPECT_EQ(thunk_size, used);
    EXPECT_NE(0, memcmp(service, target, sizeof(service)));
    EXPECT_NE(kJump32, service[0]);

    if (relaxed) {
      // It's already patched, let's patch again, and simulate a direct patch.
      service[0] = kJump32;
      ret = resolver.Setup(ntdll_base, nullptr, function, nullptr,
                           function_entry, thunk.get(), thunk_size, &used);
      EXPECT_TRUE(resolver.VerifyJumpTargetForTesting(thunk.get()));
    }
  }

  return ret;
}

NTSTATUS PatchNtdll(const char* function, bool relaxed) {
  ServiceResolverTest resolver(relaxed);
  return PatchNtdllWithResolver(function, relaxed, resolver);
}

TEST(ServiceResolverTest, PatchesServices) {
  NTSTATUS ret = PatchNtdll("NtClose", false);
  EXPECT_EQ(STATUS_SUCCESS, ret) << "NtClose, last error: " << ::GetLastError();

  ret = PatchNtdll("NtCreateFile", false);
  EXPECT_EQ(STATUS_SUCCESS, ret)
      << "NtCreateFile, last error: " << ::GetLastError();

  ret = PatchNtdll("NtCreateMutant", false);
  EXPECT_EQ(STATUS_SUCCESS, ret)
      << "NtCreateMutant, last error: " << ::GetLastError();

  ret = PatchNtdll("NtMapViewOfSection", false);
  EXPECT_EQ(STATUS_SUCCESS, ret)
      << "NtMapViewOfSection, last error: " << ::GetLastError();
}

TEST(ServiceResolverTest, FailsIfNotService) {
#if !defined(_WIN64)
  EXPECT_NE(STATUS_SUCCESS, PatchNtdll("RtlUlongByteSwap", false));
#endif

  EXPECT_NE(STATUS_SUCCESS, PatchNtdll("LdrLoadDll", false));
}

TEST(ServiceResolverTest, PatchesPatchedServices) {
// We don't support "relaxed mode" for Win64 apps.
#if !defined(_WIN64)
  NTSTATUS ret = PatchNtdll("NtClose", true);
  EXPECT_EQ(STATUS_SUCCESS, ret) << "NtClose, last error: " << ::GetLastError();

  ret = PatchNtdll("NtCreateFile", true);
  EXPECT_EQ(STATUS_SUCCESS, ret)
      << "NtCreateFile, last error: " << ::GetLastError();

  ret = PatchNtdll("NtCreateMutant", true);
  EXPECT_EQ(STATUS_SUCCESS, ret)
      << "NtCreateMutant, last error: " << ::GetLastError();

  ret = PatchNtdll("NtMapViewOfSection", true);
  EXPECT_EQ(STATUS_SUCCESS, ret)
      << "NtMapViewOfSection, last error: " << ::GetLastError();
#endif
}

TEST(ServiceResolverTest, MultiplePatchedServices) {
// We don't support "relaxed mode" for Win64 apps.
#if !defined(_WIN64)
  ServiceResolverTest thunk_test(true);
  NTSTATUS ret = PatchNtdllWithResolver("NtClose", true, thunk_test);
  EXPECT_EQ(STATUS_SUCCESS, ret) << "NtClose, last error: " << ::GetLastError();

  ret = PatchNtdllWithResolver("NtCreateFile", true, thunk_test);
  EXPECT_EQ(STATUS_SUCCESS, ret)
      << "NtCreateFile, last error: " << ::GetLastError();

  ret = PatchNtdllWithResolver("NtCreateMutant", true, thunk_test);
  EXPECT_EQ(STATUS_SUCCESS, ret)
      << "NtCreateMutant, last error: " << ::GetLastError();

  ret = PatchNtdllWithResolver("NtMapViewOfSection", true, thunk_test);
  EXPECT_EQ(STATUS_SUCCESS, ret)
      << "NtMapViewOfSection, last error: " << ::GetLastError();
#endif
}

TEST(ServiceResolverTest, LocalPatchesAllowed) {
  ServiceResolverTest resolver(true);

  HMODULE ntdll_base = ::GetModuleHandle(L"ntdll.dll");
  ASSERT_TRUE(ntdll_base);

  const char kFunctionName[] = "NtClose";

  void* target =
      reinterpret_cast<void*>(::GetProcAddress(ntdll_base, kFunctionName));
  ASSERT_TRUE(target);

  BYTE service[50];
  memcpy(service, target, sizeof(service));
  resolver.set_target(service);

  // Any pointer will do as an interception_entry_point
  void* function_entry = &resolver;
  size_t thunk_size = resolver.GetThunkSize();
  std::unique_ptr<char[]> thunk = std::make_unique<char[]>(thunk_size);
  size_t used;

  NTSTATUS ret = STATUS_UNSUCCESSFUL;

  // First try patching without having allowed local patches.
  ret = resolver.Setup(ntdll_base, nullptr, kFunctionName, nullptr,
                       function_entry, thunk.get(), thunk_size, &used);
  EXPECT_FALSE(NT_SUCCESS(ret));

  // Now allow local patches and check that things work.
  resolver.AllowLocalPatches();
  ret = resolver.Setup(ntdll_base, nullptr, kFunctionName, nullptr,
                       function_entry, thunk.get(), thunk_size, &used);
  EXPECT_EQ(STATUS_SUCCESS, ret);
}

}  // namespace