chromium/tools/clang/rewrite_templated_container_fields/RewriteTemplatedPtrFields.cpp

// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

// This is the implementation of a clang tool that rewrites containers of
// pointer fields into raw_ptr<T>:
//     std::vector<Pointee*> field_
// becomes:
//     std::vector<raw_ptr<Pointee>> field_
//
// Note that the tool emits two kinds of outputs:
//   1- A pairs of nodes formatted as {lhs};{rhs}\n representing an edge between
//      two nodes.
//   2- A single node formatted as {lhs}\n
// The concatenated outputs from multiple tool runs are then used to construct
// the graph and emit relevant edits using extract_edits.py
//
// A node (lhs, rhs) has the following format:
// '{is_field,is_excluded,has_auto_type,r:::<file
// path>:::<offset>:::<length>:::<replacement text>,include-user-header:::<file
// path>:::-1:::-1:::<include text>}'
//
// where `is_field`,`is_excluded`, and `has_auto_type` are booleans represendted
// as  0 or 1.
//
// For more details, see the doc here:
// https://docs.google.com/document/d/1P8wLVS3xueI4p3EAPO4JJP6d1_zVp5SapQB0EW9iHQI/

#include <assert.h>
#include <algorithm>
#include <cstdio>
#include <fstream>
#include <limits>
#include <map>
#include <memory>
#include <set>
#include <sstream>
#include <string>
#include <string_view>
#include <vector>

#include "RawPtrHelpers.h"
#include "RawPtrManualPathsToIgnore.h"
#include "SeparateRepositoryPaths.h"
#include "clang/AST/ASTContext.h"
#include "clang/ASTMatchers/ASTMatchFinder.h"
#include "clang/ASTMatchers/ASTMatchers.h"
#include "clang/ASTMatchers/ASTMatchersMacros.h"
#include "clang/Basic/CharInfo.h"
#include "clang/Basic/SourceLocation.h"
#include "clang/Basic/SourceManager.h"
#include "clang/Frontend/CompilerInstance.h"
#include "clang/Frontend/FrontendActions.h"
#include "clang/Lex/Lexer.h"
#include "clang/Lex/MacroArgs.h"
#include "clang/Lex/PPCallbacks.h"
#include "clang/Lex/Preprocessor.h"
#include "clang/Tooling/CommonOptionsParser.h"
#include "clang/Tooling/Refactoring.h"
#include "clang/Tooling/Tooling.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ErrorOr.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/LineIterator.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/Path.h"
#include "llvm/Support/TargetSelect.h"

using namespace clang::ast_matchers;

namespace {

// Include path that needs to be added to all the files where raw_ptr<...>
// replaces a raw pointer.
const char kRawPtrIncludePath[] = "base/memory/raw_ptr.h";

const char kOverrideExcludePathsParamName[] = "override-exclude-paths";

// This iterates over function parameters and matches the ones that match
// parm_var_decl_matcher.
AST_MATCHER_P(clang::FunctionDecl,
              forEachParmVarDecl,
              clang::ast_matchers::internal::Matcher<clang::ParmVarDecl>,
              parm_var_decl_matcher) {
  const clang::FunctionDecl& function_decl = Node;

  auto num_params = function_decl.getNumParams();
  bool is_matching = false;
  clang::ast_matchers::internal::BoundNodesTreeBuilder result;
  for (unsigned i = 0; i < num_params; i++) {
    const clang::ParmVarDecl* param = function_decl.getParamDecl(i);
    clang::ast_matchers::internal::BoundNodesTreeBuilder param_matches;
    if (parm_var_decl_matcher.matches(*param, Finder, &param_matches)) {
      is_matching = true;
      result.addMatch(param_matches);
    }
  }
  *Builder = std::move(result);
  return is_matching;
}

// Returns a StringRef of the elements apprearing after the pattern
// '(anonymous namespace)::' if any, otherwise returns input.
static llvm::StringRef RemoveAnonymous(llvm::StringRef input) {
  constexpr llvm::StringRef kAnonymousNamespace{"(anonymous namespace)::"};
  auto loc = input.find(kAnonymousNamespace);
  if (loc != input.npos) {
    return input.substr(loc + kAnonymousNamespace.size());
  }
  return input;
}

// Statements of the form: for (auto* i : affected_expr)
// need to be changed to: for (type_name* i : affected_expr)
// in order to extract the pointer type from now raw_ptr.
// The text returned by type's `getAsString()` can contain some unuseful
// data. Example: 'const struct n1::(anonymous namespace)::n2::type_name'. As
// is, this wouldn't compile. This needs to be reinterpreted as
// 'const n2::type_name'.
// `RemovePrefix` removes the class/struct keyword if any,
// conserves the constness, and trims '(anonymous namespace)::'
// as well as anything on it's lhs using `RemoveAnonymous`.
static std::string RemovePrefix(llvm::StringRef input) {
  constexpr llvm::StringRef kClassPrefix{"class "};
  constexpr llvm::StringRef kStructPrefix{"struct "};
  constexpr llvm::StringRef kConstPrefix{"const "};

  std::string result;
  result.reserve(input.size());

  if (input.consume_front(kConstPrefix)) {
    result += kConstPrefix;
  }

  input.consume_front(kClassPrefix);
  input.consume_front(kStructPrefix);

  result += RemoveAnonymous(input);
  return result;
}

struct Node {
  bool is_field = false;
  // This is set to true for Fields annotated with RAW_PTR_EXCLUSION
  bool is_excluded = false;
  // auto type variables don't need to be rewritten. They still need to be
  // present in the graph to propagate the rewrite to non auto expressions.
  // Example:
  // auto temp = member_; vector<T*>::iterator it = temp.begin();
  // `it`'s type needs to be rewritten when member's type is.
  bool has_auto_type = false;
  // A replacement follows the following format:
  // `r:::<file path>:::<offset>:::<length>:::<replacement text>`
  std::string replacement;
  // An include directive follows the following format:
  // `include-user-header:::<file path>:::-1:::-1:::<include text>`
  std::string include_directive;
  bool operator==(const Node& other) const {
    return replacement == other.replacement;
  }
  bool operator<(const Node& other) const {
    return replacement < other.replacement;
  }
  // The resulting string follows the following format:
  // {is_field\,is_excluded\,has_auto_type\,r:::<file
  // path>:::<offset>:::<length>:::<replacement
  // text>\,include-user-header:::<file path>:::-1:::-1:::<include text>}
  // where is_field,is_excluded, and has_auto_type are booleans represendted as
  // 0 or 1.
  std::string ToString() const {
    return llvm::formatv("{{{0:d}\\,{1:d}\\,{2:d}\\,{3}\\,{4}}", is_field,
                         is_excluded, has_auto_type, replacement,
                         include_directive);
  }
};

// Helper class to add edges to the set of node_pairs_;
class OutputHelper {
 public:
  OutputHelper() = default;

  void AddEdge(const Node& lhs, const Node& rhs) {
    node_pairs_.insert(
        llvm::formatv("{0};{1}\n", lhs.ToString(), rhs.ToString()));
  }

  void AddSingleNode(const Node& lhs) {
    node_pairs_.insert(llvm::formatv("{0}\n", lhs.ToString()));
  }

  void Emit() {
    for (const auto& p : node_pairs_) {
      llvm::outs() << p;
    }
  }

