chromium/tools/clang/rewrite_autofill_form_data/FieldToFunction.cpp

// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//
// Clang tool to change accesses to data members to calling getters or setters.
//
// For example, if `foo` and `bar` are of a type whose `member` is supposed to
// be rewritten, the rewriter replaces `foo.member = bar.member` with
// `foo.set_member(bar.member())`.
//
// In particular, this rewriter is used to convert the structs
// `autofill::FormData` and `autofill::FormFieldData` to classes.

#include <cassert>
#include <iostream>
#include <memory>
#include <regex>
#include <string>

#include "clang/AST/ASTContext.h"
#include "clang/ASTMatchers/ASTMatchFinder.h"
#include "clang/ASTMatchers/ASTMatchers.h"
#include "clang/ASTMatchers/ASTMatchersMacros.h"
#include "clang/Basic/SourceManager.h"
#include "clang/Frontend/CompilerInstance.h"
#include "clang/Frontend/FrontendActions.h"
#include "clang/Lex/Lexer.h"
#include "clang/Tooling/CommonOptionsParser.h"
#include "clang/Tooling/Core/Replacement.h"
#include "clang/Tooling/Tooling.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/TargetSelect.h"

// Prints a clang::SourceLocation or clang::SourceRange.
#define LOG(e)                                                     \
  llvm::errs() << __FILE__ << ":" << __LINE__ << ": " << #e << " " \
               << (e).printToString(*result.SourceManager) << '\n';

namespace {

llvm::cl::extrahelp common_help(
    clang::tooling::CommonOptionsParser::HelpMessage);
llvm::cl::extrahelp more_help(
    "The rewriter turns references to ClassOfInterest::field_of_interest into\n"
    "getter and setter calls.");
llvm::cl::OptionCategory rewriter_category("Rewriter Options");
llvm::cl::opt<std::string> classes_of_interest_option(
    "classes-of-interest",
    llvm::cl::desc("Comma-separated fully qualified names of the classes whose "
                   "field are to be rewritten"),
    llvm::cl::init("::autofill::FormFieldData"),
    llvm::cl::cat(rewriter_category));
llvm::cl::opt<std::string> fields_of_interest_option(
    "fields-of-interest",
    llvm::cl::desc("Comma-separated names of data members of any class of "
                   "interest that are to be rewritten"),
    llvm::cl::init("value"),
    llvm::cl::cat(rewriter_category));

// Generates substitution directives according to the format documented in
// tools/clang/scripts/run_tool.py.
//
// We do not use `clang::tooling::Replacements` because we don't need any
// buffering, and we'd need to implement the serialization of
// `clang::tooling::Replacement` anyway.
class OutputHelper : public clang::tooling::SourceFileCallbacks {
 public:
  OutputHelper() = default;
  ~OutputHelper() = default;

  OutputHelper(const OutputHelper&) = delete;
  OutputHelper& operator=(const OutputHelper&) = delete;

  // Deletes `replacement_range`.
  void Delete(const clang::CharSourceRange& replacement_range,
              const clang::SourceManager& source_manager,
              const clang::LangOptions& lang_opts) {
    Replace(replacement_range, "", source_manager, lang_opts);
  }

  // Replaces `replacement_range` with `replacement_text`.
  void Replace(const clang::CharSourceRange& replacement_range,
               std::string replacement_text,
               const clang::SourceManager& source_manager,
               const clang::LangOptions& lang_opts) {
    clang::tooling::Replacement replacement(source_manager, replacement_range,
                                            replacement_text, lang_opts);
    llvm::StringRef file_path = replacement.getFilePath();
    if (file_path.empty()) {
      return;
    }
    std::replace(replacement_text.begin(), replacement_text.end(), '\n', '\0');
    Add(file_path, replacement.getOffset(), replacement.getLength(),
        replacement_text);
  }

  // Inserts `lhs` and `rhs` to the left and right of `replacement_range`.
  void Wrap(const clang::CharSourceRange& replacement_range,
            std::string_view lhs,
            std::string_view rhs,
            const clang::SourceManager& source_manager,
            const clang::LangOptions& lang_opts) {
    clang::tooling::Replacement replacement(source_manager, replacement_range,
                                            "", lang_opts);
    llvm::StringRef file_path = replacement.getFilePath();
    if (file_path.empty()) {
      return;
    }
    Add(file_path, replacement.getOffset(), 0, lhs);
    Add(file_path, replacement.getOffset() + replacement.getLength(), 0, rhs);
  }

 private:
  // clang::tooling::SourceFileCallbacks:
  bool handleBeginSource(clang::CompilerInstance& compiler) override {
    const clang::FrontendOptions& frontend_options = compiler.getFrontendOpts();

    assert((frontend_options.Inputs.size() == 1) &&
           "run_tool.py should invoke the rewriter one file at a time");
    const clang::FrontendInputFile& input_file = frontend_options.Inputs[0];
    assert(input_file.isFile() &&
           "run_tool.py should invoke the rewriter on actual files");

    current_language_ = input_file.getKind().getLanguage();

    if (ShouldOutput()) {
      llvm::outs() << "==== BEGIN EDITS ====\n";
    }
    return true;  // Report that |handleBeginSource| succeeded.
  }

