// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "ash/picker/search/picker_search_request.h"
#include <algorithm>
#include <cstddef>
#include <functional>
#include <map>
#include <optional>
#include <set>
#include <string>
#include <string_view>
#include <utility>
#include <variant>
#include <vector>
#include "ash/picker/picker_clipboard_history_provider.h"
#include "ash/picker/search/picker_action_search.h"
#include "ash/picker/search/picker_date_search.h"
#include "ash/picker/search/picker_editor_search.h"
#include "ash/picker/search/picker_math_search.h"
#include "ash/picker/search/picker_search_source.h"
#include "ash/public/cpp/app_list/app_list_types.h"
#include "ash/public/cpp/picker/picker_category.h"
#include "ash/public/cpp/picker/picker_client.h"
#include "ash/public/cpp/picker/picker_search_result.h"
#include "base/check.h"
#include "base/check_deref.h"
#include "base/containers/contains.h"
#include "base/containers/fixed_flat_set.h"
#include "base/containers/flat_set.h"
#include "base/containers/span.h"
#include "base/functional/bind.h"
#include "base/functional/callback_forward.h"
#include "base/functional/callback_helpers.h"
#include "base/logging.h"
#include "base/metrics/histogram_functions.h"
#include "base/notreached.h"
#include "base/parameter_pack.h"
#include "base/ranges/algorithm.h"
#include "base/strings/utf_string_conversions.h"
#include "base/time/time.h"
#include "base/types/cxx23_to_underlying.h"
#include "url/gurl.h"
namespace ash {
namespace {
// TODO: b/330936766 - Prioritise "earlier" domains in this list.
constexpr auto kGoogleCorpGotoHosts = base::MakeFixedFlatSet<std::string_view>(
{"goto2.corp.google.com", "goto.corp.google.com", "goto.google.com", "go"});
const char* SearchSourceToHistogram(PickerSearchSource source) {
switch (source) {
case PickerSearchSource::kOmnibox:
return "Ash.Picker.Search.OmniboxProvider.QueryTime";
case PickerSearchSource::kDate:
return "Ash.Picker.Search.DateProvider.QueryTime";
case PickerSearchSource::kAction:
return "Ash.Picker.Search.CategoryProvider.QueryTime";
case PickerSearchSource::kLocalFile:
return "Ash.Picker.Search.FileProvider.QueryTime";
case PickerSearchSource::kDrive:
return "Ash.Picker.Search.DriveProvider.QueryTime";
case PickerSearchSource::kMath:
return "Ash.Picker.Search.MathProvider.QueryTime";
case PickerSearchSource::kClipboard:
return "Ash.Picker.Search.ClipboardProvider.QueryTime";
case PickerSearchSource::kEditorWrite:
case PickerSearchSource::kEditorRewrite:
return "Ash.Picker.Search.EditorProvider.QueryTime";
}
NOTREACHED() << "Unexpected search source " << base::to_underlying(source);
}
[[nodiscard]] std::vector<PickerSearchResult> DeduplicateGoogleCorpGotoDomains(
std::vector<PickerSearchResult> omnibox_results) {
std::set<std::string, std::less<>> seen;
std::vector<PickerSearchResult> deduped_results;
std::vector<PickerSearchResult*> results_to_remove;
for (PickerSearchResult& link : omnibox_results) {
auto* link_data = std::get_if<PickerBrowsingHistoryResult>(&link);
if (link_data == nullptr) {
deduped_results.push_back(std::move(link));
continue;
}
const GURL& url = link_data->url;
if (!url.has_host() || !url.has_path() ||
!kGoogleCorpGotoHosts.contains(url.host_piece())) {
deduped_results.push_back(std::move(link));
continue;
}
auto [it, inserted] = seen.emplace(url.path_piece());
if (inserted) {
deduped_results.push_back(std::move(link));
}
}
return deduped_results;
}
} // namespace
PickerSearchRequest::PickerSearchRequest(std::u16string_view query,
std::optional<PickerCategory> category,
SearchResultsCallback callback,
DoneCallback done_callback,
PickerClient* client,
const Options& options)
: is_category_specific_search_(category.has_value()),
client_(CHECK_DEREF(client)),
current_callback_(std::move(callback)),
done_callback_(std::move(done_callback)) {
CHECK(!current_callback_.is_null());
CHECK(!done_callback_.is_null());
base::span<const PickerCategory> available_categories =
options.available_categories;
std::string utf8_query = base::UTF16ToUTF8(query);
std::vector<PickerSearchSource> cros_search_sources;
cros_search_sources.reserve(3);
if ((!category.has_value() || category == PickerCategory::kLinks) &&
base::Contains(available_categories, PickerCategory::kLinks)) {
cros_search_sources.push_back(PickerSearchSource::kOmnibox);
}
if ((!category.has_value() || category == PickerCategory::kLocalFiles) &&
base::Contains(available_categories, PickerCategory::kLocalFiles)) {
cros_search_sources.push_back(PickerSearchSource::kLocalFile);
}
if ((!category.has_value() || category == PickerCategory::kDriveFiles) &&
base::Contains(available_categories, PickerCategory::kDriveFiles)) {
cros_search_sources.push_back(PickerSearchSource::kDrive);
}
if (!cros_search_sources.empty()) {
// TODO: b/326166751 - Use `available_categories_` to decide what searches
// to do.
for (PickerSearchSource source : cros_search_sources) {
MarkSearchStarted(source);
}
client_->StartCrosSearch(
std::u16string(query), category,
base::BindRepeating(&PickerSearchRequest::HandleCrosSearchResults,
weak_ptr_factory_.GetWeakPtr()));
}
if ((!category.has_value() || category == PickerCategory::kClipboard) &&
base::Contains(available_categories, PickerCategory::kClipboard)) {
clipboard_provider_ = std::make_unique<PickerClipboardHistoryProvider>();
MarkSearchStarted(PickerSearchSource::kClipboard);
clipboard_provider_->FetchResults(
base::BindOnce(&PickerSearchRequest::HandleClipboardSearchResults,
weak_ptr_factory_.GetWeakPtr()),
query);
}
if ((!category.has_value() || category == PickerCategory::kDatesTimes) &&
base::Contains(available_categories, PickerCategory::kDatesTimes)) {
MarkSearchStarted(PickerSearchSource::kDate);
// Date results is currently synchronous.
HandleDateSearchResults(PickerDateSearch(base::Time::Now(), query));
}
if ((!category.has_value() || category == PickerCategory::kUnitsMaths) &&
base::Contains(available_categories, PickerCategory::kUnitsMaths)) {
MarkSearchStarted(PickerSearchSource::kMath);
// Math results is currently synchronous.
HandleMathSearchResults(PickerMathSearch(query));
}
// These searches do not have category-specific search.
if (!category.has_value()) {
MarkSearchStarted(PickerSearchSource::kAction);
// Action results are currently synchronous.
HandleActionSearchResults(PickerActionSearch(
{.available_categories = available_categories,
.caps_lock_state_to_search = options.caps_lock_state_to_search,
.search_case_transforms = options.search_case_transforms},
query));
if (base::Contains(available_categories, PickerCategory::kEditorWrite)) {
// Editor results are currently synchronous.
MarkSearchStarted(PickerSearchSource::kEditorWrite);
HandleEditorSearchResults(
PickerSearchSource::kEditorWrite,
PickerEditorSearch(PickerEditorResult::Mode::kWrite, query));
}
if (base::Contains(available_categories, PickerCategory::kEditorRewrite)) {
// Editor results are currently synchronous.
MarkSearchStarted(PickerSearchSource::kEditorRewrite);
HandleEditorSearchResults(
PickerSearchSource::kEditorRewrite,
PickerEditorSearch(PickerEditorResult::Mode::kRewrite, query));
}
}
can_call_done_closure_ = true;
MaybeCallDoneClosure();
}
PickerSearchRequest::~PickerSearchRequest() {
// Ensure that any bound callbacks to `Handle*SearchResults` - and therefore
// `current_callback_` - will not get called by stopping searches.
weak_ptr_factory_.InvalidateWeakPtrs();
if (!done_callback_.is_null()) {
std::move(done_callback_).Run(/*interrupted=*/true);
current_callback_.Reset();
}
client_->StopCrosQuery();
}
void PickerSearchRequest::HandleSearchSourceResults(
PickerSearchSource source,
std::vector<PickerSearchResult> results,
bool has_more_results) {
MarkSearchEnded(source);
// This method is only called from `Handle*SearchResults` methods (one for
// each search source), and the only time `current_callback_` is null is when
// this request is being destructed, or `done_closure_` was called.
// The destructor invalidates any bound callbacks to `Handle*SearchResults`
// before resetting the callback to null. If `done_closure_` was called, and
// more calls would have occurred, this is a bug and we should noisly crash.
CHECK(!current_callback_.is_null())
<< "Current callback is null in HandleSearchSourceResults";
current_callback_.Run(source, std::move(results), has_more_results);
MaybeCallDoneClosure();
}
void PickerSearchRequest::HandleActionSearchResults(
std::vector<PickerSearchResult> results) {
HandleSearchSourceResults(PickerSearchSource::kAction, std::move(results),
/*has_more_results*/ false);
}
void PickerSearchRequest::HandleCrosSearchResults(
ash::AppListSearchResultType type,
std::vector<PickerSearchResult> results) {
switch (type) {
case AppListSearchResultType::kOmnibox: {
results = DeduplicateGoogleCorpGotoDomains(std::move(results));
size_t results_to_remove = is_category_specific_search_
? 0
: std::max<size_t>(results.size(), 3) - 3;
results.erase(results.end() - results_to_remove, results.end());
HandleSearchSourceResults(PickerSearchSource::kOmnibox,
std::move(results),
/*has_more_results=*/results_to_remove > 0);
break;
}
case AppListSearchResultType::kDriveSearch: {
size_t files_to_remove = is_category_specific_search_
? 0
: std::max<size_t>(results.size(), 3) - 3;
results.erase(results.end() - files_to_remove, results.end());
HandleSearchSourceResults(PickerSearchSource::kDrive, std::move(results),
/*has_more_results=*/files_to_remove > 0);
break;
}
case AppListSearchResultType::kFileSearch: {
size_t files_to_remove = is_category_specific_search_
? 0
: std::max<size_t>(results.size(), 3) - 3;
results.erase(results.end() - files_to_remove, results.end());
HandleSearchSourceResults(PickerSearchSource::kLocalFile,
std::move(results),
/*has_more_results=*/files_to_remove > 0);
break;
}
default:
LOG(DFATAL) << "Got unexpected search result type "
<< static_cast<int>(type);
break;
}
}
void PickerSearchRequest::HandleDateSearchResults(
std::vector<PickerSearchResult> results) {
// Date results are never truncated.
HandleSearchSourceResults(PickerSearchSource::kDate, std::move(results),
/*has_more_results=*/false);
}
void PickerSearchRequest::HandleMathSearchResults(
std::optional<PickerSearchResult> result) {
std::vector<PickerSearchResult> results;
if (result.has_value()) {
results.push_back(*std::move(result));
}
// Math results are never truncated.
HandleSearchSourceResults(PickerSearchSource::kMath, std::move(results),
/*has_more_results=*/false);
}
void PickerSearchRequest::HandleClipboardSearchResults(
std::vector<PickerSearchResult> results) {
// Clipboard results are never truncated.
HandleSearchSourceResults(PickerSearchSource::kClipboard, std::move(results),
/*has_more_results=*/false);
}
void PickerSearchRequest::HandleEditorSearchResults(
PickerSearchSource source,
std::optional<PickerSearchResult> result) {
std::vector<PickerSearchResult> results;
if (result.has_value()) {
results.push_back(std::move(*result));
}
// Editor results are never truncated.
HandleSearchSourceResults(source, std::move(results),
/*has_more_results=*/false);
}
void PickerSearchRequest::MarkSearchStarted(PickerSearchSource source) {
CHECK(!SwapSearchStart(source, base::TimeTicks::Now()).has_value())
<< "search_starts_ enum " << base::to_underlying(source)
<< " was already set";
}
void PickerSearchRequest::MarkSearchEnded(PickerSearchSource source) {
std::optional<base::TimeTicks> start = SwapSearchStart(source, std::nullopt);
CHECK(start.has_value()) << "search_starts_ enum "
<< base::to_underlying(source) << " was not set";
base::TimeDelta elapsed = base::TimeTicks::Now() - *start;
base::UmaHistogramTimes(SearchSourceToHistogram(source), elapsed);
}
std::optional<base::TimeTicks> PickerSearchRequest::SwapSearchStart(
PickerSearchSource source,
std::optional<base::TimeTicks> new_value) {
return std::exchange(search_starts_[base::to_underlying(source)],
std::move(new_value));
}
void PickerSearchRequest::MaybeCallDoneClosure() {
if (!can_call_done_closure_) {
return;
}
if (base::ranges::any_of(search_starts_,
[](std::optional<base::TimeTicks>& start) {
return start.has_value();
})) {
return;
}
std::move(done_callback_).Run(/*interrupted=*/false);
current_callback_.Reset();
}
} // namespace ash