chromium/chromeos/ash/components/string_matching/fuzzy_tokenized_string_match.cc

// Copyright 2019 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chromeos/ash/components/string_matching/fuzzy_tokenized_string_match.h"

#include <algorithm>
#include <cmath>
#include <cstdlib>
#include <iterator>
#include <optional>
#include <set>
#include <string>
#include <vector>

#include "base/i18n/case_conversion.h"
#include "base/strings/strcat.h"
#include "base/strings/string_util.h"
#include "chromeos/ash/components/string_matching/acronym_matcher.h"
#include "chromeos/ash/components/string_matching/diacritic_utils.h"
#include "chromeos/ash/components/string_matching/prefix_matcher.h"
#include "chromeos/ash/components/string_matching/sequence_matcher.h"

namespace ash::string_matching {

namespace {

using Hits = FuzzyTokenizedStringMatch::Hits;

constexpr double kPartialMatchPenaltyRate = 0.9;

constexpr double kMinScore = 0.0;
constexpr double kMaxScore = 1.0;

// The maximum supported size for a prefix matching scoring boost.
constexpr size_t kMaxBoostSize = 2;

// The scale ratio for non exact matching results.
constexpr double kNonExactMatchScaleRatio = 0.97;

// Returns sorted tokens from a TokenizedString.
std::vector<std::u16string> ProcessAndSort(const TokenizedString& text) {
  std::vector<std::u16string> result;
  for (const auto& token : text.tokens()) {
    result.emplace_back(token);
  }
  std::sort(result.begin(), result.end());
  return result;
}

double ScaledRelevance(const double relevance) {
  return 1.0 - std::pow(0.5, relevance);
}

}  // namespace

FuzzyTokenizedStringMatch::~FuzzyTokenizedStringMatch() = default;
FuzzyTokenizedStringMatch::FuzzyTokenizedStringMatch() = default;

double FuzzyTokenizedStringMatch::TokenSetRatio(const TokenizedString& query,
                                                const TokenizedString& text,
                                                bool partial) {
  std::set<std::u16string> query_token(query.tokens().begin(),
                                       query.tokens().end());
  std::set<std::u16string> text_token(text.tokens().begin(),
                                      text.tokens().end());

  std::vector<std::u16string> intersection;
  std::vector<std::u16string> query_diff_text;
  std::vector<std::u16string> text_diff_query;

  // Find the set intersection and the set differences between two sets of
  // tokens.
  std::set_intersection(query_token.begin(), query_token.end(),
                        text_token.begin(), text_token.end(),
                        std::back_inserter(intersection));
  std::set_difference(query_token.begin(), query_token.end(),
                      text_token.begin(), text_token.end(),
                      std::back_inserter(query_diff_text));
  std::set_difference(text_token.begin(), text_token.end(), query_token.begin(),
                      query_token.end(), std::back_inserter(text_diff_query));

  const std::u16string intersection_string =
      base::JoinString(intersection, u" ");
  const std::u16string query_rewritten =
      intersection.empty()
          ? base::JoinString(query_diff_text, u" ")
          : base::StrCat({intersection_string, u" ",
                          base::JoinString(query_diff_text, u" ")});
  const std::u16string text_rewritten =
      intersection.empty()
          ? base::JoinString(text_diff_query, u" ")
          : base::StrCat({intersection_string, u" ",
                          base::JoinString(text_diff_query, u" ")});

  if (partial) {
    return std::max({PartialRatio(intersection_string, query_rewritten),
                     PartialRatio(intersection_string, text_rewritten),
                     PartialRatio(query_rewritten, text_rewritten)});
  }

  return std::max(
      {SequenceMatcher(intersection_string, query_rewritten).Ratio(),
       SequenceMatcher(intersection_string, text_rewritten).Ratio(),
       SequenceMatcher(query_rewritten, text_rewritten).Ratio()});
}

double FuzzyTokenizedStringMatch::TokenSortRatio(const TokenizedString& query,
                                                 const TokenizedString& text,
                                                 bool partial) {
  const std::u16string query_sorted =
      base::JoinString(ProcessAndSort(query), u" ");
  const std::u16string text_sorted =
      base::JoinString(ProcessAndSort(text), u" ");

  if (partial) {
    return PartialRatio(query_sorted, text_sorted);
  }
  return SequenceMatcher(query_sorted, text_sorted).Ratio();
}

double FuzzyTokenizedStringMatch::PartialRatio(const std::u16string& query,
                                               const std::u16string& text) {
  if (query.empty() || text.empty()) {
    return kMinScore;
  }
  std::u16string shorter = query;
  std::u16string longer = text;

  if (shorter.size() > longer.size()) {
    shorter = text;
    longer = query;
  }

  const auto matching_blocks =
      SequenceMatcher(shorter, longer).GetMatchingBlocks();
  double partial_ratio = 0;

  for (const auto& block : matching_blocks) {
    const int long_start =
        block.pos_second_string > block.pos_first_string
            ? block.pos_second_string - block.pos_first_string
            : 0;

    // Penalizes the match if it is not close to the beginning of a token.
    int current = long_start - 1;
    while (current >= 0 &&
           !base::EqualsCaseInsensitiveASCII(longer.substr(current, 1), u" ")) {
      current--;
    }
    const double penalty =
        std::pow(kPartialMatchPenaltyRate, long_start - current - 1);
    // TODO(crbug.com/40638914): currently this part re-calculate the ratio for
    // every pair. Improve this to reduce latency.
    partial_ratio = std::max(
        partial_ratio,
        SequenceMatcher(shorter, longer.substr(long_start, shorter.size()))
                .Ratio() *
            penalty);

    if (partial_ratio > 0.995) {
      return kMaxScore;
    }
  }
  return partial_ratio;
}

double FuzzyTokenizedStringMatch::WeightedRatio(const TokenizedString& query,
                                                const TokenizedString& text) {
  // All token based comparisons are scaled by 0.95 (on top of any partial
  // scalars), as per original implementation:
  // https://github.com/seatgeek/fuzzywuzzy/blob/af443f918eebbccff840b86fa606ac150563f466/fuzzywuzzy/fuzz.py#L245
  const double unbase_scale = 0.95;

  // Since query.text() and text.text() is not normalized, we use query.tokens()
  // and text.tokens() instead.
  const std::u16string query_normalized(base::JoinString(query.tokens(), u" "));
  const std::u16string text_normalized(base::JoinString(text.tokens(), u" "));

  std::vector<double> weighted_ratios;
  weighted_ratios.emplace_back(
      SequenceMatcher(query_normalized, text_normalized)
          .Ratio(/*text_length_agnostic=*/true));

  const double length_ratio =
      static_cast<double>(
          std::max(query_normalized.size(), text_normalized.size())) /
      std::min(query_normalized.size(), text_normalized.size());

  // Use partial if two strings are quite different in sizes.
  const bool use_partial = length_ratio >= 1.5;
  double length_ratio_scale = 1;

  if (use_partial) {
    // TODO(crbug.com/1336160): Consider scaling |partial_scale| smoothly with
    // |length_ratio|, instead of using a step function.
    //
    // If one string is much much shorter than the other, set |partial_scale| to
    // be 0.6, otherwise set it to be 0.9.
    length_ratio_scale = length_ratio > 8 ? 0.6 : 0.9;
    weighted_ratios.emplace_back(
        PartialRatio(query_normalized, text_normalized) * length_ratio_scale);
  }
  weighted_ratios.emplace_back(TokenSortRatio(query, text, use_partial) *
                               unbase_scale * length_ratio_scale);

  // Do not use partial match for token set because the match between the
  // intersection string and query/text rewrites will always return an extremely
  // high value.
  weighted_ratios.emplace_back(TokenSetRatio(query, text, false /*partial*/) *
                               unbase_scale * length_ratio_scale);

  // Return the maximum of all included weighted ratios
  return *std::max_element(weighted_ratios.begin(), weighted_ratios.end());
}

double FuzzyTokenizedStringMatch::PrefixMatcher(const TokenizedString& query,
                                                const TokenizedString& text) {
  string_matching::PrefixMatcher match(query, text);
  match.Match();
  return ScaledRelevance(match.relevance());
}

double FuzzyTokenizedStringMatch::AcronymMatcher(const TokenizedString& query,
                                                 const TokenizedString& text) {
  string_matching::AcronymMatcher match(query, text);
  const double relevance = match.CalculateRelevance();
  return ScaledRelevance(relevance);
}

double FuzzyTokenizedStringMatch::PrefixMatcher(
    const TokenizedString& query,
    const TokenizedString& text,
    std::vector<Hits>& hits_vector) {
  string_matching::PrefixMatcher match(query, text);
  match.Match();

  hits_vector.emplace_back(match.hits());
  return ScaledRelevance(match.relevance());
}

double FuzzyTokenizedStringMatch::AcronymMatcher(
    const TokenizedString& query,
    const TokenizedString& text,
    std::vector<Hits>& hits_vector) {
  string_matching::AcronymMatcher match(query, text);
  const double relevance = match.CalculateRelevance();

  hits_vector.emplace_back(match.hits());
  return ScaledRelevance(relevance);
}

double FuzzyTokenizedStringMatch::Relevance(const TokenizedString& query_input,
                                            const TokenizedString& text_input,
                                            bool use_weighted_ratio,
                                            bool strip_diacritics,
                                            bool use_acronym_matcher) {
  // If the query is much longer than the text then it's often not a match.
  if (query_input.text().size() >= text_input.text().size() * 2) {
    return 0.0;
  }

  std::optional<TokenizedString> stripped_query;
  std::optional<TokenizedString> stripped_text;
  if (strip_diacritics) {
    stripped_query.emplace(RemoveDiacritics(query_input.text()));
    stripped_text.emplace(RemoveDiacritics(text_input.text()));
  }

  const TokenizedString& query =
      strip_diacritics ? stripped_query.value() : query_input;
  const TokenizedString& text =
      strip_diacritics ? stripped_text.value() : text_input;

  // If there is an exact match, relevance will be 1.0 and there is only 1
  // hit that is the entire text/query.
  const auto& query_text = query.text();
  const auto& text_text = text.text();
  const auto query_size = query_text.size();
  const auto text_size = text_text.size();
  if (query_size > 0 && query_size == text_size &&
      base::EqualsCaseInsensitiveASCII(query_text, text_text)) {
    hits_.emplace_back(0, query_size);
    return 1.0;
  }

  // The |relevances| stores the |relevance_scores| calculated from different
  // string matching methods. The highest result among them will be returned.
  std::vector<double> relevances;
  // The |hits_vector| stores the |hits| calculated from different string
  // matching methods. The final selected instance corresponds to the hits
  // generated by the matching algorithm which yielded the highest relevance
  // score. The final selected instance will be assigned to |hits_| then.
  std::vector<Hits> hits_vector;

  double prefix_score = PrefixMatcher(query, text, hits_vector);
  // A scoring boost for short prefix matching queries.
  if (query_size <= kMaxBoostSize && prefix_score > kMinScore) {
    prefix_score = std::min(
        1.0, prefix_score + 2.0 / (query_size * (query_size + text_size)));
  }
  relevances.emplace_back(prefix_score);

  // Find hits using SequenceMatcher on original query and text.
  Hits sequence_hits;
  size_t match_size = 0;
  for (const auto& match :
       SequenceMatcher(query_text, text_text).GetMatchingBlocks()) {
    if (match.length > 0) {
      match_size += match.length;
      sequence_hits.emplace_back(match.pos_second_string,
                                 match.pos_second_string + match.length);
    }
  }
  hits_vector.emplace_back(sequence_hits);

  relevances.emplace_back(use_weighted_ratio
                              ? WeightedRatio(query, text)
                              : SequenceMatcher(base::i18n::ToLower(query_text),
                                                base::i18n::ToLower(text_text))
                                    .Ratio(/*text_length_agnostic=*/true));
  if (use_acronym_matcher) {
    relevances.emplace_back(AcronymMatcher(query, text, hits_vector));
  }

  size_t best_match_pos =
      std::max_element(relevances.begin(), relevances.end()) -
      relevances.begin();
  hits_ = hits_vector[best_match_pos];
  return match_size == text_size
             ? relevances[best_match_pos]
             : relevances[best_match_pos] * kNonExactMatchScaleRatio;
}

}  // namespace ash::string_matching