chromium/tools/clang/spanify/Spanifier.cpp

// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include <assert.h>

#include <map>
#include <set>
#include <string>
#include <vector>

#include "RawPtrHelpers.h"
#include "clang/AST/ASTContext.h"
#include "clang/ASTMatchers/ASTMatchFinder.h"
#include "clang/Basic/SourceLocation.h"
#include "clang/Basic/SourceManager.h"
#include "clang/Tooling/CommonOptionsParser.h"
#include "clang/Tooling/Refactoring.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/TargetSelect.h"

using namespace clang::ast_matchers;

namespace {

const char kBaseSpanIncludePath[] = "base/containers/span.h";

// Include path that needs to be added to all the files where
// base::raw_span<...> replaces a raw_ptr<...>.
const char kBaseRawSpanIncludePath[] = "base/memory/raw_span.h";

// This iterates over function parameters and matches the ones that match
// parm_var_decl_matcher.
AST_MATCHER_P(clang::FunctionDecl,
              forEachParmVarDecl,
              clang::ast_matchers::internal::Matcher<clang::ParmVarDecl>,
              parm_var_decl_matcher) {
  const clang::FunctionDecl& function_decl = Node;

  unsigned num_params = function_decl.getNumParams();
  bool is_matching = false;
  clang::ast_matchers::internal::BoundNodesTreeBuilder result;
  for (unsigned i = 0; i < num_params; i++) {
    const clang::ParmVarDecl* param = function_decl.getParamDecl(i);
    clang::ast_matchers::internal::BoundNodesTreeBuilder param_matches;
    if (parm_var_decl_matcher.matches(*param, Finder, &param_matches)) {
      is_matching = true;
      result.addMatch(param_matches);
    }
  }
  *Builder = std::move(result);
  return is_matching;
}

struct Node {
  bool is_buffer = false;

  // A replacement follows the following format:
  // `r:::<file path>:::<offset>:::<length>:::<replacement text>`
  std::string replacement;

  // An include directive follows the following format:
  // `include-user-header:::<file path>:::-1:::-1:::<include text>`
  std::string include_directive;

  // This is true for nodes representing the following:
  //  - nullptr => size is zero
  //  - calls to new/new[n] => size is 1/n
  //  - constant arrays buf[1024] => size is 1024
  //  - calls to third_party functions that we can't rewrite (they should
  //    provide a size for the pointer returned)
  bool size_info_available = false;

  // This is true for dereference expressions.
  // Example: *buf, *fct(), *(buf++), ...
  bool is_deref_expr = false;

  // This is true for the cases where the lhs node doesn't get rewritten while
  // the rhs does. in that case, we create a special node that adds a `.data()`
  // call to the rhs. Example: ptr[index] = something; => ptr is used as a
  // buffer => gets spanified T* temp = ptr; => temp never used as a buffer =>
  // need to add `.data()` The statement becomes: T* temp = ptr.data();
  bool is_data_change = false;

  bool operator==(const Node& other) const {
    return replacement == other.replacement;
  }

  bool operator<(const Node& other) const {
    return replacement < other.replacement;
  }

  // The resulting string follows the following format:
  // {is_buffer\,r:::<filepath>:::<offset>:::<length>:::<replacement_text>
  //\,include-user-header:::<file path>:::-1:::-1:::<include
  // text>\,size_info_available\,is_deref_expr\,is_data_change}
  // where the booleans are represented as 0 or 1.
  std::string ToString() const {
    return llvm::formatv("{{{0:d}\\,{1}\\,{2}\\,{3:d}\\,{4:d}\\,{5:d}}",
                         is_buffer, replacement, include_directive,
                         size_info_available, is_deref_expr, is_data_change);
  }
};

// Helper class to add edges to the set of node_pairs_;
class OutputHelper {
 public:
  OutputHelper() = default;

  void AddEdge(const Node& lhs, const Node& rhs) {
    node_pairs_.insert(
        llvm::formatv("{0};{1}\n", lhs.ToString(), rhs.ToString()));
  }

  void AddSingleNode(const Node& lhs) {
    node_pairs_.insert(llvm::formatv("{0}\n", lhs.ToString()));
  }

  void Emit() {
    for (const auto& p : node_pairs_) {
      llvm::outs() << p;
    }
  }

