chromium/tools/clang/base_bind_rewriters/BaseBindRewriters.cpp

// Copyright 2016 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include <assert.h>
#include <stdlib.h>
#include <algorithm>
#include <memory>
#include <string>

#include "clang/AST/ASTContext.h"
#include "clang/AST/ParentMap.h"
#include "clang/ASTMatchers/ASTMatchFinder.h"
#include "clang/ASTMatchers/ASTMatchers.h"
#include "clang/ASTMatchers/ASTMatchersMacros.h"
#include "clang/Analysis/CFG.h"
#include "clang/Basic/SourceManager.h"
#include "clang/Frontend/FrontendActions.h"
#include "clang/Lex/Lexer.h"
#include "clang/Tooling/CommonOptionsParser.h"
#include "clang/Tooling/Refactoring.h"
#include "clang/Tooling/Tooling.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/TargetSelect.h"

using Replacements = std::vector<clang::tooling::Replacement>;
using clang::ASTContext;
using clang::CFG;
using clang::CFGBlock;
using clang::CFGLifetimeEnds;
using clang::CFGStmt;
using clang::CallExpr;
using clang::Decl;
using clang::DeclRefExpr;
using clang::FunctionDecl;
using clang::LambdaExpr;
using clang::Stmt;
using clang::UnaryOperator;
using clang::ast_type_traits::DynTypedNode;
using clang::tooling::CommonOptionsParser;
using namespace clang::ast_matchers;

namespace {

class Rewriter {
 public:
  virtual ~Rewriter() {}
};

// Removes unneeded base::Passed() on a parameter of base::BindOnce().
// Example:
//   // Before
//   base::BindOnce(&Foo, base::Passed(&bar));
//   base::BindOnce(&Foo, base::Passed(std::move(baz)));
//   base::BindOnce(&Foo, base::Passed(qux));
//
//   // After
//   base::BindOnce(&Foo, std::move(bar));
//   base::BindOnce(&Foo, std::move(baz));
//   base::BindOnce(&Foo, std::move(*qux));
class PassedToMoveRewriter : public MatchFinder::MatchCallback,
                             public Rewriter {
 public:
  explicit PassedToMoveRewriter(Replacements* replacements)
      : replacements_(replacements) {}

  StatementMatcher GetMatcher() {
    auto is_passed = namedDecl(hasName("::base::Passed"));
    auto is_bind_once_call = callee(namedDecl(hasName("::base::BindOnce")));

    // Matches base::Passed() call on a base::BindOnce() argument.
    return callExpr(is_bind_once_call,
                    hasAnyArgument(ignoringImplicit(
                        callExpr(callee(is_passed)).bind("target"))));
  }

  void run(const MatchFinder::MatchResult& result) override {
    auto* target = result.Nodes.getNodeAs<CallExpr>("target");
    auto* callee = target->getCallee()->IgnoreImpCasts();

    auto* callee_decl = clang::dyn_cast<DeclRefExpr>(callee)->getDecl();
    auto* passed_decl = clang::dyn_cast<FunctionDecl>(callee_decl);
    auto* param_type = passed_decl->getParamDecl(0)->getType().getTypePtr();

    if (param_type->isRValueReferenceType()) {
      // base::Passed(xxx) -> xxx.
      // The parameter type is already an rvalue reference.
      // Example:
      //   std::unique_ptr<int> foo();
      //   std::unique_ptr<int> bar;
      //   base::Passed(foo());
      //   base::Passed(std::move(bar));
      // In these cases, we can just remove base::Passed.
      auto left = clang::CharSourceRange::getTokenRange(
          result.SourceManager->getSpellingLoc(target->getBeginLoc()),
          result.SourceManager->getSpellingLoc(target->getArg(0)->getExprLoc())
              .getLocWithOffset(-1));
      auto r_paren = clang::CharSourceRange::getTokenRange(
          result.SourceManager->getSpellingLoc(target->getRParenLoc()),
          result.SourceManager->getSpellingLoc(target->getRParenLoc()));
      replacements_->emplace_back(*result.SourceManager, left, " ");
      replacements_->emplace_back(*result.SourceManager, r_paren, " ");
      return;
    }

    if (!param_type->isPointerType())
      return;

    auto* passed_arg = target->getArg(0)->IgnoreImpCasts();
    if (auto* unary = clang::dyn_cast<clang::UnaryOperator>(passed_arg)) {
      if (unary->getOpcode() == clang::UO_AddrOf) {
        // base::Passed(&xxx) -> std::move(xxx).
        auto left = clang::CharSourceRange::getTokenRange(
            result.SourceManager->getSpellingLoc(target->getBeginLoc()),
            result.SourceManager->getSpellingLoc(
                target->getArg(0)->getExprLoc()));
        replacements_->emplace_back(*result.SourceManager, left, "std::move(");
        return;
      }
    }

    // base::Passed(xxx) -> std::move(*xxx)
    auto left = clang::CharSourceRange::getTokenRange(
        result.SourceManager->getSpellingLoc(target->getBeginLoc()),
        result.SourceManager->getSpellingLoc(target->getArg(0)->getExprLoc())
            .getLocWithOffset(-1));
    replacements_->emplace_back(*result.SourceManager, left, "std::move(*");
  }

 private:
  Replacements* replacements_;
};

// Replace base::Bind() and base::BindRepeating() to base::BindOnce() where
// resulting callbacks are implicitly converted into base::OnceCallback.
// Example:
//   // Before
//   base::PostTask(FROM_HERE, base::BindRepeating(&Foo));
//   base::OnceCallback<void()> cb = base::BindRepeating(&Foo);
//
//   // After
//   base::PostTask(FROM_HERE, base::BindOnce(&Foo));
//   base::OnceCallback<void()> cb = base::BindOnce(&Foo);
class BindOnceRewriter : public MatchFinder::MatchCallback, public Rewriter {
 public:
  explicit BindOnceRewriter(Replacements* replacements)
      : replacements_(replacements) {}

