chromium/chromeos/ash/components/local_search_service/inverted_index_search.cc

// Copyright 2020 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chromeos/ash/components/local_search_service/inverted_index_search.h"

#include <cstdint>
#include <optional>
#include <utility>
#include <vector>

#include "base/functional/bind.h"
#include "base/functional/callback_helpers.h"
#include "base/i18n/rtl.h"
#include "base/strings/string_split.h"
#include "base/strings/string_util.h"
#include "base/task/task_traits.h"
#include "base/task/thread_pool.h"
#include "base/time/time.h"
#include "chromeos/ash/components/local_search_service/content_extraction_utils.h"
#include "chromeos/ash/components/local_search_service/inverted_index.h"
#include "chromeos/ash/components/string_matching/tokenized_string.h"

namespace ash::local_search_service {

namespace {

using string_matching::TokenizedString;
using ExtractedContent =
    std::vector<std::pair<std::string, std::vector<Token>>>;

std::vector<Token> ExtractDocumentTokens(const Data& data) {
  // Use input locale unless it's empty. In this case we will use system
  // default locale.
  const std::string locale =
      data.locale.empty() ? base::i18n::GetConfiguredLocale() : data.locale;
  std::vector<Token> document_tokens;
  for (const Content& content : data.contents) {
    DCHECK_GE(content.weight, 0);
    DCHECK_LE(content.weight, 1);
    const std::vector<Token> content_tokens =
        ExtractContent(content.id, content.content, content.weight, locale);
    document_tokens.insert(document_tokens.end(), content_tokens.begin(),
                           content_tokens.end());
  }
  return ConsolidateToken(document_tokens);
}

ExtractedContent ExtractDocumentsContent(const std::vector<Data>& data) {
  ExtractedContent documents;
  for (const Data& d : data) {
    const std::vector<Token> document_tokens = ExtractDocumentTokens(d);
    documents.push_back({d.id, document_tokens});
  }

  return documents;
}

std::unordered_set<std::u16string> GetTokenizedQuery(
    const std::u16string& query) {
  // TODO(jiameng): actual input query may not be the same as default locale.
  // Need another way to determine actual language of the query.
  const TokenizedString::Mode mode =
      IsNonLatinLocale(base::i18n::GetConfiguredLocale())
          ? TokenizedString::Mode::kCamelCase
          : TokenizedString::Mode::kWords;

  const TokenizedString tokenized_query(query, mode);
  std::unordered_set<std::u16string> tokens;
  for (const auto& token : tokenized_query.tokens()) {
    // TODO(jiameng): we are not removing stopword because they shouldn't exist
    // in the index. However, for performance reason, it may be worth to be
    // removed.
    tokens.insert(token);
  }
  return tokens;
}

}  // namespace

InvertedIndexSearch::InvertedIndexSearch(IndexId index_id)
    : Index(index_id, Backend::kInvertedIndex),
      inverted_index_(std::make_unique<InvertedIndex>()),
      blocking_task_runner_(base::ThreadPool::CreateSequencedTaskRunner(
          {base::TaskPriority::BEST_EFFORT, base::MayBlock(),
           base::TaskShutdownBehavior::CONTINUE_ON_SHUTDOWN})) {}

InvertedIndexSearch::~InvertedIndexSearch() {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
}

void InvertedIndexSearch::GetSize(GetSizeCallback callback) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  std::move(callback).Run(inverted_index_->NumberDocuments());
}

void InvertedIndexSearch::AddOrUpdate(const std::vector<Data>& data,
                                      AddOrUpdateCallback callback) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  DCHECK(!data.empty());
  blocking_task_runner_->PostTaskAndReplyWithResult(
      FROM_HERE, base::BindOnce(&ExtractDocumentsContent, data),
      base::BindOnce(
          &InvertedIndexSearch::FinalizeAddOrUpdate,
          weak_ptr_factory_.GetWeakPtr(),
          base::BindOnce(&InvertedIndexSearch::AddOrUpdateCallbackWithTime,
                         weak_ptr_factory_.GetWeakPtr(), std::move(callback),
                         base::Time::Now())));
}