 private:
  // This represents a line for every 2 adjacent nodes.
  // The format is: {lhs};{rhs}\n where lhs & rhs are generated using
  // Node::ToString().
  // Buffer expressions are added to the graph as a single node
  // in which case the line is {lhs};\n
  std::set<std::string> node_pairs_;
};

static std::pair<std::string, std::string> GetReplacementAndIncludeDirectives(
    const clang::SourceRange replacement_range,
    std::string replacement_text,
    const clang::SourceManager& source_manager,
    const char* include_path = nullptr,
    bool is_system_include_path = false) {
  clang::tooling::Replacement replacement(
      source_manager, clang::CharSourceRange::getCharRange(replacement_range),
      replacement_text);
  llvm::StringRef file_path = replacement.getFilePath();
  if (file_path.empty()) {
    return {"", ""};
  }
  std::replace(replacement_text.begin(), replacement_text.end(), '\n', '\0');
  std::string replacement_directive = llvm::formatv(
      "r:::{0}:::{1}:::{2}:::{3}", file_path, replacement.getOffset(),
      replacement.getLength(), replacement_text);

  if (!include_path) {
    include_path = kBaseSpanIncludePath;
    is_system_include_path = false;
  }
  std::string include_directive;
  if (is_system_include_path) {
    include_directive = llvm::formatv(
        "include-system-header:::{0}:::-1:::-1:::{1}", file_path, include_path);
  } else {
    include_directive = llvm::formatv(
        "include-user-header:::{0}:::-1:::-1:::{1}", file_path, include_path);
  }

  return {replacement_directive, include_directive};
}

// Clang doesn't seem to be providing correct begin/end locations for
// clang::MemberExpr and clang::DeclRefExpr. This function handles these cases,
// otherwise returns expression's begin_loc and end_loc offset by 1.
clang::SourceRange getExprRange(const clang::Expr* expr) {
  if (const auto* member_expr = clang::dyn_cast<clang::MemberExpr>(expr)) {
    clang::SourceLocation begin_loc = member_expr->getMemberLoc();
    size_t member_name_length = member_expr->getMemberDecl()->getName().size();
    clang::SourceLocation end_loc =
        begin_loc.getLocWithOffset(member_name_length);
    return {begin_loc, end_loc};
  }

  if (const auto* decl_ref = clang::dyn_cast<clang::DeclRefExpr>(expr)) {
    auto name = decl_ref->getNameInfo().getName().getAsString();
    return {decl_ref->getBeginLoc(),
            decl_ref->getEndLoc().getLocWithOffset(name.size())};
  }

  return {expr->getBeginLoc(), expr->getEndLoc().getLocWithOffset(1)};
}

// This functions generates a string representing the converted type from a
// raw pointer type to a base::span type. It handles preservation of
// const/volatile qualifiers and uses a specific printing policy to format the
// underlying pointee type.
std::string GenerateSpanType(const clang::ASTContext& ast_context,
                             const clang::QualType& pointer_type) {
  std::string result;

  clang::QualType pointee_type = pointer_type->getPointeeType();

  // Preserve qualifiers.
  if (pointer_type.isConstQualified()) {
    result += "const ";
  }
  if (pointer_type.isVolatileQualified()) {
    result += "volatile ";
  }

  // Convert pointee type to string.
  clang::PrintingPolicy printing_policy(ast_context.getLangOpts());
  printing_policy.SuppressScope = 1;
  printing_policy.PrintCanonicalTypes = 1;
  std::string pointee_type_as_string =
      pointee_type.getAsString(printing_policy);
  result += llvm::formatv("base::span<{0}>", pointee_type_as_string);

  return result;
}

// It is intentional that this function ignores cast expressions and applies
// the `.data()` addition to the internal expression. if we have:
// type* ptr = reinterpret_cast<type*>(buf);  where buf needs to be rewritten
// to span and ptr doesn't. The `.data()` call is added right after buffer as
// follows: type* ptr = reinterpret_cast<type*>(buf.data());
static clang::SourceRange getSourceRange(
    const MatchFinder::MatchResult& result) {
  if (auto* op =
          result.Nodes.getNodeAs<clang::UnaryOperator>("unaryOperator")) {
    if (op->isPostfix()) {
      return {op->getBeginLoc(), op->getEndLoc().getLocWithOffset(2)};
    }
    auto* expr = result.Nodes.getNodeAs<clang::Expr>("rhs_expr");
    return {op->getBeginLoc(), getExprRange(expr).getEnd()};
  }
  if (auto* op = result.Nodes.getNodeAs<clang::Expr>("binaryOperator")) {
    auto* sub_expr = result.Nodes.getNodeAs<clang::Expr>("bin_op_rhs");
    auto end_loc = getExprRange(sub_expr).getEnd();
    return {op->getBeginLoc(), end_loc};
  }
  if (auto* op = result.Nodes.getNodeAs<clang::CXXOperatorCallExpr>(
          "raw_ptr_operator++")) {
    auto* callee = op->getDirectCallee();
    if (callee->getNumParams() == 0) {  // postfix op++ on raw_ptr;
      auto* expr = result.Nodes.getNodeAs<clang::Expr>("rhs_expr");
      return clang::SourceRange(getExprRange(expr).getEnd());
    }
    return clang::SourceRange(op->getEndLoc().getLocWithOffset(2));
  }

  auto* expr = result.Nodes.getNodeAs<clang::Expr>("rhs_expr");
  return clang::SourceRange(getExprRange(expr).getEnd());
}

static void maybeUpdateSourceRangeIfInMacro(
    const clang::SourceManager& source_manager,
    const MatchFinder::MatchResult& result,
    clang::SourceRange& range) {
  if (!range.isValid() || !range.getBegin().isMacroID()) {
    return;
  }
  // We need to find the reference to the object that might be getting
  // accessed and rewritten to find the location to rewrite. SpellingLocation
  // returns a different position if the source was pointing into the macro
  // definition. See clang::SourceManager for details but relevant section:
  //
  // "Spelling locations represent where the bytes corresponding to a token came
  // from and expansion locations represent where the location is in the user's
  // view. In the case of a macro expansion, for example, the spelling location
  // indicates where the expanded token came from and the expansion location
  // specifies where it was expanded."
  auto* rhs_decl_ref =
      result.Nodes.getNodeAs<clang::DeclRefExpr>("declRefExpr");
  if (!rhs_decl_ref) {
    return;
  }
  // We're extracting the spellingLocation's position and then we'll move the
  // location forward by the length of the variable. This will allow us to
  // insert .data() at the end of the decl_ref.
  clang::SourceLocation correct_start =
      source_manager.getSpellingLoc(rhs_decl_ref->getLocation());

  bool invalid_line, invalid_col = false;
  auto line =
      source_manager.getSpellingLineNumber(correct_start, &invalid_line);
  auto col =
      source_manager.getSpellingColumnNumber(correct_start, &invalid_col);
  assert(correct_start.isValid() && !invalid_line && !invalid_col &&
         "Unable to get SpellingLocation info");
  // Get the name and find the end of the decl_ref.
  std::string name = rhs_decl_ref->getFoundDecl()->getNameAsString();
  clang::SourceLocation correct_end = source_manager.translateLineCol(
      source_manager.getFileID(correct_start), line, col + name.size());
  assert(correct_end.isValid() &&
         "Incorrectly got an End SourceLocation for macro");
  // This returns at the end of the variable being referenced so we can
  // insert .data(), if we wanted it wrapped in params (variable).data()
  // we'd need {correct_start, correct_end} but this doesn't seem needed in
  // macros tested on so far.
  range = clang::SourceRange{correct_end};
}

static Node getNodeFromPointerTypeLoc(const clang::PointerTypeLoc* type_loc,
                                      const MatchFinder::MatchResult& result) {
  const clang::SourceManager& source_manager = *result.SourceManager;
  const clang::ASTContext& ast_context = *result.Context;
  const auto& lang_opts = ast_context.getLangOpts();
  // We are in the case of a function return type loc.
  // This doesn't always generate the right range since type_loc doesn't
  // account for qualifiers (like const). Didn't find a proper way for now
  // to get the location with type qualifiers taken into account.
  clang::SourceRange replacement_range = {
      type_loc->getBeginLoc(), type_loc->getEndLoc().getLocWithOffset(1)};
  std::string initial_text =
      clang::Lexer::getSourceText(
          clang::CharSourceRange::getCharRange(replacement_range),
          source_manager, lang_opts)
          .str();
  initial_text.pop_back();
  std::string replacement_text = "base::span<" + initial_text + ">";
  auto replacement_and_include_pair = GetReplacementAndIncludeDirectives(
      replacement_range, replacement_text, source_manager);
  Node n;
  n.replacement = replacement_and_include_pair.first;
  n.include_directive = replacement_and_include_pair.second;
  return n;
}

static Node getNodeFromRawPtrTypeLoc(
    const clang::TemplateSpecializationTypeLoc* raw_ptr_type_loc,
    const MatchFinder::MatchResult& result) {
  const clang::SourceManager& source_manager = *result.SourceManager;
  auto replacement_range = clang::SourceRange(raw_ptr_type_loc->getBeginLoc(),
                                              raw_ptr_type_loc->getLAngleLoc());

  auto replacement_and_include_pair = GetReplacementAndIncludeDirectives(
      replacement_range, "base::raw_span", source_manager,
      kBaseRawSpanIncludePath);
  Node n;
  n.replacement = replacement_and_include_pair.first;
  n.include_directive = replacement_and_include_pair.second;
  return n;
}

static Node getNodeFromDecl(const clang::DeclaratorDecl* decl,
                            const MatchFinder::MatchResult& result) {
  const clang::SourceManager& source_manager = *result.SourceManager;
  const clang::ASTContext& ast_context = *result.Context;
  clang::SourceRange replacement_range{decl->getBeginLoc(),
                                       decl->getLocation()};
  auto pointer_type = decl->getType();
  auto replacement_text = GenerateSpanType(ast_context, pointer_type);
  auto replacement_and_include_pair = GetReplacementAndIncludeDirectives(
      replacement_range, replacement_text, source_manager);
  Node n;
  n.replacement = replacement_and_include_pair.first;
  n.include_directive = replacement_and_include_pair.second;
  return n;
}

static Node getNodeFromDerefExpr(const clang::Expr* deref_expr,
                                 const MatchFinder::MatchResult& result) {
  const clang::SourceManager& source_manager = *result.SourceManager;
  const clang::ASTContext& ast_context = *result.Context;
  const auto& lang_opts = ast_context.getLangOpts();
  auto source_range = clang::SourceRange(deref_expr->getBeginLoc(),
                                         getSourceRange(result).getEnd());
  std::string initial_text =
      clang::Lexer::getSourceText(
          clang::CharSourceRange::getCharRange(source_range), source_manager,
          lang_opts)
          .str();

  std::string replacement_text = initial_text.substr(1) + "[0]";
  if (result.Nodes.getNodeAs<clang::Expr>("unaryOperator") ||
      result.Nodes.getNodeAs<clang::Expr>("binaryOperator")) {
    replacement_text = "(" + initial_text.substr(1) + ")[0]";
  }

  auto replacement_and_include_pair = GetReplacementAndIncludeDirectives(
      source_range, replacement_text, source_manager);
  Node n;
  n.replacement = replacement_and_include_pair.first;
  n.include_directive = "<empty>";
  n.is_deref_expr = true;
  return n;
}

static Node getNodeFromMemberCallExpr(const clang::CXXMemberCallExpr* get_call,
                                      const char* member_expr_id,
                                      const MatchFinder::MatchResult& result) {
  const clang::SourceManager& source_manager = *result.SourceManager;
  const clang::MemberExpr* member_expr =
      result.Nodes.getNodeAs<clang::MemberExpr>(member_expr_id);
  clang::SourceLocation begin_loc = member_expr->getMemberLoc();
  size_t member_name_length =
      member_expr->getMemberDecl()->getName().size() + 2;
  clang::SourceLocation end_loc =
      begin_loc.getLocWithOffset(member_name_length);
  begin_loc = begin_loc.getLocWithOffset(-1);
  clang::SourceRange replacement_range(begin_loc, end_loc);

  // This deletes the member call expression part. Example:
  // char* ptr = member_.get(); which is then rewritten to
  // span<char> ptr = member_;
  // member_ here is a raw_ptr
  auto replacement_and_include_pair = GetReplacementAndIncludeDirectives(
      replacement_range, " ", source_manager);
  Node n;
  n.replacement = replacement_and_include_pair.first;
  n.include_directive = replacement_and_include_pair.second;
  return n;
}

static Node getNodeFromCallToExternalFunction(
    const MatchFinder::MatchResult& result) {
  const clang::SourceManager& source_manager = *result.SourceManager;
  const clang::ASTContext& ast_context = *result.Context;
  const auto& lang_opts = ast_context.getLangOpts();
  auto rep_range = getSourceRange(result);
  std::string initial_text =
      clang::Lexer::getSourceText(
          clang::CharSourceRange::getCharRange(rep_range), source_manager,
          lang_opts)
          .str();
  std::string replacement_text =
      initial_text.empty() ? ".data()" : "(" + initial_text + ").data()";
  auto replacement_and_include_pair = GetReplacementAndIncludeDirectives(
      rep_range, replacement_text, source_manager);
  Node n;
  n.replacement = replacement_and_include_pair.first;
  n.include_directive = "<empty>";
  n.is_deref_expr = true;
  return n;
}

static Node getNodeFromSizeExpr(const clang::Expr* size_expr,
                                const MatchFinder::MatchResult& result) {
  const clang::SourceManager& source_manager = *result.SourceManager;
  std::string replacement = "<empty>";
  clang::SourceRange replacement_range;
  if (const auto* nullptr_expr =
          result.Nodes.getNodeAs<clang::CXXNullPtrLiteralExpr>(
              "nullptr_expr")) {
    replacement = "{}";
    // The hardcoded offset corresponds to the length of "nullptr" keyword.
    replacement_range = {nullptr_expr->getBeginLoc(),
                         nullptr_expr->getBeginLoc().getLocWithOffset(7)};
  } else {
    // Generate empty insertion just to keep track of the node's loc;
    replacement_range =
        clang::SourceRange(size_expr->getSourceRange().getBegin(),
                           size_expr->getSourceRange().getBegin());
  }

  auto replacement_and_include_pair = GetReplacementAndIncludeDirectives(
      replacement_range, replacement, source_manager);
  Node n;
  n.size_info_available = true;
  n.replacement = replacement_and_include_pair.first;
  n.include_directive = replacement_and_include_pair.second;
  return n;
}

static Node getDataChangeNode(const std::string& lhs_replacement,
                              const MatchFinder::MatchResult& result) {
  const clang::SourceManager& source_manager = *result.SourceManager;
  const clang::ASTContext& ast_context = *result.Context;
  const auto& lang_opts = ast_context.getLangOpts();
  auto rep_range = getSourceRange(result);

  // If we're inside a macro the rep_range computed above is going to be
  // incorrect because it will point into the file where the macro is defined.
  // We need to get the "SpellingLocation", and then we figure out the end of
  // the parameter so we can insert .data() at the end if needed.
  maybeUpdateSourceRangeIfInMacro(source_manager, result, rep_range);

  std::string initial_text =
      clang::Lexer::getSourceText(
          clang::CharSourceRange::getCharRange(rep_range), source_manager,
          lang_opts)
          .str();
  std::string replacement_text =
      initial_text.empty() ? ".data()" : "(" + initial_text + ").data()";
  auto replacement_and_include_pair = GetReplacementAndIncludeDirectives(
      rep_range, replacement_text, source_manager);
  Node data_node;
  data_node.replacement = replacement_and_include_pair.first;
  // We need a way to check whether the lhs node was rewritten, in which
  // case we don't need to add this change. We achieve this by storing the
  // lhs key (the replacement which is unique) in the data_node's include
  // directive.
  data_node.include_directive = lhs_replacement;
  data_node.is_data_change = true;
  return data_node;
}

// Gets the array size as written in the source code (if possible), otherwise
// relies on the compile time value as seen in the ConstantArrayType.
// Returns an empty string in case of error.
std::string getArraySize(const MatchFinder::MatchResult& result) {
  const clang::SourceManager& source_manager = *result.SourceManager;
  const clang::ASTContext& ast_context = *result.Context;
  const auto& lang_opts = ast_context.getLangOpts();

  const auto* type_loc =
      result.Nodes.getNodeAs<clang::TypeLoc>("array_type_loc");

  auto array_type_loc = type_loc->getAs<clang::ArrayTypeLoc>();

  // This is the case for arrays where the size expression is omitted. Example:
  // int a[] = {1,2,3,4};
  // For such cases, we rely on getting the compile-time size from the
  // ConstantArrayType below.
  if (array_type_loc.getLBracketLoc() != array_type_loc.getRBracketLoc()) {
    auto source_range =
        clang::SourceRange(array_type_loc.getLBracketLoc().getLocWithOffset(1),
                           array_type_loc.getRBracketLoc());
    auto size_text = clang::Lexer::getSourceText(
                         clang::CharSourceRange::getCharRange(source_range),
                         source_manager, lang_opts)
                         .str();
    if (!size_text.empty()) {
      return size_text;
    }
  }
  auto* array_type = result.Nodes.getNodeAs<clang::ArrayType>("array_type");
  if (const clang::ConstantArrayType* type =
          clang::dyn_cast<clang::ConstantArrayType>(array_type)) {
    return std::to_string(*type->getSize().getRawData());
  }
  assert(false && "Unable to determine array size.");
}

// Creates a replacement node for c-style arrays on which we invoke operator[].
// These arrays are rewritten to std::array<Type, Size>.
Node getNodeFromArrayType(const MatchFinder::MatchResult& result) {
  const clang::SourceManager& source_manager = *result.SourceManager;
  const clang::ASTContext& ast_context = *result.Context;

  auto* array_type_loc =
      result.Nodes.getNodeAs<clang::TypeLoc>("array_type_loc");
  auto* array_type = result.Nodes.getNodeAs<clang::ArrayType>("array_type");
  auto* array_variable =
      result.Nodes.getNodeAs<clang::VarDecl>("array_variable");

  auto element_type = array_type->getElementType();

  clang::PrintingPolicy printing_policy(ast_context.getLangOpts());
  printing_policy.SuppressScope = 1;
  printing_policy.PrintCanonicalTypes = 1;
  std::string element_type_as_string =
      element_type.getAsString(printing_policy);

  std::string array_size_as_string = getArraySize(result);
  std::string replacement_text =
      llvm::formatv("std::array<{0},{1}>{2}", element_type_as_string,
                    array_size_as_string, array_variable->getNameAsString());

  clang::SourceRange replacement_range = {
      array_type_loc->getSourceRange().getBegin(),
      array_type_loc->getSourceRange().getEnd().getLocWithOffset(1)};

  auto replacement_and_include_pair = GetReplacementAndIncludeDirectives(
      replacement_range, replacement_text, source_manager, "array",
      /* is_system_include_header =*/true);
  Node n;
  n.replacement = replacement_and_include_pair.first;
  n.include_directive = replacement_and_include_pair.second;
  n.size_info_available = true;
  return n;
}

// Called when the Match registered for it was successfully found in the AST.
// The matches registered represent two categories:
//   1- An adjacency relationship
//      In that case, a node pair is created, using matched node ids, and added
//      to the node_pair list using `OutputHelper::AddEdge`
//   2- A single is_buffer node match
//      In that case, a single node is created and added to the node_pair list
//      using `OutputHelper::AddSingleNode`
class PotentialNodes : public MatchFinder::MatchCallback {
 public:
  explicit PotentialNodes(OutputHelper& helper) : output_helper_(helper) {}

