chromium/chromeos/ash/services/secure_channel/active_connection_manager_impl_unittest.cc

// Copyright 2018 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chromeos/ash/services/secure_channel/active_connection_manager_impl.h"

#include <memory>

#include "base/containers/contains.h"
#include "base/containers/flat_map.h"
#include "base/containers/to_vector.h"
#include "base/functional/bind.h"
#include "base/memory/raw_ptr.h"
#include "base/ranges/algorithm.h"
#include "base/test/gtest_util.h"
#include "base/test/task_environment.h"
#include "base/unguessable_token.h"
#include "chromeos/ash/services/secure_channel/client_connection_parameters.h"
#include "chromeos/ash/services/secure_channel/connection_details.h"
#include "chromeos/ash/services/secure_channel/fake_active_connection_manager.h"
#include "chromeos/ash/services/secure_channel/fake_authenticated_channel.h"
#include "chromeos/ash/services/secure_channel/fake_client_connection_parameters.h"
#include "chromeos/ash/services/secure_channel/fake_multiplexed_channel.h"
#include "chromeos/ash/services/secure_channel/multiplexed_channel_impl.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace ash::secure_channel {

namespace {

class FakeMultiplexedChannelFactory : public MultiplexedChannelImpl::Factory {
 public:
  explicit FakeMultiplexedChannelFactory(
      MultiplexedChannel::Delegate* expected_delegate)
      : expected_delegate_(expected_delegate) {}

  FakeMultiplexedChannelFactory(const FakeMultiplexedChannelFactory&) = delete;
  FakeMultiplexedChannelFactory& operator=(
      const FakeMultiplexedChannelFactory&) = delete;

  ~FakeMultiplexedChannelFactory() override = default;

  base::flat_map<ConnectionDetails, FakeMultiplexedChannel*>&
  connection_details_to_active_channel_map() {
    return connection_details_to_active_channel_map_;
  }

  void set_next_expected_authenticated_channel(
      AuthenticatedChannel* authenticated_channel) {
    next_expected_authenticated_channel_ = authenticated_channel;
  }

  // MultiplexedChannelImpl::Factory:
  std::unique_ptr<MultiplexedChannel> CreateInstance(
      std::unique_ptr<AuthenticatedChannel> authenticated_channel,
      MultiplexedChannel::Delegate* delegate,
      ConnectionDetails connection_details,
      std::vector<std::unique_ptr<ClientConnectionParameters>>* initial_clients)
      override {
    EXPECT_EQ(expected_delegate_, delegate);
    EXPECT_EQ(next_expected_authenticated_channel_,
              authenticated_channel.get());
    next_expected_authenticated_channel_ = nullptr;

    auto fake_channel = std::make_unique<FakeMultiplexedChannel>(
        delegate, connection_details,
        base::BindOnce(&FakeMultiplexedChannelFactory::OnChannelDeleted,
                       base::Unretained(this)));

    for (auto& initial_client : *initial_clients)
      fake_channel->AddClientToChannel(std::move(initial_client));

    connection_details_to_active_channel_map_[connection_details] =
        fake_channel.get();

    return fake_channel;
  }

 private:
  void OnChannelDeleted(const ConnectionDetails& connection_details) {
    size_t num_deleted =
        connection_details_to_active_channel_map_.erase(connection_details);
    EXPECT_EQ(1u, num_deleted);
  }

  raw_ptr<const MultiplexedChannel::Delegate, DanglingUntriaged>
      expected_delegate_;

  raw_ptr<AuthenticatedChannel> next_expected_authenticated_channel_ = nullptr;

  base::flat_map<ConnectionDetails, FakeMultiplexedChannel*>
      connection_details_to_active_channel_map_;
};

std::vector<base::UnguessableToken> ClientListToIdList(
    const std::vector<std::unique_ptr<ClientConnectionParameters>>&
        client_list) {
  return base::ToVector(client_list, &ClientConnectionParameters::id);
}

}  // namespace

class SecureChannelActiveConnectionManagerImplTest : public testing::Test {
 public:
  SecureChannelActiveConnectionManagerImplTest(
      const SecureChannelActiveConnectionManagerImplTest&) = delete;
  SecureChannelActiveConnectionManagerImplTest& operator=(
      const SecureChannelActiveConnectionManagerImplTest&) = delete;

 protected:
  SecureChannelActiveConnectionManagerImplTest() = default;
  ~SecureChannelActiveConnectionManagerImplTest() override = default;

