chromium/chromeos/ash/services/secure_channel/secure_channel_unittest.cc

// Copyright 2017 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chromeos/ash/services/secure_channel/secure_channel.h"

#include <memory>
#include <string>

#include "base/functional/bind.h"
#include "base/memory/ptr_util.h"
#include "base/memory/raw_ptr.h"
#include "base/memory/weak_ptr.h"
#include "base/test/bind.h"
#include "chromeos/ash/components/multidevice/fake_secure_message_delegate.h"
#include "chromeos/ash/components/multidevice/remote_device_ref.h"
#include "chromeos/ash/components/multidevice/remote_device_test_util.h"
#include "chromeos/ash/components/multidevice/secure_message_delegate_impl.h"
#include "chromeos/ash/services/secure_channel/fake_authenticator.h"
#include "chromeos/ash/services/secure_channel/fake_connection.h"
#include "chromeos/ash/services/secure_channel/fake_secure_context.h"
#include "chromeos/ash/services/secure_channel/file_transfer_update_callback.h"
#include "chromeos/ash/services/secure_channel/public/mojom/secure_channel_types.mojom.h"
#include "chromeos/ash/services/secure_channel/wire_message.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace ash::secure_channel {

namespace {

struct SecureChannelStatusChange {
  SecureChannelStatusChange(const SecureChannel::Status& old_status,
                            const SecureChannel::Status& new_status)
      : old_status(old_status), new_status(new_status) {}

  SecureChannel::Status old_status;
  SecureChannel::Status new_status;
};

struct ReceivedMessage {
  ReceivedMessage(const std::string& feature, const std::string& payload)
      : feature(feature), payload(payload) {}

  std::string feature;
  std::string payload;
};

class TestObserver final : public SecureChannel::Observer {
 public:
  explicit TestObserver(SecureChannel* secure_channel)
      : secure_channel_(secure_channel) {}

  const std::vector<SecureChannelStatusChange>& connection_status_changes() {
    return connection_status_changes_;
  }

  const std::vector<ReceivedMessage>& received_messages() {
    return received_messages_;
  }

  const std::vector<int>& sent_sequence_numbers() {
    return sent_sequence_numbers_;
  }

  // SecureChannel::Observer:
  void OnSecureChannelStatusChanged(
      SecureChannel* secure_channel,
      const SecureChannel::Status& old_status,
      const SecureChannel::Status& new_status) override {
    DCHECK(secure_channel == secure_channel_);
    connection_status_changes_.push_back(
        SecureChannelStatusChange(old_status, new_status));
  }

  void OnMessageReceived(SecureChannel* secure_channel,
                         const std::string& feature,
                         const std::string& payload) override {
    DCHECK(secure_channel == secure_channel_);
    received_messages_.push_back(ReceivedMessage(feature, payload));
  }

  void OnMessageSent(SecureChannel* secure_channel,
                     int sequence_number) override {
    DCHECK(secure_channel == secure_channel_);
    sent_sequence_numbers_.push_back(sequence_number);
  }

 private:
  raw_ptr<SecureChannel, DanglingUntriaged> secure_channel_;
  std::vector<SecureChannelStatusChange> connection_status_changes_;
  std::vector<ReceivedMessage> received_messages_;
  std::vector<int> sent_sequence_numbers_;
};

// Observer used in the ObserverDeletesChannel test. This Observer deletes the
// SecureChannel when it receives an OnMessageSent() call.
class DeletingObserver final : public SecureChannel::Observer {
 public:
  explicit DeletingObserver(std::unique_ptr<SecureChannel>* secure_channel)
      : secure_channel_(secure_channel) {}

  // SecureChannel::Observer:
  void OnSecureChannelStatusChanged(
      SecureChannel* secure_channel,
      const SecureChannel::Status& old_status,
      const SecureChannel::Status& new_status) override {}

  void OnMessageReceived(SecureChannel* secure_channel,
                         const std::string& feature,
                         const std::string& payload) override {}

  void OnMessageSent(SecureChannel* secure_channel,
                     int sequence_number) override {
    DCHECK(secure_channel == secure_channel_->get());
    // Delete the channel when an OnMessageSent() call occurs.
    secure_channel_->reset();
  }