void InvertedIndexSearch::Delete(const std::vector<std::string>& ids,
                                 DeleteCallback callback) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  DCHECK(!ids.empty());
  blocking_task_runner_->PostTaskAndReply(
      FROM_HERE, base::DoNothing(),
      base::BindOnce(
          &InvertedIndexSearch::FinalizeDelete, weak_ptr_factory_.GetWeakPtr(),
          base::BindOnce(&InvertedIndexSearch::DeleteCallbackWithTime,
                         weak_ptr_factory_.GetWeakPtr(), std::move(callback),
                         base::Time::Now()),
          ids));
}

void InvertedIndexSearch::UpdateDocuments(const std::vector<Data>& data,
                                          UpdateDocumentsCallback callback) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  DCHECK(!data.empty());
  blocking_task_runner_->PostTaskAndReplyWithResult(
      FROM_HERE, base::BindOnce(&ExtractDocumentsContent, data),
      base::BindOnce(
          &InvertedIndexSearch::FinalizeUpdateDocuments,
          weak_ptr_factory_.GetWeakPtr(),
          base::BindOnce(&InvertedIndexSearch::UpdateDocumentsCallbackWithTime,
                         weak_ptr_factory_.GetWeakPtr(), std::move(callback),
                         base::Time::Now())));
}

void InvertedIndexSearch::Find(const std::u16string& query,
                               uint32_t max_results,
                               FindCallback callback) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  const base::TimeTicks start = base::TimeTicks::Now();
  if (query.empty()) {
    const ResponseStatus status = ResponseStatus::kEmptyQuery;
    MaybeLogSearchResultsStats(status, 0u, base::TimeDelta());
    std::move(callback).Run(status, std::nullopt);
    return;
  }
  if (inverted_index_->NumberDocuments() == 0u) {
    const ResponseStatus status = ResponseStatus::kEmptyIndex;
    MaybeLogSearchResultsStats(status, 0u, base::TimeDelta());
    std::move(callback).Run(status, std::nullopt);
    return;
  }

  std::vector<Result> results =
      inverted_index_->FindMatchingDocumentsApproximately(
          GetTokenizedQuery(query), search_params_.prefix_threshold,
          search_params_.fuzzy_threshold);

  if (results.size() > max_results && max_results > 0u)
    results.resize(max_results);

  const ResponseStatus status = ResponseStatus::kSuccess;
  const base::TimeTicks end = base::TimeTicks::Now();
  MaybeLogSearchResultsStats(status, results.size(), end - start);
  std::move(callback).Run(status, results);
}

void InvertedIndexSearch::ClearIndex(ClearIndexCallback callback) {
  inverted_index_->ClearInvertedIndex(base::BindOnce(
      &InvertedIndexSearch::ClearIndexCallbackWithTime,
      weak_ptr_factory_.GetWeakPtr(), std::move(callback), base::Time::Now()));
}

uint32_t InvertedIndexSearch::GetIndexSize() const {
  return inverted_index_->NumberDocuments();
}

std::vector<std::pair<std::string, uint32_t>>
InvertedIndexSearch::FindTermForTesting(const std::u16string& term) const {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  const PostingList posting_list = inverted_index_->FindTerm(term);
  std::vector<std::pair<std::string, uint32_t>> doc_with_freq;
  for (const auto& kv : posting_list) {
    doc_with_freq.push_back({kv.first, kv.second.size()});
  }

  return doc_with_freq;
}

void InvertedIndexSearch::FinalizeAddOrUpdate(
    AddOrUpdateCallback callback,
    const ExtractedContent& documents) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  inverted_index_->AddDocuments(documents, std::move(callback));
}

void InvertedIndexSearch::FinalizeDelete(DeleteCallback callback,
                                         const std::vector<std::string>& ids) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  inverted_index_->RemoveDocuments(ids, std::move(callback));
}

void InvertedIndexSearch::FinalizeUpdateDocuments(
    UpdateDocumentsCallback callback,
    const ExtractedContent& documents) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  inverted_index_->UpdateDocuments(documents, std::move(callback));
}

}  // namespace ash::local_search_service