chromium/chrome/browser/enterprise/platform_auth/cloud_ap_provider_win.cc

// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifdef UNSAFE_BUFFERS_BUILD
// TODO(crbug.com/40285824): Remove this and convert code to safer constructs.
#pragma allow_unsafe_buffers
#endif

#include "chrome/browser/enterprise/platform_auth/cloud_ap_provider_win.h"

#include <objbase.h>

#include <windows.h>

#include <lmcons.h>
#include <lmjoin.h>
#include <proofofpossessioncookieinfo.h>
#include <stdint.h>
#include <windows.security.authentication.web.core.h>
#include <wrl/client.h>

#include <memory>
#include <string_view>
#include <utility>
#include <vector>

#include "base/callback_list.h"
#include "base/check.h"
#include "base/feature_list.h"
#include "base/functional/bind.h"
#include "base/functional/callback.h"
#include "base/location.h"
#include "base/logging.h"
#include "base/memory/ref_counted.h"
#include "base/metrics/histogram_functions.h"
#include "base/native_library.h"
#include "base/scoped_native_library.h"
#include "base/sequence_checker.h"
#include "base/strings/string_util.h"
#include "base/strings/utf_string_conversions.h"
#include "base/task/sequenced_task_runner.h"
#include "base/task/task_runner.h"
#include "base/task/task_traits.h"
#include "base/task/thread_pool.h"
#include "base/threading/platform_thread.h"
#include "base/timer/elapsed_timer.h"
#include "base/win/com_init_util.h"
#include "base/win/core_winrt_util.h"
#include "base/win/post_async_results.h"
#include "base/win/scoped_hstring.h"
#include "chrome/browser/enterprise/platform_auth/cloud_ap_utils_win.h"
#include "chrome/browser/enterprise/platform_auth/platform_auth_features.h"
#include "net/cookies/cookie_util.h"
#include "net/http/http_request_headers.h"
#include "url/gurl.h"

using ABI::Windows::Foundation::IAsyncOperation;
using ABI::Windows::Security::Authentication::Web::Core::
    IWebAuthenticationCoreManagerStatics;
using ABI::Windows::Security::Credentials::IWebAccountProvider;
using ABI::Windows::Security::Credentials::WebAccountProvider;
using Microsoft::WRL::ComPtr;