  void handleEndSource() override {
    if (ShouldOutput()) {
      llvm::outs() << "==== END EDITS ====\n";
    }
  }

  void Add(llvm::StringRef file_path,
           unsigned offset,
           unsigned length,
           std::string_view replacement_text) {
    if (ShouldOutput()) {
      llvm::outs() << "r:::" << file_path << ":::" << offset << ":::" << length
                   << ":::" << replacement_text << "\n";
    }
  }

  bool ShouldOutput() {
    switch (current_language_) {
      case clang::Language::CXX:
      case clang::Language::OpenCLCXX:
      case clang::Language::ObjCXX:
        return true;
      case clang::Language::Unknown:
      case clang::Language::Asm:
      case clang::Language::LLVM_IR:
      case clang::Language::CIR:
      case clang::Language::OpenCL:
      case clang::Language::CUDA:
      case clang::Language::RenderScript:
      case clang::Language::HIP:
      case clang::Language::HLSL:
      case clang::Language::C:
      case clang::Language::ObjC:
        return false;
    }
    assert(false && "Unrecognized clang::Language");
    return false;
  }

  clang::Language current_language_ = clang::Language::Unknown;
};

namespace matchers {

auto IsClassOfInterest() {
  using namespace clang::ast_matchers;
  static auto names = [] {
    llvm::SmallVector<llvm::StringRef> names;
    llvm::StringRef(classes_of_interest_option).split(names, ",");
    return names;
  }();
  return cxxRecordDecl(hasAnyName(names));
}

auto IsFieldOfInterest() {
  using namespace clang::ast_matchers;
  static auto names = [] {
    llvm::SmallVector<llvm::StringRef> names;
    llvm::StringRef(fields_of_interest_option).split(names, ",");
    return names;
  }();
  return fieldDecl(hasAnyName(names));
}

// Matches `object.member` and `object->member` where `member` is a
// member-of-interest of a class-of-interest
auto IsMemberExprOfInterest() {
  using namespace clang::ast_matchers;
  return memberExpr(
             allOf(member(allOf(IsFieldOfInterest(),
                                fieldDecl(hasParent(IsClassOfInterest())))),
                   // Avoids matching memberExprs in defaulted
                   // constructors and operators.
                   unless(hasAncestor(functionDecl(isDefaulted())))))
      .bind("member_expr");
}

// Matches `f.member = foo`.
auto IsWriteExprOfInterest() {
  using namespace clang::ast_matchers;
  return expr(binaryOperation(
                  hasOperatorName("="),
                  hasLHS(ignoringParenImpCasts(IsMemberExprOfInterest())),
                  hasRHS(expr().bind("rhs"))))
      .bind("assignment");
}

// Matches `f.member` and `f->member`, except on the left-hand side of
// assignments.
auto IsReadExprOfInterest() {
  using namespace clang::ast_matchers;
  return expr(
      anyOf(allOf(IsMemberExprOfInterest(),
                  unless(hasAncestor(IsWriteExprOfInterest()))),
            expr(binaryOperation(
                hasOperatorName("="),
                hasRHS(anyOf(IsMemberExprOfInterest(),
                             hasDescendant(IsMemberExprOfInterest())))))));
}

}  // namespace matchers

class BaseRewriter {
 public:
  explicit BaseRewriter(OutputHelper* output_helper)
      : output_helper_(*output_helper) {}

