chromium/tools/clang/rewrite_raw_ptr_fields/RewriteRawPtrFields.cpp

// Copyright 2020 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

// This is implementation of a clang tool that rewrites raw pointer fields into
// raw_ptr<T>:
//     Pointee* field_
// becomes:
//     raw_ptr<Pointee> field_
//
// Note that the tool always emits two kinds of output:
// 1. Fields to exclude:
//    - FilteredExprWriter
// 2. Edit/replacement directives:
//    - FieldDeclRewriter
//    - AffectedExprRewriter
// The rewriter is expected to be used twice, in two passes:
// 1. Output from the 1st pass should be used to generate fields-to-ignore.txt
//    (or to augment the manually created exclusion list file)
// 2. The 2nd pass should use fields-to-ignore.txt from the first pass as input
//    for the --exclude-fields cmdline parameter.  The output from the 2nd pass
//    can be used to perform the actual rewrite via extract_edits.py and
//    apply_edits.py.
//
// For more details, see the doc here:
// https://docs.google.com/document/d/1chTvr3fSofQNV_PDPEHRyUgcJCQBgTDOOBriW9gIm9M

#include <assert.h>

#include <algorithm>
#include <filesystem>
#include <limits>
#include <memory>
#include <optional>
#include <regex>
#include <string>
#include <vector>

#include "RawPtrHelpers.h"
#include "RawPtrManualPathsToIgnore.h"
#include "SeparateRepositoryPaths.h"
#include "clang/AST/ASTContext.h"
#include "clang/ASTMatchers/ASTMatchFinder.h"
#include "clang/ASTMatchers/ASTMatchers.h"
#include "clang/ASTMatchers/ASTMatchersMacros.h"
#include "clang/Basic/CharInfo.h"
#include "clang/Basic/SourceLocation.h"
#include "clang/Basic/SourceManager.h"
#include "clang/Frontend/CompilerInstance.h"
#include "clang/Frontend/FrontendActions.h"
#include "clang/Lex/Lexer.h"
#include "clang/Lex/MacroArgs.h"
#include "clang/Lex/PPCallbacks.h"
#include "clang/Lex/Preprocessor.h"
#include "clang/Tooling/CommonOptionsParser.h"
#include "clang/Tooling/Refactoring.h"
#include "clang/Tooling/Tooling.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ErrorOr.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/Path.h"
#include "llvm/Support/TargetSelect.h"

using namespace clang::ast_matchers;

namespace {

// Include path that needs to be added to all the files where raw_ptr<...>
// replaces a raw pointer.
const char kRawPtrIncludePath[] = "base/memory/raw_ptr.h";

// Include path that needs to be added to all the files where raw_ref<...>
// replaces a raw reference.
const char kRawRefIncludePath[] = "base/memory/raw_ref.h";

const char kRawSpanIncludePath[] = "base/memory/raw_span.h";

// Name of a cmdline parameter that can be used to specify a file listing fields
// that should not be rewritten to use raw_ptr<T>.
//
// See also:
// - OutputSectionHelper
// - raw_ptr_plugin::FilterFile
const char kExcludeFieldsParamName[] = "exclude-fields";

// Name of a cmdline parameter that can be used to specify a file listing
// regular expressions describing paths that should be excluded from the
// rewrite.
//
// See also:
// - PathFilterFile
const char kOverrideExcludePathsParamName[] = "override-exclude-paths";

// OutputSectionHelper helps gather and emit a section of output.
//
// The section of output is delimited in a way that makes it easy to extract it
// with sed like so:
//    $ DELIM = ...
//    $ cat ~/scratch/rewriter.out \
//        | sed '/^==== BEGIN $DELIM ====$/,/^==== END $DELIM ====$/{//!b};d' \
//        | sort | uniq > ~/scratch/some-out-of-band-output.txt
//    (For DELIM="EDITS", there is also tools/clang/scripts/extract_edits.py.)
//
// Each output line is deduped and may be followed by optional comment tags:
//        Some filter # tag1, tag2
//        Another filter # tag1, tag2, tag3
//        An output line with no comment tags
//
// The output lines are sorted.  This helps provide deterministic output (even
// if AST matchers start firing in a different order after benign clang
// changes).
//
// See also:
// - raw_ptr_plugin::FilterFile
// - OutputHelper
class OutputSectionHelper {
 public:
  explicit OutputSectionHelper(llvm::StringRef output_delimiter)
      : output_delimiter_(output_delimiter.str()) {}

  OutputSectionHelper(const OutputSectionHelper&) = delete;
  OutputSectionHelper& operator=(const OutputSectionHelper&) = delete;

  void Add(llvm::StringRef output_line,
           llvm::StringRef tag = "",
           llvm::StringRef loc = "") {
    // Look up |tags| associated with |output_line|.  As a side effect of the
    // lookup, |output_line| will be inserted if it wasn't already present in
    // the map.
    llvm::StringSet<>& tags = output_line_to_tags_[output_line];

    if (!tag.empty()) {
      tags.insert(tag);
    }

    // Do the same for source locations.
    llvm::StringSet<>& locs = output_line_to_locs_[output_line];
    if (!loc.empty()) {
      locs.insert(loc);
    }
  }

  void Emit() {
    if (output_line_to_tags_.empty())
      return;

    llvm::outs() << "==== BEGIN " << output_delimiter_ << " ====\n";
    for (const llvm::StringRef& output_line :
         GetSortedKeys(output_line_to_tags_)) {
      llvm::outs() << output_line;

      const llvm::StringSet<>& locs = output_line_to_locs_[output_line];
      if (!locs.empty()) {
        std::vector<llvm::StringRef> sorted_locs = GetSortedKeys(locs);
        std::string locs_comment =
            llvm::join(sorted_locs.begin(), sorted_locs.end(), ", ");
        llvm::outs() << " @ " << locs_comment;
      }

      const llvm::StringSet<>& tags = output_line_to_tags_[output_line];
      if (!tags.empty()) {
        std::vector<llvm::StringRef> sorted_tags = GetSortedKeys(tags);
        std::string tags_comment =
            llvm::join(sorted_tags.begin(), sorted_tags.end(), ", ");
        llvm::outs() << "  # " << tags_comment;
      }

      llvm::outs() << "\n";
    }
    llvm::outs() << "==== END " << output_delimiter_ << " ====\n";
  }

 private:
  template <typename TValue>
  static std::vector<llvm::StringRef> GetSortedKeys(
      const llvm::StringMap<TValue>& map) {
    std::vector<llvm::StringRef> sorted(map.keys().begin(), map.keys().end());
    std::sort(sorted.begin(), sorted.end());
    return sorted;
  }

  std::string output_delimiter_;
  llvm::StringMap<llvm::StringSet<>> output_line_to_tags_;
  llvm::StringMap<llvm::StringSet<>> output_line_to_locs_;
};

// Output format is documented in //docs/clang_tool_refactoring.md
class OutputHelper : public clang::tooling::SourceFileCallbacks {
 public:
  OutputHelper()
      : edits_helper_("EDITS"), field_decl_filter_helper_("FIELD FILTERS") {}
  ~OutputHelper() = default;

  OutputHelper(const OutputHelper&) = delete;
  OutputHelper& operator=(const OutputHelper&) = delete;

  void AddReplacement(const clang::SourceManager& source_manager,
                      const clang::SourceRange& replacement_range,
                      std::string replacement_text,
                      const char* include_path = nullptr) {
    clang::tooling::Replacement replacement(
        source_manager, clang::CharSourceRange::getCharRange(replacement_range),
        replacement_text);
    std::string file_path =
        std::filesystem::proximate(replacement.getFilePath().str());
    if (file_path.empty())
      return;

    std::replace(replacement_text.begin(), replacement_text.end(), '\n', '\0');
    std::string replacement_directive = llvm::formatv(
        "r:::{0}:::{1}:::{2}:::{3}", file_path, replacement.getOffset(),
        replacement.getLength(), replacement_text);
    edits_helper_.Add(replacement_directive);

    if (include_path) {
      std::string include_directive = llvm::formatv(
          "include-user-header:::{0}:::-1:::-1:::{1}", file_path, include_path);
      edits_helper_.Add(include_directive);
    }
  }

