chromium/tools/clang/raw_ptr_plugin/TypePredicateUtil.h

// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef TOOLS_CLANG_RAW_PTR_PLUGIN_TYPEPREDICATEUTIL_H_
#define TOOLS_CLANG_RAW_PTR_PLUGIN_TYPEPREDICATEUTIL_H_

#include <map>
#include <memory>
#include <optional>
#include <set>
#include <string>
#include <vector>

#include "clang/AST/Decl.h"
#include "clang/AST/DeclTemplate.h"
#include "clang/AST/Type.h"
#include "llvm/ADT/ScopeExit.h"

enum class InductionRule : unsigned {
  kNone = 0,
  kPointerPointee = (1 << 0),
  kObjCPointerPointee = (1 << 1),
  kReferencePointee = (1 << 2),
  kArrayElement = (1 << 3),
  kUnqualifiedDesugaredType = (1 << 4),
  kBaseClass = (1 << 5),
  kVirtualBaseClass = (1 << 6),
  kField = (1 << 7),
  kTemplateArgument = (1 << 8),
};

constexpr InductionRule operator|(InductionRule a, InductionRule b) {
  return static_cast<InductionRule>(static_cast<unsigned>(a) |
                                    static_cast<unsigned>(b));
}
constexpr InductionRule operator&(InductionRule a, InductionRule b) {
  return static_cast<InductionRule>(static_cast<unsigned>(a) &
                                    static_cast<unsigned>(b));
}

// Represents a match result |verdict_|.
// - MatchResult::kNoMatch: no match found against |type|.
// - MatchResult::kMatch: a match found against |type|.
// - MatchResult::kUndetermined: This denotes the result
//   is not yet determined, due to cross references.
// Holds some additional information to tell reasons.
class MatchResult {
 public:
  enum Verdict {
    kMatch,
    kNoMatch,
    // This denotes the match status is not yet determined.
    kUndetermined,
  };

  explicit MatchResult(const clang::Type* type) : type_(type) {}
  explicit MatchResult(const clang::Type* type, Verdict verdict)
      : type_(type), verdict_(verdict) {}

  const clang::Type* type() const { return type_; }

  Verdict verdict() const { return this->verdict_; }

  std::shared_ptr<MatchResult> source() const { return this->source_; }

  std::optional<clang::SourceLocation> source_loc() const {
    return this->source_loc_;
  }

 private:
  template <InductionRule Rules>
  friend class TypePredicate;

  // Merges a sub verdict into this type's verdict.
  //
  // | this   \ sub  | kNoMatch      | kUndetermined | kMatch |
  // +---------------+---------------+---------------+--------+
  // | kNoMatch      | kNoMatch      | kUndetermined | kMatch |
  // | kUndetermined | kUndetermined | kUndetermined | kMatch |
  // | kMatch        | kMatch        | kMatch        | kMatch |
  Verdict MergeSubResult(
      std::shared_ptr<MatchResult> sub,
      std::optional<clang::SourceLocation> loc = std::nullopt) {
    if (sub->verdict_ == kMatch && this->verdict_ != kMatch) {
      this->verdict_ = kMatch;
      this->source_ = std::move(sub);
      this->source_loc_ = loc;
    } else if (sub->verdict_ == kUndetermined && this->verdict_ == kNoMatch) {
      this->verdict_ = kUndetermined;
      this->source_ = std::move(sub);
      this->source_loc_ = loc;
    }
    return this->verdict_;
  }

  // |type_| is considered to be |verdict_|.
  // Optionally, the result contains a reason for the verdict, |source_|.
  // There can be multiple reasons (e.g. |type_| has multiple matching
  // members), but only one of them is stored. The relation between |type_|
  // and |source_| is optionally shown at |source_loc_|.
  const clang::Type* type_;
  Verdict verdict_ = kNoMatch;
  std::shared_ptr<MatchResult> source_;
  std::optional<clang::SourceLocation> source_loc_;
};

// Determines there is a match against |type| or not.
// A type is considered match if |IsBaseMatch| returns true or
// reach such |type| by applying InductionRule recursively.
template <InductionRule Rules>
class TypePredicate {
 public:
  virtual ~TypePredicate() = default;
  bool Matches(const clang::Type* type) const {
    return GetMatchResult(type)->verdict_ == MatchResult::kMatch;
  }

