chromium/chromeos/ash/components/dbus/shill/shill_third_party_vpn_driver_client.cc

// Copyright 2014 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifdef UNSAFE_BUFFERS_BUILD
// TODO(crbug.com/40285824): Remove this and convert code to safer constructs.
#pragma allow_unsafe_buffers
#endif

#include "chromeos/ash/components/dbus/shill/shill_third_party_vpn_driver_client.h"

#include <stddef.h>
#include <stdint.h>

#include <map>
#include <set>

#include "base/containers/contains.h"
#include "base/functional/bind.h"
#include "base/functional/callback_helpers.h"
#include "base/logging.h"
#include "base/memory/raw_ptr.h"
#include "base/values.h"
#include "chromeos/ash/components/dbus/shill/fake_shill_third_party_vpn_driver_client.h"
#include "chromeos/ash/components/dbus/shill/shill_third_party_vpn_observer.h"
#include "dbus/bus.h"
#include "dbus/message.h"
#include "dbus/object_proxy.h"
#include "third_party/cros_system_api/dbus/service_constants.h"

namespace ash {

namespace {

const char* kSetParametersKeyList[] = {
    shill::kAddressParameterThirdPartyVpn,
    shill::kBroadcastAddressParameterThirdPartyVpn,
    shill::kExclusionListParameterThirdPartyVpn,
    shill::kInclusionListParameterThirdPartyVpn,
    shill::kSubnetPrefixParameterThirdPartyVpn,
    shill::kMtuParameterThirdPartyVpn,
    shill::kDomainSearchParameterThirdPartyVpn,
    shill::kDnsServersParameterThirdPartyVpn,
    shill::kReconnectParameterThirdPartyVpn};

ShillThirdPartyVpnDriverClient* g_instance = nullptr;

// The ShillThirdPartyVpnDriverClient implementation.
class ShillThirdPartyVpnDriverClientImpl
    : public ShillThirdPartyVpnDriverClient {
 public:
  explicit ShillThirdPartyVpnDriverClientImpl(dbus::Bus* bus);

  ShillThirdPartyVpnDriverClientImpl(
      const ShillThirdPartyVpnDriverClientImpl&) = delete;
  ShillThirdPartyVpnDriverClientImpl& operator=(
      const ShillThirdPartyVpnDriverClientImpl&) = delete;

  ~ShillThirdPartyVpnDriverClientImpl() override;

  // ShillThirdPartyVpnDriverClient overrides
  void AddShillThirdPartyVpnObserver(
      const std::string& object_path_value,
      ShillThirdPartyVpnObserver* observer) override;

  void RemoveShillThirdPartyVpnObserver(
      const std::string& object_path_value) override;

  void SetParameters(const std::string& object_path_value,
                     const base::Value::Dict& parameters,
                     StringCallback callback,
                     ErrorCallback error_callback) override;

  void UpdateConnectionState(const std::string& object_path_value,
                             const uint32_t connection_state,
                             base::OnceClosure callback,
                             ErrorCallback error_callback) override;

  void SendPacket(const std::string& object_path_value,
                  const std::vector<char>& ip_packet,
                  base::OnceClosure callback,
                  ErrorCallback error_callback) override;

  TestInterface* GetTestInterface() override { return nullptr; }

 private:
  class HelperInfo {
   public:
    explicit HelperInfo(dbus::ObjectProxy* object_proxy);

    ShillClientHelper* helper() { return &helper_; }
    ShillThirdPartyVpnObserver* observer() { return observer_; }

    void set_observer(ShillThirdPartyVpnObserver* observer) {
      observer_ = observer;
    }

    base::WeakPtr<HelperInfo> GetWeakPtr() {
      return weak_ptr_factory_.GetWeakPtr();
    }

   private:
    ShillClientHelper helper_;
    raw_ptr<ShillThirdPartyVpnObserver> observer_;

    base::WeakPtrFactory<HelperInfo> weak_ptr_factory_{this};
  };
  using HelperMap = std::map<std::string, HelperInfo*>;

  static void OnPacketReceived(base::WeakPtr<HelperInfo> helper_info,
                               dbus::Signal* signal);
  static void OnPlatformMessage(base::WeakPtr<HelperInfo> helper_info,
                                dbus::Signal* signal);
  static void OnSignalConnected(const std::string& interface,
                                const std::string& signal,
                                bool success);

  // Returns or creates the corresponding ShillClientHelper for the
  // |object_path_value|.
  ShillClientHelper* GetHelper(const std::string& object_path_value);

  // Returns or creates the corresponding HelperInfo for the
  // |object_path_value|.
  HelperInfo* GetHelperInfo(const std::string& object_path_value);

  // Returns the corresponding HelperInfo for the |object_path| if exists,
  // nullptr if not.
  HelperInfo* FindHelperInfo(const dbus::ObjectPath& object_path);

  // Deletes the helper object corresponding to |object_path|.
  void DeleteHelper(const dbus::ObjectPath& object_path);

  raw_ptr<dbus::Bus> bus_;
  HelperMap helpers_;
  std::set<std::string> valid_keys_;
};

ShillThirdPartyVpnDriverClientImpl::HelperInfo::HelperInfo(
    dbus::ObjectProxy* object_proxy)
    : helper_(object_proxy), observer_(nullptr) {}

ShillThirdPartyVpnDriverClientImpl::ShillThirdPartyVpnDriverClientImpl(
    dbus::Bus* bus)
    : bus_(bus) {
  for (uint32_t i = 0; i < std::size(kSetParametersKeyList); ++i) {
    valid_keys_.insert(kSetParametersKeyList[i]);
  }
}

ShillThirdPartyVpnDriverClientImpl::~ShillThirdPartyVpnDriverClientImpl() {
  for (auto& iter : helpers_) {
    HelperInfo* helper_info = iter.second;
    bus_->RemoveObjectProxy(
        shill::kFlimflamServiceName,
        helper_info->helper()->object_proxy()->object_path(),
        base::DoNothing());
    delete helper_info;
  }
}

void ShillThirdPartyVpnDriverClientImpl::AddShillThirdPartyVpnObserver(
    const std::string& object_path_value,
    ShillThirdPartyVpnObserver* observer) {
  HelperInfo* helper_info = GetHelperInfo(object_path_value);
  if (helper_info->observer()) {
    LOG(ERROR) << "Observer exists for " << object_path_value;
    return;
  }

  // TODO(kaliamoorthi): Remove the const_cast.
  helper_info->set_observer(observer);
  dbus::ObjectProxy* proxy =
      const_cast<dbus::ObjectProxy*>(helper_info->helper()->object_proxy());

  proxy->ConnectToSignal(
      shill::kFlimflamThirdPartyVpnInterface, shill::kOnPlatformMessageFunction,
      base::BindRepeating(
          &ShillThirdPartyVpnDriverClientImpl::OnPlatformMessage,
          helper_info->GetWeakPtr()),
      base::BindOnce(&ShillThirdPartyVpnDriverClientImpl::OnSignalConnected));

  proxy->ConnectToSignal(
      shill::kFlimflamThirdPartyVpnInterface, shill::kOnPacketReceivedFunction,
      base::BindRepeating(&ShillThirdPartyVpnDriverClientImpl::OnPacketReceived,
                          helper_info->GetWeakPtr()),
      base::BindOnce(&ShillThirdPartyVpnDriverClientImpl::OnSignalConnected));
}

void ShillThirdPartyVpnDriverClientImpl::RemoveShillThirdPartyVpnObserver(
    const std::string& object_path_value) {
  HelperInfo* helper_info = FindHelperInfo(dbus::ObjectPath(object_path_value));
  if (!helper_info) {
    LOG(ERROR) << "Unknown object_path_value " << object_path_value;
    return;
  }

  CHECK(helper_info->observer());
  helper_info->set_observer(nullptr);
  DeleteHelper(dbus::ObjectPath(object_path_value));
}

void ShillThirdPartyVpnDriverClientImpl::DeleteHelper(
    const dbus::ObjectPath& object_path) {
  HelperInfo* helper_info = FindHelperInfo(dbus::ObjectPath(object_path));
  if (!helper_info) {
    LOG(ERROR) << "Unknown object_path " << object_path.value();
    return;
  }

  bus_->RemoveObjectProxy(shill::kFlimflamServiceName, object_path,
                          base::DoNothing());
  helpers_.erase(helpers_.find(object_path.value()));
  delete helper_info;
}

void ShillThirdPartyVpnDriverClientImpl::SetParameters(
    const std::string& object_path_value,
    const base::Value::Dict& parameters,
    StringCallback callback,
    ErrorCallback error_callback) {
  dbus::MethodCall method_call(shill::kFlimflamThirdPartyVpnInterface,
                               shill::kSetParametersFunction);
  dbus::MessageWriter writer(&method_call);
  dbus::MessageWriter array_writer(nullptr);
  writer.OpenArray("{ss}", &array_writer);
  for (auto it : parameters) {
    if (!base::Contains(valid_keys_, it.first)) {
      LOG(WARNING) << "Unknown key " << it.first;
      continue;
    }
    if (!it.second.is_string()) {
      LOG(WARNING) << "Non string value " << it.second;
      continue;
    }
    dbus::MessageWriter entry_writer(nullptr);
    array_writer.OpenDictEntry(&entry_writer);
    entry_writer.AppendString(it.first);
    entry_writer.AppendString(it.second.GetString());
    array_writer.CloseContainer(&entry_writer);
  }
  writer.CloseContainer(&array_writer);
  GetHelper(object_path_value)
      ->CallStringMethodWithErrorCallback(&method_call, std::move(callback),
                                          std::move(error_callback));
}

void ShillThirdPartyVpnDriverClientImpl::UpdateConnectionState(
    const std::string& object_path_value,
    const uint32_t connection_state,
    base::OnceClosure callback,
    ErrorCallback error_callback) {
  dbus::MethodCall method_call(shill::kFlimflamThirdPartyVpnInterface,
                               shill::kUpdateConnectionStateFunction);
  dbus::MessageWriter writer(&method_call);
  writer.AppendUint32(connection_state);
  GetHelper(object_path_value)
      ->CallVoidMethodWithErrorCallback(&method_call, std::move(callback),
                                        std::move(error_callback));
}

void ShillThirdPartyVpnDriverClientImpl::SendPacket(
    const std::string& object_path_value,
    const std::vector<char>& ip_packet,
    base::OnceClosure callback,
    ErrorCallback error_callback) {
  dbus::MethodCall method_call(shill::kFlimflamThirdPartyVpnInterface,
                               shill::kSendPacketFunction);
  dbus::MessageWriter writer(&method_call);
  static_assert(sizeof(uint8_t) == sizeof(char),
                "Can't reinterpret ip_packet if char is not 8 bit large.");
  writer.AppendArrayOfBytes(base::as_byte_span(ip_packet));
  GetHelper(object_path_value)
      ->CallVoidMethodWithErrorCallback(&method_call, std::move(callback),
                                        std::move(error_callback));
}

// static
void ShillThirdPartyVpnDriverClientImpl::OnPacketReceived(
    base::WeakPtr<HelperInfo> helper_info,
    dbus::Signal* signal) {
  if (!helper_info || !helper_info->observer())
    return;

  dbus::MessageReader reader(signal);
  const uint8_t* data = nullptr;
  size_t length = 0;
  if (reader.PopArrayOfBytes(&data, &length)) {
    helper_info->observer()->OnPacketReceived(
        std::vector<char>(data, data + length));
  }
}

// static
void ShillThirdPartyVpnDriverClientImpl::OnPlatformMessage(
    base::WeakPtr<HelperInfo> helper_info,
    dbus::Signal* signal) {
  if (!helper_info || !helper_info->observer())
    return;

  dbus::MessageReader reader(signal);
  uint32_t platform_message = 0;
  if (reader.PopUint32(&platform_message))
    helper_info->observer()->OnPlatformMessage(platform_message);
}

// static
void ShillThirdPartyVpnDriverClientImpl::OnSignalConnected(
    const std::string& interface,
    const std::string& signal,
    bool success) {
  LOG_IF(ERROR, !success) << "Connect to " << interface << " " << signal
                          << " failed.";
}

ShillClientHelper* ShillThirdPartyVpnDriverClientImpl::GetHelper(
    const std::string& object_path_value) {
  return GetHelperInfo(object_path_value)->helper();
}

ShillThirdPartyVpnDriverClientImpl::HelperInfo*
ShillThirdPartyVpnDriverClientImpl::FindHelperInfo(
    const dbus::ObjectPath& object_path) {
  HelperMap::iterator it = helpers_.find(object_path.value());
  return (it != helpers_.end()) ? it->second : nullptr;
}

ShillThirdPartyVpnDriverClientImpl::HelperInfo*
ShillThirdPartyVpnDriverClientImpl::GetHelperInfo(
    const std::string& object_path_value) {
  dbus::ObjectPath object_path(object_path_value);
  HelperInfo* helper_info = FindHelperInfo(object_path);
  if (helper_info)
    return helper_info;

  // There is no helper for the profile, create it.
  dbus::ObjectProxy* object_proxy =
      bus_->GetObjectProxy(shill::kFlimflamServiceName, object_path);
  helper_info = new HelperInfo(object_proxy);
  helpers_[object_path_value] = helper_info;
  return helper_info;
}

}  // namespace

ShillThirdPartyVpnDriverClient::ShillThirdPartyVpnDriverClient() {
  DCHECK(!g_instance);
  g_instance = this;
}

ShillThirdPartyVpnDriverClient::~ShillThirdPartyVpnDriverClient() {
  DCHECK_EQ(this, g_instance);
  g_instance = nullptr;
}

// static
void ShillThirdPartyVpnDriverClient::Initialize(dbus::Bus* bus) {
  DCHECK(bus);
  new ShillThirdPartyVpnDriverClientImpl(bus);
}

// static
void ShillThirdPartyVpnDriverClient::InitializeFake() {
  // If a browser test creates a custom fake, don't override it.
  if (!FakeShillThirdPartyVpnDriverClient::Get())
    new FakeShillThirdPartyVpnDriverClient();
}

// static
void ShillThirdPartyVpnDriverClient::Shutdown() {
  DCHECK(g_instance);
  delete g_instance;
}

// static
ShillThirdPartyVpnDriverClient* ShillThirdPartyVpnDriverClient::Get() {
  return g_instance;
}

}  // namespace ash