  void AddFilteredField(const clang::SourceManager& source_manager,
                        const clang::FieldDecl& field_decl,
                        llvm::StringRef filter_tag) {
    std::string qualified_name = field_decl.getQualifiedNameAsString();

    clang::SourceLocation loc = field_decl.getBeginLoc();
    // Calculate a relative path to the file not to make the output
    // environment-specific.
    std::string loc_str =
        std::filesystem::proximate(source_manager.getFilename(loc).str());
    if (!loc_str.empty()) {
      loc_str +=
          ":" + std::to_string(source_manager.getSpellingLineNumber(loc));
      loc_str +=
          ":" + std::to_string(source_manager.getSpellingColumnNumber(loc));
    }
    field_decl_filter_helper_.Add(qualified_name, filter_tag, loc_str);
  }

 private:
  // clang::tooling::SourceFileCallbacks override:
  bool handleBeginSource(clang::CompilerInstance& compiler) override {
    const clang::FrontendOptions& frontend_options = compiler.getFrontendOpts();

    assert((frontend_options.Inputs.size() == 1) &&
           "run_tool.py should invoke the rewriter one file at a time");
    const clang::FrontendInputFile& input_file = frontend_options.Inputs[0];
    assert(input_file.isFile() &&
           "run_tool.py should invoke the rewriter on actual files");

    current_language_ = input_file.getKind().getLanguage();

    return true;  // Report that |handleBeginSource| succeeded.
  }

  // clang::tooling::SourceFileCallbacks override:
  void handleEndSource() override {
    if (ShouldSuppressOutput())
      return;

    edits_helper_.Emit();
    field_decl_filter_helper_.Emit();
  }

  bool ShouldSuppressOutput() {
    switch (current_language_) {
      case clang::Language::Unknown:
      case clang::Language::Asm:
      case clang::Language::LLVM_IR:
      case clang::Language::OpenCL:
      case clang::Language::CUDA:
      case clang::Language::RenderScript:
      case clang::Language::HIP:
      case clang::Language::HLSL:
        // Rewriter can't handle rewriting the current input language.
        return true;

      case clang::Language::C:
      case clang::Language::ObjC:
        // raw_ptr<T> requires C++.  In particular, attempting to #include
        // "base/memory/raw_ptr.h" from C-only compilation units will lead
        // to compilation errors.
        return true;

      case clang::Language::CXX:
      case clang::Language::OpenCLCXX:
      case clang::Language::ObjCXX:
        return false;
    }

    assert(false && "Unrecognized clang::Language");
    return true;
  }

  OutputSectionHelper edits_helper_;
  OutputSectionHelper field_decl_filter_helper_;
  clang::Language current_language_ = clang::Language::Unknown;
};

// Matches CXXRecordDecls that are classified as trivial:
// https://en.cppreference.com/w/cpp/named_req/TrivialType
AST_MATCHER(clang::CXXRecordDecl, isTrivial) {
  return Node.isTrivial();
}

// Returns |true| if and only if:
// 1. |a| and |b| are in the same file (e.g. |false| is returned if any location
//    is within macro scratch space or a similar location;  similarly |false| is
//    returned if |a| and |b| are in different files).
// 2. |a| and |b| overlap.
bool IsOverlapping(const clang::SourceManager& source_manager,
                   const clang::SourceRange& a,
                   const clang::SourceRange& b) {
  clang::FullSourceLoc a1(a.getBegin(), source_manager);
  clang::FullSourceLoc a2(a.getEnd(), source_manager);
  clang::FullSourceLoc b1(b.getBegin(), source_manager);
  clang::FullSourceLoc b2(b.getEnd(), source_manager);

  // Are all locations in a file?
  if (!a1.isFileID() || !a2.isFileID() || !b1.isFileID() || !b2.isFileID())
    return false;

  // Are all locations in the same file?
  if (a1.getFileID() != a2.getFileID() || a2.getFileID() != b1.getFileID() ||
      b1.getFileID() != b2.getFileID()) {
    return false;
  }

  // Check the 2 cases below:
  // 1. A: |============|
  //    B:      |===============|
  //       a1   b1      a2      b2
  // or
  // 2. A: |====================|
  //    B:      |=======|
  //       a1   b1      b2      a2
  bool b1_is_inside_a_range = a1.getFileOffset() <= b1.getFileOffset() &&
                              b1.getFileOffset() <= a2.getFileOffset();

  // Check the 2 cases below:
  // 1. B: |============|
  //    A:      |===============|
  //       b1   a1      b2      a2
  // or
  // 2. B: |====================|
  //    A:      |=======|
  //       b1   a1      a2      b2
  bool a1_is_inside_b_range = b1.getFileOffset() <= a1.getFileOffset() &&
                              a1.getFileOffset() <= b2.getFileOffset();

  return b1_is_inside_a_range || a1_is_inside_b_range;
}

// Matcher for FieldDecl that has a SourceRange that overlaps other declarations
// within the parent RecordDecl.
//
// Given
//   struct MyStruct {
//     int f;
//     int f2, f3;
//     struct S { int x } f4;
//   };
// - doesn't match |f|
// - matches |f2| and |f3| (which overlap each other's location)
// - matches |f4| (which overlaps the location of |S|)
AST_MATCHER(clang::FieldDecl, overlapsOtherDeclsWithinRecordDecl) {
  const clang::FieldDecl& self = Node;
  const clang::SourceManager& source_manager =
      Finder->getASTContext().getSourceManager();

  const clang::RecordDecl* record_decl = self.getParent();
  if (!record_decl)
    return false;

  clang::SourceRange self_range(self.getBeginLoc(), self.getEndLoc());

  auto is_overlapping_sibling = [&](const clang::Decl* other_decl) {
    if (other_decl == &self)
      return false;

    clang::SourceRange other_range(other_decl->getBeginLoc(),
                                   other_decl->getEndLoc());
    return IsOverlapping(source_manager, self_range, other_range);
  };
  bool has_sibling_with_overlapping_location =
      std::any_of(record_decl->decls_begin(), record_decl->decls_end(),
                  is_overlapping_sibling);
  return has_sibling_with_overlapping_location;
}

// Matches clang::Type if
// 1) it represents a RecordDecl with a FieldDecl that matches the InnerMatcher
//    (*all* such FieldDecls will be matched)
// or
// 2) it represents an array or a RecordDecl that nests the case #1
//    (this recurses to any depth).
AST_MATCHER_P(clang::QualType,
              typeWithEmbeddedFieldDecl,
              clang::ast_matchers::internal::Matcher<clang::FieldDecl>,
              InnerMatcher) {
  const clang::Type* type =
      Node.getDesugaredType(Finder->getASTContext()).getTypePtrOrNull();
  if (!type)
    return false;

  if (const clang::CXXRecordDecl* record_decl = type->getAsCXXRecordDecl()) {
    auto matcher =
        recordDecl(forEach(fieldDecl(raw_ptr_plugin::hasExplicitFieldDecl(anyOf(
            InnerMatcher, hasType(typeWithEmbeddedFieldDecl(InnerMatcher)))))));
    return matcher.matches(*record_decl, Finder, Builder);
  }

  if (type->isArrayType()) {
    const clang::ArrayType* array_type =
        Finder->getASTContext().getAsArrayType(Node);
    auto matcher = typeWithEmbeddedFieldDecl(InnerMatcher);
    return matcher.matches(array_type->getElementType(), Finder, Builder);
  }

  return false;
}

class FieldDeclRewriter : public MatchFinder::MatchCallback {
 public:
  explicit FieldDeclRewriter(OutputHelper* output_helper,
                             const char* format_string,
                             const char* include_path)
      : output_helper_(output_helper),
        format_string_(format_string),
        include_path_(include_path) {}

  FieldDeclRewriter(const FieldDeclRewriter&) = delete;
  FieldDeclRewriter& operator=(const FieldDeclRewriter&) = delete;

  virtual bool earlyExit(const MatchFinder::MatchResult& result) const = 0;