  PotentialNodes(const PotentialNodes&) = delete;
  PotentialNodes& operator=(const PotentialNodes&) = delete;

  // Extracts the lhs node from the match result.
  Node getLHSNodeFromMatchResult(const MatchFinder::MatchResult& result) {
    if (auto* type_loc =
            result.Nodes.getNodeAs<clang::PointerTypeLoc>("lhs_type_loc")) {
      return getNodeFromPointerTypeLoc(type_loc, result);
    }

    if (auto* raw_ptr_type_loc =
            result.Nodes.getNodeAs<clang::TemplateSpecializationTypeLoc>(
                "lhs_raw_ptr_type_loc")) {
      return getNodeFromRawPtrTypeLoc(raw_ptr_type_loc, result);
    }

    if (auto* lhs_begin =
            result.Nodes.getNodeAs<clang::DeclaratorDecl>("lhs_begin")) {
      return getNodeFromDecl(lhs_begin, result);
    }

    if (auto* deref_op = result.Nodes.getNodeAs<clang::Expr>("deref_expr")) {
      return getNodeFromDerefExpr(deref_op, result);
    }

    if (auto* get_call = result.Nodes.getNodeAs<clang::CXXMemberCallExpr>(
            "raw_ptr_get_call")) {
      Node n = getNodeFromMemberCallExpr(get_call, "get_member_expr", result);
      n.include_directive = "<empty>";
      n.is_deref_expr = true;
      return n;
    }

    if (result.Nodes.getNodeAs<clang::Expr>(
            "passing_a_buffer_to_third_party_function")) {
      return getNodeFromCallToExternalFunction(result);
    }

    if (result.Nodes.getNodeAs<clang::VarDecl>("array_variable")) {
      return getNodeFromArrayType(result);
    }
    assert(false);
  }

