chromium/chrome/updater/net/network_fetcher_win.cc

// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "components/winhttp/network_fetcher.h"

#include <windows.h>

#include <cstdint>
#include <memory>
#include <optional>
#include <string>
#include <utility>

#include "base/containers/flat_map.h"
#include "base/functional/bind.h"
#include "base/functional/callback.h"
#include "base/functional/callback_helpers.h"
#include "base/logging.h"
#include "base/memory/scoped_refptr.h"
#include "base/sequence_checker.h"
#include "base/strings/sys_string_conversions.h"
#include "base/strings/utf_string_conversions.h"
#include "chrome/updater/net/network.h"
#include "chrome/updater/policy/service.h"
#include "chrome/updater/updater_scope.h"
#include "chrome/updater/util/util.h"
#include "chrome/updater/util/win_util.h"
#include "chrome/updater/win/scoped_handle.h"
#include "chrome/updater/win/scoped_impersonation.h"
#include "chrome/updater/win/user_info.h"
#include "components/update_client/network.h"
#include "components/winhttp/proxy_configuration.h"
#include "components/winhttp/scoped_hinternet.h"
#include "url/gurl.h"

namespace updater {
namespace {

// Factory method for the proxy configuration strategy.
scoped_refptr<winhttp::ProxyConfiguration> GetProxyConfiguration(
    std::optional<PolicyServiceProxyConfiguration>
        policy_service_proxy_configuration) {
  if (policy_service_proxy_configuration) {
    VLOG(1) << "Using cloud policy configuration for proxy.";
    return base::MakeRefCounted<winhttp::ProxyConfiguration>(winhttp::ProxyInfo{
        policy_service_proxy_configuration->proxy_auto_detect,
        base::SysUTF8ToWide(
            policy_service_proxy_configuration->proxy_pac_url.value_or("")),
        base::SysUTF8ToWide(
            policy_service_proxy_configuration->proxy_url.value_or("")),
        L""});
  }
  VLOG(1) << "Using the system configuration for proxy.";
  return base::MakeRefCounted<winhttp::AutoProxyConfiguration>();
}

class NetworkFetcher : public update_client::NetworkFetcher {
 public:
  using ResponseStartedCallback =
      update_client::NetworkFetcher::ResponseStartedCallback;
  using ProgressCallback = update_client::NetworkFetcher::ProgressCallback;
  using PostRequestCompleteCallback =
      update_client::NetworkFetcher::PostRequestCompleteCallback;
  using DownloadToFileCompleteCallback =
      update_client::NetworkFetcher::DownloadToFileCompleteCallback;

  NetworkFetcher(scoped_refptr<winhttp::SharedHInternet> session_handle,
                 scoped_refptr<winhttp::ProxyConfiguration> proxy_config);
  ~NetworkFetcher() override;
  NetworkFetcher(const NetworkFetcher&) = delete;
  NetworkFetcher& operator=(const NetworkFetcher&) = delete;

  // NetworkFetcher overrides.
  void PostRequest(
      const GURL& url,
      const std::string& post_data,
      const std::string& content_type,
      const base::flat_map<std::string, std::string>& post_additional_headers,
      ResponseStartedCallback response_started_callback,
      ProgressCallback progress_callback,
      PostRequestCompleteCallback post_request_complete_callback) override;
  base::OnceClosure DownloadToFile(
      const GURL& url,
      const base::FilePath& file_path,
      ResponseStartedCallback response_started_callback,
      ProgressCallback progress_callback,
      DownloadToFileCompleteCallback download_to_file_complete_callback)
      override;

 private:
  SEQUENCE_CHECKER(sequence_checker_);

  void PostRequestComplete(int response_code);
  void DownloadToFileComplete(int response_code);

  scoped_refptr<winhttp::NetworkFetcher> winhttp_network_fetcher_;

