chromium/chrome/credential_provider/gaiacp/scoped_lsa_policy.cc

// Copyright 2018 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/credential_provider/gaiacp/scoped_lsa_policy.h"

#include <Windows.h>  // Needed for ACCESS_MASK, <lm.h>
#include <Winternl.h>
#include <lm.h>  // Needed for LSA_UNICODE_STRING
#include <ntstatus.h>

#define _NTDEF_  // Prevent redefition errors, must come after <winternl.h>
#include <ntsecapi.h>  // For LSA_xxx types

#include "chrome/credential_provider/gaiacp/gcp_utils.h"  // For STATUS_SUCCESS.
#include "chrome/credential_provider/gaiacp/logging.h"

namespace credential_provider {

// static
ScopedLsaPolicy::CreatorCallback* ScopedLsaPolicy::GetCreatorCallbackStorage() {
  static CreatorCallback creator_for_testing;
  return &creator_for_testing;
}

// static
void ScopedLsaPolicy::SetCreatorForTesting(CreatorCallback creator) {
  *GetCreatorCallbackStorage() = creator;
}

// static
std::unique_ptr<ScopedLsaPolicy> ScopedLsaPolicy::Create(ACCESS_MASK mask) {
  if (!GetCreatorCallbackStorage()->is_null())
    return GetCreatorCallbackStorage()->Run(mask);

  std::unique_ptr<ScopedLsaPolicy> scoped(new ScopedLsaPolicy(mask));
  return scoped->IsValid() ? std::move(scoped) : nullptr;
}

ScopedLsaPolicy::ScopedLsaPolicy(ACCESS_MASK mask) {
  LSA_OBJECT_ATTRIBUTES oa;
  memset(&oa, 0, sizeof(oa));
  NTSTATUS sts = ::LsaOpenPolicy(nullptr, &oa, mask, &handle_);
  if (sts != STATUS_SUCCESS) {
    HRESULT hr = HRESULT_FROM_NT(sts);
    LOGFN(ERROR) << "LsaOpenPolicy hr=" << putHR(hr);
    ::SetLastError(hr);
    handle_ = nullptr;
  }
}

ScopedLsaPolicy::~ScopedLsaPolicy() {
  if (handle_ != nullptr)
    ::LsaClose(handle_);
}

bool ScopedLsaPolicy::IsValid() const {
  return handle_ != nullptr;
}

HRESULT ScopedLsaPolicy::StorePrivateData(const wchar_t* key,
                                          const wchar_t* value) {
  LSA_UNICODE_STRING lsa_key;
  InitLsaString(key, &lsa_key);
  LSA_UNICODE_STRING lsa_value;
  InitLsaString(value, &lsa_value);

  // When calling LsaStorePrivateData(), the value's length should include
  // the null terminator.
  lsa_value.Length = lsa_value.MaximumLength;

  NTSTATUS sts = ::LsaStorePrivateData(handle_, &lsa_key, &lsa_value);
  if (sts != STATUS_SUCCESS) {
    HRESULT hr = HRESULT_FROM_NT(sts);
    LOGFN(ERROR) << "LsaStorePrivateData hr=" << putHR(hr);
    return hr;
  }
  return S_OK;
}

HRESULT ScopedLsaPolicy::RemovePrivateData(const wchar_t* key) {
  LSA_UNICODE_STRING lsa_key;
  InitLsaString(key, &lsa_key);

  NTSTATUS sts = ::LsaStorePrivateData(handle_, &lsa_key, nullptr);
  if (sts != STATUS_SUCCESS) {
    HRESULT hr = HRESULT_FROM_NT(sts);
    LOGFN(ERROR) << "LsaStorePrivateData hr=" << putHR(hr);
    return hr;
  }
  return S_OK;
}

HRESULT ScopedLsaPolicy::RetrievePrivateData(const wchar_t* key,
                                             wchar_t* value,
                                             size_t length) {
  LSA_UNICODE_STRING lsa_key;
  InitLsaString(key, &lsa_key);
  LSA_UNICODE_STRING* lsa_value;

  NTSTATUS sts = ::LsaRetrievePrivateData(handle_, &lsa_key, &lsa_value);
  if (sts != STATUS_SUCCESS)
    return HRESULT_FROM_NT(sts);

  errno_t err = wcscpy_s(value, length, lsa_value->Buffer);
  SecurelyClearBuffer(lsa_value->Buffer, lsa_value->Length);
  ::LsaFreeMemory(lsa_value);

  return err == 0 ? S_OK : E_FAIL;
}

bool ScopedLsaPolicy::PrivateDataExists(const wchar_t* key) {
  LSA_UNICODE_STRING lsa_key;
  InitLsaString(key, &lsa_key);
  LSA_UNICODE_STRING* lsa_value;

  NTSTATUS sts = ::LsaRetrievePrivateData(handle_, &lsa_key, &lsa_value);

  if (sts != STATUS_SUCCESS)
    return false;

  SecurelyClearBuffer(lsa_value->Buffer, lsa_value->Length);
  ::LsaFreeMemory(lsa_value);

  return true;
}

HRESULT ScopedLsaPolicy::AddAccountRights(
    PSID sid,
    const std::vector<std::wstring>& rights) {
  LOGFN(VERBOSE);
  for (auto& right : rights) {
    std::vector<LSA_UNICODE_STRING> lsa_rights;
    LSA_UNICODE_STRING lsa_right;
    InitLsaString(right.c_str(), &lsa_right);
    lsa_rights.push_back(lsa_right);
    NTSTATUS sts = ::LsaAddAccountRights(handle_, sid, lsa_rights.data(),
                                         lsa_rights.size());
    if (sts != STATUS_SUCCESS) {
      HRESULT hr = HRESULT_FROM_NT(sts);
      LOGFN(ERROR) << "LsaAddAccountRights " << right << "sts=" << putHR(sts)
                   << " hr=" << putHR(hr);
      return hr;
    }
  }
  return S_OK;
}

HRESULT ScopedLsaPolicy::RemoveAccountRights(
    PSID sid,
    const std::vector<std::wstring>& rights) {
  LOGFN(VERBOSE);
  for (auto& right : rights) {
    std::vector<LSA_UNICODE_STRING> lsa_rights;
    LSA_UNICODE_STRING lsa_right;
    InitLsaString(right.c_str(), &lsa_right);
    lsa_rights.push_back(lsa_right);
    NTSTATUS sts = ::LsaRemoveAccountRights(
        handle_, sid, FALSE, lsa_rights.data(), lsa_rights.size());
    if (sts != STATUS_SUCCESS) {
      HRESULT hr = HRESULT_FROM_NT(sts);
      // Donot log error message when the privilege isn't
      // assigned to the user and removal of the account right
      // fails.
      if (hr != HRESULT_FROM_NT(STATUS_OBJECT_NAME_NOT_FOUND)) {
        LOGFN(ERROR) << "LsaRemoveAccountRights " << right
                     << " sts=" << putHR(sts) << " hr=" << putHR(hr);
      }
    }
  }
  return S_OK;
}

HRESULT ScopedLsaPolicy::RemoveAccount(PSID sid) {
  // When all rights are removed from an account, the account itself is also
  // deleted.
  // I thought the above meant the user would be removed from the
  // computer, but apparently I am mistaken.  It is still important to call
  // NetUserDel().
  NTSTATUS sts = ::LsaRemoveAccountRights(handle_, sid, TRUE, nullptr, 0);
  if (sts != STATUS_SUCCESS) {
    HRESULT hr = HRESULT_FROM_NT(sts);
    LOGFN(ERROR) << "LsaRemoveAccountRights sts=" << putHR(sts)
                 << " hr=" << putHR(hr);
    return hr;
  }
  return S_OK;
}

// static
void ScopedLsaPolicy::InitLsaString(const wchar_t* string,
                                    _UNICODE_STRING* lsa_string) {
  lsa_string->Buffer = const_cast<wchar_t*>(string);
  lsa_string->Length =
      static_cast<USHORT>(wcslen(lsa_string->Buffer) * sizeof(wchar_t));
  lsa_string->MaximumLength = lsa_string->Length + sizeof(wchar_t);
}

}  // namespace credential_provider