chromium/net/http/mock_sspi_library_win.cc

// Copyright 2010 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "net/http/mock_sspi_library_win.h"

#include <algorithm>
#include <cstring>
#include <memory>
#include <string>

#include "base/check_op.h"
#include "base/memory/raw_ptr.h"
#include "base/strings/string_util_win.h"
#include "base/strings/stringprintf.h"
#include "base/strings/utf_string_conversions.h"
#include "base/time/time.h"
#include "testing/gtest/include/gtest/gtest.h"

// Comparator so we can use CredHandle and CtxtHandle with std::set. Both of
// those classes are typedefs for _SecHandle.
bool operator<(const _SecHandle left, const _SecHandle right) {
  return left.dwUpper < right.dwUpper || left.dwLower < right.dwLower;
}

namespace net {

namespace {

int uniquifier_ = 0;

struct MockCredential {
  std::u16string source_principal;
  std::u16string package;
  bool has_explicit_credentials = false;
  int uniquifier = ++uniquifier_;

  // CredHandle and CtxtHandle both shared the following definition:
  //
  // typedef struct _SecHandle {
  //   ULONG_PTR       dwLower;
  //   ULONG_PTR       dwUpper;
  // } SecHandle, * PSecHandle;
  //
  // ULONG_PTR type can hold a pointer. This function stuffs |this| into dwUpper
  // and adds a uniquifier to dwLower. This ensures that all PCredHandles issued
  // by this method during the lifetime of this process is unique.
  void StoreInHandle(PCredHandle handle) {
    DCHECK(uniquifier > 0);
    EXPECT_FALSE(SecIsValidHandle(handle));

    handle->dwLower = uniquifier;
    handle->dwUpper = reinterpret_cast<ULONG_PTR>(this);

    DCHECK(SecIsValidHandle(handle));
  }

  static MockCredential* FromHandle(PCredHandle handle) {
    return reinterpret_cast<MockCredential*>(handle->dwUpper);
  }
};

struct MockContext {
  raw_ptr<MockCredential> credential = nullptr;
  std::u16string target_principal;
  int uniquifier = ++uniquifier_;
  int rounds = 0;

  // CredHandle and CtxtHandle both shared the following definition:
  //
  // typedef struct _SecHandle {
  //   ULONG_PTR       dwLower;
  //   ULONG_PTR       dwUpper;
  // } SecHandle, * PSecHandle;
  //
  // ULONG_PTR type can hold a pointer. This function stuffs |this| into dwUpper
  // and adds a uniquifier to dwLower. This ensures that all PCredHandles issued
  // by this method during the lifetime of this process is unique.
  void StoreInHandle(PCtxtHandle handle) {
    EXPECT_FALSE(SecIsValidHandle(handle));
    DCHECK(uniquifier > 0);

    handle->dwLower = uniquifier;
    handle->dwUpper = reinterpret_cast<ULONG_PTR>(this);

    DCHECK(SecIsValidHandle(handle));
  }

  std::string ToString() const {
    return base::StringPrintf(
        "%s's token #%d for %s",
        base::UTF16ToUTF8(credential->source_principal).c_str(), rounds + 1,
        base::UTF16ToUTF8(target_principal).c_str());
  }

