chromium/tools/clang/blink_gc_plugin/CheckTraceVisitor.cpp

// Copyright 2015 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "CheckTraceVisitor.h"

#include <vector>

#include "Config.h"
#include "Edge.h"
#include "RecordInfo.h"

using namespace clang;

CheckTraceVisitor::CheckTraceVisitor(CXXMethodDecl* trace,
                                     RecordInfo* info,
                                     RecordCache* cache)
    : trace_(trace), info_(info), cache_(cache) {}

bool CheckTraceVisitor::VisitMemberExpr(MemberExpr* member) {
  // In weak callbacks, consider any occurrence as a correct usage.
  // TODO: We really want to require that isAlive is checked on manually
  // processed weak fields.
  if (IsWeakCallback()) {
    if (FieldDecl* field = dyn_cast<FieldDecl>(member->getMemberDecl()))
      FoundField(field);
  }
  return true;
}

bool CheckTraceVisitor::VisitCallExpr(CallExpr* call) {
  // In weak callbacks we don't check calls (see VisitMemberExpr).
  if (IsWeakCallback())
    return true;

  Expr* callee = call->getCallee();

  // Trace calls from a templated derived class result in a
  // DependentScopeMemberExpr because the concrete trace call depends on the
  // instantiation of any shared template parameters. In this case the call is
  // "unresolved" and we resort to comparing the syntactic type names.
  if (DependentScopeDeclRefExpr* expr =
          dyn_cast<DependentScopeDeclRefExpr>(callee)) {
    CheckDependentScopeDeclRefExpr(call, expr);
    return true;
  }

  if (ImplicitCastExpr* expr = dyn_cast<ImplicitCastExpr>(callee)) {
    if (CheckImplicitCastExpr(call, expr))
      return true;
  }

  // A tracing call will have either a |visitor| or a |m_field| argument.
  // A registerWeakMembers call will have a |this| argument.
  if (call->getNumArgs() != 1)
    return true;
  Expr* arg = call->getArg(0);

  if (UnresolvedMemberExpr* expr = dyn_cast<UnresolvedMemberExpr>(callee)) {
    // This could be a trace call of a base class, as explained in the
    // comments of CheckTraceBaseCall().
    if (CheckTraceBaseCall(call))
      return true;

    if (expr->getMemberName().getAsString() == kRegisterWeakMembersName)
      MarkAllWeakMembersTraced();

    QualType base = expr->getBaseType();
    if (!base->isPointerType())
      return true;
    CXXRecordDecl* decl = base->getPointeeType()->getAsCXXRecordDecl();
    if (decl)
      CheckTraceFieldCall(expr->getMemberName().getAsString(), decl, arg);
    return true;
  }

  if (CXXMemberCallExpr* expr = dyn_cast<CXXMemberCallExpr>(call)) {
    if (CheckTraceFieldMemberCall(expr) || CheckRegisterWeakMembers(expr))
      return true;
  }

  CheckTraceBaseCall(call);
  return true;
}

bool CheckTraceVisitor::IsTraceCallName(const std::string& name) {
  // Currently, a manually dispatched class cannot have mixin bases (having
  // one would add a vtable which we explicitly check against). This means
  // that we can only make calls to a trace method of the same name. Revisit
  // this if our mixin/vtable assumption changes.
  return name == trace_->getName();
}

CXXRecordDecl* CheckTraceVisitor::GetDependentTemplatedDecl(
    DependentScopeDeclRefExpr* expr) {
  NestedNameSpecifier* qual = expr->getQualifier();
  if (!qual)
    return 0;

  const Type* type = qual->getAsType();
  if (!type)
    return 0;

  return RecordInfo::GetDependentTemplatedDecl(*type);
}

namespace {

class FindFieldVisitor : public RecursiveASTVisitor<FindFieldVisitor> {
 public:
  FindFieldVisitor();
  FieldDecl* field() const;
  bool TraverseMemberExpr(MemberExpr* member);

