chromium/components/safe_browsing/android/remote_database_manager_unittest.cc

// Copyright 2015 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "components/safe_browsing/android/remote_database_manager.h"

#include <map>
#include <memory>

#include "base/metrics/field_trial.h"
#include "base/metrics/field_trial_params.h"
#include "base/run_loop.h"
#include "base/test/metrics/histogram_tester.h"
#include "base/test/scoped_feature_list.h"
#include "base/time/time.h"
#include "components/safe_browsing/android/safe_browsing_api_handler_bridge.h"
#include "components/safe_browsing/core/browser/db/database_manager.h"
#include "components/safe_browsing/core/browser/db/v4_protocol_manager_util.h"
#include "components/safe_browsing/core/browser/db/v4_test_util.h"
#include "components/safe_browsing/core/common/features.h"
#include "components/variations/variations_associated_data.h"
#include "content/public/test/browser_task_environment.h"
#include "services/network/public/cpp/shared_url_loader_factory.h"
#include "services/network/public/cpp/weak_wrapper_shared_url_loader_factory.h"
#include "services/network/test/test_url_loader_factory.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace safe_browsing {

namespace {

// Used to override response from SafeBrowsingApiHandlerBridge.
class TestUrlCheckInterceptor : public safe_browsing::UrlCheckInterceptor {
 public:
  TestUrlCheckInterceptor() = default;
  ~TestUrlCheckInterceptor() override = default;

  // Checks the threat type of |url| previously set by
  // |SetSafetyNetThreatTypeForUrl|. It crashes if the threat type of |url| is
  // not set in advance.
  void CheckBySafetyNet(SafeBrowsingApiHandlerBridge::ResponseCallback callback,
                        const GURL& gurl) override {
    std::string url = gurl.spec();
    DCHECK(base::Contains(urls_safetynet_threat_type_, url));
    std::move(callback).Run(urls_safetynet_threat_type_[url], ThreatMetadata());
  }

  // Checks the threat type of |url| previously set by
  // |SetSafeBrowsingThreatTypeForUrl|. It crashes if the threat type of |url|
  // is not set in advance.
  void CheckBySafeBrowsing(
      SafeBrowsingApiHandlerBridge::ResponseCallback callback,
      const GURL& gurl) override {
    std::string url = gurl.spec();
    DCHECK(base::Contains(urls_safebrowsing_threat_type_, url));
    std::move(callback).Run(urls_safebrowsing_threat_type_[url],
                            ThreatMetadata());
  }

  void SetSafetyNetThreatTypeForUrl(const GURL& url, SBThreatType threat_type) {
    urls_safetynet_threat_type_[url.spec()] = threat_type;
  }

  void SetSafeBrowsingThreatTypeForUrl(const GURL& url,
                                       SBThreatType threat_type) {
    urls_safebrowsing_threat_type_[url.spec()] = threat_type;
  }

 private:
  base::flat_map<std::string, SBThreatType> urls_safetynet_threat_type_;
  base::flat_map<std::string, SBThreatType> urls_safebrowsing_threat_type_;
};

// Used to verify the result returned from RemoteDatabaseManager is expected.
class TestClient : public SafeBrowsingDatabaseManager::Client {
 public:
  TestClient(scoped_refptr<RemoteSafeBrowsingDatabaseManager> db,
             const GURL& expected_url,
             SBThreatType expected_threat_type)
      : db_(db),
        expected_url_(expected_url),
        expected_threat_type_(expected_threat_type) {}

  ~TestClient() override { db_->CancelCheck(this); }

  void OnCheckBrowseUrlResult(const GURL& url,
                              SBThreatType threat_type,
                              const ThreatMetadata& metadata) override {
    EXPECT_EQ(expected_url_, url);
    EXPECT_EQ(expected_threat_type_, threat_type);
    is_callback_called_ = true;
  }

  void OnCheckDownloadUrlResult(const std::vector<GURL>& url_chain,
                                SBThreatType threat_type) override {
    EXPECT_EQ(expected_threat_type_, threat_type);
    is_callback_called_ = true;
  }

  bool IsCallbackCalled() { return is_callback_called_; }

 private:
  scoped_refptr<RemoteSafeBrowsingDatabaseManager> db_;
  GURL expected_url_;
  SBThreatType expected_threat_type_;
  bool is_callback_called_ = false;
};

}  // namespace

class RemoteDatabaseManagerTest : public testing::Test {
 protected:
  using enum SBThreatType;

  RemoteDatabaseManagerTest() {}

  void SetUp() override {
    test_shared_loader_factory_ =
        base::MakeRefCounted<network::WeakWrapperSharedURLLoaderFactory>(
            &test_url_loader_factory_);
    db_ = new RemoteSafeBrowsingDatabaseManager();
    db_->StartOnUIThread(test_shared_loader_factory_,
                         GetTestV4ProtocolConfig());

    url_interceptor_ = std::make_unique<TestUrlCheckInterceptor>();
    SafeBrowsingApiHandlerBridge::GetInstance().SetInterceptorForTesting(
        url_interceptor_.get());
  }