 protected:
  OutputHelper& output_helper_;
};

// Replaces `object.member` with `object.member()`.
class FieldReadAccessRewriter
    : public BaseRewriter,
      public clang::ast_matchers::MatchFinder::MatchCallback {
 public:
  using BaseRewriter::BaseRewriter;

  void AddMatchers(clang::ast_matchers::MatchFinder& match_finder) {
    match_finder.addMatcher(matchers::IsReadExprOfInterest(), this);
  }

 private:
  void run(
      const clang::ast_matchers::MatchFinder::MatchResult& result) override {
    // object.member
    // ^-----------^  "member_expr"
    const auto* member_expr =
        result.Nodes.getNodeAs<clang::MemberExpr>("member_expr");
    assert(member_expr);
    clang::CharSourceRange range = clang::CharSourceRange::getTokenRange(
        clang::SourceRange(member_expr->getMemberLoc()));
    auto source_text = clang::Lexer::getSourceText(
        range, *result.SourceManager, result.Context->getLangOpts());
    std::string replacement_text =
        std::string(source_text.begin(), source_text.end()) + "()";
    output_helper_.Replace(range, replacement_text, *result.SourceManager,
                           result.Context->getLangOpts());
  }
};

// Replaces `object.member = something` with `object.set_value(something)`
// (modulo whitespace).
//
// More precisely, it inserts `set_` before `member`, removes `=`, and puts
// `something` into parentheses. We do it this way to avoid overlapping
// substitutions because `something` might need substitutions on its own.
class FieldWriteAccessRewriter
    : public BaseRewriter,
      public clang::ast_matchers::MatchFinder::MatchCallback {
 public:
  using BaseRewriter::BaseRewriter;

  void AddMatchers(clang::ast_matchers::MatchFinder& match_finder) {
    match_finder.addMatcher(matchers::IsWriteExprOfInterest(), this);
  }

 private:
  void run(
      const clang::ast_matchers::MatchFinder::MatchResult& result) override {
    // object.member = something;
    // ^-----------^              "member_expr"
    //                 ^-------^  "rhs"
    // ^-----------------------^  "assignment"
    const auto* assignment = result.Nodes.getNodeAs<clang::Expr>("assignment");
    const auto* member_expr =
        result.Nodes.getNodeAs<clang::MemberExpr>("member_expr");
    const auto* rhs = result.Nodes.getNodeAs<clang::Expr>("rhs");
    assert(assignment);
    assert(member_expr);
    assert(rhs);
    {
      // Replace `member` with `set_member`.
      clang::CharSourceRange range = clang::CharSourceRange::getTokenRange(
          clang::SourceRange(member_expr->getMemberLoc()));
      output_helper_.Wrap(range, "set_", "", *result.SourceManager,
                          result.Context->getLangOpts());
    }
    {
      // Replace `=` with ``.
      clang::CharSourceRange range =
          clang::CharSourceRange::getTokenRange(assignment->getExprLoc());
      output_helper_.Delete(range, *result.SourceManager,
                            result.Context->getLangOpts());
    }
    {
      // Replace `something` with `(something)`.
      clang::CharSourceRange range =
          clang::CharSourceRange::getTokenRange(rhs->getSourceRange());
      output_helper_.Wrap(range, "(", ")", *result.SourceManager,
                          result.Context->getLangOpts());
    }
  }
};

// Replaces `::testing::Field(..., &ClassOfInterest::field_of_interest, ...)`
// with `::testing::Property(..., &ClassOfInterest::field_of_interest, ...)`.
class TestingFieldRewriter
    : public BaseRewriter,
      public clang::ast_matchers::MatchFinder::MatchCallback {
 public:
  using BaseRewriter::BaseRewriter;

  void AddMatchers(clang::ast_matchers::MatchFinder& match_finder) {
    using namespace clang::ast_matchers;
    using namespace matchers;
    // Matches `Field(&ClassOfInterest::field_of_interest, ...)`.
    auto testing_field =
        callExpr(
            callee(functionDecl(hasName("::testing::Field"))),
            hasAnyArgument(expr(unaryOperator(
                hasOperatorName("&"), hasUnaryOperand(declRefExpr(to(allOf(
                                          hasParent(IsClassOfInterest()),
                                          fieldDecl(IsFieldOfInterest())))))))))
            .bind("call_expr");
    match_finder.addMatcher(testing_field, this);
  }

 private:
  void run(
      const clang::ast_matchers::MatchFinder::MatchResult& result) override {
    const auto* call_expr =
        result.Nodes.getNodeAs<clang::CallExpr>("call_expr");
    assert(call_expr);
    // Replace `::testing::Field` with `::testing::Property`.
    clang::CharSourceRange range = clang::CharSourceRange::getTokenRange(
        call_expr->getCallee()->getSourceRange());
    auto source_text = clang::Lexer::getSourceText(
        range, *result.SourceManager, result.Context->getLangOpts());
    std::string replacement_text =
        std::string(source_text.begin(), source_text.end());
    replacement_text =
        std::regex_replace(replacement_text, std::regex("Field"), "Property");
    output_helper_.Replace(range, replacement_text, *result.SourceManager,
                           result.Context->getLangOpts());
  }
};

class FieldToFunctionRewriter : public BaseRewriter {
 public:
  using BaseRewriter::BaseRewriter;

  void AddMatchers(clang::ast_matchers::MatchFinder& match_finder) {
    frar_.AddMatchers(match_finder);
    fwar_.AddMatchers(match_finder);
    tfr_.AddMatchers(match_finder);
  }

 private:
  FieldReadAccessRewriter frar_{&output_helper_};
  FieldWriteAccessRewriter fwar_{&output_helper_};
  TestingFieldRewriter tfr_{&output_helper_};
};

}  // namespace

int main(int argc, const char* argv[]) {
  llvm::InitializeNativeTarget();
  llvm::InitializeNativeTargetAsmParser();

  llvm::Expected<clang::tooling::CommonOptionsParser> options =
      clang::tooling::CommonOptionsParser::create(argc, argv,
                                                  rewriter_category);
  assert(static_cast<bool>(options));
  clang::tooling::ClangTool tool(options->getCompilations(),
                                 options->getSourcePathList());

  OutputHelper output_helper;
  clang::ast_matchers::MatchFinder match_finder;

  FieldToFunctionRewriter ftf_rewriter(&output_helper);
  ftf_rewriter.AddMatchers(match_finder);

  std::unique_ptr<clang::tooling::FrontendActionFactory> factory =
      clang::tooling::newFrontendActionFactory(&match_finder, &output_helper);
  return tool.run(factory.get());
}