 private:
  FieldDecl* field_;
};

FindFieldVisitor::FindFieldVisitor() : field_(0) {}

FieldDecl* FindFieldVisitor::field() const {
  return field_;
}

bool FindFieldVisitor::TraverseMemberExpr(MemberExpr* member) {
  if (FieldDecl* field = dyn_cast<FieldDecl>(member->getMemberDecl())) {
    field_ = field;
    return false;
  }
  return true;
}

}  // namespace

void CheckTraceVisitor::CheckDependentScopeDeclRefExpr(
    CallExpr* call,
    DependentScopeDeclRefExpr* expr) {
  std::string fn_name = expr->getDeclName().getAsString();

  // Check for T::Trace(visitor).
  if (NestedNameSpecifier* qual = expr->getQualifier()) {
    if (const Type* type = qual->getAsType()) {
      if (const TemplateTypeParmType* tmpl_parm_type =
              type->getAs<TemplateTypeParmType>()) {
        const unsigned param_index = tmpl_parm_type->getIndex();
        if (param_index >= info_->GetBases().size())
          return;
        info_->GetBases()[param_index].second.MarkTraced();
      }
    }
  }

  CXXRecordDecl* tmpl = GetDependentTemplatedDecl(expr);
  if (!tmpl)
    return;

  // Check for Super<T>::trace(visitor)
  if (call->getNumArgs() == 1 && IsTraceCallName(fn_name)) {
    RecordInfo::Bases::iterator it = info_->GetBases().begin();
    for (; it != info_->GetBases().end(); ++it) {
      if (it->first->getName() == tmpl->getName())
        it->second.MarkTraced();
    }
  }

  // Check for TraceIfNeeded<T>::trace(visitor, &field) where T cannot be
  // resolved
  if (call->getNumArgs() == 2 && fn_name == kTraceName &&
      tmpl->getName() == kTraceIfNeededName) {
    FindFieldVisitor finder;
    finder.TraverseStmt(call->getArg(1));
    if (finder.field())
      FoundField(finder.field());
  }
}

bool CheckTraceVisitor::CheckTraceBaseCall(CallExpr* call) {
  // Checks for "Base::trace(visitor)"-like calls.

  // Checking code for these two variables is shared among MemberExpr* case
  // and UnresolvedMemberCase* case below.
  //
  // For example, if we've got "Base::trace(visitor)" as |call|,
  // callee_record will be "Base", and func_name will be "trace".
  CXXRecordDecl* callee_record = nullptr;
  std::string func_name;

  if (MemberExpr* callee = dyn_cast<MemberExpr>(call->getCallee())) {
    if (!callee->hasQualifier())
      return false;

    FunctionDecl* trace_decl =
        dyn_cast<FunctionDecl>(callee->getMemberDecl());
    if (!trace_decl || !Config::IsTraceMethod(trace_decl))
      return false;

    const Type* type = callee->getQualifier()->getAsType();
    if (!type)
      return false;

    callee_record = type->getAsCXXRecordDecl();
    func_name = std::string(trace_decl->getName());
  } else if (UnresolvedMemberExpr* callee =
             dyn_cast<UnresolvedMemberExpr>(call->getCallee())) {
    // Callee part may become unresolved if the type of the argument
    // ("visitor") is a template parameter and the called function is
    // overloaded.
    //
    // Here, we try to find a function that looks like trace() from the
    // candidate overloaded functions, and if we find one, we assume it is
    // called here.

    CXXMethodDecl* trace_decl = nullptr;
    for (NamedDecl* named_decl : callee->decls()) {
      if (CXXMethodDecl* method_decl = dyn_cast<CXXMethodDecl>(named_decl)) {
        if (Config::IsTraceMethod(method_decl)) {
          trace_decl = method_decl;
          break;
        }
      }
    }
    if (!trace_decl)
      return false;

    // Check if the passed argument is named "visitor".
    if (call->getNumArgs() != 1)
      return false;
    DeclRefExpr* arg = dyn_cast<DeclRefExpr>(call->getArg(0));
    if (!arg || arg->getNameInfo().getAsString() != kVisitorVarName)
      return false;

    callee_record = trace_decl->getParent();
    func_name = callee->getMemberName().getAsString();
  }

  if (!callee_record)
    return false;

  if (!IsTraceCallName(func_name))
    return false;

  for (auto& base : info_->GetBases()) {
    // We want to deal with omitted trace() function in an intermediary
    // class in the class hierarchy, e.g.:
    //     class A : public GarbageCollected<A> { trace() { ... } };
    //     class B : public A { /* No trace(); have nothing to trace. */ };
    //     class C : public B { trace() { B::trace(visitor); } }
    // where, B::trace() is actually A::trace(), and in some cases we get
    // A as |callee_record| instead of B. We somehow need to mark B as
    // traced if we find A::trace() call.
    //
    // To solve this, here we keep going up the class hierarchy as long as
    // they are not required to have a trace method. The implementation is
    // a simple DFS, where |base_records| represents the set of base classes
    // we need to visit.

    std::vector<CXXRecordDecl*> base_records;
    base_records.push_back(base.first);

    while (!base_records.empty()) {
      CXXRecordDecl* base_record = base_records.back();
      base_records.pop_back();

      if (base_record == callee_record) {
        // If we find a matching trace method, pretend the user has written
        // a correct trace() method of the base; in the example above, we
        // find A::trace() here and mark B as correctly traced.
        base.second.MarkTraced();
        return true;
      }

      if (RecordInfo* base_info = cache_->Lookup(base_record)) {
        if (!base_info->RequiresTraceMethod()) {
          // If this base class is not required to have a trace method, then
          // the actual trace method may be defined in an ancestor.
          for (auto& inner_base : base_info->GetBases())
            base_records.push_back(inner_base.first);
        }
      }
    }
  }

  return false;
}

