chromium/chromeos/ash/services/secure_channel/device_to_device_authenticator_unittest.cc

// Copyright 2015 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chromeos/ash/services/secure_channel/device_to_device_authenticator.h"

#include <memory>
#include <utility>
#include <vector>

#include "base/base64url.h"
#include "base/functional/bind.h"
#include "base/functional/callback.h"
#include "base/memory/ptr_util.h"
#include "base/memory/raw_ptr.h"
#include "base/rand_util.h"
#include "base/test/task_environment.h"
#include "base/test/test_future.h"
#include "base/timer/mock_timer.h"
#include "chromeos/ash/components/multidevice/fake_secure_message_delegate.h"
#include "chromeos/ash/components/multidevice/remote_device_test_util.h"
#include "chromeos/ash/components/multidevice/secure_message_delegate.h"
#include "chromeos/ash/services/secure_channel/authenticator.h"
#include "chromeos/ash/services/secure_channel/connection.h"
#include "chromeos/ash/services/secure_channel/device_to_device_responder_operations.h"
#include "chromeos/ash/services/secure_channel/file_transfer_update_callback.h"
#include "chromeos/ash/services/secure_channel/public/mojom/secure_channel_types.mojom.h"
#include "chromeos/ash/services/secure_channel/secure_context.h"
#include "chromeos/ash/services/secure_channel/session_keys.h"
#include "chromeos/ash/services/secure_channel/wire_message.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace ash::secure_channel {

namespace {

// The initiator's session public key in base64url form. Note that this is
// actually a serialized proto.
const char kInitiatorSessionPublicKeyBase64[] =
    "CAESRQogOlH8DgPMQu7eAt-b6yoTXcazG8mAl6SPC5Ds-LTULIcSIQDZDMqsoYRO4tNMej1FB"
    "El1sTiTiVDqrcGq-CkYCzDThw==";

// The initiator's session public key in base64url form. Note that this is
// actually a serialized proto.
const char kResponderSessionPublicKeyBase64[] =
    "CAESRgohAN9QYU5HySO14Gi9PDIClacBnC0C8wqPwXsNHUNG_vXlEiEAggzU80ZOd9DWuCBdp"
    "6bzpGcC-oj1yrwdVCHGg_yeaAQ=";

// Callback saving the result of ValidateHelloMessage().
void SaveValidateHelloMessageResult(bool* validated_out,
                                    std::string* public_key_out,
                                    bool validated,
                                    const std::string& public_key) {
  *validated_out = validated;
  *public_key_out = public_key;
}

// Connection implementation for testing.
class FakeConnection : public Connection {
 public:
  explicit FakeConnection(multidevice::RemoteDeviceRef remote_device)
      : Connection(remote_device), connection_blocked_(false) {}

  FakeConnection(const FakeConnection&) = delete;
  FakeConnection& operator=(const FakeConnection&) = delete;

  ~FakeConnection() override {}

  // Connection:
  void Connect() override { SetStatus(Connection::Status::CONNECTED); }
  void Disconnect() override { SetStatus(Connection::Status::DISCONNECTED); }
  std::string GetDeviceAddress() override { return std::string(); }

  using Connection::OnBytesReceived;

  void ClearMessageBuffer() { message_buffer_.clear(); }

  const std::vector<std::unique_ptr<WireMessage>>& message_buffer() {
    return message_buffer_;
  }

  void set_connection_blocked(bool connection_blocked) {
    connection_blocked_ = connection_blocked;
  }

  bool connection_blocked() { return connection_blocked_; }

 protected:
  // Connection:
  void SendMessageImpl(std::unique_ptr<WireMessage> message) override {
    const WireMessage& message_alias = *message;
    message_buffer_.push_back(std::move(message));
    OnDidSendMessage(message_alias, !connection_blocked_);
  }

  void RegisterPayloadFileImpl(
      int64_t payload_id,
      mojom::PayloadFilesPtr payload_files,
      FileTransferUpdateCallback file_transfer_update_callback,
      base::OnceCallback<void(bool)> registration_result_callback) override {
    std::move(registration_result_callback).Run(/*success=*/false);
  }

 private:
  std::vector<std::unique_ptr<WireMessage>> message_buffer_;

  bool connection_blocked_;
};

// Harness for testing DeviceToDeviceAuthenticator.
class DeviceToDeviceAuthenticatorForTest : public DeviceToDeviceAuthenticator {
 public:
  DeviceToDeviceAuthenticatorForTest(
      Connection* connection,
      std::unique_ptr<multidevice::SecureMessageDelegate>
          secure_message_delegate)
      : DeviceToDeviceAuthenticator(connection,
                                    std::move(secure_message_delegate)),
        timer_(nullptr) {}

  DeviceToDeviceAuthenticatorForTest(
      const DeviceToDeviceAuthenticatorForTest&) = delete;
  DeviceToDeviceAuthenticatorForTest& operator=(
      const DeviceToDeviceAuthenticatorForTest&) = delete;

  ~DeviceToDeviceAuthenticatorForTest() override {}

  base::MockOneShotTimer* timer() { return timer_; }