 private:
  // This represents a line for every 2 adjacent nodes.
  // The format is: {lhs};{rhs}\n where lhs & rhs generated using
  // Node::ToString().
  // There are two cases where the line contains only a lhs node {lhs}\n
  // 1- To make sure that fields that are not connected to any other node are
  // represented in the graph.
  // 2- Fields annotated with RAW_PTR_EXCLUSION are also inserted as a single
  // node to the list.
  std::set<std::string> node_pairs_;
};

// This visitor is used to extract a FunctionDecl* bound with a node id
// "fct_decl" from a given match.
// This is used in the `forEachArg` and `forEachBindArg` matchers below.
class LocalVisitor
    : public clang::ast_matchers::internal::BoundNodesTreeBuilder::Visitor {
 public:
  void visitMatch(
      const clang::ast_matchers::BoundNodes& BoundNodesView) override {
    if (const auto* ptr =
            BoundNodesView.getNodeAs<clang::FunctionDecl>("fct_decl")) {
      fct_decl_ = BoundNodesView.getNodeAs<clang::FunctionDecl>("fct_decl");
      is_lambda_ = false;
    } else {
      const clang::LambdaExpr* lambda_expr =
          BoundNodesView.getNodeAs<clang::LambdaExpr>("lambda_expr");
      fct_decl_ = lambda_expr->getCallOperator();
      is_lambda_ = true;
    }
  }
  const clang::FunctionDecl* fct_decl_;
  bool is_lambda_;
};

// This is used to map arguments passed to std::make_unique to the underlying
// constructor parameters. For each expr that matches, using `LocalVisitor`, we
// extract the ptr to clang::FunctionDecl which represents here the
// constructorDecl and use it to get the parmVarDecl corresponding to the
// argument.
// This iterates over a callExpressions's arguments and matches the ones that
// match expr_matcher. For each argument matched, retrieve the corresponding
// constructor parameter. The constructor parameter is then checked against
// parm_var_decl_matcher.
AST_MATCHER_P2(clang::CallExpr,
               forEachArg,
               clang::ast_matchers::internal::Matcher<clang::Expr>,
               expr_matcher,
               clang::ast_matchers::internal::Matcher<clang::ParmVarDecl>,
               parm_var_decl_matcher) {
  const clang::CallExpr& call_expr = Node;

  auto num_args = call_expr.getNumArgs();
  bool is_matching = false;
  clang::ast_matchers::internal::BoundNodesTreeBuilder result;
  for (unsigned i = 0; i < num_args; i++) {
    const clang::Expr* arg = call_expr.getArg(i);
    clang::ast_matchers::internal::BoundNodesTreeBuilder arg_matches;
    if (expr_matcher.matches(*arg, Finder, &arg_matches)) {
      LocalVisitor l;
      arg_matches.visitMatches(&l);
      const auto* fct_decl = l.fct_decl_;
      if (fct_decl) {
        const auto* param = fct_decl->getParamDecl(i);
        clang::ast_matchers::internal::BoundNodesTreeBuilder parm_var_matches(
            arg_matches);
        if (parm_var_decl_matcher.matches(*param, Finder, &parm_var_matches)) {
          is_matching = true;
          result.addMatch(parm_var_matches);
        }
      }
    }
  }
  *Builder = std::move(result);
  return is_matching;
}

// This is used to handle expressions of the form:
// base::BindOnce(
//                [](std::vector<raw_ptr<Label>>& message_labels,
//                    Label* message_label) {
//                   message_labels.push_back(message_label);
//                 },
//                 std::ref(message_labels_))))
// This creates a link between the parmVarDecl's in the lambda/functionPointer
// passed as 1st argument and the rest of the arguments passed to the bind call.
AST_MATCHER_P2(clang::CallExpr,
               forEachBindArg,
               clang::ast_matchers::internal::Matcher<clang::Expr>,
               expr_matcher,
               clang::ast_matchers::internal::Matcher<clang::ParmVarDecl>,
               parm_var_decl_matcher) {
  const clang::CallExpr& call_expr = Node;

  auto num_args = call_expr.getNumArgs();
  if (num_args == 1) {
    // No arguments to map to the lambda/fct parmVarDecls.
    return false;
  }

  bool is_matching = false;
  clang::ast_matchers::internal::BoundNodesTreeBuilder result;
  for (unsigned i = 1; i < num_args; i++) {
    const clang::Expr* arg = call_expr.getArg(i);
    clang::ast_matchers::internal::BoundNodesTreeBuilder arg_matches;
    if (expr_matcher.matches(*arg, Finder, &arg_matches)) {
      LocalVisitor l;
      arg_matches.visitMatches(&l);
      const auto* fct_decl = l.fct_decl_;

      // start_index=1 when we start with second arg for Bind and first arg for
      // lambda/fct
      unsigned start_index = 1;
      // start_index=2 when the second arg is a pointer to the object on which
      // the function is to be invoked. This is done when the function pointer is
      // not a static class function.
      // isGlobal is true for free functions as well as static member functions,
      // both of which don't need a pointer to the object on which they are
      // invoked.
      if (!l.is_lambda_ && !l.fct_decl_->isGlobal()) {
        start_index = 2;
        // Skip the second argument passed to BindOnce/BindRepeating as it is an
        // object pointer unrelated to target function args.
        if (i == 1) {
          continue;
        }
      }
      const auto* param = fct_decl->getParamDecl(i - start_index);
      clang::ast_matchers::internal::BoundNodesTreeBuilder
          parm_var_decl_matches(arg_matches);
      if (parm_var_decl_matcher.matches(*param, Finder,
                                        &parm_var_decl_matches)) {
        is_matching = true;
        result.addMatch(parm_var_decl_matches);
      }
    }
  }
  *Builder = std::move(result);
  return is_matching;
}

static std::string GenerateNewType(const clang::ASTContext& ast_context,
                                   const clang::QualType& pointer_type) {
  std::string result;

  clang::QualType pointee_type = pointer_type->getPointeeType();

  // Preserve qualifiers.
  assert(!pointer_type.isRestrictQualified() &&
         "|restrict| is a C-only qualifier and raw_ptr<T>/raw_ref<T> need C++");
  if (pointer_type.isConstQualified()) {
    result += "const ";
  }
  if (pointer_type.isVolatileQualified()) {
    result += "volatile ";
  }

  // Convert pointee type to string.
  clang::PrintingPolicy printing_policy(ast_context.getLangOpts());
  printing_policy.SuppressScope = 1;  // s/blink::Pointee/Pointee/
  std::string pointee_type_as_string =
      pointee_type.getAsString(printing_policy);
  result += llvm::formatv("raw_ptr<{0}>", pointee_type_as_string);

  return result;
}

static std::pair<std::string, std::string> GetReplacementAndIncludeDirectives(
    const clang::PointerTypeLoc* type_loc,
    const clang::TemplateSpecializationTypeLoc* tst_loc,
    std::string replacement_text,
    const clang::SourceManager& source_manager) {
  clang::SourceLocation begin_loc = tst_loc->getLAngleLoc().getLocWithOffset(1);
  // This is done to skip the star '*' because type_loc's end loc is just
  // before the star position.
  clang::SourceLocation end_loc = type_loc->getEndLoc().getLocWithOffset(1);

  clang::SourceRange replacement_range(begin_loc, end_loc);

  clang::tooling::Replacement replacement(
      source_manager, clang::CharSourceRange::getCharRange(replacement_range),
      replacement_text);
  llvm::StringRef file_path = replacement.getFilePath();
  if (file_path.empty()) {
    return {"", ""};
  }

  std::replace(replacement_text.begin(), replacement_text.end(), '\n', '\0');
  std::string replacement_directive = llvm::formatv(
      "r:::{0}:::{1}:::{2}:::{3}", file_path, replacement.getOffset(),
      replacement.getLength(), replacement_text);

  std::string include_directive =
      llvm::formatv("include-user-header:::{0}:::-1:::-1:::{1}", file_path,
                    kRawPtrIncludePath);

  return {replacement_directive, include_directive};
}

std::string GenerateReplacementForAutoLoc(
    const clang::TypeLoc* auto_loc,
    const std::string& replacement_text,
    const clang::SourceManager& source_manager,
    const clang::ASTContext& ast_context) {
  clang::SourceLocation begin_loc = auto_loc->getBeginLoc();

  clang::SourceRange replacement_range(begin_loc, begin_loc);

  clang::tooling::Replacement replacement(
      source_manager, clang::CharSourceRange::getCharRange(replacement_range),
      replacement_text);
  llvm::StringRef file_path = replacement.getFilePath();

  return llvm::formatv("r:::{0}:::{1}:::{2}:::{3}", file_path,
                       replacement.getOffset(), replacement.getLength(),
                       replacement_text);
}

// Called when the Match registered for it was successfully found in the AST.
// The matches registered represent two categories:
//   1- An adjacency relationship
//      In that case, a node pair is created, using matched node ids, and added
//      to the node_pair list using `OutputHelper::AddEdge`
//   2- A single fieldDecl node match
//      In that case, a single node is created and added to the node_pair list
//      using `OutputHelper::AddSingleNode`
class PotentialNodes : public MatchFinder::MatchCallback {
 public:
  explicit PotentialNodes(OutputHelper& helper) : output_helper_(helper) {}