 private:
  raw_ptr<std::unique_ptr<SecureChannel>> secure_channel_;
};

class TestAuthenticatorFactory final
    : public DeviceToDeviceAuthenticator::Factory {
 public:
  TestAuthenticatorFactory() : last_instance_(nullptr) {}

  std::unique_ptr<Authenticator> CreateInstance(
      Connection* connection,
      std::unique_ptr<multidevice::SecureMessageDelegate>
          secure_message_delegate) override {
    last_instance_ = new FakeAuthenticator();
    return base::WrapUnique(last_instance_.get());
  }

  Authenticator* last_instance() { return last_instance_; }

 private:
  raw_ptr<Authenticator, DanglingUntriaged> last_instance_;
};

multidevice::RemoteDeviceRef CreateTestRemoteDevice() {
  multidevice::RemoteDeviceRef remote_device =
      multidevice::CreateRemoteDeviceRefListForTest(1)[0];
  return remote_device;
}

}  // namespace

class SecureChannelConnectionTest : public testing::Test {
 public:
  SecureChannelConnectionTest(const SecureChannelConnectionTest&) = delete;
  SecureChannelConnectionTest& operator=(const SecureChannelConnectionTest&) =
      delete;

 protected:
  SecureChannelConnectionTest() : test_device_(CreateTestRemoteDevice()) {}

  void SetUp() override {
    test_authenticator_factory_ = std::make_unique<TestAuthenticatorFactory>();
    DeviceToDeviceAuthenticator::Factory::SetFactoryForTesting(
        test_authenticator_factory_.get());

    fake_secure_message_delegate_factory_ =
        std::make_unique<multidevice::FakeSecureMessageDelegateFactory>();
    multidevice::SecureMessageDelegateImpl::Factory::SetFactoryForTesting(
        fake_secure_message_delegate_factory_.get());

    fake_secure_context_ = nullptr;

    fake_connection_ =
        new FakeConnection(test_device_, /* should_auto_connect */ false);

    EXPECT_FALSE(fake_connection_->observers().size());
    secure_channel_ = base::WrapUnique(
        new SecureChannel(base::WrapUnique(fake_connection_.get())));
    EXPECT_EQ(static_cast<size_t>(1), fake_connection_->observers().size());
    EXPECT_EQ(secure_channel_.get(), fake_connection_->observers()[0]);

    test_observer_ = std::make_unique<TestObserver>(secure_channel_.get());
    secure_channel_->AddObserver(test_observer_.get());
  }

  void TearDown() override {
    // All state changes should have already been verified. This ensures that
    // no test has missed one.
    VerifyConnectionStateChanges(std::vector<SecureChannelStatusChange>());

    // Same with received messages.
    VerifyReceivedMessages(std::vector<ReceivedMessage>());

    // Same with messages being sent.
    if (secure_channel_)
      VerifyNoMessageBeingSent();

    multidevice::SecureMessageDelegateImpl::Factory::SetFactoryForTesting(
        nullptr);
  }

  void VerifyConnectionStateChanges(
      const std::vector<SecureChannelStatusChange>& expected_changes) {
    verified_status_changes_.insert(verified_status_changes_.end(),
                                    expected_changes.begin(),
                                    expected_changes.end());

    ASSERT_EQ(verified_status_changes_.size(),
              test_observer_->connection_status_changes().size());

    for (size_t i = 0; i < verified_status_changes_.size(); i++) {
      EXPECT_EQ(verified_status_changes_[i].old_status,
                test_observer_->connection_status_changes()[i].old_status);
      EXPECT_EQ(verified_status_changes_[i].new_status,
                test_observer_->connection_status_changes()[i].new_status);
    }
  }

  void VerifyReceivedMessages(
      const std::vector<ReceivedMessage>& expected_messages) {
    verified_received_messages_.insert(verified_received_messages_.end(),
                                       expected_messages.begin(),
                                       expected_messages.end());

    ASSERT_EQ(verified_received_messages_.size(),
              test_observer_->received_messages().size());

    for (size_t i = 0; i < verified_received_messages_.size(); i++) {
      EXPECT_EQ(verified_received_messages_[i].feature,
                test_observer_->received_messages()[i].feature);
      EXPECT_EQ(verified_received_messages_[i].payload,
                test_observer_->received_messages()[i].payload);
    }
  }