  StatementMatcher GetMatcher() {
    auto is_once_callback = hasType(hasCanonicalType(hasDeclaration(
        classTemplateSpecializationDecl(hasName("::base::OnceCallback")))));
    auto is_repeating_callback =
        hasType(hasCanonicalType(hasDeclaration(classTemplateSpecializationDecl(
            hasName("::base::RepeatingCallback")))));

    auto bind_call =
        callExpr(callee(namedDecl(anyOf(hasName("::base::Bind"),
                                        hasName("::base::BindRepeating")))))
            .bind("target");
    auto parameter_construction =
        cxxConstructExpr(is_repeating_callback, argumentCountIs(1),
                         hasArgument(0, ignoringImplicit(bind_call)));
    auto constructor_conversion = cxxConstructExpr(
        is_once_callback, argumentCountIs(1),
        hasArgument(0, ignoringImplicit(parameter_construction)));
    return implicitCastExpr(is_once_callback,
                            hasSourceExpression(constructor_conversion));
  }

  void run(const MatchFinder::MatchResult& result) override {
    auto* target = result.Nodes.getNodeAs<clang::CallExpr>("target");
    auto* callee = target->getCallee();
    auto range = clang::CharSourceRange::getTokenRange(
        result.SourceManager->getSpellingLoc(callee->getEndLoc()),
        result.SourceManager->getSpellingLoc(callee->getEndLoc()));
    replacements_->emplace_back(*result.SourceManager, range, "BindOnce");
  }

 private:
  Replacements* replacements_;
};

// Converts pass-by-const-ref base::RepeatingCallback's to
// pass-by-value. Example:
//   // Before
//   using BarCallback = base::RepeatingCallback<void(void*)>;
//   void Foo(const base::RepeatingCallback<void(int)>& cb);
//   void Bar(const BarCallback& cb);
//
//   // After
//   using BarCallback = base::RepeatingCallback<void(void*)>;
//   void Foo(base::RepeatingCallback<void(int)> cb);
//   void Bar(BarCallback cb);
class PassByValueRewriter : public MatchFinder::MatchCallback, public Rewriter {
 public:
  explicit PassByValueRewriter(Replacements* replacements)
      : replacements_(replacements) {}