namespace enterprise_auth {

namespace {

using OnSupportLevelCallback =
    base::OnceCallback<void(CloudApProviderWin::SupportLevel)>;

// A helper to manage the lifetime of various objects while checking to see if
// there is at least one WebAccount for the default provider.
class WebAccountSupportFinder
    : public base::RefCountedThreadSafe<WebAccountSupportFinder> {
 public:
  REQUIRE_ADOPTION_FOR_REFCOUNTED_TYPE();

  // `on_support_level` is posted to `result_runner` upon destruction with the
  // results of the operation. Reports `SupportLevel::kEnabled` if an account is
  // found, `SupportLevel::kDisabled` if no account is found, or
  // `SupportLevel::kUnsupported` in case of any error.
  WebAccountSupportFinder(scoped_refptr<base::TaskRunner> result_runner,
                          OnSupportLevelCallback on_support_level)
      : result_runner_(std::move(result_runner)),
        on_support_level_(std::move(on_support_level)) {
    DETACH_FROM_SEQUENCE(sequence_checker_);
  }

  WebAccountSupportFinder(const WebAccountSupportFinder&) = delete;
  WebAccountSupportFinder& operator=(const WebAccountSupportFinder&) = delete;

  // Starts the operation.
  void Find() {
    DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
    base::win::AssertComApartmentType(base::win::ComApartmentType::MTA);

    // Get the `WebAuthenticationCoreManager`.
    ComPtr<IWebAuthenticationCoreManagerStatics> auth_manager;
    HRESULT hresult = base::win::GetActivationFactory<
        IWebAuthenticationCoreManagerStatics,
        RuntimeClass_Windows_Security_Authentication_Web_Core_WebAuthenticationCoreManager>(
        &auth_manager);
    if (FAILED(hresult))
      return;  // Unsupported.

    ComPtr<IAsyncOperation<WebAccountProvider*>> find_provider_op;

    // "https://login.windows.local" -- account provider for the OS. Don't
    // specify an authority when using it.
    // https://docs.microsoft.com/en-us/uwp/api/windows.security.authentication.web.core.webauthenticationcoremanager.findaccountproviderasync?view=winrt-19041
    hresult = auth_manager->FindAccountProviderAsync(
        base::win::ScopedHString::Create(L"https://login.windows.local").get(),
        &find_provider_op);
    if (FAILED(hresult))
      return;  // Unsupported.

    hresult = base::win::PostAsyncHandlers(
        find_provider_op.Get(),
        base::BindOnce(&WebAccountSupportFinder::OnAccountProvider,
                       base::WrapRefCounted(this)));
    if (FAILED(hresult)) {
      DLOG(ERROR)
          << __func__
          << ": Failed to post result task for provider fetch; HRESULT = "
          << std::hex << hresult;
    }
  }

 private:
  friend class base::RefCountedThreadSafe<WebAccountSupportFinder>;

  // Posts `on_support_level_` with `support_level_` to `result_runner_`.
  ~WebAccountSupportFinder() {
    DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
    result_runner_->PostTask(
        FROM_HERE,
        base::BindOnce(std::move(on_support_level_), support_level_));
  }

  // Handles the result of a successful call to `FindAccountProviderAsync()`.
  void OnAccountProvider(ComPtr<IWebAccountProvider> account_provider) {
    DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);

    // Regardless of whether a provider is found, the machine supports account
    // providers.
    support_level_ = account_provider
                         ? CloudApProviderWin::SupportLevel::kEnabled
                         : CloudApProviderWin::SupportLevel::kDisabled;
  }