  PotentialNodes(const PotentialNodes&) = delete;
  PotentialNodes& operator=(const PotentialNodes&) = delete;

  void run(const MatchFinder::MatchResult& result) override {
    const clang::SourceManager& source_manager = *result.SourceManager;
    const clang::ASTContext& ast_context = *result.Context;

    Node lhs;

    if (auto* type_loc = result.Nodes.getNodeAs<clang::PointerTypeLoc>(
            "lhs_argPointerLoc")) {
      std::string replacement_text =
          GenerateNewType(ast_context, type_loc->getType());

      const clang::TemplateSpecializationTypeLoc* tst_loc =
          result.Nodes.getNodeAs<clang::TemplateSpecializationTypeLoc>(
              "lhs_tst_loc");

      auto p = GetReplacementAndIncludeDirectives(
          type_loc, tst_loc, replacement_text, source_manager);
      lhs.replacement = p.first;
      lhs.include_directive = p.second;

      if (const clang::FieldDecl* field_decl =
              result.Nodes.getNodeAs<clang::FieldDecl>("lhs_field")) {
        lhs.is_field = true;
      }

      // To make sure we add all field decls to the graph.(Specifically those
      // not connected to other nodes)
      if (const clang::FieldDecl* field_decl =
              result.Nodes.getNodeAs<clang::FieldDecl>("field_decl")) {
        lhs.is_field = true;
        output_helper_.AddSingleNode(lhs);
        return;
      }

      // RAW_PTR_EXCLUSION is not captured when adding edges between nodes. For
      // that reason, fields annotated with RAW_PTR_EXCLUSION are added as
      // single nodes to the list, this is then used as a starting point to
      // propagate the exclusion to all neighboring nodes.
      if (const clang::FieldDecl* field_decl =
              result.Nodes.getNodeAs<clang::FieldDecl>("excluded_field_decl")) {
        lhs.is_field = true;
        lhs.is_excluded = true;
        output_helper_.AddSingleNode(lhs);
        return;
      }
    } else if (const clang::TypeLoc* auto_loc =
                   result.Nodes.getNodeAs<clang::TypeLoc>("lhs_auto_loc")) {
      lhs.replacement = GenerateReplacementForAutoLoc(
          auto_loc, "replacement_text", source_manager, ast_context);
      lhs.include_directive = lhs.replacement;
      // No need to emit a rewrite for auto type variables. They still need to
      // appear in the graph to propagate the rewrite to non-auto type nodes
      // codes connected to them.
      lhs.has_auto_type = true;
    } else {  // Not supposed to get here
      assert(false);
    }

    Node rhs;
    if (const clang::FieldDecl* field_decl =
            result.Nodes.getNodeAs<clang::FieldDecl>("rhs_field")) {
      rhs.is_field = true;
    }

    if (auto* type_loc = result.Nodes.getNodeAs<clang::PointerTypeLoc>(
            "rhs_argPointerLoc")) {
      std::string replacement_text =
          GenerateNewType(ast_context, type_loc->getType());

      const clang::TemplateSpecializationTypeLoc* tst_loc =
          result.Nodes.getNodeAs<clang::TemplateSpecializationTypeLoc>(
              "rhs_tst_loc");

      auto p = GetReplacementAndIncludeDirectives(
          type_loc, tst_loc, replacement_text, source_manager);

      rhs.replacement = p.first;
      rhs.include_directive = p.second;
    } else if (const clang::TypeLoc* auto_loc =
                   result.Nodes.getNodeAs<clang::TypeLoc>("rhs_auto_loc")) {
      rhs.replacement = GenerateReplacementForAutoLoc(
          auto_loc, "replacement_text", source_manager, ast_context);
      rhs.include_directive = rhs.replacement;
      // No need to emit a rewrite for auto type variables. They still need to
      // appear in the graph to propagate the rewrite to non-auto type nodes
      // codes connected to them.
      rhs.has_auto_type = true;
    } else {  // Not supposed to get here
      assert(false);
    }

    output_helper_.AddEdge(lhs, rhs);
  }

 private:
  OutputHelper& output_helper_;
};

// Called when the Match registered for it was successfully found in the AST.
// The match represents a parmVarDecl Node or an RTNode and the corresponding
// function declaration. Using the function declaration:
//         1- Create a unique key `current_key`
//         2- if the function has a previous declaration or is overridden,
//            retrieve previous decls and create their keys `prev_key`
//         3- for each `prev_key`, add pair `current_key`, `prev_key` to
//         `fct_sig_pairs_`
//
// Using the parmVarDecl or RTNode:
//        1- Create a node
//        2- insert node into `fct_sig_nodes_[current_key]`
//
// At the end of the tool run for a given translation unit, edges between
// corresponding nodes of two adjacent function signatures are created.
class FunctionSignatureNodes : public MatchFinder::MatchCallback {
 public:
  explicit FunctionSignatureNodes(
      std::map<std::string, std::set<Node>>& sig_nodes,
      std::vector<std::pair<std::string, std::string>>& sig_pairs)
      : fct_sig_nodes_(sig_nodes), fct_sig_pairs_(sig_pairs) {}

  FunctionSignatureNodes(const FunctionSignatureNodes&) = delete;
  FunctionSignatureNodes& operator=(const FunctionSignatureNodes&) = delete;

  // Key here means a unique string generated from a function signature
  std::string GetKey(const clang::FunctionDecl* fct_decl,
                     const clang::SourceManager& source_manager) {
    auto name = fct_decl->getNameInfo().getName().getAsString();
    clang::SourceLocation start_loc = fct_decl->getBeginLoc();
    // This is done here to get the spelling loc of a functionDecl. This is
    // needed to handle cases where the function is in a Macro Expansion. In
    // that case, multiple functionDecls will have the same location and this
    // will create problems for argument mapping. Example:
    // MOCK_METHOD0(GetAllStreams, std::vector<DemuxerStream*>());
    clang::SourceRange replacement_range(source_manager.getFileLoc(start_loc),
                                         source_manager.getFileLoc(start_loc));
    clang::tooling::Replacement replacement(
        source_manager, clang::CharSourceRange::getCharRange(replacement_range),
        name.c_str());
    llvm::StringRef file_path = replacement.getFilePath();

    return llvm::formatv("r:::{0}:::{1}:::{2}:::{3}", file_path,
                         replacement.getOffset(), replacement.getLength(),
                         name.c_str());
  }

  void run(const MatchFinder::MatchResult& result) override {
    const clang::SourceManager& source_manager = *result.SourceManager;
    const clang::ASTContext& ast_context = *result.Context;

    const clang::FunctionDecl* fct_decl =
        result.Nodes.getNodeAs<clang::FunctionDecl>("fct_decl");

    std::string key = GetKey(fct_decl, source_manager);
    if (auto* prev_decl = fct_decl->getPreviousDecl()) {
      std::string prev_key = GetKey(prev_decl, source_manager);
      fct_sig_pairs_.push_back({prev_key, key});
    }

    if (const clang::CXXMethodDecl* method_decl =
            result.Nodes.getNodeAs<clang::CXXMethodDecl>("fct_decl")) {
      for (auto* m : method_decl->overridden_methods()) {
        std::string prev_key = GetKey(m, source_manager);
        fct_sig_pairs_.push_back({prev_key, key});
      }
    }

    auto* type_loc =
        result.Nodes.getNodeAs<clang::PointerTypeLoc>("rhs_argPointerLoc");
    Node rhs;
    std::string replacement_text =
        GenerateNewType(ast_context, type_loc->getType());

    const clang::TemplateSpecializationTypeLoc* tst_loc =
        result.Nodes.getNodeAs<clang::TemplateSpecializationTypeLoc>(
            "rhs_tst_loc");

    auto p = GetReplacementAndIncludeDirectives(
        type_loc, tst_loc, replacement_text, source_manager);
    rhs.replacement = p.first;
    rhs.include_directive = p.second;

    fct_sig_nodes_[key].insert(rhs);
  }

