chromium/chromeos/ash/services/device_sync/cryptauth_device_notifier_impl_unittest.cc

// Copyright 2019 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chromeos/ash/services/device_sync/cryptauth_device_notifier_impl.h"

#include <memory>
#include <optional>
#include <string>
#include <utility>

#include "base/containers/flat_set.h"
#include "base/containers/queue.h"
#include "base/functional/callback.h"
#include "base/memory/raw_ptr.h"
#include "base/no_destructor.h"
#include "base/timer/mock_timer.h"
#include "chromeos/ash/services/device_sync/cryptauth_client.h"
#include "chromeos/ash/services/device_sync/cryptauth_device_notifier.h"
#include "chromeos/ash/services/device_sync/cryptauth_feature_type.h"
#include "chromeos/ash/services/device_sync/cryptauth_key_bundle.h"
#include "chromeos/ash/services/device_sync/mock_cryptauth_client.h"
#include "chromeos/ash/services/device_sync/network_request_error.h"
#include "chromeos/ash/services/device_sync/proto/cryptauth_common.pb.h"
#include "chromeos/ash/services/device_sync/proto/cryptauth_devicesync.pb.h"
#include "chromeos/ash/services/device_sync/proto/cryptauth_v2_test_util.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace ash {

namespace device_sync {

namespace {

const char kAccessTokenUsed[] = "access token used by CryptAuthClient";

const cryptauthv2::ClientMetadata& GetClientMetadata() {
  static const base::NoDestructor<cryptauthv2::ClientMetadata> client_metadata(
      [] {
        cryptauthv2::ClientMetadata client_metadata;
        client_metadata.set_invocation_reason(
            cryptauthv2::ClientMetadata::INVOCATION_REASON_UNSPECIFIED);
        return client_metadata;
      }());
  return *client_metadata;
}

const cryptauthv2::RequestContext& GetRequestContext() {
  static const base::NoDestructor<cryptauthv2::RequestContext> request_context(
      cryptauthv2::BuildRequestContext(
          CryptAuthKeyBundle::KeyBundleNameEnumToString(
              CryptAuthKeyBundle::Name::kDeviceSyncBetterTogether),
          GetClientMetadata(), cryptauthv2::kTestInstanceId,
          cryptauthv2::kTestInstanceIdToken));
  return *request_context;
}

// Request with |device_ids|, Enrollment as the target service, and
// BetterTogether host (enabled) as the feature type.
cryptauthv2::BatchNotifyGroupDevicesRequest
NotifyEnrollmentBetterTogetherHostEnabledRequest(
    const base::flat_set<std::string>& device_ids) {
  cryptauthv2::BatchNotifyGroupDevicesRequest request;
  request.mutable_context()->CopyFrom(GetRequestContext());
  *request.mutable_notify_device_ids() = {device_ids.begin(), device_ids.end()};
  request.set_target_service(cryptauthv2::TargetService::ENROLLMENT);
  request.set_feature_type(CryptAuthFeatureTypeToString(
      CryptAuthFeatureType::kBetterTogetherHostEnabled));

  return request;
}

// Request with |device_ids|, DeviceSync as the target service, and
// MagicTether client (supported) as the feature type.
cryptauthv2::BatchNotifyGroupDevicesRequest
NotifyDeviceSyncMagicTetherSupportedRequest(
    const base::flat_set<std::string>& device_ids) {
  cryptauthv2::BatchNotifyGroupDevicesRequest request;
  request.mutable_context()->CopyFrom(GetRequestContext());
  *request.mutable_notify_device_ids() = {device_ids.begin(), device_ids.end()};
  request.set_target_service(cryptauthv2::TargetService::DEVICE_SYNC);
  request.set_feature_type(CryptAuthFeatureTypeToString(
      CryptAuthFeatureType::kMagicTetherClientSupported));

  return request;
}

}  // namespace

class DeviceSyncCryptAuthDeviceNotifierImplTest
    : public testing::Test,
      public MockCryptAuthClientFactory::Observer {
 protected:
  enum class RequestAction { kSucceed, kFail, kTimeout };

  DeviceSyncCryptAuthDeviceNotifierImplTest()
      : mock_client_factory_(
            MockCryptAuthClientFactory::MockType::MAKE_NICE_MOCKS) {
    mock_client_factory_.AddObserver(this);
  }

  DeviceSyncCryptAuthDeviceNotifierImplTest(
      const DeviceSyncCryptAuthDeviceNotifierImplTest&) = delete;
  DeviceSyncCryptAuthDeviceNotifierImplTest& operator=(
      const DeviceSyncCryptAuthDeviceNotifierImplTest&) = delete;

  ~DeviceSyncCryptAuthDeviceNotifierImplTest() override {
    mock_client_factory_.RemoveObserver(this);
  }

  // testing::Test:
  void SetUp() override {
    auto mock_timer = std::make_unique<base::MockOneShotTimer>();
    mock_timer_ = mock_timer.get();

    device_notifier_ = CryptAuthDeviceNotifierImpl::Factory::Create(
        cryptauthv2::kTestInstanceId, cryptauthv2::kTestInstanceIdToken,
        &mock_client_factory_, std::move(mock_timer));
  }

  // MockCryptAuthClientFactory::Observer:
  void OnCryptAuthClientCreated(MockCryptAuthClient* client) override {
    ON_CALL(*client,
            BatchNotifyGroupDevices(testing::_, testing::_, testing::_))
        .WillByDefault(Invoke(this, &DeviceSyncCryptAuthDeviceNotifierImplTest::
                                        OnBatchNotifyGroupDevices));

    ON_CALL(*client, GetAccessTokenUsed())
        .WillByDefault(testing::Return(kAccessTokenUsed));
  }

  void NotifyDevices(const base::flat_set<std::string>& device_ids,
                     cryptauthv2::TargetService target_service,
                     CryptAuthFeatureType feature_type) {
    device_notifier_->NotifyDevices(
        device_ids, target_service, feature_type,
        base::BindOnce(
            &DeviceSyncCryptAuthDeviceNotifierImplTest::OnNotifyDevicesSuccess,
            base::Unretained(this)),
        base::BindOnce(
            &DeviceSyncCryptAuthDeviceNotifierImplTest::OnNotifyDevicesFailure,
            base::Unretained(this)));
  }

  void HandleNextBatchNotifyGroupDevicesRequest(
      const cryptauthv2::BatchNotifyGroupDevicesRequest& expected_request,
      RequestAction request_action,
      std::optional<NetworkRequestError> error = std::nullopt) {
    ASSERT_FALSE(batch_notify_group_devices_requests_.empty());

    cryptauthv2::BatchNotifyGroupDevicesRequest current_request =
        std::move(batch_notify_group_devices_requests_.front());
    batch_notify_group_devices_requests_.pop();

    CryptAuthClient::BatchNotifyGroupDevicesCallback current_success_callback =
        std::move(batch_notify_group_devices_success_callbacks_.front());
    batch_notify_group_devices_success_callbacks_.pop();

    CryptAuthClient::ErrorCallback current_failure_callback =
        std::move(batch_notify_group_devices_failure_callbacks_.front());
    batch_notify_group_devices_failure_callbacks_.pop();

    EXPECT_EQ(expected_request.SerializeAsString(),
              current_request.SerializeAsString());

    switch (request_action) {
      case RequestAction::kSucceed:
        std::move(current_success_callback)
            .Run(cryptauthv2::BatchNotifyGroupDevicesResponse());
        break;
      case RequestAction::kFail:
        ASSERT_TRUE(error);
        std::move(current_failure_callback).Run(*error);
        break;
      case RequestAction::kTimeout:
        mock_timer_->Fire();
        break;
    }
  }

  void VerifyResults(
      const std::vector<std::optional<NetworkRequestError>> expected_results) {
    // Verify that all requests were processed.
    EXPECT_TRUE(batch_notify_group_devices_requests_.empty());
    EXPECT_TRUE(batch_notify_group_devices_success_callbacks_.empty());
    EXPECT_TRUE(batch_notify_group_devices_failure_callbacks_.empty());

    EXPECT_EQ(expected_results, results_);
  }

 private:
  void OnBatchNotifyGroupDevices(
      const cryptauthv2::BatchNotifyGroupDevicesRequest& request,
      CryptAuthClient::BatchNotifyGroupDevicesCallback callback,
      CryptAuthClient::ErrorCallback error_callback) {
    batch_notify_group_devices_requests_.push(request);
    batch_notify_group_devices_success_callbacks_.push(std::move(callback));
    batch_notify_group_devices_failure_callbacks_.push(
        std::move(error_callback));
  }

  void OnNotifyDevicesSuccess() { results_.push_back(std::nullopt); }

  void OnNotifyDevicesFailure(NetworkRequestError error) {
    results_.push_back(error);
  }

  base::queue<cryptauthv2::BatchNotifyGroupDevicesRequest>
      batch_notify_group_devices_requests_;
  base::queue<CryptAuthClient::BatchNotifyGroupDevicesCallback>
      batch_notify_group_devices_success_callbacks_;
  base::queue<CryptAuthClient::ErrorCallback>
      batch_notify_group_devices_failure_callbacks_;

  // std::nullopt indicates a success.
  std::vector<std::optional<NetworkRequestError>> results_;

  MockCryptAuthClientFactory mock_client_factory_;
  raw_ptr<base::MockOneShotTimer, DanglingUntriaged> mock_timer_ = nullptr;

  std::unique_ptr<CryptAuthDeviceNotifier> device_notifier_;
};

TEST_F(DeviceSyncCryptAuthDeviceNotifierImplTest, Test) {
  // Queue up 4 requests before any finish. They should be processed
  // sequentially.
  NotifyDevices({"device_id_1"}, cryptauthv2::TargetService::ENROLLMENT,
                CryptAuthFeatureType::kBetterTogetherHostEnabled);
  NotifyDevices({"device_id_2", "device_id_3"},
                cryptauthv2::TargetService::DEVICE_SYNC,
                CryptAuthFeatureType::kMagicTetherClientSupported);
  NotifyDevices({"device_id_4", "device_id_5"},
                cryptauthv2::TargetService::ENROLLMENT,
                CryptAuthFeatureType::kBetterTogetherHostEnabled);
  NotifyDevices({"device_id_6"}, cryptauthv2::TargetService::DEVICE_SYNC,
                CryptAuthFeatureType::kMagicTetherClientSupported);

  // std::nullopt indicates a success.
  std::vector<std::optional<NetworkRequestError>> expected_results;

  // Timeout waiting for BatchNotifyGroupDevices.
  HandleNextBatchNotifyGroupDevicesRequest(
      NotifyEnrollmentBetterTogetherHostEnabledRequest({"device_id_1"}),
      RequestAction::kTimeout);
  expected_results.push_back(NetworkRequestError::kUnknown);

  // Fail BatchNotifyGroupDevices call with "Bad Request".
  HandleNextBatchNotifyGroupDevicesRequest(
      NotifyDeviceSyncMagicTetherSupportedRequest(
          {"device_id_2", "device_id_3"}),
      RequestAction::kFail, NetworkRequestError::kBadRequest);
  expected_results.push_back(NetworkRequestError::kBadRequest);

  // Succeed notifying devices.
  HandleNextBatchNotifyGroupDevicesRequest(
      NotifyEnrollmentBetterTogetherHostEnabledRequest(
          {"device_id_4", "device_id_5"}),
      RequestAction::kSucceed);
  expected_results.push_back(std::nullopt);

  // Succeed notifying devices.
  HandleNextBatchNotifyGroupDevicesRequest(
      NotifyDeviceSyncMagicTetherSupportedRequest({"device_id_6"}),
      RequestAction::kSucceed);
  expected_results.push_back(std::nullopt);

  VerifyResults(expected_results);
}

}  // namespace device_sync

}  // namespace ash