bool CheckTraceVisitor::CheckTraceFieldMemberCall(CXXMemberCallExpr* call) {
  return CheckTraceFieldCall(call->getMethodDecl()->getNameAsString(),
                             call->getRecordDecl(),
                             call->getArg(0));
}

bool CheckTraceVisitor::CheckTraceFieldCall(
    const std::string& name,
    CXXRecordDecl* callee,
    Expr* arg) {
  if (name != kTraceName || !Config::IsVisitor(callee->getName()))
    return false;

  FindFieldVisitor finder;
  finder.TraverseStmt(arg);
  if (finder.field())
    FoundField(finder.field());

  return true;
}

bool CheckTraceVisitor::CheckRegisterWeakMembers(CXXMemberCallExpr* call) {
  CXXMethodDecl* fn = call->getMethodDecl();
  if (fn->getName() != kRegisterWeakMembersName)
    return false;

  if (fn->isTemplateInstantiation()) {
    const TemplateArgumentList& args =
        *fn->getTemplateSpecializationInfo()->TemplateArguments;
    // The second template argument is the callback method.
    if (args.size() > 1 &&
        args[1].getKind() == TemplateArgument::Declaration) {
      if (FunctionDecl* callback =
          dyn_cast<FunctionDecl>(args[1].getAsDecl())) {
        if (callback->hasBody()) {
          CheckTraceVisitor nested_visitor(nullptr, info_, nullptr);
          nested_visitor.TraverseStmt(callback->getBody());
        }
      }
      // TODO: mark all WeakMember<>s as traced even if
      // the body isn't available?
    }
  }
  return true;
}

bool CheckTraceVisitor::IsWeakCallback() const {
  return !trace_;
}

void CheckTraceVisitor::MarkTraced(RecordInfo::Fields::iterator it) {
  // In a weak callback we can't mark strong fields as traced.
  if (IsWeakCallback() && !it->second.edge()->IsWeakMember())
    return;
  it->second.MarkTraced();
}

