chromium/chromeos/ash/services/secure_channel/nearby_connection_manager_impl_unittest.cc

// Copyright 2020 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chromeos/ash/services/secure_channel/nearby_connection_manager_impl.h"

#include <memory>
#include <optional>

#include "base/functional/bind.h"
#include "base/memory/raw_ptr.h"
#include "base/test/task_environment.h"
#include "chromeos/ash/components/multidevice/remote_device_test_util.h"
#include "chromeos/ash/services/secure_channel/authenticated_channel_impl.h"
#include "chromeos/ash/services/secure_channel/fake_authenticated_channel.h"
#include "chromeos/ash/services/secure_channel/fake_ble_scanner.h"
#include "chromeos/ash/services/secure_channel/fake_connection.h"
#include "chromeos/ash/services/secure_channel/fake_secure_channel_connection.h"
#include "chromeos/ash/services/secure_channel/fake_secure_channel_disconnector.h"
#include "chromeos/ash/services/secure_channel/nearby_connection.h"
#include "chromeos/ash/services/secure_channel/public/cpp/client/fake_nearby_connector.h"
#include "chromeos/ash/services/secure_channel/public/mojom/nearby_connector.mojom-shared.h"
#include "chromeos/ash/services/secure_channel/public/mojom/secure_channel.mojom-shared.h"
#include "chromeos/ash/services/secure_channel/secure_channel.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace ash::secure_channel {

namespace {

const size_t kNumTestDevices = 3;

class FakeNearbyConnectionFactory : public NearbyConnection::Factory {
 public:
  FakeNearbyConnectionFactory() = default;
  FakeNearbyConnectionFactory(const FakeNearbyConnectionFactory&) = delete;
  FakeNearbyConnectionFactory& operator=(const FakeNearbyConnectionFactory&) =
      delete;
  ~FakeNearbyConnectionFactory() override = default;

  FakeConnection* last_created_instance() { return last_created_instance_; }

 private:
  // cryptauth::NearbyConnection::Factory:
  std::unique_ptr<Connection> CreateInstance(
      multidevice::RemoteDeviceRef remote_device,
      const std::vector<uint8_t>& eid,
      mojom::NearbyConnector* nearby_connector) override {
    auto instance = std::make_unique<FakeConnection>(remote_device);
    last_created_instance_ = instance.get();
    return instance;
  }

  raw_ptr<FakeConnection, DanglingUntriaged> last_created_instance_ = nullptr;
};

class FakeSecureChannelFactory : public SecureChannel::Factory {
 public:
  FakeSecureChannelFactory() = default;
  FakeSecureChannelFactory(const FakeSecureChannelFactory&) = delete;
  FakeSecureChannelFactory& operator=(const FakeSecureChannelFactory&) = delete;
  virtual ~FakeSecureChannelFactory() = default;

  FakeSecureChannelConnection* last_created_instance() {
    return last_created_instance_;
  }

 private:
  // SecureChannel::Factory:
  std::unique_ptr<SecureChannel> CreateInstance(
      std::unique_ptr<Connection> connection) override {
    auto instance =
        std::make_unique<FakeSecureChannelConnection>(std::move(connection));
    last_created_instance_ = instance.get();
    return instance;
  }

  raw_ptr<FakeSecureChannelConnection, DanglingUntriaged>
      last_created_instance_ = nullptr;
};

class FakeAuthenticatedChannelFactory
    : public AuthenticatedChannelImpl::Factory {
 public:
  FakeAuthenticatedChannelFactory() = default;
  FakeAuthenticatedChannelFactory(const FakeAuthenticatedChannelFactory&) =
      delete;
  FakeAuthenticatedChannelFactory& operator=(
      const FakeAuthenticatedChannelFactory&) = delete;
  ~FakeAuthenticatedChannelFactory() override = default;

  void SetExpectationsForNextCall(
      FakeSecureChannelConnection* expected_fake_secure_channel) {
    expected_fake_secure_channel_ = expected_fake_secure_channel;
  }

  FakeAuthenticatedChannel* last_created_instance() {
    return last_created_instance_;
  }

 private:
  // AuthenticatedChannelImpl::Factory:
  std::unique_ptr<AuthenticatedChannel> CreateInstance(
      const std::vector<mojom::ConnectionCreationDetail>&
          connection_creation_details,
      std::unique_ptr<SecureChannel> secure_channel) override {
    EXPECT_EQ(expected_fake_secure_channel_, secure_channel.get());

    auto instance = std::make_unique<FakeAuthenticatedChannel>();
    last_created_instance_ = instance.get();
    return instance;
  }

  raw_ptr<FakeSecureChannelConnection, DanglingUntriaged>
      expected_fake_secure_channel_ = nullptr;
  raw_ptr<FakeAuthenticatedChannel, DanglingUntriaged> last_created_instance_ =
      nullptr;
};

}  // namespace

