// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/updater/app/app_net_worker.h"
#include <cstdint>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>
#include "base/command_line.h"
#include "base/containers/flat_map.h"
#include "base/files/file_path.h"
#include "base/functional/bind.h"
#include "base/memory/ref_counted.h"
#include "base/memory/weak_ptr.h"
#include "base/sequence_checker.h"
#include "base/task/bind_post_task.h"
#include "base/threading/sequence_bound.h"
#include "base/threading/thread.h"
#include "chrome/updater/app/app.h"
#include "chrome/updater/constants.h"
#include "chrome/updater/net/mac/mojom/updater_fetcher.mojom.h"
#include "chrome/updater/net/network.h"
#include "chrome/updater/net/network_file_fetcher.h"
#include "chrome/updater/policy/service.h"
#include "components/update_client/network.h"
#include "mojo/public/cpp/bindings/pending_receiver.h"
#include "mojo/public/cpp/bindings/receiver.h"
#include "mojo/public/cpp/bindings/remote.h"
#include "mojo/public/cpp/bindings/self_owned_receiver.h"
#include "mojo/public/cpp/platform/platform_channel.h"
#include "mojo/public/cpp/system/invitation.h"
#include "mojo/public/cpp/system/message_pipe.h"
namespace updater {
namespace {
// Creates a `PostRequestObserver` remote with the given callback and put it
// into a thin wrapper for ref-counting.
class PostRequestObserverWrapper
: public base::RefCountedThreadSafe<PostRequestObserverWrapper> {
public:
explicit PostRequestObserverWrapper(
mojom::FetchService::PostRequestCallback callback) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
std::move(callback).Run(observer_.BindNewPipeAndPassReceiver());
}
void OnResponseStarted(int32_t http_status_code, int64_t content_length) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
observer_->OnResponseStarted(http_status_code, content_length);
}
void OnProgress(int64_t current) { observer_->OnProgress(current); }
void OnRequestComplete(std::unique_ptr<std::string> response_body,
int32_t net_error,
const std::string& header_etag,
const std::string& header_x_cup_server_proof,
int64_t xheader_retry_after_sec) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
observer_->OnRequestComplete(*response_body, net_error, header_etag,
header_x_cup_server_proof,
xheader_retry_after_sec);
}
private:
friend class base::RefCountedThreadSafe<PostRequestObserverWrapper>;
virtual ~PostRequestObserverWrapper() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
}
SEQUENCE_CHECKER(sequence_checker_);
mojo::Remote<mojom::PostRequestObserver> observer_;
};
// Creates a `FileDownloadObserver` remote with the given callback and put it
// into a thin wrapper for ref-counting.
class FileDownloadObserverWrapper
: public base::RefCountedThreadSafe<FileDownloadObserverWrapper> {
public:
explicit FileDownloadObserverWrapper(
mojom::FetchService::DownloadToFileCallback callback) {
std::move(callback).Run(observer_.BindNewPipeAndPassReceiver());
}
void OnResponseStarted(int32_t http_status_code, int64_t content_length) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
observer_->OnResponseStarted(http_status_code, content_length);
}
void OnProgress(int64_t current) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
observer_->OnProgress(current);
}
void OnDownloadComplete(int32_t net_error, int64_t content_length) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
observer_->OnDownloadComplete(net_error, content_length);
}
private:
friend class base::RefCountedThreadSafe<FileDownloadObserverWrapper>;
virtual ~FileDownloadObserverWrapper() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
}
SEQUENCE_CHECKER(sequence_checker_);
mojo::Remote<mojom::FileDownloadObserver> observer_;
};
// The stub class that translates and forwards the Mojo requests to the
// underlying fetcher. It also keeps a reference to the remote receiver and
// sends the result back when fetch is done.
class FetchServiceImpl : public mojom::FetchService {
public:
FetchServiceImpl(mojo::PendingReceiver<mojom::FetchService> pending_receiver,
base::OnceCallback<void(int)> on_complete_callback);
~FetchServiceImpl() override {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
}
// Overrides for mojom::FetchService.
void PostRequest(const ::GURL& url,
const std::string& post_data,
const std::string& content_type,
std::vector<mojom::HttpHeaderPtr> additional_headers,
mojom::FetchService::PostRequestCallback callback) override;
void DownloadToFile(
const ::GURL& url,
::base::File output_file,
mojom::FetchService::DownloadToFileCallback callback) override;
private:
SEQUENCE_CHECKER(sequence_checker_);
mojo::Receiver<mojom::FetchService> receiver_;
base::OnceCallback<void(int)> on_complete_callback_;
// Network fetcher for POST request.
std::unique_ptr<update_client::NetworkFetcher> fetcher_;
// For file download, `update_client::NetworkFetcher` interface takes
// a `base::FilePath` as the output, and the Mojo interface takes a
// `base::File` object. This customized fetcher is used to support
// the Mojo interface.
std::unique_ptr<NetworkFileFetcher> file_fetcher_;
};
FetchServiceImpl::FetchServiceImpl(
mojo::PendingReceiver<mojom::FetchService> pending_receiver,
base::OnceCallback<void(int)> on_complete_callback)
: receiver_(this, std::move(pending_receiver)),
on_complete_callback_(std::move(on_complete_callback)) {}
void FetchServiceImpl::PostRequest(
const ::GURL& url,
const std::string& post_data,
const std::string& content_type,
std::vector<mojom::HttpHeaderPtr> additional_headers,
mojom::FetchService::PostRequestCallback callback) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
auto wrapper =
base::MakeRefCounted<PostRequestObserverWrapper>(std::move(callback));
if (fetcher_ || file_fetcher_) {
LOG(ERROR) << "Each service instance can do only one fetch request.";
wrapper->OnRequestComplete(nullptr, kErrorMojoRequestRejected, {}, {}, -1);
std::move(on_complete_callback_).Run(kErrorMojoRequestRejected);
return;
}
base::flat_map<std::string, std::string> headers;
for (const auto& header : additional_headers) {
headers.emplace(header->name, header->value);
}
// Creates a network fetcher without any proxy configuration (let the system
// handle the proxy settings) to fetch data.
fetcher_ =
base::MakeRefCounted<NetworkFetcherFactory>(std::nullopt)->Create();
fetcher_->PostRequest(
url, post_data, content_type, headers,
base::BindRepeating(&PostRequestObserverWrapper::OnResponseStarted,
wrapper),
base::BindRepeating(&PostRequestObserverWrapper::OnProgress, wrapper),
base::BindOnce(
[](scoped_refptr<PostRequestObserverWrapper> wrapper,
base::OnceCallback<void(int)> callback,
std::unique_ptr<std::string> response_body, int32_t net_error,
const std::string& header_etag,
const std::string& header_x_cup_server_proof,
int64_t xheader_retry_after_sec) {
wrapper->OnRequestComplete(std::move(response_body), net_error,
header_etag, header_x_cup_server_proof,
xheader_retry_after_sec);
std::move(callback).Run(net_error);
},
wrapper, std::move(on_complete_callback_)));
}
void FetchServiceImpl::DownloadToFile(
const ::GURL& url,
::base::File output_file,
mojom::FetchService::DownloadToFileCallback callback) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
auto wrapper =
base::MakeRefCounted<FileDownloadObserverWrapper>(std::move(callback));
if (fetcher_ || file_fetcher_) {
LOG(ERROR) << "Each service instance can do only one fetch request.";
wrapper->OnDownloadComplete(kErrorMojoRequestRejected, -1);
std::move(on_complete_callback_).Run(kErrorMojoRequestRejected);
return;
}
file_fetcher_ = std::make_unique<NetworkFileFetcher>();
file_fetcher_->Download(
url, std::move(output_file),
base::BindRepeating(&FileDownloadObserverWrapper::OnResponseStarted,
wrapper),
base::BindRepeating(&FileDownloadObserverWrapper::OnProgress, wrapper),
base::BindOnce(
[](scoped_refptr<FileDownloadObserverWrapper> wrapper,
base::OnceCallback<void(int)> callback, int32_t net_error,
int64_t content_length) {
wrapper->OnDownloadComplete(net_error, content_length);
std::move(callback).Run(net_error);
},
wrapper, std::move(on_complete_callback_)));
}
// AppNetWorker runs networking tasks in a dedicated process.
class AppNetWorker : public App {
public:
AppNetWorker() {
net_thread_.StartWithOptions({base::MessagePumpType::IO, 0});
}
private:
~AppNetWorker() override = default;
void FirstTaskRun() override {
// This process must be started with the command line switch
/// `--mojo-platform-channel-handle=N`. In other words, the command line
// must be prepared by
// `mojo::PlatformChannel::PrepareToPassRemoteEndpoint()`.
mojo::PlatformChannelEndpoint endpoint =
mojo::PlatformChannel::RecoverPassedEndpointFromCommandLine(
*base::CommandLine::ForCurrentProcess());
if (!endpoint.is_valid()) {
Shutdown(kErrorMojoConnectionFailure);
return;
}
mojo::ScopedMessagePipeHandle pipe =
mojo::IncomingInvitation::AcceptIsolated(std::move(endpoint));
if (!pipe->is_valid()) {
Shutdown(kErrorMojoConnectionFailure);
return;
}
fetcher_stub_ = base::SequenceBound<FetchServiceImpl>(
net_thread_.task_runner(),
mojo::PendingReceiver<mojom::FetchService>(std::move(pipe)),
base::BindPostTaskToCurrentDefault(base::BindOnce(
&AppNetWorker::Shutdown, weak_ptr_factory_.GetWeakPtr())));
// TODO(crbug.com/353751917): Add a timer that shutdown this process if
// no incoming network requests in time.
}
base::Thread net_thread_{"Network"};
base::SequenceBound<FetchServiceImpl> fetcher_stub_;
base::WeakPtrFactory<AppNetWorker> weak_ptr_factory_{this};
};
} // namespace
scoped_refptr<App> MakeAppNetWorker() {
return base::MakeRefCounted<AppNetWorker>();
}
} // namespace updater