  void FailAuthentication(Authenticator::Result result) {
    ASSERT_NE(result, Authenticator::Result::SUCCESS);

    FakeAuthenticator* authenticator = static_cast<FakeAuthenticator*>(
        test_authenticator_factory_->last_instance());
    authenticator->last_callback().Run(result, nullptr);
  }

  void AuthenticateSuccessfully() {
    FakeAuthenticator* authenticator = static_cast<FakeAuthenticator*>(
        test_authenticator_factory_->last_instance());

    fake_secure_context_ = new FakeSecureContext();
    authenticator->last_callback().Run(
        Authenticator::Result::SUCCESS,
        base::WrapUnique(fake_secure_context_.get()));
  }

  void ConnectAndAuthenticate() {
    secure_channel_->Initialize();
    VerifyConnectionStateChanges(std::vector<SecureChannelStatusChange>{
        {SecureChannel::Status::DISCONNECTED,
         SecureChannel::Status::CONNECTING}});

    fake_connection_->CompleteInProgressConnection(/* success */ true);
    VerifyConnectionStateChanges(std::vector<SecureChannelStatusChange>{
        {SecureChannel::Status::CONNECTING, SecureChannel::Status::CONNECTED},
        {SecureChannel::Status::CONNECTED,
         SecureChannel::Status::AUTHENTICATING}});

    AuthenticateSuccessfully();
    VerifyConnectionStateChanges(std::vector<SecureChannelStatusChange>{
        {SecureChannel::Status::AUTHENTICATING,
         SecureChannel::Status::AUTHENTICATED}});
  }

  // Starts sending the message and returns the sequence number.
  int StartSendingMessage(const std::string& feature,
                          const std::string& payload) {
    int sequence_number = secure_channel_->SendMessage(feature, payload);
    VerifyMessageBeingSent(feature, payload);
    return sequence_number;
  }

  void FinishSendingMessage(int sequence_number, bool success) {
    std::vector<int> sent_sequence_numbers_before_send =
        test_observer_->sent_sequence_numbers();

    fake_connection_->FinishSendingMessageWithSuccess(success);

    if (success) {
      std::vector<int> sent_sequence_numbers_after_send =
          test_observer_->sent_sequence_numbers();
      EXPECT_EQ(sent_sequence_numbers_before_send.size() + 1u,
                sent_sequence_numbers_after_send.size());
      EXPECT_EQ(sequence_number, sent_sequence_numbers_after_send.back());
    }
  }

  void StartAndFinishSendingMessage(const std::string& feature,
                                    const std::string& payload,
                                    bool success) {
    int sequence_number = StartSendingMessage(feature, payload);
    FinishSendingMessage(sequence_number, success);
  }

  void VerifyNoMessageBeingSent() {
    EXPECT_FALSE(fake_connection_->current_message());
  }

  void VerifyMessageBeingSent(const std::string& feature,
                              const std::string& payload) {
    WireMessage* message_being_sent = fake_connection_->current_message();
    // Note that despite the fact that |Encode()| has an asynchronous interface,
    // the implementation will call |VerifyWireMessageContents()| synchronously.
    fake_secure_context_->Encode(
        payload,
        base::BindOnce(&SecureChannelConnectionTest::VerifyWireMessageContents,
                       weak_ptr_factory_.GetWeakPtr(), message_being_sent,
                       feature));
  }

  void VerifyWireMessageContents(WireMessage* wire_message,
                                 const std::string& expected_feature,
                                 const std::string& expected_payload) {
    EXPECT_EQ(expected_feature, wire_message->feature());
    EXPECT_EQ(expected_payload, wire_message->payload());
  }

  void VerifyRssi(std::optional<int32_t> expected_rssi) {
    fake_connection_->set_rssi_to_return(expected_rssi);

    secure_channel_->GetConnectionRssi(
        base::BindOnce(&SecureChannelConnectionTest::OnConnectionRssi,
                       base::Unretained(this)));

    std::optional<int32_t> rssi = rssi_;
    rssi_.reset();

    EXPECT_EQ(expected_rssi, rssi);
  }