 private:
  // Map a function signature, which is modeled as a string representing file
  // location, to it's matched graph nodes (RTNode and ParmVarDecl nodes).
  // Note: `RTNode` represents a function return type node.
  // In order to avoid relying on the order with which nodes are matched in the
  // AST, and to guarantee that nodes are stored in the file declaration order,
  // we use a `std::set<Node>` which sorts Nodes based on the replacement
  // directive which contains the file offset of a given node.
  // Note that a replacement directive has the following format:
  // `r:::<file path>:::<offset>:::<length>:::<replacement text>`
  // The order is important because at the end of a tool run on a
  // translationUnit, for each pair of function signatures, we iterate
  // concurrently through the two sets of Nodes creating edges between nodes
  // that appear at the same index.
  // AddEdge(first function's node1, second function's node1)
  // AddEdge(first function's node2, second function's node2)
  // and so on...
  std::map<std::string, std::set<Node>>& fct_sig_nodes_;
  // Map related function signatures to each other, this is needed for functions
  // with separate definition and declaration, and for overridden functions.
  std::vector<std::pair<std::string, std::string>>& fct_sig_pairs_;
};

// Called when the Match registered for it was successfully found in the AST.
// The matches registered represent three categories:
//   1- Range-based for loops of the form:
//        for (auto* i : ctn_expr) => for (type_name* i : ctn_expr)
//
//   2- Expressions of the form:
//      auto* var = ctn_expr.front(); => auto* var = ctn_expr.front().get();
//
//   3- Expressions of the form:
//      auto* var = ctn_expr[index]; => auto* var = ctn_expr[index].get();
//
// In each of the above cases a node pair is created and added to node_pairs
// using `OutputHelper::AddEdge`
class AffectedPtrExprRewriter : public MatchFinder::MatchCallback {
 public:
  explicit AffectedPtrExprRewriter(OutputHelper& helper)
      : output_helper_(helper) {}

  AffectedPtrExprRewriter(const AffectedPtrExprRewriter&) = delete;
  AffectedPtrExprRewriter& operator=(const AffectedPtrExprRewriter&) = delete;

  void run(const MatchFinder::MatchResult& result) override {
    const clang::SourceManager& source_manager = *result.SourceManager;
    const clang::ASTContext& ast_context = *result.Context;

    Node lhs;
    if (const clang::VarDecl* var_decl =
            result.Nodes.getNodeAs<clang::VarDecl>("autoVarDecl")) {
      auto* type_loc = result.Nodes.getNodeAs<clang::TypeLoc>("autoLoc");

      clang::SourceRange replacement_range(var_decl->getBeginLoc(),
                                           type_loc->getEndLoc());

      std::string replacement_text =
          var_decl->getType()->getPointeeType().getAsString();

      replacement_text = RemovePrefix(replacement_text);
      lhs.replacement = getReplacementDirective(
          replacement_text, replacement_range, source_manager, ast_context);
      lhs.include_directive = lhs.replacement;

    } else if (const clang::CXXMemberCallExpr* member_expr =
                   result.Nodes.getNodeAs<clang::CXXMemberCallExpr>(
                       "affectedMemberExpr")) {
      clang::SourceLocation insertion_loc =
          member_expr->getEndLoc().getLocWithOffset(1);

      clang::SourceRange replacement_range(insertion_loc, insertion_loc);
      std::string replacement_text = ".get()";
      lhs.replacement = getReplacementDirective(
          replacement_text, replacement_range, source_manager, ast_context);
      lhs.include_directive = lhs.replacement;
    } else if (const clang::CXXOperatorCallExpr* op_call_expr =
                   result.Nodes.getNodeAs<clang::CXXOperatorCallExpr>(
                       "affectedOpCall")) {
      clang::SourceLocation insertion_loc =
          op_call_expr->getEndLoc().getLocWithOffset(1);
      clang::SourceRange replacement_range(insertion_loc, insertion_loc);
      std::string replacement_text = ".get()";
      lhs.replacement = getReplacementDirective(
          replacement_text, replacement_range, source_manager, ast_context);
      lhs.include_directive = lhs.replacement;
    } else if (const clang::ParmVarDecl* var_decl =
                   result.Nodes.getNodeAs<clang::ParmVarDecl>(
                       "lambda_parmVarDecl")) {
      auto* type_loc =
          result.Nodes.getNodeAs<clang::TypeLoc>("template_type_param_loc");

      auto* md = result.Nodes.getNodeAs<clang::CXXMethodDecl>("md");

      clang::SourceRange replacement_range(var_decl->getBeginLoc(),
                                           type_loc->getEndLoc());

      std::string replacement_text =
          (md->getParamDecl(var_decl->getFunctionScopeIndex()))
              ->getType()
              ->getPointeeType()
              .getAsString();

      replacement_text = RemovePrefix(replacement_text);
      lhs.replacement = getReplacementDirective(
          replacement_text, replacement_range, source_manager, ast_context);
      lhs.include_directive = lhs.replacement;
    }

    Node rhs;
    if (const clang::FieldDecl* field_decl =
            result.Nodes.getNodeAs<clang::FieldDecl>("rhs_field")) {
      rhs.is_field = true;
    }

    if (auto* type_loc = result.Nodes.getNodeAs<clang::PointerTypeLoc>(
            "rhs_argPointerLoc")) {
      std::string replacement_text =
          GenerateNewType(ast_context, type_loc->getType());

      const clang::TemplateSpecializationTypeLoc* tst_loc =
          result.Nodes.getNodeAs<clang::TemplateSpecializationTypeLoc>(
              "rhs_tst_loc");

      auto p = GetReplacementAndIncludeDirectives(
          type_loc, tst_loc, replacement_text, source_manager);

      rhs.replacement = p.first;
      rhs.include_directive = p.second;
    } else if (const clang::TypeLoc* auto_loc =
                   result.Nodes.getNodeAs<clang::TypeLoc>("rhs_auto_loc")) {
      rhs.replacement = GenerateReplacementForAutoLoc(
          auto_loc, "replacement_text", source_manager, ast_context);
      rhs.include_directive = rhs.replacement;
      rhs.has_auto_type = true;
    } else {
      // Should not get here.
      assert(false);
    }
    output_helper_.AddEdge(lhs, rhs);
  }

 private:
  std::string getReplacementDirective(
      std::string& replacement_text,
      clang::SourceRange replacement_range,
      const clang::SourceManager& source_manager,
      const clang::ASTContext& ast_context) {
    clang::tooling::Replacement replacement(
        source_manager, clang::CharSourceRange::getCharRange(replacement_range),
        replacement_text);
    llvm::StringRef file_path = replacement.getFilePath();

    std::replace(replacement_text.begin(), replacement_text.end(), '\n', '\0');
    return llvm::formatv("r:::{0}:::{1}:::{2}:::{3}", file_path,
                         replacement.getOffset(), replacement.getLength(),
                         replacement_text);
  }