  void run(const MatchFinder::MatchResult& result) override {
    if (earlyExit(result)) {
      return;
    }
    const clang::ASTContext& ast_context = *result.Context;
    const clang::SourceManager& source_manager = *result.SourceManager;

    const clang::FieldDecl* field_decl =
        result.Nodes.getNodeAs<clang::FieldDecl>("affectedFieldDecl");

    assert(field_decl && "matcher should bind 'fieldDecl'");

    const clang::TypeSourceInfo* type_source_info =
        field_decl->getTypeSourceInfo();
    if (auto* ivar_decl = clang::dyn_cast<clang::ObjCIvarDecl>(field_decl)) {
      // Objective-C @synthesize statements should not be rewritten. They
      // return null for getTypeSourceInfo().
      if (ivar_decl->getSynthesize()) {
        assert(!type_source_info);
        return;
      }
    }
    assert(type_source_info && "assuming |type_source_info| is always present");

    clang::QualType pointer_type = type_source_info->getType();

    // Calculate the |replacement_range|.
    //
    // Consider the following example:
    //      const Pointee* const field_name_;
    //      ^--------------------^  = |replacement_range|
    //                           ^  = |field_decl->getLocation()|
    //      ^                       = |field_decl->getBeginLoc()|
    //                   ^          = PointerTypeLoc::getStarLoc
    //            ^------^          = TypeLoc::getSourceRange
    //
    // We get the |replacement_range| in a bit clumsy way, because clang docs
    // for QualifiedTypeLoc explicitly say that these objects "intentionally
    // do not provide source location for type qualifiers".
    clang::SourceRange replacement_range(field_decl->getBeginLoc(),
                                         field_decl->getLocation());

    // Calculate |replacement_text|.
    std::string replacement_text = GenerateNewText(ast_context, pointer_type);
    if (field_decl->isMutable())
      replacement_text.insert(0, "mutable ");

    // Generate and print a replacement.
    output_helper_->AddReplacement(source_manager, replacement_range,
                                   replacement_text, include_path_);
  }

 private:
  std::string GenerateNewText(const clang::ASTContext& ast_context,
                              const clang::QualType& pointer_type) {
    std::string result;

    clang::QualType pointee_type = pointer_type->getPointeeType();

    // Preserve qualifiers.
    assert(
        !pointer_type.isRestrictQualified() &&
        "|restrict| is a C-only qualifier and raw_ptr<T>/raw_ref<T> need C++");
    if (pointer_type.isConstQualified())
      result += "const ";
    if (pointer_type.isVolatileQualified())
      result += "volatile ";

    // Convert pointee type to string.
    clang::PrintingPolicy printing_policy(ast_context.getLangOpts());
    printing_policy.SuppressScope = 1;  // s/blink::Pointee/Pointee/
    std::string pointee_type_as_string =
        pointee_type.getAsString(printing_policy);
    result += llvm::formatv(format_string_, pointee_type_as_string);

    return result;
  }

  OutputHelper* const output_helper_;
  const char* format_string_;
  const char* include_path_;
};

class AffectedExprRewriter : public MatchFinder::MatchCallback {
 public:
  explicit AffectedExprRewriter(
      OutputHelper* output_helper,
      std::function<std::pair<clang::SourceRange, std::string>(
          const MatchFinder::MatchResult&)> fct)
      : output_helper_(output_helper), getRangeAndText_(fct) {}

  AffectedExprRewriter(const AffectedExprRewriter&) = delete;
  AffectedExprRewriter& operator=(const AffectedExprRewriter&) = delete;

  void run(const MatchFinder::MatchResult& result) override {
    const clang::SourceManager& source_manager = *result.SourceManager;

    auto [replacement_range, text] = getRangeAndText_(result);
    output_helper_->AddReplacement(source_manager, replacement_range,
                                   text.c_str());
  }

 private:
  OutputHelper* const output_helper_;
  std::function<std::pair<clang::SourceRange, std::string>(
      const MatchFinder::MatchResult&)>
      getRangeAndText_;
};

// Emits problematic fields (matched as "affectedFieldDecl") as filtered fields.
class FilteredExprWriter : public MatchFinder::MatchCallback {
 public:
  FilteredExprWriter(OutputHelper* output_helper, llvm::StringRef filter_tag)
      : output_helper_(output_helper), filter_tag_(filter_tag) {}

  FilteredExprWriter(const FilteredExprWriter&) = delete;
  FilteredExprWriter& operator=(const FilteredExprWriter&) = delete;

  void run(const MatchFinder::MatchResult& result) override {
    const clang::FieldDecl* field_decl =
        result.Nodes.getNodeAs<clang::FieldDecl>("affectedFieldDecl");
    assert(field_decl && "matcher should bind 'affectedFieldDecl'");

    output_helper_->AddFilteredField(*result.SourceManager, *field_decl,
                                     filter_tag_);
  }

 private:
  OutputHelper* const output_helper_;
  llvm::StringRef filter_tag_;
};

class RawPtrRewriter {
 public:
  RawPtrRewriter(
      OutputHelper* output_helper,
      MatchFinder& finder,
      const raw_ptr_plugin::RawPtrAndRefExclusionsOptions& exclusion_options)
      : match_finder(finder),
        field_decl_rewriter(output_helper, "raw_ptr<{0}> ", kRawPtrIncludePath),
        affected_expr_rewriter(output_helper, getRangeAndText_),
        filtered_addr_of_expr_writer(output_helper, "addr-of"),
        filtered_in_out_ref_arg_writer(output_helper, "in-out-param-ref"),
        overlapping_field_decl_writer(output_helper, "overlapping"),
        macro_field_decl_writer(output_helper, "macro"),
        global_scope_rewriter(output_helper, "global-scope"),
        union_field_decl_writer(output_helper, "union"),
        reinterpret_cast_struct_writer(output_helper,
                                       "reinterpret-cast-trivial-type"),
        exclusion_options_(exclusion_options) {}