  DeclarationMatcher GetMatcher() {
    auto is_repeating_callback =
        namedDecl(hasName("::base::RepeatingCallback"));
    return parmVarDecl(
               hasType(hasCanonicalType(references(is_repeating_callback))))
        .bind("target");
  }

  void run(const MatchFinder::MatchResult& result) override {
    auto* target = result.Nodes.getNodeAs<clang::ParmVarDecl>("target");
    auto qual_type = target->getType();
    auto* ref_type =
        clang::dyn_cast<clang::LValueReferenceType>(qual_type.getTypePtr());
    if (!ref_type || !ref_type->getPointeeType().isLocalConstQualified())
      return;

    // Remove the leading `const` and the following `&`.
    auto type_loc = target->getTypeSourceInfo()->getTypeLoc();
    auto const_keyword = clang::CharSourceRange::getTokenRange(
        result.SourceManager->getSpellingLoc(target->getBeginLoc()),
        result.SourceManager->getSpellingLoc(target->getBeginLoc()));
    auto lvalue_ref = clang::CharSourceRange::getTokenRange(
        result.SourceManager->getSpellingLoc(type_loc.getEndLoc()),
        result.SourceManager->getSpellingLoc(type_loc.getEndLoc()));
    replacements_->emplace_back(*result.SourceManager, const_keyword, " ");
    replacements_->emplace_back(*result.SourceManager, lvalue_ref, " ");
  }

 private:
  Replacements* replacements_;
};

// Adds std::move() to base::RepeatingCallback<> where it looks relevant.
// Example:
//   // Before
//   void Foo(base::RepeatingCallback<void(int)> cb1) {
//     base::RepeatingClosure cb2 = base::BindRepeating(cb1, 42);
//     PostTask(FROM_HERE, cb2);
//   }
//
//   // After
//   void Foo(base::RepeatingCallback<void(int)> cb1) {
//     base::RepeatingClosure cb2 = base::BindRepeating(std::move(cb1), 42);
//     PostTask(FROM_HERE, std::move(cb2));
//   }
class AddStdMoveRewriter : public MatchFinder::MatchCallback, public Rewriter {
 public:
  explicit AddStdMoveRewriter(Replacements* replacements)
      : replacements_(replacements) {}

  StatementMatcher GetMatcher() {
    return declRefExpr(
               hasType(hasCanonicalType(hasDeclaration(
                   namedDecl(hasName("::base::RepeatingCallback"))))),
               anyOf(hasAncestor(cxxConstructorDecl().bind("enclosing_ctor")),
                     hasAncestor(functionDecl().bind("enclosing_func")),
                     hasAncestor(lambdaExpr().bind("enclosing_lambda"))))
        .bind("target");
  }

  // Build Control Flow Graph (CFG) for |stmt| and populate class members with
  // the content of the graph. Returns true if the analysis finished
  // successfully.
  bool ExtractCFGContentToMembers(Stmt* stmt, ASTContext* context) {
    // Try to make a cache entry. The failure implies it's already in the cache.
    auto inserted = cfg_cache_.emplace(stmt, nullptr);
    if (!inserted.second)
      return !!inserted.first->second;

    std::unique_ptr<CFG>& cfg = inserted.first->second;
    CFG::BuildOptions opts;
    opts.AddInitializers = true;
    opts.AddLifetime = true;
    opts.AddStaticInitBranches = true;
    cfg = CFG::buildCFG(nullptr, stmt, context, opts);

    // CFG construction may fail. Report it to the caller.
    if (!cfg)
      return false;
    if (!parent_map_)
      parent_map_ = std::make_unique<clang::ParentMap>(stmt);
    else
      parent_map_->addStmt(stmt);

    // Populate |top_stmts_|, that contains Stmts that is evaluated in its own
    // CFGElement.
    for (auto* block : *cfg) {
      for (auto& elem : *block) {
        if (auto stmt = elem.getAs<CFGStmt>())
          top_stmts_.insert(stmt->getStmt());
      }
    }

    // Populate |enclosing_block_|, that maps a Stmt to a CFGBlock that contains
    // the Stmt.
    std::function<void(const CFGBlock*, const Stmt*)> recursive_set_enclosing =
        [&](const CFGBlock* block, const Stmt* stmt) {
          enclosing_block_[stmt] = block;
          for (auto* c : stmt->children()) {
            if (!c)
              continue;
            if (top_stmts_.find(c) != top_stmts_.end())
              continue;
            recursive_set_enclosing(block, c);
          }
        };
    for (auto* block : *cfg) {
      for (auto& elem : *block) {
        if (auto stmt = elem.getAs<CFGStmt>())
          recursive_set_enclosing(block, stmt->getStmt());
      }
    }

    return true;
  }