 private:
  // DeviceToDeviceAuthenticator:
  std::unique_ptr<base::OneShotTimer> CreateTimer() override {
    auto timer = std::make_unique<base::MockOneShotTimer>();
    timer_ = timer.get();
    return timer;
  }

  // This instance is owned by the super class.
  raw_ptr<base::MockOneShotTimer, DanglingUntriaged> timer_;

  base::test::SingleThreadTaskEnvironment env_;
};

}  // namespace

class SecureChannelDeviceToDeviceAuthenticatorTest : public testing::Test {
 public:
  SecureChannelDeviceToDeviceAuthenticatorTest()
      : remote_device_(multidevice::CreateRemoteDeviceRefForTest()),
        connection_(remote_device_),
        secure_message_delegate_(new multidevice::FakeSecureMessageDelegate),
        authenticator_(&connection_,
                       base::WrapUnique(secure_message_delegate_.get())) {}

  SecureChannelDeviceToDeviceAuthenticatorTest(
      const SecureChannelDeviceToDeviceAuthenticatorTest&) = delete;
  SecureChannelDeviceToDeviceAuthenticatorTest& operator=(
      const SecureChannelDeviceToDeviceAuthenticatorTest&) = delete;

  ~SecureChannelDeviceToDeviceAuthenticatorTest() override {}

  void SetUp() override {
    // Set up the session asymmetric keys for both the local and remote devices.
    ASSERT_TRUE(
        base::Base64UrlDecode(kInitiatorSessionPublicKeyBase64,
                              base::Base64UrlDecodePolicy::REQUIRE_PADDING,
                              &local_session_public_key_));
    ASSERT_TRUE(
        base::Base64UrlDecode(kResponderSessionPublicKeyBase64,
                              base::Base64UrlDecodePolicy::REQUIRE_PADDING,
                              &remote_session_public_key_));
    remote_session_private_key_ =
        secure_message_delegate_->GetPrivateKeyForPublicKey(
            remote_session_public_key_),

    secure_message_delegate_->set_next_public_key(local_session_public_key_);
    connection_.Connect();

    base::test::TestFuture<const std::string&> future;
    secure_message_delegate_->DeriveKey(remote_session_private_key_,
                                        local_session_public_key_,
                                        future.GetCallback());
    session_symmetric_key_ = future.Take();
  }

  // Begins authentication, and returns the [Hello] message sent from the local
  // device to the remote device.
  std::string BeginAuthentication() {
    authenticator_.Authenticate(base::BindOnce(
        &SecureChannelDeviceToDeviceAuthenticatorTest::OnAuthenticationResult,
        base::Unretained(this)));

    EXPECT_EQ(1u, connection_.message_buffer().size());
    std::string hello_message = connection_.message_buffer()[0]->payload();
    connection_.ClearMessageBuffer();

    bool validated = false;
    std::string local_session_public_key;
    DeviceToDeviceResponderOperations::ValidateHelloMessage(
        hello_message, remote_device_.persistent_symmetric_key(),
        secure_message_delegate_,
        base::BindOnce(&SaveValidateHelloMessageResult, &validated,
                       &local_session_public_key));

    EXPECT_TRUE(validated);
    EXPECT_EQ(local_session_public_key_, local_session_public_key);

    return hello_message;
  }

  // Simulate receiving a valid [Responder Auth] message from the remote device.
  std::string SimulateResponderAuth(const std::string& hello_message) {
    std::string remote_device_private_key =
        secure_message_delegate_->GetPrivateKeyForPublicKey(
            multidevice::kTestRemoteDevicePublicKey);

    base::test::TestFuture<const std::string&> future;
    DeviceToDeviceResponderOperations::CreateResponderAuthMessage(
        hello_message, remote_session_public_key_, remote_session_private_key_,
        remote_device_private_key, remote_device_.persistent_symmetric_key(),
        secure_message_delegate_, future.GetCallback());
    std::string responder_auth_message = future.Take();
    EXPECT_FALSE(responder_auth_message.empty());

    WireMessage wire_message(responder_auth_message,
                             Authenticator::kAuthenticationFeature);
    connection_.OnBytesReceived(wire_message.Serialize());

    return responder_auth_message;
  }

  void OnAuthenticationResult(Authenticator::Result result,
                              std::unique_ptr<SecureContext> secure_context) {
    secure_context_ = std::move(secure_context);
    OnAuthenticationResultProxy(result);
  }

  MOCK_METHOD1(OnAuthenticationResultProxy, void(Authenticator::Result result));

  // Contains information about the remote device.
  const multidevice::RemoteDeviceRef remote_device_;

  // Simulates the connection to the remote device.
  FakeConnection connection_;

  // The SecureMessageDelegate used by the authenticator.
  // Owned by |authenticator_|.
  raw_ptr<multidevice::FakeSecureMessageDelegate, DanglingUntriaged>
      secure_message_delegate_;

  // The DeviceToDeviceAuthenticator under test.
  DeviceToDeviceAuthenticatorForTest authenticator_;

  // The session keys in play during authentication.
  std::string local_session_public_key_;
  std::string remote_session_public_key_;
  std::string remote_session_private_key_;
  std::string session_symmetric_key_;