  OutputHelper& output_helper_;
};

class ExprVisitor
    : public clang::ast_matchers::internal::BoundNodesTreeBuilder::Visitor {
 public:
  void visitMatch(
      const clang::ast_matchers::BoundNodes& BoundNodesView) override {
    expr_ = BoundNodesView.getNodeAs<clang::Expr>("expr");
  }
  const clang::Expr* expr_;
};

const clang::Expr* getExpr(
    clang::ast_matchers::internal::BoundNodesTreeBuilder& matches) {
  ExprVisitor v;
  matches.visitMatches(&v);
  return v.expr_;
}

// The goal of this matcher is to handle all possible combinations of matching
// expressions. This works by unpacking the expression nodes recursively (in any
// order they appear in) until we reach the matching lhs_expr/rhs_expr. This
// allows to handle cases like the following:
// std::map<int, std::vector<S*>> member;
// std::vector<S*>::iterator it = member.begin()->second;
AST_MATCHER_P(clang::Expr,
              expr_variations,
              clang::ast_matchers::internal::Matcher<clang::Expr>,
              InnerMatcher) {
  auto iterator = cxxMemberCallExpr(
      callee(functionDecl(
          anyOf(hasName("begin"), hasName("cbegin"), hasName("rbegin"),
                hasName("crbegin"), hasName("end"), hasName("cend"),
                hasName("rend"), hasName("crend"), hasName("find"),
                hasName("upper_bound"), hasName("lower_bound"),
                hasName("equal_range"), hasName("emplace"), hasName("Get")))),
      has(memberExpr(has(expr().bind("expr")))));

  auto search_calls = callExpr(callee(functionDecl(matchesName("find"))),
                               hasArgument(0, expr().bind("expr")));

  auto unary_op = unaryOperator(has(expr().bind("expr")));

  auto reversed_expr = callExpr(callee(functionDecl(hasName("base::Reversed"))),
                                hasArgument(0, expr().bind("expr")));

  auto bracket_op_call = cxxOperatorCallExpr(
      has(declRefExpr(to(cxxMethodDecl(hasName("operator[]"))))),
      has(expr(unless(declRefExpr(to(cxxMethodDecl(hasName("operator[]"))))))
              .bind("expr")));

  auto arrow_op_call = cxxOperatorCallExpr(
      has(declRefExpr(to(cxxMethodDecl(hasName("operator->"))))),
      has(expr(unless(declRefExpr(to(cxxMethodDecl(hasName("operator->"))))))
              .bind("expr")));

  auto star_op_call = cxxOperatorCallExpr(
      has(declRefExpr(to(cxxMethodDecl(hasName("operator*"))))),
      has(expr(unless(declRefExpr(to(cxxMethodDecl(hasName("operator*"))))))
              .bind("expr")));

  auto second_member =
      memberExpr(member(hasName("second")), has(expr().bind("expr")));

  auto items = {iterator,        search_calls,  unary_op,     reversed_expr,
                bracket_op_call, arrow_op_call, star_op_call, second_member};
  clang::ast_matchers::internal::BoundNodesTreeBuilder matches;
  const clang::Expr* n = nullptr;
  std::any_of(items.begin(), items.end(), [&](auto& item) {
    if (item.matches(Node, Finder, &matches)) {
      n = getExpr(matches);
      return true;
    }
    return false;
  });

  if (n) {
    auto matcher = expr_variations(InnerMatcher);
    return matcher.matches(*n, Finder, Builder);
  }
  return InnerMatcher.matches(Node, Finder, Builder);
}

class DeclVisitor
    : public clang::ast_matchers::internal::BoundNodesTreeBuilder::Visitor {
 public:
  void visitMatch(
      const clang::ast_matchers::BoundNodes& BoundNodesView) override {
    decl_ = BoundNodesView.getNodeAs<clang::TypedefNameDecl>("decl");
  }
  const clang::TypedefNameDecl* decl_;
};
const clang::TypedefNameDecl* getDecl(
    clang::ast_matchers::internal::BoundNodesTreeBuilder& matches) {
  DeclVisitor v;
  matches.visitMatches(&v);
  return v.decl_;
}
// This allows us to unpack typedefs recursively until we reach the node
// matching InnerMatcher.
// Example:
// using VECTOR = std::vector<S*>;
// using MAP = std::map<int, VECTOR>;
// MAP member; => this will lead to VECTOR being rewritten.
AST_MATCHER_P(clang::TypedefNameDecl,
              type_def_name_decl,
              clang::ast_matchers::internal::Matcher<clang::TypedefNameDecl>,
              InnerMatcher) {
  auto type_def_matcher = typedefNameDecl(
      hasDescendant(loc(qualType(hasDeclaration(
          typedefNameDecl(unless(isExpansionInSystemHeader())).bind("decl"))))),
      unless(isExpansionInSystemHeader()));

  clang::ast_matchers::internal::BoundNodesTreeBuilder matches;
  if (type_def_matcher.matches(Node, Finder, &matches)) {
    const clang::TypedefNameDecl* n = getDecl(matches);
    auto matcher = type_def_name_decl(InnerMatcher);
    return matcher.matches(*n, Finder, Builder);
  }
  return InnerMatcher.matches(Node, Finder, Builder);
}

class ContainerRewriter {
 public:
  explicit ContainerRewriter(
      MatchFinder& finder,
      OutputHelper& output_helper,
      std::map<std::string, std::set<Node>>& sig_nodes,
      std::vector<std::pair<std::string, std::string>>& sig_pairs,
      const raw_ptr_plugin::FilterFile* excluded_paths)
      : match_finder_(finder),
        affected_ptr_expr_rewriter_(output_helper),
        potentail_nodes_(output_helper),
        fct_sig_nodes_(sig_nodes, sig_pairs),
        paths_to_exclude(excluded_paths) {}