class SecureChannelNearbyConnectionManagerImplTest : public testing::Test {
 protected:
  SecureChannelNearbyConnectionManagerImplTest()
      : task_environment_(
            base::test::TaskEnvironment::MainThreadType::DEFAULT,
            base::test::TaskEnvironment::ThreadPoolExecutionMode::QUEUED),
        test_devices_(
            multidevice::CreateRemoteDeviceRefListForTest(kNumTestDevices)) {}
  SecureChannelNearbyConnectionManagerImplTest(
      const SecureChannelNearbyConnectionManagerImplTest&) = delete;
  SecureChannelNearbyConnectionManagerImplTest& operator=(
      const SecureChannelNearbyConnectionManagerImplTest&) = delete;
  ~SecureChannelNearbyConnectionManagerImplTest() override = default;

  // testing::Test:
  void SetUp() override {
    fake_nearby_connection_factory_ =
        std::make_unique<FakeNearbyConnectionFactory>();
    NearbyConnection::Factory::SetFactoryForTesting(
        fake_nearby_connection_factory_.get());

    fake_secure_channel_factory_ = std::make_unique<FakeSecureChannelFactory>();
    SecureChannel::Factory::SetFactoryForTesting(
        fake_secure_channel_factory_.get());

    fake_authenticated_channel_factory_ =
        std::make_unique<FakeAuthenticatedChannelFactory>();
    AuthenticatedChannelImpl::Factory::SetFactoryForTesting(
        fake_authenticated_channel_factory_.get());

    fake_ble_scanner_ = std::make_unique<FakeBleScanner>();
    fake_secure_channel_disconnector_ =
        std::make_unique<FakeSecureChannelDisconnector>();

    manager_ = NearbyConnectionManagerImpl::Factory::Create(
        fake_ble_scanner_.get(), fake_secure_channel_disconnector_.get());

    EXPECT_FALSE(manager_->IsNearbyConnectorSet());
    fake_nearby_connector_ = std::make_unique<FakeNearbyConnector>();
    manager_->SetNearbyConnector(
        fake_nearby_connector_->GeneratePendingRemote());
    EXPECT_TRUE(manager_->IsNearbyConnectorSet());
  }

  void TearDown() override {
    NearbyConnection::Factory::SetFactoryForTesting(nullptr);
    SecureChannel::Factory::SetFactoryForTesting(nullptr);
    AuthenticatedChannelImpl::Factory::SetFactoryForTesting(nullptr);
  }

  void AttemptNearbyInitiatorConnection(const DeviceIdPair& device_id_pair,
                                        bool expected_to_add_request,
                                        bool should_cancel_attempt_on_failure) {
    SetInRemoteDeviceIdToMetadataMap(device_id_pair);

    manager_->AttemptNearbyInitiatorConnection(
        device_id_pair,
        base::BindRepeating(&SecureChannelNearbyConnectionManagerImplTest::
                                OnBleDiscoveryStateChanged,
                            base::Unretained(this), device_id_pair),
        base::BindRepeating(&SecureChannelNearbyConnectionManagerImplTest::
                                OnNearbyConnectionStateChanged,
                            base::Unretained(this), device_id_pair),
        base::BindRepeating(&SecureChannelNearbyConnectionManagerImplTest::
                                OnSecureChannelAuthenticationStateChanged,
                            base::Unretained(this), device_id_pair),
        base::BindOnce(
            &SecureChannelNearbyConnectionManagerImplTest::OnConnectionSuccess,
            base::Unretained(this), device_id_pair),
        base::BindRepeating(&SecureChannelNearbyConnectionManagerImplTest::
                                OnNearbyInitiatorFailure,
                            base::Unretained(this), device_id_pair,
                            should_cancel_attempt_on_failure));

    bool has_request =
        fake_ble_scanner_->HasScanRequest(ConnectionAttemptDetails(
            device_id_pair, ConnectionMedium::kNearbyConnections,
            ConnectionRole::kInitiatorRole));
    EXPECT_EQ(expected_to_add_request, has_request);
  }