  void addMatchers() {
    auto field_decl_matcher = AffectedRawPtrFieldDecl(exclusion_options_);

    match_finder.addMatcher(field_decl_matcher, &field_decl_rewriter);

    // Matches expressions that used to return a value of type |SomeClass*|
    // but after the rewrite return an instance of |raw_ptr<SomeClass>|.
    // Many such expressions might need additional changes after the rewrite:
    // - Some expressions (printf args, const_cast args, etc.) might need
    // |.get()|
    //   appended.
    // - Using such expressions in specific contexts (e.g. as in-out arguments
    // or
    //   as a return value of a function returning references) may require
    //   additional work and should cause related fields to be emitted as
    //   candidates for the --field-filter-file parameter.
    auto affected_member_expr_matcher =
        memberExpr(member(fieldDecl(raw_ptr_plugin::hasExplicitFieldDecl(
                       field_decl_matcher))))
            .bind("affectedMemberExpr");
    auto affected_expr_matcher = ignoringImplicit(affected_member_expr_matcher);

    // Places where |.get()| needs to be appended =========
    // Given
    //   void foo(const S& s) {
    //     printf("%p", s.y);
    //     const_cast<...>(s.y)
    //     reinterpret_cast<...>(s.y)
    //   }
    // matches the |s.y| expr if it matches the |affected_expr_matcher| above.
    //
    // See also testcases in tests/affected-expr-original.cc
    auto affected_expr_that_needs_fixing_matcher = expr(allOf(
        affected_expr_matcher,
        hasParent(expr(anyOf(callExpr(callee(functionDecl(isVariadic()))),
                             cxxConstCastExpr(), cxxReinterpretCastExpr())))));

    match_finder.addMatcher(affected_expr_that_needs_fixing_matcher,
                            &affected_expr_rewriter);

    // Affected ternary operator args =========
    // Given
    //   void foo(const S& s) {
    //     cond ? s.y : ...
    //   }
    // binds the |s.y| expr if it matches the |affected_expr_matcher| above.
    //
    // See also testcases in tests/affected-expr-original.cc
    auto affected_ternary_operator_arg_matcher =
        conditionalOperator(eachOf(hasTrueExpression(affected_expr_matcher),
                                   hasFalseExpression(affected_expr_matcher)));
    match_finder.addMatcher(affected_ternary_operator_arg_matcher,
                            &affected_expr_rewriter);

    // Affected string binary operator =========
    // Given
    //   struct S { const char* y; }
    //   void foo(const S& s) {
    //     std::string other;
    //     bool v1 = s.y == other;
    //     std::string v2 = s.y + other;
    //   }
    // binds the |s.y| expr if it matches the |affected_expr_matcher| above.
    //
    // See also testcases in tests/affected-expr-original.cc
    auto std_string_expr_matcher =
        expr(hasType(cxxRecordDecl(hasName("::std::basic_string"))));
    auto affected_string_binary_operator_arg_matcher = cxxOperatorCallExpr(
        hasAnyOverloadedOperatorName("+", "==", "!=", "<", "<=", ">", ">="),
        hasAnyArgument(std_string_expr_matcher),
        forEachArgumentWithParam(affected_expr_matcher, parmVarDecl()));
    match_finder.addMatcher(affected_string_binary_operator_arg_matcher,
                            &affected_expr_rewriter);

    // Calls to templated functions =========
    // Given
    //   struct S { int* y; };
    //   template <typename T>
    //   void templatedFunc(T* arg) {}
    //   void foo(const S& s) {
    //     templatedFunc(s.y);
    //   }
    // binds the |s.y| expr if it matches the |affected_expr_matcher| above.
    //
    // See also testcases in tests/affected-expr-original.cc
    auto templated_function_arg_matcher = forEachArgumentWithParam(
        affected_expr_matcher,
        parmVarDecl(allOf(
            hasType(
                qualType(allOf(findAll(qualType(substTemplateTypeParmType())),
                               unless(referenceType())))),
            unless(hasAncestor(functionDecl(hasName("Unretained")))))));
    match_finder.addMatcher(callExpr(templated_function_arg_matcher),
                            &affected_expr_rewriter);
    // TODO(lukasza): It is unclear why |traverse| below is needed.  Maybe it
    // can be removed if https://bugs.llvm.org/show_bug.cgi?id=46287 is fixed.
    match_finder.addMatcher(
        traverse(clang::TraversalKind::TK_AsIs,
                 cxxConstructExpr(templated_function_arg_matcher)),
        &affected_expr_rewriter);

    // Calls to constructors via an implicit cast =========
    // Given
    //   struct I { I(int*) {} };
    //   void bar(I i) {}
    //   struct S { int* y; };
    //   void foo(const S& s) {
    //     bar(s.y);  // implicit cast from |s.y| to I.
    //   }
    // binds the |s.y| expr if it matches the |affected_expr_matcher| above.
    //
    // See also testcases in tests/affected-expr-original.cc
    auto implicit_ctor_expr_matcher = cxxConstructExpr(
        allOf(anyOf(hasParent(materializeTemporaryExpr()),
                    hasParent(implicitCastExpr())),
              hasDeclaration(cxxConstructorDecl(
                  allOf(parameterCountIs(1), unless(isExplicit())))),
              forEachArgumentWithParam(affected_expr_matcher, parmVarDecl())));
    match_finder.addMatcher(implicit_ctor_expr_matcher,
                            &affected_expr_rewriter);

    // |auto| type declarations =========
    // Given
    //   struct S { int* y; };
    //   void foo(const S& s) {
    //     auto* p = s.y;
    //   }
    // binds the |s.y| expr if it matches the |affected_expr_matcher| above.
    //
    // See also testcases in tests/affected-expr-original.cc
    auto auto_var_decl_matcher = declStmt(forEach(
        varDecl(allOf(hasType(pointerType(pointee(autoType()))),
                      hasInitializer(anyOf(
                          affected_expr_matcher,
                          initListExpr(hasInit(0, affected_expr_matcher))))))));
    match_finder.addMatcher(auto_var_decl_matcher, &affected_expr_rewriter);

    // address-of(affected-expr) =========
    // Given
    //   ... &s.y ...
    // matches the |s.y| expr if it matches the |affected_member_expr_matcher|
    // above.
    //
    // See also the testcases in tests/gen-in-out-arg-test.cc.
    auto affected_addr_of_expr_matcher = expr(allOf(
        affected_expr_matcher, hasParent(unaryOperator(hasOperatorName("&")))));

    match_finder.addMatcher(affected_addr_of_expr_matcher,
                            &filtered_addr_of_expr_writer);

    // in-out reference arg =========
    // Given
    //   struct S { SomeClass* ptr_field; };
    //   void f(SomeClass*& in_out_arg) { ... }
    //   template <typename T> void f2(T&& rvalue_ref_arg) { ... }
    //   template <typename... Ts> void f3(Ts&&... rvalue_ref_args) { ... }
    //   void bar() {
    //     S s;
    //     foo(s.ptr_field)
    //   }
    // matches the |s.ptr_field| expr if it matches the
    // |affected_member_expr_matcher| and is passed as a function argument that
    // has |FooBar*&| type (like |f|, but unlike |f2| and |f3|).
    //
    // See also the testcases in tests/gen-in-out-arg-test.cc.
    auto affected_in_out_ref_arg_matcher = callExpr(forEachArgumentWithParam(
        affected_expr_matcher,
        raw_ptr_plugin::hasExplicitParmVarDecl(
            hasType(qualType(allOf(referenceType(pointee(pointerType())),
                                   unless(rValueReferenceType())))))));

    match_finder.addMatcher(affected_in_out_ref_arg_matcher,
                            &filtered_in_out_ref_arg_writer);

    // See the doc comment for the overlapsOtherDeclsWithinRecordDecl matcher
    // and the testcases in tests/gen-overlapping-test.cc.
    auto overlapping_field_decl_matcher = fieldDecl(
        allOf(field_decl_matcher, overlapsOtherDeclsWithinRecordDecl()));

    match_finder.addMatcher(overlapping_field_decl_matcher,
                            &overlapping_field_decl_writer);

    // See the doc comment for the isInMacroLocation matcher
    // and the testcases in tests/gen-macros-test.cc.
    auto macro_field_decl_matcher = fieldDecl(
        allOf(field_decl_matcher, raw_ptr_plugin::isInMacroLocation()));

    match_finder.addMatcher(macro_field_decl_matcher, &macro_field_decl_writer);

    // See the testcases in tests/gen-global-scope-test.cc.
    auto global_scope_matcher =
        varDecl(allOf(hasGlobalStorage(),
                      hasType(typeWithEmbeddedFieldDecl(field_decl_matcher))));

    match_finder.addMatcher(global_scope_matcher, &global_scope_rewriter);

    // This is used to exclude unions from certain files that are known to have
    // safe usage of union (i.e. doesn't cause ref count mismatch), such as
    // std::optional and absl::variant.
    files_with_audited_unions =
        std::make_unique<raw_ptr_plugin::FilterFile>(std::vector<std::string>{
            "third_party/libc++/src/include/optional",
            "third_party/abseil-cpp/absl/types/internal/variant.h",
        });
    // Matches fields in unions (both directly rewritable fields as well as
    // union fields that embed a struct that contains a rewritable field).  See
    // also the testcases in tests/gen-unions-test.cc.
    auto union_field_decl_matcher = recordDecl(allOf(
        isUnion(),
        unless(isInLocationListedInFilterFile(files_with_audited_unions.get())),
        forEach(fieldDecl(
            anyOf(field_decl_matcher,
                  hasType(typeWithEmbeddedFieldDecl(field_decl_matcher)))))));

    match_finder.addMatcher(union_field_decl_matcher, &union_field_decl_writer);

    // Matches rewritable fields of struct `SomeStruct` if that struct happens
    // to be a destination type of a `reinterpret_cast<SomeStruct*>` cast and is
    // a trivial type (otherwise `reinterpret_cast<SomeStruct*>` wouldn't be
    // valid before the rewrite if it skipped non-trivial constructors).
    auto reinterpret_cast_struct_matcher =
        cxxReinterpretCastExpr(hasDestinationType(pointerType(pointee(
            hasUnqualifiedDesugaredType(recordType(hasDeclaration(cxxRecordDecl(
                allOf(forEach(field_decl_matcher), isTrivial())))))))));

    match_finder.addMatcher(reinterpret_cast_struct_matcher,
                            &reinterpret_cast_struct_writer);
  }

 private:
  // Rewrites |SomeClass* field| (matched as "affectedFieldDecl") into
  // |raw_ptr<SomeClass> field| and for each file rewritten in such way adds an
  // |#include "base/memory/raw_ptr.h"|.
  class RawPtrFieldDeclRewriter : public FieldDeclRewriter {
   public:
    explicit RawPtrFieldDeclRewriter(OutputHelper* output_helper,
                                     const char* format_string,
                                     const char* include_path)
        : FieldDeclRewriter(output_helper, format_string, include_path) {}