  void addMatchers() {
    // Assume every container has the three following methods: begin, end, size
    auto container_methods =
        anyOf(allOf(hasMethod(hasName("push_back")),
                    hasMethod(hasName("pop_back")), hasMethod(hasName("size"))),
              allOf(hasMethod(hasName("insert")), hasMethod(hasName("erase")),
                    hasMethod(hasName("size"))),
              allOf(hasMethod(hasName("push")), hasMethod(hasName("pop")),
                    hasMethod(hasName("size"))));

    // Exclude maps as they need special handling to be rewritten.
    // TODO: handle rewriting maps.
    auto excluded_containers = matchesName("map");

    auto supported_containers = anyOf(
        hasDeclaration(classTemplateSpecializationDecl(
            container_methods, unless(excluded_containers))),
        hasDeclaration(typeAliasTemplateDecl(has(typeAliasDecl(
            hasType(qualType(hasDeclaration(classTemplateDecl(has(cxxRecordDecl(
                container_methods, unless(excluded_containers))))))))))));

    auto tst_type_loc = templateSpecializationTypeLoc(
        loc(qualType(supported_containers)),
        hasTemplateArgumentLoc(
            0, hasTypeLoc(loc(qualType(allOf(
                   raw_ptr_plugin::supported_pointer_type(),
                   unless(raw_ptr_plugin::const_char_pointer_type(false))))))));

    auto lhs_location =
        templateSpecializationTypeLoc(
            tst_type_loc,
            hasTemplateArgumentLoc(
                0, hasTypeLoc(pointerTypeLoc().bind("lhs_argPointerLoc"))))
            .bind("lhs_tst_loc");

    auto rhs_location =
        templateSpecializationTypeLoc(
            tst_type_loc,
            hasTemplateArgumentLoc(
                0, hasTypeLoc(pointerTypeLoc().bind("rhs_argPointerLoc"))))
            .bind("rhs_tst_loc");

    auto exclude_callbacks = anyOf(
        hasType(typedefNameDecl(hasType(qualType(hasDeclaration(
            recordDecl(anyOf(hasName("base::RepeatingCallback"),
                             hasName("base::OnceCallback")))))))),
        hasType(qualType(
            hasDeclaration(recordDecl(anyOf(hasName("base::RepeatingCallback"),
                                            hasName("base::OnceCallback")))))));

    auto field_exclusions =
        anyOf(isExpansionInSystemHeader(), raw_ptr_plugin::isInExternCContext(),
              raw_ptr_plugin::isInThirdPartyLocation(),
              raw_ptr_plugin::isInGeneratedLocation(),
              raw_ptr_plugin::ImplicitFieldDeclaration(), exclude_callbacks,
              // Exclude fieldDecls in macros.
              // `raw_ptr_plugin::isInMacroLocation()` is also true for fields
              // annotated with RAW_PTR_EXCLUSION. The annotated fields are not
              // included in `field_exclusions` as they are handled differently
              // by the `excluded_field_decl` matcher.
              allOf(raw_ptr_plugin::isInMacroLocation(),
                    unless(raw_ptr_plugin::isRawPtrExclusionAnnotated())));

    // Supports typedefs as well.
    auto lhs_type_loc =
        anyOf(hasDescendant(loc(qualType(hasDeclaration(typedefNameDecl(
                  type_def_name_decl(hasDescendant(lhs_location))))))),
              hasDescendant(lhs_location));

    // Supports typedefs as well.
    auto rhs_type_loc =
        anyOf(hasDescendant(loc(qualType(hasDeclaration(typedefNameDecl(
                  type_def_name_decl(hasDescendant(rhs_location))))))),
              hasDescendant(rhs_location));

    auto lhs_field =
        fieldDecl(raw_ptr_plugin::hasExplicitFieldDecl(lhs_type_loc),
                  unless(field_exclusions))
            .bind("lhs_field");
    auto rhs_field =
        fieldDecl(raw_ptr_plugin::hasExplicitFieldDecl(rhs_type_loc),
                  unless(field_exclusions))
            .bind("rhs_field");

    auto lhs_var = anyOf(
        varDecl(hasDescendant(loc(qualType(autoType())).bind("lhs_auto_loc"))),
        varDecl(lhs_type_loc).bind("lhs_var"));

    auto rhs_var = anyOf(
        varDecl(hasDescendant(loc(qualType(autoType())).bind("rhs_auto_loc"))),
        varDecl(rhs_type_loc).bind("rhs_var"));

    auto lhs_param =
        parmVarDecl(raw_ptr_plugin::hasExplicitParmVarDecl(lhs_type_loc))
            .bind("lhs_param");

    auto rhs_param =
        parmVarDecl(raw_ptr_plugin::hasExplicitParmVarDecl(rhs_type_loc))
            .bind("rhs_param");

    auto rhs_call_expr =
        callExpr(callee(functionDecl(hasReturnTypeLoc(rhs_type_loc))));

    auto lhs_call_expr =
        callExpr(callee(functionDecl(hasReturnTypeLoc(lhs_type_loc))));

    auto lhs_expr = expr(
        ignoringImpCasts(anyOf(declRefExpr(to(anyOf(lhs_var, lhs_param))),
                               memberExpr(member(lhs_field)), lhs_call_expr)));

    auto rhs_expr = expr(
        ignoringImpCasts(anyOf(declRefExpr(to(anyOf(rhs_var, rhs_param))),
                               memberExpr(member(rhs_field)), rhs_call_expr)));

    // To make sure we add all field decls to the graph.(Specifically those not
    // connected to other nodes)
    auto field_decl =
        fieldDecl(raw_ptr_plugin::hasExplicitFieldDecl(lhs_type_loc),
                  unless(anyOf(field_exclusions,
                               raw_ptr_plugin::isRawPtrExclusionAnnotated())))
            .bind("field_decl");
    match_finder_.addMatcher(field_decl, &potentail_nodes_);

    // Fields annotated with RAW_PTR_EXCLUSION (as well as fields in excluded
    // paths) cannot be filtered using field exclusions. They need to appear in
    // the graph so that we can properly propagate the exclusion to reachable
    // nodes. For this reason, and in order to capture this information,
    // RAW_PTR_EXCLUSION fields are added as single nodes to the list and then
    // used as a starting point to propagate the exclusion before running dfs on
    // the graph.
    auto excluded_field_decl =
        fieldDecl(raw_ptr_plugin::hasExplicitFieldDecl(lhs_type_loc),
                  anyOf(raw_ptr_plugin::isRawPtrExclusionAnnotated(),
                        isInLocationListedInFilterFile(paths_to_exclude)))
            .bind("excluded_field_decl");
    match_finder_.addMatcher(excluded_field_decl, &potentail_nodes_);

    auto ref_cref_move =
        anyOf(hasName("std::move"), hasName("std::ref"), hasName("std::cref"));
    auto rhs_move_call =
        callExpr(callee(functionDecl(ref_cref_move)), hasArgument(0, rhs_expr));

    // This is needed for ternary cond operator true_expr. (cond) ? true_expr :
    // false_expr;
    auto lhs_move_call =
        callExpr(callee(functionDecl(ref_cref_move)), hasArgument(0, lhs_expr));

    auto rhs_cxx_temp_expr = cxxTemporaryObjectExpr(rhs_type_loc);

    auto lhs_cxx_temp_expr = cxxTemporaryObjectExpr(lhs_type_loc);

    // This represents the forms under which an expr could appear on the right
    // hand side of an assignment operation, var construction, or an expr passed
    // as callExpr argument. Examples: rhs_expr, &rhs_expr, *rhs_expr,
    // fct_call(),*fct_call(), &fct_call(), std::move(), .begin();
    auto rhs_expr_variations =
        expr_variations(anyOf(rhs_expr, rhs_move_call, rhs_cxx_temp_expr));

    auto lhs_expr_variations =
        expr_variations(anyOf(lhs_expr, lhs_move_call, lhs_cxx_temp_expr));

    // rewrite affected expressions
    {
      // This is needed to handle container-like types that implement a begin()
      // method. Range-based for loops over such types also need to be
      // rewritten.
      auto ctn_like_type =
          expr(hasType(cxxRecordDecl(has(functionDecl(
                   hasName("begin"), hasReturnTypeLoc(rhs_type_loc))))),
               unless(isExpansionInSystemHeader()));

      auto reversed_expr =
          callExpr(callee(functionDecl(hasName("base::Reversed"))),
                   hasArgument(0, rhs_expr_variations));

      // handles statements of the form: for (auto* i : member/var/param/fct())
      // that should be modified after rewriting the container.
      auto auto_star_in_range_stmt = traverse(
          clang::TK_IgnoreUnlessSpelledInSource,
          cxxForRangeStmt(
              has(varDecl(hasDescendant(loc(qualType(pointsTo(autoType())))
                                            .bind("autoLoc")))
                      .bind("autoVarDecl")),
              has(expr(
                  anyOf(rhs_expr_variations, reversed_expr, ctn_like_type)))));
      match_finder_.addMatcher(auto_star_in_range_stmt,
                               &affected_ptr_expr_rewriter_);

      // handles expressions of the form: auto* var = member.front();
      // This becomes: auto* var = member.front().get();
      auto affected_expr = traverse(
          clang::TK_IgnoreUnlessSpelledInSource,
          declStmt(has(varDecl(
              hasType(pointsTo(autoType())),
              has(cxxMemberCallExpr(
                      callee(
                          functionDecl(anyOf(hasName("front"), hasName("back"),
                                             hasName("at"), hasName("top")))),
                      has(memberExpr(has(expr(rhs_expr_variations)))))
                      .bind("affectedMemberExpr"))))));
      match_finder_.addMatcher(affected_expr, &affected_ptr_expr_rewriter_);

      // handles expressions of the form: auto* var = member[0];
      // This becomes: auto* var = member[0].get();
      auto affected_op_call =
          traverse(clang::TK_IgnoreUnlessSpelledInSource,
                   declStmt(has(varDecl(
                       hasType(pointsTo(autoType())),
                       has(cxxOperatorCallExpr(has(expr(rhs_expr_variations)))
                               .bind("affectedOpCall"))))));
      match_finder_.addMatcher(affected_op_call, &affected_ptr_expr_rewriter_);

      // handles expressions of the form:
      // base::ranges::any_of(view->children(), [](const auto* v) {
      //     ...
      //   });
      // where auto* needs to be rewritten into type_name*.
      auto range_exprs = callExpr(
          callee(functionDecl(anyOf(
              matchesName("find"), matchesName("any_of"), matchesName("all_of"),
              matchesName("transform"), matchesName("copy"),
              matchesName("accumulate"), matchesName("count")))),
          hasArgument(0, traverse(clang::TK_IgnoreUnlessSpelledInSource,
                                  expr(anyOf(rhs_expr_variations, reversed_expr,
                                             ctn_like_type)))),
          hasAnyArgument(expr(allOf(
              traverse(
                  clang::TK_IgnoreUnlessSpelledInSource,
                  lambdaExpr(
                      has(parmVarDecl(
                              hasTypeLoc(loc(qualType(anything()))
                                             .bind("template_type_param_loc")),
                              hasType(pointsTo(templateTypeParmType())))
                              .bind("lambda_parmVarDecl")))
                      .bind("lambda_expr")),
              lambdaExpr(has(cxxRecordDecl(has(functionTemplateDecl(has(
                  cxxMethodDecl(isTemplateInstantiation()).bind("md")))))))))));
      match_finder_.addMatcher(range_exprs, &affected_ptr_expr_rewriter_);
    }

    // needed for ternary operator expr: (cond) ? true_expr : false_expr;
    // true_expr => lhs; false_expr => rhs;
    // creates a link between false_expr and true_expr of a ternary conditional
    // operator;
    // handles:
    // (cond) ? (a/&a/*a/std::move(a)/fct()/*fct()/&fct()/a.begin()) :
    // (b/&b/*b/std::move(b)/fct()/*fct()/&fct()/b.begin())
    auto ternary_cond_expr =
        traverse(clang::TK_IgnoreUnlessSpelledInSource,
                 conditionalOperator(hasTrueExpression(lhs_expr_variations),
                                     hasFalseExpression(rhs_expr_variations),
                                     unless(isExpansionInSystemHeader())));
    match_finder_.addMatcher(ternary_cond_expr, &potentail_nodes_);

    // Handles assignment:
    // a = b;
    // a = &b;
    // *a = b;
    // *a = *b;
    // a = fct();
    // a = *fct();
    // a = std::move(b);
    // a = &fct();
    // it = member.begin();
    // a = vector<S*>();
    // a = (cond) ? expr1 : expr2;
    auto assignement_relationship = traverse(
        clang::TK_IgnoreUnlessSpelledInSource,
        binaryOperation(hasOperatorName("="),
                        hasOperands(lhs_expr_variations,
                                    anyOf(rhs_expr_variations,
                                          conditionalOperator(hasTrueExpression(
                                              rhs_expr_variations)))),
                        unless(isExpansionInSystemHeader())));
    match_finder_.addMatcher(assignement_relationship, &potentail_nodes_);

    // Supports:
    // std::vector<T*>* temp = &member;
    // std::vector<T*>& temp = member;
    // std::vector<T*> temp = *member;
    // std::vector<T*> temp = member;  and other similar stmts.
    // std::vector<T*> temp = init();
    // std::vector<T*> temp = *fct();
    // std::vector<T*>::iterator it = member.begin();
    // std::vector<T*> temp = (cond) ? expr1 : expr2;
    auto var_construction = traverse(
        clang::TK_IgnoreUnlessSpelledInSource,
        varDecl(
            lhs_var,
            has(expr(anyOf(
                rhs_expr_variations,
                conditionalOperator(hasTrueExpression(rhs_expr_variations)),
                cxxConstructExpr(has(expr(anyOf(
                    rhs_expr_variations, conditionalOperator(hasTrueExpression(
                                             rhs_expr_variations))))))))),
            unless(isExpansionInSystemHeader())));
    match_finder_.addMatcher(var_construction, &potentail_nodes_);

    // Supports:
    // return member;
    // return *member;
    // return &member;
    // return fct();
    // return *fct();
    // return std::move(member);
    // return member.begin();
    // return (cond) ? expr1 : expr2;
    auto returned_var_or_member = traverse(
        clang::TK_IgnoreUnlessSpelledInSource,
        returnStmt(
            hasReturnValue(expr(anyOf(
                rhs_expr_variations,
                conditionalOperator(hasTrueExpression(rhs_expr_variations))))),
            unless(isExpansionInSystemHeader()),
            forFunction(functionDecl(hasReturnTypeLoc(lhs_type_loc))
                            .bind("lhs_fct_return")))
            .bind("lhs_stmt"));
    match_finder_.addMatcher(returned_var_or_member, &potentail_nodes_);

    // Handles expressions of the form member(arg).
    // A(const std::vector<T*>& arg): member(arg){}
    // member(init());
    // member(*fct());
    // member2(&member1);
    auto ctor_initilizer = traverse(
        clang::TK_IgnoreUnlessSpelledInSource,
        cxxCtorInitializer(withInitializer(anyOf(
                               cxxConstructExpr(has(expr(rhs_expr_variations))),
                               rhs_expr_variations)),
                           forField(lhs_field)));

    match_finder_.addMatcher(ctor_initilizer, &potentail_nodes_);

    // link var/field passed as function arguments to function parameter
    // This handles func(var/member/param), func(&var/member/param),
    // func(*var/member/param), func(func2()), func(&func2())
    // cxxOpCallExprs excluded here since operator= can be invoked as a call
    // expr for classes/structs.
    auto call_expr = traverse(
        clang::TK_IgnoreUnlessSpelledInSource,
        callExpr(forEachArgumentWithParam(
                     expr(anyOf(rhs_expr_variations,
                                conditionalOperator(
                                    hasTrueExpression(rhs_expr_variations)))),
                     lhs_param),
                 unless(isExpansionInSystemHeader()),
                 unless(cxxOperatorCallExpr(hasOperatorName("=")))));
    match_finder_.addMatcher(call_expr, &potentail_nodes_);

    // Handles: member.swap(temp); temp.swap(member);
    auto member_swap_call = traverse(
        clang::TK_IgnoreUnlessSpelledInSource,
        cxxMemberCallExpr(callee(functionDecl(hasName("swap"))),
                          hasArgument(0, rhs_expr_variations),
                          has(memberExpr(has(expr(lhs_expr_variations)))),
                          unless(isExpansionInSystemHeader())));
    match_finder_.addMatcher(member_swap_call, &potentail_nodes_);

    // Handles: std::swap(member, temp); std::swap(temp, member);
    auto std_swap_call =
        traverse(clang::TK_IgnoreUnlessSpelledInSource,
                 callExpr(callee(functionDecl(hasName("std::swap"))),
                          hasArgument(0, lhs_expr_variations),
                          hasArgument(1, rhs_expr_variations),
                          unless(isExpansionInSystemHeader())));
    match_finder_.addMatcher(std_swap_call, &potentail_nodes_);

    auto assert_expect_eq =
        traverse(clang::TK_IgnoreUnlessSpelledInSource,
                 callExpr(anyOf(isExpandedFromMacro("EXPECT_EQ"),
                                isExpandedFromMacro("ASSERT_EQ")),
                          hasArgument(2, lhs_expr_variations),
                          hasArgument(3, rhs_expr_variations),
                          unless(isExpansionInSystemHeader())));
    match_finder_.addMatcher(assert_expect_eq, &potentail_nodes_);

    // Supports:
    // std::vector<S*> temp;
    // Obj o(temp); Obj o{temp};
    // This links temp to the parameter in Obj's constructor.
    auto var_passed_in_constructor = traverse(
        clang::TK_IgnoreUnlessSpelledInSource,
        cxxConstructExpr(forEachArgumentWithParam(
            expr(anyOf(
                rhs_expr_variations,
                conditionalOperator(hasTrueExpression(rhs_expr_variations)))),
            lhs_param)));
    match_finder_.addMatcher(var_passed_in_constructor, &potentail_nodes_);

    // handles Obj o{temp} when Obj has no constructor.
    // This creates a link between the expr and the underlying field.
    auto var_passed_in_initlistExpr = traverse(
        clang::TK_IgnoreUnlessSpelledInSource,
        initListExpr(raw_ptr_plugin::forEachInitExprWithFieldDecl(
            expr(anyOf(
                rhs_expr_variations,
                conditionalOperator(hasTrueExpression(rhs_expr_variations)))),
            lhs_field)));
    match_finder_.addMatcher(var_passed_in_initlistExpr, &potentail_nodes_);

    // This creates a link between each argument passed to the make_unique call
    // and the corresponding constructor parameter.
    auto make_unique_call = traverse(
        clang::TK_IgnoreUnlessSpelledInSource,
        callExpr(
            callee(functionDecl(hasName("std::make_unique"))),
            forEachArg(
                expr(rhs_expr_variations,
                     hasParent(callExpr(
                         callee(functionDecl(hasDescendant(cxxConstructExpr(
                             has(cxxNewExpr(has(cxxConstructExpr(hasDeclaration(
                                 functionDecl().bind("fct_decl"))))))))))))),
                lhs_param)));
    match_finder_.addMatcher(make_unique_call, &potentail_nodes_);

    // This creates a link between lhs and an argument passed to emplace.
    // Example:
    // std::map<int, std::vector<S*>> m;
    // m.emplace(index, o.member);
    // where member has type std::vector<S*>;
    auto emplace_call_with_arg =
        traverse(clang::TK_IgnoreUnlessSpelledInSource,
                 cxxMemberCallExpr(callee(functionDecl(hasName("emplace"))),
                                   has(memberExpr(has(rhs_expr_variations))),
                                   hasAnyArgument(lhs_expr_variations)));
    match_finder_.addMatcher(emplace_call_with_arg, &potentail_nodes_);

    // Handle BindOnce/BindRepeating;
    auto first_arg = hasParent(callExpr(hasArgument(
        0, anyOf(lambdaExpr().bind("lambda_expr"),
                 unaryOperator(
                     has(declRefExpr(to(functionDecl().bind("fct_decl")))))))));

    auto bind_args = traverse(
        clang::TK_IgnoreUnlessSpelledInSource,
        callExpr(
            callee(functionDecl(
                anyOf(hasName("BindOnce"), hasName("BindRepeating")))),
            forEachBindArg(expr(rhs_expr_variations, first_arg), lhs_param)));
    match_finder_.addMatcher(bind_args, &potentail_nodes_);

    // This is useful to handle iteration over maps with vector as value.
    // Example:
    // std::vector<int, std::vector<S*>> m;
    // for (auto& p : m){
    // ...
    // }
    // This creates a link between p and m.
    auto for_range_stmts = traverse(
        clang::TK_IgnoreUnlessSpelledInSource,
        cxxForRangeStmt(
            hasLoopVariable(decl(lhs_var, unless(has(pointerTypeLoc())))),
            hasRangeInit(rhs_expr_variations)));
    match_finder_.addMatcher(for_range_stmts, &potentail_nodes_);

    // Map function declaration signature to function definition signature;
    // This is problematic in the case of callbacks defined in function.
    auto fct_decls_params = traverse(
        clang::TK_IgnoreUnlessSpelledInSource,
        functionDecl(forEachParmVarDecl(rhs_param),
                     unless(anyOf(isExpansionInSystemHeader(),
                                  raw_ptr_plugin::isInMacroLocation())))
            .bind("fct_decl"));
    match_finder_.addMatcher(fct_decls_params, &fct_sig_nodes_);

    auto fct_decls_returns = traverse(
        clang::TK_IgnoreUnlessSpelledInSource,
        functionDecl(hasReturnTypeLoc(rhs_type_loc),
                     unless(anyOf(isExpansionInSystemHeader(),
                                  raw_ptr_plugin::isInMacroLocation())))
            .bind("fct_decl"));
    match_finder_.addMatcher(fct_decls_returns, &fct_sig_nodes_);

    auto macro_fct_signatures = traverse(
        clang::TK_IgnoreUnlessSpelledInSource,
        templateSpecializationTypeLoc(
            rhs_location,
            hasAncestor(
                cxxMethodDecl(raw_ptr_plugin::isInMacroLocation(),
                              anyOf(isExpandedFromMacro("MOCK_METHOD"),
                                    isExpandedFromMacro("MOCK_METHOD0"),
                                    isExpandedFromMacro("MOCK_METHOD1"),
                                    isExpandedFromMacro("MOCK_METHOD2"),
                                    isExpandedFromMacro("MOCK_METHOD3"),
                                    isExpandedFromMacro("MOCK_METHOD4"),
                                    isExpandedFromMacro("MOCK_METHOD5"),
                                    isExpandedFromMacro("MOCK_METHOD6"),
                                    isExpandedFromMacro("MOCK_METHOD7"),
                                    isExpandedFromMacro("MOCK_METHOD8"),
                                    isExpandedFromMacro("MOCK_METHOD9"),
                                    isExpandedFromMacro("MOCK_METHOD10"),
                                    isExpandedFromMacro("MOCK_CONST_METHOD0"),
                                    isExpandedFromMacro("MOCK_CONST_METHOD1"),
                                    isExpandedFromMacro("MOCK_CONST_METHOD2"),
                                    isExpandedFromMacro("MOCK_CONST_METHOD3"),
                                    isExpandedFromMacro("MOCK_CONST_METHOD4"),
                                    isExpandedFromMacro("MOCK_CONST_METHOD5"),
                                    isExpandedFromMacro("MOCK_CONST_METHOD6"),
                                    isExpandedFromMacro("MOCK_CONST_METHOD7"),
                                    isExpandedFromMacro("MOCK_CONST_METHOD8"),
                                    isExpandedFromMacro("MOCK_CONST_METHOD9"),
                                    isExpandedFromMacro("MOCK_CONST_METHOD10")),
                              unless(isExpansionInSystemHeader()))
                    .bind("fct_decl"))));
    match_finder_.addMatcher(macro_fct_signatures, &fct_sig_nodes_);

    // TODO: handle calls to templated functions
  }