  // testing::Test:
  void SetUp() override {
    fake_delegate_ = std::make_unique<FakeActiveConnectionManagerDelegate>();

    manager_ =
        ActiveConnectionManagerImpl::Factory::Create(fake_delegate_.get());

    ActiveConnectionManagerImpl* ptr_as_impl =
        static_cast<ActiveConnectionManagerImpl*>(manager_.get());
    fake_multiplexed_channel_factory_ =
        std::make_unique<FakeMultiplexedChannelFactory>(ptr_as_impl);
    MultiplexedChannelImpl::Factory::SetFactoryForTesting(
        fake_multiplexed_channel_factory_.get());
  }

  void TearDown() override {
    MultiplexedChannelImpl::Factory::SetFactoryForTesting(nullptr);
  }

  void AddActiveConnectionAndVerifyState(
      const std::string& device_id,
      std::vector<std::unique_ptr<ClientConnectionParameters>>
          initial_clients) {
    EXPECT_EQ(ActiveConnectionManager::ConnectionState::kNoConnectionExists,
              GetConnectionState(device_id));

    std::vector<base::UnguessableToken> initial_client_ids =
        ClientListToIdList(initial_clients);

    auto fake_authenticated_channel =
        std::make_unique<FakeAuthenticatedChannel>();
    fake_multiplexed_channel_factory_->set_next_expected_authenticated_channel(
        fake_authenticated_channel.get());

    manager_->AddActiveConnection(
        std::move(fake_authenticated_channel), std::move(initial_clients),
        ConnectionDetails(device_id, ConnectionMedium::kBluetoothLowEnergy));

    // The connection should be active, and the initial clients should now be
    // present in the associated channel.
    EXPECT_EQ(ActiveConnectionManager::ConnectionState::kActiveConnectionExists,
              GetConnectionState(device_id));
    EXPECT_EQ(initial_client_ids,
              ClientListToIdList(
                  GetActiveChannelForDeviceId(device_id)->added_clients()));
  }

  void AddNewClientAndVerifyState(const std::string& device_id,
                                  std::unique_ptr<ClientConnectionParameters>
                                      client_connection_parameters) {
    EXPECT_EQ(ActiveConnectionManager::ConnectionState::kActiveConnectionExists,
              GetConnectionState(device_id));

    // Initialize to the IDs before this call.
    std::vector<base::UnguessableToken> client_ids = ClientListToIdList(
        GetActiveChannelForDeviceId(device_id)->added_clients());

    // Add in the new ID for this new client.
    client_ids.push_back(client_connection_parameters->id());

    manager_->AddClientToChannel(
        std::move(client_connection_parameters),
        ConnectionDetails(device_id, ConnectionMedium::kBluetoothLowEnergy));

    // The connection should remain active, and the clients list should now have
    // the new client.
    EXPECT_EQ(ActiveConnectionManager::ConnectionState::kActiveConnectionExists,
              GetConnectionState(device_id));
    EXPECT_EQ(client_ids,
              ClientListToIdList(
                  GetActiveChannelForDeviceId(device_id)->added_clients()));
  }

  ActiveConnectionManager::ConnectionState GetConnectionState(
      const std::string device_id) {
    return manager_->GetConnectionState(
        ConnectionDetails(device_id, ConnectionMedium::kBluetoothLowEnergy));
  }

  FakeMultiplexedChannel* GetActiveChannelForDeviceId(
      const std::string& device_id) {
    ConnectionDetails connection_details(device_id,
                                         ConnectionMedium::kBluetoothLowEnergy);
    if (!base::Contains(fake_multiplexed_channel_factory_
                            ->connection_details_to_active_channel_map(),
                        connection_details)) {
      return nullptr;
    }

    return fake_multiplexed_channel_factory_
        ->connection_details_to_active_channel_map()[connection_details];
  }

  size_t GetNumActiveChannels() {
    return fake_multiplexed_channel_factory_
        ->connection_details_to_active_channel_map()
        .size();
  }

  size_t GetNumDisconnections(const std::string& device_id) {
    ConnectionDetails connection_details(device_id,
                                         ConnectionMedium::kBluetoothLowEnergy);

    const auto& map =
        fake_delegate_->connection_details_to_num_disconnections_map();
    auto it = map.find(connection_details);
    EXPECT_NE(it, map.end());
    return it->second;
  }

  ActiveConnectionManager* active_connection_manager() {
    return manager_.get();
  }

 private:
  base::test::TaskEnvironment task_environment_;

  std::unique_ptr<FakeMultiplexedChannelFactory>
      fake_multiplexed_channel_factory_;
  std::unique_ptr<FakeActiveConnectionManagerDelegate> fake_delegate_;