  // Stores the SecureContext returned after authentication succeeds.
  std::unique_ptr<SecureContext> secure_context_;
};

TEST_F(SecureChannelDeviceToDeviceAuthenticatorTest, AuthenticateSucceeds) {
  // Starts the authentication protocol and grab [Hello] message.
  std::string hello_message = BeginAuthentication();

  // Simulate receiving a valid [Responder Auth] from the remote device.
  EXPECT_CALL(*this,
              OnAuthenticationResultProxy(Authenticator::Result::SUCCESS));
  std::string responder_auth_message = SimulateResponderAuth(hello_message);
  EXPECT_TRUE(secure_context_);

  // Validate the local device sends a valid [Initiator Auth] message.
  ASSERT_EQ(1u, connection_.message_buffer().size());
  std::string initiator_auth = connection_.message_buffer()[0]->payload();

  base::test::TestFuture<bool> future;
  DeviceToDeviceResponderOperations::ValidateInitiatorAuthMessage(
      initiator_auth, SessionKeys(session_symmetric_key_),
      remote_device_.persistent_symmetric_key(), responder_auth_message,
      secure_message_delegate_, future.GetCallback());
  ASSERT_TRUE(future.Get());
}

TEST_F(SecureChannelDeviceToDeviceAuthenticatorTest, ResponderRejectsHello) {
  std::string hello_message = BeginAuthentication();

  // If the responder could not validate the [Hello message], it essentially
  // sends random bytes back for privacy reasons.
  WireMessage wire_message(base::RandBytesAsString(300u),
                           Authenticator::kAuthenticationFeature);
  EXPECT_CALL(*this,
              OnAuthenticationResultProxy(Authenticator::Result::FAILURE));
  connection_.OnBytesReceived(wire_message.Serialize());
  EXPECT_FALSE(secure_context_);
}

TEST_F(SecureChannelDeviceToDeviceAuthenticatorTest, ResponderAuthTimesOut) {
  // Starts the authentication protocol and grab [Hello] message.
  std::string hello_message = BeginAuthentication();
  ASSERT_TRUE(authenticator_.timer());
  EXPECT_CALL(*this,
              OnAuthenticationResultProxy(Authenticator::Result::FAILURE));
  authenticator_.timer()->Fire();
  EXPECT_FALSE(secure_context_);
}

TEST_F(SecureChannelDeviceToDeviceAuthenticatorTest,
       DisconnectsWaitingForResponderAuth) {
  std::string hello_message = BeginAuthentication();
  EXPECT_CALL(*this,
              OnAuthenticationResultProxy(Authenticator::Result::DISCONNECTED));
  connection_.Disconnect();
  EXPECT_FALSE(secure_context_);
}

TEST_F(SecureChannelDeviceToDeviceAuthenticatorTest, NotConnectedInitially) {
  connection_.Disconnect();
  EXPECT_CALL(*this,
              OnAuthenticationResultProxy(Authenticator::Result::DISCONNECTED));
  authenticator_.Authenticate(base::BindOnce(
      &SecureChannelDeviceToDeviceAuthenticatorTest::OnAuthenticationResult,
      base::Unretained(this)));
  EXPECT_FALSE(secure_context_);
}

TEST_F(SecureChannelDeviceToDeviceAuthenticatorTest, FailToSendHello) {
  connection_.set_connection_blocked(true);
  EXPECT_CALL(*this,
              OnAuthenticationResultProxy(Authenticator::Result::FAILURE));
  authenticator_.Authenticate(base::BindOnce(
      &SecureChannelDeviceToDeviceAuthenticatorTest::OnAuthenticationResult,
      base::Unretained(this)));
  EXPECT_FALSE(secure_context_);
}

TEST_F(SecureChannelDeviceToDeviceAuthenticatorTest, FailToSendInitiatorAuth) {
  std::string hello_message = BeginAuthentication();

  connection_.set_connection_blocked(true);
  EXPECT_CALL(*this,
              OnAuthenticationResultProxy(Authenticator::Result::FAILURE));
  SimulateResponderAuth(hello_message);
  EXPECT_FALSE(secure_context_);
}

TEST_F(SecureChannelDeviceToDeviceAuthenticatorTest,
       SendMessagesAfterAuthenticationSuccess) {
  std::string hello_message = BeginAuthentication();
  EXPECT_CALL(*this,
              OnAuthenticationResultProxy(Authenticator::Result::SUCCESS));
  SimulateResponderAuth(hello_message);

  // Test that the authenticator is properly cleaned up after authentication
  // completes.
  WireMessage wire_message(base::RandBytesAsString(300u),
                           Authenticator::kAuthenticationFeature);
  connection_.SendMessage(std::make_unique<WireMessage>(
      base::RandBytesAsString(300u), Authenticator::kAuthenticationFeature));
  connection_.OnBytesReceived(wire_message.Serialize());
  connection_.SendMessage(std::make_unique<WireMessage>(
      base::RandBytesAsString(300u), Authenticator::kAuthenticationFeature));
  connection_.OnBytesReceived(wire_message.Serialize());
}

}  // namespace ash::secure_channel