  void CancelNearbyInitiatorConnectionAttempt(
      const DeviceIdPair& device_id_pair) {
    RemoveFromRemoteDeviceIdToMetadataMap(device_id_pair);
    manager_->CancelNearbyInitiatorConnectionAttempt(device_id_pair);
    EXPECT_FALSE(fake_ble_scanner_->HasScanRequest(ConnectionAttemptDetails(
        device_id_pair, ConnectionMedium::kNearbyConnections,
        ConnectionRole::kInitiatorRole)));
  }

  void SimulateBleDisvoceryFailed(const DeviceIdPair& device_id_pair) {
    fake_ble_scanner_->NotifyBleDiscoverySessionFailed(
        device_id_pair, mojom::DiscoveryResult::kFailure,
        mojom::DiscoveryErrorCode::kTimeout);

    // As a result of the connection, all ongoing connection attmepts should
    // have been canceled, since a connection is in progress.
    EXPECT_EQ(device_discovery_results_[device_id_pair],
              mojom::DiscoveryResult::kFailure);
  }

  // Returns the SecureChannel created by this call.
  FakeSecureChannelConnection* SimulateConnectionEstablished(
      multidevice::RemoteDeviceRef remote_device) {
    fake_ble_scanner_->NotifyReceivedAdvertisementFromDevice(
        remote_device, /*bluetooth_device=*/nullptr,
        ConnectionMedium::kNearbyConnections, ConnectionRole::kInitiatorRole,
        {0, 0} /* eid */);

    // As a result of the connection, all ongoing connection attmepts should
    // have been canceled, since a connection is in progress.
    EXPECT_TRUE(
        fake_ble_scanner_
            ->GetAllScanRequestsForRemoteDevice(remote_device.GetDeviceId())
            .empty());

    FakeSecureChannelConnection* last_created_secure_channel =
        fake_secure_channel_factory_->last_created_instance();
    EXPECT_TRUE(last_created_secure_channel->was_initialized());
    return last_created_secure_channel;
  }

  void SimulateSecureChannelDisconnection(
      const std::string& remote_device_id,
      bool fail_during_authentication,
      FakeSecureChannelConnection* fake_secure_channel,
      size_t num_initiator_attempts_canceled_from_disconnection = 0u) {
    size_t num_nearby_initiator_failures_before_call =
        nearby_initiator_failures_.size();

    // Connect, then disconnect. If needed, start authenticating before
    // disconnecting.
    fake_secure_channel->ChangeStatus(SecureChannel::Status::CONNECTED);
    if (fail_during_authentication) {
      fake_secure_channel->ChangeStatus(SecureChannel::Status::AUTHENTICATING);
    }
    fake_secure_channel->ChangeStatus(SecureChannel::Status::DISCONNECTED);

    // Iterate through all pending requests to |remote_device_id|, ensuring that
    // all expected failures have been communicated back to the client.
    size_t initiator_failures_index =
        num_nearby_initiator_failures_before_call +
        num_initiator_attempts_canceled_from_disconnection;
    for (const auto& pair :
         remote_device_id_to_id_pairs_map_[remote_device_id]) {
      EXPECT_EQ(pair,
                nearby_initiator_failures_[initiator_failures_index].first);
      EXPECT_EQ(fail_during_authentication
                    ? NearbyInitiatorFailureType::kAuthenticationError
                    : NearbyInitiatorFailureType::kConnectivityError,
                nearby_initiator_failures_[initiator_failures_index].second);
      ++initiator_failures_index;
    }
    EXPECT_EQ(initiator_failures_index, nearby_initiator_failures_.size());

    // All requests which were paused during the connection should have started
    // back up again, since the connection became disconnected.
    for (const auto& pair :
         remote_device_id_to_id_pairs_map_[remote_device_id]) {
      EXPECT_TRUE(fake_ble_scanner_->HasScanRequest(
          ConnectionAttemptDetails(pair, ConnectionMedium::kNearbyConnections,
                                   ConnectionRole::kInitiatorRole)));
    }
  }