    bool earlyExit(const MatchFinder::MatchResult& result) const override {
      return false;
    }
  };
  // Rewrites |my_struct.ptr_field| (matched as "affectedMemberExpr") into
  // |my_struct.ptr_field.get()|.
  std::function<std::pair<clang::SourceRange, std::string>(
      const MatchFinder::MatchResult&)>
      getRangeAndText_ = [](const MatchFinder::MatchResult& result)
      -> std::pair<clang::SourceRange, std::string> {
    const clang::MemberExpr* member_expr =
        result.Nodes.getNodeAs<clang::MemberExpr>("affectedMemberExpr");
    assert(member_expr && "matcher should bind 'affectedMemberExpr'");

    clang::SourceLocation member_name_start = member_expr->getMemberLoc();
    size_t member_name_length = member_expr->getMemberDecl()->getName().size();
    clang::SourceLocation insertion_loc =
        member_name_start.getLocWithOffset(member_name_length);

    clang::SourceRange replacement_range(insertion_loc, insertion_loc);
    return {replacement_range, ".get()"};
  };
  MatchFinder& match_finder;
  RawPtrFieldDeclRewriter field_decl_rewriter;
  AffectedExprRewriter affected_expr_rewriter;
  FilteredExprWriter filtered_addr_of_expr_writer;
  FilteredExprWriter filtered_in_out_ref_arg_writer;
  FilteredExprWriter overlapping_field_decl_writer;
  FilteredExprWriter macro_field_decl_writer;
  FilteredExprWriter global_scope_rewriter;
  FilteredExprWriter union_field_decl_writer;
  FilteredExprWriter reinterpret_cast_struct_writer;
  std::unique_ptr<raw_ptr_plugin::FilterFile> files_with_audited_unions;
  const raw_ptr_plugin::RawPtrAndRefExclusionsOptions exclusion_options_;
};

class RawRefRewriter {
 public:
  RawRefRewriter(
      OutputHelper* output_helper,
      MatchFinder& finder,
      const raw_ptr_plugin::RawPtrAndRefExclusionsOptions& exclusion_options)
      : match_finder(finder),
        field_decl_rewriter(output_helper,
                            "const raw_ref<{0}> ",
                            kRawRefIncludePath),
        affected_expr_operator_rewriter(output_helper,
                                        affectedMemberExprOperatorFct_),
        affected_expr_rewriter(output_helper, affectedMemberExprFct_),
        affected_expr_rewriter_with_parentheses(
            output_helper,
            affectedMemberExprWithParenFct_),
        affected_initializer_expr_rewriter(output_helper,
                                           affectedInitializerExprFct_),
        global_scope_rewriter(output_helper, "global-scope"),
        overlapping_field_decl_writer(output_helper, "overlapping"),
        macro_field_decl_writer(output_helper, "macro"),
        exclusion_options_(exclusion_options) {}

  void addMatchers() {
    auto field_decl_matcher = AffectedRawRefFieldDecl(exclusion_options_);

    match_finder.addMatcher(field_decl_matcher, &field_decl_rewriter);

    // Matches expressions of the form |someClass.ref_field.sub_member| which
    // should be rewritten as |someClass.ref_field->sub_member| as we can't
    // overload `operator.` in C++.
    auto affected_member_expr_operator_matcher =
        expr(anyOf(memberExpr(has(memberExpr(
                       member(fieldDecl(raw_ptr_plugin::hasExplicitFieldDecl(
                           field_decl_matcher)))))),
                   memberExpr(has(implicitCastExpr(has(memberExpr(
                       member(fieldDecl(raw_ptr_plugin::hasExplicitFieldDecl(
                           field_decl_matcher)))))))),
                   cxxDependentScopeMemberExpr(has(memberExpr(
                       member(fieldDecl(raw_ptr_plugin::hasExplicitFieldDecl(
                           field_decl_matcher))))))))
            .bind("affectedMemberExprOperator");

    match_finder.addMatcher(affected_member_expr_operator_matcher,
                            &affected_expr_operator_rewriter);

    // Matches expressions that used to have |SomeType&| as return type and
    // became |const raw_ref<SomeType>| after the rewrite.
    auto affected_member_expr = memberExpr(
        memberExpr(
            member(fieldDecl(
                raw_ptr_plugin::hasExplicitFieldDecl(field_decl_matcher))),
            unless(
                anyOf(hasParent(memberExpr()),
                      hasParent(implicitCastExpr(hasParent(memberExpr()))),
                      hasParent(cxxDependentScopeMemberExpr()),
                      hasParent(varDecl(unless(anyOf(
                          hasType(referenceType(pointee(autoType()))),
                          hasParent(declStmt(hasParent(cxxForRangeStmt()))))))),
                      hasAncestor(cxxConstructorDecl(isDefaulted())),
                      hasParent(cxxOperatorCallExpr()),
                      hasParent(unaryOperator(
                          anyOf(hasOperatorName("--"), hasOperatorName("++")))),
                      hasParent(arraySubscriptExpr()),
                      hasParent(callExpr(
                          callee(fieldDecl(raw_ptr_plugin::hasExplicitFieldDecl(
                              field_decl_matcher))))))))
            .bind("affectedMemberExpr"),

        unless(anyOf(
            // Exclude memberExpressions appearing inside a constructor
            // initializer of a reference field where we should NOT add
            // operator*.
            hasParent(cxxConstructorDecl(hasAnyConstructorInitializer(
                allOf(withInitializer(
                          memberExpr(equalsBoundNode("affectedMemberExpr"))),
                      forField(fieldDecl(raw_ptr_plugin::hasExplicitFieldDecl(
                          field_decl_matcher))))))),
            // Exclude memberExpressions, in initializer lists, that are
            // initializing a reference field that will be rewritten into
            // raw_ref.
            hasParent(initListExpr(raw_ptr_plugin::forEachInitExprWithFieldDecl(
                memberExpr(equalsBoundNode("affectedMemberExpr")),
                raw_ptr_plugin::hasExplicitFieldDecl(field_decl_matcher)))))));

    match_finder.addMatcher(affected_member_expr, &affected_expr_rewriter);

    auto affected_member_expr_matcher =
        memberExpr(member(fieldDecl(raw_ptr_plugin::hasExplicitFieldDecl(
                       field_decl_matcher))))
            .bind("affectedMemberExpr");

    // Calls to constructors via an implicit cast =========
    // Given
    //   struct I { I(int&) {} };
    //   void bar(I i) {}
    //   struct S { int& y; };
    //   void foo(const S& s) {
    //     bar(s.y);  // implicit cast from |s.y| to I.
    //   }
    // binds the |s.y| expr if it matches the |affected_expr_matcher| above.
    //
    // See also testcases in tests/affected-expr-original.cc
    auto implicit_ctor_expr_matcher = cxxConstructExpr(allOf(
        anyOf(hasParent(materializeTemporaryExpr()),
              hasParent(implicitCastExpr())),
        hasDeclaration(cxxConstructorDecl(
            allOf(parameterCountIs(1), unless(isExplicit())))),
        forEachArgumentWithParam(affected_member_expr_matcher, parmVarDecl())));
    match_finder.addMatcher(implicit_ctor_expr_matcher,
                            &affected_expr_rewriter);

    // |auto| type declarations =========
    // Given
    //   struct S { int& y; };
    //   void foo(const S& s) {
    //     auto& p = s.y;
    //   }
    // binds the |s.y| expr if it matches the |affected_expr_matcher| above.
    //
    // See also testcases in tests/affected-expr-original.cc
    auto auto_var_decl_matcher = declStmt(forEach(varDecl(
        allOf(hasType(referenceType(pointee(autoType()))),
              hasInitializer(anyOf(
                  affected_member_expr_matcher,
                  initListExpr(hasInit(0, affected_member_expr_matcher))))))));
    match_finder.addMatcher(auto_var_decl_matcher, &affected_expr_rewriter);

    // Matches affected member expressions that need parenthesization.
    auto affected_member_expr_with_parentheses =
        memberExpr(member(fieldDecl(raw_ptr_plugin::hasExplicitFieldDecl(
                       field_decl_matcher))),
                   anyOf(hasParent(cxxOperatorCallExpr()),
                         hasParent(unaryOperator(anyOf(hasOperatorName("--"),
                                                       hasOperatorName("++")))),
                         hasParent(arraySubscriptExpr()),
                         hasParent(callExpr(callee(
                             fieldDecl(raw_ptr_plugin::hasExplicitFieldDecl(
                                 field_decl_matcher)))))))
            .bind("affectedMemberExprWithParentheses");

    match_finder.addMatcher(affected_member_expr_with_parentheses,
                            &affected_expr_rewriter_with_parentheses);

    // for structs/class that don't define a constructor and are initialized
    // using braced list initialization, we need to add raw_ref around the
    // initializing expression since raw_ref's constructor is explicit.
    // Example:
    // struct A{ int& member; }; => struct A{ const raw_ref<int> member;};
    // int num = x;
    // A a{num}; => A a{raw_ref(num)};
    auto init_list_expr_with_raw_ref = initListExpr(
        raw_ptr_plugin::forEachInitExprWithFieldDecl(
            expr(unless(anyOf(
                     materializeTemporaryExpr(),
                     // Exclude member expressions where the member is a
                     // reference field that will be rewritten into raw_ref.
                     memberExpr(
                         member(fieldDecl(raw_ptr_plugin::hasExplicitFieldDecl(
                             field_decl_matcher)))))))
                .bind("initializer_expr"),
            raw_ptr_plugin::hasExplicitFieldDecl(field_decl_matcher)),
        unless(hasParent(cxxConstructExpr())));

    match_finder.addMatcher(init_list_expr_with_raw_ref,
                            &affected_initializer_expr_rewriter);

    // See the doc comment for the overlapsOtherDeclsWithinRecordDecl
    // matcher and the testcases in tests/gen-overlapping-test.cc.
    auto overlapping_field_decl_matcher = fieldDecl(
        allOf(field_decl_matcher, overlapsOtherDeclsWithinRecordDecl()));

    match_finder.addMatcher(overlapping_field_decl_matcher,
                            &overlapping_field_decl_writer);

    // See the doc comment for the isInMacroLocation matcher
    // and the testcases in tests/gen-macros-test.cc.
    auto macro_field_decl_matcher = fieldDecl(
        allOf(field_decl_matcher, raw_ptr_plugin::isInMacroLocation()));

    match_finder.addMatcher(macro_field_decl_matcher, &macro_field_decl_writer);

    // See the testcases in tests/gen-global-scope-test.cc.
    auto global_scope_matcher =
        varDecl(allOf(hasGlobalStorage(),
                      hasType(typeWithEmbeddedFieldDecl(field_decl_matcher))));

    match_finder.addMatcher(global_scope_matcher, &global_scope_rewriter);
  }

