chromium/chrome/browser/enterprise/platform_auth/platform_auth_navigation_throttle_unittest.cc

// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/browser/enterprise/platform_auth/platform_auth_navigation_throttle.h"

#include <memory>

#include "base/memory/raw_ptr.h"
#include "base/run_loop.h"
#include "base/task/thread_pool.h"
#include "base/test/bind.h"
#include "base/test/mock_callback.h"
#include "base/test/scoped_feature_list.h"
#include "base/test/task_environment.h"
#include "base/time/time.h"
#include "base/values.h"
#include "build/build_config.h"
#include "chrome/browser/enterprise/platform_auth/mock_platform_auth_provider.h"
#include "chrome/browser/enterprise/platform_auth/platform_auth_provider_manager.h"
#include "chrome/browser/enterprise/platform_auth/scoped_set_provider_for_testing.h"
#include "chrome/test/base/testing_browser_process.h"
#include "chrome/test/base/testing_profile.h"
#include "components/policy/core/common/policy_pref_names.h"
#include "components/sync_preferences/testing_pref_service_syncable.h"
#include "content/public/browser/navigation_throttle.h"
#include "content/public/browser/web_contents.h"
#include "content/public/test/browser_task_environment.h"
#include "content/public/test/mock_navigation_handle.h"
#include "content/public/test/test_renderer_host.h"
#include "content/public/test/web_contents_tester.h"
#include "net/http/http_request_headers.h"
#include "net/http/http_response_headers.h"
#include "net/http/http_util.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"

using content::NavigationThrottle;
using ::testing::_;
using ::testing::Invoke;
using ::testing::Return;

namespace {

void EnableManager(enterprise_auth::PlatformAuthProviderManager& manager,
                   bool enabled) {
  base::MockCallback<base::OnceClosure> callback;
  base::RunLoop run_loop;
  EXPECT_CALL(callback, Run()).WillOnce([&run_loop]() {
    run_loop.QuitWhenIdle();
  });
  manager.SetEnabled(enabled, callback.Get());
  EXPECT_EQ(manager.IsEnabled(), enabled);
  run_loop.Run();
  EXPECT_EQ(manager.IsEnabled(), enabled);
}

}  // namespace

namespace enterprise_auth {

class PlatformAuthNavigationThrottleTest : public testing::Test {
 public:
  PlatformAuthNavigationThrottleTest() : mock_provider_(owned_provider_.get()) {
    // Set up a default that will clear the raw pointer upon destruction.
    ON_CALL(*mock_provider_, Die()).WillByDefault([this]() {
      this->mock_provider_ = nullptr;
    });
    // Expect the provider to be destroyed at some point.
    EXPECT_CALL(*mock_provider_, Die());
  }

  PlatformAuthProviderManager& manager() {
    return PlatformAuthProviderManager::GetInstance();
  }

  std::unique_ptr<PlatformAuthNavigationThrottle> CreateThrottle(
      content::NavigationHandle* navigation_handle) {
    return std::make_unique<PlatformAuthNavigationThrottle>(navigation_handle);
  }

  content::WebContents* web_contents() const { return web_contents_.get(); }
  content::RenderFrameHost* main_frame() const {
    return web_contents()->GetPrimaryMainFrame();
  }

  std::unique_ptr<PlatformAuthProvider> TakeProvider() {
    return std::move(owned_provider_);
  }

  virtual void SetupOriginFilteringExpectations() {
    EXPECT_CALL(*mock_provider_, SupportsOriginFiltering())
        .WillOnce(::testing::Return(true));
  }

  ::testing::StrictMock<MockPlatformAuthProvider>* mock_provider() {
    return mock_provider_;
  }

 protected:
  void SetUp() override {
    web_contents_ =
        content::WebContentsTester::CreateTestWebContents(&profile_, nullptr);
    SetupOriginFilteringExpectations();
  }