  void OnConnectionRssi(std::optional<int32_t> rssi) { rssi_ = rssi; }

  // Owned by secure_channel_.
  raw_ptr<FakeConnection, DanglingUntriaged> fake_connection_;

  std::unique_ptr<multidevice::FakeSecureMessageDelegateFactory>
      fake_secure_message_delegate_factory_;

  // Owned by secure_channel_ once authentication has completed successfully.
  raw_ptr<FakeSecureContext, DanglingUntriaged> fake_secure_context_;

  std::vector<SecureChannelStatusChange> verified_status_changes_;

  std::vector<ReceivedMessage> verified_received_messages_;

  std::unique_ptr<SecureChannel> secure_channel_;

  std::unique_ptr<TestObserver> test_observer_;

  std::unique_ptr<TestAuthenticatorFactory> test_authenticator_factory_;

  const multidevice::RemoteDeviceRef test_device_;

  std::optional<int32_t> rssi_;

  base::WeakPtrFactory<SecureChannelConnectionTest> weak_ptr_factory_{this};
};

TEST_F(SecureChannelConnectionTest, ConnectionAttemptFails) {
  secure_channel_->Initialize();
  VerifyConnectionStateChanges(std::vector<SecureChannelStatusChange>{
      {SecureChannel::Status::DISCONNECTED,
       SecureChannel::Status::CONNECTING}});

  fake_connection_->CompleteInProgressConnection(/* success */ false);
  VerifyConnectionStateChanges(std::vector<SecureChannelStatusChange>{
      {SecureChannel::Status::CONNECTING,
       SecureChannel::Status::DISCONNECTED}});
}

TEST_F(SecureChannelConnectionTest, DisconnectBeforeAuthentication) {
  secure_channel_->Initialize();
  VerifyConnectionStateChanges(std::vector<SecureChannelStatusChange>{
      {SecureChannel::Status::DISCONNECTED,
       SecureChannel::Status::CONNECTING}});

  fake_connection_->Disconnect();
  VerifyConnectionStateChanges(std::vector<SecureChannelStatusChange>{
      {SecureChannel::Status::CONNECTING,
       SecureChannel::Status::DISCONNECTED}});
}

TEST_F(SecureChannelConnectionTest, AuthenticationFails_Disconnect) {
  secure_channel_->Initialize();
  VerifyConnectionStateChanges(std::vector<SecureChannelStatusChange>{
      {SecureChannel::Status::DISCONNECTED,
       SecureChannel::Status::CONNECTING}});

  fake_connection_->CompleteInProgressConnection(/* success */ true);
  VerifyConnectionStateChanges(std::vector<SecureChannelStatusChange>{
      {SecureChannel::Status::CONNECTING, SecureChannel::Status::CONNECTED},
      {SecureChannel::Status::CONNECTED,
       SecureChannel::Status::AUTHENTICATING}});

  FailAuthentication(Authenticator::Result::DISCONNECTED);
  VerifyConnectionStateChanges(std::vector<SecureChannelStatusChange>{
      {SecureChannel::Status::AUTHENTICATING,
       SecureChannel::Status::DISCONNECTING},
      {SecureChannel::Status::DISCONNECTING,
       SecureChannel::Status::DISCONNECTED}});
}

TEST_F(SecureChannelConnectionTest, AuthenticationFails_Failure) {
  secure_channel_->Initialize();
  VerifyConnectionStateChanges(std::vector<SecureChannelStatusChange>{
      {SecureChannel::Status::DISCONNECTED,
       SecureChannel::Status::CONNECTING}});

  fake_connection_->CompleteInProgressConnection(/* success */ true);
  VerifyConnectionStateChanges(std::vector<SecureChannelStatusChange>{
      {SecureChannel::Status::CONNECTING, SecureChannel::Status::CONNECTED},
      {SecureChannel::Status::CONNECTED,
       SecureChannel::Status::AUTHENTICATING}});

  FailAuthentication(Authenticator::Result::FAILURE);
  VerifyConnectionStateChanges(std::vector<SecureChannelStatusChange>{
      {SecureChannel::Status::AUTHENTICATING,
       SecureChannel::Status::DISCONNECTING},
      {SecureChannel::Status::DISCONNECTING,
       SecureChannel::Status::DISCONNECTED}});
}