 private:
  MatchFinder& match_finder_;
  AffectedPtrExprRewriter affected_ptr_expr_rewriter_;
  PotentialNodes potentail_nodes_;
  FunctionSignatureNodes fct_sig_nodes_;
  const raw_ptr_plugin::FilterFile* paths_to_exclude;
};

}  // namespace

int main(int argc, const char* argv[]) {
  llvm::InitializeNativeTarget();
  llvm::InitializeNativeTargetAsmParser();
  llvm::cl::OptionCategory category(
      "rewrite_templated_container_fields: changes |vector<T*> field_| to "
      "|vector<raw_ptr<T>> field_|.");

  llvm::cl::opt<std::string> override_exclude_paths_param(
      kOverrideExcludePathsParamName, llvm::cl::value_desc("filepath"),
      llvm::cl::desc(
          "override file listing paths to be blocked (not rewritten)"));
  llvm::Expected<clang::tooling::CommonOptionsParser> options =
      clang::tooling::CommonOptionsParser::create(argc, argv, category);
  assert(static_cast<bool>(options));  // Should not return an error.
  clang::tooling::ClangTool tool(options->getCompilations(),
                                 options->getSourcePathList());

  std::unique_ptr<raw_ptr_plugin::FilterFile> paths_to_exclude;
  if (override_exclude_paths_param.getValue().empty()) {
    std::vector<std::string> paths_to_exclude_lines;
    for (auto* const line : kRawPtrManualPathsToIgnore) {
      paths_to_exclude_lines.push_back(line);
    }
    for (auto* const line : kSeparateRepositoryPaths) {
      paths_to_exclude_lines.push_back(line);
    }
    paths_to_exclude =
        std::make_unique<raw_ptr_plugin::FilterFile>(paths_to_exclude_lines);
  } else {
    paths_to_exclude = std::make_unique<raw_ptr_plugin::FilterFile>(
        override_exclude_paths_param,
        override_exclude_paths_param.ArgStr.str());
  }

  // Map a function signature, which is modeled as a string representing file
  // location, to it's graph nodes (RTNode and ParmVarDecl nodes).
  // RTNode represents a function return type.
  std::map<std::string, std::set<Node>> fct_sig_nodes;
  // Map related function signatures to each other, this is needed for functions
  // with separate definition and declaration, and for overridden functions.
  std::vector<std::pair<std::string, std::string>> fct_sig_pairs;
  OutputHelper output_helper;
  MatchFinder match_finder;
  ContainerRewriter rewriter(match_finder, output_helper, fct_sig_nodes,
                             fct_sig_pairs, paths_to_exclude.get());
  rewriter.addMatchers();

  // Prepare and run the tool.
  std::unique_ptr<clang::tooling::FrontendActionFactory> factory =
      clang::tooling::newFrontendActionFactory(&match_finder);
  int result = tool.run(factory.get());

  // For each pair of adjacent function signatures, create a link between
  // corresponding parameters.
  // 2 functions are said to be adjacent if one overrides the other, or if one
  // is a function definition and the other is that function's declaration.
  for (auto& [l, r] : fct_sig_pairs) {
    if (fct_sig_nodes.find(l) == fct_sig_nodes.end()) {
      continue;
    }
    if (fct_sig_nodes.find(r) == fct_sig_nodes.end()) {
      continue;
    }
    auto& s1 = fct_sig_nodes[l];
    auto& s2 = fct_sig_nodes[r];
    assert(s1.size() == s2.size());
    auto i1 = s1.begin();
    auto i2 = s2.begin();
    while (i1 != s1.end()) {
      output_helper.AddEdge(*i1, *i2);
      i1++;
      i2++;
    }
  }

  // Emits the list of edges.
  output_helper.Emit();

  return result;
}