  void SimulateSecureChannelAuthentication(
      const std::string& remote_device_id,
      FakeSecureChannelConnection* fake_secure_channel) {
    fake_authenticated_channel_factory_->SetExpectationsForNextCall(
        fake_secure_channel);

    size_t num_success_callbacks_before_call = successful_connections_.size();

    fake_secure_channel->ChangeNearbyConnectionState(
        mojom::NearbyConnectionStep::
            kWaitingForConnectionToBeAcceptedByRemoteDeviceStarted,
        mojom::NearbyConnectionStepResult::kSuccess);
    fake_secure_channel->ChangeStatus(SecureChannel::Status::CONNECTED);
    fake_secure_channel->ChangeStatus(SecureChannel::Status::AUTHENTICATING);
    fake_secure_channel->ChangeSecureChannelAuthenticationState(
        mojom::SecureChannelState::kValidatedResponderAuth);
    fake_secure_channel->ChangeStatus(SecureChannel::Status::AUTHENTICATED);

    // Verify that the callback was made. Verification that the provided
    // DeviceIdPair was correct occurs in OnConnectionSuccess().
    EXPECT_EQ(num_success_callbacks_before_call + 1u,
              successful_connections_.size());

    // For all remaining requests, verify that they were added back.
    for (const auto& pair :
         remote_device_id_to_id_pairs_map_[remote_device_id]) {
      EXPECT_TRUE(fake_ble_scanner_->HasScanRequest(
          ConnectionAttemptDetails(pair, ConnectionMedium::kNearbyConnections,
                                   ConnectionRole::kInitiatorRole)));
    }
  }

  bool WasChannelHandledByDisconnector(
      FakeSecureChannelConnection* fake_secure_channel) {
    return fake_secure_channel_disconnector_->WasChannelHandled(
        fake_secure_channel);
  }

  base::test::TaskEnvironment task_environment_;
  const multidevice::RemoteDeviceRefList& test_devices() {
    return test_devices_;
  }

 private:
  void OnConnectionSuccess(
      const DeviceIdPair& device_id_pair,
      std::unique_ptr<AuthenticatedChannel> authenticated_channel) {
    successful_connections_.push_back(
        std::make_pair(device_id_pair, std::move(authenticated_channel)));

    // The request which received the success callback is automatically removed
    // by NearbyConnectionManager, so it no longer needs to be tracked.
    remote_device_id_to_id_pairs_map_[device_id_pair.remote_device_id()].erase(
        device_id_pair);

    // Make a copy of the entries which should be canceled. This is required
    // because the Cancel*() calls above end up removing entries from
    // |remote_device_id_to_id_pairs_map_|, which can cause access to deleted
    // memory.
    base::flat_set<DeviceIdPair> to_cancel =
        remote_device_id_to_id_pairs_map_[device_id_pair.remote_device_id()];

    for (const auto& pair : to_cancel)
      CancelNearbyInitiatorConnectionAttempt(pair);
  }

  void OnNearbyInitiatorFailure(const DeviceIdPair& device_id_pair,
                                bool should_cancel_attempt_on_failure,
                                NearbyInitiatorFailureType failure_type) {
    nearby_initiator_failures_.push_back(
        std::make_pair(device_id_pair, failure_type));
    if (!should_cancel_attempt_on_failure)
      return;

    // Make a copy of the pair before canceling the attempt, since the reference
    // points to memory owned by |manager_| which will be deleted.
    DeviceIdPair device_id_pair_copy = device_id_pair;
    CancelNearbyInitiatorConnectionAttempt(device_id_pair_copy);
  }