// Regression test for crbug.com/765810. This test ensures that a crash does not
// occur if an unexpected message is received before authentication is complete.
TEST_F(SecureChannelConnectionTest, ReceiveMessageBeforeAuth) {
  secure_channel_->Initialize();
  VerifyConnectionStateChanges(std::vector<SecureChannelStatusChange>{
      {SecureChannel::Status::DISCONNECTED,
       SecureChannel::Status::CONNECTING}});

  fake_connection_->CompleteInProgressConnection(/* success */ true);
  VerifyConnectionStateChanges(std::vector<SecureChannelStatusChange>{
      {SecureChannel::Status::CONNECTING, SecureChannel::Status::CONNECTED},
      {SecureChannel::Status::CONNECTED,
       SecureChannel::Status::AUTHENTICATING}});

  // Receive an unexpected message (i.e., a non-auth message).
  fake_connection_->ReceiveMessage("feature", "payload, but encoded");

  // Still should be able to finish authentication.
  AuthenticateSuccessfully();
  VerifyConnectionStateChanges(std::vector<SecureChannelStatusChange>{
      {SecureChannel::Status::AUTHENTICATING,
       SecureChannel::Status::AUTHENTICATED}});
}

TEST_F(SecureChannelConnectionTest, SendMessage_DisconnectWhileSending) {
  ConnectAndAuthenticate();
  int sequence_number = StartSendingMessage("feature", "payload");

  fake_connection_->Disconnect();
  VerifyConnectionStateChanges(std::vector<SecureChannelStatusChange>{
      {SecureChannel::Status::AUTHENTICATED,
       SecureChannel::Status::DISCONNECTED}});

  FinishSendingMessage(sequence_number, false);
  // No further state change should have occurred.
  VerifyConnectionStateChanges(std::vector<SecureChannelStatusChange>());
}

TEST_F(SecureChannelConnectionTest,
       SendMessage_DisconnectWhileSending_ThenSendCompletedOccurs) {
  ConnectAndAuthenticate();
  StartSendingMessage("feature", "payload");

  fake_connection_->Disconnect();
  VerifyConnectionStateChanges(std::vector<SecureChannelStatusChange>{
      {SecureChannel::Status::AUTHENTICATED,
       SecureChannel::Status::DISCONNECTED}});

  // If, due to a race condition, a disconnection occurs and |SendCompleted()|
  // is called in the success case, nothing should occur.
  fake_connection_->FinishSendingMessageWithSuccess(true);

  // No further state change should have occurred.
  VerifyConnectionStateChanges(std::vector<SecureChannelStatusChange>());
}

TEST_F(SecureChannelConnectionTest, SendMessage_Failure) {
  ConnectAndAuthenticate();
  StartAndFinishSendingMessage("feature", "payload", /* success */ false);
  VerifyConnectionStateChanges(std::vector<SecureChannelStatusChange>{
      {SecureChannel::Status::AUTHENTICATED,
       SecureChannel::Status::DISCONNECTING},
      {SecureChannel::Status::DISCONNECTING,
       SecureChannel::Status::DISCONNECTED}});
}

TEST_F(SecureChannelConnectionTest, SendMessage_Success) {
  ConnectAndAuthenticate();
  StartAndFinishSendingMessage("feature", "payload", /* success */ true);
}

TEST_F(SecureChannelConnectionTest, SendMessage_MultipleMessages_Success) {
  ConnectAndAuthenticate();

  // Send a second message before the first has completed.
  int sequence_number1 = secure_channel_->SendMessage("feature1", "payload1");
  int sequence_number2 = secure_channel_->SendMessage("feature2", "payload2");

  // The first message should still be sending.
  VerifyMessageBeingSent("feature1", "payload1");

  // Send the first message.
  FinishSendingMessage(sequence_number1, true);

  // Now, the second message should be sending.
  VerifyMessageBeingSent("feature2", "payload2");
  FinishSendingMessage(sequence_number2, true);
}