  std::unique_ptr<ActiveConnectionManager> manager_;
};

TEST_F(SecureChannelActiveConnectionManagerImplTest, EdgeCases) {
  std::vector<std::unique_ptr<ClientConnectionParameters>> client_list;
  client_list.push_back(
      std::make_unique<FakeClientConnectionParameters>("feature"));

  AddActiveConnectionAndVerifyState("deviceId", std::move(client_list));

  // Try to add another channel for the same ConnectionDetails; this should
  // fail, since one already exists.
  client_list.push_back(
      std::make_unique<FakeClientConnectionParameters>("feature"));
  EXPECT_DCHECK_DEATH(active_connection_manager()->AddActiveConnection(
      std::make_unique<FakeAuthenticatedChannel>(), std::move(client_list),
      ConnectionDetails("deviceId", ConnectionMedium::kBluetoothLowEnergy)));

  // Move to disconnecting state.
  GetActiveChannelForDeviceId("deviceId")->SetDisconnecting();
  EXPECT_EQ(
      ActiveConnectionManager::ConnectionState::kDisconnectingConnectionExists,
      GetConnectionState("deviceId"));

  // Try to add another channel; this should still fail while disconnecting.
  client_list.push_back(
      std::make_unique<FakeClientConnectionParameters>("feature"));
  EXPECT_DCHECK_DEATH(active_connection_manager()->AddActiveConnection(
      std::make_unique<FakeAuthenticatedChannel>(), std::move(client_list),
      ConnectionDetails("deviceId", ConnectionMedium::kBluetoothLowEnergy)));

  // Try to add an additional client; this should also fail while disconnecting.
  EXPECT_DCHECK_DEATH(active_connection_manager()->AddClientToChannel(
      std::make_unique<FakeClientConnectionParameters>("feature"),
      ConnectionDetails("deviceId", ConnectionMedium::kBluetoothLowEnergy)));

  GetActiveChannelForDeviceId("deviceId")->SetDisconnected();
  EXPECT_EQ(ActiveConnectionManager::ConnectionState::kNoConnectionExists,
            GetConnectionState("deviceId"));

  // Try to add an additional client; this should also fail while disconnected.
  EXPECT_DCHECK_DEATH(active_connection_manager()->AddClientToChannel(
      std::make_unique<FakeClientConnectionParameters>("feature"),
      ConnectionDetails("deviceId", ConnectionMedium::kBluetoothLowEnergy)));
}

TEST_F(SecureChannelActiveConnectionManagerImplTest, SingleChannel_OneClient) {
  std::vector<std::unique_ptr<ClientConnectionParameters>> client_list;
  client_list.push_back(
      std::make_unique<FakeClientConnectionParameters>("feature"));

  AddActiveConnectionAndVerifyState("deviceId", std::move(client_list));
  EXPECT_EQ(1u, GetNumActiveChannels());

  GetActiveChannelForDeviceId("deviceId")->SetDisconnecting();
  EXPECT_EQ(
      ActiveConnectionManager::ConnectionState::kDisconnectingConnectionExists,
      GetConnectionState("deviceId"));
  EXPECT_EQ(1u, GetNumActiveChannels());

  GetActiveChannelForDeviceId("deviceId")->SetDisconnected();
  EXPECT_EQ(ActiveConnectionManager::ConnectionState::kNoConnectionExists,
            GetConnectionState("deviceId"));
  EXPECT_EQ(0u, GetNumActiveChannels());
  EXPECT_EQ(1u, GetNumDisconnections("deviceId"));
}

TEST_F(SecureChannelActiveConnectionManagerImplTest,
       SingleChannel_MultipleClients) {
  std::vector<std::unique_ptr<ClientConnectionParameters>> client_list;
  client_list.push_back(
      std::make_unique<FakeClientConnectionParameters>("feature1"));
  client_list.push_back(
      std::make_unique<FakeClientConnectionParameters>("feature2"));

  AddActiveConnectionAndVerifyState("deviceId", std::move(client_list));
  EXPECT_EQ(1u, GetNumActiveChannels());

  AddNewClientAndVerifyState(
      "deviceId", std::make_unique<FakeClientConnectionParameters>("feature3"));
  EXPECT_EQ(1u, GetNumActiveChannels());

  GetActiveChannelForDeviceId("deviceId")->SetDisconnecting();
  EXPECT_EQ(
      ActiveConnectionManager::ConnectionState::kDisconnectingConnectionExists,
      GetConnectionState("deviceId"));
  EXPECT_EQ(1u, GetNumActiveChannels());

  GetActiveChannelForDeviceId("deviceId")->SetDisconnected();
  EXPECT_EQ(ActiveConnectionManager::ConnectionState::kNoConnectionExists,
            GetConnectionState("deviceId"));
  EXPECT_EQ(0u, GetNumActiveChannels());
  EXPECT_EQ(1u, GetNumDisconnections("deviceId"));
}