  // Extracts the rhs node from the match result.
  Node getRHSNodeFromMatchResult(const MatchFinder::MatchResult& result) {
    if (auto* type_loc =
            result.Nodes.getNodeAs<clang::PointerTypeLoc>("rhs_type_loc")) {
      return getNodeFromPointerTypeLoc(type_loc, result);
    }

    if (auto* raw_ptr_type_loc =
            result.Nodes.getNodeAs<clang::TemplateSpecializationTypeLoc>(
                "rhs_raw_ptr_type_loc")) {
      return getNodeFromRawPtrTypeLoc(raw_ptr_type_loc, result);
    }

    if (auto* rhs_begin =
            result.Nodes.getNodeAs<clang::DeclaratorDecl>("rhs_begin")) {
      return getNodeFromDecl(rhs_begin, result);
    }

    if (const clang::CXXMemberCallExpr* data_call =
            result.Nodes.getNodeAs<clang::CXXMemberCallExpr>(
                "member_data_call")) {
      auto node =
          getNodeFromMemberCallExpr(data_call, "data_member_expr", result);
      node.size_info_available = true;
      return node;
    }

    if (const clang::Expr* size_expr =
            result.Nodes.getNodeAs<clang::Expr>("size_node")) {
      return getNodeFromSizeExpr(size_expr, result);
    }
    // Not supposed to get here.
    assert(false);
  }

  // MatchFinder::MatchCallback:
  void run(const MatchFinder::MatchResult& result) override {
    Node lhs = getLHSNodeFromMatchResult(result);

    // Buffer usage expressions are added as a single node, return
    // early in this case.
    if (result.Nodes.getNodeAs<clang::Expr>("buffer_expr")) {
      lhs.is_buffer = true;
      output_helper_.AddSingleNode(lhs);
      return;
    }

    Node rhs = getRHSNodeFromMatchResult(result);

    auto* expr = result.Nodes.getNodeAs<clang::Expr>("span_frontier");
    if (expr && !lhs.is_deref_expr && !rhs.size_info_available) {
      // Node to add `.data()`;
      // This is needed in the case where rhs is rewritten and lhs is not.
      // Adding `.data()` is thus needed to extract the pointer since lhs and
      // rhs no longer have the same type.
      Node data_node = getDataChangeNode(lhs.replacement, result);
      output_helper_.AddEdge(data_node, rhs);
    }

    output_helper_.AddEdge(lhs, rhs);
  }