namespace {
RecordInfo::Fields::iterator FindField(RecordInfo* info, FieldDecl* field) {
  if (Config::IsTemplateInstantiation(info->record())) {
    // Pointer equality on fields does not work for template instantiations.
    // The trace method refers to fields of the template definition which
    // are different from the instantiated fields that need to be traced.
    const std::string& name = field->getNameAsString();
    for (RecordInfo::Fields::iterator it = info->GetFields().begin();
         it != info->GetFields().end(); ++it) {
      if (it->first->getNameAsString() == name) {
        return it;
      }
    }
    return info->GetFields().end();
  } else {
    return info->GetFields().find(field);
  }
}
}  // namespace

void CheckTraceVisitor::FoundField(FieldDecl* field) {
  RecordInfo::Fields::iterator it = FindField(info_, field);
  if (it != info_->GetFields().end()) {
    MarkTraced(it);
  }
}

void CheckTraceVisitor::MarkAllWeakMembersTraced() {
  // If we find a call to registerWeakMembers which is unresolved we
  // unsoundly consider all weak members as traced.
  // TODO: Find out how to validate weak member tracing for unresolved call.
  for (auto& field : info_->GetFields()) {
    if (field.second.edge()->IsWeakMember())
      field.second.MarkTraced();
  }
}

bool CheckTraceVisitor::CheckImplicitCastExpr(CallExpr* call,
                                              ImplicitCastExpr* expr) {
  DeclRefExpr* sub_expr = dyn_cast<DeclRefExpr>(expr->getSubExpr());
  if (!sub_expr)
    return false;
  NestedNameSpecifier* qualifier = sub_expr->getQualifier();
  if (!qualifier)
    return false;
  CXXRecordDecl* class_decl = qualifier->getAsRecordDecl();
  if (!class_decl)
    return false;
  NamedDecl* found_decl = sub_expr->getFoundDecl();
  std::string fn_name = found_decl->getNameAsString();
  // Check for TraceIfNeeded<T>::trace(visitor, &field) where T can be resolved
  if (call->getNumArgs() == 2 && fn_name == kTraceName &&
      class_decl->getName() == kTraceIfNeededName) {
    FindFieldVisitor finder;
    finder.TraverseStmt(call->getArg(1));
    if (finder.field())
      FoundField(finder.field());
    return true;
  }
  return false;
}

namespace {
FieldDecl* GetRangeField(CXXForRangeStmt* for_range_stmt) {
  DeclStmt* decl_stmt = for_range_stmt->getRangeStmt();
  if (!decl_stmt->isSingleDecl()) {
    return nullptr;
  }
  VarDecl* var_decl = dyn_cast<VarDecl>(decl_stmt->getSingleDecl());
  if (!var_decl) {
    return nullptr;
  }
  MemberExpr* member_expr = dyn_cast<MemberExpr>(var_decl->getInit());
  if (!member_expr) {
    return nullptr;
  }
  FieldDecl* field_decl = dyn_cast<FieldDecl>(member_expr->getMemberDecl());
  if (!field_decl) {
    return nullptr;
  }
  return field_decl;
}
}  // namespace

bool CheckTraceVisitor::VisitStmt(Stmt* stmt) {
  CXXForRangeStmt* for_range = dyn_cast<CXXForRangeStmt>(stmt);
  if (!for_range) {
    return true;
  }

  // Array tracing could be phrased as a for-range statement over the array.
  FieldDecl* field_decl = GetRangeField(for_range);
  if (!field_decl) {
    return true;
  }

  // The range of the for-range statement references a field. If that field
  // is an array, assume the array is being traced.
  RecordInfo::Fields::iterator it = FindField(info_, field_decl);
  if (it == info_->GetFields().end()) {
    return true;
  }

  Edge* field_edge = it->second.edge();
  if (field_edge->IsArray()) {
    MarkTraced(it);
  }
  if (field_edge->IsCollection()) {
    Collection* collection = static_cast<Collection*>(field_edge);
    if (collection->IsSTDCollection() &&
        (collection->GetCollectionName() == "array")) {
      MarkTraced(it);
    }
  }

  return true;
}