// Copyright 2019 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chromeos/ash/components/string_matching/fuzzy_tokenized_string_match.h"
#include <algorithm>
#include <cmath>
#include <cstdlib>
#include <iterator>
#include <optional>
#include <set>
#include <string>
#include <vector>
#include "base/i18n/case_conversion.h"
#include "base/strings/strcat.h"
#include "base/strings/string_util.h"
#include "chromeos/ash/components/string_matching/acronym_matcher.h"
#include "chromeos/ash/components/string_matching/diacritic_utils.h"
#include "chromeos/ash/components/string_matching/prefix_matcher.h"
#include "chromeos/ash/components/string_matching/sequence_matcher.h"
namespace ash::string_matching {
namespace {
using Hits = FuzzyTokenizedStringMatch::Hits;
constexpr double kPartialMatchPenaltyRate = 0.9;
constexpr double kMinScore = 0.0;
constexpr double kMaxScore = 1.0;
// The maximum supported size for a prefix matching scoring boost.
constexpr size_t kMaxBoostSize = 2;
// The scale ratio for non exact matching results.
constexpr double kNonExactMatchScaleRatio = 0.97;
// Returns sorted tokens from a TokenizedString.
std::vector<std::u16string> ProcessAndSort(const TokenizedString& text) {
std::vector<std::u16string> result;
for (const auto& token : text.tokens()) {
result.emplace_back(token);
}
std::sort(result.begin(), result.end());
return result;
}
double ScaledRelevance(const double relevance) {
return 1.0 - std::pow(0.5, relevance);
}
} // namespace
FuzzyTokenizedStringMatch::~FuzzyTokenizedStringMatch() = default;
FuzzyTokenizedStringMatch::FuzzyTokenizedStringMatch() = default;
double FuzzyTokenizedStringMatch::TokenSetRatio(const TokenizedString& query,
const TokenizedString& text,
bool partial) {
std::set<std::u16string> query_token(query.tokens().begin(),
query.tokens().end());
std::set<std::u16string> text_token(text.tokens().begin(),
text.tokens().end());
std::vector<std::u16string> intersection;
std::vector<std::u16string> query_diff_text;
std::vector<std::u16string> text_diff_query;
// Find the set intersection and the set differences between two sets of
// tokens.
std::set_intersection(query_token.begin(), query_token.end(),
text_token.begin(), text_token.end(),
std::back_inserter(intersection));
std::set_difference(query_token.begin(), query_token.end(),
text_token.begin(), text_token.end(),
std::back_inserter(query_diff_text));
std::set_difference(text_token.begin(), text_token.end(), query_token.begin(),
query_token.end(), std::back_inserter(text_diff_query));
const std::u16string intersection_string =
base::JoinString(intersection, u" ");
const std::u16string query_rewritten =
intersection.empty()
? base::JoinString(query_diff_text, u" ")
: base::StrCat({intersection_string, u" ",
base::JoinString(query_diff_text, u" ")});
const std::u16string text_rewritten =
intersection.empty()
? base::JoinString(text_diff_query, u" ")
: base::StrCat({intersection_string, u" ",
base::JoinString(text_diff_query, u" ")});
if (partial) {
return std::max({PartialRatio(intersection_string, query_rewritten),
PartialRatio(intersection_string, text_rewritten),
PartialRatio(query_rewritten, text_rewritten)});
}
return std::max(
{SequenceMatcher(intersection_string, query_rewritten).Ratio(),
SequenceMatcher(intersection_string, text_rewritten).Ratio(),
SequenceMatcher(query_rewritten, text_rewritten).Ratio()});
}
double FuzzyTokenizedStringMatch::TokenSortRatio(const TokenizedString& query,
const TokenizedString& text,
bool partial) {
const std::u16string query_sorted =
base::JoinString(ProcessAndSort(query), u" ");
const std::u16string text_sorted =
base::JoinString(ProcessAndSort(text), u" ");
if (partial) {
return PartialRatio(query_sorted, text_sorted);
}
return SequenceMatcher(query_sorted, text_sorted).Ratio();
}
double FuzzyTokenizedStringMatch::PartialRatio(const std::u16string& query,
const std::u16string& text) {
if (query.empty() || text.empty()) {
return kMinScore;
}
std::u16string shorter = query;
std::u16string longer = text;
if (shorter.size() > longer.size()) {
shorter = text;
longer = query;
}
const auto matching_blocks =
SequenceMatcher(shorter, longer).GetMatchingBlocks();
double partial_ratio = 0;
for (const auto& block : matching_blocks) {
const int long_start =
block.pos_second_string > block.pos_first_string
? block.pos_second_string - block.pos_first_string
: 0;
// Penalizes the match if it is not close to the beginning of a token.
int current = long_start - 1;
while (current >= 0 &&
!base::EqualsCaseInsensitiveASCII(longer.substr(current, 1), u" ")) {
current--;
}
const double penalty =
std::pow(kPartialMatchPenaltyRate, long_start - current - 1);
// TODO(crbug.com/40638914): currently this part re-calculate the ratio for
// every pair. Improve this to reduce latency.
partial_ratio = std::max(
partial_ratio,
SequenceMatcher(shorter, longer.substr(long_start, shorter.size()))
.Ratio() *
penalty);
if (partial_ratio > 0.995) {
return kMaxScore;
}
}
return partial_ratio;
}
double FuzzyTokenizedStringMatch::WeightedRatio(const TokenizedString& query,
const TokenizedString& text) {
// All token based comparisons are scaled by 0.95 (on top of any partial
// scalars), as per original implementation:
// https://github.com/seatgeek/fuzzywuzzy/blob/af443f918eebbccff840b86fa606ac150563f466/fuzzywuzzy/fuzz.py#L245
const double unbase_scale = 0.95;
// Since query.text() and text.text() is not normalized, we use query.tokens()
// and text.tokens() instead.
const std::u16string query_normalized(base::JoinString(query.tokens(), u" "));
const std::u16string text_normalized(base::JoinString(text.tokens(), u" "));
std::vector<double> weighted_ratios;
weighted_ratios.emplace_back(
SequenceMatcher(query_normalized, text_normalized)
.Ratio(/*text_length_agnostic=*/true));
const double length_ratio =
static_cast<double>(
std::max(query_normalized.size(), text_normalized.size())) /
std::min(query_normalized.size(), text_normalized.size());
// Use partial if two strings are quite different in sizes.
const bool use_partial = length_ratio >= 1.5;
double length_ratio_scale = 1;
if (use_partial) {
// TODO(crbug.com/1336160): Consider scaling |partial_scale| smoothly with
// |length_ratio|, instead of using a step function.
//
// If one string is much much shorter than the other, set |partial_scale| to
// be 0.6, otherwise set it to be 0.9.
length_ratio_scale = length_ratio > 8 ? 0.6 : 0.9;
weighted_ratios.emplace_back(
PartialRatio(query_normalized, text_normalized) * length_ratio_scale);
}
weighted_ratios.emplace_back(TokenSortRatio(query, text, use_partial) *
unbase_scale * length_ratio_scale);
// Do not use partial match for token set because the match between the
// intersection string and query/text rewrites will always return an extremely
// high value.
weighted_ratios.emplace_back(TokenSetRatio(query, text, false /*partial*/) *
unbase_scale * length_ratio_scale);
// Return the maximum of all included weighted ratios
return *std::max_element(weighted_ratios.begin(), weighted_ratios.end());
}
double FuzzyTokenizedStringMatch::PrefixMatcher(const TokenizedString& query,
const TokenizedString& text) {
string_matching::PrefixMatcher match(query, text);
match.Match();
return ScaledRelevance(match.relevance());
}
double FuzzyTokenizedStringMatch::AcronymMatcher(const TokenizedString& query,
const TokenizedString& text) {
string_matching::AcronymMatcher match(query, text);
const double relevance = match.CalculateRelevance();
return ScaledRelevance(relevance);
}
double FuzzyTokenizedStringMatch::PrefixMatcher(
const TokenizedString& query,
const TokenizedString& text,
std::vector<Hits>& hits_vector) {
string_matching::PrefixMatcher match(query, text);
match.Match();
hits_vector.emplace_back(match.hits());
return ScaledRelevance(match.relevance());
}
double FuzzyTokenizedStringMatch::AcronymMatcher(
const TokenizedString& query,
const TokenizedString& text,
std::vector<Hits>& hits_vector) {
string_matching::AcronymMatcher match(query, text);
const double relevance = match.CalculateRelevance();
hits_vector.emplace_back(match.hits());
return ScaledRelevance(relevance);
}
double FuzzyTokenizedStringMatch::Relevance(const TokenizedString& query_input,
const TokenizedString& text_input,
bool use_weighted_ratio,
bool strip_diacritics,
bool use_acronym_matcher) {
// If the query is much longer than the text then it's often not a match.
if (query_input.text().size() >= text_input.text().size() * 2) {
return 0.0;
}
std::optional<TokenizedString> stripped_query;
std::optional<TokenizedString> stripped_text;
if (strip_diacritics) {
stripped_query.emplace(RemoveDiacritics(query_input.text()));
stripped_text.emplace(RemoveDiacritics(text_input.text()));
}
const TokenizedString& query =
strip_diacritics ? stripped_query.value() : query_input;
const TokenizedString& text =
strip_diacritics ? stripped_text.value() : text_input;
// If there is an exact match, relevance will be 1.0 and there is only 1
// hit that is the entire text/query.
const auto& query_text = query.text();
const auto& text_text = text.text();
const auto query_size = query_text.size();
const auto text_size = text_text.size();
if (query_size > 0 && query_size == text_size &&
base::EqualsCaseInsensitiveASCII(query_text, text_text)) {
hits_.emplace_back(0, query_size);
return 1.0;
}
// The |relevances| stores the |relevance_scores| calculated from different
// string matching methods. The highest result among them will be returned.
std::vector<double> relevances;
// The |hits_vector| stores the |hits| calculated from different string
// matching methods. The final selected instance corresponds to the hits
// generated by the matching algorithm which yielded the highest relevance
// score. The final selected instance will be assigned to |hits_| then.
std::vector<Hits> hits_vector;
double prefix_score = PrefixMatcher(query, text, hits_vector);
// A scoring boost for short prefix matching queries.
if (query_size <= kMaxBoostSize && prefix_score > kMinScore) {
prefix_score = std::min(
1.0, prefix_score + 2.0 / (query_size * (query_size + text_size)));
}
relevances.emplace_back(prefix_score);
// Find hits using SequenceMatcher on original query and text.
Hits sequence_hits;
size_t match_size = 0;
for (const auto& match :
SequenceMatcher(query_text, text_text).GetMatchingBlocks()) {
if (match.length > 0) {
match_size += match.length;
sequence_hits.emplace_back(match.pos_second_string,
match.pos_second_string + match.length);
}
}
hits_vector.emplace_back(sequence_hits);
relevances.emplace_back(use_weighted_ratio
? WeightedRatio(query, text)
: SequenceMatcher(base::i18n::ToLower(query_text),
base::i18n::ToLower(text_text))
.Ratio(/*text_length_agnostic=*/true));
if (use_acronym_matcher) {
relevances.emplace_back(AcronymMatcher(query, text, hits_vector));
}
size_t best_match_pos =
std::max_element(relevances.begin(), relevances.end()) -
relevances.begin();
hits_ = hits_vector[best_match_pos];
return match_size == text_size
? relevances[best_match_pos]
: relevances[best_match_pos] * kNonExactMatchScaleRatio;
}
} // namespace ash::string_matching