chromium/chrome/updater/ipc/update_service_internal_proxy_win.cc

// Copyright 2020 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/updater/ipc/update_service_internal_proxy_win.h"

#include <windows.h>

#include <wrl/client.h>
#include <wrl/implements.h>

#include <ios>
#include <optional>
#include <utility>

#include "base/check_op.h"
#include "base/functional/callback.h"
#include "base/logging.h"
#include "base/memory/scoped_refptr.h"
#include "base/sequence_checker.h"
#include "base/task/bind_post_task.h"
#include "chrome/updater/app/server/win/updater_internal_idl.h"
#include "chrome/updater/ipc/proxy_impl_base_win.h"
#include "chrome/updater/ipc/update_service_internal_proxy.h"
#include "chrome/updater/updater_scope.h"
#include "chrome/updater/util/win_util.h"
#include "chrome/updater/win/setup/setup_util.h"
#include "chrome/updater/win/win_constants.h"

namespace updater {
namespace {

class UpdaterInternalCallback
    : public DYNAMICIIDSIMPL(IUpdaterInternalCallback) {
 public:
  explicit UpdaterInternalCallback(
      base::OnceCallback<void(std::optional<RpcError>)> callback)
      : callback_(std::move(callback)) {}
  UpdaterInternalCallback(const UpdaterInternalCallback&) = delete;
  UpdaterInternalCallback& operator=(const UpdaterInternalCallback&) = delete;

  // Overrides for IUpdaterInternalCallback. Called on a system thread by COM
  // RPC.
  IFACEMETHODIMP Run(LONG result) override;

  // Disconnects this callback from its subject and ensures the callbacks are
  // not posted after this function is called. Returns the completion callback
  // so that the owner of this object can take back the callback ownership.
  base::OnceCallback<void(std::optional<RpcError>)> Disconnect();

 private:
  ~UpdaterInternalCallback() override {
    if (callback_) {
      std::move(callback_).Run(std::nullopt);
    }
  }

  // Called by IUpdaterInternalCallback::Run when the COM RPC call is done.
  base::OnceCallback<void(std::optional<RpcError>)> callback_;
};

IFACEMETHODIMP UpdaterInternalCallback::Run(LONG result) {
  VLOG(2) << __func__ << " result " << result << ".";
  return S_OK;
}

base::OnceCallback<void(std::optional<RpcError>)>
UpdaterInternalCallback::Disconnect() {
  VLOG(2) << __func__;
  return std::move(callback_);
}

}  // namespace

class UpdateServiceInternalProxyImplImpl
    : public base::RefCountedThreadSafe<UpdateServiceInternalProxyImplImpl>,
      public ProxyImplBase<UpdateServiceInternalProxyImplImpl,
                           IUpdaterInternal,
                           __uuidof(IUpdaterInternalUser),
                           __uuidof(IUpdaterInternalSystem)> {
 public:
  explicit UpdateServiceInternalProxyImplImpl(UpdaterScope scope)
      : ProxyImplBase(scope) {}

  static auto GetClassGuid(UpdaterScope scope) {
    return IsSystemInstall(scope) ? __uuidof(UpdaterInternalSystemClass)
                                  : __uuidof(UpdaterInternalUserClass);
  }

  void Run(base::OnceCallback<void(std::optional<RpcError>)> callback) {
    PostRPCTask(
        base::BindOnce(&UpdateServiceInternalProxyImplImpl::RunOnTaskRunner,
                       this, std::move(callback)));
  }

  void Hello(base::OnceCallback<void(std::optional<RpcError>)> callback) {
    PostRPCTask(
        base::BindOnce(&UpdateServiceInternalProxyImplImpl::HelloOnTaskRunner,
                       this, std::move(callback)));
  }

 private:
  friend class base::RefCountedThreadSafe<UpdateServiceInternalProxyImplImpl>;
  ~UpdateServiceInternalProxyImplImpl() = default;

  void RunOnTaskRunner(
      base::OnceCallback<void(std::optional<RpcError>)> callback) {
    DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
    if (HRESULT connection = ConnectToServer(); FAILED(connection)) {
      std::move(callback).Run(connection);
      return;
    }
    auto callback_wrapper =
        MakeComObjectOrCrash<UpdaterInternalCallback>(std::move(callback));
    HRESULT hr = get_interface()->Run(callback_wrapper.Get());
    if (FAILED(hr)) {
      VLOG(2) << "Failed to call IUpdaterInternal::Run" << std::hex << hr;
      callback_wrapper->Disconnect().Run(hr);
      return;
    }
  }

  void HelloOnTaskRunner(
      base::OnceCallback<void(std::optional<RpcError>)> callback) {
    DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
    if (HRESULT connection = ConnectToServer(); FAILED(connection)) {
      std::move(callback).Run(connection);
      return;
    }
    auto callback_wrapper =
        MakeComObjectOrCrash<UpdaterInternalCallback>(std::move(callback));
    HRESULT hr = get_interface()->Hello(callback_wrapper.Get());
    if (FAILED(hr)) {
      VLOG(2) << "Failed to call IUpdaterInternal::Hello" << std::hex << hr;
      callback_wrapper->Disconnect().Run(hr);
      return;
    }
  }
};

UpdateServiceInternalProxyImpl::UpdateServiceInternalProxyImpl(
    UpdaterScope scope)
    : impl_(base::MakeRefCounted<UpdateServiceInternalProxyImplImpl>(scope)) {}

UpdateServiceInternalProxyImpl::~UpdateServiceInternalProxyImpl() {
  VLOG(1) << __func__;
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  UpdateServiceInternalProxyImplImpl::Destroy(std::move(impl_));
}

void UpdateServiceInternalProxyImpl::Run(
    base::OnceCallback<void(std::optional<RpcError>)> callback) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  VLOG(1) << __func__;
  impl_->Run(base::BindPostTaskToCurrentDefault(std::move(callback)));
}

void UpdateServiceInternalProxyImpl::Hello(
    base::OnceCallback<void(std::optional<RpcError>)> callback) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  VLOG(1) << __func__;
  impl_->Hello(base::BindPostTaskToCurrentDefault(std::move(callback)));
}

scoped_refptr<UpdateServiceInternal> CreateUpdateServiceInternalProxy(
    UpdaterScope updater_scope) {
  return base::MakeRefCounted<UpdateServiceInternalProxy>(
      base::MakeRefCounted<UpdateServiceInternalProxyImpl>(updater_scope));
}

}  // namespace updater