  const Stmt* EnclosingCxxStatement(const Stmt* stmt) {
    while (true) {
      const Stmt* parent = parent_map_->getParentIgnoreParenCasts(stmt);
      assert(parent);
      switch (parent->getStmtClass()) {
        case Stmt::CompoundStmtClass:
        case Stmt::ForStmtClass:
        case Stmt::CXXForRangeStmtClass:
        case Stmt::WhileStmtClass:
        case Stmt::DoStmtClass:
        case Stmt::IfStmtClass:

          // Other candidates:
          //   Stmt::CXXTryStmtClass
          //   Stmt::CXXCatchStmtClass
          //   Stmt::CapturedStmtClass
          //   Stmt::SwitchStmtClass
          //   Stmt::SwitchCaseClass
          return stmt;
        default:
          stmt = parent;
          break;
      }
    }
  }

  bool WasPointerTaken(const Stmt* stmt, const Decl* decl) {
    std::function<bool(const Stmt*)> visit_stmt = [&](const Stmt* stmt) {
      if (auto* op = clang::dyn_cast<UnaryOperator>(stmt)) {
        if (op->getOpcode() == clang::UO_AddrOf) {
          auto* ref = clang::dyn_cast<DeclRefExpr>(op->getSubExpr());
          // |ref| may be null if the sub-expr has a dependent type.
          if (ref && ref->getDecl() == decl)
            return true;
        }
      }

      for (auto* c : stmt->children()) {
        if (!c)
          continue;
        if (visit_stmt(c))
          return true;
      }
      return false;
    };
    return visit_stmt(stmt);
  }

  bool HasCapturingLambda(const Stmt* stmt, const Decl* decl) {
    std::function<bool(const Stmt*)> visit_stmt = [&](const Stmt* stmt) {
      if (auto* l = clang::dyn_cast<LambdaExpr>(stmt)) {
        for (auto c : l->captures()) {
          if (c.getCapturedVar() == decl)
            return true;
        }
      }

      for (auto* c : stmt->children()) {
        if (!c)
          continue;

        if (visit_stmt(c))
          return true;
      }

      return false;
    };
    return visit_stmt(stmt);
  }

  // Returns true if there are multiple occurrences to |decl| in one of C++
  // statements in |stmt|.
  bool HasUnorderedOccurrences(const Decl* decl, const Stmt* stmt) {
    int count = 0;
    std::function<void(const Stmt*)> visit_stmt = [&](const Stmt* s) {
      if (auto* ref = clang::dyn_cast<DeclRefExpr>(s)) {
        if (ref->getDecl() == decl)
          ++count;
      }
      for (auto* c : s->children()) {
        if (!c)
          continue;
        visit_stmt(c);
      }
    };

    visit_stmt(EnclosingCxxStatement(stmt));
    return count > 1;
  }