  scoped_refptr<base::TaskRunner> result_runner_;
  OnSupportLevelCallback on_support_level_;
  CloudApProviderWin::SupportLevel support_level_ =
      CloudApProviderWin::SupportLevel::kUnsupported;
  SEQUENCE_CHECKER(sequence_checker_);
};

CloudApProviderWin::SupportLevel* support_level_for_testing_ = nullptr;

// Returns the platform's ProofOfPossessionCookieInfoManager, or null if
// unsupported. `hresult_out`, if provided, is populated with the result of
// object creation.
ComPtr<IProofOfPossessionCookieInfoManager> MakeCookieInfoManager(
    HRESULT* hresult_out = nullptr) {
  // CLSID_ProofOfPossessionCookieInfoManager from
  // ProofOfPossessionCookieInfo.h.
  static constexpr CLSID kClsidProofOfPossessionCookieInfoManager = {
      0xA9927F85,
      0xA304,
      0x4390,
      {0x8B, 0x23, 0xA7, 0x5F, 0x1C, 0x66, 0x86, 0x00}};

  // There is no need for SCOPED_MAY_LOAD_LIBRARY_AT_BACKGROUND_PRIORITY here
  // since this task is posted at USER_VISIBLE priority.
  DCHECK_NE(base::PlatformThread::GetCurrentThreadType(),
            base::ThreadType::kBackground);
  base::win::AssertComInitialized();

  ComPtr<IProofOfPossessionCookieInfoManager> manager;

  HRESULT hresult = ::CoCreateInstance(
      kClsidProofOfPossessionCookieInfoManager,
      /*pUnkOuter=*/nullptr, CLSCTX_INPROC_SERVER, IID_PPV_ARGS(&manager));
  if (hresult_out)
    *hresult_out = hresult;
  return SUCCEEDED(hresult) ? manager : nullptr;
}

void ParseCookieInfo(const ProofOfPossessionCookieInfo* cookie_info,
                     const DWORD cookie_info_count,
                     net::HttpRequestHeaders& auth_headers) {
  net::cookie_util::ParsedRequestCookies parsed_cookies;

  // If the auth cookie name begins with 'x-ms-', attach the cookie as a
  // new header. Otherwise, append it to the existing list of cookies.
  static constexpr std::string_view kHeaderPrefix("x-ms-");
  for (DWORD i = 0; i < cookie_info_count; ++i) {
    const ProofOfPossessionCookieInfo& cookie = cookie_info[i];
    auto ascii_cookie_name = base::WideToASCII(cookie.name);
    if (base::StartsWith(ascii_cookie_name, kHeaderPrefix,
                         base::CompareCase::INSENSITIVE_ASCII)) {
      // Removing cookie attributes from the value before setting it as a
      // header.
      std::string ascii_cookie_value = base::WideToASCII(cookie.data);
      std::string::size_type cookie_attributes_position =
          ascii_cookie_value.find(";");
      if (cookie_attributes_position != std::string::npos) {
        ascii_cookie_value =
            ascii_cookie_value.substr(0, cookie_attributes_position);
      }
      auth_headers.SetHeader(std::move(ascii_cookie_name),
                             std::move(ascii_cookie_value));
    } else {
      parsed_cookies.emplace_back(std::move(ascii_cookie_name),
                                  base::WideToASCII(cookie.data));
    }
  }

  if (parsed_cookies.size() > 0) {
    auth_headers.SetHeader(
        net::HttpRequestHeaders::kCookie,
        net::cookie_util::SerializeRequestCookieLine(parsed_cookies));
  }
}

// Returns the proof-of-possession cookies and headers for the interactive
// user to authenticate to the IdP/STS at `url`.
net::HttpRequestHeaders GetAuthData(const GURL& url) {
  base::win::AssertComInitialized();
  DCHECK(url.is_valid());

  net::HttpRequestHeaders auth_headers;
  DWORD cookie_info_count = 0;
  base::ElapsedTimer elapsed_timer;

  HRESULT hresult = S_OK;
  auto manager = MakeCookieInfoManager(&hresult);
  if (manager) {
    ProofOfPossessionCookieInfo* cookie_info = nullptr;
    hresult =
        manager->GetCookieInfoForUri(base::ASCIIToWide(url.spec()).c_str(),
                                     &cookie_info_count, &cookie_info);
    if (SUCCEEDED(hresult)) {
      DCHECK(!cookie_info_count || cookie_info);
      ParseCookieInfo(cookie_info, cookie_info_count, auth_headers);
      if (cookie_info)
        FreeProofOfPossessionCookieInfoArray(cookie_info, cookie_info_count);
    }
  }
  const auto delta = elapsed_timer.Elapsed();

  if (SUCCEEDED(hresult)) {
    base::UmaHistogramTimes("Enterprise.PlatformAuth.GetAuthData.SuccessTime",
                            delta);
    base::UmaHistogramExactLinear("Enterprise.PlatformAuth.GetAuthData.Count",
                                  cookie_info_count,
                                  10);  // Expect < 10 cookies.
  } else {
    base::UmaHistogramTimes("Enterprise.PlatformAuth.GetAuthData.FailureTime",
                            delta);
    base::UmaHistogramSparse(
        "Enterprise.PlatformAuth.GetAuthData.FailureHresult", int{hresult});
  }

  return auth_headers;
}

// Returns the support level based on Azure AD join status.
CloudApProviderWin::SupportLevel GetAadJoinSupportLevel() {
  // There is no need for `SCOPED_MAY_LOAD_LIBRARY_AT_BACKGROUND_PRIORITY` here
  // since this task is posted at `USER_VISIBLE` priority.
  DCHECK_NE(base::PlatformThread::GetCurrentThreadType(),
            base::ThreadType::kBackground);

  // If Azure AD join info retrieval fails, this feature is not supported.
  PDSREG_JOIN_INFO join_info = nullptr;
  if (FAILED(::NetGetAadJoinInformation(/*pcszTenantId=*/nullptr, &join_info)))
    return CloudApProviderWin::SupportLevel::kUnsupported;

  // Azure AD join info was retrieved successfully, so the feature is supported.
  // This will free the retrieved Azure AD join info after going out of scope.
  std::unique_ptr<DSREG_JOIN_INFO, decltype(&NetFreeAadJoinInformation)>
      scoped_join_info(join_info, ::NetFreeAadJoinInformation);

  return (!join_info || join_info->joinType == DSREG_UNKNOWN_JOIN)
             ? CloudApProviderWin::SupportLevel::kDisabled
             : CloudApProviderWin::SupportLevel::kEnabled;
}

// Handles the results of checking for a WebAccount from the default provider.
void OnFindWebAccount(OnSupportLevelCallback on_support_level,
                      CloudApProviderWin::SupportLevel support_level) {
  // Full support if there's at least one WebAccount for the default provider.
  if (support_level == CloudApProviderWin::SupportLevel::kEnabled) {
    std::move(on_support_level).Run(CloudApProviderWin::SupportLevel::kEnabled);
    return;
  }

  // Otherwise, support is based on whether or not the device is AAD-joined one
  // way (device joined) or another (workplace joined).
  std::move(on_support_level).Run(GetAadJoinSupportLevel());
}

// Evaluates the level of support for Cloud AP SSO, running `on_support_level`
// on the caller's sequence (synchronously or asynchronously) with the result.
void GetSupportLevel(OnSupportLevelCallback on_support_level) {
  if (support_level_for_testing_) {
    std::move(on_support_level).Run(*support_level_for_testing_);
    return;
  }

  // Check if the machine has the ProofOfPossessionCookieInfoManager COM class.
  if (!MakeCookieInfoManager()) {
    std::move(on_support_level)
        .Run(CloudApProviderWin::SupportLevel::kUnsupported);
    return;
  }

  // Check if there's at least one WebAccount for the default provider.
  base::MakeRefCounted<WebAccountSupportFinder>(
      base::SequencedTaskRunner::GetCurrentDefault(),
      base::BindOnce(&OnFindWebAccount, std::move(on_support_level)))
      ->Find();
}

// Reads the IdP origins from the Windows registry.
std::vector<url::Origin> ReadOrigins() {
  static constexpr wchar_t kLoginUri[] = L"LoginUri";
  std::vector<url::Origin> result;

  // Windows registry locations (provided by Microsoft) which are expected to
  // contain Microsoft IdP origins.
  AppendRegistryOrigins(HKEY_LOCAL_MACHINE,
                        L"SOFTWARE\\Microsoft\\IdentityStore\\LoadParameters\\"
                        L"{B16898C6-A148-4967-9171-64D755DA8520}",
                        kLoginUri, result);
  AppendRegistryOrigins(
      HKEY_LOCAL_MACHINE,
      L"SOFTWARE\\Microsoft\\IdentityStore\\Providers\\"
      L"{B16898C6-A148-4967-9171-64D755DA8520}\\LoadParameters",
      kLoginUri, result);
  AppendRegistryOrigins(
      HKEY_CURRENT_USER,
      L"Software\\Microsoft\\Windows\\CurrentVersion\\AAD\\Package", kLoginUri,
      result);
  AppendRegistryOrigins(HKEY_LOCAL_MACHINE, L"SOFTWARE\\Microsoft\\IdentityCRL",
                        L"LoginUrl", result);

  if (result.empty()) {
    // Certain legacy versions of Windows may not have origins in the registry.
    // Use the two well-known origins if none other are found.
    result.push_back(url::Origin::Create(GURL("https://login.live.com")));
    result.push_back(
        url::Origin::Create(GURL("https://login.microsoftonline.com")));
  }

  return result;
}

// Handles the results of a call to `GetSupportLevel()`.
void OnSupportLevel(scoped_refptr<base::TaskRunner> result_runner,
                    CloudApProviderWin::FetchOriginsCallback on_origins,
                    CloudApProviderWin::SupportLevel support_level) {
  std::unique_ptr<std::vector<url::Origin>> results;

  switch (support_level) {
    case CloudApProviderWin::SupportLevel::kUnsupported:
      // There is no hope in trying again.
      break;
    case CloudApProviderWin::SupportLevel::kDisabled:
      // Not joined at the moment, but could change in the future.
      results = std::make_unique<std::vector<url::Origin>>();
      break;
    case CloudApProviderWin::SupportLevel::kEnabled:
      results = std::make_unique<std::vector<url::Origin>>(ReadOrigins());
      break;
  }

  result_runner->PostTask(
      FROM_HERE, base::BindOnce(std::move(on_origins), std::move(results)));
}

// Fetches the collection of IdP/STS origins in the ThreadPool. Runs
// `on_origins` on `result_runner` with the origins or nullptr if Cloud AP SSO
// is not supported.
void FetchOriginsInPool(scoped_refptr<base::TaskRunner> result_runner,
                        CloudApProviderWin::FetchOriginsCallback on_origins) {
  GetSupportLevel(base::BindOnce(&OnSupportLevel, std::move(result_runner),
                                 std::move(on_origins)));
}

}  // namespace

