// Copyright 2018 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifdef UNSAFE_BUFFERS_BUILD
// TODO(crbug.com/40285824): Remove this and convert code to safer constructs.
#pragma allow_unsafe_buffers
#endif
#include "chrome/credential_provider/gaiacp/win_http_url_fetcher.h"
#include <Windows.h>
#include <atlconv.h>
#include <process.h>
#include <winhttp.h>
#include <string>
#include <string_view>
#include "base/base64.h"
#include "base/containers/contains.h"
#include "base/containers/span.h"
#include "base/json/json_reader.h"
#include "base/json/json_writer.h"
#include "base/memory/ptr_util.h"
#include "base/strings/strcat_win.h"
#include "base/strings/utf_string_conversions.h"
#include "base/synchronization/lock.h"
#include "base/time/time.h"
#include "chrome/credential_provider/gaiacp/logging.h"
namespace {
// Key name containing the HTTP error code within the dictionary returned by the
// server in case of errors.
constexpr char kHttpErrorCodeKeyNameInResponse[] = "code";
// Error key name that is likely to be present in HTTP responses.
const char kErrorKeyInRequestResult[] = "error";
// The HTTP response codes for which the request is re-tried on failure.
constexpr int kRetryableHttpErrorCodes[] = {
503, // Service Unavailable
504 // Gateway Timeout
};
// Self deleting http service requester. This class will try to make a query
// using the given url fetcher. It will delete itself when the request is
// completed, either because the request completed successfully within the
// timeout or the request has timed out and is allowed to complete in the
// background without having the result read by anyone.
// There are two situations where the request will be deleted:
// 1. If the background thread making the request returns within the given
// timeout, the function is guaranteed to return the result that was fetched.
// 2. If however the background thread times out there are two potential
// race conditions that can occur:
// 1. The main thread making the request can mark that the background thread
// is orphaned before it can complete. In this case when the background
// thread completes it will check whether the request is orphaned and self
// delete.
// 2. The background thread completes before the main thread can mark the
// request as orphaned. In this case the background thread will have
// marked that the request is no longer processing and thus the main
// thread can self delete.
class HttpServiceRequest {
public:
static HttpServiceRequest* Create(
const GURL& request_url,
const std::string& access_token,
const std::vector<std::pair<std::string, std::string>>& headers,
const std::string& request_body,
const base::TimeDelta& request_timeout);
// Tries to fetch the request stored in |fetcher_| in a background thread
// within the given |request_timeout|. If the background thread returns before
// the timeout expires, it is guaranteed that a result can be returned and the
// requester will delete itself.
std::optional<base::Value> WaitForResponseFromHttpService(
const base::TimeDelta& request_timeout) {
std::optional<base::Value> result;
// Start the thread and wait on its handle until |request_timeout| expires
// or the thread finishes.
unsigned wait_thread_id;
uintptr_t wait_thread = ::_beginthreadex(
nullptr, 0, &HttpServiceRequest::FetchResultFromHttpService, this, 0,
&wait_thread_id);
HRESULT hr = S_OK;
if (wait_thread == 0)
return result;
// Hold the handle in the scoped handle so that it can be immediately
// closed when the wait is complete allowing the thread to finish
// completely if needed.
base::win::ScopedHandle thread_handle(
reinterpret_cast<HANDLE>(wait_thread));
hr = ::WaitForSingleObject(thread_handle.Get(),
request_timeout.InMilliseconds());
// The race condition starts here. It is possible that between the expiry of
// the timeout in the call for WaitForSingleObject and the call to
// OrphanRequest, the fetching thread could have finished. So there is a two
// part handshake. Either the background thread has called ProcessingDone
// in which case it has already passed its own check for |is_orphaned_| and
// the call to OrphanRequest should delete this object right now. Otherwise
// the background thread is still running and will be able to query the
// |is_orphaned_| state and delete the object after thread completion.
if (hr != WAIT_OBJECT_0) {
LOGFN(ERROR) << "Wait for response timed out or failed hr="
<< credential_provider::putHR(hr);
OrphanRequest();
return result;
}
result = base::JSONReader::Read(
std::string_view(response_.data(), response_.size()),
base::JSON_PARSE_CHROMIUM_EXTENSIONS |
base::JSON_ALLOW_TRAILING_COMMAS);
if (!result) {
LOGFN(ERROR) << "base::JSONReader::Read returned 0";
result.reset();
} else if (!result->is_dict()) {
LOGFN(ERROR) << "json result is not a dictionary";
result.reset();
}
delete this;
return result;
}
private:
explicit HttpServiceRequest(
std::unique_ptr<credential_provider::WinHttpUrlFetcher> fetcher)
: fetcher_(std::move(fetcher)) {
DCHECK(fetcher_);
}
void OrphanRequest() {
bool delete_self = false;
{
base::AutoLock locker(orphan_lock_);
CHECK(!is_orphaned_);
if (!is_processing_) {
delete_self = true;
} else {
is_orphaned_ = true;
}
}
if (delete_self)
delete this;
}
void ProcessingDone() {
bool delete_self = false;
{
base::AutoLock locker(orphan_lock_);
CHECK(is_processing_);
if (is_orphaned_) {
delete_self = true;
} else {
is_processing_ = false;
}
}
if (delete_self)
delete this;
}
// Background thread function that is used to query the request to the
// http service. This thread never times out and simply marks the fetcher
// as finished processing when it is done.
static unsigned __stdcall FetchResultFromHttpService(void* param) {
DCHECK(param);
auto* requester = reinterpret_cast<HttpServiceRequest*>(param);
HRESULT hr = requester->fetcher_->Fetch(&requester->response_);
if (FAILED(hr))
LOGFN(ERROR) << "fetcher.Fetch hr=" << credential_provider::putHR(hr);
requester->ProcessingDone();
return 0;
}
base::Lock orphan_lock_;
std::unique_ptr<credential_provider::WinHttpUrlFetcher> fetcher_;
std::vector<char> response_;
bool is_orphaned_ = false;
bool is_processing_ = true;
};
HttpServiceRequest* HttpServiceRequest::Create(
const GURL& request_url,
const std::string& access_token,
const std::vector<std::pair<std::string, std::string>>& headers,
const std::string& request_body,
const base::TimeDelta& request_timeout) {
auto url_fetcher =
credential_provider::WinHttpUrlFetcher::Create(request_url);
if (!url_fetcher) {
LOGFN(ERROR) << "Could not create valid fetcher for url="
<< request_url.spec();
return nullptr;
}
url_fetcher->SetRequestHeader("Content-Type", "application/json");
if (!access_token.empty()) {
url_fetcher->SetRequestHeader("Authorization",
("Bearer " + access_token).c_str());
}
for (auto& header : headers)
url_fetcher->SetRequestHeader(header.first.c_str(), header.second.c_str());
if (!request_body.empty()) {
HRESULT hr = url_fetcher->SetRequestBody(request_body.c_str());
if (FAILED(hr)) {
LOGFN(ERROR) << "fetcher.SetRequestBody hr="
<< credential_provider::putHR(hr);
return nullptr;
}
}
if (!request_timeout.is_zero()) {
url_fetcher->SetHttpRequestTimeout(request_timeout.InMilliseconds());
}
return new HttpServiceRequest(std::move(url_fetcher));
}
} // namespace
namespace credential_provider {
// static
WinHttpUrlFetcher::CreatorCallback*
WinHttpUrlFetcher::GetCreatorFunctionStorage() {
static CreatorCallback creator_for_testing;
return &creator_for_testing;
}
// static
std::unique_ptr<WinHttpUrlFetcher> WinHttpUrlFetcher::Create(const GURL& url) {
return !GetCreatorFunctionStorage()->is_null()
? GetCreatorFunctionStorage()->Run(url)
: base::WrapUnique(new WinHttpUrlFetcher(url));
}
// static
void WinHttpUrlFetcher::SetCreatorForTesting(CreatorCallback creator) {
*GetCreatorFunctionStorage() = creator;
}
WinHttpUrlFetcher::WinHttpUrlFetcher(const GURL& url)
: url_(url), session_(nullptr), request_(nullptr) {
LOGFN(VERBOSE) << "url=" << url.spec() << " (scheme and port ignored)";
ScopedWinHttpHandle::Handle session = ::WinHttpOpen(
L"GaiaCP/1.0 (Windows NT)", WINHTTP_ACCESS_TYPE_AUTOMATIC_PROXY,
WINHTTP_NO_PROXY_NAME, WINHTTP_NO_PROXY_BYPASS, 0);
if (!session) {
HRESULT hr = HRESULT_FROM_WIN32(::GetLastError());
LOGFN(ERROR) << "WinHttpOpen hr=" << putHR(hr);
}
session_.Set(session);
}
WinHttpUrlFetcher::WinHttpUrlFetcher() = default;
WinHttpUrlFetcher::~WinHttpUrlFetcher() {
// Closing the session handle closes all derived handles too.
}
bool WinHttpUrlFetcher::IsValid() const {
return session_.IsValid();
}
HRESULT WinHttpUrlFetcher::SetRequestHeader(const char* name,
const char* value) {
DCHECK(name);
DCHECK(value);
// TODO(rogerta): does not support multivalued headers.
request_headers_[name] = value;
return S_OK;
}
HRESULT WinHttpUrlFetcher::SetRequestBody(const char* body) {
DCHECK(body);
body_ = body;
return S_OK;
}
HRESULT WinHttpUrlFetcher::SetHttpRequestTimeout(const int timeout_in_millis) {
DCHECK(timeout_in_millis);
timeout_in_millis_ = timeout_in_millis;
return S_OK;
}
HRESULT WinHttpUrlFetcher::Fetch(std::vector<char>* response) {
USES_CONVERSION;
DCHECK(response);
response->clear();
if (!session_.IsValid()) {
LOGFN(ERROR) << "Invalid fetcher";
return E_UNEXPECTED;
}
// Open a connection to the server.
ScopedWinHttpHandle connect;
{
std::string host = url_.host();
ScopedWinHttpHandle::Handle connect_tmp = ::WinHttpConnect(
session_.Get(), A2CW(host.c_str()), INTERNET_DEFAULT_PORT, 0);
if (!connect_tmp) {
HRESULT hr = HRESULT_FROM_WIN32(::GetLastError());
LOGFN(ERROR) << "WinHttpConnect hr=" << putHR(hr);
return hr;
}
connect.Set(connect_tmp);
}
{
// Set timeout if specified.
if (timeout_in_millis_ != 0) {
if (!::WinHttpSetTimeouts(session_.Get(), timeout_in_millis_,
timeout_in_millis_, timeout_in_millis_,
timeout_in_millis_)) {
HRESULT hr = HRESULT_FROM_WIN32(::GetLastError());
LOGFN(ERROR) << "WinHttpSetTimeouts hr=" << putHR(hr);
return hr;
}
}
}
{
bool use_post = !body_.empty();
std::string path = url_.path();
std::string path_for_request = url_.PathForRequest();
ScopedWinHttpHandle::Handle request = ::WinHttpOpenRequest(
connect.Get(), use_post ? L"POST" : L"GET",
use_post ? A2CW(path.c_str()) : A2CW(path_for_request.c_str()), nullptr,
WINHTTP_NO_REFERER, WINHTTP_DEFAULT_ACCEPT_TYPES,
WINHTTP_FLAG_REFRESH | WINHTTP_FLAG_SECURE);
if (!request) {
HRESULT hr = HRESULT_FROM_WIN32(::GetLastError());
LOGFN(ERROR) << "WinHttpOpenRequest hr=" << putHR(hr);
return hr;
}
request_.Set(request);
}
// Add request headers.
for (const auto& kv : request_headers_) {
const wchar_t* key = A2CW(kv.first.c_str());
const wchar_t* value = A2CW(kv.second.c_str());
std::wstring header = base::StrCat({key, L": ", value});
if (!::WinHttpAddRequestHeaders(
request_.Get(), header.c_str(), header.length(),
WINHTTP_ADDREQ_FLAG_ADD | WINHTTP_ADDREQ_FLAG_REPLACE)) {
HRESULT hr = HRESULT_FROM_WIN32(::GetLastError());
LOGFN(ERROR) << "WinHttpAddRequestHeaders name=" << kv.first
<< " hr=" << putHR(hr);
return hr;
}
}
// Write request body if needed.
if (!::WinHttpSendRequest(request_.Get(), WINHTTP_NO_ADDITIONAL_HEADERS, 0,
const_cast<char*>(body_.c_str()), body_.length(),
body_.length(),
reinterpret_cast<DWORD_PTR>(nullptr))) {
HRESULT hr = HRESULT_FROM_WIN32(::GetLastError());
LOGFN(ERROR) << "WinHttpSendRequest hr=" << putHR(hr);
return hr;
}
// Wait for the response.
if (!::WinHttpReceiveResponse(request_.Get(), nullptr)) {
HRESULT hr = HRESULT_FROM_WIN32(::GetLastError());
LOGFN(ERROR) << "WinHttpReceiveResponse hr=" << putHR(hr);
return hr;
}
DWORD length = 0;
if (!::WinHttpQueryDataAvailable(request_.Get(), &length)) {
HRESULT hr = HRESULT_FROM_WIN32(::GetLastError());
LOGFN(ERROR) << "WinHttpQueryDataAvailable hr=" << putHR(hr);
return hr;
}
// 256k max response size to make sure bad data does not crash GCPW.
// This fetcher is only used to retrieve small information such as token
// handle status and profile picture images so it should not need a larger
// buffer than 256k.
constexpr size_t kMaxResponseSize = 256 * 1024 * 1024;
// Read the response.
auto buffer = std::make_unique<char[]>(length);
DWORD actual = 0;
do {
if (!::WinHttpReadData(request_.Get(), buffer.get(), length, &actual)) {
HRESULT hr = HRESULT_FROM_WIN32(::GetLastError());
LOGFN(ERROR) << "WinHttpReadData hr=" << putHR(hr);
return hr;
}
size_t current_size = response->size();
response->resize(response->size() + actual);
memcpy(response->data() + current_size, buffer.get(), actual);
if (response->size() >= kMaxResponseSize) {
LOGFN(ERROR) << "Response has exceeded max size=" << kMaxResponseSize;
return E_OUTOFMEMORY;
}
} while (actual);
return S_OK;
}
HRESULT WinHttpUrlFetcher::Close() {
request_.Close();
return S_OK;
}
HRESULT WinHttpUrlFetcher::BuildRequestAndFetchResultFromHttpService(
const GURL& request_url,
std::string access_token,
const std::vector<std::pair<std::string, std::string>>& headers,
const base::Value::Dict& request_dict,
const base::TimeDelta& request_timeout,
unsigned int request_retries,
std::optional<base::Value>* request_result) {
DCHECK(request_result);
std::string request_body;
if (!request_dict.empty() &&
!base::JSONWriter::Write(request_dict, &request_body)) {
LOGFN(ERROR) << "base::JSONWriter::Write failed";
return E_FAIL;
}
if ((request_dict.empty() && !request_body.empty()) ||
(!request_dict.empty() && request_body.empty())) {
LOGFN(ERROR) << "Mismatch between request dict and body";
return E_FAIL;
}
for (unsigned int try_count = 0; try_count <= request_retries; ++try_count) {
HttpServiceRequest* request = HttpServiceRequest::Create(
request_url, access_token, headers, request_body, request_timeout);
if (!request) {
LOGFN(ERROR)
<< "Could not create an HttpServiceRequest object. request url: "
<< request_url.spec() << " request body: " << request_body;
return E_FAIL;
}
auto extracted_param =
request->WaitForResponseFromHttpService(request_timeout);
if (!extracted_param)
continue;
*request_result = std::move(extracted_param);
const base::Value::Dict* error_detail =
(*request_result)->GetDict().FindDict(kErrorKeyInRequestResult);
if (!error_detail)
return S_OK;
LOGFN(ERROR) << "error: " << *error_detail;
// If error code is known, retry only on retryable server errors.
std::optional<int> error_code =
error_detail->FindInt(kHttpErrorCodeKeyNameInResponse);
if (error_code.has_value() &&
!base::Contains(kRetryableHttpErrorCodes, error_code.value())) {
return E_FAIL;
}
}
LOGFN(ERROR) << "Unable to serve http service request";
return E_FAIL;
}
} // namespace credential_provider