TEST_F(SecureChannelActiveConnectionManagerImplTest,
       MultipleChannels_MultipleClients) {
  // Add an initial channel with two clients.
  std::vector<std::unique_ptr<ClientConnectionParameters>> client_list;
  client_list.push_back(
      std::make_unique<FakeClientConnectionParameters>("feature1"));
  client_list.push_back(
      std::make_unique<FakeClientConnectionParameters>("feature2"));

  AddActiveConnectionAndVerifyState("deviceId1", std::move(client_list));
  EXPECT_EQ(1u, GetNumActiveChannels());

  // Add another channel with two more clients.
  client_list.push_back(
      std::make_unique<FakeClientConnectionParameters>("feature3"));
  client_list.push_back(
      std::make_unique<FakeClientConnectionParameters>("feature4"));

  AddActiveConnectionAndVerifyState("deviceId2", std::move(client_list));
  EXPECT_EQ(2u, GetNumActiveChannels());

  // Add a new client to the first channel.
  AddNewClientAndVerifyState(
      "deviceId1",
      std::make_unique<FakeClientConnectionParameters>("feature5"));
  EXPECT_EQ(2u, GetNumActiveChannels());

  // Add a new client to the second channel.
  AddNewClientAndVerifyState(
      "deviceId2",
      std::make_unique<FakeClientConnectionParameters>("feature6"));
  EXPECT_EQ(2u, GetNumActiveChannels());

  // Start disconnecting the first channel.
  GetActiveChannelForDeviceId("deviceId1")->SetDisconnecting();
  EXPECT_EQ(
      ActiveConnectionManager::ConnectionState::kDisconnectingConnectionExists,
      GetConnectionState("deviceId1"));
  EXPECT_EQ(2u, GetNumActiveChannels());

  // Disconnect the first channel.
  GetActiveChannelForDeviceId("deviceId1")->SetDisconnected();
  EXPECT_EQ(ActiveConnectionManager::ConnectionState::kNoConnectionExists,
            GetConnectionState("deviceId1"));
  EXPECT_EQ(1u, GetNumActiveChannels());
  EXPECT_EQ(1u, GetNumDisconnections("deviceId1"));

  // Now, add another channel for the same device that just disconnected.
  client_list.push_back(
      std::make_unique<FakeClientConnectionParameters>("feature7"));
  client_list.push_back(
      std::make_unique<FakeClientConnectionParameters>("feature8"));

  AddActiveConnectionAndVerifyState("deviceId1", std::move(client_list));
  EXPECT_EQ(2u, GetNumActiveChannels());

  // Start disconnecting the second channel.
  GetActiveChannelForDeviceId("deviceId2")->SetDisconnecting();
  EXPECT_EQ(
      ActiveConnectionManager::ConnectionState::kDisconnectingConnectionExists,
      GetConnectionState("deviceId2"));
  EXPECT_EQ(2u, GetNumActiveChannels());

  // Disconnect the second channel.
  GetActiveChannelForDeviceId("deviceId2")->SetDisconnected();
  EXPECT_EQ(ActiveConnectionManager::ConnectionState::kNoConnectionExists,
            GetConnectionState("deviceId2"));
  EXPECT_EQ(1u, GetNumActiveChannels());
  EXPECT_EQ(1u, GetNumDisconnections("deviceId2"));

  // Start disconnecting the second iteration of the first channel.
  GetActiveChannelForDeviceId("deviceId1")->SetDisconnecting();
  EXPECT_EQ(
      ActiveConnectionManager::ConnectionState::kDisconnectingConnectionExists,
      GetConnectionState("deviceId1"));
  EXPECT_EQ(1u, GetNumActiveChannels());

  // Disconnect the second iteration of the first channel.
  GetActiveChannelForDeviceId("deviceId1")->SetDisconnected();
  EXPECT_EQ(ActiveConnectionManager::ConnectionState::kNoConnectionExists,
            GetConnectionState("deviceId1"));
  EXPECT_EQ(0u, GetNumActiveChannels());
  EXPECT_EQ(2u, GetNumDisconnections("deviceId1"));
}

}  // namespace ash::secure_channel