  static MockContext* FromHandle(PCtxtHandle handle) {
    return reinterpret_cast<MockContext*>(handle->dwUpper);
  }
};

}  // namespace

MockSSPILibrary::MockSSPILibrary(const wchar_t* package)
    : SSPILibrary(package) {}

MockSSPILibrary::~MockSSPILibrary() {
  EXPECT_TRUE(expected_package_queries_.empty());
  EXPECT_TRUE(expected_freed_packages_.empty());
  EXPECT_TRUE(active_credentials_.empty());
  EXPECT_TRUE(active_contexts_.empty());
}

SECURITY_STATUS MockSSPILibrary::AcquireCredentialsHandle(
    LPWSTR pszPrincipal,
    unsigned long fCredentialUse,
    void* pvLogonId,
    void* pvAuthData,
    SEC_GET_KEY_FN pGetKeyFn,
    void* pvGetKeyArgument,
    PCredHandle phCredential,
    PTimeStamp ptsExpiry) {
  DCHECK(!SecIsValidHandle(phCredential));
  auto* credential = new MockCredential;
  credential->source_principal =
      pszPrincipal ? base::as_u16cstr(pszPrincipal) : u"<Default>";
  credential->package = base::as_u16cstr(package_name_.c_str());
  credential->has_explicit_credentials = !!pvAuthData;

  credential->StoreInHandle(phCredential);

  if (ptsExpiry) {
    ptsExpiry->LowPart = 0xBAA5B780;
    ptsExpiry->HighPart = 0x01D54E17;
  }

  active_credentials_.insert(*phCredential);
  return SEC_E_OK;
}

SECURITY_STATUS MockSSPILibrary::InitializeSecurityContext(
    PCredHandle phCredential,
    PCtxtHandle phContext,
    SEC_WCHAR* pszTargetName,
    unsigned long fContextReq,
    unsigned long Reserved1,
    unsigned long TargetDataRep,
    PSecBufferDesc pInput,
    unsigned long Reserved2,
    PCtxtHandle phNewContext,
    PSecBufferDesc pOutput,
    unsigned long* contextAttr,
    PTimeStamp ptsExpiry) {
  MockContext* new_context = new MockContext;
  new_context->credential = MockCredential::FromHandle(phCredential);
  new_context->target_principal = base::as_u16cstr(pszTargetName);
  new_context->rounds = 0;

  // Always rotate contexts. That way tests will fail if the caller's context
  // management is broken.
  if (phContext && SecIsValidHandle(phContext)) {
    std::unique_ptr<MockContext> old_context{
        MockContext::FromHandle(phContext)};
    EXPECT_EQ(old_context->credential, new_context->credential);
    EXPECT_EQ(1u, active_contexts_.erase(*phContext));

    new_context->rounds = old_context->rounds + 1;
    SecInvalidateHandle(phContext);
  }

  new_context->StoreInHandle(phNewContext);
  active_contexts_.insert(*phNewContext);

  auto token = new_context->ToString();
  PSecBuffer out_buffer = pOutput->pBuffers;
  out_buffer->cbBuffer = std::min<ULONG>(out_buffer->cbBuffer, token.size());
  std::memcpy(out_buffer->pvBuffer, token.data(), out_buffer->cbBuffer);

  if (ptsExpiry) {
    ptsExpiry->LowPart = 0xBAA5B780;
    ptsExpiry->HighPart = 0x01D54E15;
  }
  return SEC_E_OK;
}

SECURITY_STATUS MockSSPILibrary::QueryContextAttributesEx(PCtxtHandle phContext,
                                                          ULONG ulAttribute,
                                                          PVOID pBuffer,
                                                          ULONG cbBuffer) {
  static const SecPkgInfoW kNegotiatedPackage = {
      0,
      0,
      0,
      0,
      const_cast<SEC_WCHAR*>(L"Itsa me Kerberos!!"),
      const_cast<SEC_WCHAR*>(L"I like turtles")};

  auto* context = MockContext::FromHandle(phContext);

  switch (ulAttribute) {
    case SECPKG_ATTR_NATIVE_NAMES: {
      auto* native_names =
          reinterpret_cast<SecPkgContext_NativeNames*>(pBuffer);
      DCHECK_EQ(sizeof(*native_names), cbBuffer);
      native_names->sClientName =
          base::as_writable_wcstr(context->credential->source_principal);
      native_names->sServerName =
          base::as_writable_wcstr(context->target_principal);
      return SEC_E_OK;
    }

    case SECPKG_ATTR_NEGOTIATION_INFO: {
      auto* negotiation_info =
          reinterpret_cast<SecPkgContext_NegotiationInfo*>(pBuffer);
      DCHECK_EQ(sizeof(*negotiation_info), cbBuffer);
      negotiation_info->PackageInfo =
          const_cast<SecPkgInfoW*>(&kNegotiatedPackage);
      negotiation_info->NegotiationState = (context->rounds == 1)
                                               ? SECPKG_NEGOTIATION_COMPLETE
                                               : SECPKG_NEGOTIATION_IN_PROGRESS;
      return SEC_E_OK;
    }

    case SECPKG_ATTR_AUTHORITY: {
      auto* authority = reinterpret_cast<SecPkgContext_Authority*>(pBuffer);
      DCHECK_EQ(sizeof(*authority), cbBuffer);
      authority->sAuthorityName = const_cast<SEC_WCHAR*>(L"Dodgy Server");
      return SEC_E_OK;
    }

    default:
      return SEC_E_UNSUPPORTED_FUNCTION;
  }
}

SECURITY_STATUS MockSSPILibrary::QuerySecurityPackageInfo(
    PSecPkgInfoW* pkgInfo) {
  if (expected_package_queries_.empty()) {
    static SecPkgInfoW kDefaultPkgInfo{
        0, 0, 0, kDefaultMaxTokenLength, nullptr, nullptr};
    *pkgInfo = &kDefaultPkgInfo;
    expected_freed_packages_.insert(&kDefaultPkgInfo);
    return SEC_E_OK;
  }

  PackageQuery package_query = expected_package_queries_.front();
  expected_package_queries_.pop_front();
  *pkgInfo = package_query.package_info;
  if (package_query.response_code == SEC_E_OK)
    expected_freed_packages_.insert(package_query.package_info);
  return package_query.response_code;
}

SECURITY_STATUS MockSSPILibrary::FreeCredentialsHandle(
    PCredHandle phCredential) {
  DCHECK(SecIsValidHandle(phCredential));
  EXPECT_EQ(1u, active_credentials_.erase(*phCredential));
  std::unique_ptr<MockCredential> owned{
      MockCredential::FromHandle(phCredential)};
  SecInvalidateHandle(phCredential);
  return SEC_E_OK;
}

SECURITY_STATUS MockSSPILibrary::DeleteSecurityContext(PCtxtHandle phContext) {
  std::unique_ptr<MockContext> context{MockContext::FromHandle(phContext)};
  EXPECT_EQ(1u, active_contexts_.erase(*phContext));
  SecInvalidateHandle(phContext);
  return SEC_E_OK;
}

SECURITY_STATUS MockSSPILibrary::FreeContextBuffer(PVOID pvContextBuffer) {
  PSecPkgInfoW package_info = static_cast<PSecPkgInfoW>(pvContextBuffer);
  std::set<PSecPkgInfoW>::iterator it = expected_freed_packages_.find(
      package_info);
  EXPECT_TRUE(it != expected_freed_packages_.end());
  expected_freed_packages_.erase(it);
  return SEC_E_OK;
}

void MockSSPILibrary::ExpectQuerySecurityPackageInfo(
    SECURITY_STATUS response_code,
    PSecPkgInfoW package_info) {
  expected_package_queries_.emplace_back(
      PackageQuery{response_code, package_info});
}

}  // namespace net