  std::shared_ptr<MatchResult> GetMatchResult(
      const clang::Type* type,
      std::set<const clang::Type*>* visited = nullptr) const {
    // Retrieve a "base" type to reduce recursion depth.
    const clang::Type* raw_type = GetBaseType(type);
    if (!raw_type || !raw_type->isRecordType()) {
      // |TypePredicate| does not support followings:
      // - function type
      // - enum type
      // - builtin type
      // - complex type
      // - obj-C types
      // - using type
      // - typeof type
      return std::make_shared<MatchResult>(type);  // No match.
    }

    // Use a memoized result if exists.
    auto iter = cache_.find(type);
    if (iter != cache_.end()) {
      return iter->second;
    }

    // This performs DFS on a directed graph composed of |Type*|.
    // Avoid searching for visited nodes by managing |visited|, as this can lead
    // to infinite loops in the presence of self-references and
    // cross-references. Since finding a match for |Type* x| is equivalent to
    // being able to reach from node |Type* x| to node |Type* y| where
    // |IsBaseCase(y)|, there is no need to look up visited nodes again.
    bool root = visited == nullptr;
    if (root) {
      // Will be deleted as a part of |clean_up()|.
      visited = new std::set<const clang::Type*>();
    } else if (visited->count(type)) {
      // This type is already visited but not memoized,
      // therefore this node is reached by following cross-references from
      // ancestors. The verdict of this node cannot be determined without
      // waiting for computation in its ancestors.
      return std::make_shared<MatchResult>(raw_type,
                                           MatchResult::kUndetermined);
    }
    visited->insert(type);

    auto match = std::make_shared<MatchResult>(raw_type);

    // Clean-up: this lambda is called automatically at the scope exit.
    const auto clean_up =
        llvm::make_scope_exit([this, &visited, &raw_type, &root, &match] {
          if (root) {
            delete visited;
          }
          // Memoize the result if finalized.
          if (match->verdict_ != MatchResult::kUndetermined) {
            this->cache_.insert({raw_type, match});
          }
        });

    // Base case.
    if (IsBaseMatch(raw_type)) {
      match->verdict_ = MatchResult::kMatch;
      return match;
    }

    const clang::RecordDecl* decl = raw_type->getAsRecordDecl();
    assert(decl);

    // Check member fields
    if constexpr ((Rules & InductionRule::kField) != InductionRule::kNone) {
      for (const auto& field : decl->fields()) {
        match->MergeSubResult(
            GetMatchResult(field->getType().getTypePtrOrNull(), visited),
            field->getBeginLoc());

        // Verdict finalized: early return.
        if (match->verdict_ == MatchResult::kMatch) {
          return match;
        }
      }
    }

    const auto* cxx_decl = clang::dyn_cast<clang::CXXRecordDecl>(decl);
    if (cxx_decl && cxx_decl->hasDefinition()) {
      // Check base classes
      if constexpr ((Rules & InductionRule::kBaseClass) !=
                    InductionRule::kNone) {
        for (const auto& base_specifier : cxx_decl->bases()) {
          match->MergeSubResult(
              GetMatchResult(base_specifier.getType().getTypePtr(), visited),
              base_specifier.getBeginLoc());

          // Verdict finalized: early return.
          if (match->verdict_ == MatchResult::kMatch) {
            return match;
          }
        }
      }

      // Check virtual base classes
      if constexpr ((Rules & InductionRule::kVirtualBaseClass) !=
                    InductionRule::kNone) {
        for (const auto& base_specifier : cxx_decl->vbases()) {
          match->MergeSubResult(
              GetMatchResult(base_specifier.getType().getTypePtr(), visited),
              base_specifier.getBeginLoc());

          // Verdict finalized: early return.
          if (match->verdict_ == MatchResult::kMatch) {
            return match;
          }
        }
      }
    }

    // Check template parameters.
    if constexpr ((Rules & InductionRule::kTemplateArgument) !=
                  InductionRule::kNone) {
      if (auto* field_record_template =
              clang::dyn_cast<clang::ClassTemplateSpecializationDecl>(decl)) {
        const auto& template_args = field_record_template->getTemplateArgs();
        for (unsigned i = 0; i < template_args.size(); i++) {
          if (template_args[i].getKind() != clang::TemplateArgument::Type) {
            continue;
          }
          match->MergeSubResult(
              GetMatchResult(template_args[i].getAsType().getTypePtrOrNull(),
                             visited),
              field_record_template->getTemplateKeywordLoc());

          // Verdict finalized: early return.
          if (match->verdict_ == MatchResult::kMatch) {
            return match;
          }
        }
      }
    }

    // All reachable types have been traversed but the root type has not
    // been marked as a match; therefore it must be no match.
    if (root && match->verdict_ == MatchResult::kUndetermined) {
      match->verdict_ = MatchResult::kNoMatch;
    }
    return match;
  }

 private:
  const clang::Type* GetBaseType(const clang::Type* type) const {
    using clang::dyn_cast;

    const clang::Type* last_type = nullptr;
    while (type && type != last_type) {
      last_type = type;

      // Unwrap type aliases.
      if constexpr ((Rules & InductionRule::kUnqualifiedDesugaredType) !=
                    InductionRule::kNone) {
        type = type->getUnqualifiedDesugaredType();
      }

      // Unwrap pointers.
      if constexpr ((Rules & InductionRule::kPointerPointee) !=
                    InductionRule::kNone) {
        while (type && type->isPointerType()) {
          type = type->getPointeeType().getTypePtr();
        }
      }

      // Unwrap ObjC pointers.
      if constexpr ((Rules & InductionRule::kObjCPointerPointee) !=
                    InductionRule::kNone) {
        while (type && type->isObjCObjectPointerType()) {
          type = type->getPointeeType().getTypePtr();
        }
      }

      // Unwrap array.
      if constexpr ((Rules & InductionRule::kArrayElement) !=
                    InductionRule::kNone) {
        while (const auto* array_type = dyn_cast<clang::ArrayType>(type)) {
          type = array_type->getElementType().getTypePtr();
        }
      }

      // Unwrap reference.
      if constexpr ((Rules & InductionRule::kReferencePointee) !=
                    InductionRule::kNone) {
        if (const auto* ref_type = dyn_cast<clang::ReferenceType>(type)) {
          type = ref_type->getPointeeType().getTypePtrOrNull();
        }
      }
    }
    return type;
  }

  virtual bool IsBaseMatch(const clang::Type* type) const { return false; }

  // Cache to efficiently determine match.
  mutable std::map<const clang::Type*, std::shared_ptr<MatchResult>> cache_;
};

#endif  // TOOLS_CLANG_RAW_PTR_PLUGIN_TYPEPREDICATEUTIL_H_