 private:
  // Rewrites |SomeClass& field| (matched as "affectedFieldDecl") as
  // |const raw_ref<SomeClass> field| and for each file rewritten in such way
  // adds an
  // |#include "base/memory/raw_ref.h"|.
  class RawRefFieldDeclRewriter : public FieldDeclRewriter {
   public:
    explicit RawRefFieldDeclRewriter(OutputHelper* output_helper,
                                     const char* format_string,
                                     const char* include_path)
        : FieldDeclRewriter(output_helper, format_string, include_path) {}

    bool earlyExit(const MatchFinder::MatchResult& result) const override {
      auto* type = result.Nodes.getNodeAs<clang::LValueReferenceTypeLoc>(
          "affectedFieldDeclType");
      // in this case, it's not an lvalue reference type member => DO NOTHING
      return !type;
    }
  };

  // Rewrites |my_struct.ref_field| (matched as "affectedMemberExpr") as
  // |*my_struct.ref_field|.
  std::function<std::pair<clang::SourceRange, std::string>(
      const MatchFinder::MatchResult&)>
      affectedMemberExprFct_ = [](const MatchFinder::MatchResult& result)
      -> std::pair<clang::SourceRange, std::string> {
    const clang::MemberExpr* member_expr =
        result.Nodes.getNodeAs<clang::MemberExpr>("affectedMemberExpr");
    assert(member_expr && "matcher should bind 'affectedMemberExpr'");
    clang::SourceRange replacement_range(member_expr->getBeginLoc(),
                                         member_expr->getBeginLoc());
    return {replacement_range, "*"};
  };

  // Rewrites |my_struct.ref_field| (matched as
  // "affectedMemberExprWithParentheses") as
  // |(*my_struct.ref_field)|.
  // Examples on why this is needed:
  //  1- std::vector<T>& v; => const raw_ref<std::vector<T>> v;
  //     v[0] => needs to be rewritten as (*v)[0] after the rewrite.
  //  2- key_compare& comp_; => const raw_ref<key_compare> comp_;
  //     comp_(a, b) => needs to be rewritten as (*comp_)(a,b) after the
  //     rewrite.
  std::function<std::pair<clang::SourceRange, std::string>(
      const MatchFinder::MatchResult&)>
      affectedMemberExprWithParenFct_ =
          [](const MatchFinder::MatchResult& result)
      -> std::pair<clang::SourceRange, std::string> {
    const clang::SourceManager& source_manager = *result.SourceManager;

    const clang::MemberExpr* member_expr =
        result.Nodes.getNodeAs<clang::MemberExpr>(
            "affectedMemberExprWithParentheses");
    assert(member_expr &&
           "matcher should bind 'affectedMemberExprWithParentheses'");

    clang::SourceLocation member_name_start = member_expr->getMemberLoc();
    clang::SourceLocation endLoc = member_name_start.getLocWithOffset(
        member_expr->getMemberDecl()->getName().size());

    clang::SourceRange replacement_range(member_expr->getBeginLoc(), endLoc);

    auto source_text = clang::Lexer::getSourceText(
        clang::CharSourceRange::getTokenRange(member_expr->getSourceRange()),
        source_manager, result.Context->getLangOpts());
    return {replacement_range,
            llvm::formatv("(*{0})",
                          std::string(source_text.begin(), source_text.end()))};
  };

  // Rewrites |my_struct.ptr_field.sub_field| (matched as
  // "affectedMemberExprOperator") into |my_struct.ptr_field->sub_field|.
  std::function<std::pair<clang::SourceRange, std::string>(
      const MatchFinder::MatchResult&)>
      affectedMemberExprOperatorFct_ =
          [](const MatchFinder::MatchResult& result)
      -> std::pair<clang::SourceRange, std::string> {
    const clang::MemberExpr* member_expr =
        result.Nodes.getNodeAs<clang::MemberExpr>("affectedMemberExprOperator");
    const clang::CXXDependentScopeMemberExpr* cxx_dependent_scope_member_expr =
        result.Nodes.getNodeAs<clang::CXXDependentScopeMemberExpr>(
            "affectedMemberExprOperator");
    assert((member_expr || cxx_dependent_scope_member_expr) &&
           "matcher should bind 'affectedMemberExprOperator'");
    if (member_expr) {
      clang::SourceRange replacement_range(member_expr->getOperatorLoc(),
                                           member_expr->getMemberLoc());
      return {replacement_range, "->"};
    }
    clang::SourceRange replacement_range(
        cxx_dependent_scope_member_expr->getOperatorLoc(),
        cxx_dependent_scope_member_expr->getMemberLoc());
    return {replacement_range, "->"};
  };

  std::function<std::pair<clang::SourceRange, std::string>(
      const MatchFinder::MatchResult&)>
      affectedInitializerExprFct_ = [](const MatchFinder::MatchResult& result)
      -> std::pair<clang::SourceRange, std::string> {
    const clang::SourceManager& source_manager = *result.SourceManager;

    const clang::Expr* initializer_expr =
        result.Nodes.getNodeAs<clang::Expr>("initializer_expr");
    auto source_text = clang::Lexer::getSourceText(
        clang::CharSourceRange::getTokenRange(
            initializer_expr->getSourceRange()),
        source_manager, result.Context->getLangOpts());

    clang::SourceLocation endLoc =
        initializer_expr->getBeginLoc().getLocWithOffset(source_text.size());

    clang::SourceRange replacement_range(initializer_expr->getBeginLoc(),
                                         endLoc);

    return {replacement_range,
            llvm::formatv("raw_ref({0})",
                          std::string(source_text.begin(), source_text.end()))};
  };