  void OnBleDiscoveryStateChanged(
      const DeviceIdPair& device_id_pair,
      mojom::DiscoveryResult result,
      std::optional<mojom::DiscoveryErrorCode> error_code) {
    device_discovery_results_[device_id_pair] = result;
  }

  void OnNearbyConnectionStateChanged(
      const DeviceIdPair& device_id_pair,
      mojom::NearbyConnectionStep nearby_connection_step,
      mojom::NearbyConnectionStepResult result) {
    device_nearby_connection_states_[device_id_pair] = nearby_connection_step;
  }

  void OnSecureChannelAuthenticationStateChanged(
      const DeviceIdPair& device_id_pair,
      mojom::SecureChannelState secure_channel_state) {
    device_secure_channel_states_[device_id_pair] = secure_channel_state;
  }

  void SetInRemoteDeviceIdToMetadataMap(const DeviceIdPair& device_id_pair) {
    remote_device_id_to_id_pairs_map_[device_id_pair.remote_device_id()].insert(
        device_id_pair);
  }

  void RemoveFromRemoteDeviceIdToMetadataMap(
      const DeviceIdPair& device_id_pair) {
    base::flat_set<DeviceIdPair>& set_for_remote_device =
        remote_device_id_to_id_pairs_map_[device_id_pair.remote_device_id()];

    for (auto it = set_for_remote_device.begin();
         it != set_for_remote_device.end(); ++it) {
      if (*it == device_id_pair) {
        set_for_remote_device.erase(it);
        return;
      }
    }

    NOTREACHED_IN_MIGRATION();
  }

  const multidevice::RemoteDeviceRefList test_devices_;

  base::flat_map<std::string, base::flat_set<DeviceIdPair>>
      remote_device_id_to_id_pairs_map_;
  base::flat_map<DeviceIdPair, mojom::DiscoveryResult>
      device_discovery_results_;
  base::flat_map<DeviceIdPair, mojom::NearbyConnectionStep>
      device_nearby_connection_states_;
  base::flat_map<DeviceIdPair, mojom::SecureChannelState>
      device_secure_channel_states_;
  std::vector<std::pair<DeviceIdPair, std::unique_ptr<AuthenticatedChannel>>>
      successful_connections_;
  std::vector<std::pair<DeviceIdPair, NearbyInitiatorFailureType>>
      nearby_initiator_failures_;

  std::unique_ptr<FakeNearbyConnectionFactory> fake_nearby_connection_factory_;
  std::unique_ptr<FakeSecureChannelFactory> fake_secure_channel_factory_;
  std::unique_ptr<FakeAuthenticatedChannelFactory>
      fake_authenticated_channel_factory_;

  std::unique_ptr<FakeBleScanner> fake_ble_scanner_;
  std::unique_ptr<FakeSecureChannelDisconnector>
      fake_secure_channel_disconnector_;
  std::unique_ptr<FakeNearbyConnector> fake_nearby_connector_;

  std::unique_ptr<NearbyConnectionManager> manager_;
};

TEST_F(SecureChannelNearbyConnectionManagerImplTest,
       AttemptAndCancelWithoutConnection) {
  DeviceIdPair pair(test_devices()[1].GetDeviceId(),
                    test_devices()[0].GetDeviceId());

  AttemptNearbyInitiatorConnection(pair,
                                   /*expected_to_add_request=*/true,
                                   /*should_cancel_attempt_on_failure=*/false);
  CancelNearbyInitiatorConnectionAttempt(pair);
}

TEST_F(SecureChannelNearbyConnectionManagerImplTest,
       AttemptAndDiscoveryFailed) {
  DeviceIdPair pair(test_devices()[1].GetDeviceId(),
                    test_devices()[0].GetDeviceId());

  AttemptNearbyInitiatorConnection(pair,
                                   /*expected_to_add_request=*/true,
                                   /*should_cancel_attempt_on_failure=*/false);
  SimulateBleDisvoceryFailed(pair);
}

