chromium/chromeos/ash/components/dbus/kerberos/kerberos_client.cc

// Copyright 2019 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chromeos/ash/components/dbus/kerberos/kerberos_client.h"

#include <optional>
#include <utility>

#include "base/callback_list.h"
#include "base/functional/bind.h"
#include "base/location.h"
#include "base/logging.h"
#include "base/memory/raw_ptr.h"
#include "base/memory/weak_ptr.h"
#include "base/task/single_thread_task_runner.h"
#include "chromeos/ash/components/dbus/kerberos/fake_kerberos_client.h"
#include "dbus/bus.h"
#include "dbus/message.h"
#include "dbus/object_proxy.h"
#include "third_party/cros_system_api/dbus/kerberos/dbus-constants.h"

namespace ash {
namespace {

KerberosClient* g_instance = nullptr;

// Tries to parse a proto message from |response| into |proto|.
// Returns false if |response| is nullptr or the message cannot be parsed.
bool ParseProto(dbus::Response* response,
                google::protobuf::MessageLite* proto) {
  if (!response) {
    LOG(ERROR) << "Failed to call kerberosd";
    return false;
  }

  dbus::MessageReader reader(response);
  if (!reader.PopArrayOfBytesAsProto(proto)) {
    LOG(ERROR) << "Failed to parse response message from kerberosd";
    return false;
  }

  return true;
}

void OnSignalConnected(const std::string& interface_name,
                       const std::string& signal_name,
                       bool success) {
  DCHECK_EQ(interface_name, kerberos::kKerberosInterface);
  DCHECK(success);
}

// "Real" implementation of KerberosClient talking to the Kerberos daemon on
// the ChromeOS side.
class KerberosClientImpl : public KerberosClient {
 public:
  KerberosClientImpl() = default;

  KerberosClientImpl(const KerberosClientImpl&) = delete;
  KerberosClientImpl& operator=(const KerberosClientImpl&) = delete;

  ~KerberosClientImpl() override = default;

  // KerberosClient overrides:
  void AddAccount(const kerberos::AddAccountRequest& request,
                  AddAccountCallback callback) override {
    CallProtoMethod(kerberos::kAddAccountMethod, request, std::move(callback));
  }

  void RemoveAccount(const kerberos::RemoveAccountRequest& request,
                     RemoveAccountCallback callback) override {
    CallProtoMethod(kerberos::kRemoveAccountMethod, request,
                    std::move(callback));
  }

  void ClearAccounts(const kerberos::ClearAccountsRequest& request,
                     ClearAccountsCallback callback) override {
    CallProtoMethod(kerberos::kClearAccountsMethod, request,
                    std::move(callback));
  }

  void ListAccounts(const kerberos::ListAccountsRequest& request,
                    ListAccountsCallback callback) override {
    CallProtoMethod(kerberos::kListAccountsMethod, request,
                    std::move(callback));
  }

  void SetConfig(const kerberos::SetConfigRequest& request,
                 SetConfigCallback callback) override {
    CallProtoMethod(kerberos::kSetConfigMethod, request, std::move(callback));
  }

  void ValidateConfig(const kerberos::ValidateConfigRequest& request,
                      ValidateConfigCallback callback) override {
    CallProtoMethod(kerberos::kValidateConfigMethod, request,
                    std::move(callback));
  }

  void AcquireKerberosTgt(const kerberos::AcquireKerberosTgtRequest& request,
                          int password_fd,
                          AcquireKerberosTgtCallback callback) override {
    // kAcquireKerberosTgtMethod takes |password_fd| as extra arg.
    CallProtoMethodWithExtraArgs(
        kerberos::kAcquireKerberosTgtMethod, request, std::move(callback),
        base::BindOnce(
            [](int in_password_fd, dbus::MessageWriter* writer) {
              writer->AppendFileDescriptor(in_password_fd);
            },
            password_fd));
  }

  void GetKerberosFiles(const kerberos::GetKerberosFilesRequest& request,
                        GetKerberosFilesCallback callback) override {
    CallProtoMethod(kerberos::kGetKerberosFilesMethod, request,
                    std::move(callback));
  }

  base::CallbackListSubscription SubscribeToKerberosFileChangedSignal(
      KerberosFilesChangedCallback callback) override {
    proxy_->ConnectToSignal(
        kerberos::kKerberosInterface, kerberos::kKerberosFilesChangedSignal,
        base::BindRepeating(&KerberosClientImpl::OnKerberosFilesChanged,
                            weak_factory_.GetWeakPtr()),
        base::BindOnce(&OnSignalConnected));

    return kerberos_files_changed_callback_list_.Add(callback);
  }

  base::CallbackListSubscription SubscribeToKerberosTicketExpiringSignal(
      KerberosTicketExpiringCallback callback) override {
    proxy_->ConnectToSignal(
        kerberos::kKerberosInterface, kerberos::kKerberosTicketExpiringSignal,
        base::BindRepeating(&KerberosClientImpl::OnKerberosTicketExpiring,
                            weak_factory_.GetWeakPtr()),
        base::BindOnce(&OnSignalConnected));

    return kerberos_ticket_expiring_callback_list_.Add(callback);
  }

