// Copyright 2020 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chromeos/ash/components/local_search_service/inverted_index.h"
#include <numeric>
#include <string>
#include <tuple>
#include <vector>
#include "base/functional/bind.h"
#include "base/functional/callback.h"
#include "base/strings/utf_string_conversions.h"
#include "base/task/task_traits.h"
#include "base/task/thread_pool.h"
#include "chromeos/ash/components/local_search_service/search_utils.h"
namespace ash::local_search_service {
namespace {
// (document-score, posting-of-all-matching-terms).
using ScoreWithPosting = std::pair<double, Posting>;
// Calculates TF-IDF scores for a term
std::vector<TfidfResult> CalculateTfidf(const std::u16string& term,
const DocLength& doc_length,
const Dictionary& dictionary) {
std::vector<TfidfResult> results;
// We don't apply weights to idf because the effect is likely small.
const float idf =
1.0 + log((1.0 + doc_length.size()) / (1.0 + dictionary.at(term).size()));
for (const auto& item : dictionary.at(term)) {
// If a term has a very low content weight in a doc, its effective number of
// occurrences in the doc should be lower. Strictly speaking, the effective
// length of the doc should be smaller too. However, for performance
// reasons, we only apply the weight to the term occurrences but not doc
// length.
// TODO(jiameng): this is an expensive operation, we will need to monitor
// its performance and optimize it.
const double effective_term_occ = std::accumulate(
item.second.begin(), item.second.end(), 0.0,
[](double sum, const WeightedPosition& weighted_position) {
return sum + weighted_position.weight;
});
const float tf = effective_term_occ / doc_length.at(item.first);
results.push_back({item.first, item.second, tf * idf});
}
return results;
}
// Builds TF-IDF cache given the data. Since this function is expensive, it
// should run on a non-blocking thread that is different than the main thread.
TfidfCache BuildTfidf(uint32_t num_docs_from_last_update,
const DocLength& doc_length,
const Dictionary& dictionary,
const TermSet& terms_to_be_updated,
const TfidfCache& tfidf_cache) {
// TODO(crbug.com/40152719): consider moving the helper functions inside the
// class so that we can use SequenceChecker.
TfidfCache new_cache(tfidf_cache);
// If number of documents doesn't change from the last time index was built,
// we only need to update terms in |terms_to_be_updated|. Otherwise we need
// to rebuild the index.
if (num_docs_from_last_update == doc_length.size()) {
for (const auto& term : terms_to_be_updated) {
if (dictionary.find(term) != dictionary.end()) {
new_cache[term] = CalculateTfidf(term, doc_length, dictionary);
} else {
new_cache.erase(term);
}
}
} else {
new_cache.clear();
for (const auto& item : dictionary) {
new_cache[item.first] =
CalculateTfidf(item.first, doc_length, dictionary);
}
}
return new_cache;
}
// Removes a document from document state variables given it's ID. Don't do
// anything if the ID doesn't exist. Return true if the document is removed.
bool RemoveDocumentIfExist(const std::string& document_id,
DocLength* doc_length,
Dictionary* dictionary,
TermSet* terms_to_be_updated) {
CHECK(doc_length);
CHECK(dictionary);
CHECK(terms_to_be_updated);
bool document_removed = false;
if (doc_length->find(document_id) == doc_length->end())
return document_removed;
doc_length->erase(document_id);
for (auto it = dictionary->begin(); it != dictionary->end();) {
if (it->second.find(document_id) != it->second.end()) {
terms_to_be_updated->insert(it->first);
it->second.erase(document_id);
document_removed = true;
}
// Removes term from the dictionary if its posting list is empty.
if (it->second.empty()) {
it = dictionary->erase(it);
} else {
it++;
}
}
return document_removed;
}
// Given list of documents to update and document state variables, returns new
// document state variables and number of deleted documents.
std::pair<DocumentStateVariables, uint32_t> UpdateDocumentStateVariables(
DocumentToUpdate&& documents_to_update,
const DocLength& doc_length,
Dictionary&& dictionary,
TermSet&& terms_to_be_updated) {
DocLength new_doc_length(doc_length);
uint32_t num_deleted = 0u;
for (const auto& document : documents_to_update) {
const std::string document_id(document.first);
bool is_deleted = RemoveDocumentIfExist(document_id, &new_doc_length,
&dictionary, &terms_to_be_updated);
// Update the document if necessary.
if (!document.second.empty()) {
// If document content is not empty, it is being updated but not
// deleted.
is_deleted = false;
for (const auto& token : document.second) {
dictionary[token.content][document_id] = token.positions;
new_doc_length[document_id] += token.positions.size();
terms_to_be_updated.insert(token.content);
}
}
num_deleted += (is_deleted) ? 1 : 0;
}
return std::make_pair(
std::make_tuple(std::move(new_doc_length), std::move(dictionary),
std::move(terms_to_be_updated)),
num_deleted);
}
// Given the index variables, clear all the data.
std::pair<DocumentStateVariables, TfidfCache> ClearData(
DocumentToUpdate&& documents_to_update,
const DocLength& doc_length,
Dictionary&& dictionary,
TermSet&& terms_to_be_updated,
TfidfCache&& tfidf_cache) {
DocLength new_doc_length;
documents_to_update.clear();
dictionary.clear();
terms_to_be_updated.clear();
tfidf_cache.clear();
return std::make_pair(
std::make_tuple(std::move(new_doc_length), std::move(dictionary),
std::move(terms_to_be_updated)),
std::move(tfidf_cache));
}
} // namespace
InvertedIndex::InvertedIndex() {
task_runner_ = base::ThreadPool::CreateSequencedTaskRunner(
{base::TaskPriority::BEST_EFFORT, base::MayBlock(),
base::TaskShutdownBehavior::CONTINUE_ON_SHUTDOWN});
}
InvertedIndex::~InvertedIndex() = default;
PostingList InvertedIndex::FindTerm(const std::u16string& term) const {
auto it = dictionary_.find(term);
if (it != dictionary_.end()) {
return it->second;
}
return {};
}
std::vector<Result> InvertedIndex::FindMatchingDocumentsApproximately(
const std::unordered_set<std::u16string>& terms,
double prefix_threshold,
double block_threshold) const {
// For each document, its score is the sum of the scores of its terms that
// match one of more query term. Each term's score is the product of its
// TF-IDF score and its match relevance score.
// The map is keyed by the document id.
std::unordered_map<std::string, ScoreWithPosting> matching_docs;
for (const auto& kv : tfidf_cache_) {
const std::u16string& index_term = kv.first;
const std::vector<TfidfResult>& tfidf_results = kv.second;
for (const auto& term : terms) {
const float relevance = RelevanceCoefficient(
term, index_term, prefix_threshold, block_threshold);
if (relevance > 0) {
// If the |index_term| is relevant, all of the enclosing documents will
// have their ranking scores updated.
for (const auto& docid_tfidf : tfidf_results) {
const std::string& docid = std::get<0>(docid_tfidf);
const Posting& posting = std::get<1>(docid_tfidf);
const float tfidf = std::get<2>(docid_tfidf);
auto it = matching_docs.find(docid);
if (it == matching_docs.end()) {
it = matching_docs.emplace(docid, ScoreWithPosting(0.0, {})).first;
}
auto& score_posting = it->second;
// TODO(jiameng): add position penalty.
score_posting.first += tfidf * relevance;
// Also update matching positions.
auto& existing_posting = score_posting.second;
existing_posting.insert(existing_posting.end(), posting.begin(),
posting.end());
}
// Break out from inner loop, i.e. no need to check other query terms.
break;
}
}
}
std::vector<Result> sorted_matching_docs;
for (const auto& kv : matching_docs) {
// We don't need to include weights in the search results.
std::vector<Position> positions;
for (const auto& weighted_position : kv.second.second) {
positions.emplace_back(weighted_position.position);
}
sorted_matching_docs.emplace_back(
Result(kv.first, kv.second.first, positions));
}
std::sort(sorted_matching_docs.begin(), sorted_matching_docs.end(),
CompareResults);
return sorted_matching_docs;
}
void InvertedIndex::AddDocuments(const DocumentToUpdate& documents,
base::OnceCallback<void()> callback) {
if (documents.empty())
return;
task_runner_->PostTaskAndReplyWithResult(
FROM_HERE,
base::BindOnce(&UpdateDocumentStateVariables, documents,
std::move(doc_length_), std::move(dictionary_),
std::move(terms_to_be_updated_)),
base::BindOnce(&InvertedIndex::OnAddDocumentsComplete,
weak_ptr_factory_.GetWeakPtr(), std::move(callback)));
}
void InvertedIndex::RemoveDocuments(
const std::vector<std::string>& document_ids,
base::OnceCallback<void(uint32_t)> callback) {
DocumentToUpdate documents;
for (const auto& id : document_ids) {
documents.push_back({id, std::vector<Token>()});
}
task_runner_->PostTaskAndReplyWithResult(
FROM_HERE,
base::BindOnce(&UpdateDocumentStateVariables, documents,
std::move(doc_length_), std::move(dictionary_),
std::move(terms_to_be_updated_)),
base::BindOnce(&InvertedIndex::OnUpdateDocumentsComplete,
weak_ptr_factory_.GetWeakPtr(), std::move(callback)));
}
void InvertedIndex::UpdateDocuments(
const DocumentToUpdate& documents,
base::OnceCallback<void(uint32_t)> callback) {
task_runner_->PostTaskAndReplyWithResult(
FROM_HERE,
base::BindOnce(&UpdateDocumentStateVariables, documents,
std::move(doc_length_), std::move(dictionary_),
std::move(terms_to_be_updated_)),
base::BindOnce(&InvertedIndex::OnUpdateDocumentsComplete,
weak_ptr_factory_.GetWeakPtr(), std::move(callback)));
}
std::vector<TfidfResult> InvertedIndex::GetTfidf(
const std::u16string& term) const {
auto it = tfidf_cache_.find(term);
if (it != tfidf_cache_.end()) {
return it->second;
}
return {};
}
void InvertedIndex::BuildInvertedIndex(base::OnceCallback<void()> callback) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
task_runner_->PostTaskAndReplyWithResult(
FROM_HERE,
base::BindOnce(&BuildTfidf, num_docs_from_last_update_, doc_length_,
dictionary_, std::move(terms_to_be_updated_),
tfidf_cache_),
base::BindOnce(&InvertedIndex::OnBuildTfidfComplete,
weak_ptr_factory_.GetWeakPtr(), std::move(callback)));
}
void InvertedIndex::ClearInvertedIndex(base::OnceCallback<void()> callback) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
task_runner_->PostTaskAndReplyWithResult(
FROM_HERE,
base::BindOnce(&ClearData, std::move(documents_to_update_), doc_length_,
std::move(dictionary_), std::move(terms_to_be_updated_),
std::move(tfidf_cache_)),
base::BindOnce(&InvertedIndex::OnDataCleared,
weak_ptr_factory_.GetWeakPtr(), std::move(callback)));
}
void InvertedIndex::OnBuildTfidfComplete(base::OnceCallback<void()> callback,
TfidfCache&& new_cache) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
num_docs_from_last_update_ = doc_length_.size();
tfidf_cache_ = std::move(new_cache);
std::move(callback).Run();
}
void InvertedIndex::OnUpdateDocumentsComplete(
base::OnceCallback<void(uint32_t)> callback,
std::pair<DocumentStateVariables, uint32_t>&&
document_state_variables_and_num_deleted) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
doc_length_ =
std::move(std::get<0>(document_state_variables_and_num_deleted.first));
dictionary_ =
std::move(std::get<1>(document_state_variables_and_num_deleted.first));
terms_to_be_updated_ =
std::move(std::get<2>(document_state_variables_and_num_deleted.first));
BuildInvertedIndex(base::BindOnce(
[](base::OnceCallback<void(uint32_t)> callback, uint32_t num_deleted) {
std::move(callback).Run(num_deleted);
},
std::move(callback), document_state_variables_and_num_deleted.second));
}
void InvertedIndex::OnAddDocumentsComplete(
base::OnceCallback<void()> callback,
std::pair<DocumentStateVariables, uint32_t>&&
document_state_variables_and_num_deleted) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DCHECK_EQ(document_state_variables_and_num_deleted.second, 0u);
doc_length_ =
std::move(std::get<0>(document_state_variables_and_num_deleted.first));
dictionary_ =
std::move(std::get<1>(document_state_variables_and_num_deleted.first));
terms_to_be_updated_ =
std::move(std::get<2>(document_state_variables_and_num_deleted.first));
BuildInvertedIndex(std::move(callback));
}
void InvertedIndex::OnDataCleared(
base::OnceCallback<void()> callback,
std::pair<DocumentStateVariables, TfidfCache>&& inverted_index_data) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
doc_length_ = std::move(std::get<0>(inverted_index_data.first));
dictionary_ = std::move(std::get<1>(inverted_index_data.first));
terms_to_be_updated_ = std::move(std::get<2>(inverted_index_data.first));
tfidf_cache_ = std::move(inverted_index_data.second);
num_docs_from_last_update_ = 0;
std::move(callback).Run();
}
} // namespace ash::local_search_service