  void run(const MatchFinder::MatchResult& result) override {
    auto* target = result.Nodes.getNodeAs<clang::DeclRefExpr>("target");
    auto* decl = clang::dyn_cast<clang::VarDecl>(target->getDecl());

    // Other than local variables and parameters are out-of-scope.
    if (!decl || !decl->isLocalVarDeclOrParm())
      return;

    auto qual_type = decl->getType();
    // Qualified variables are out-of-scope. They are likely not movable.
    if (qual_type.getCanonicalType().hasQualifiers())
      return;

    auto* type = qual_type.getTypePtr();
    // References and pointers are out-of-scope.
    if (type->isReferenceType() || type->isPointerType())
      return;

    Stmt* body = nullptr;
    if (auto* ctor = result.Nodes.getNodeAs<LambdaExpr>("enclosing_ctor"))
      return;  // Skip constructor case for now. TBD.
    else if (auto* func =
                 result.Nodes.getNodeAs<FunctionDecl>("enclosing_func"))
      body = func->getBody();
    else if (auto* lambda =
                 result.Nodes.getNodeAs<LambdaExpr>("enclosing_lambda"))
      body = lambda->getBody();
    else
      return;

    // Disable the replacement if there is a lambda that captures |decl|.
    if (HasCapturingLambda(body, decl))
      return;

    // Disable the replacement if the pointer to |decl| is taken in the scope.
    if (WasPointerTaken(body, decl))
      return;

    if (!ExtractCFGContentToMembers(body, result.Context))
      return;

    auto* parent = parent_map_->getParentIgnoreParenCasts(target);
    if (auto* p = clang::dyn_cast<CallExpr>(parent)) {
      auto* callee = p->getCalleeDecl();
      // |callee| may be null if the CallExpr has an unresolved look up.
      if (!callee)
        return;
      auto* callee_decl = clang::dyn_cast<clang::NamedDecl>(callee);
      auto name = callee_decl->getQualifiedNameAsString();

      // Disable the replacement if it's already in std::move() or
      // std::forward().
      if (name == "std::__1::move" || name == "std::__1::forward")
        return;
    } else if (parent->getStmtClass() == Stmt::ReturnStmtClass) {
      // Disable the replacement if it's in a return statement.
      return;
    }

    // If the same C++ statement contains multiple reference to the variable,
    // don't insert std::move() to be conservative.
    if (HasUnorderedOccurrences(decl, target))
      return;

    bool saw_reuse = false;
    ForEachFollowingStmts(target, [&](const Stmt* stmt) {
      if (auto* ref = clang::dyn_cast<DeclRefExpr>(stmt)) {
        if (ref->getDecl() == decl) {
          saw_reuse = true;
          return false;
        }
      }

      // TODO: Detect Reset() and operator=() to stop the traversal.
      return true;
    });
    if (saw_reuse)
      return;

    replacements_->emplace_back(
        *result.SourceManager,
        result.SourceManager->getSpellingLoc(target->getBeginLoc()), 0,
        "std::move(");
    replacements_->emplace_back(
        *result.SourceManager,
        clang::Lexer::getLocForEndOfToken(target->getEndLoc(), 0,
                                          *result.SourceManager,
                                          result.Context->getLangOpts()),
        0, ")");
  }

  // Invokes |handler| for each Stmt that follows |target| until it reaches the
  // end of the lifetime of the variable that |target| references.
  // If |handler| returns false, stops following the current control flow.
  void ForEachFollowingStmts(const DeclRefExpr* target,
                             std::function<bool(const Stmt*)> handler) {
    auto* decl = target->getDecl();
    auto* block = enclosing_block_[target];

    std::set<const clang::CFGBlock*> visited;
    std::vector<const clang::CFGBlock*> stack = {block};

    bool saw_target = false;
    std::function<bool(const Stmt*)> visit_stmt = [&](const Stmt* s) {
      for (auto* t : s->children()) {
        if (!t)
          continue;

        // |t| is evaluated elsewhere if a sub-Stmt is in |top_stmt_|.
        if (top_stmts_.find(t) != top_stmts_.end())
          continue;

        if (!visit_stmt(t))
          return false;
      }

      if (!saw_target) {
        if (s == target)
          saw_target = true;
        return true;
      }

      return handler(s);
    };

    bool visited_initial_block_twice = false;
    while (!stack.empty()) {
      auto* b = stack.back();
      stack.pop_back();
      if (!visited.insert(b).second) {
        if (b != block || visited_initial_block_twice)
          continue;
        visited_initial_block_twice = true;
      }

      bool cont = true;
      for (auto e : *b) {
        if (auto s = e.getAs<CFGStmt>()) {
          if (!visit_stmt(s->getStmt())) {
            cont = false;
            break;
          }
        } else if (auto l = e.getAs<CFGLifetimeEnds>()) {
          if (l->getVarDecl() == decl) {
            cont = false;
            break;
          }
        }
      }

      if (cont) {
        for (auto s : b->succs()) {
          if (!s)
            continue;  // Unreachable block.
          stack.push_back(s);
        }
      }
    }
  }