 private:
  OutputHelper& output_helper_;
};

// Called when the registered Match is found in the AST.
//
// The match includes:
// - A parmVarDecl or RTNode
// - Corresponding function declaration
//
// Using the function declaration, this:
// 1. Create a unique key for the current function: `current_key`
// 2. If the function has previous declarations or is overridden:
//    - Retrieve previous declarations
//    - Create keys for each previous declaration: `prev_key`
//    - For each `prev_key`, add the pair (`current_key`, `prev_key`) to
//      `fct_sig_pairs_`
//
// Using the parmVarDecl or RTNode, this:
// 1. Create a node
// 2. Insert the node into `fct_sig_nodes_[current_key]`
//
// At the end of the tool run for a given translation unit, edges between
// corresponding nodes of two adjacent function signatures are created.
class FunctionSignatureNodes : public MatchFinder::MatchCallback {
 public:
  explicit FunctionSignatureNodes(
      std::map<std::string, std::set<Node>>& sig_nodes,
      std::vector<std::pair<std::string, std::string>>& sig_pairs)
      : fct_sig_nodes_(sig_nodes), fct_sig_pairs_(sig_pairs) {}

  FunctionSignatureNodes(const FunctionSignatureNodes&) = delete;
  FunctionSignatureNodes& operator=(const FunctionSignatureNodes&) = delete;

  // Key here means a unique string generated from a function signature
  std::string GetKey(const clang::FunctionDecl* fct_decl,
                     const clang::SourceManager& source_manager) {
    auto name = fct_decl->getNameInfo().getName().getAsString();
    clang::SourceLocation start_loc = fct_decl->getBeginLoc();
    // This is done here to get the spelling loc of a functionDecl. This is
    // needed to handle cases where the function is in a Macro Expansion.
    clang::SourceRange replacement_range(source_manager.getFileLoc(start_loc),
                                         source_manager.getFileLoc(start_loc));
    clang::tooling::Replacement replacement(
        source_manager, clang::CharSourceRange::getCharRange(replacement_range),
        name.c_str());
    llvm::StringRef file_path = replacement.getFilePath();

    return llvm::formatv("r:::{0}:::{1}:::{2}:::{3}", file_path,
                         replacement.getOffset(), replacement.getLength(),
                         name.c_str());
  }

  Node getNodeFromMatchResult(const MatchFinder::MatchResult& result) {
    if (auto* type_loc =
            result.Nodes.getNodeAs<clang::PointerTypeLoc>("rhs_type_loc")) {
      return getNodeFromPointerTypeLoc(type_loc, result);
    }

    if (auto* raw_ptr_type_loc =
            result.Nodes.getNodeAs<clang::TemplateSpecializationTypeLoc>(
                "rhs_raw_ptr_type_loc")) {
      return getNodeFromRawPtrTypeLoc(raw_ptr_type_loc, result);
    }

    // "rhs_begin" match id could refer to a declaration that has a raw_ptr
    // type. Those are handled in getNodeFromRawPtrTypeLoc. We
    // should always check for a "rhs_raw_ptr_type_loc" match id and call
    // getNodeFromRawPtrTypeLoc first.
    if (auto* rhs_begin =
            result.Nodes.getNodeAs<clang::DeclaratorDecl>("rhs_begin")) {
      return getNodeFromDecl(rhs_begin, result);
    }

    // Shouldn't get here.
    assert(false);
  }

  void run(const MatchFinder::MatchResult& result) override {
    const clang::SourceManager& source_manager = *result.SourceManager;
    const clang::FunctionDecl* fct_decl =
        result.Nodes.getNodeAs<clang::FunctionDecl>("fct_decl");
    const clang::CXXMethodDecl* method_decl =
        result.Nodes.getNodeAs<clang::CXXMethodDecl>("fct_decl");

    const std::string current_key = GetKey(fct_decl, source_manager);

    // Function related by separate declaration and definition:
    {
      for (auto* previous_decl = fct_decl->getPreviousDecl(); previous_decl;
           previous_decl = previous_decl->getPreviousDecl()) {
        // TODO(356666773): The `previous_decl` might be part of third_party/.
        // Then it won't be matched by the matcher. So only one of the pair
        // would have a node.
        const std::string previous_key = GetKey(previous_decl, source_manager);
        fct_sig_pairs_.push_back({
            current_key,
            previous_key,
        });
      }
    }

    // Function related by overriding:
    if (method_decl) {
      for (auto* m : method_decl->overridden_methods()) {
        const std::string previous_key = GetKey(m, source_manager);
        fct_sig_pairs_.push_back({
            current_key,
            previous_key,
        });
      }
    }

    Node n = getNodeFromMatchResult(result);
    fct_sig_nodes_[current_key].insert(n);
  }

 private:
  // Map a function signature, which is modeled as a string representing file
  // location, to its matched graph nodes (RTNode and ParmVarDecl nodes).
  // Note: `RTNode` represents a function return type node.
  // In order to avoid relying on the order with which nodes are matched in
  // the AST, and to guarantee that nodes are stored in the file declaration
  // order, we use a `std::set<Node>` which sorts Nodes based on the replacement
  // directive which contains the file offset of a given node.
  // Note that a replacement directive has the following format:
  // `r:::<file path>:::<offset>:::<length>:::<replacement text>`
  // The order is important because at the end of a tool run on a
  // translationUnit, for each pair of function signatures, we iterate
  // concurrently through the two sets of Nodes creating edges between nodes
  // that appear at the same index.
  // AddEdge(first function's node1, second function's node1)
  // AddEdge(first function's node2, second function's node2)
  // and so on...
  std::map<std::string, std::set<Node>>& fct_sig_nodes_;

  // Map related function signatures to each other, this is needed for
  // functions
  // with separate definition and declaration, and for overridden functions.
  std::vector<std::pair<std::string, std::string>>& fct_sig_pairs_;
};

class Spanifier {
 public:
  explicit Spanifier(
      MatchFinder& finder,
      OutputHelper& output_helper,
      std::map<std::string, std::set<Node>>& sig_nodes,
      std::vector<std::pair<std::string, std::string>>& sig_pairs)
      : match_finder_(finder),
        potential_nodes_(output_helper),
        fct_sig_nodes_(sig_nodes, sig_pairs) {}

