chromium/chromeos/services/machine_learning/cpp/ash/service_connection_ash.cc

// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chromeos/services/machine_learning/public/cpp/service_connection.h"

#include <utility>

#include "base/component_export.h"
#include "base/functional/bind.h"
#include "base/no_destructor.h"
#include "base/sequence_checker.h"
#include "base/task/sequenced_task_runner.h"
#include "chromeos/dbus/machine_learning/machine_learning_client.h"
#include "chromeos/services/machine_learning/public/mojom/machine_learning_service.mojom.h"
#include "mojo/core/embedder/embedder.h"
#include "mojo/public/cpp/bindings/remote.h"
#include "mojo/public/cpp/platform/platform_channel.h"
#include "mojo/public/cpp/system/invitation.h"
#include "third_party/cros_system_api/dbus/service_constants.h"

namespace ash {
namespace machine_learning {

namespace {

// Real Impl of ServiceConnection
class COMPONENT_EXPORT(CHROMEOS_MLSERVICE) ServiceConnectionAsh
    : public chromeos::machine_learning::ServiceConnection {
 public:
  ServiceConnectionAsh();
  ServiceConnectionAsh(const ServiceConnectionAsh&) = delete;
  ServiceConnectionAsh& operator=(const ServiceConnectionAsh&) = delete;

  ~ServiceConnectionAsh() override;

  chromeos::machine_learning::mojom::MachineLearningService&
  GetMachineLearningService() override;

  void BindMachineLearningService(
      mojo::PendingReceiver<
          chromeos::machine_learning::mojom::MachineLearningService> receiver)
      override;

  void Initialize() override;

 private:
  // Binds the primordial, top-level interface |machine_learning_service_| to an
  // implementation in the ML Service daemon, if it is not already bound. The
  // binding is accomplished via D-Bus bootstrap.
  void BindPrimordialMachineLearningServiceIfNeeded();

  // Mojo disconnect handler. Resets |machine_learning_service_|, which
  // will be reconnected upon next use.
  void OnMojoDisconnect();

  // Response callback for MlClient::BootstrapMojoConnection.
  void OnBootstrapMojoConnectionResponse(bool success);

  mojo::Remote<chromeos::machine_learning::mojom::MachineLearningService>
      machine_learning_service_;
  scoped_refptr<base::SequencedTaskRunner> task_runner_;

  SEQUENCE_CHECKER(sequence_checker_);
};

ServiceConnectionAsh::ServiceConnectionAsh() {
  DETACH_FROM_SEQUENCE(sequence_checker_);
}

ServiceConnectionAsh::~ServiceConnectionAsh() = default;

chromeos::machine_learning::mojom::MachineLearningService&
ServiceConnectionAsh::GetMachineLearningService() {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  DCHECK(task_runner_)
      << "Call Initialize before first use of ServiceConnection.";
  BindPrimordialMachineLearningServiceIfNeeded();
  return *machine_learning_service_.get();
}

void ServiceConnectionAsh::BindMachineLearningService(
    mojo::PendingReceiver<
        chromeos::machine_learning::mojom::MachineLearningService> receiver) {
  DCHECK(task_runner_)
      << "Call Initialize before first use of ServiceConnection.";
  if (!task_runner_->RunsTasksInCurrentSequence()) {
    task_runner_->PostTask(
        FROM_HERE,
        base::BindOnce(&ServiceConnectionAsh::BindMachineLearningService,
                       base::Unretained(this), std::move(receiver)));
    return;
  }

  GetMachineLearningService().Clone(std::move(receiver));
}

void ServiceConnectionAsh::Initialize() {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  DCHECK(!task_runner_) << "Initialize must be called only once.";

  task_runner_ = base::SequencedTaskRunner::GetCurrentDefault();
}

void ServiceConnectionAsh::BindPrimordialMachineLearningServiceIfNeeded() {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  if (machine_learning_service_) {
    return;
  }

  mojo::PlatformChannel platform_channel;

  // Prepare a Mojo invitation to send through |platform_channel|.
  mojo::OutgoingInvitation invitation;
  // Include an initial Mojo pipe in the invitation.
  mojo::ScopedMessagePipeHandle pipe =
      invitation.AttachMessagePipe(ml::kBootstrapMojoConnectionChannelToken);
  if (mojo::core::IsMojoIpczEnabled()) {
    // IPCz requires an application to explicitly opt in to broker sharing
    // and inheritance when establishing a direct connection between two
    // non-broker nodes.
    invitation.set_extra_flags(MOJO_SEND_INVITATION_FLAG_SHARE_BROKER);
  }
  mojo::OutgoingInvitation::Send(std::move(invitation),
                                 base::kNullProcessHandle,
                                 platform_channel.TakeLocalEndpoint());

  // Bind our end of |pipe| to our mojo::Remote<MachineLearningService>. The
  // daemon should bind its end to a MachineLearningService implementation.
  machine_learning_service_.Bind(
      mojo::PendingRemote<
          chromeos::machine_learning::mojom::MachineLearningService>(
          std::move(pipe), 0u /* version */));
  machine_learning_service_.set_disconnect_handler(base::BindOnce(
      &ServiceConnectionAsh::OnMojoDisconnect, base::Unretained(this)));

  // Send the file descriptor for the other end of |platform_channel| to the
  // ML service daemon over D-Bus.
  chromeos::MachineLearningClient::Get()->BootstrapMojoConnection(
      platform_channel.TakeRemoteEndpoint().TakePlatformHandle().TakeFD(),
      base::BindOnce(&ServiceConnectionAsh::OnBootstrapMojoConnectionResponse,
                     base::Unretained(this)));
}

void ServiceConnectionAsh::OnMojoDisconnect() {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  // Connection errors are not expected so log a warning.
  LOG(WARNING) << "ML Service Mojo connection closed";
  machine_learning_service_.reset();
}

void ServiceConnectionAsh::OnBootstrapMojoConnectionResponse(
    const bool success) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  if (!success) {
    LOG(WARNING) << "BootstrapMojoConnection D-Bus call failed";
    machine_learning_service_.reset();
  }
}

}  // namespace

}  // namespace machine_learning
}  // namespace ash

namespace chromeos {
namespace machine_learning {

ServiceConnection* ServiceConnection::CreateRealInstance() {
  static base::NoDestructor<ash::machine_learning::ServiceConnectionAsh>
      service_connection;
  return service_connection.get();
}

}  // namespace machine_learning
}  // namespace chromeos