  MatchFinder& match_finder;
  RawRefFieldDeclRewriter field_decl_rewriter;
  AffectedExprRewriter affected_expr_operator_rewriter;
  AffectedExprRewriter affected_expr_rewriter;
  AffectedExprRewriter affected_expr_rewriter_with_parentheses;
  AffectedExprRewriter affected_initializer_expr_rewriter;
  FilteredExprWriter global_scope_rewriter;
  FilteredExprWriter overlapping_field_decl_writer;
  FilteredExprWriter macro_field_decl_writer;
  const raw_ptr_plugin::RawPtrAndRefExclusionsOptions exclusion_options_;
};

class SpanFieldDeclRewriter : public MatchFinder::MatchCallback {
 public:
  explicit SpanFieldDeclRewriter(OutputHelper* output_helper,
                                 const char* include_path)
      : output_helper_(output_helper), include_path_(include_path) {}

  SpanFieldDeclRewriter(const SpanFieldDeclRewriter&) = delete;
  SpanFieldDeclRewriter& operator=(const SpanFieldDeclRewriter&) = delete;

  void run(const MatchFinder::MatchResult& result) override {
    const clang::ASTContext& ast_context = *result.Context;
    const clang::SourceManager& source_manager = *result.SourceManager;
    const auto& lang_opts = ast_context.getLangOpts();
    const clang::FieldDecl* field_decl =
        result.Nodes.getNodeAs<clang::FieldDecl>("affectedFieldDecl");

    assert(field_decl && "matcher should bind 'fieldDecl'");

    const clang::TypeSourceInfo* type_source_info =
        field_decl->getTypeSourceInfo();
    if (auto* ivar_decl = clang::dyn_cast<clang::ObjCIvarDecl>(field_decl)) {
      // Objective-C @synthesize statements should not be rewritten. They
      // return null for getTypeSourceInfo().
      if (ivar_decl->getSynthesize()) {
        assert(!type_source_info);
        return;
      }
    }

    assert(type_source_info && "assuming |type_source_info| is always present");

    if (result.Nodes.getNodeAs<clang::QualType>("container_type")) {
      HandleContainerArguments(field_decl, result);
      return;
    }

    // Calculate the |replacement_range|.
    //
    // Consider the following example:
    //      const span<> const   field_name_;
    //      ^--------------------^  = |replacement_range|
    //                           ^  = |field_decl->getLocation()|
    //      ^                       = |field_decl->getBeginLoc()|
    //
    // We get the |replacement_range| in a bit clumsy way, because clang docs
    // for QualifiedTypeLoc explicitly say that these objects "intentionally
    // do not provide source location for type qualifiers".
    clang::SourceRange replacement_range(field_decl->getBeginLoc(),
                                         field_decl->getLocation());

    GenerateReplacement(replacement_range, source_manager, lang_opts);
  }

 private:
  clang::SourceRange GetTemplateArgumentSourceRange(
      const clang::TemplateSpecializationTypeLoc& tst_tl,
      unsigned i) {
    // For some reason, the last template argument's end location is marked as
    // being in scratch space. This leads to a wrong size for the replacement.
    // Work around this by using the RAngle's ('>') Location.
    if (i == (tst_tl.getNumArgs() - 1)) {
      return clang::SourceRange(tst_tl.getArgLoc(i).getLocation(),
                                tst_tl.getRAngleLoc());
    }

    return tst_tl.getArgLoc(i).getSourceRange();
  }

  std::optional<clang::TemplateSpecializationTypeLoc>
  GetTemplateSpecializationTypeLoc(clang::TypeLoc loc) {
    // We can have a TemplateSpecializationTypeLoc directly.
    // Example: span<some_type> member;
    if (auto specialization =
            loc.getAs<clang::TemplateSpecializationTypeLoc>()) {
      return specialization;
    }

    // Or an elaboratedTypeLoc, which has a namedTypeLoc (the
    // TemplateSpecializationTypeLoc)
    // Example:
    // base::span<some_type> member;
    //       ^-------------^ => templateSpecializationTypeLoc
    // ^-------------------^ => elaboratedTypeLoc
    if (auto elaborated = loc.getAs<clang::ElaboratedTypeLoc>()) {
      if (auto specialization =
              elaborated.getNamedTypeLoc()
                  .getAs<clang::TemplateSpecializationTypeLoc>()) {
        return specialization;
      }
    }
    return {};
  }

  void HandleContainerArguments(const clang::FieldDecl* decl,
                                const MatchFinder::MatchResult& result) {
    const clang::ASTContext& ast_context = *result.Context;
    const clang::SourceManager& source_manager = *result.SourceManager;
    auto field_type_loc = decl->getTypeSourceInfo()->getTypeLoc();
    const auto& lang_opts = ast_context.getLangOpts();
    auto tstl = GetTemplateSpecializationTypeLoc(field_type_loc);

    // This means that the field type is a typedef to a span type. This is not
    // handled by the rewriter.
    if (!tstl) {
      return;
    }

    unsigned argument_index = 0;
    if (result.Nodes.getNodeAs<clang::TemplateArgument>("template_arg0")) {
      argument_index = 0;
      auto source_range = GetTemplateArgumentSourceRange(*tstl, argument_index);
      GenerateReplacement(source_range, source_manager, lang_opts);
    }

    if (result.Nodes.getNodeAs<clang::TemplateArgument>("template_arg1")) {
      argument_index = 1;
      auto source_range = GetTemplateArgumentSourceRange(*tstl, argument_index);
      GenerateReplacement(source_range, source_manager, lang_opts);
    }
  }

  void GenerateReplacement(const clang::SourceRange& source_range,
                           const clang::SourceManager& source_manager,
                           const clang::LangOptions& lang_opts) {
    std::string initial_text =
        clang::Lexer::getSourceText(
            clang::CharSourceRange::getCharRange(source_range), source_manager,
            lang_opts)
            .str();

    // The span type to rewrite could appear as follows:
    // 1- span<some_type> (used within base namespace)
    // 2- base::span<some_type> or container<base::span<some_type>>
    // 3- container<span<some_type>> (used within base namespace)
    // The statement below inserts `raw_` before the second matched group
    // `span`. std::span is banned in chromium code, so it's not taken into
    // account in the below regex.
    std::string replacement_text = std::regex_replace(
        initial_text, std::regex("(<|base::)?(span<)"), "$1raw_$2");

    // No need to add a replacement if the replacement text is empty or is the
    // same as the initial text. |initial_text| is the same as
    // |replacemenet_text| when the field's type is an alias of base::span,
    // meaning span<T> does not appear in |initial_text| and thus the regex
    // replace does nothing.
    if (replacement_text.empty() || (initial_text == replacement_text)) {
      return;
    }
    // Generate and print a replacement.
    output_helper_->AddReplacement(source_manager, source_range,
                                   replacement_text, include_path_);
  }

  OutputHelper* const output_helper_;
  const char* include_path_;
};

class SpanRewriter {
 public:
  SpanRewriter(
      OutputHelper* output_helper,
      MatchFinder& finder,
      const raw_ptr_plugin::RawPtrAndRefExclusionsOptions& exclusion_options)
      : match_finder(finder),
        field_decl_rewriter(output_helper, kRawSpanIncludePath),
        global_scope_rewriter(output_helper, "global-scope"),
        overlapping_field_decl_writer(output_helper, "overlapping"),
        macro_field_decl_writer(output_helper, "macro"),
        exclusion_options_(exclusion_options) {}