  void addMatchers() {
    auto exclusions = anyOf(
        isExpansionInSystemHeader(), raw_ptr_plugin::isInExternCContext(),
        raw_ptr_plugin::isInThirdPartyLocation(),
        raw_ptr_plugin::isInGeneratedLocation(),
        raw_ptr_plugin::ImplicitFieldDeclaration(),
        raw_ptr_plugin::isInMacroLocation(),
        hasAncestor(cxxRecordDecl(anyOf(hasName("raw_ptr"), hasName("span")))));

    // Exclude literal strings as these need to become string_view
    auto pointer_type = pointerType(pointee(qualType(unless(anyOf(
        qualType(hasDeclaration(
            cxxRecordDecl(raw_ptr_plugin::isAnonymousStructOrUnion()))),
        hasUnqualifiedDesugaredType(anyOf(functionType(), memberPointerType())),
        hasCanonicalType(
            anyOf(asString("const char"), asString("const wchar_t"),
                  asString("const char8_t"), asString("const char16_t"),
                  asString("const char32_t"))))))));

    auto raw_ptr_type = qualType(
        hasDeclaration(classTemplateSpecializationDecl(hasName("raw_ptr"))));
    auto raw_ptr_type_loc = templateSpecializationTypeLoc(loc(raw_ptr_type));

    auto lhs_type_loc = anyOf(
        hasType(pointer_type),
        allOf(hasType(raw_ptr_type),
              hasDescendant(raw_ptr_type_loc.bind("lhs_raw_ptr_type_loc"))));
    auto rhs_type_loc = anyOf(
        hasType(pointer_type),
        allOf(hasType(raw_ptr_type),
              hasDescendant(raw_ptr_type_loc.bind("rhs_raw_ptr_type_loc"))));

    auto lhs_field =
        fieldDecl(raw_ptr_plugin::hasExplicitFieldDecl(lhs_type_loc),
                  unless(exclusions),
                  unless(hasParent(cxxRecordDecl(hasName("raw_ptr")))))
            .bind("lhs_begin");
    auto rhs_field =
        fieldDecl(raw_ptr_plugin::hasExplicitFieldDecl(rhs_type_loc),
                  unless(exclusions),
                  unless(hasParent(cxxRecordDecl(hasName("raw_ptr")))))
            .bind("rhs_begin");

    auto lhs_var = varDecl(lhs_type_loc, unless(exclusions)).bind("lhs_begin");
    auto rhs_var = varDecl(rhs_type_loc, unless(exclusions)).bind("rhs_begin");

    auto lhs_param =
        parmVarDecl(lhs_type_loc, unless(exclusions)).bind("lhs_begin");

    auto rhs_param =
        parmVarDecl(rhs_type_loc, unless(exclusions)).bind("rhs_begin");

    // Exclude functions returning literal strings as these need to become
    // string_view.
    auto exclude_literal_strings =
        unless(returns(qualType(pointsTo(qualType(hasCanonicalType(
            anyOf(asString("const char"), asString("const wchar_t"),
                  asString("const char8_t"), asString("const char16_t"),
                  asString("const char32_t"))))))));

    auto rhs_call_expr = callExpr(callee(
        functionDecl(hasReturnTypeLoc(pointerTypeLoc().bind("rhs_type_loc")),
                     exclude_literal_strings, unless(exclusions))));

    auto lhs_call_expr = callExpr(callee(
        functionDecl(hasReturnTypeLoc(pointerTypeLoc().bind("lhs_type_loc")),
                     exclude_literal_strings, unless(exclusions))));

    auto lhs_expr = expr(anyOf(declRefExpr(to(anyOf(lhs_var, lhs_param))),
                               memberExpr(member(lhs_field)), lhs_call_expr));

    auto constant_array_exprs =
        declRefExpr(to(anyOf(varDecl(hasType(constantArrayType())),
                             parmVarDecl(hasType(constantArrayType())),
                             fieldDecl(hasType(constantArrayType())))));

    // Matches statements of the form: &buf[n] where buf is a container type
    // (span, vector, array,...).
    auto buff_address_from_container = unaryOperator(
        hasOperatorName("&"),
        hasUnaryOperand(cxxOperatorCallExpr(callee(functionDecl(
            hasName("operator[]"),
            hasParent(cxxRecordDecl(hasMethod(hasName("size")))))))));

    // t* a = buf.data();
    auto member_data_call =
        cxxMemberCallExpr(
            callee(functionDecl(
                hasName("data"),
                hasParent(cxxRecordDecl(hasMethod(hasName("size")))))),
            has(memberExpr().bind("data_member_expr")))
            .bind("member_data_call");

    // Defines nodes that contain size information, these include:
    //  - nullptr => size is zero
    //  - calls to new/new[n] => size is 1/n
    //  - constant arrays buf[1024] => size is 1024
    //  - calls to third_party functions that we can't rewrite (they should
    //    provide a size for the pointer returned)
    // TODO(353710304): Consider handling functions taking in/out args ex:
    //                  void alloc(**ptr);
    // TODO(353710304): Consider making member_data_call and size_node mutually
    //                  exclusive. We rely here on the ordering of expressions
    //                  in the anyOf matcher to first match member_data_call
    //                  which is a subset of size_node.
    auto size_node_matcher = expr(anyOf(
        member_data_call,
        expr(anyOf(callExpr(callee(functionDecl(
                       hasReturnTypeLoc(pointerTypeLoc()),
                       anyOf(raw_ptr_plugin::isInThirdPartyLocation(),
                             isExpansionInSystemHeader(),
                             raw_ptr_plugin::isInExternCContext())))),
                   cxxNullPtrLiteralExpr().bind("nullptr_expr"), cxxNewExpr(),
                   constant_array_exprs, buff_address_from_container))
            .bind("size_node")));

    auto rhs_expr =
        expr(ignoringParenCasts(anyOf(
                 declRefExpr(to(anyOf(rhs_var, rhs_param))).bind("declRefExpr"),
                 memberExpr(member(rhs_field)).bind("memberExpr"),
                 rhs_call_expr.bind("callExpr"))))
            .bind("rhs_expr");

    auto get_calls_on_raw_ptr = cxxMemberCallExpr(
        callee(cxxMethodDecl(hasName("get"), ofClass(hasName("raw_ptr")))),
        has(memberExpr(has(rhs_expr))));

    auto rhs_exprs_without_size_nodes =
        expr(ignoringParenCasts(anyOf(
                 rhs_expr,
                 binaryOperation(hasOperatorName("+"), hasLHS(rhs_expr),
                                 hasRHS(expr().bind("bin_op_rhs")))
                     .bind("binaryOperator"),
                 unaryOperator(hasOperatorName("++"), hasUnaryOperand(rhs_expr))
                     .bind("unaryOperator"),
                 cxxOperatorCallExpr(
                     callee(cxxMethodDecl(ofClass(hasName("raw_ptr")))),
                     hasOperatorName("++"), hasArgument(0, rhs_expr))
                     .bind("raw_ptr_operator++"),
                 get_calls_on_raw_ptr)))
            .bind("span_frontier");

    // This represents the forms under which an expr could appear on the right
    // hand side of an assignment operation, var construction, or an expr passed
    // as callExpr argument. Examples:
    // rhs_expr, rhs_expr++, ++rhs_expr, rhs_expr + n, cast(rhs_expr);
    auto rhs_expr_variations = expr(ignoringParenCasts(
        anyOf(size_node_matcher, rhs_exprs_without_size_nodes)));

    auto lhs_expr_variations = expr(ignoringParenCasts(lhs_expr));

    // Expressions used to decide the pointer is used as a buffer include:
    // expr[n], expr++, ++expr, expr + n, expr += n
    auto buffer_expr1 = traverse(
        clang::TK_IgnoreUnlessSpelledInSource,
        expr(ignoringParenCasts(anyOf(
                 arraySubscriptExpr(hasLHS(lhs_expr_variations)),
                 binaryOperation(
                     anyOf(hasOperatorName("+="), hasOperatorName("+")),
                     hasLHS(lhs_expr_variations)),
                 unaryOperator(hasOperatorName("++"),
                               hasUnaryOperand(lhs_expr_variations)),
                 // for raw_ptr ops
                 cxxOperatorCallExpr(anyOf(hasOverloadedOperatorName("[]"),
                                           hasOperatorName("++")),
                                     hasArgument(0, lhs_expr_variations)))))
            .bind("buffer_expr"));
    match_finder_.addMatcher(buffer_expr1, &potential_nodes_);

    auto buffer_expr2 = traverse(
        clang::TK_IgnoreUnlessSpelledInSource,
        expr(ignoringParenCasts(arraySubscriptExpr(hasLHS(declRefExpr(to(
                 varDecl(hasType(arrayType().bind("array_type")),
                         hasTypeLoc(
                             loc(qualType(anything())).bind("array_type_loc")),
                         unless(exclusions), unless(hasExternalFormalLinkage()))
                     .bind("array_variable")))))))
            .bind("buffer_expr"));
    match_finder_.addMatcher(buffer_expr2, &potential_nodes_);

    auto deref_expression = traverse(
        clang::TK_IgnoreUnlessSpelledInSource,
        expr(anyOf(unaryOperator(hasOperatorName("*"),
                                 hasUnaryOperand(rhs_exprs_without_size_nodes)),
                   cxxOperatorCallExpr(
                       hasOverloadedOperatorName("*"),
                       hasArgument(0, rhs_exprs_without_size_nodes))),
             unless(raw_ptr_plugin::isInMacroLocation()))
            .bind("deref_expr"));
    match_finder_.addMatcher(deref_expression, &potential_nodes_);

    // This is needed to remove the `.get()` call on raw_ptr from rewritten
    // expressions. Example: raw_ptr<T> member; auto* temp = member.get(); if
    // member's type is rewritten to a raw_span<T>, this matcher is used to
    // remove the `.get()` call.
    auto raw_ptr_get_call = traverse(
        clang::TK_IgnoreUnlessSpelledInSource,
        cxxMemberCallExpr(
            callee(cxxMethodDecl(hasName("get"), ofClass(hasName("raw_ptr")))),
            has(memberExpr(has(rhs_expr)).bind("get_member_expr")))
            .bind("raw_ptr_get_call"));
    match_finder_.addMatcher(raw_ptr_get_call, &potential_nodes_);

    // When passing now-span buffers to third_party functions as parameters, we
    // need to add `.data()` to extract the pointer and keep things compiling.
    auto passing_a_buffer_to_external_functions = traverse(
        clang::TK_IgnoreUnlessSpelledInSource,
        callExpr(callee(functionDecl(
                     anyOf(isExpansionInSystemHeader(),
                           raw_ptr_plugin::isInExternCContext(),
                           raw_ptr_plugin::isInThirdPartyLocation()))),
                 forEachArgumentWithParam(
                     expr(rhs_expr_variations,
                          unless(anyOf(
                              castExpr(hasSourceExpression(size_node_matcher)),
                              size_node_matcher)))
                         .bind("passing_a_buffer_to_third_party_function"),
                     parmVarDecl())));
    match_finder_.addMatcher(passing_a_buffer_to_external_functions,
                             &potential_nodes_);

    // Handles assignment:
    // a = b;
    // a = fct();
    // a = reinterpret_cast<>(b);
    // a = (cond) ? expr1 : expr2;
    auto assignement_relationship = traverse(
        clang::TK_IgnoreUnlessSpelledInSource,
        binaryOperation(hasOperatorName("="),
                        hasOperands(lhs_expr_variations,
                                    anyOf(rhs_expr_variations,
                                          conditionalOperator(hasTrueExpression(
                                              rhs_expr_variations)))),
                        unless(isExpansionInSystemHeader())));
    match_finder_.addMatcher(assignement_relationship, &potential_nodes_);

    // Creates the edge from lhs to false_expr in a ternary conditional
    // operator.
    auto assignement_relationship2 = traverse(
        clang::TK_IgnoreUnlessSpelledInSource,
        binaryOperation(hasOperatorName("="),
                        hasOperands(lhs_expr_variations,
                                    conditionalOperator(hasFalseExpression(
                                        rhs_expr_variations))),
                        unless(isExpansionInSystemHeader())));
    match_finder_.addMatcher(assignement_relationship2, &potential_nodes_);

    // Supports:
    // T* temp = member;
    // T* temp = init();
    // T* temp = (cond) ? expr1 : expr2;
    // T* temp = reinterpret_cast<>(b);
    auto var_construction = traverse(
        clang::TK_IgnoreUnlessSpelledInSource,
        varDecl(
            lhs_var,
            has(expr(anyOf(
                rhs_expr_variations,
                conditionalOperator(hasTrueExpression(rhs_expr_variations)),
                cxxConstructExpr(has(expr(anyOf(
                    rhs_expr_variations, conditionalOperator(hasTrueExpression(
                                             rhs_expr_variations))))))))),
            unless(isExpansionInSystemHeader())));
    match_finder_.addMatcher(var_construction, &potential_nodes_);

    // Creates the edge from lhs to false_expr in a ternary conditional
    // operator.
    auto var_construction2 = traverse(
        clang::TK_IgnoreUnlessSpelledInSource,
        varDecl(
            lhs_var,
            has(expr(anyOf(
                conditionalOperator(hasFalseExpression(rhs_expr_variations)),
                cxxConstructExpr(has(expr(conditionalOperator(
                    hasFalseExpression(rhs_expr_variations)))))))),
            unless(isExpansionInSystemHeader())));
    match_finder_.addMatcher(var_construction2, &potential_nodes_);

    // Supports:
    // return member;
    // return fct();
    // return reinterpret_cast(expr);
    // return (cond) ? expr1 : expr2;
    auto returned_var_or_member = traverse(
        clang::TK_IgnoreUnlessSpelledInSource,
        returnStmt(
            hasReturnValue(expr(anyOf(
                rhs_expr_variations,
                conditionalOperator(hasTrueExpression(rhs_expr_variations))))),
            unless(isExpansionInSystemHeader()),
            forFunction(functionDecl(
                hasReturnTypeLoc(pointerTypeLoc().bind("lhs_type_loc")),
                unless(exclusions))))
            .bind("lhs_stmt"));
    match_finder_.addMatcher(returned_var_or_member, &potential_nodes_);

    // Creates the edge from lhs to false_expr in a ternary conditional
    // operator.
    auto returned_var_or_member2 = traverse(
        clang::TK_IgnoreUnlessSpelledInSource,
        returnStmt(hasReturnValue(conditionalOperator(
                       hasFalseExpression(rhs_expr_variations))),
                   unless(isExpansionInSystemHeader()),
                   forFunction(functionDecl(
                       hasReturnTypeLoc(pointerTypeLoc().bind("lhs_type_loc")),
                       unless(exclusions))))
            .bind("lhs_stmt"));
    match_finder_.addMatcher(returned_var_or_member2, &potential_nodes_);

    // Handles expressions of the form member(arg).
    // A(const T* arg): member(arg){}
    // member(init());
    // member(fct());
    auto ctor_initilizer = traverse(
        clang::TK_IgnoreUnlessSpelledInSource,
        cxxCtorInitializer(withInitializer(anyOf(
                               cxxConstructExpr(has(expr(rhs_expr_variations))),
                               rhs_expr_variations)),
                           forField(lhs_field)));

    match_finder_.addMatcher(ctor_initilizer, &potential_nodes_);

    // Supports:
    // S* temp;
    // Obj o(temp); Obj o{temp};
    // This links temp to the parameter in Obj's constructor.
    auto var_passed_in_constructor = traverse(
        clang::TK_IgnoreUnlessSpelledInSource,
        cxxConstructExpr(forEachArgumentWithParam(
            expr(anyOf(
                rhs_expr_variations,
                conditionalOperator(hasTrueExpression(rhs_expr_variations)))),
            lhs_param)));
    match_finder_.addMatcher(var_passed_in_constructor, &potential_nodes_);

    // Creates the edge from lhs to false_expr in a ternary conditional
    // operator.
    auto var_passed_in_constructor2 = traverse(
        clang::TK_IgnoreUnlessSpelledInSource,
        cxxConstructExpr(forEachArgumentWithParam(
            expr(conditionalOperator(hasFalseExpression(rhs_expr_variations))),
            lhs_param)));
    match_finder_.addMatcher(var_passed_in_constructor2, &potential_nodes_);

    // handles Obj o{temp} when Obj has no constructor.
    // This creates a link between the expr and the underlying field.
    auto var_passed_in_initlistExpr = traverse(
        clang::TK_IgnoreUnlessSpelledInSource,
        initListExpr(raw_ptr_plugin::forEachInitExprWithFieldDecl(
            expr(anyOf(
                rhs_expr_variations,
                conditionalOperator(hasTrueExpression(rhs_expr_variations)))),
            lhs_field)));
    match_finder_.addMatcher(var_passed_in_initlistExpr, &potential_nodes_);

    auto var_passed_in_initlistExpr2 = traverse(
        clang::TK_IgnoreUnlessSpelledInSource,
        initListExpr(raw_ptr_plugin::forEachInitExprWithFieldDecl(
            expr(conditionalOperator(hasFalseExpression(rhs_expr_variations))),
            lhs_field)));
    match_finder_.addMatcher(var_passed_in_initlistExpr2, &potential_nodes_);

    // Link var/field passed as function arguments to function parameter
    // This handles func(var/member/param), func(func2())
    // cxxOpCallExprs excluded here since operator= can be invoked as a call
    // expr for classes/structs.
    auto call_expr = traverse(
        clang::TK_IgnoreUnlessSpelledInSource,
        callExpr(forEachArgumentWithParam(
                     expr(anyOf(rhs_expr_variations,
                                conditionalOperator(
                                    hasTrueExpression(rhs_expr_variations)))),
                     lhs_param),
                 unless(isExpansionInSystemHeader()),
                 unless(cxxOperatorCallExpr(hasOperatorName("=")))));
    match_finder_.addMatcher(call_expr, &potential_nodes_);

    // Map function declaration signature to function definition signature;
    // This is problematic in the case of callbacks defined in function.
    auto fct_decls_params =
        traverse(clang::TK_IgnoreUnlessSpelledInSource,
                 functionDecl(forEachParmVarDecl(rhs_param), unless(exclusions))
                     .bind("fct_decl"));
    match_finder_.addMatcher(fct_decls_params, &fct_sig_nodes_);

    auto fct_decls_returns = traverse(
        clang::TK_IgnoreUnlessSpelledInSource,
        functionDecl(hasReturnTypeLoc(pointerTypeLoc().bind("rhs_type_loc")),
                     unless(exclusions))
            .bind("fct_decl"));
    match_finder_.addMatcher(fct_decls_returns, &fct_sig_nodes_);
  }

