llvm/mlir/include/mlir/Tools/PDLL/AST/Nodes.h

//===- Nodes.h --------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_TOOLS_PDLL_AST_NODES_H_
#define MLIR_TOOLS_PDLL_AST_NODES_H_

#include "mlir/Support/LLVM.h"
#include "mlir/Tools/PDLL/AST/Types.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/SMLoc.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/TrailingObjects.h"
#include <optional>

namespace mlir {
namespace pdll {
namespace ast {
class Context;
class Decl;
class Expr;
class NamedAttributeDecl;
class OpNameDecl;
class VariableDecl;

//===----------------------------------------------------------------------===//
// Name
//===----------------------------------------------------------------------===//

/// This class provides a convenient API for interacting with source names. It
/// contains a string name as well as the source location for that name.
struct Name {};

//===----------------------------------------------------------------------===//
// DeclScope
//===----------------------------------------------------------------------===//

/// This class represents a scope for named AST decls. A scope determines the
/// visibility and lifetime of a named declaration.
class DeclScope {};

//===----------------------------------------------------------------------===//
// Node
//===----------------------------------------------------------------------===//

/// This class represents a base AST node. All AST nodes are derived from this
/// class, and it contains many of the base functionality for interacting with
/// nodes.
class Node {};

//===----------------------------------------------------------------------===//
// Stmt
//===----------------------------------------------------------------------===//

/// This class represents a base AST Statement node.
class Stmt : public Node {};

//===----------------------------------------------------------------------===//
// CompoundStmt
//===----------------------------------------------------------------------===//

/// This statement represents a compound statement, which contains a collection
/// of other statements.
class CompoundStmt final : public Node::NodeBase<CompoundStmt, Stmt>,
                           private llvm::TrailingObjects<CompoundStmt, Stmt *> {};

//===----------------------------------------------------------------------===//
// LetStmt
//===----------------------------------------------------------------------===//

/// This statement represents a `let` statement in PDLL. This statement is used
/// to define variables.
class LetStmt final : public Node::NodeBase<LetStmt, Stmt> {};

//===----------------------------------------------------------------------===//
// OpRewriteStmt
//===----------------------------------------------------------------------===//

/// This class represents a base operation rewrite statement. Operation rewrite
/// statements perform a set of transformations on a given root operation.
class OpRewriteStmt : public Stmt {};

//===----------------------------------------------------------------------===//
// EraseStmt

/// This statement represents the `erase` statement in PDLL. This statement
/// erases the given root operation, corresponding roughly to the
/// PatternRewriter::eraseOp API.
class EraseStmt final : public Node::NodeBase<EraseStmt, OpRewriteStmt> {};

//===----------------------------------------------------------------------===//
// ReplaceStmt

/// This statement represents the `replace` statement in PDLL. This statement
/// replace the given root operation with a set of values, corresponding roughly
/// to the PatternRewriter::replaceOp API.
class ReplaceStmt final : public Node::NodeBase<ReplaceStmt, OpRewriteStmt>,
                          private llvm::TrailingObjects<ReplaceStmt, Expr *> {};

//===----------------------------------------------------------------------===//
// RewriteStmt

/// This statement represents an operation rewrite that contains a block of
/// nested rewrite commands. This allows for building more complex operation
/// rewrites that span across multiple statements, which may be unconnected.
class RewriteStmt final : public Node::NodeBase<RewriteStmt, OpRewriteStmt> {};

//===----------------------------------------------------------------------===//
// ReturnStmt
//===----------------------------------------------------------------------===//

/// This statement represents a return from a "callable" like decl, e.g. a
/// Constraint or a Rewrite.
class ReturnStmt final : public Node::NodeBase<ReturnStmt, Stmt> {};

//===----------------------------------------------------------------------===//
// Expr
//===----------------------------------------------------------------------===//

/// This class represents a base AST Expression node.
class Expr : public Stmt {};

//===----------------------------------------------------------------------===//
// AttributeExpr
//===----------------------------------------------------------------------===//

/// This expression represents a literal MLIR Attribute, and contains the
/// textual assembly format of that attribute.
class AttributeExpr : public Node::NodeBase<AttributeExpr, Expr> {};

//===----------------------------------------------------------------------===//
// CallExpr
//===----------------------------------------------------------------------===//

/// This expression represents a call to a decl, such as a
/// UserConstraintDecl/UserRewriteDecl.
class CallExpr final : public Node::NodeBase<CallExpr, Expr>,
                       private llvm::TrailingObjects<CallExpr, Expr *> {};

//===----------------------------------------------------------------------===//
// DeclRefExpr
//===----------------------------------------------------------------------===//

/// This expression represents a reference to a Decl node.
class DeclRefExpr : public Node::NodeBase<DeclRefExpr, Expr> {};

//===----------------------------------------------------------------------===//
// MemberAccessExpr
//===----------------------------------------------------------------------===//

/// This expression represents a named member or field access of a given parent
/// expression.
class MemberAccessExpr : public Node::NodeBase<MemberAccessExpr, Expr> {};

//===----------------------------------------------------------------------===//
// AllResultsMemberAccessExpr

/// This class represents an instance of MemberAccessExpr that references all
/// results of an operation.
class AllResultsMemberAccessExpr : public MemberAccessExpr {};

//===----------------------------------------------------------------------===//
// OperationExpr
//===----------------------------------------------------------------------===//

/// This expression represents the structural form of an MLIR Operation. It
/// represents either an input operation to match, or an operation to create
/// within a rewrite.
class OperationExpr final
    : public Node::NodeBase<OperationExpr, Expr>,
      private llvm::TrailingObjects<OperationExpr, Expr *,
                                    NamedAttributeDecl *> {};

//===----------------------------------------------------------------------===//
// RangeExpr
//===----------------------------------------------------------------------===//

/// This expression builds a range from a set of element values (which may be
/// ranges themselves).
class RangeExpr final : public Node::NodeBase<RangeExpr, Expr>,
                        private llvm::TrailingObjects<RangeExpr, Expr *> {};

//===----------------------------------------------------------------------===//
// TupleExpr
//===----------------------------------------------------------------------===//

/// This expression builds a tuple from a set of element values.
class TupleExpr final : public Node::NodeBase<TupleExpr, Expr>,
                        private llvm::TrailingObjects<TupleExpr, Expr *> {};

//===----------------------------------------------------------------------===//
// TypeExpr
//===----------------------------------------------------------------------===//

/// This expression represents a literal MLIR Type, and contains the textual
/// assembly format of that type.
class TypeExpr : public Node::NodeBase<TypeExpr, Expr> {};

//===----------------------------------------------------------------------===//
// Decl
//===----------------------------------------------------------------------===//

/// This class represents the base Decl node.
class Decl : public Node {};

//===----------------------------------------------------------------------===//
// ConstraintDecl
//===----------------------------------------------------------------------===//

/// This class represents the base of all AST Constraint decls. Constraints
/// apply matcher conditions to, and define the type of PDLL variables.
class ConstraintDecl : public Decl {};

/// This class represents a reference to a constraint, and contains a constraint
/// and the location of the reference.
struct ConstraintRef {};

//===----------------------------------------------------------------------===//
// CoreConstraintDecl
//===----------------------------------------------------------------------===//

/// This class represents the base of all "core" constraints. Core constraints
/// are those that generally represent a concrete IR construct, such as
/// `Type`s or `Value`s.
class CoreConstraintDecl : public ConstraintDecl {};

//===----------------------------------------------------------------------===//
// AttrConstraintDecl

/// The class represents an Attribute constraint, and constrains a variable to
/// be an Attribute.
class AttrConstraintDecl
    : public Node::NodeBase<AttrConstraintDecl, CoreConstraintDecl> {};

//===----------------------------------------------------------------------===//
// OpConstraintDecl

/// The class represents an Operation constraint, and constrains a variable to
/// be an Operation.
class OpConstraintDecl
    : public Node::NodeBase<OpConstraintDecl, CoreConstraintDecl> {};

//===----------------------------------------------------------------------===//
// TypeConstraintDecl

/// The class represents a Type constraint, and constrains a variable to be a
/// Type.
class TypeConstraintDecl
    : public Node::NodeBase<TypeConstraintDecl, CoreConstraintDecl> {};

//===----------------------------------------------------------------------===//
// TypeRangeConstraintDecl

/// The class represents a TypeRange constraint, and constrains a variable to be
/// a TypeRange.
class TypeRangeConstraintDecl
    : public Node::NodeBase<TypeRangeConstraintDecl, CoreConstraintDecl> {};

//===----------------------------------------------------------------------===//
// ValueConstraintDecl

/// The class represents a Value constraint, and constrains a variable to be a
/// Value.
class ValueConstraintDecl
    : public Node::NodeBase<ValueConstraintDecl, CoreConstraintDecl> {};

//===----------------------------------------------------------------------===//
// ValueRangeConstraintDecl

/// The class represents a ValueRange constraint, and constrains a variable to
/// be a ValueRange.
class ValueRangeConstraintDecl
    : public Node::NodeBase<ValueRangeConstraintDecl, CoreConstraintDecl> {};

//===----------------------------------------------------------------------===//
// UserConstraintDecl
//===----------------------------------------------------------------------===//

/// This decl represents a user defined constraint. This is either:
///   * an imported native constraint
///     - Similar to an external function declaration. This is a native
///       constraint defined externally, and imported into PDLL via a
///       declaration.
///   * a native constraint defined in PDLL
///     - This is a native constraint, i.e. a constraint whose implementation is
///       defined in C++(or potentially some other non-PDLL language). The
///       implementation of this constraint is specified as a string code block
///       in PDLL.
///   * a PDLL constraint
///     - This is a constraint which is defined using only PDLL constructs.
class UserConstraintDecl final
    : public Node::NodeBase<UserConstraintDecl, ConstraintDecl>,
      llvm::TrailingObjects<UserConstraintDecl, VariableDecl *, StringRef> {};

//===----------------------------------------------------------------------===//
// NamedAttributeDecl
//===----------------------------------------------------------------------===//

/// This Decl represents a NamedAttribute, and contains a string name and
/// attribute value.
class NamedAttributeDecl : public Node::NodeBase<NamedAttributeDecl, Decl> {};

//===----------------------------------------------------------------------===//
// OpNameDecl
//===----------------------------------------------------------------------===//

/// This Decl represents an OperationName.
class OpNameDecl : public Node::NodeBase<OpNameDecl, Decl> {};

//===----------------------------------------------------------------------===//
// PatternDecl
//===----------------------------------------------------------------------===//

/// This Decl represents a single Pattern.
class PatternDecl : public Node::NodeBase<PatternDecl, Decl> {};

//===----------------------------------------------------------------------===//
// UserRewriteDecl
//===----------------------------------------------------------------------===//

/// This decl represents a user defined rewrite. This is either:
///   * an imported native rewrite
///     - Similar to an external function declaration. This is a native
///       rewrite defined externally, and imported into PDLL via a declaration.
///   * a native rewrite defined in PDLL
///     - This is a native rewrite, i.e. a rewrite whose implementation is
///       defined in C++(or potentially some other non-PDLL language). The
///       implementation of this rewrite is specified as a string code block
///       in PDLL.
///   * a PDLL rewrite
///     - This is a rewrite which is defined using only PDLL constructs.
class UserRewriteDecl final
    : public Node::NodeBase<UserRewriteDecl, Decl>,
      llvm::TrailingObjects<UserRewriteDecl, VariableDecl *> {};

//===----------------------------------------------------------------------===//
// CallableDecl
//===----------------------------------------------------------------------===//

/// This decl represents a shared interface for all callable decls.
class CallableDecl : public Decl {};

//===----------------------------------------------------------------------===//
// VariableDecl
//===----------------------------------------------------------------------===//

/// This Decl represents the definition of a PDLL variable.
class VariableDecl final
    : public Node::NodeBase<VariableDecl, Decl>,
      private llvm::TrailingObjects<VariableDecl, ConstraintRef> {};

//===----------------------------------------------------------------------===//
// Module
//===----------------------------------------------------------------------===//

/// This class represents a top-level AST module.
class Module final : public Node::NodeBase<Module, Node>,
                     private llvm::TrailingObjects<Module, Decl *> {};

//===----------------------------------------------------------------------===//
// Defered Method Definitions
//===----------------------------------------------------------------------===//

inline bool Decl::classof(const Node *node) {}

inline bool ConstraintDecl::classof(const Node *node) {}

inline bool CoreConstraintDecl::classof(const Node *node) {}

inline bool Expr::classof(const Node *node) {}

inline bool OpRewriteStmt::classof(const Node *node) {}

inline bool Stmt::classof(const Node *node) {}

} // namespace ast
} // namespace pdll
} // namespace mlir

#endif // MLIR_TOOLS_PDLL_AST_NODES_H_