  void addMatchers() {
    auto raw_span = hasTemplateArgument(
        2, refersToType(qualType(hasCanonicalType(qualType(hasDeclaration(
               mapAnyOf(classTemplateSpecializationDecl, classTemplateDecl)
                   .with(hasName("raw_ptr"))))))));

    auto string_literals_span = hasTemplateArgument(
        0, refersToType(qualType(hasCanonicalType(
               anyOf(asString("const char"), asString("const wchar_t"),
                     asString("const char8_t"), asString("const char16_t"),
                     asString("const char32_t"))))));

    auto excluded_spans = anyOf(raw_span, string_literals_span);

    auto span_type = anyOf(
        qualType(hasCanonicalType(
            qualType(hasDeclaration(classTemplateSpecializationDecl(
                hasName("base::span"), unless(excluded_spans)))))),
        // This part of the matcher is needed to handle templates.
        // Example:
        // template<typename T>struct S{ base::span<T> member; };
        // |member| has canonical type templateSpecializationType.
        qualType(hasCanonicalType(qualType(type(templateSpecializationType(
            hasDeclaration(classTemplateDecl(hasName("base::span"))),
            unless(excluded_spans)))))));

    auto optional_span_type = anyOf(
        qualType(
            hasCanonicalType(hasDeclaration(classTemplateSpecializationDecl(
                hasName("optional"),
                hasTemplateArgument(0, refersToType(span_type)))))),
        qualType(hasCanonicalType(qualType(type(templateSpecializationType(
            hasDeclaration(classTemplateDecl(hasName("optional"))),
            hasTemplateArgument(0, refersToType(span_type))))))));

    auto container_methods =
        anyOf(allOf(hasMethod(hasName("push_back")),
                    hasMethod(hasName("pop_back")), hasMethod(hasName("size"))),
              allOf(hasMethod(hasName("insert")), hasMethod(hasName("erase")),
                    hasMethod(hasName("size"))),
              allOf(hasMethod(hasName("push")), hasMethod(hasName("pop")),
                    hasMethod(hasName("size"))));

    auto template_arg0 = hasTemplateArgument(
        0, templateArgument(refersToType(anyOf(span_type, optional_span_type)))
               .bind("template_arg0"));
    auto template_arg1 = hasTemplateArgument(
        1, templateArgument(refersToType(anyOf(span_type, optional_span_type)))
               .bind("template_arg1"));
    // template_arg0 and template_arg1 are necessary to locate the container
    // template arguments that need to be rewritten. The use of allOf is to
    // force the matching of both arguments if both need to be rewritten.
    // Did not use a forEachTemplateArgument instead as we need the template
    // argument's index to get its location using the field's
    // templateSpecializationTypeLoc.
    auto template_arguments = anyOf(allOf(template_arg0, template_arg1),
                                    template_arg0, template_arg1);

    auto container_of_span_type =
        qualType(hasCanonicalType(anyOf(
                     qualType(hasDeclaration(classTemplateSpecializationDecl(
                         container_methods, template_arguments))),
                     qualType(type(templateSpecializationType(
                         hasDeclaration(classTemplateDecl(
                             has(cxxRecordDecl(container_methods)))),
                         template_arguments))))))
            .bind("container_type");

    auto field_decl_matcher =
        traverse(clang::TK_IgnoreUnlessSpelledInSource,
                 fieldDecl(hasType(qualType(anyOf(span_type, optional_span_type,
                                                  container_of_span_type))),
                           unless(PtrAndRefExclusions(exclusion_options_)))
                     .bind("affectedFieldDecl"));

    match_finder.addMatcher(field_decl_matcher, &field_decl_rewriter);

    // See the testcases in tests/gen-global-scope-test.cc.
    auto global_scope_matcher =
        varDecl(allOf(hasGlobalStorage(),
                      hasType(typeWithEmbeddedFieldDecl(field_decl_matcher))));

    match_finder.addMatcher(global_scope_matcher, &global_scope_rewriter);

    // See the doc comment for the isInMacroLocation matcher
    // and the testcases in tests/gen-macros-test.cc.
    auto macro_field_decl_matcher = fieldDecl(
        allOf(field_decl_matcher, raw_ptr_plugin::isInMacroLocation()));

    match_finder.addMatcher(macro_field_decl_matcher, &macro_field_decl_writer);
  }

 private:
  MatchFinder& match_finder;
  SpanFieldDeclRewriter field_decl_rewriter;
  FilteredExprWriter global_scope_rewriter;
  FilteredExprWriter overlapping_field_decl_writer;
  FilteredExprWriter macro_field_decl_writer;
  const raw_ptr_plugin::RawPtrAndRefExclusionsOptions exclusion_options_;
};

}  // namespace

int main(int argc, const char* argv[]) {
  // TODO(dcheng): Clang tooling should do this itself.
  // http://llvm.org/bugs/show_bug.cgi?id=21627
  llvm::InitializeNativeTarget();
  llvm::InitializeNativeTargetAsmParser();
  llvm::cl::OptionCategory category(
      "rewrite_raw_ptr_fields: changes |T* field_| to |raw_ptr<T> field_|.");
  llvm::cl::opt<std::string> exclude_fields_param(
      kExcludeFieldsParamName, llvm::cl::value_desc("filepath"),
      llvm::cl::desc("file listing fields to be blocked (not rewritten)"));
  llvm::cl::opt<std::string> override_exclude_paths_param(
      kOverrideExcludePathsParamName, llvm::cl::value_desc("filepath"),
      llvm::cl::desc(
          "override file listing paths to be blocked (not rewritten)"));

  llvm::cl::opt<bool> enable_raw_ref_rewrite(
      "enable_raw_ref_rewrite", llvm::cl::init(false),
      llvm::cl::desc("Rewrite T& into const raw_ref<T>"));

  llvm::cl::opt<bool> enable_raw_ptr_rewrite(
      "enable_raw_ptr_rewrite", llvm::cl::init(false),
      llvm::cl::desc("Rewrite T* into raw_ptr<T>"));

  llvm::cl::opt<bool> exclude_stack_allocated(
      "exclude_stack_allocated", llvm::cl::init(true),
      llvm::cl::desc("Exclude pointers/references to `STACK_ALLOCATED` objects "
                     "from the rewrite"));

  llvm::Expected<clang::tooling::CommonOptionsParser> options =
      clang::tooling::CommonOptionsParser::create(argc, argv, category);
  assert(static_cast<bool>(options));  // Should not return an error.
  clang::tooling::ClangTool tool(options->getCompilations(),
                                 options->getSourcePathList());

  // Rewrite both T& and T* into const raw_ref<T> and raw_ptr<T> respectively if
  // no argument is provided.
  bool rewrite_raw_ref_and_ptr =
      !enable_raw_ref_rewrite && !enable_raw_ptr_rewrite;
  MatchFinder match_finder;
  OutputHelper output_helper;
  raw_ptr_plugin::FilterFile fields_to_exclude(
      exclude_fields_param, exclude_fields_param.ArgStr.str());

  std::unique_ptr<raw_ptr_plugin::FilterFile> paths_to_exclude;
  if (override_exclude_paths_param == "") {
    std::vector<std::string> paths_to_exclude_lines;
    for (auto* const line : kRawPtrManualPathsToIgnore) {
      paths_to_exclude_lines.push_back(line);
    }
    for (auto* const line : kSeparateRepositoryPaths) {
      paths_to_exclude_lines.push_back(line);
    }
    paths_to_exclude =
        std::make_unique<raw_ptr_plugin::FilterFile>(paths_to_exclude_lines);
  } else {
    paths_to_exclude = std::make_unique<raw_ptr_plugin::FilterFile>(
        override_exclude_paths_param,
        override_exclude_paths_param.ArgStr.str());
  }

  raw_ptr_plugin::StackAllocatedPredicate stack_allocated_checker;
  raw_ptr_plugin::RawPtrAndRefExclusionsOptions exclusion_options{
      &fields_to_exclude, paths_to_exclude.get(), exclude_stack_allocated,
      &stack_allocated_checker, true};

  RawPtrRewriter raw_ptr_rewriter(&output_helper, match_finder,
                                  exclusion_options);
  if (rewrite_raw_ref_and_ptr || enable_raw_ptr_rewrite) {
    raw_ptr_rewriter.addMatchers();
  }

  RawRefRewriter raw_ref_rewriter(&output_helper, match_finder,
                                  exclusion_options);
  if (rewrite_raw_ref_and_ptr || enable_raw_ref_rewrite) {
    raw_ref_rewriter.addMatchers();
  }

  SpanRewriter span_rewriter(&output_helper, match_finder, exclusion_options);
  span_rewriter.addMatchers();

  // Prepare and run the tool.
  std::unique_ptr<clang::tooling::FrontendActionFactory> factory =
      clang::tooling::newFrontendActionFactory(&match_finder, &output_helper);
  int result = tool.run(factory.get());
  if (result != 0)
    return result;

  return 0;
}