TEST_F(SecureChannelConnectionTest, SendMessage_MultipleMessages_FirstFails) {
  ConnectAndAuthenticate();

  // Send a second message before the first has completed.
  int sequence_number1 = secure_channel_->SendMessage("feature1", "payload1");
  secure_channel_->SendMessage("feature2", "payload2");

  // The first message should still be sending.
  VerifyMessageBeingSent("feature1", "payload1");

  // Fail sending the first message.
  FinishSendingMessage(sequence_number1, false);

  // The connection should have become disconnected.
  VerifyConnectionStateChanges(std::vector<SecureChannelStatusChange>{
      {SecureChannel::Status::AUTHENTICATED,
       SecureChannel::Status::DISCONNECTING},
      {SecureChannel::Status::DISCONNECTING,
       SecureChannel::Status::DISCONNECTED}});

  // The first message failed, so no other ones should be tried afterward.
  VerifyNoMessageBeingSent();
}

TEST_F(SecureChannelConnectionTest, RegisterPayloadFile) {
  ConnectAndAuthenticate();

  secure_channel_->RegisterPayloadFile(
      /*payload_id=*/1234, mojom::PayloadFiles::New(),
      FileTransferUpdateCallback(),
      base::BindLambdaForTesting([&](bool success) { EXPECT_TRUE(success); }));
  secure_channel_->RegisterPayloadFile(
      /*payload_id=*/-5678, mojom::PayloadFiles::New(),
      FileTransferUpdateCallback(),
      base::BindLambdaForTesting([&](bool success) { EXPECT_TRUE(success); }));

  EXPECT_EQ(2ul, fake_connection_->reigster_payload_file_requests().size());
  EXPECT_EQ(
      1234,
      fake_connection_->reigster_payload_file_requests().at(0).payload_id);
  EXPECT_EQ(
      -5678,
      fake_connection_->reigster_payload_file_requests().at(1).payload_id);
}

TEST_F(SecureChannelConnectionTest, ReceiveMessage) {
  ConnectAndAuthenticate();

  // Note: FakeSecureContext's Encode() function simply adds ", but encoded" to
  // the end of the message.
  fake_connection_->ReceiveMessage("feature", "payload, but encoded");
  VerifyReceivedMessages(std::vector<ReceivedMessage>{{"feature", "payload"}});
}

TEST_F(SecureChannelConnectionTest, SendAndReceiveMessages) {
  ConnectAndAuthenticate();

  StartAndFinishSendingMessage("feature", "request1", /* success */ true);

  // Note: FakeSecureContext's Encode() function simply adds ", but encoded" to
  // the end of the message.
  fake_connection_->ReceiveMessage("feature", "response1, but encoded");
  VerifyReceivedMessages(
      std::vector<ReceivedMessage>{{"feature", "response1"}});

  StartAndFinishSendingMessage("feature", "request2", /* success */ true);

  fake_connection_->ReceiveMessage("feature", "response2, but encoded");
  VerifyReceivedMessages(
      std::vector<ReceivedMessage>{{"feature", "response2"}});
}

TEST_F(SecureChannelConnectionTest, ObserverDeletesChannel) {
  // Add a special Observer which deletes |secure_channel_| once it receives an
  // OnMessageSent() call.
  std::unique_ptr<DeletingObserver> deleting_observer =
      base::WrapUnique(new DeletingObserver(&secure_channel_));
  secure_channel_->AddObserver(deleting_observer.get());

  ConnectAndAuthenticate();

  // Send a message successfully; this triggers an OnMessageSent() call which
  // deletes the channel. Note that this would have caused a crash before the
  // fix for crbug.com/751884.
  StartAndFinishSendingMessage("feature", "request1", /* success */ true);
  EXPECT_FALSE(secure_channel_);
}

TEST_F(SecureChannelConnectionTest, GetRssi) {
  // Test a few different values.
  VerifyRssi(-50 /* expected_rssi */);
  VerifyRssi(-40 /* expected_rssi */);
  VerifyRssi(-30 /* expected_rssi */);
}

TEST_F(SecureChannelConnectionTest, GetChannelBindingData) {
  ConnectAndAuthenticate();

  fake_secure_context_->set_channel_binding_data("channel_binding_data");
  EXPECT_EQ("channel_binding_data", secure_channel_->GetChannelBindingData());
}

}  // namespace ash::secure_channel