// Copyright 2019 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/ui/ash/assistant/assistant_test_mixin.h"
#include "base/memory/raw_ptr.h"
#include <utility>
#include <vector>
#include "ash/assistant/model/ui/assistant_card_element.h"
#include "ash/assistant/ui/assistant_ui_constants.h"
#include "ash/assistant/ui/main_stage/assistant_ui_element_view.h"
#include "ash/constants/ash_switches.h"
#include "ash/public/cpp/assistant/assistant_state.h"
#include "ash/public/cpp/test/assistant_test_api.h"
#include "base/auto_reset.h"
#include "base/containers/to_vector.h"
#include "base/ranges/algorithm.h"
#include "base/run_loop.h"
#include "base/task/sequenced_task_runner.h"
#include "base/test/scoped_run_loop_timeout.h"
#include "base/time/time.h"
#include "chrome/browser/ash/login/test/embedded_test_server_setup_mixin.h"
#include "chrome/browser/ash/login/test/login_manager_mixin.h"
#include "chrome/browser/profiles/profile_manager.h"
#include "chrome/browser/ui/ash/assistant/test_support/fake_s3_server.h"
#include "chrome/test/base/fake_gaia_mixin.h"
#include "chrome/test/base/mixin_based_in_process_browser_test.h"
#include "chromeos/ash/components/login/auth/public/user_context.h"
#include "components/account_id/account_id.h"
#include "components/language/core/browser/pref_names.h"
#include "google_apis/gaia/gaia_urls.h"
#include "net/dns/mock_host_resolver.h"
#include "ui/events/test/event_generator.h"
#include "ui/views/controls/label.h"
namespace ash::assistant {
namespace {
constexpr const char kTestUser[] = "[email protected]";
constexpr const char kTestUserGaiaId[] = "test_user_gaia_id";
LoginManagerMixin::TestUserInfo GetTestUserInfo() {
return LoginManagerMixin::TestUserInfo(
AccountId::FromUserEmailGaiaId(kTestUser, kTestUserGaiaId));
}
// Waiter that blocks in the |Wait| method until a given |AssistantStatus|
// is reached, or until a timeout is hit.
// On timeout this will abort the test with a useful error message.
class AssistantStatusWaiter : private AssistantStateObserver {
public:
AssistantStatusWaiter(AssistantState* state, AssistantStatus expected_status)
: state_(state), expected_status_(expected_status) {
state_->AddObserver(this);
}
~AssistantStatusWaiter() override { state_->RemoveObserver(this); }
void RunUntilExpectedStatus() {
if (state_->assistant_status() == expected_status_)
return;
// Wait until we're ready or we hit the timeout.
base::RunLoop run_loop;
base::AutoReset<base::OnceClosure> quit_loop(&quit_loop_,
run_loop.QuitClosure());
EXPECT_NO_FATAL_FAILURE(run_loop.Run())
<< "Failed waiting for AssistantStatus |" << expected_status_ << "|. "
<< "Current status is |" << state_->assistant_status() << "|. "
<< "One possible cause is that you're using an expired access token.";
}
private:
void OnAssistantStatusChanged(AssistantStatus status) override {
if (status == expected_status_ && quit_loop_)
std::move(quit_loop_).Run();
}
const raw_ptr<AssistantState> state_;
AssistantStatus const expected_status_;
base::OnceClosure quit_loop_;
};
// Base class that observes all new responses being displayed under the
// |parent_view|, waiting for HasResponse() to return |true| or until a timeout
// is hit. On timeout this will abort the test with a useful error message. By
// default, HasResponse() checks for any non-empty response, but this behavior
// can be overridden by derived classes wishing to assert more specific
// expectations. The derived classes must implement the logic to extract the
// response from a given view.
class ResponseWaiter : private views::ViewObserver {
public:
explicit ResponseWaiter(views::View* parent_view)
: parent_view_(parent_view) {
parent_view_->AddObserver(this);
}
~ResponseWaiter() override {
if (parent_view_)
parent_view_->RemoveObserver(this);
}
void RunUntilResponseReceived() {
if (HasResponse())
return;
// Wait until we're ready or we hit the timeout.
base::RunLoop run_loop;
base::AutoReset<base::OnceClosure> quit_loop(&quit_loop_,
run_loop.QuitClosure());
EXPECT_NO_FATAL_FAILURE(run_loop.Run())
<< "Failed waiting for Assistant response.\n"
<< GetFailureMessage();
}
std::string GetResponseText() const {
return GetResponseTextRecursive(parent_view_);
}
private:
// views::ViewObserver overrides:
void OnViewHierarchyChanged(
views::View* observed_view,
const views::ViewHierarchyChangedDetails& details) override {
if (quit_loop_ && HasResponse())
std::move(quit_loop_).Run();
}
void OnViewIsDeleting(views::View* observed_view) override {
DCHECK(observed_view == parent_view_);
if (quit_loop_) {
ADD_FAILURE() << parent_view_->GetClassName() << " is deleted "
<< "before receiving the Assistant response.\n"
<< GetFailureMessage();
std::move(quit_loop_).Run();
}
parent_view_ = nullptr;
}
virtual bool HasResponse() const { return !GetResponseText().empty(); }
virtual std::string GetFailureMessage() const {
return "Expected to receive any non-empty response.";
}
std::string GetResponseTextRecursive(views::View* view) const {
std::optional<std::string> response_maybe = GetResponseTextOfView(view);
if (response_maybe) {
return response_maybe.value() + "\n";
} else {
std::stringstream result;
for (views::View* child : view->children())
result << GetResponseTextRecursive(child);
return result.str();
}
}
virtual std::optional<std::string> GetResponseTextOfView(
views::View* view) const = 0;
raw_ptr<views::View> parent_view_;
base::OnceClosure quit_loop_;
};
// A ResponseWaiter which waits until one of |expected_responses| is received.
// The derived classes must implement the logic to extract the response from a
// given view.
class ExpectedResponseWaiter : public ResponseWaiter {
public:
ExpectedResponseWaiter(views::View* parent_view,
const std::vector<std::string>& expected_responses)
: ResponseWaiter(parent_view), expected_responses_(expected_responses) {}
private:
// ResponseWaiter overrides:
bool HasResponse() const override {
std::string response = GetResponseText();
for (const std::string& expected : expected_responses_) {
if (response.find(expected) != std::string::npos)
return true;
}
return false;
}
std::string GetFailureMessage() const override {
std::stringstream message;
message << "Expected any of " << FormatExpectedResponses() << ".\n";
message << "Got \"" << GetResponseText() << "\"";
return message.str();
}
std::string FormatExpectedResponses() const {
std::stringstream result;
result << "{\n";
for (const std::string& expected : expected_responses_)
result << " \"" << expected << "\",\n";
result << "}";
return result.str();
}
std::vector<std::string> expected_responses_;
};
// A ResponseWaiter which waits until any non-empty response is received for a
// response of the type indicated by the specified |class_name|.
// NOTE: |class_name| must name a class inheriting from AssistantUiElementView.
class TypedResponseWaiter : public ResponseWaiter {
public:
TypedResponseWaiter(const std::string& class_name, views::View* parent_view)
: ResponseWaiter(parent_view), class_name_(class_name) {}
~TypedResponseWaiter() override = default;
private:
// ResponseWaiter overrides:
std::optional<std::string> GetResponseTextOfView(
views::View* view) const override {
if (view->GetClassName() == class_name_) {
return static_cast<AssistantUiElementView*>(view)->ToStringForTesting();
}
return std::nullopt;
}
const std::string class_name_;
};
// An ExpectedResponseWaiter which waits until one of |expected_responses| is
// received for a response of the type indicated by the specified |class_name|.
// NOTE: |class_name| must name a class inheriting from AssistantUiElementView.
class TypedExpectedResponseWaiter : public ExpectedResponseWaiter {
public:
TypedExpectedResponseWaiter(
const std::string& class_name,
views::View* parent_view,
const std::vector<std::string>& expected_responses)
: ExpectedResponseWaiter(parent_view, expected_responses),
class_name_(class_name) {}
~TypedExpectedResponseWaiter() override = default;
private:
// ExpectedResponseWaiter overrides:
std::optional<std::string> GetResponseTextOfView(
views::View* view) const override {
if (view->GetClassName() == class_name_)
return static_cast<AssistantUiElementView*>(view)->ToStringForTesting();
return std::nullopt;
}
const std::string class_name_;
};
// Calls a callback when the view hierarchy changes.
class CallbackViewHierarchyChangedObserver : views::ViewObserver {
public:
explicit CallbackViewHierarchyChangedObserver(
views::View* parent_view,
base::RepeatingCallback<void(const views::ViewHierarchyChangedDetails&)>
callback)
: callback_(callback), parent_view_(parent_view) {
parent_view_->AddObserver(this);
}
~CallbackViewHierarchyChangedObserver() override {
if (parent_view_)
parent_view_->RemoveObserver(this);
}
// ViewObserver:
void OnViewHierarchyChanged(
views::View* observed_view,
const views::ViewHierarchyChangedDetails& details) override {
callback_.Run(details);
}
void OnViewIsDeleting(views::View* view) override {
DCHECK_EQ(view, parent_view_);
if (parent_view_)
parent_view_->RemoveObserver(this);
parent_view_ = nullptr;
}
private:
base::RepeatingCallback<void(const views::ViewHierarchyChangedDetails&)>
callback_;
raw_ptr<views::View> parent_view_;
};
} // namespace
// Test mixin for the browser tests that logs in the given user and issues
// refresh and access tokens for this user.
class LoggedInUserMixin : public InProcessBrowserTestMixin {
public:
LoggedInUserMixin(InProcessBrowserTestMixinHost* host,
InProcessBrowserTest* test_base,
LoginManagerMixin::TestUserInfo user,
net::EmbeddedTestServer* embedded_test_server)
: InProcessBrowserTestMixin(host),
login_manager_(host, {user}),
test_server_(host, embedded_test_server),
fake_gaia_(host),
user_(user),
test_base_(test_base),
user_context_(LoginManagerMixin::CreateDefaultUserContext(user)) {
// Tell LoginManagerMixin to launch the browser when the user is logged in.
login_manager_.SetShouldLaunchBrowser(true);
}
~LoggedInUserMixin() override = default;
void SetAccessToken(std::string token) { access_token_ = token; }
void SetUpOnMainThread() override {
// By default, browser tests block anything that doesn't go to localhost, so
// account.google.com requests would never reach fake GAIA server without
// this.
test_base_->host_resolver()->AddRule("*", "127.0.0.1");
LogIn();
SetupFakeGaia();
// Ensure test_base_->browser() returns the browser of the logged in user
// session.
test_base_->SelectFirstBrowser();
}
void LogIn() {
user_context_.SetRefreshToken(kRefreshToken);
bool success = login_manager_.LoginAndWaitForActiveSession(user_context_);
EXPECT_TRUE(success) << "Failed to log in as test user.";
}
void SetupFakeGaia() {
FakeGaia::AccessTokenInfo token_info;
token_info.token = access_token_;
token_info.audience = GaiaUrls::GetInstance()->oauth2_chrome_client_id();
token_info.email = user_context_.GetAccountId().GetUserEmail();
token_info.any_scope = true;
token_info.expires_in = kAccessTokenExpiration;
fake_gaia_.fake_gaia()->MapEmailToGaiaId(user_.account_id.GetUserEmail(),
user_.account_id.GetGaiaId());
fake_gaia_.fake_gaia()->IssueOAuthToken(kRefreshToken, token_info);
}
private:
const char* kRefreshToken = FakeGaiaMixin::kFakeRefreshToken;
const int kAccessTokenExpiration = FakeGaiaMixin::kFakeAccessTokenExpiration;
LoginManagerMixin login_manager_;
EmbeddedTestServerSetupMixin test_server_;
FakeGaiaMixin fake_gaia_;
LoginManagerMixin::TestUserInfo user_;
const raw_ptr<InProcessBrowserTest> test_base_;
UserContext user_context_;
std::string access_token_{FakeGaiaMixin::kFakeAllScopeAccessToken};
};
AssistantTestMixin::AssistantTestMixin(
InProcessBrowserTestMixinHost* host,
InProcessBrowserTest* test_base,
net::EmbeddedTestServer* embedded_test_server,
FakeS3Mode mode,
int test_data_version)
: InProcessBrowserTestMixin(host),
fake_s3_server_(test_data_version),
mode_(mode),
test_api_(AssistantTestApi::Create()),
user_mixin_(std::make_unique<LoggedInUserMixin>(host,
test_base,
GetTestUserInfo(),
embedded_test_server)) {}
AssistantTestMixin::~AssistantTestMixin() = default;
void AssistantTestMixin::SetUpCommandLine(base::CommandLine* command_line) {
// Prevent the Assistant setup flow dialog from popping up immediately on user
// start - otherwise the Assistant can not be started.
command_line->AppendSwitch(switches::kOobeSkipPostLogin);
}
void AssistantTestMixin::SetUpOnMainThread() {
fake_s3_server_.Setup(mode_);
user_mixin_->SetAccessToken(fake_s3_server_.GetAccessToken());
test_api_->DisableAnimations();
}
void AssistantTestMixin::TearDownOnMainThread() {
DisableAssistant();
DisableFakeS3Server();
}
void AssistantTestMixin::DisableFakeS3Server() {
fake_s3_server_.Teardown();
}
void AssistantTestMixin::StartAssistantAndWaitForReady(
base::TimeDelta wait_timeout) {
const base::test::ScopedRunLoopTimeout run_timeout(FROM_HERE, wait_timeout);
// Note: You might be tempted to call this function from SetUpOnMainThread(),
// but that will not work as the Assistant service can not start until
// |BrowserTestBase| calls InitializeNetworkProcess(), which it only does
// after SetUpOnMainThread() finishes.
test_api_->SetAssistantEnabled(true);
SetPreferVoice(false);
AssistantStatusWaiter waiter(test_api_->GetAssistantState(),
AssistantStatus::READY);
waiter.RunUntilExpectedStatus();
}
void AssistantTestMixin::SetAssistantEnabled(bool enabled) {
test_api_->SetAssistantEnabled(enabled);
}
void AssistantTestMixin::SetPreferVoice(bool prefer_voice) {
test_api_->SetPreferVoice(prefer_voice);
}
void AssistantTestMixin::SendTextQuery(const std::string& query) {
test_api_->SendTextQuery(query);
}
template <typename T>
T AssistantTestMixin::SyncCall(
base::OnceCallback<void(base::OnceCallback<void(T)>)> func) {
const base::test::ScopedRunLoopTimeout run_timeout(FROM_HERE,
kDefaultWaitTimeout);
base::RunLoop run_loop(base::RunLoop::Type::kNestableTasksAllowed);
T result;
auto callback = base::BindOnce(
[](T* result_ptr, base::OnceClosure quit_closure, T result_value) {
*result_ptr = result_value;
std::move(quit_closure).Run();
},
&result, run_loop.QuitClosure());
std::move(func).Run(std::move(callback));
EXPECT_NO_FATAL_FAILURE(run_loop.Run())
<< "Failed waiting for async callback to return.\n";
return result;
}
template std::optional<double> AssistantTestMixin::SyncCall(
base::OnceCallback<void(base::OnceCallback<void(std::optional<double>)>)>
func);
void AssistantTestMixin::ExpectCardResponse(
const std::string& expected_response,
base::TimeDelta wait_timeout) {
const base::test::ScopedRunLoopTimeout run_timeout(FROM_HERE, wait_timeout);
TypedExpectedResponseWaiter waiter("AssistantCardElementView",
test_api_->ui_element_container(),
{expected_response});
waiter.RunUntilResponseReceived();
}
void AssistantTestMixin::ExpectTextResponse(
const std::string& expected_response,
base::TimeDelta wait_timeout) {
ExpectAnyOfTheseTextResponses({expected_response}, wait_timeout);
}
void AssistantTestMixin::ExpectAnyOfTheseTextResponses(
const std::vector<std::string>& expected_responses,
base::TimeDelta wait_timeout) {
const base::test::ScopedRunLoopTimeout run_timeout(FROM_HERE, wait_timeout);
TypedExpectedResponseWaiter waiter("AssistantTextElementView",
test_api_->ui_element_container(),
expected_responses);
waiter.RunUntilResponseReceived();
}
void AssistantTestMixin::ExpectErrorResponse(
const std::string& expected_response,
base::TimeDelta wait_timeout) {
const base::test::ScopedRunLoopTimeout run_timeout(FROM_HERE, wait_timeout);
TypedExpectedResponseWaiter waiter("AssistantErrorElementView",
test_api_->ui_element_container(),
{expected_response});
waiter.RunUntilResponseReceived();
}
void AssistantTestMixin::ExpectTimersResponse(
const std::vector<base::TimeDelta>& timers,
base::TimeDelta wait_timeout) {
// We expect the textual representation of a timers response to be of the form
// "<timer1 remaining time in seconds>\n<timer2 remaining time in seconds>..."
std::stringstream expected_response;
for (const auto& timer : timers)
expected_response << timer.InSeconds() << "\n";
const base::test::ScopedRunLoopTimeout run_timeout(FROM_HERE, wait_timeout);
TypedExpectedResponseWaiter waiter("AssistantTimersElementView",
test_api_->ui_element_container(),
{expected_response.str()});
waiter.RunUntilResponseReceived();
}
std::vector<base::TimeDelta> AssistantTestMixin::ExpectAndReturnTimersResponse(
base::TimeDelta wait_timeout) {
const base::test::ScopedRunLoopTimeout run_timeout(FROM_HERE, wait_timeout);
TypedResponseWaiter waiter("AssistantTimersElementView",
test_api_->ui_element_container());
waiter.RunUntilResponseReceived();
// We expect the textual representation of a timers response to be of the form
// "<timer1 remaining time in seconds>\n<timer2 remaining time in seconds>..."
std::vector<std::string> timers_as_strings =
base::SplitString(base::TrimString(waiter.GetResponseText(), "\n",
base::TrimPositions::TRIM_TRAILING),
"\n", base::WhitespaceHandling::KEEP_WHITESPACE,
base::SplitResult::SPLIT_WANT_ALL);
// Transform the textual representation of our timers into TimeDelta objects.
return base::ToVector(
timers_as_strings, [](const std::string& timer_as_string) {
int seconds_remaining = 0;
base::StringToInt(timer_as_string, &seconds_remaining);
return base::Seconds(seconds_remaining);
});
}
void AssistantTestMixin::PressAssistantKey() {
SendKeyPress(::ui::VKEY_ASSISTANT);
}
bool AssistantTestMixin::IsVisible() {
return test_api_->IsVisible();
}
void AssistantTestMixin::ExpectNoChange(base::TimeDelta wait_timeout) {
base::test::ScopedDisableRunLoopTimeout disable_timeout;
base::RunLoop run_loop;
// Exit the runloop after wait_timeout.
base::SequencedTaskRunner::GetCurrentDefault()->PostDelayedTask(
FROM_HERE,
base::BindRepeating(
[](base::RepeatingClosure quit) { std::move(quit).Run(); },
run_loop.QuitClosure()),
wait_timeout);
// Fail the runloop when the view hierarchy changes.
auto callback = base::BindRepeating(
[](const views::ViewHierarchyChangedDetails& change) { FAIL(); });
CallbackViewHierarchyChangedObserver observer(
test_api_->ui_element_container(), std::move(callback));
EXPECT_NO_FATAL_FAILURE(run_loop.Run())
<< "View hierarchy changed during ExpectNoChange.";
}
PrefService* AssistantTestMixin::GetUserPreferences() {
return ProfileManager::GetPrimaryUserProfile()->GetPrefs();
}
void AssistantTestMixin::SendKeyPress(::ui::KeyboardCode key) {
::ui::test::EventGenerator event_generator(test_api_->root_window());
event_generator.PressKey(key, /*flags=*/::ui::EF_NONE);
}
void AssistantTestMixin::DisableAssistant() {
// First disable Assistant in the settings.
test_api_->SetAssistantEnabled(false);
// Then wait for the Service to shutdown.
AssistantStatusWaiter waiter(test_api_->GetAssistantState(),
AssistantStatus::NOT_READY);
waiter.RunUntilExpectedStatus();
}
} // namespace ash::assistant