 private:
  // Function body to CFG.
  std::map<const Stmt*, std::unique_ptr<CFG>> cfg_cache_;

  // Statement to the enclosing CFGBlock.
  std::map<const Stmt*, const CFGBlock*> enclosing_block_;

  // Stmt to its parent Stmt.
  std::unique_ptr<clang::ParentMap> parent_map_;

  // A set of Stmt that a CFGElement has it directly.
  std::set<const Stmt*> top_stmts_;

  Replacements* replacements_;
};

llvm::cl::extrahelp common_help(CommonOptionsParser::HelpMessage);
llvm::cl::OptionCategory rewriter_category("Rewriter Options");

llvm::cl::opt<std::string> rewriter_option(
    "rewriter",
    llvm::cl::desc(R"(One of the name of rewriter to apply.
Available rewriters are:
    remove_unneeded_passed
    bind_to_bind_once
    pass_by_value
    add_std_move
The default is remove_unneeded_passed.
)"),
    llvm::cl::init("remove_unneeded_passed"),
    llvm::cl::cat(rewriter_category));

}  // namespace.

int main(int argc, const char* argv[]) {
  llvm::InitializeNativeTarget();
  llvm::InitializeNativeTargetAsmParser();
  CommonOptionsParser options(argc, argv, rewriter_category);
  clang::tooling::ClangTool tool(options.getCompilations(),
                                 options.getSourcePathList());

  MatchFinder match_finder;
  std::vector<clang::tooling::Replacement> replacements;

  std::unique_ptr<Rewriter> rewriter;
  if (rewriter_option == "remove_unneeded_passed") {
    auto passed_to_move = std::make_unique<PassedToMoveRewriter>(&replacements);
    match_finder.addMatcher(passed_to_move->GetMatcher(), passed_to_move.get());
    rewriter = std::move(passed_to_move);
  } else if (rewriter_option == "bind_to_bind_once") {
    auto bind_once = std::make_unique<BindOnceRewriter>(&replacements);
    match_finder.addMatcher(bind_once->GetMatcher(), bind_once.get());
    rewriter = std::move(bind_once);
  } else if (rewriter_option == "pass_by_value") {
    auto pass_by_value = std::make_unique<PassByValueRewriter>(&replacements);
    match_finder.addMatcher(pass_by_value->GetMatcher(), pass_by_value.get());
    rewriter = std::move(pass_by_value);
  } else if (rewriter_option == "add_std_move") {
    auto add_std_move = std::make_unique<AddStdMoveRewriter>(&replacements);
    match_finder.addMatcher(add_std_move->GetMatcher(), add_std_move.get());
    rewriter = std::move(add_std_move);
  } else {
    abort();
  }

  std::unique_ptr<clang::tooling::FrontendActionFactory> factory =
      clang::tooling::newFrontendActionFactory(&match_finder);
  int result = tool.run(factory.get());
  if (result != 0)
    return result;

  if (replacements.empty())
    return 0;

  // Serialization format is documented in tools/clang/scripts/run_tool.py
  llvm::outs() << "==== BEGIN EDITS ====\n";
  for (const auto& r : replacements) {
    std::string replacement_text = r.getReplacementText().str();
    std::replace(replacement_text.begin(), replacement_text.end(), '\n', '\0');
    llvm::outs() << "r:::" << r.getFilePath() << ":::" << r.getOffset()
                 << ":::" << r.getLength() << ":::" << replacement_text << "\n";
  }
  llvm::outs() << "==== END EDITS ====\n";

  return 0;
}