  void OnKerberosFilesChanged(dbus::Signal* signal) {
    DCHECK_EQ(signal->GetInterface(), kerberos::kKerberosInterface);
    DCHECK_EQ(signal->GetMember(), kerberos::kKerberosFilesChangedSignal);

    dbus::MessageReader signal_reader(signal);
    std::string principal_name;
    if (!signal_reader.PopString(&principal_name)) {
      LOG(ERROR)
          << "Failed to read principal name for KerberosFilesChanged signal";
      return;
    }

    DCHECK(!kerberos_files_changed_callback_list_.empty());
    kerberos_files_changed_callback_list_.Notify(principal_name);
  }

  void OnKerberosTicketExpiring(dbus::Signal* signal) {
    DCHECK_EQ(signal->GetInterface(), kerberos::kKerberosInterface);
    DCHECK_EQ(signal->GetMember(), kerberos::kKerberosTicketExpiringSignal);

    dbus::MessageReader signal_reader(signal);
    std::string principal_name;
    if (!signal_reader.PopString(&principal_name)) {
      LOG(ERROR)
          << "Failed to read principal name for KerberosTicketExpiring signal";
      return;
    }

    DCHECK(!kerberos_ticket_expiring_callback_list_.empty());
    kerberos_ticket_expiring_callback_list_.Notify(principal_name);
  }

  void Init(dbus::Bus* bus) {
    proxy_ =
        bus->GetObjectProxy(kerberos::kKerberosServiceName,
                            dbus::ObjectPath(kerberos::kKerberosServicePath));
  }

 private:
  using KerberosFilesChangedCallbackList =
      base::RepeatingCallbackList<PrincipalNameFunc>;
  using KerberosTicketExpiringCallbackList =
      base::RepeatingCallbackList<PrincipalNameFunc>;

  TestInterface* GetTestInterface() override { return nullptr; }

  // Calls kerberosd's |method_name| method, passing in |request| as input. Once
  // the (asynchronous) call finishes, |callback| is called with the response
  // proto (on the same thread as this call).
  template <class TRequest, class TResponse>
  void CallProtoMethodWithExtraArgs(
      const char* method_name,
      const TRequest& request,
      base::OnceCallback<void(const TResponse&)> callback,
      base::OnceCallback<void(dbus::MessageWriter*)> write_extra_args) {
    dbus::MethodCall method_call(kerberos::kKerberosInterface, method_name);
    dbus::MessageWriter writer(&method_call);
    if (!writer.AppendProtoAsArrayOfBytes(request)) {
      TResponse response;
      response.set_error(kerberos::ERROR_DBUS_FAILURE);
      base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
          FROM_HERE, base::BindOnce(std::move(callback), response));
      return;
    }
    if (write_extra_args)
      std::move(write_extra_args).Run(&writer);
    proxy_->CallMethod(
        &method_call, dbus::ObjectProxy::TIMEOUT_USE_DEFAULT,
        base::BindOnce(&KerberosClientImpl::HandleResponse<TResponse>,
                       weak_factory_.GetWeakPtr(), std::move(callback)));
  }

  // Same as CallProtoMethodWithExtraArgs, but doesn't pass in extra args.
  // Use for methods that only take a request proto as input.
  template <class TRequest, class TResponse>
  void CallProtoMethod(const char* method_name,
                       const TRequest& request,
                       base::OnceCallback<void(const TResponse&)> callback) {
    CallProtoMethodWithExtraArgs(method_name, request, std::move(callback), {});
  }

  // Parses the response proto message from |response| and calls |callback| with
  // the decoded message. Calls |callback| with an |ERROR_DBUS_FAILURE| message
  // on error.
  template <class TProto>
  void HandleResponse(base::OnceCallback<void(const TProto&)> callback,
                      dbus::Response* response) {
    TProto response_proto;
    if (!ParseProto(response, &response_proto))
      response_proto.set_error(kerberos::ERROR_DBUS_FAILURE);
    std::move(callback).Run(response_proto);
  }

  // D-Bus proxy for the Kerberos daemon, not owned.
  raw_ptr<dbus::ObjectProxy> proxy_ = nullptr;

  // Signal callback lists.
  KerberosFilesChangedCallbackList kerberos_files_changed_callback_list_;
  KerberosFilesChangedCallbackList kerberos_ticket_expiring_callback_list_;

  base::WeakPtrFactory<KerberosClientImpl> weak_factory_{this};
};

}  // namespace

KerberosClient::KerberosClient() {
  CHECK(!g_instance);
  g_instance = this;
}

KerberosClient::~KerberosClient() {
  CHECK_EQ(this, g_instance);
  g_instance = nullptr;
}

// static
void KerberosClient::Initialize(dbus::Bus* bus) {
  CHECK(bus);
  (new KerberosClientImpl())->Init(bus);
}

// static
void KerberosClient::InitializeFake() {
  new FakeKerberosClient();
}

// static
void KerberosClient::Shutdown() {
  CHECK(g_instance);
  delete g_instance;
}

// static
KerberosClient* KerberosClient::Get() {
  return g_instance;
}

}  // namespace ash