  DownloadToFileCompleteCallback download_to_file_complete_callback_;
  PostRequestCompleteCallback post_request_complete_callback_;
};

NetworkFetcher::NetworkFetcher(
    scoped_refptr<winhttp::SharedHInternet> session_handle,
    scoped_refptr<winhttp::ProxyConfiguration> proxy_config)
    : winhttp_network_fetcher_(
          base::MakeRefCounted<winhttp::NetworkFetcher>(session_handle,
                                                        proxy_config)) {}

NetworkFetcher::~NetworkFetcher() {
  winhttp_network_fetcher_->Close();
}

void NetworkFetcher::PostRequest(
    const GURL& url,
    const std::string& post_data,
    const std::string& content_type,
    const base::flat_map<std::string, std::string>& post_additional_headers,
    ResponseStartedCallback response_started_callback,
    ProgressCallback progress_callback,
    PostRequestCompleteCallback post_request_complete_callback) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  VLOG(2) << __func__;
  post_request_complete_callback_ = std::move(post_request_complete_callback);
  winhttp_network_fetcher_->PostRequest(
      url, post_data, content_type, post_additional_headers,
      std::move(response_started_callback), std::move(progress_callback),
      base::BindOnce(&NetworkFetcher::PostRequestComplete,
                     base::Unretained(this)));
}

base::OnceClosure NetworkFetcher::DownloadToFile(
    const GURL& url,
    const base::FilePath& file_path,
    ResponseStartedCallback response_started_callback,
    ProgressCallback progress_callback,
    DownloadToFileCompleteCallback download_to_file_complete_callback) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  VLOG(2) << __func__;
  download_to_file_complete_callback_ =
      std::move(download_to_file_complete_callback);
  return winhttp_network_fetcher_->DownloadToFile(
      url, file_path, std::move(response_started_callback),
      std::move(progress_callback),
      base::BindOnce(&NetworkFetcher::DownloadToFileComplete,
                     base::Unretained(this)));
}

void NetworkFetcher::PostRequestComplete(int response_code) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  VLOG(2) << __func__ << ": response code=" << response_code;

  // Attempt to get some response headers.  Not all headers may be present so
  // this is best effort only.
  std::wstring x_cup_server_proof;
  std::wstring etag;
  int x_retry_after_sec = -1;
  winhttp_network_fetcher_->QueryHeaderString(
      base::SysUTF8ToWide(
          update_client::NetworkFetcher::kHeaderXCupServerProof),
      &x_cup_server_proof);
  winhttp_network_fetcher_->QueryHeaderString(
      base::SysUTF8ToWide(update_client::NetworkFetcher::kHeaderEtag), &etag);
  winhttp_network_fetcher_->QueryHeaderInt(
      base::SysUTF8ToWide(update_client::NetworkFetcher::kHeaderXRetryAfter),
      &x_retry_after_sec);

  std::move(post_request_complete_callback_)
      .Run(std::make_unique<std::string>(
               winhttp_network_fetcher_->GetResponseBody()),
           winhttp_network_fetcher_->GetNetError(), base::SysWideToUTF8(etag),
           base::SysWideToUTF8(x_cup_server_proof), x_retry_after_sec);
}

void NetworkFetcher::DownloadToFileComplete(int /*response_code*/) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  VLOG(2) << __func__;
  std::move(download_to_file_complete_callback_)
      .Run(winhttp_network_fetcher_->GetNetError(),
           winhttp_network_fetcher_->GetContentSize());
}

}  // namespace

class NetworkFetcherFactory::Impl {
 public:
  explicit Impl(std::optional<PolicyServiceProxyConfiguration>
                    policy_service_proxy_configuration)
      : proxy_configuration_(
            GetProxyConfiguration(policy_service_proxy_configuration)) {
    ScopedImpersonation impersonate;
    if (IsSystemInstall()) {
      HResultOr<ScopedKernelHANDLE> token = GetLoggedOnUserToken();
      VLOG_IF(2, !token.has_value())
          << __func__ << ": GetLoggedOnUserToken failed: " << std::hex
          << token.error();
      if (token.has_value()) {
        const HRESULT hr = impersonate.Impersonate(token.value().get());
        VLOG(2)
            << __func__
            << ": Successfully got logged on user token. Impersonate result: "
            << std::hex << hr;
      }
    }
    session_handle_ = base::MakeRefCounted<winhttp::SharedHInternet>(
        winhttp::CreateSessionHandle(base::SysUTF8ToWide(GetUpdaterUserAgent()),
                                     proxy_configuration_->access_type(),
                                     proxy_configuration_->proxy(),
                                     proxy_configuration_->proxy_bypass()));
    VLOG_IF(2, !session_handle_) << "Failed to create a winhttp session.";
  }

  std::unique_ptr<update_client::NetworkFetcher> Create() {
    return session_handle_ ? std::make_unique<NetworkFetcher>(
                                 session_handle_, proxy_configuration_)
                           : nullptr;
  }

 private:
  scoped_refptr<winhttp::ProxyConfiguration> proxy_configuration_;
  scoped_refptr<winhttp::SharedHInternet> session_handle_;
};

NetworkFetcherFactory::NetworkFetcherFactory(
    std::optional<PolicyServiceProxyConfiguration>
        policy_service_proxy_configuration)
    : impl_(std::make_unique<Impl>(policy_service_proxy_configuration)) {}
NetworkFetcherFactory::~NetworkFetcherFactory() = default;

std::unique_ptr<update_client::NetworkFetcher> NetworkFetcherFactory::Create()
    const {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  return impl_->Create();
}

}  // namespace updater