  content::BrowserTaskEnvironment task_environment_{
      base::test::TaskEnvironment::TimeSource::MOCK_TIME};
  content::RenderViewHostTestEnabler rvh_test_enabler_;
  TestingProfile profile_;
  std::unique_ptr<content::WebContents> web_contents_;
  std::unique_ptr<::testing::StrictMock<MockPlatformAuthProvider>>
      owned_provider_{
          std::make_unique<::testing::StrictMock<MockPlatformAuthProvider>>()};
  raw_ptr<::testing::StrictMock<MockPlatformAuthProvider>> mock_provider_ =
      nullptr;
};

// The manager is disabled, so no origins or data are fetched.
TEST_F(PlatformAuthNavigationThrottleTest, ManagerDisabled) {
  EXPECT_CALL(*mock_provider(), FetchOrigins(_)).Times(0);
  ScopedSetProviderForTesting set_provider(TakeProvider());

  content::MockNavigationHandle test_handle(GURL("https://www.example.test/"),
                                            main_frame());
  auto throttle = CreateThrottle(&test_handle);
  EXPECT_CALL(test_handle, SetAllowCookiesFromBrowser(false));
  EXPECT_CALL(test_handle, SetRequestHeader(_, _)).Times(0);

  EXPECT_FALSE(manager().IsEnabled());
  EXPECT_EQ(manager().GetOriginsForTesting(), std::vector<url::Origin>());
  EXPECT_EQ(NavigationThrottle::PROCEED, throttle->WillStartRequest().action());
  EXPECT_STREQ("PlatformAuthNavigationThrottle", throttle->GetNameForLogging());
}

// The manager is enabled, so an origin fetch happens. No data is fetched when
// an empty set of origins is returned.
TEST_F(PlatformAuthNavigationThrottleTest, EmptyOrigins) {
  ScopedSetProviderForTesting set_provider(TakeProvider());

  EXPECT_CALL(*mock_provider(), FetchOrigins(_))
      .WillOnce([](PlatformAuthProvider::FetchOriginsCallback callback) {
        std::move(callback).Run(nullptr);
      });
  EnableManager(manager(), true);
  EXPECT_TRUE(manager().IsEnabled());

  content::MockNavigationHandle test_handle(GURL("https://www.example.test/"),
                                            main_frame());
  auto throttle = CreateThrottle(&test_handle);
  EXPECT_CALL(test_handle, SetAllowCookiesFromBrowser(true));
  EXPECT_CALL(test_handle, SetRequestHeader(_, _)).Times(0);

  EXPECT_EQ(NavigationThrottle::PROCEED, throttle->WillStartRequest().action());
  EXPECT_EQ(manager().GetOriginsForTesting(), std::vector<url::Origin>());
}

// The manager is enabled and a set of origins is returned, so the throttle
// tries to fetch data. The fetched data is empty, so the throttle does not
// attach any headers to the request.
TEST_F(PlatformAuthNavigationThrottleTest, EmptyData) {
  auto url = GURL("https://www.example.test/");
  ScopedSetProviderForTesting set_provider(TakeProvider());
  EXPECT_CALL(*mock_provider(), FetchOrigins(_))
      .WillOnce([&url](PlatformAuthProvider::FetchOriginsCallback callback) {
        std::move(callback).Run(std::make_unique<std::vector<url::Origin>>(
            std::vector<url::Origin>{url::Origin::Create(url)}));
      });
  EnableManager(manager(), true);
  EXPECT_TRUE(manager().IsEnabled());

  content::MockNavigationHandle test_handle(url, main_frame());
  auto throttle = CreateThrottle(&test_handle);
  EXPECT_CALL(test_handle, SetAllowCookiesFromBrowser(true));
  EXPECT_CALL(test_handle, SetRequestHeader(_, _)).Times(0);

  // The provider returns an empty set of authentication headers.
  EXPECT_CALL(*mock_provider(), GetData(_, _))
      .WillOnce([](const GURL& url,
                   PlatformAuthProviderManager::GetDataCallback callback) {
        std::move(callback).Run(net::HttpRequestHeaders());
      });

  EXPECT_EQ(manager().GetOriginsForTesting(),
            std::vector<url::Origin>({url::Origin::Create(url)}));
  EXPECT_EQ(NavigationThrottle::PROCEED, throttle->WillStartRequest().action());
}

// The manager is enabled and a set of origins is returned, so the throttle
// fetches data and attaches all received headers to the request. On redirect to
// another origin, all previously added headers are removed from the request.
TEST_F(PlatformAuthNavigationThrottleTest, DataReceived) {
  const char kOldCookieValue[] = "old-cookie=old-data";
  const char kCookieValue[] = "cookie-1=cookie-1-data; cookie-2=cookie-2-data";
  const char kOldHeaderName[] = "old-header";
  const char kOldHeaderValue[] = "old-header-data";
  const char kHeader1Name[] = "header-1";
  const char kHeader1Value[] = "header-1-data";
  const char kHeader2Name[] = "header-2";
  const char kHeader2Value[] = "header-2-data";

  auto url = GURL("https://www.example.test/");
  ScopedSetProviderForTesting set_provider(TakeProvider());
  EXPECT_CALL(*mock_provider(), FetchOrigins(_))
      .WillOnce([&url](PlatformAuthProvider::FetchOriginsCallback callback) {
        std::move(callback).Run(std::make_unique<std::vector<url::Origin>>(
            std::vector<url::Origin>{url::Origin::Create(url)}));
      });
  EnableManager(manager(), true);
  EXPECT_TRUE(manager().IsEnabled());

  content::MockNavigationHandle test_handle(url, main_frame());
  // Set some existing request headers.
  net::HttpRequestHeaders headers;
  headers.SetHeader(net::HttpRequestHeaders::kCookie,
                    base::JoinString({kOldCookieValue, kCookieValue}, "; "));
  headers.SetHeader(kOldHeaderName, kOldHeaderValue);
  headers.SetHeader(kHeader1Name, kHeader1Value);
  test_handle.set_request_headers(std::move(headers));

  auto throttle = CreateThrottle(&test_handle);
  throttle->set_resume_callback_for_testing(base::DoNothing());

  // Headers are added to the request whose origin matches the origin stored in
  // the manager.
  EXPECT_CALL(test_handle, SetAllowCookiesFromBrowser(true));
  EXPECT_CALL(test_handle,
              SetRequestHeader(net::HttpRequestHeaders::kCookie, kCookieValue));
  EXPECT_CALL(test_handle, SetRequestHeader(kHeader1Name, kHeader1Value));
  EXPECT_CALL(test_handle, SetRequestHeader(kHeader2Name, kHeader2Value));

  // The provider returns a non-empty set of authentication headers.
  EXPECT_CALL(*mock_provider(), GetData(_, _))
      .WillOnce([&kCookieValue, &kHeader1Name, &kHeader1Value, &kHeader2Name,
                 &kHeader2Value](
                    const GURL& url,
                    PlatformAuthProviderManager::GetDataCallback callback) {
        net::HttpRequestHeaders auth_headers;
        auth_headers.SetHeader(net::HttpRequestHeaders::kCookie, kCookieValue);
        auth_headers.SetHeader(kHeader1Name, kHeader1Value);
        auth_headers.SetHeader(kHeader2Name, kHeader2Value);
        base::ThreadPool::PostTask(base::BindOnce(
            [](net::HttpRequestHeaders auth_headers,
               PlatformAuthProviderManager::GetDataCallback callback) {
              // Simulates a data fetch delay to cause the throttle to be
              // deferred.
              base::PlatformThread::Sleep(base::Milliseconds(10));
              std::move(callback).Run(std::move(auth_headers));
            },
            std::move(auth_headers), std::move(callback)));
      });

  EXPECT_EQ(manager().GetOriginsForTesting(),
            std::vector<url::Origin>({url::Origin::Create(url)}));
  EXPECT_EQ(NavigationThrottle::DEFER, throttle->WillStartRequest().action());

  // Ensure the async data fetch completes.
  task_environment_.RunUntilIdle();

  ::testing::Mock::VerifyAndClearExpectations(&test_handle);

  // The redirect is to another origin, so cookies and headers set by the
  // throttle are removed.
  EXPECT_CALL(test_handle,
              RemoveRequestHeader(net::HttpRequestHeaders::kCookie));
  EXPECT_CALL(test_handle, RemoveRequestHeader(kHeader1Name));
  EXPECT_CALL(test_handle, RemoveRequestHeader(kHeader2Name));
  auto redirect_url = GURL("https://www.redirect.test/");
  test_handle.set_url(redirect_url);
  EXPECT_EQ(NavigationThrottle::PROCEED,
            throttle->WillRedirectRequest().action());
}

class PlatformAuthNavigationNoOriginFilteringThrottleTest
    : public PlatformAuthNavigationThrottleTest {
  void SetupOriginFilteringExpectations() override {
    EXPECT_CALL(*mock_provider_, SupportsOriginFiltering())
        .WillOnce(::testing::Return(false));
  }
};

// The manager is enabled and a set of origins is returned, so the throttle
// fetches data and attaches all received headers to the request. On redirect to
// another origin, all previously added headers are removed from the request.
TEST_F(PlatformAuthNavigationNoOriginFilteringThrottleTest,
       DataReceivedOriginFiltering) {
  const char kOldCookieValue[] = "old-cookie=old-data";
  const char kCookieValue[] = "cookie-1=cookie-1-data; cookie-2=cookie-2-data";
  const char kOldHeaderName[] = "old-header";
  const char kOldHeaderValue[] = "old-header-data";
  const char kHeader1Name[] = "header-1";
  const char kHeader1Value[] = "header-1-data";
  const char kHeader2Name[] = "header-2";
  const char kHeader2Value[] = "header-2-data";

  auto url = GURL("https://www.example.test/");
  ScopedSetProviderForTesting set_provider(TakeProvider());
  EXPECT_CALL(*mock_provider(), FetchOrigins(_)).Times(0);
  EnableManager(manager(), true);
  EXPECT_TRUE(manager().IsEnabled());

  content::MockNavigationHandle test_handle(url, main_frame());
  // Set some existing request headers.
  net::HttpRequestHeaders headers;
  headers.SetHeader(net::HttpRequestHeaders::kCookie,
                    base::JoinString({kOldCookieValue, kCookieValue}, "; "));
  headers.SetHeader(kOldHeaderName, kOldHeaderValue);
  headers.SetHeader(kHeader1Name, kHeader1Value);
  test_handle.set_request_headers(std::move(headers));

  auto throttle = CreateThrottle(&test_handle);
  throttle->set_resume_callback_for_testing(base::DoNothing());

  // Headers are added to the request whose origin matches the origin stored in
  // the manager.
  EXPECT_CALL(test_handle, SetAllowCookiesFromBrowser(true));
  EXPECT_CALL(test_handle,
              SetRequestHeader(net::HttpRequestHeaders::kCookie, kCookieValue));
  EXPECT_CALL(test_handle, SetRequestHeader(kHeader1Name, kHeader1Value));
  EXPECT_CALL(test_handle, SetRequestHeader(kHeader2Name, kHeader2Value));

  // The provider returns a non-empty set of authentication headers.
  EXPECT_CALL(*mock_provider(), GetData(_, _))
      .WillOnce([&kCookieValue, &kHeader1Name, &kHeader1Value, &kHeader2Name,
                 &kHeader2Value](
                    const GURL& url,
                    PlatformAuthProviderManager::GetDataCallback callback) {
        net::HttpRequestHeaders auth_headers;
        auth_headers.SetHeader(net::HttpRequestHeaders::kCookie, kCookieValue);
        auth_headers.SetHeader(kHeader1Name, kHeader1Value);
        auth_headers.SetHeader(kHeader2Name, kHeader2Value);
        base::ThreadPool::PostTask(base::BindOnce(
            [](net::HttpRequestHeaders auth_headers,
               PlatformAuthProviderManager::GetDataCallback callback) {
              // Simulates a data fetch delay to cause the throttle to be
              // deferred.
              base::PlatformThread::Sleep(base::Milliseconds(10));
              std::move(callback).Run(std::move(auth_headers));
            },
            std::move(auth_headers), std::move(callback)));
      });

  EXPECT_EQ(manager().GetOriginsForTesting(), std::vector<url::Origin>());
  EXPECT_EQ(NavigationThrottle::DEFER, throttle->WillStartRequest().action());

  // Ensure the async data fetch completes.
  task_environment_.RunUntilIdle();

  ::testing::Mock::VerifyAndClearExpectations(&test_handle);

  // The redirect is to another origin, so cookies and headers set by the
  // throttle are removed.
  EXPECT_CALL(test_handle,
              RemoveRequestHeader(net::HttpRequestHeaders::kCookie));
  EXPECT_CALL(test_handle, RemoveRequestHeader(kHeader1Name));
  EXPECT_CALL(test_handle, RemoveRequestHeader(kHeader2Name));

  // The provider returns an empty set of authentication headers.
  EXPECT_CALL(*mock_provider(), GetData(_, _))
      .WillOnce([](const GURL& url,
                   PlatformAuthProviderManager::GetDataCallback callback) {
        std::move(callback).Run(net::HttpRequestHeaders());
      });

  auto redirect_url = GURL("https://www.redirect.test/");
  test_handle.set_url(redirect_url);
  EXPECT_EQ(NavigationThrottle::PROCEED,
            throttle->WillRedirectRequest().action());
}

}  // namespace enterprise_auth