CloudApProviderWin::CloudApProviderWin() = default;

CloudApProviderWin::~CloudApProviderWin() = default;

bool CloudApProviderWin::SupportsOriginFiltering() {
  return true;
}

void CloudApProviderWin::FetchOrigins(FetchOriginsCallback on_fetch_complete) {
  // The strategy is as follows:
  // 1. See if the ProofOfPossessionCookieInfoManager can be instantiated. If
  //    not, the platform doesn't support AAD SSO.
  // 2. See if the user has a WebAccount from the default provider. If they do,
  //    the platform supports AAD SSO and it is enabled.
  // 3. See if either the device is joined to an AAD domain or if an AAD work
  //    account has been added. In either case, the device supports AAD SSO and
  //    it is enabled.
  // 4. If checking the join status failed, the platform doesn't support AAD
  //    SSO; otherwise, the platform supports AAD SSO but it is disabled.
  // The callback is run with:
  // - nullptr if AAD SSO is not supported.
  // - an empty collection of origins if AAD SSO is supported but disabled.
  // - two or more URLs if AAD SSO is supported and enabled.
  base::ThreadPool::CreateSequencedTaskRunner(
      {base::TaskPriority::USER_VISIBLE,
       base::TaskShutdownBehavior::CONTINUE_ON_SHUTDOWN, base::MayBlock()})
      ->PostTask(FROM_HERE,
                 base::BindOnce(&FetchOriginsInPool,
                                base::SequencedTaskRunner::GetCurrentDefault(),
                                std::move(on_fetch_complete)));
}

