chromium/components/device_signals/core/browser/win/win_signals_collector_unittest.cc

// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "components/device_signals/core/browser/win/win_signals_collector.h"

#include <array>

#include "base/run_loop.h"
#include "base/test/task_environment.h"
#include "base/values.h"
#include "components/device_signals/core/browser/mock_system_signals_service_host.h"
#include "components/device_signals/core/browser/signals_types.h"
#include "components/device_signals/core/common/signals_constants.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"

using testing::_;
using testing::Invoke;
using testing::Return;
using testing::StrictMock;

namespace device_signals {

using GetAntiVirusSignalsCallback =
    MockSystemSignalsService::GetAntiVirusSignalsCallback;
using GetHotfixSignalsCallback =
    MockSystemSignalsService::GetHotfixSignalsCallback;

class WinSignalsCollectorTest : public testing::Test {
 protected:
  WinSignalsCollectorTest() : win_collector_(&service_host_) {
    ON_CALL(service_host_, GetService()).WillByDefault(Return(&service_));
  }

  SignalsAggregationRequest CreateRequest(SignalName signal_name) {
    SignalsAggregationRequest request;
    request.signal_names.emplace(signal_name);
    return request;
  }

  base::test::TaskEnvironment task_environment_;

  StrictMock<MockSystemSignalsServiceHost> service_host_;
  StrictMock<MockSystemSignalsService> service_;
  WinSignalsCollector win_collector_;
};

// Test that runs a sanity check on the set of signals supported by this
// collector. Will need to be updated if new signals become supported.
TEST_F(WinSignalsCollectorTest, SupportedSignalNames) {
  const std::array<SignalName, 2> supported_signals{
      {SignalName::kAntiVirus, SignalName::kHotfixes}};

  const auto names_set = win_collector_.GetSupportedSignalNames();

  EXPECT_EQ(names_set.size(), supported_signals.size());
  for (const auto& signal_name : supported_signals) {
    EXPECT_TRUE(names_set.find(signal_name) != names_set.end());
  }
}

// Tests that an unsupported signal is marked as unsupported.
TEST_F(WinSignalsCollectorTest, GetSignal_Unsupported) {
  SignalName signal_name = SignalName::kFileSystemInfo;
  SignalsAggregationResponse response;
  base::RunLoop run_loop;
  win_collector_.GetSignal(signal_name, CreateRequest(signal_name), response,
                           run_loop.QuitClosure());

  run_loop.Run();

  ASSERT_TRUE(response.top_level_error.has_value());
  EXPECT_EQ(response.top_level_error.value(),
            SignalCollectionError::kUnsupported);
}

// Tests that not being able to retrieve a pointer to the SystemSignalsService
// returns an error.
TEST_F(WinSignalsCollectorTest, GetSignal_AV_MissingSystemSignalsService) {
  EXPECT_CALL(service_host_, GetService()).WillOnce(Return(nullptr));

  SignalName signal_name = SignalName::kAntiVirus;
  SignalsAggregationResponse response;
  base::RunLoop run_loop;
  win_collector_.GetSignal(signal_name, CreateRequest(signal_name), response,
                           run_loop.QuitClosure());

  run_loop.Run();

  ASSERT_FALSE(response.top_level_error.has_value());
  ASSERT_TRUE(response.av_signal_response.has_value());
  ASSERT_TRUE(response.av_signal_response->collection_error.has_value());
  EXPECT_EQ(response.av_signal_response->collection_error.value(),
            SignalCollectionError::kMissingSystemService);
}

// Tests that not being able to retrieve a pointer to the SystemSignalsService
// returns an error.
TEST_F(WinSignalsCollectorTest, GetSignal_Hotfix_MissingSystemSignalsService) {
  EXPECT_CALL(service_host_, GetService()).WillOnce(Return(nullptr));

  SignalName signal_name = SignalName::kHotfixes;
  SignalsAggregationResponse response;
  base::RunLoop run_loop;
  win_collector_.GetSignal(signal_name, CreateRequest(signal_name), response,
                           run_loop.QuitClosure());

  run_loop.Run();

  ASSERT_FALSE(response.top_level_error.has_value());
  ASSERT_TRUE(response.hotfix_signal_response.has_value());
  ASSERT_TRUE(response.hotfix_signal_response->collection_error.has_value());
  EXPECT_EQ(response.hotfix_signal_response->collection_error.value(),
            SignalCollectionError::kMissingSystemService);
}

// Tests a successful AV signal retrieval.
TEST_F(WinSignalsCollectorTest, GetSignal_AV) {
  std::vector<AvProduct> av_products;
  av_products.push_back(
      {"AV Product Name", AvProductState::kOn, "some product id"});

  EXPECT_CALL(service_host_, GetService()).Times(1);
  EXPECT_CALL(service_, GetAntiVirusSignals(_))
      .WillOnce(
          Invoke([&av_products](GetAntiVirusSignalsCallback signal_callback) {
            std::move(signal_callback).Run(av_products);
          }));

  SignalName signal_name = SignalName::kAntiVirus;
  SignalsAggregationResponse response;
  base::RunLoop run_loop;
  win_collector_.GetSignal(signal_name, CreateRequest(signal_name), response,
                           run_loop.QuitClosure());

  run_loop.Run();

  EXPECT_FALSE(response.top_level_error.has_value());
  ASSERT_TRUE(response.av_signal_response.has_value());
  EXPECT_FALSE(response.av_signal_response->collection_error.has_value());
  EXPECT_EQ(response.av_signal_response->av_products.size(),
            av_products.size());
  EXPECT_EQ(response.av_signal_response->av_products[0], av_products[0]);
}

// Tests a successful hotfix signal retrieval.
TEST_F(WinSignalsCollectorTest, GetSignal_Hotfix) {
  std::vector<InstalledHotfix> hotfixes;
  hotfixes.push_back({"hotfix id"});

  EXPECT_CALL(service_host_, GetService()).Times(1);
  EXPECT_CALL(service_, GetHotfixSignals(_))
      .WillOnce(Invoke([&hotfixes](GetHotfixSignalsCallback signal_callback) {
        std::move(signal_callback).Run(hotfixes);
      }));

  SignalName signal_name = SignalName::kHotfixes;
  SignalsAggregationResponse response;
  base::RunLoop run_loop;
  win_collector_.GetSignal(signal_name, CreateRequest(signal_name), response,
                           run_loop.QuitClosure());

  run_loop.Run();

  EXPECT_FALSE(response.top_level_error.has_value());
  ASSERT_TRUE(response.hotfix_signal_response.has_value());
  EXPECT_FALSE(response.hotfix_signal_response->collection_error.has_value());
  EXPECT_EQ(response.hotfix_signal_response->hotfixes.size(), hotfixes.size());
  EXPECT_EQ(response.hotfix_signal_response->hotfixes[0], hotfixes[0]);
}

}  // namespace device_signals