TEST_F(SecureChannelNearbyConnectionManagerImplTest,
       StartConnectionThenDisconnect_CancelAfter) {
  DeviceIdPair pair(test_devices()[1].GetDeviceId(),
                    test_devices()[0].GetDeviceId());

  AttemptNearbyInitiatorConnection(pair,
                                   /*expected_to_add_request=*/true,
                                   /*should_cancel_attempt_on_failure=*/false);

  FakeSecureChannelConnection* fake_secure_channel =
      SimulateConnectionEstablished(test_devices()[1]);
  SimulateSecureChannelDisconnection(pair.remote_device_id(),
                                     /*fail_during_authentication=*/true,
                                     fake_secure_channel);

  CancelNearbyInitiatorConnectionAttempt(pair);
}

TEST_F(SecureChannelNearbyConnectionManagerImplTest,
       StartConnectionThenDisconnect_CancelInCallback) {
  DeviceIdPair pair(test_devices()[1].GetDeviceId(),
                    test_devices()[0].GetDeviceId());

  AttemptNearbyInitiatorConnection(pair,
                                   /*expected_to_add_request=*/true,
                                   /*should_cancel_attempt_on_failure=*/true);

  FakeSecureChannelConnection* fake_secure_channel =
      SimulateConnectionEstablished(test_devices()[1]);
  SimulateSecureChannelDisconnection(
      pair.remote_device_id(),
      /*fail_during_authentication=*/true, fake_secure_channel,
      /*num_initiator_attempts_canceled_from_disconnection=*/1u);
}

TEST_F(SecureChannelNearbyConnectionManagerImplTest, SuccessfulConnection) {
  DeviceIdPair pair(test_devices()[1].GetDeviceId(),
                    test_devices()[0].GetDeviceId());

  AttemptNearbyInitiatorConnection(pair,
                                   /*expected_to_add_request=*/true,
                                   /*should_cancel_attempt_on_failure=*/true);

  FakeSecureChannelConnection* fake_secure_channel =
      SimulateConnectionEstablished(test_devices()[1]);
  SimulateSecureChannelAuthentication(pair.remote_device_id(),
                                      fake_secure_channel);
}

TEST_F(SecureChannelNearbyConnectionManagerImplTest, TwoSimultaneousAttempts) {
  DeviceIdPair pair_1(test_devices()[1].GetDeviceId(),
                      test_devices()[0].GetDeviceId());
  DeviceIdPair pair_2(test_devices()[2].GetDeviceId(),
                      test_devices()[0].GetDeviceId());

  AttemptNearbyInitiatorConnection(pair_1,
                                   /*expected_to_add_request=*/true,
                                   /*should_cancel_attempt_on_failure=*/true);
  AttemptNearbyInitiatorConnection(pair_2,
                                   /*expected_to_add_request=*/true,
                                   /*should_cancel_attempt_on_failure=*/true);

  FakeSecureChannelConnection* fake_secure_channel_1 =
      SimulateConnectionEstablished(test_devices()[1]);
  SimulateSecureChannelAuthentication(pair_1.remote_device_id(),
                                      fake_secure_channel_1);
  FakeSecureChannelConnection* fake_secure_channel_2 =
      SimulateConnectionEstablished(test_devices()[2]);
  SimulateSecureChannelAuthentication(pair_2.remote_device_id(),
                                      fake_secure_channel_2);
}

TEST_F(SecureChannelNearbyConnectionManagerImplTest,
       CancelWhileAuthenticating) {
  DeviceIdPair pair(test_devices()[1].GetDeviceId(),
                    test_devices()[0].GetDeviceId());

  AttemptNearbyInitiatorConnection(pair,
                                   /*expected_to_add_request=*/true,
                                   /*should_cancel_attempt_on_failure=*/true);

  FakeSecureChannelConnection* fake_secure_channel =
      SimulateConnectionEstablished(test_devices()[1]);
  CancelNearbyInitiatorConnectionAttempt(pair);
  EXPECT_TRUE(WasChannelHandledByDisconnector(fake_secure_channel));
}

}  // namespace ash::secure_channel