void CloudApProviderWin::GetData(
    const GURL& url,
    PlatformAuthProviderManager::GetDataCallback callback) {
  get_data_subscription_ = on_get_data_callback_list_.Add(std::move(callback));
  if (!base::ThreadPool::CreateCOMSTATaskRunner(
           {base::TaskPriority::USER_BLOCKING,
            base::TaskShutdownBehavior::CONTINUE_ON_SHUTDOWN, base::MayBlock()})
           ->PostTaskAndReplyWithResult(
               FROM_HERE, base::BindOnce(&GetAuthData, url),
               base::BindOnce(&CloudApProviderWin::OnGetDataCallback,
                              base::Unretained(this)))) {
    OnGetDataCallback(net::HttpRequestHeaders());
  }
}

void CloudApProviderWin::OnGetDataCallback(
    net::HttpRequestHeaders auth_headers) {
  on_get_data_callback_list_.Notify(std::move(auth_headers));
}

// static
void CloudApProviderWin::SetSupportLevelForTesting(
    std::optional<SupportLevel> level) {
  delete std::exchange(support_level_for_testing_, nullptr);
  if (!level)
    return;
  support_level_for_testing_ = new SupportLevel;
  *support_level_for_testing_ = level.value();
}

void CloudApProviderWin::ParseCookieInfoForTesting(
    const ProofOfPossessionCookieInfo* cookie_info,
    const DWORD cookie_info_count,
    net::HttpRequestHeaders& auth_headers) {
  ParseCookieInfo(cookie_info, cookie_info_count, auth_headers);
}

}  // namespace enterprise_auth