  void TearDown() override {
    db_->StopOnUIThread(/*shutdown=*/false);
    db_ = nullptr;
  }

  content::BrowserTaskEnvironment task_environment_;
  std::unique_ptr<TestUrlCheckInterceptor> url_interceptor_;
  network::TestURLLoaderFactory test_url_loader_factory_;
  scoped_refptr<network::SharedURLLoaderFactory> test_shared_loader_factory_;
  scoped_refptr<RemoteSafeBrowsingDatabaseManager> db_;
  base::HistogramTester histogram_tester_;
};

TEST_F(RemoteDatabaseManagerTest, CheckBrowseUrl_HashDatabase) {
  GURL url("https://example.com");
  url_interceptor_->SetSafeBrowsingThreatTypeForUrl(
      url, SB_THREAT_TYPE_URL_PHISHING);
  TestClient client(db_, /*expected_url=*/url,
                    /*expected_threat_type=*/SB_THREAT_TYPE_URL_PHISHING);

  db_->CheckBrowseUrl(url, {SB_THREAT_TYPE_URL_PHISHING}, &client,
                      CheckBrowseUrlType::kHashDatabase);

  task_environment_.RunUntilIdle();
  EXPECT_TRUE(client.IsCallbackCalled());
  histogram_tester_.ExpectUniqueSample("SB2.RemoteCall.CanCheckUrl",
                                       /*sample=*/true,
                                       /*expected_bucket_count=*/1);
  histogram_tester_.ExpectUniqueSample(
      "SB2.RemoteCall.CanCheckUrl.HashDatabase",
      /*sample=*/true,
      /*expected_bucket_count=*/1);
  histogram_tester_.ExpectTotalCount("SB2.RemoteCall.CanCheckUrl.HashRealTime",
                                     /*expected_count=*/0);
}

TEST_F(RemoteDatabaseManagerTest, CheckBrowseUrl_HashRealtime) {
  GURL url("https://example.com");
  url_interceptor_->SetSafeBrowsingThreatTypeForUrl(
      url, SB_THREAT_TYPE_URL_PHISHING);
  TestClient client(db_, /*expected_url=*/url,
                    /*expected_threat_type=*/SB_THREAT_TYPE_URL_PHISHING);

  db_->CheckBrowseUrl(url, {SB_THREAT_TYPE_URL_PHISHING}, &client,
                      CheckBrowseUrlType::kHashRealTime);

  task_environment_.RunUntilIdle();
  EXPECT_TRUE(client.IsCallbackCalled());
  histogram_tester_.ExpectUniqueSample("SB2.RemoteCall.CanCheckUrl",
                                       /*sample=*/true,
                                       /*expected_bucket_count=*/1);
  histogram_tester_.ExpectUniqueSample(
      "SB2.RemoteCall.CanCheckUrl.HashRealTime",
      /*sample=*/true,
      /*expected_bucket_count=*/1);
  histogram_tester_.ExpectTotalCount("SB2.RemoteCall.CanCheckUrl.HashDatabase",
                                     /*expected_count=*/0);
}

TEST_F(RemoteDatabaseManagerTest, ThreatSource) {
  base::test::ScopedFeatureList scoped_feature_list;
  EXPECT_EQ(ThreatSource::ANDROID_SAFEBROWSING,
            db_->GetBrowseUrlThreatSource(CheckBrowseUrlType::kHashDatabase));
  EXPECT_EQ(ThreatSource::ANDROID_SAFEBROWSING_REAL_TIME,
            db_->GetBrowseUrlThreatSource(CheckBrowseUrlType::kHashRealTime));
  EXPECT_EQ(ThreatSource::REMOTE, db_->GetNonBrowseUrlThreatSource());
}

TEST_F(RemoteDatabaseManagerTest, CheckDownloadUrl) {
  GURL url("https://example.com");
  GURL referrer_url("https://unrelated.com");
  url_interceptor_->SetSafeBrowsingThreatTypeForUrl(url,
                                                    SB_THREAT_TYPE_URL_MALWARE);
  url_interceptor_->SetSafeBrowsingThreatTypeForUrl(referrer_url,
                                                    SB_THREAT_TYPE_SAFE);

  TestClient client(db_, /*expected_url=*/url,
                    /*expected_threat_type=*/SB_THREAT_TYPE_URL_MALWARE);

  db_->CheckDownloadUrl({GURL("https://unrelated.com/"), url}, &client);

  task_environment_.RunUntilIdle();
  EXPECT_TRUE(client.IsCallbackCalled());
}

}  // namespace safe_browsing