 private:
  MatchFinder& match_finder_;
  PotentialNodes potential_nodes_;
  FunctionSignatureNodes fct_sig_nodes_;
};

}  // namespace

int main(int argc, const char* argv[]) {
  llvm::InitializeNativeTarget();
  llvm::InitializeNativeTargetAsmParser();
  llvm::cl::OptionCategory category(
      "spanifier: changes"
      " 1- |T* var| to |base::span<T> var|."
      " 2- |raw_ptr<T> var| to |base::raw_span<T> var|");

  llvm::Expected<clang::tooling::CommonOptionsParser> options =
      clang::tooling::CommonOptionsParser::create(argc, argv, category);
  assert(static_cast<bool>(options));  // Should not return an error.
  clang::tooling::ClangTool tool(options->getCompilations(),
                                 options->getSourcePathList());

  // Map a function signature, which is modeled as a string representing file
  // location, to it's graph nodes (RTNode and ParmVarDecl nodes).
  // RTNode represents a function return type.
  std::map<std::string, std::set<Node>> fct_sig_nodes;
  // Map related function signatures to each other, this is needed for functions
  // with separate definition and declaration, and for overridden functions.
  std::vector<std::pair<std::string, std::string>> fct_sig_pairs;
  OutputHelper output_helper;
  MatchFinder match_finder;
  Spanifier rewriter(match_finder, output_helper, fct_sig_nodes, fct_sig_pairs);
  rewriter.addMatchers();

  // Prepare and run the tool.
  std::unique_ptr<clang::tooling::FrontendActionFactory> factory =
      clang::tooling::newFrontendActionFactory(&match_finder);
  int result = tool.run(factory.get());

  // Establish connections between corresponding parameters of adjacent function
  // signatures. Two functions are considered adjacent if one overrides the
  // other or if one is a function declaration while the other is its
  // corresponding definition.
  for (auto& [l, r] : fct_sig_pairs) {
    // By construction, only the left side of the pair is guaranteed to have a
    // matching set of nodes.
    assert(fct_sig_nodes.find(l) != fct_sig_nodes.end());

    // TODO(356666773): Handle the case where both side of the pair haven't
    // been matched. This happens when a function is declared in third_party/,
    // but implemented in first party.
    if (fct_sig_nodes.find(r) == fct_sig_nodes.end()) {
      continue;
    }

    auto& s1 = fct_sig_nodes[l];
    auto& s2 = fct_sig_nodes[r];
    assert(s1.size() == s2.size());
    auto i1 = s1.begin();
    auto i2 = s2.begin();
    while (i1 != s1.end()) {
      output_helper.AddEdge(*i1, *i2);
      output_helper.AddEdge(*i2, *i1);
      i1++;
      i2++;
    }
  }

  // Emits the list of edges.
  output_helper.Emit();
  return result;
}